import { Builtins, Ops } from "./arch.ts";
import { Assembler, Label } from "./assembler.ts";
import { AnnoView, Expr, Stmt } from "./ast.ts";
import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts";
import { MonoCallNameGenMap, MonoFn, MonoFnsMap } from "./mono.ts";
import { Pos } from "./token.ts";
import { vtypeToString } from "./vtype.ts";

export type FnNamesMap = { [pc: number]: string };

export class Lowerer {
    private program = Assembler.newRoot();

    public constructor(
        private monoFns: MonoFnsMap,
        private callMap: MonoCallNameGenMap,
        private lastPos: Pos,
    ) {}

    public lower(): { program: number[]; fnNames: FnNamesMap } {
        const fnLabelNameMap: FnLabelMap = {};
        for (const nameGen in this.monoFns) {
            fnLabelNameMap[nameGen] = nameGen;
        }

        this.addPrelimiary();

        for (const fn of Object.values(this.monoFns)) {
            const fnProgram = new MonoFnLowerer(
                fn,
                this.program.fork(),
                this.callMap,
            ).lower();
            this.program.join(fnProgram);
        }

        this.addConcluding();

        const { program, locs } = this.program.assemble();
        const fnNames: FnNamesMap = {};
        for (const label in locs) {
            if (label in fnLabelNameMap) {
                fnNames[locs[label]] = fnLabelNameMap[label];
            }
        }
        return { program, fnNames };
    }

    private addPrelimiary() {
        this.addClearingSourceMap();
        this.program.add(Ops.PushPtr, { label: "main" });
        this.program.add(Ops.Call, 0);
        this.program.add(Ops.PushPtr, { label: "_exit" });
        this.program.add(Ops.Jump);
    }

    private addConcluding() {
        this.program.setLabel({ label: "_exit" });
        this.addSourceMap(this.lastPos);
        this.program.add(Ops.Pop);
    }

    private addSourceMap({ index, line, col }: Pos) {
        this.program.add(Ops.SourceMap, index, line, col);
    }

    private addClearingSourceMap() {
        this.program.add(Ops.SourceMap, 0, 1, 1);
    }

    public printProgram() {
        this.program.printProgram();
    }
}

type FnLabelMap = { [nameGen: string]: string };

class MonoFnLowerer {
    private locals: Locals = new LocalsFnRoot();
    private returnStack: Label[] = [];
    private breakStack: Label[] = [];

    public constructor(
        private fn: MonoFn,
        private program: Assembler,
        private callMap: MonoCallNameGenMap,
    ) {}

    public lower(): Assembler {
        this.lowerFnStmt(this.fn.stmt);
        return this.program;
    }

    private lowerFnStmt(stmt: Stmt) {
        if (stmt.kind.type !== "fn") {
            throw new Error();
        }
        const label = this.fn.nameGen;
        this.program.setLabel({ label });
        this.addSourceMap(stmt.pos);

        const outerLocals = this.locals;
        const fnRoot = new LocalsFnRoot(outerLocals);
        const outerProgram = this.program;

        const returnLabel = this.program.makeLabel();
        this.returnStack.push(returnLabel);

        this.program = outerProgram.fork();
        this.locals = fnRoot;
        for (const { ident } of stmt.kind.params) {
            this.locals.allocSym(ident);
        }

        const annos = new AnnoView(stmt.details);
        if (annos.has("builtin")) {
            const anno = annos.get("builtin");
            if (!anno) {
                throw new Error();
            }
            this.lowerFnBuiltinBody(anno.args);
        } else if (annos.has("remainder")) {
            this.program.add(Ops.Remainder);
        } else {
            this.lowerExpr(stmt.kind.body);
        }
        this.locals = outerLocals;

        const localAmount = fnRoot.stackReserved() -
            stmt.kind.params.length;
        for (let i = 0; i < localAmount; ++i) {
            outerProgram.add(Ops.PushNull);
        }

        this.returnStack.pop();
        this.program.setLabel(returnLabel);
        this.program.add(Ops.Return);

        outerProgram.join(this.program);
        this.program = outerProgram;
    }

