import { BinaryType, Stmt, Sym } from "../ast.ts";
import { VType, vtypeToString } from "../vtype.ts";

export type Mir = {
    fns: Fn[];
};

export type Fn = {
    stmt: Stmt;
    locals: Local[];
    blocks: Block[];
    entry: BlockId;
    exit: BlockId;
};

export type LocalId = number;

export type Local = {
    id: LocalId;
    mut: boolean;
    vtype: VType;
    sym?: Sym;
};

export type BlockId = number;

export type Block = {
    id: BlockId;
    ops: Op[];
    ter: Ter;
    label?: string;
};

export type Op = {
    kind: OpKind;
};

type L = LocalId;
type R = RValue;

export type OpKind =
    | { type: "error" }
    | { type: "assign"; dst: L; src: R }
    | { type: "ref"; dst: L; src: L }
    | { type: "ref_mut"; dst: L; src: L }
    | { type: "ptr"; dst: L; src: L }
    | { type: "ptr_mut"; dst: L; src: L }
    | { type: "deref"; dst: L; src: R }
    | { type: "assign_deref"; subject: R; src: R }
    | { type: "field"; dst: L; subject: R; ident: string }
    | { type: "assign_field"; subject: R; ident: string; src: R }
    | { type: "index"; dst: L; subject: R; index: R }
    | { type: "assign_index"; subject: R; index: R; src: R }
    | { type: "call_val"; dst: L; subject: R; args: R[] }
    | { type: "binary"; binaryType: BinaryType; dst: L; left: R; right: R };

export type Ter = {
    kind: TerKind;
};

export type TerKind =
    | { type: "error" }
    | { type: "return" }
    | { type: "jump"; target: BlockId }
    | { type: "if"; cond: R; truthy: BlockId; falsy: BlockId };

export type RValue =
    | { type: "error" }
    | { type: "copy"; id: BlockId }
    | { type: "move"; id: BlockId }
    | { type: "null" }
    | { type: "bool"; val: boolean }
    | { type: "int"; val: number }
    | { type: "string"; val: string }
    | { type: "fn"; stmt: Stmt };

export function visitBlockDsts(
    block: Block,
    visit: (local: LocalId, index: number) => void,
) {
    for (const [op, i] of block.ops.map((v, i) => [v, i] as const)) {
        const ok = op.kind;
        switch (ok.type) {
            case "error":
                break;
            case "assign":
            case "ref":
            case "ref_mut":
            case "ptr":
            case "ptr_mut":
            case "deref":
            case "field":
            case "index":
            case "call_val":
            case "binary":
                visit(ok.dst, i);
                break;
            case "assign_deref":
            case "assign_field":
            case "assign_index":
                break;
            default:
                throw new Error();
        }
    }
}

export function replaceBlockSrcs(
    block: Block,
    replace: (src: RValue) => RValue,
) {
    for (const op of block.ops) {
        const ok = op.kind;
        switch (ok.type) {
            case "error":
                break;
            case "assign":
                ok.src = replace(ok.src);
                break;
            case "ref":
            case "ref_mut":
            case "ptr":
            case "ptr_mut":
                break;
            case "deref":
                ok.src = replace(ok.src);
                break;
            case "assign_deref":
                ok.subject = replace(ok.subject);
                ok.src = replace(ok.src);
                break;
            case "field":
                ok.subject = replace(ok.subject);
                break;
            case "assign_field":
                ok.subject = replace(ok.subject);
                ok.src = replace(ok.src);
                break;
            case "index":
                ok.subject = replace(ok.subject);
                ok.index = replace(ok.index);
                break;
            case "assign_index":
                ok.subject = replace(ok.subject);
                ok.index = replace(ok.index);
                ok.src = replace(ok.src);
                break;
            case "call_val":
                ok.subject = replace(ok.subject);
                ok.args = ok.args.map((arg) => replace(arg));
                break;
            case "binary":
                ok.left = replace(ok.left);
                ok.right = replace(ok.right);
                break;
            default:
                throw new Error();
        }
    }
    const tk = block.ter.kind;
    switch (tk.type) {
        case "error":
            break;
        case "return":
            break;
        case "jump":
            break;
        case "if":
            tk.cond = replace(tk.cond);
            break;
    }
}

export function visitBlockSrcs(
    block: Block,
    visitor: (src: RValue, op?: Op, index?: number, ter?: Ter) => void,
) {
    for (const [op, i] of block.ops.map((v, i) => [v, i] as const)) {
        const ok = op.kind;
        switch (ok.type) {
            case "error":
                break;
            case "assign":
                visitor(ok.src, op, i);
                break;
            case "ref":
            case "ref_mut":
            case "ptr":
            case "ptr_mut":
                break;
            case "deref":
                visitor(ok.src, op, i);
                break;
            case "assign_deref":
                visitor(ok.src, op, i);
                visitor(ok.subject, op, i);
                break;
            case "field":
                visitor(ok.subject, op, i);
                break;
            case "assign_field":
                visitor(ok.subject, op, i);
                visitor(ok.src, op, i);
                break;
            case "index":
                visitor(ok.subject, op, i);
                visitor(ok.index, op, i);
                break;
            case "assign_index":
                visitor(ok.subject, op, i);
                visitor(ok.index, op, i);
                visitor(ok.src, op, i);
                break;
            case "call_val":
                visitor(ok.subject, op, i);
                ok.args.map((arg) => visitor(arg, op, i));
                break;
            case "binary":
                visitor(ok.left, op, i);
                visitor(ok.right, op, i);
                break;
            default:
                throw new Error();
        }
    }
    const tk = block.ter.kind;
    switch (tk.type) {
        case "error":
            break;
        case "return":
            break;
        case "jump":
            break;
        case "if":
            visitor(tk.cond, undefined, undefined, block.ter);
            break;
    }
}

