import { BinaryType, Expr, Stmt } from "./ast.ts";
import { Ops } from "./mod.ts";
import { VType } from "./vtypes.ts";

interface Locals {
    reserveId(id: number): void;
    allocSym(ident: string): void;
    symId(ident: string): number;
}

class LocalsFnRoot implements Locals {
    private localsAmount = 0;
    private localIdCounter = 0;
    private symLocalMap: { [key: string]: number } = {};

    constructor(private parent?: Locals) {
    }

    reserveId(id: number): void {
        this.localsAmount = Math.max(id + 1, this.localsAmount);
    }

    allocSym(ident: string) {
        this.symLocalMap[ident] = this.localIdCounter;
        this.localIdCounter++;
        this.reserveId(this.localIdCounter);
    }

    symId(ident: string): number {
        if (ident in this.symLocalMap) {
            return this.symLocalMap[ident];
        }
        if (this.parent) {
            return this.parent.symId(ident);
        }
        throw new Error(`undefined symbol '${ident}'`);
    }
}

class LocalLeaf implements Locals {
    private localIdCounter = 0;
    private symLocalMap: { [key: string]: number } = {};

    constructor(private parent: Locals) {
    }

    reserveId(id: number): void {
        this.parent.reserveId(id);
    }

    allocSym(ident: string) {
        this.symLocalMap[ident] = this.localIdCounter;
        this.localIdCounter++;
        this.reserveId(this.localIdCounter);
    }

    symId(ident: string): number {
        if (ident in this.symLocalMap) {
            return this.symLocalMap[ident];
        }
        return this.parent.symId(ident);
    }
}

export class Lowerer {
    private program: number[] = [];
    private locals = new LocalsFnRoot();

    lower(stmts: Stmt[]) {
        for (const stmt of stmts) {
            this.lowerStmt(stmt);
        }
    }

    lowerStmt(stmt: Stmt) {
        switch (stmt.kind.type) {
            case "error":
            case "break":
            case "return":
            case "fn":
                break;
            case "let":
                return this.lowerLetStmt(stmt);
            case "assign":
            case "expr":
        }
        throw new Error(`Unhandled stmt ${stmt.kind.type}`);
    }

    lowerLetStmt(stmt: Stmt) {
        if (stmt.kind.type !== "let") {
            throw new Error();
        }
        this.lowerExpr(stmt.kind.value);
        this.locals.allocSym(stmt.kind.param.ident),
            this.program.push(Ops.StoreLocal);
        this.program.push(this.locals.symId(stmt.kind.param.ident));
    }

    lowerExpr(expr: Expr) {
        switch (expr.kind.type) {
            case "string":
            case "error":
                break;
            case "int":
                return this.lowerInt(expr);
            case "ident":
            case "group":
            case "field":
            case "index":
            case "call":
            case "unary":
                break;
            case "binary":
                return this.lowerBinaryExpr(expr);
            case "if":
            case "bool":
            case "null":
            case "loop":
            case "block":
                break;
            case "sym":
                return this.lowerSym(expr);
        }
        throw new Error(`Unhandled expr ${expr.kind.type}`);
    }

    lowerInt(expr: Expr) {
        if (expr.kind.type !== "int") {
            throw new Error();
        }
        this.program.push(Ops.PushInt);
        this.program.push(expr.kind.value);
    }

    lowerSym(expr: Expr) {
        if (expr.kind.type !== "sym") {
            throw new Error();
        }
        if (expr.kind.defType == "let") {
            this.program.push(Ops.LoadLocal);
            this.program.push(this.locals.symId(expr.kind.ident));
            return;
        }
        throw new Error(`Unhandled sym deftype ${expr.kind.defType}`);
    }

    lowerBinaryExpr(expr: Expr) {
        if (expr.kind.type !== "binary") {
            throw new Error();
        }
        this.lowerExpr(expr.kind.left);
        this.lowerExpr(expr.kind.right);
        if (expr.vtype!.type === "int") {
            switch (expr.kind.binaryType) {
                case "+":
                    this.program.push(Ops.Add);
                    return;
                case "*":
                    this.program.push(Ops.Multiply);
                    return;
                case "==":
                case "-":
                case "/":
                case "!=":
                case "<":
                case ">":
                case "<=":
                case ">=":
                case "or":
                case "and":
            }
        }
        throw new Error(
            `Unhandled vtype/binaryType '${
                expr.vtype!.type
            }/${expr.kind.binaryType}'`,
        );
    }
}