    private addSourceMap({ index, line, col }: Pos) {
        this.program.add(Ops.SourceMap, index, line, col);
    }

    private addClearingSourceMap() {
        this.program.add(Ops.SourceMap, 0, 1, 1);
    }

    private lowerStmt(stmt: Stmt) {
        switch (stmt.kind.type) {
            case "error":
                break;
            case "break":
                return this.lowerBreakStmt(stmt);
            case "return":
                return this.lowerReturnStmt(stmt);
            case "fn":
                return this.lowerFnStmt(stmt);
            case "let":
                return this.lowerLetStmt(stmt);
            case "assign":
                return this.lowerAssignStmt(stmt);
            case "expr":
                this.lowerExpr(stmt.kind.expr);
                this.program.add(Ops.Pop);
                return;
        }
        throw new Error(`unhandled stmt '${stmt.kind.type}'`);
    }

    private lowerAssignStmt(stmt: Stmt) {
        if (stmt.kind.type !== "assign") {
            throw new Error();
        }
        this.lowerExpr(stmt.kind.value);
        switch (stmt.kind.subject.kind.type) {
            case "field": {
                this.lowerExpr(stmt.kind.subject.kind.subject);
                this.program.add(Ops.PushString, stmt.kind.subject.kind.ident);
                this.program.add(Ops.Builtin, Builtins.StructSet);
                return;
            }
            case "index": {
                this.lowerExpr(stmt.kind.subject.kind.subject);
                this.lowerExpr(stmt.kind.subject.kind.value);
                this.program.add(Ops.Builtin, Builtins.ArraySet);
                return;
            }
            case "sym": {
                this.program.add(
                    Ops.StoreLocal,
                    this.locals.symId(stmt.kind.subject.kind.sym.ident),
                );
                return;
            }
            default:
                throw new Error();
        }
    }

    private lowerReturnStmt(stmt: Stmt) {
        if (stmt.kind.type !== "return") {
            throw new Error();
        }
        if (stmt.kind.expr) {
            this.lowerExpr(stmt.kind.expr);
        }
        this.addClearingSourceMap();
        this.program.add(Ops.PushPtr, this.returnStack.at(-1)!);
        this.program.add(Ops.Jump);
    }

    private lowerBreakStmt(stmt: Stmt) {
        if (stmt.kind.type !== "break") {
            throw new Error();
        }
        if (stmt.kind.expr) {
            this.lowerExpr(stmt.kind.expr);
        }
        this.addClearingSourceMap();
        this.program.add(Ops.PushPtr, this.breakStack.at(-1)!);
        this.program.add(Ops.Jump);
    }

    private lowerFnBuiltinBody(annoArgs: Expr[]) {
        if (annoArgs.length !== 1) {
            throw new Error("invalid # of arguments to builtin annotation");
        }
        const anno = annoArgs[0];
        if (anno.kind.type !== "ident") {
            throw new Error(
                `unexpected argument type '${anno.kind.type}' expected 'ident'`,
            );
        }
        const value = anno.kind.ident;
        const builtin = Object.entries(Builtins).find((entry) =>
            entry[0] === value
        )?.[1];
        if (builtin === undefined) {
            throw new Error(
                `unrecognized builtin '${value}'`,
            );
        }
        this.program.add(Ops.Builtin, builtin);
    }

    private lowerLetStmt(stmt: Stmt) {
        if (stmt.kind.type !== "let") {
            throw new Error();
        }
        this.lowerExpr(stmt.kind.value);
        this.locals.allocSym(stmt.kind.param.ident);
        this.program.add(
            Ops.StoreLocal,
            this.locals.symId(stmt.kind.param.ident),
        );
    }

