import { Builtins } from "./arch.ts";
import { Expr, Stmt } from "./ast.ts";
import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts";
import { Ops } from "./mod.ts";
import { Assembler, Label } from "./assembler.ts";
import { vtypeToString } from "./vtype.ts";

export class Lowerer {
    private program = new Assembler();
    private locals: Locals = new LocalsFnRoot();
    private fnStmtIdLabelMap: { [key: number]: string } = {};
    private breakStack: Label[] = [];

    public lower(stmts: Stmt[]) {
        this.program.add(Ops.PushPtr, { label: "_start" });
        this.program.add(Ops.Jump);
        this.scoutFnHeaders(stmts);
        for (const stmt of stmts) {
            this.lowerStaticStmt(stmt);
        }
        this.program.setLabel({ label: "_start" });
        this.program.add(Ops.PushPtr, { label: "main" });
        this.program.add(Ops.Call, 0);
        this.program.add(Ops.Pop);
    }

    public finish(): number[] {
        return this.program.assemble();
    }

    private scoutFnHeaders(stmts: Stmt[]) {
        for (const stmt of stmts) {
            if (stmt.kind.type !== "fn") {
                continue;
            }
            const label = stmt.kind.ident === "main"
                ? "main"
                : `${stmt.kind.ident}_${stmt.id}`;
            this.fnStmtIdLabelMap[stmt.id] = label;
        }
    }

    private lowerStaticStmt(stmt: Stmt) {
        switch (stmt.kind.type) {
            case "fn":
                return this.lowerFnStmt(stmt);
            case "error":
            case "break":
            case "return":
            case "let":
            case "assign":
            case "expr":
        }
        throw new Error(`unhandled static statement '${stmt.kind.type}'`);
    }

    private lowerStmt(stmt: Stmt) {
        switch (stmt.kind.type) {
            case "error":
                break;
            case "break":
                return this.lowerBreakStmt(stmt);
            case "return":
                break;
            case "fn":
                return this.lowerFnStmt(stmt);
            case "let":
                return this.lowerLetStmt(stmt);
            case "assign":
                return this.lowerAssignStmt(stmt);
            case "expr":
                this.lowerExpr(stmt.kind.expr);
                this.program.add(Ops.Pop);
                return;
        }
        throw new Error(`unhandled stmt '${stmt.kind.type}'`);
    }

    private lowerAssignStmt(stmt: Stmt) {
        if (stmt.kind.type !== "assign") {
            throw new Error();
        }
        this.lowerExpr(stmt.kind.value);
        switch (stmt.kind.subject.kind.type) {
            case "field": {
                this.lowerExpr(stmt.kind.subject.kind.subject);
                this.program.add(Ops.PushString, stmt.kind.subject.kind.value);
                this.program.add(Ops.Builtin, Builtins.StructSet);
                return;
            }
            case "index": {
                this.lowerExpr(stmt.kind.subject.kind.subject);
                this.lowerExpr(stmt.kind.subject.kind.value);
                this.program.add(Ops.Builtin, Builtins.ArraySet);
                return;
            }
            case "sym": {
                this.program.add(
                    Ops.StoreLocal,
                    this.locals.symId(stmt.kind.subject.kind.sym.ident),
                );
                return;
            }
            default:
                throw new Error();
        }
    }

    private lowerBreakStmt(stmt: Stmt) {
        if (stmt.kind.type !== "break") {
            throw new Error();
        }
        if (stmt.kind.expr) {
            this.lowerExpr(stmt.kind.expr);
        }
        this.program.add(Ops.Jump, this.breakStack.at(-1)!);
    }

