import * as Ast from "../ast.ts";
import { AllFnsCollector } from "../mono.ts";
import { VType, vtypesEqual } from "../vtype.ts";
import {
    Block,
    BlockId,
    Fn,
    Local,
    LocalId,
    Mir,
    OpKind,
    RValue,
    Ter,
    TerKind,
} from "./mir.ts";

export function lowerAst(ast: Ast.Stmt[]): Mir {
    return new AstLowerer(ast).lower();
}

class AstLowerer {
    public constructor(private ast: Ast.Stmt[]) {}

    public lower(): Mir {
        const fnAsts = new AllFnsCollector().collect(this.ast).values();
        const fns = fnAsts
            .map((fnAst) => new FnAstLowerer(fnAst).lower())
            .toArray();
        return { fns };
    }
}

class LocalAllocator {
    private locals: Local[] = [];

    public alloc(vtype: VType, sym?: Ast.Sym): LocalId {
        const id = this.locals.length;
        this.locals.push({ id, mut: false, vtype, sym });
        return id;
    }

    public allocMut(vtype: VType, sym?: Ast.Sym): LocalId {
        const id = this.locals.length;
        this.locals.push({ id, mut: true, vtype, sym });
        return id;
    }

    public finish(): Local[] {
        return this.locals;
    }
}

class FnAstLowerer {
    private locals = new LocalAllocator();
    private blockIdCounter = 0;
    private currentBlockId = 0;
    private blocks = new Map<BlockId, Block>();

    private fnParamIndexLocals = new Map<number, LocalId>();
    private letStmtIdLocals = new Map<number, LocalId>();

    private breakStack: { local: LocalId; block: BlockId }[] = [];

    public constructor(private ast: Ast.Stmt) {}

    public lower(): Fn {
        const stmt = this.ast;
        if (stmt.kind.type !== "fn") {
            throw new Error();
        }
        const vtype = stmt.kind.vtype;
        if (vtype?.type !== "fn") {
            throw new Error();
        }

        const rLoc = this.locals.alloc(vtype.returnType);
        for (const param of stmt.kind.params) {
            const id = this.locals.allocMut(param.vtype!);
            this.fnParamIndexLocals.set(param.index!, id);
        }

        const entry = this.pushBlock();
        const rVal = this.lowerBlockExpr(stmt.kind.body);
        this.addOp({ type: "assign", dst: rLoc, src: local(rVal) });
        this.setTer({ type: "return" });
        const exit = this.currentBlock();

        const locals = this.locals.finish();
        const blocks = this.blocks.values().toArray();
        return { stmt, locals, blocks, entry, exit };
    }

    private lowerStmt(stmt: Ast.Stmt) {
        switch (stmt.kind.type) {
            case "error":
            case "mod_block":
            case "mod_file":
            case "mod":
                break;
            case "break": {
                const { local: dst, block } = this.breakStack.at(-1)!;
                if (stmt.kind.expr) {
                    const val = this.lowerExpr(stmt.kind.expr);
                    this.addOp({ type: "assign", dst, src: local(val) });
                } else {
                    this.addOp({ type: "assign", dst, src: { type: "null" } });
                }
                this.setTer({ type: "jump", target: block });
                this.pushBlock();
                return;
            }
            case "return":
                break;
            case "fn":
                // nothing
                return;
            case "let":
                this.lowerLetStmt(stmt);
                return;
            case "type_alias":
                break;
            case "assign":
                return this.lowerAssign(stmt);
            case "expr": {
                this.lowerExpr(stmt.kind.expr);
                return;
            }
        }
        throw new Error(`statement type '${stmt.kind.type}' not covered`);
    }

    private lowerAssign(stmt: Ast.Stmt) {
        if (stmt.kind.type !== "assign") {
            throw new Error();
        }
        if (stmt.kind.assignType !== "=") {
            throw new Error("incomplete desugar");
        }
        const src = local(this.lowerExpr(stmt.kind.value));
        const s = stmt.kind.subject;
        switch (s.kind.type) {
            case "field": {
                const subject = local(this.lowerExpr(s.kind.subject));
                const ident = s.kind.ident;
                this.addOp({ type: "assign_field", subject, ident, src });
                return;
            }
            case "index": {
                const subject = local(this.lowerExpr(s.kind.subject));
                const index = local(this.lowerExpr(s.kind.value));
                this.addOp({ type: "assign_index", subject, index, src });
                return;
            }
            case "sym": {
                const sym = s.kind.sym;
                switch (sym.type) {
                    case "let": {
                        const dst = this.letStmtIdLocals.get(sym.stmt.id)!;
                        this.addOp({ type: "assign", dst, src });
                        return;
                    }
                    case "fn_param": {
                        const dst = this.fnParamIndexLocals.get(
                            sym.param.index!,
                        )!;
                        this.addOp({ type: "assign", dst, src });
                        return;
                    }
                }
                throw new Error(`symbol type '${sym.type}' not covered`);
            }
            default:
                throw new Error();
        }
    }