    private lowerExpr(expr: Expr) {
        switch (expr.kind.type) {
            case "error":
                break;
            case "sym":
                return this.lowerSymExpr(expr);
            case "null":
                break;
            case "int":
                return this.lowerIntExpr(expr);
            case "bool":
                return this.lowerBoolExpr(expr);
            case "string":
                return this.lowerStringExpr(expr);
            case "ident":
                break;
            case "group":
                return void this.lowerExpr(expr.kind.expr);
            case "field":
                return this.lowerFieldExpr(expr);
            case "index":
                return this.lowerIndexExpr(expr);
            case "call":
                return this.lowerCallExpr(expr);
            case "etype_args":
                return this.lowerETypeArgsExpr(expr);
            case "unary":
                return this.lowerUnaryExpr(expr);
            case "binary":
                return this.lowerBinaryExpr(expr);
            case "if":
                return this.lowerIfExpr(expr);
            case "loop":
                return this.lowerLoopExpr(expr);
            case "block":
                return this.lowerBlockExpr(expr);
        }
        throw new Error(`unhandled expr '${expr.kind.type}'`);
    }

    private lowerFieldExpr(expr: Expr) {
        if (expr.kind.type !== "field") {
            throw new Error();
        }
        this.lowerExpr(expr.kind.subject);
        this.program.add(Ops.PushString, expr.kind.ident);

        if (expr.kind.subject.vtype?.type == "struct") {
            this.program.add(Ops.Builtin, Builtins.StructAt);
            return;
        }
        throw new Error(`unhandled field subject type '${expr.kind.subject}'`);
    }

    private lowerIndexExpr(expr: Expr) {
        if (expr.kind.type !== "index") {
            throw new Error();
        }
        this.lowerExpr(expr.kind.subject);
        this.lowerExpr(expr.kind.value);

        if (expr.kind.subject.vtype?.type == "array") {
            this.program.add(Ops.Builtin, Builtins.ArrayAt);
            return;
        }
        if (expr.kind.subject.vtype?.type == "string") {
            this.program.add(Ops.Builtin, Builtins.StringCharAt);
            return;
        }
        throw new Error(`unhandled index subject type '${expr.kind.subject}'`);
    }

    private lowerSymExpr(expr: Expr) {
        if (expr.kind.type !== "sym") {
            throw new Error();
        }
        if (expr.kind.sym.type === "let") {
            const symId = this.locals.symId(expr.kind.ident);
            this.program.add(Ops.LoadLocal, symId);
            return;
        }
        if (expr.kind.sym.type === "fn_param") {
            this.program.add(
                Ops.LoadLocal,
                this.locals.symId(expr.kind.ident),
            );
            return;
        }
        if (expr.kind.sym.type === "fn") {
            // Is this smart? Well, my presumption is
            // that it isn't. The underlying problem, which
            // this solutions raison d'ĂȘtre is to solve, is
            // that the compiler, as it d'ĂȘtre's currently
            // doesn't support checking and infering generic
            // fn args all the way down to the sym. Therefore,
            // when a sym is checked in a call expr, we can't
            // really do anything useful. Instead the actual
            // function pointer pointing to the actual
            // monomorphized function is emplaced when
            // lowering the call expression itself. But what
            // should we do then, if the user decides to
            // assign a function to a local? You might ask.
            // You see, that's where the problem lies.
            // My current, very thought out solution, as
            // you can read below, is to push a null pointer,
            // for it to then be replaced later. This will
            // probably cause many hastles in the future
            // for myself in particular, when trying to
            // decipher the lowerer's output. So if you're
            // the unlucky girl, who has tried for ages to
            // decipher why a zero value is pushed and then
            // later replaced, and then you finally
            // stumbled upon this here implementation,
            // let me first say, I'm so sorry. At the time
            // of writing, I really haven't thought out
            // very well, how the generic call system should
            // work, and it's therefore a bit flaky, and the
            // implementation kinda looks like it was
            // implementated by a girl who didn't really
            // understand very well what they were
            // implementing at the time that they were
            // implementing it. Anyway, I just wanted to
            // apologize. Happy coding.
            // -Your favorite compiler girl.
            this.program.add(Ops.PushPtr, 0);
            return;
        }
        throw new Error(`unhandled sym type '${expr.kind.sym.type}'`);
    }

