import { Builtins } from "./arch.ts";
import { EType, Expr, Stmt } from "./ast.ts";
import { printStackTrace, Reporter } from "./info.ts";
import { Pos } from "./token.ts";
import { VType, VTypeParam, vtypesEqual, vtypeToString } from "./vtype.ts";

export class Checker {
    private fnReturnStack: VType[] = [];
    private loopBreakStack: VType[][] = [];

    public constructor(private reporter: Reporter) {}

    public check(stmts: Stmt[]) {
        this.checkFnHeaders(stmts);
        for (const stmt of stmts) {
            this.checkStmt(stmt);
        }
    }

    private checkFnHeaders(stmts: Stmt[]) {
        for (const stmt of stmts) {
            if (stmt.kind.type !== "fn") {
                continue;
            }
            const returnType: VType = stmt.kind.returnType
                ? this.checkEType(stmt.kind.returnType)
                : { type: "null" };
            const params: VTypeParam[] = [];
            for (const param of stmt.kind.params) {
                if (param.etype === undefined) {
                    this.report("parameter types must be defined", param.pos);
                    stmt.kind.vtype = { type: "error" };
                }
                const vtype = this.checkEType(param.etype!);
                param.vtype = vtype;
                params.push({ ident: param.ident, vtype });
            }
            stmt.kind.vtype = { type: "fn", params, returnType };
        }
    }

    public checkStmt(stmt: Stmt) {
        switch (stmt.kind.type) {
            case "error":
                return { type: "error" };
            case "break":
                return this.checkBreakStmt(stmt);
            case "return":
                return this.checkReturnStmt(stmt);
            case "fn":
                return this.checkFnStmt(stmt);
            case "let":
                return this.checkLetStmt(stmt);
            case "assign":
                return this.checkAssignStmt(stmt);
            case "expr":
                return this.checkExpr(stmt.kind.expr);
        }
    }

    public checkBreakStmt(stmt: Stmt) {
        if (stmt.kind.type !== "break") {
            throw new Error();
        }
        const pos = stmt.pos;
        if (this.loopBreakStack.length === 0) {
            this.report("cannot break outside loop context", pos);
            return;
        }
        const exprType: VType = stmt.kind.expr
            ? this.checkExpr(stmt.kind.expr)
            : { type: "null" };
        const breakTypes = this.loopBreakStack.at(-1)!;
        if (breakTypes.length === 0) {
            breakTypes.push(exprType);
            return;
        }
        const prevBreakType = breakTypes.at(-1)!;
        if (!vtypesEqual(prevBreakType, exprType)) {
            this.report(
                `incompatible types for break` +
                    `, got ${exprType}` +
                    ` incompatible with ${prevBreakType}`,
                pos,
            );
            return;
        }
        breakTypes.push(exprType);
    }

    public checkReturnStmt(stmt: Stmt) {
        if (stmt.kind.type !== "return") {
            throw new Error();
        }
        const pos = stmt.pos;
        if (this.fnReturnStack.length === 0) {
            this.report("cannot return outside fn context", pos);
            return;
        }
        const exprType: VType = stmt.kind.expr
            ? this.checkExpr(stmt.kind.expr)
            : { type: "null" };
        const returnType = this.fnReturnStack.at(-1)!;
        if (!vtypesEqual(exprType, returnType)) {
            this.report(
                `incompatible return type` +
                    `, got ${exprType}` +
                    `, expected ${returnType}`,
                pos,
            );
        }
    }

    public checkFnStmt(stmt: Stmt) {
        if (stmt.kind.type !== "fn") {
            throw new Error();
        }
        const pos = stmt.pos;
        if (stmt.kind.vtype!.type !== "fn") {
            throw new Error();
        }
        const { returnType } = stmt.kind.vtype!;
        this.fnReturnStack.push(returnType);

        const isBuiltin = stmt.kind.anno && stmt.kind.anno.ident === "builtin";
        if (isBuiltin) {
            stmt.kind.body.kind = { type: "block", stmts: [] };
        }
        const body = this.checkExpr(stmt.kind.body);

        this.fnReturnStack.pop();
        if (!vtypesEqual(returnType, body)) {
            this.report(
                `incompatible return type` +
                    `, expected '${vtypeToString(returnType)}'` +
                    `, got '${vtypeToString(body)}'`,
                pos,
            );
        }
    }