    private lowerBuiltinAnno(annoArgs: Expr[]) {
        if (annoArgs.length !== 1) {
            throw new Error("invalid # of arguments to builtin annotation");
        }
        const anno = annoArgs[0];
        if (anno.kind.type !== "ident") {
            throw new Error(
                `unexpected argument type '${anno.kind.type}' expected 'ident'`,
            );
        }
        const value = anno.kind.value;
        switch (value) {
            case "print": {
                this.program.add(Ops.Builtin, Builtins.Print);
                break;
            }
            default: {
                throw new Error(
                    `unrecognized builtin '${value}'`,
                );
            }
        }
    }

    private lowerFnStmt(stmt: Stmt) {
        if (stmt.kind.type !== "fn") {
            throw new Error();
        }
        const label = stmt.kind.ident === "main"
            ? "main"
            : `${stmt.kind.ident}_${stmt.id}`;
        this.program.setLabel({ label });

        const outerLocals = this.locals;
        const fnRoot = new LocalsFnRoot(outerLocals);
        const outerProgram = this.program;

        this.program = new Assembler();
        this.locals = fnRoot;
        for (const { ident } of stmt.kind.params) {
            this.locals.allocSym(ident);
        }
        if (stmt.kind.anno?.ident === "builtin") {
            this.lowerBuiltinAnno(stmt.kind.anno.values);
        } else {
            this.lowerExpr(stmt.kind.body);
        }
        this.locals = outerLocals;

        const localAmount = fnRoot.stackReserved() -
            stmt.kind.params.length;
        for (let i = 0; i < localAmount; ++i) {
            outerProgram.add(Ops.PushNull);
        }

        this.program.add(Ops.Return);

        outerProgram.concat(this.program);
        this.program = outerProgram;
    }

    private lowerLetStmt(stmt: Stmt) {
        if (stmt.kind.type !== "let") {
            throw new Error();
        }
        this.lowerExpr(stmt.kind.value);
        this.locals.allocSym(stmt.kind.param.ident);
        this.program.add(
            Ops.StoreLocal,
            this.locals.symId(stmt.kind.param.ident),
        );
    }

    private lowerExpr(expr: Expr) {
        switch (expr.kind.type) {
            case "error":
                break;
            case "sym":
                return this.lowerSymExpr(expr);
            case "null":
                break;
            case "int":
                return this.lowerIntExpr(expr);
            case "bool":
                break;
            case "string":
                return this.lowerStringExpr(expr);
            case "ident":
                break;
            case "group":
                break;
            case "field":
                break;
            case "index":
                break;
            case "call":
                return this.lowerCallExpr(expr);
            case "unary":
                break;
            case "binary":
                return this.lowerBinaryExpr(expr);
            case "if":
                return this.lowerIfExpr(expr);
            case "loop":
                return this.lowerLoopExpr(expr);
            case "block":
                return this.lowerBlockExpr(expr);
        }
        throw new Error(`unhandled expr '${expr.kind.type}'`);
    }

    private lowerSymExpr(expr: Expr) {
        if (expr.kind.type !== "sym") {
            throw new Error();
        }
        if (expr.kind.sym.type === "let") {
            this.program.add(
                Ops.LoadLocal,
                this.locals.symId(expr.kind.ident),
            );
            return;
        }
        if (expr.kind.sym.type === "fn_param") {
            this.program.add(
                Ops.LoadLocal,
                this.locals.symId(expr.kind.ident),
            );
            return;
        }
        if (expr.kind.sym.type === "fn") {
            const label = this.fnStmtIdLabelMap[expr.kind.sym.stmt.id];
            this.program.add(Ops.PushPtr, { label });
            return;
        }
        throw new Error(`unhandled sym type '${expr.kind.sym.type}'`);
    }

    private lowerIntExpr(expr: Expr) {
        if (expr.kind.type !== "int") {
            throw new Error();
        }
        this.program.add(Ops.PushInt, expr.kind.value);
    }

    private lowerStringExpr(expr: Expr) {
        if (expr.kind.type !== "string") {
            throw new Error();
        }
        this.program.add(Ops.PushString, expr.kind.value);
    }