    private lowerIntExpr(expr: Expr) {
        if (expr.kind.type !== "int") {
            throw new Error();
        }
        this.program.add(Ops.PushInt, expr.kind.value);
    }

    private lowerBoolExpr(expr: Expr) {
        if (expr.kind.type !== "bool") {
            throw new Error();
        }
        this.program.add(Ops.PushBool, expr.kind.value);
    }

    private lowerStringExpr(expr: Expr) {
        if (expr.kind.type !== "string") {
            throw new Error();
        }
        this.program.add(Ops.PushString, expr.kind.value);
    }

    private lowerUnaryExpr(expr: Expr) {
        if (expr.kind.type !== "unary") {
            throw new Error();
        }
        this.lowerExpr(expr.kind.subject);
        const vtype = expr.kind.subject.vtype!;
        if (vtype.type === "bool") {
            switch (expr.kind.unaryType) {
                case "not":
                    this.program.add(Ops.Not);
                    return;
                default:
            }
        }
        if (vtype.type === "int") {
            switch (expr.kind.unaryType) {
                case "-": {
                    this.program.add(Ops.PushInt, 0);
                    this.program.add(Ops.Swap);
                    this.program.add(Ops.Subtract);
                    return;
                }
                default:
            }
        }
        throw new Error(
            `unhandled unary` +
                ` '${vtypeToString(expr.vtype!)}' aka. ` +
                ` ${expr.kind.unaryType}` +
                ` '${vtypeToString(expr.kind.subject.vtype!)}'`,
        );
    }

    private lowerBinaryExpr(expr: Expr) {
        if (expr.kind.type !== "binary") {
            throw new Error();
        }
        const vtype = expr.kind.left.vtype!;
        if (vtype.type === "bool") {
            if (["or", "and"].includes(expr.kind.binaryType)) {
                const shortCircuitLabel = this.program.makeLabel();
                this.lowerExpr(expr.kind.left);
                this.program.add(Ops.Duplicate);
                if (expr.kind.binaryType === "and") {
                    this.program.add(Ops.Not);
                }
                this.program.add(Ops.PushPtr, shortCircuitLabel);
                this.program.add(Ops.JumpIfTrue);
                this.program.add(Ops.Pop);
                this.lowerExpr(expr.kind.right);
                this.program.setLabel(shortCircuitLabel);
                return;
            }
        }
        this.lowerExpr(expr.kind.left);
        this.lowerExpr(expr.kind.right);
        if (vtype.type === "int") {
            switch (expr.kind.binaryType) {
                case "+":
                    this.program.add(Ops.Add);
                    return;
                case "-":
                    this.program.add(Ops.Subtract);
                    return;
                case "*":
                    this.program.add(Ops.Multiply);
                    return;
                case "/":
                    this.program.add(Ops.Multiply);
                    return;
                case "==":
                    this.program.add(Ops.Equal);
                    return;
                case "!=":
                    this.program.add(Ops.Equal);
                    this.program.add(Ops.Not);
                    return;
                case "<":
                    this.program.add(Ops.LessThan);
                    return;
                case ">":
                    this.program.add(Ops.Swap);
                    this.program.add(Ops.LessThan);
                    return;
                case "<=":
                    this.program.add(Ops.Swap);
                    this.program.add(Ops.LessThan);
                    this.program.add(Ops.Not);
                    return;
                case ">=":
                    this.program.add(Ops.LessThan);
                    this.program.add(Ops.Not);
                    return;
                default:
            }
        }
        if (vtype.type === "bool") {
            switch (expr.kind.binaryType) {
                case "==":
                    this.program.add(Ops.And);
                    return;
                case "!=":
                    this.program.add(Ops.And);
                    this.program.add(Ops.Not);
                    return;
                default:
            }
        }
        if (vtype.type === "string") {
            if (expr.kind.binaryType === "+") {
                this.program.add(Ops.Builtin, Builtins.StringConcat);
                return;
            }
            if (expr.kind.binaryType === "==") {
                this.program.add(Ops.Builtin, Builtins.StringEqual);
                return;
            }
            if (expr.kind.binaryType === "!=") {
                this.program.add(Ops.Builtin, Builtins.StringEqual);
                this.program.add(Ops.Not);
                return;
            }
        }
        throw new Error(
            `unhandled binaryType` +
                ` '${vtypeToString(expr.vtype!)}' aka. ` +
                ` '${vtypeToString(expr.kind.left.vtype!)}'` +
                ` ${expr.kind.binaryType}` +
                ` '${vtypeToString(expr.kind.left.vtype!)}'`,
        );
    }