export function mirOpCount(mir: Mir): number {
    return mir.fns
        .reduce((acc, fn) =>
            acc + fn.blocks
                .reduce((acc, block) => acc + block.ops.length + 1, 0), 0);
}

export function printMir(mir: Mir) {
    for (const fn of mir.fns) {
        const stmt = fn.stmt;
        if (stmt.kind.type !== "fn") {
            throw new Error();
        }
        const name = stmt.kind.sym!.fullPath;

        const vtype = stmt.kind.vtype;
        if (vtype?.type !== "fn") {
            throw new Error();
        }
        const generics = vtype.genericParams
            ?.map(({ ident }) => `${ident}`).join(", ") ?? "";
        const params = vtype.params
            .map(({ mut, vtype }, i) =>
                `${mut && "mut" || ""} _${fn.locals[i + 1].id}: ${
                    vtypeToString(vtype)
                }`
            )
            .join(", ");
        const returnType = vtypeToString(vtype.returnType);
        console.log(`${name}${generics}(${params}) -> ${returnType} {`);

        const paramIndices = vtype.params.map((_v, i) => i + 1);
        for (
            const { id, vtype, mut } of fn.locals
                .filter((_v, i) => !paramIndices.includes(i))
        ) {
            const m = mut ? "mut" : "";
            const v = vtypeToString(vtype);
            console.log(`    let ${m} _${id}: ${v};`);
        }
        for (const block of fn.blocks) {
            const l = (msg: string) => console.log(`        ${msg}`);
            const r = rvalueToString;

            console.log(`    ${block.label ?? "bb" + block.id}: {`);
            for (const op of block.ops) {
                const k = op.kind;
                switch (k.type) {
                    case "error":
                        l(`<error>;`);
                        break;
                    case "assign":
                        l(`_${k.dst} = ${r(k.src)};`);
                        break;
                    case "ref":
                        l(`_${k.dst} = &_${k.src};`);
                        break;
                    case "ref_mut":
                        l(`_${k.dst} = &mut _${k.src};`);
                        break;
                    case "ptr":
                        l(`_${k.dst} = *_${k.src};`);
                        break;
                    case "ptr_mut":
                        l(`_${k.dst} = *mut _${k.src};`);
                        break;
                    case "deref":
                        l(`_${k.dst} = *${r(k.src)};`);
                        break;
                    case "assign_deref":
                        l(`*${r(k.subject)} = ${r(k.src)};`);
                        break;
                    case "field":
                        l(`_${k.dst} = ${r(k.subject)}.${k.ident};`);
                        break;
                    case "assign_field":
                        l(`${r(k.subject)}.${k.ident} = ${r(k.src)};`);
                        break;
                    case "index":
                        l(`_${k.dst} = ${r(k.subject)}[${r(k.index)}];`);
                        break;
                    case "assign_index":
                        l(`${r(k.subject)}[${r(k.index)}] = ${r(k.src)};`);
                        break;
                    case "call_val": {
                        const args = k.args.map((arg) => r(arg)).join(", ");
                        l(`_${k.dst} = call ${r(k.subject)}(${args});`);
                        break;
                    }
                    case "binary": {
                        l(`_${k.dst} = ${r(k.left)} ${k.binaryType} ${
                            r(k.right)
                        };`);
                        break;
                    }
                    default:
                        throw new Error();
                }
            }
            const tk = block.ter.kind;
            switch (tk.type) {
                case "error":
                    l(`<error>;`);
                    break;
                case "return":
                    l(`return;`);
                    break;
                case "jump":
                    l(`jump bb${tk.target};`);
                    break;
                case "if":
                    l(`if ${
                        r(tk.cond)
                    }, true: bb${tk.truthy}, false: bb${tk.falsy};`);
                    break;
                default:
                    throw new Error();
            }
            console.log("    }");
        }
        console.log("}");
    }
}

export function rvalueToString(rvalue: RValue): string {
    switch (rvalue.type) {
        case "error":
            return `<error>`;
        case "copy":
            return `copy _${rvalue.id}`;
        case "move":
            return `move _${rvalue.id}`;
        case "null":
            return "null";
        case "bool":
            return `${rvalue.val}`;
        case "int":
            return `${rvalue.val}`;
        case "string":
            return `"${rvalue.val}"`;
        case "fn": {
            const stmt = rvalue.stmt;
            if (stmt.kind.type !== "fn") {
                throw new Error();
            }
            return stmt.kind.sym!.fullPath;
        }
    }
}