    public checkLetStmt(stmt: Stmt) {
        if (stmt.kind.type !== "let") {
            throw new Error();
        }
        const pos = stmt.pos;
        const value = this.checkExpr(stmt.kind.value);
        if (stmt.kind.param.etype) {
            const paramVtype = this.checkEType(stmt.kind.param.etype);
            if (!vtypesEqual(value, paramVtype)) {
                this.report(
                    `incompatible value type` +
                        `, got '${vtypeToString(value)}'` +
                        `, expected '${vtypeToString(paramVtype)}'`,
                    pos,
                );
                return;
            }
        }
        stmt.kind.param.vtype = value;
    }

    public checkAssignStmt(stmt: Stmt) {
        if (stmt.kind.type !== "assign") {
            throw new Error();
        }
        const pos = stmt.pos;
        const value = this.checkExpr(stmt.kind.value);
        switch (stmt.kind.subject.kind.type) {
            case "field": {
                const subject = this.checkExpr(stmt.kind.subject.kind.subject);
                if (subject.type !== "struct") {
                    this.report("cannot use field on non-struct", pos);
                    return { type: "error" };
                }
                const fieldValue = stmt.kind.subject.kind.value;
                const found = subject.fields.find((param) =>
                    param.ident === fieldValue
                );
                if (!found) {
                    this.report(
                        `no field named '${stmt.kind.subject.kind.value}' on struct`,
                        pos,
                    );
                    return { type: "error" };
                }
                if (!vtypesEqual(found.vtype, value)) {
                    this.report(
                        `cannot assign incompatible type to field '${found.ident}'` +
                            `, got '${vtypeToString(value)}'` +
                            `, expected '${vtypeToString(found.vtype)}'`,
                        pos,
                    );
                    return;
                }
                return;
            }
            case "index": {
                const subject = this.checkExpr(stmt.kind.subject.kind.subject);
                if (subject.type !== "array") {
                    this.report("cannot index on non-array", pos);
                    return { type: "error" };
                }
                const indexValue = this.checkExpr(stmt.kind.subject.kind.value);
                if (indexValue.type !== "int") {
                    this.report("cannot index on array with non-int", pos);
                    return { type: "error" };
                }
                if (!vtypesEqual(subject.inner, value)) {
                    this.report(
                        `cannot assign incompatible type to array ` +
                            `'${vtypeToString(subject)}'` +
                            `, got '${vtypeToString(value)}'`,
                        pos,
                    );
                    return;
                }
                return;
            }
            case "sym": {
                if (stmt.kind.subject.kind.sym.type !== "let") {
                    this.report("cannot only assign to let-symbol", pos);
                    return { type: "error" };
                }
                if (
                    !vtypesEqual(stmt.kind.subject.kind.sym.param.vtype!, value)
                ) {
                    this.report(
                        `cannot assign to incompatible type` +
                            `, got '${vtypeToString(value)}'` +
                            `, expected '${
                                vtypeToString(
                                    stmt.kind.subject.kind.sym.param.vtype!,
                                )
                            }'`,
                        pos,
                    );
                    return;
                }
                return;
            }
            default:
                this.report("unassignable expression", pos);
                return;
        }
    }