    private lowerLetStmt(stmt: Ast.Stmt) {
        if (stmt.kind.type !== "let") {
            throw new Error();
        }
        const srcId = this.lowerExpr(stmt.kind.value);
        const dst = this.locals.allocMut(
            stmt.kind.param.vtype!,
            stmt.kind.param.sym!,
        );
        this.addOp({ type: "assign", dst, src: local(srcId) });
        this.letStmtIdLocals.set(stmt.id, dst);
    }

    private lowerExpr(expr: Ast.Expr): LocalId {
        switch (expr.kind.type) {
            case "error": {
                const dst = this.locals.alloc({ type: "error" });
                this.addOp({ type: "assign", dst, src: { type: "error" } });
                return dst;
            }
            case "null": {
                const dst = this.locals.alloc({ type: "null" });
                this.addOp({ type: "assign", dst, src: { type: "null" } });
                return dst;
            }
            case "bool": {
                const val = expr.kind.value;
                const dst = this.locals.alloc({ type: "bool" });
                this.addOp({ type: "assign", dst, src: { type: "bool", val } });
                return dst;
            }
            case "int": {
                const val = expr.kind.value;
                const dst = this.locals.alloc({ type: "int" });
                this.addOp({ type: "assign", dst, src: { type: "int", val } });
                return dst;
            }
            case "string": {
                const val = expr.kind.value;
                const dst = this.locals.alloc({ type: "string" });
                this.addOp({
                    type: "assign",
                    dst,
                    src: { type: "string", val },
                });
                return dst;
            }
            case "ident":
                throw new Error("should've been resolved");
            case "sym":
                return this.lowerSymExpr(expr);
            case "group":
                return this.lowerExpr(expr.kind.expr);
            case "ref": {
                const src = this.lowerExpr(expr.kind.subject);
                const dst = this.locals.alloc(expr.vtype!);
                this.addOp({ type: "ref", dst, src });
                return dst;
            }
            case "ref_mut": {
                const src = this.lowerExpr(expr.kind.subject);
                const dst = this.locals.alloc(expr.vtype!);
                this.addOp({ type: "ref_mut", dst, src });
                return dst;
            }
            case "deref": {
                const src = local(this.lowerExpr(expr.kind.subject));
                const dst = this.locals.alloc(expr.kind.subject.vtype!);
                this.addOp({ type: "deref", dst, src });
                return dst;
            }
            case "array":
                throw new Error("incomplete desugar");
            case "struct":
                throw new Error("incomplete desugar");
            case "field":
                return this.lowerFieldExpr(expr);
            case "index":
                return this.lowerIndexExpr(expr);
            case "call":
                return this.lowerCallExpr(expr);
            case "path":
            case "etype_args":
            case "unary":
                break;
            case "binary":
                return this.lowerBinaryExpr(expr);
            case "if":
                return this.lowerIfExpr(expr);
            case "loop":
                return this.lowerLoopExpr(expr);
            case "block":
                return this.lowerBlockExpr(expr);
            case "while":
            case "for_in":
            case "for":
                throw new Error("incomplete desugar");
        }
        throw new Error(`expression type '${expr.kind.type}' not covered`);
    }

    private lowerSymExpr(expr: Ast.Expr): LocalId {
        if (expr.kind.type !== "sym") {
            throw new Error();
        }
        const sym = expr.kind.sym;
        switch (sym.type) {
            case "let":
                return this.letStmtIdLocals.get(sym.stmt.id)!;
            case "let_static":
            case "type_alias":
                break;
            case "fn": {
                const stmt = sym.stmt;
                if (sym.stmt.kind.type !== "fn") {
                    throw new Error();
                }
                const dst = this.locals.alloc(sym.stmt.kind.vtype!);
                this.addOp({ type: "assign", dst, src: { type: "fn", stmt } });
                return dst;
            }
            case "fn_param": {
                return this.fnParamIndexLocals.get(sym.param.index!)!;
            }
            case "closure":
            case "generic":
            case "mod":
        }
        throw new Error(`symbol type '${sym.type}' not covered`);
    }

    private lowerFieldExpr(expr: Ast.Expr): LocalId {
        if (expr.kind.type !== "field") {
            throw new Error();
        }
        const ident = expr.kind.ident;
        const subject = local(this.lowerExpr(expr.kind.subject));

        const subjectVType = expr.kind.subject.vtype!;
        if (subjectVType.type !== "struct") {
            throw new Error();
        }
        const fieldVType = subjectVType.fields.find((field) =>
            field.ident === ident
        );
        if (fieldVType === undefined) {
            throw new Error();
        }

        const dst = this.locals.alloc(fieldVType.vtype);
        this.addOp({ type: "field", dst, subject, ident });
        return dst;
    }

    private lowerIndexExpr(expr: Ast.Expr): LocalId {
        if (expr.kind.type !== "index") {
            throw new Error();
        }
        const subject = local(this.lowerExpr(expr.kind.subject));
        const index = local(this.lowerExpr(expr.kind.value));

        const dstVType = ((): VType => {
            const outer = expr.kind.subject.vtype!;
            if (outer.type === "array") {
                return outer.subject;
            }
            if (outer.type === "string") {
                return { type: "int" };
            }
            throw new Error();
        })();

        const dst = this.locals.alloc(dstVType);
        this.addOp({ type: "index", dst, subject, index });
        return dst;
    }