    private lowerBinaryExpr(expr: Expr) {
        if (expr.kind.type !== "binary") {
            throw new Error();
        }
        this.lowerExpr(expr.kind.left);
        this.lowerExpr(expr.kind.right);
        const vtype = expr.kind.left.vtype!;
        if (vtype.type === "int") {
            switch (expr.kind.binaryType) {
                case "+":
                    this.program.add(Ops.Add);
                    return;
                case "*":
                    this.program.add(Ops.Multiply);
                    return;
                case "==":
                    this.program.add(Ops.Equal);
                    return;
                case ">=":
                    this.program.add(Ops.LessThan);
                    this.program.add(Ops.Not);
                    return;
                default:
            }
        }
        if (vtype.type === "string") {
            if (expr.kind.binaryType === "+") {
                this.program.add(Ops.Builtin, Builtins.StringConcat);
                return;
            }
            if (expr.kind.binaryType === "==") {
                this.program.add(Ops.Builtin, Builtins.StringEqual);
                return;
            }
            if (expr.kind.binaryType === "!=") {
                this.program.add(Ops.Builtin, Builtins.StringEqual);
                this.program.add(Ops.Not);
                return;
            }
        }
        throw new Error(
            `unhandled binaryType` +
                ` '${vtypeToString(expr.vtype!)}' aka. ` +
                ` '${vtypeToString(expr.kind.left.vtype!)}'` +
                ` ${expr.kind.binaryType}` +
                ` '${vtypeToString(expr.kind.left.vtype!)}'`,
        );
    }

    private lowerCallExpr(expr: Expr) {
        if (expr.kind.type !== "call") {
            throw new Error();
        }
        for (const arg of expr.kind.args) {
            this.lowerExpr(arg);
        }
        this.lowerExpr(expr.kind.subject);
        this.program.add(Ops.Call, expr.kind.args.length);
    }

    private lowerIfExpr(expr: Expr) {
        if (expr.kind.type !== "if") {
            throw new Error();
        }

        const falseLabel = this.program.makeLabel();
        const doneLabel = this.program.makeLabel();

        this.lowerExpr(expr.kind.cond);

        this.program.add(Ops.Not);
        this.program.add(Ops.PushPtr, falseLabel);
        this.program.add(Ops.JumpIfTrue);

        this.lowerExpr(expr.kind.truthy);

        this.program.add(Ops.PushPtr, doneLabel);
        this.program.add(Ops.Jump);

        this.program.setLabel(falseLabel);

        if (expr.kind.falsy) {
            this.lowerExpr(expr.kind.falsy!);
        } else {
            this.program.add(Ops.PushNull);
        }

        this.program.setLabel(doneLabel);
    }

    private lowerLoopExpr(expr: Expr) {
        if (expr.kind.type !== "loop") {
            throw new Error();
        }
        const contineLabel = this.program.makeLabel();
        const breakLabel = this.program.makeLabel();

        this.breakStack.push(breakLabel);

        this.program.setLabel(contineLabel);
        this.lowerExpr(expr.kind.body);
        this.program.add(Ops.PushPtr, breakLabel);
        this.program.add(Ops.Jump);
        this.program.setLabel(breakLabel);
        if (expr.vtype!.type === "null") {
            this.program.add(Ops.PushNull);
        }

        this.breakStack.pop();
    }

    private lowerBlockExpr(expr: Expr) {
        if (expr.kind.type !== "block") {
            throw new Error();
        }
        const outerLocals = this.locals;
        this.locals = new LocalLeaf(this.locals);
        this.scoutFnHeaders(expr.kind.stmts);
        for (const stmt of expr.kind.stmts) {
            this.lowerStmt(stmt);
        }
        if (expr.kind.expr) {
            this.lowerExpr(expr.kind.expr);
        } else {
            this.program.add(Ops.PushNull);
        }
        this.locals = outerLocals;
    }

    public printProgram() {
        this.program.printProgram();
    }
}