    public checkExpr(expr: Expr): VType {
        const vtype = ((): VType => {
            switch (expr.kind.type) {
                case "error":
                    throw new Error("error in AST");
                case "ident":
                    throw new Error("ident expr in AST");
                case "sym":
                    return this.checkSymExpr(expr);
                case "null":
                    return { type: "null" };
                case "int":
                    return { type: "int" };
                case "bool":
                    return { type: "bool" };
                case "string":
                    return { type: "string" };
                case "group":
                    return this.checkExpr(expr.kind.expr);
                case "field":
                    return this.checkFieldExpr(expr);
                case "index":
                    return this.checkIndexExpr(expr);
                case "call":
                    return this.checkCallExpr(expr);
                case "unary":
                    return this.checkUnaryExpr(expr);
                case "binary":
                    return this.checkBinaryExpr(expr);
                case "if":
                    return this.checkIfExpr(expr);
                case "loop":
                    return this.checkLoopExpr(expr);
                case "block":
                    return this.checkBlockExpr(expr);
            }
            // throw new Error(`unhandled type ${expr.kind.type}`);
        })();
        return expr.vtype = vtype;
    }

    public checkSymExpr(expr: Expr): VType {
        if (expr.kind.type !== "sym") {
            throw new Error();
        }
        switch (expr.kind.sym.type) {
            case "let":
                return expr.kind.sym.param.vtype!;
            case "fn": {
                const fnStmt = expr.kind.sym.stmt!;
                if (fnStmt.kind.type !== "fn") {
                    throw new Error();
                }
                const vtype = fnStmt.kind.vtype!;
                if (vtype.type !== "fn") {
                    throw new Error();
                }
                const { params, returnType } = vtype;
                return { type: "fn", params, returnType };
            }
            case "fn_param":
                return expr.kind.sym.param.vtype!;
            case "builtin":
            case "let_static":
            case "closure":
                throw new Error(
                    `not implemented, sym type '${expr.kind.sym.type}'`,
                );
        }
    }

    public checkFieldExpr(expr: Expr): VType {
        if (expr.kind.type !== "field") {
            throw new Error();
        }
        const pos = expr.pos;
        const subject = this.checkExpr(expr.kind.subject);
        if (subject.type !== "struct") {
            this.report("cannot use field on non-struct", pos);
            return { type: "error" };
        }
        const value = expr.kind.value;
        const found = subject.fields.find((param) => param.ident === value);
        if (!found) {
            this.report(
                `no field named '${expr.kind.value}' on struct`,
                pos,
            );
            return { type: "error" };
        }
        return found.vtype;
    }

    public checkIndexExpr(expr: Expr): VType {
        if (expr.kind.type !== "index") {
            throw new Error();
        }
        const pos = expr.pos;
        const subject = this.checkExpr(expr.kind.subject);
        if (subject.type !== "array") {
            this.report("cannot index on non-array", pos);
            return { type: "error" };
        }
        const value = this.checkExpr(expr.kind.value);
        if (value.type !== "int") {
            this.report("cannot index on array with non-int", pos);
            return { type: "error" };
        }
        return subject.inner;
    }

    public checkCallExpr(expr: Expr): VType {
        if (expr.kind.type !== "call") {
            throw new Error();
        }
        const pos = expr.pos;
        const subject = this.checkExpr(expr.kind.subject);
        if (subject.type !== "fn") {
            this.report("cannot call non-fn", pos);
            return { type: "error" };
        }
        const args = expr.kind.args.map((arg) => this.checkExpr(arg));
        if (args.length !== subject.params.length) {
            this.report(
                `incorrect number of arguments` +
                    `, expected ${subject.params.length}`,
                pos,
            );
        }
        for (let i = 0; i < args.length; ++i) {
            if (!vtypesEqual(args[i], subject.params[i].vtype)) {
                this.report(
                    `incorrect argument ${i} '${subject.params[i].ident}'` +
                        `, expected ${vtypeToString(subject.params[i].vtype)}` +
                        `, got ${vtypeToString(args[i])}`,
                    pos,
                );
                break;
            }
        }
        return subject.returnType;
    }