    private lowerCallExpr(expr: Ast.Expr): LocalId {
        if (expr.kind.type !== "call") {
            throw new Error();
        }

        const args = expr.kind.args.map((arg) => local(this.lowerExpr(arg)));

        const subject = local(this.lowerExpr(expr.kind.subject));

        const subjectVType = expr.kind.subject.vtype!;
        if (subjectVType.type !== "fn") {
            throw new Error();
        }

        const dst = this.locals.alloc(subjectVType.returnType);
        this.addOp({ type: "call_val", dst, subject, args });
        return dst;
    }

    private lowerBinaryExpr(expr: Ast.Expr): LocalId {
        if (expr.kind.type !== "binary") {
            throw new Error();
        }
        const leftVType = expr.kind.left.vtype!;
        const rightVType = expr.kind.right.vtype!;
        if (!vtypesEqual(leftVType, rightVType)) {
            throw new Error();
        }
        //const vtype = leftVType.type === "error" && rightVType || leftVType;

        const binaryType = expr.kind.binaryType;
        const left = local(this.lowerExpr(expr.kind.left));
        const right = local(this.lowerExpr(expr.kind.right));

        const dst = this.locals.alloc(expr.vtype!);

        this.addOp({ type: "binary", binaryType, dst, left, right });
        return dst;

        //throw new Error(
        //    `binary vtype '${vtypeToString(leftVType)}' not covered`,
        //);
    }

    private lowerIfExpr(expr: Ast.Expr): LocalId {
        if (expr.kind.type !== "if") {
            throw new Error();
        }
        const condBlock = this.currentBlock();
        const cond = local(this.lowerExpr(expr.kind.cond));
        const end = this.reserveBlock();

        const val = this.locals.alloc(expr.vtype!);

        const truthy = this.pushBlock();
        const truthyVal = local(this.lowerExpr(expr.kind.truthy));
        this.addOp({ type: "assign", dst: val, src: truthyVal });
        this.setTer({ type: "jump", target: end });

        if (expr.kind.falsy) {
            const falsy = this.pushBlock();
            const falsyVal = local(this.lowerExpr(expr.kind.falsy));
            this.addOp({ type: "assign", dst: val, src: falsyVal });
            this.setTer({ type: "jump", target: end });

            this.setTerOn(condBlock, { type: "if", cond, truthy, falsy });
        } else {
            this.setTerOn(condBlock, { type: "if", cond, truthy, falsy: end });
        }

        this.pushBlockWithId(end);

        return val;
    }

    private lowerLoopExpr(expr: Ast.Expr): LocalId {
        if (expr.kind.type !== "loop") {
            throw new Error();
        }

        const val = this.locals.alloc(expr.vtype!);
        const breakBlock = this.reserveBlock();
        this.breakStack.push({ local: val, block: breakBlock });

        const before = this.currentBlock();
        const body = this.pushBlock();
        this.setTerOn(before, { type: "jump", target: body });

        this.lowerExpr(expr.kind.body);
        this.setTer({ type: "jump", target: body });

        this.breakStack.pop();

        this.pushBlockWithId(breakBlock);
        return val;
    }

    private lowerBlockExpr(expr: Ast.Expr): LocalId {
        if (expr.kind.type !== "block") {
            throw new Error();
        }

        for (const stmt of expr.kind.stmts) {
            this.lowerStmt(stmt);
        }
        if (expr.kind.expr) {
            return this.lowerExpr(expr.kind.expr);
        } else {
            const local = this.locals.alloc({ type: "null" });
            this.addOp({ type: "assign", dst: local, src: { type: "null" } });
            return local;
        }
    }

    private addOp(kind: OpKind) {
        this.blocks.get(this.currentBlockId)!.ops.push({ kind });
    }

    private addOpOn(blockId: BlockId, kind: OpKind) {
        this.blocks.get(blockId)!.ops.push({ kind });
    }

    private setTer(kind: TerKind) {
        this.blocks.get(this.currentBlockId)!.ter = { kind };
    }

    private setTerOn(blockId: BlockId, kind: TerKind) {
        this.blocks.get(blockId)!.ter = { kind };
    }

    private currentBlock(): BlockId {
        return this.currentBlockId;
    }

    private reserveBlock(): BlockId {
        const id = this.blockIdCounter;
        this.blockIdCounter += 1;
        return id;
    }

    private pushBlock(label?: string): BlockId {
        const id = this.blockIdCounter;
        this.blockIdCounter += 1;
        const ter: Ter = { kind: { type: "error" } };
        this.blocks.set(id, { id, ops: [], ter, label });
        this.currentBlockId = id;
        return id;
    }

    private pushBlockWithId(id: BlockId): BlockId {
        const ter: Ter = { kind: { type: "error" } };
        this.blocks.set(id, { id, ops: [], ter });
        this.currentBlockId = id;
        return id;
    }
}

function local(id: LocalId): RValue {
    return { type: "move", id };
}