    private lowerCallExpr(expr: Expr) {
        if (expr.kind.type !== "call") {
            throw new Error();
        }
        for (const arg of expr.kind.args) {
            this.lowerExpr(arg);
        }
        this.lowerExpr(expr.kind.subject);
        this.program.add(Ops.Pop);
        this.program.add(Ops.PushPtr, { label: this.callMap[expr.id] });
        this.program.add(Ops.Call, expr.kind.args.length);
    }

    private lowerETypeArgsExpr(expr: Expr) {
        if (expr.kind.type !== "etype_args") {
            throw new Error();
        }
        this.lowerExpr(expr.kind.subject);
    }

    private lowerIfExpr(expr: Expr) {
        if (expr.kind.type !== "if") {
            throw new Error();
        }

        const falseLabel = this.program.makeLabel();
        const doneLabel = this.program.makeLabel();

        this.lowerExpr(expr.kind.cond);

        this.program.add(Ops.Not);
        this.addClearingSourceMap();
        this.program.add(Ops.PushPtr, falseLabel);
        this.program.add(Ops.JumpIfTrue);

        this.addSourceMap(expr.kind.truthy.pos);
        this.lowerExpr(expr.kind.truthy);

        this.addClearingSourceMap();
        this.program.add(Ops.PushPtr, doneLabel);
        this.program.add(Ops.Jump);

        this.program.setLabel(falseLabel);

        if (expr.kind.falsy) {
            this.addSourceMap(expr.kind.elsePos!);
            this.lowerExpr(expr.kind.falsy);
        } else {
            this.program.add(Ops.PushNull);
        }

        this.program.setLabel(doneLabel);
    }

    private lowerLoopExpr(expr: Expr) {
        if (expr.kind.type !== "loop") {
            throw new Error();
        }
        const continueLabel = this.program.makeLabel();
        const breakLabel = this.program.makeLabel();

        this.breakStack.push(breakLabel);

        this.program.setLabel(continueLabel);
        this.addSourceMap(expr.kind.body.pos);
        this.lowerExpr(expr.kind.body);
        this.program.add(Ops.Pop);
        this.addClearingSourceMap();
        this.program.add(Ops.PushPtr, continueLabel);
        this.program.add(Ops.Jump);
        this.program.setLabel(breakLabel);
        if (expr.vtype!.type === "null") {
            this.program.add(Ops.PushNull);
        }
        this.breakStack.pop();
    }

    private lowerBlockExpr(expr: Expr) {
        if (expr.kind.type !== "block") {
            throw new Error();
        }
        const outerLocals = this.locals;
        this.locals = new LocalLeaf(this.locals);
        for (const stmt of expr.kind.stmts) {
            this.addSourceMap(stmt.pos);
            this.lowerStmt(stmt);
        }
        if (expr.kind.expr) {
            this.addSourceMap(expr.kind.expr.pos);
            this.lowerExpr(expr.kind.expr);
        } else {
            this.program.add(Ops.PushNull);
        }
        this.locals = outerLocals;
    }
}