    public checkUnaryExpr(expr: Expr): VType {
        if (expr.kind.type !== "unary") {
            throw new Error();
        }
        const pos = expr.pos;
        const subject = this.checkExpr(expr.kind.subject);
        for (const operation of simpleUnaryOperations) {
            if (operation.unaryType !== expr.kind.unaryType) {
                continue;
            }
            if (!vtypesEqual(operation.operand, subject)) {
                continue;
            }
            return operation.result ?? operation.operand;
        }
        this.report(
            `cannot apply unary operation '${expr.kind.unaryType}' ` +
                `on type '${vtypeToString(subject)}'`,
            pos,
        );
        return { type: "error" };
    }

    public checkBinaryExpr(expr: Expr): VType {
        if (expr.kind.type !== "binary") {
            throw new Error();
        }
        const pos = expr.pos;
        const left = this.checkExpr(expr.kind.left);
        const right = this.checkExpr(expr.kind.right);
        for (const operation of simpleBinaryOperations) {
            if (operation.binaryType !== expr.kind.binaryType) {
                continue;
            }
            if (!vtypesEqual(operation.operand, left)) {
                continue;
            }
            if (!vtypesEqual(left, right)) {
                continue;
            }
            return operation.result ?? operation.operand;
        }
        this.report(
            `cannot apply binary operation '${expr.kind.binaryType}' ` +
                `on types '${vtypeToString(left)}' and '${
                    vtypeToString(right)
                }'`,
            pos,
        );
        return { type: "error" };
    }

    public checkIfExpr(expr: Expr): VType {
        if (expr.kind.type !== "if") {
            throw new Error();
        }
        const pos = expr.pos;
        const cond = this.checkExpr(expr.kind.cond);
        const truthy = this.checkExpr(expr.kind.truthy);
        const falsy = expr.kind.falsy
            ? this.checkExpr(expr.kind.falsy)
            : undefined;
        if (cond.type !== "bool") {
            this.report(
                `if condition should be 'bool', got '${vtypeToString(cond)}'`,
                pos,
            );
            return { type: "error" };
        }
        if (falsy === undefined && truthy.type !== "null") {
            this.report(
                `if expressions without false-case must result in type 'null'` +
                    `, got '${vtypeToString(truthy)}'`,
                pos,
            );
            return { type: "error" };
        }
        if (falsy !== undefined && !vtypesEqual(truthy, falsy)) {
            this.report(
                `if cases must be compatible, got incompatible types` +
                    ` '${vtypeToString(truthy)}'` +
                    ` and '${vtypeToString(falsy)}'`,
                pos,
            );
            return { type: "error" };
        }
        return truthy;
    }

    public checkLoopExpr(expr: Expr): VType {
        if (expr.kind.type !== "loop") {
            throw new Error();
        }
        const pos = expr.pos;
        this.loopBreakStack.push([]);
        const body = this.checkExpr(expr.kind.body);
        if (body.type !== "null") {
            this.report(
                `loop body must result in type 'null'` +
                    `, got '${vtypeToString(body)}'`,
                pos,
            );
            return { type: "error" };
        }
        const loopBreakTypes = this.loopBreakStack.pop()!;
        if (loopBreakTypes.length === 0) {
            return { type: "null" };
        }
        const breakType = loopBreakTypes.reduce<[VType, boolean, VType]>(
            (acc, curr) => {
                const [resulting, isIncompatible, outlier] = acc;
                if (isIncompatible) {
                    return acc;
                }
                if (!vtypesEqual(resulting, curr)) {
                    return [resulting, true, curr];
                }
                return [resulting, false, outlier];
            },
            [{ type: "null" }, false, { type: "null" }],
        );
        if (breakType[1]) {
            this.report(
                `incompatible types in break statements` +
                    `, got '${vtypeToString(breakType[2])}'` +
                    ` incompatible with ${vtypeToString(breakType[0])}`,
                pos,
            );
            return { type: "error" };
        }
        return breakType[0];
    }

    public checkBlockExpr(expr: Expr): VType {
        if (expr.kind.type !== "block") {
            throw new Error();
        }
        this.checkFnHeaders(expr.kind.stmts);
        for (const stmt of expr.kind.stmts) {
            this.checkStmt(stmt);
        }
        return expr.kind.expr
            ? this.checkExpr(expr.kind.expr)
            : { type: "null" };
    }

    public checkEType(etype: EType): VType {
        const pos = etype.pos;
        if (etype.kind.type === "ident") {
            if (etype.kind.value === "null") {
                return { type: "null" };
            }
            if (etype.kind.value === "int") {
                return { type: "int" };
            }
            if (etype.kind.value === "bool") {
                return { type: "bool" };
            }
            if (etype.kind.value === "string") {
                return { type: "string" };
            }
            this.report(`undefined type '${etype.kind.value}'`, pos);
            return { type: "error" };
        }
        if (etype.kind.type === "array") {
            const inner = this.checkEType(etype.kind.inner);
            return { type: "array", inner };
        }
        if (etype.kind.type === "struct") {
            const noTypeTest = etype.kind.fields.reduce(
                (acc, param) => [acc[0] || !param.etype, param.ident],
                [false, ""],
            );
            if (noTypeTest[0]) {
                this.report(
                    `field '${noTypeTest[1]}' declared without type`,
                    pos,
                );
                return { type: "error" };
            }
            const declaredTwiceTest = etype.kind.fields.reduce<
                [boolean, string[], string]
            >(
                (acc, curr) => {
                    if (acc[0]) {
                        return acc;
                    }
                    if (acc[1].includes(curr.ident)) {
                        return [true, acc[1], curr.ident];
                    }
                    return [false, [...acc[1], curr.ident], ""];
                },
                [false, [], ""],
            );
            if (
                declaredTwiceTest[0]
            ) {
                this.report(`field ${declaredTwiceTest[2]} defined twice`, pos);
                return { type: "error" };
            }
            const fields = etype.kind.fields.map((param): VTypeParam => ({
                ident: param.ident,
                vtype: this.checkEType(param.etype!),
            }));
            return { type: "struct", fields };
        }
        throw new Error(`unknown explicit type ${etype.kind.type}`);
    }

    private report(msg: string, pos: Pos) {
        this.reporter.reportError({ reporter: "Checker", msg, pos });
        printStackTrace();
    }
}

const simpleUnaryOperations: {
    unaryType: string;
    operand: VType;
    result?: VType;
}[] = [
    { unaryType: "not", operand: { type: "bool" } },
];

const simpleBinaryOperations: {
    binaryType: string;
    operand: VType;
    result?: VType;
}[] = [
    // arithmetic
    { binaryType: "+", operand: { type: "int" } },
    { binaryType: "+", operand: { type: "string" } },
    { binaryType: "-", operand: { type: "int" } },
    { binaryType: "*", operand: { type: "int" } },
    { binaryType: "/", operand: { type: "int" } },
    // logical
    { binaryType: "and", operand: { type: "bool" } },
    { binaryType: "or", operand: { type: "bool" } },
    // equality
    { binaryType: "==", operand: { type: "null" }, result: { type: "bool" } },
    { binaryType: "==", operand: { type: "int" }, result: { type: "bool" } },
    { binaryType: "==", operand: { type: "string" }, result: { type: "bool" } },
    { binaryType: "==", operand: { type: "bool" }, result: { type: "bool" } },
    { binaryType: "!=", operand: { type: "null" }, result: { type: "bool" } },
    { binaryType: "!=", operand: { type: "int" }, result: { type: "bool" } },
    { binaryType: "!=", operand: { type: "string" }, result: { type: "bool" } },
    { binaryType: "!=", operand: { type: "bool" }, result: { type: "bool" } },
    // comparison
    { binaryType: "<", operand: { type: "int" }, result: { type: "bool" } },
    { binaryType: ">", operand: { type: "int" }, result: { type: "bool" } },
    { binaryType: "<=", operand: { type: "int" }, result: { type: "bool" } },
    { binaryType: ">=", operand: { type: "int" }, result: { type: "bool" } },
];