import { Expr, Stmt } from "./ast.ts";
import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts";
import { GenericArgsMap, VType } from "./vtype.ts";

export class Monomorphizer {
    private fnIdCounter = 0;
    private fns: MonoFnsMap = {};
    private callMap: MonoCallNameGenMap = {};
    private allFns: Map<number, Stmt>;
    private entryFn: Stmt;

    constructor(private ast: Stmt[]) {
        this.allFns = new AllFnsCollector().collect(this.ast);
        this.entryFn = findMain(this.allFns);
    }

    public monomorphize(): MonoResult {
        this.monomorphizeFn(this.entryFn);
        return { monoFns: this.fns, callMap: this.callMap };
    }

    private monomorphizeFn(
        stmt: Stmt,
        genericArgs?: GenericArgsMap,
    ): MonoFn {
        const id = this.fnIdCounter;
        this.fnIdCounter += 1;
        const nameGen = monoFnNameGen(id, stmt, genericArgs);
        if (nameGen in this.fns) {
            return this.fns[nameGen];
        }
        const monoFn = { id, nameGen, stmt, genericArgs };
        this.fns[nameGen] = monoFn;
        const calls = new CallCollector().collect(stmt);
        for (const call of calls) {
            this.callMap[call.id] = nameGen;
            if (call.kind.type !== "call") {
                throw new Error();
            }
            if (
                call.kind.subject.vtype?.type === "fn" &&
                call.kind.subject.vtype.genericParams === undefined
            ) {
                const fn = this.allFns.get(call.kind.subject.vtype.stmtId);
                if (fn === undefined) {
                    throw new Error();
                }
                const monoFn = this.monomorphizeFn(fn);
                this.callMap[call.id] = monoFn.nameGen;
                continue;
            }
            if (
                call.kind.subject.vtype?.type === "fn" &&
                call.kind.subject.vtype.genericParams !== undefined
            ) {
                if (call.kind.genericArgs === undefined) {
                    throw new Error();
                }
                const genericArgs = call.kind.genericArgs;

                const monoArgs: GenericArgsMap = {};
                for (const key in call.kind.genericArgs) {
                    const vtype = genericArgs[key];
                    if (vtype.type === "generic") {
                        if (genericArgs === undefined) {
                            throw new Error();
                        }
                        monoArgs[key] = genericArgs[vtype.param.id];
                    } else {
                        monoArgs[key] = vtype;
                    }
                }

                const fnType = call.kind.subject.vtype!;
                if (fnType.type !== "fn") {
                    throw new Error();
                }
                const fn = this.allFns.get(fnType.stmtId);
                if (fn === undefined) {
                    throw new Error();
                }
                const monoFn = this.monomorphizeFn(fn, monoArgs);
                this.callMap[call.id] = monoFn.nameGen;
                continue;
            }
            if (call.kind.subject.vtype?.type === "generic_spec") {
                const genericSpecType = call.kind.subject.vtype!;
                if (genericSpecType.subject.type !== "fn") {
                    throw new Error();
                }
                const fnType = genericSpecType.subject;

                const monoArgs: GenericArgsMap = {};
                for (const key in genericSpecType.genericArgs) {
                    const vtype = genericSpecType.genericArgs[key];
                    if (vtype.type === "generic") {
                        if (genericArgs === undefined) {
                            throw new Error();
                        }
                        monoArgs[key] = genericArgs[vtype.param.id];
                    } else {
                        monoArgs[key] = vtype;
                    }
                }

                const fn = this.allFns.get(fnType.stmtId);
                if (fn === undefined) {
                    throw new Error();
                }
                const monoFn = this.monomorphizeFn(fn, monoArgs);
                this.callMap[call.id] = monoFn.nameGen;
                continue;
            }
            throw new Error();
        }
        return monoFn;
    }
}

export type MonoResult = {
    monoFns: MonoFnsMap;
    callMap: MonoCallNameGenMap;
};

export type MonoFnsMap = { [nameGen: string]: MonoFn };

export type MonoFn = {
    id: number;
    nameGen: string;
    stmt: Stmt;
    genericArgs?: GenericArgsMap;
};

export type MonoCallNameGenMap = { [exprId: number]: string };

function monoFnNameGen(
    id: number,
    stmt: Stmt,
    genericArgs?: GenericArgsMap,
): string {
    if (stmt.kind.type !== "fn") {
        throw new Error();
    }
    if (stmt.kind.ident === "main") {
        return "main";
    }
    if (genericArgs === undefined) {
        return `${stmt.kind.ident}_${id}`;
    }
    const args = Object.values(genericArgs)
        .map((arg) => vtypeNameGenPart(arg))
        .join("_");
    return `${stmt.kind.ident}_${id}_${args}`;
}

function vtypeNameGenPart(vtype: VType): string {
    switch (vtype.type) {
        case "error":
            throw new Error("error in type");
        case "string":
        case "int":
        case "bool":
        case "null":
        case "unknown":
            return vtype.type;
        case "ref":
            return `&${vtypeNameGenPart(vtype.subject)}`;
        case "ref_mut":
            return `&mut ${vtypeNameGenPart(vtype.subject)}`;
        case "ptr":
            return `*${vtypeNameGenPart(vtype.subject)}`;
        case "ptr_mut":
            return `*mut ${vtypeNameGenPart(vtype.subject)}`;
        case "array":
            return `[${vtypeNameGenPart(vtype.subject)}]`;
        case "struct": {
            const fields = vtype.fields
                .map((field) =>
                    `${field.ident}, ${vtypeNameGenPart(field.vtype)}`
                )
                .join(", ");
            return `struct { ${fields} }`;
        }
        case "fn":
            return `fn(${vtype.stmtId})`;
        case "generic":
        case "generic_spec":
            throw new Error("cannot be monomorphized");
    }
}

export class AllFnsCollector implements AstVisitor {
    private allFns = new Map<number, Stmt>();

    public collect(ast: Stmt[]): Map<number, Stmt> {
        visitStmts(ast, this);
        return this.allFns;
    }

    visitFnStmt(stmt: Stmt): VisitRes {
        if (stmt.kind.type !== "fn") {
            throw new Error();
        }
        this.allFns.set(stmt.id, stmt);
    }
}

function findMain(fns: Map<number, Stmt>): Stmt {
    const mainId = fns.values().find((stmt) =>
        stmt.kind.type === "fn" && stmt.kind.ident === "main"
    );
    if (mainId === undefined) {
        console.error("error: cannot find function 'main'");
        console.error(apology);
        throw new Error("cannot find function 'main'");
    }
    return mainId;
}

class CallCollector implements AstVisitor {
    private calls: Expr[] = [];

    public collect(fn: Stmt): Expr[] {
        if (fn.kind.type !== "fn") {
            throw new Error();
        }
        visitExpr(fn.kind.body, this);
        return this.calls;
    }

    visitFnStmt(_stmt: Stmt): VisitRes {
        return "stop";
    }

    visitCallExpr(expr: Expr): VisitRes {
        if (expr.kind.type !== "call") {
            throw new Error();
        }
        this.calls.push(expr);
    }
}

const apology = `
    Hear me out. Monomorphization, meaning the process
    inwich generic functions are stamped out into seperate
    specialized functions is actually really hard, and I
    have a really hard time right now, figuring out, how
    to do it in a smart way. To really explain it, let's
    imagine you have a function, you defined as a<T>().
    For each call with seperate generics arguments given,
    such as a::<int>() and a::<string>(), a specialized
    function has to be 'stamped out', ie. created and put
    into the compilation with the rest of the program. Now
    to the reason as to why 'main' is needed. To do the
    monomorphization, we have to do it recursively. To
    explain this, imagine you have a generic function a<T>
    and inside the body of a<T>, you call another generic
    function such as b<T> with the same generic type. This
    means that the monomorphization process of b<T> depends
    on the monomorphization of a<T>. What this essentially
    means, is that the monomorphization process works on
    the program as a call graph, meaning a graph or tree
    structure where each represents a function call to
    either another function or a recursive call to the
    function itself. But a problem arises from doing it
    this way, which is that a call graph will need an
    entrypoint. The language, as it is currently, does
    not really require a 'main'-function. Or maybe it
    does, but that's beside the point. The point is that
    we need a main function, to be the entry point for
    the call graph. The monomorphization process then
    runs through the program from that entry point. This
    means that each function we call, will itself be
    monomorphized and added to the compilation. It also
    means that functions that are not called, will also
    not be added to the compilation. This essentially
    eliminates uncalled/dead functions. Is this
    particularly smart to do in such a high level part
    of the compilation process? I don't know. It's
    obvious that we can't just use every function as
    an entry point in the call graph, because we're
    actively added new functions. Additionally, with
    generic functions, we don't know, if they're the
    entry point, what generic arguments, they should
    be monomorphized with. We could do monomorphization
    the same way C++ does it, where all non-generic
    functions before monomorphization are treated as
    entry points in the call graph. But this has the
    drawback that generic and non-generic functions
    are treated differently, which has many underlying
    drawbacks, especially pertaining to the amount of
    work needed to handle both in all proceeding steps
    of the compiler. Anyways, I just wanted to yap and
    complain about the way generics and monomorphization
    has made the compiler 100x more complicated, and
    that I find it really hard to implement in a way,
    that is not either too simplistic or so complicated
    and advanced I'm too dumb to implement it. So if
    you would be so kind as to make it clear to the
    compiler, what function it should designate as
    the entry point to the call graph, it will use
    for monomorphization, that would be very kind of
    you. The way you do this, is by added or selecting
    one of your current functions and giving it the
    name of 'main'. This is spelled m-a-i-n. The word
    is synonemous with the words primary and principle.
    The name is meant to designate the entry point into
    the program, which is why the monomorphization
    process uses this specific function as the entry
    point into the call graph, it generates. So if you
    would be so kind as to do that, that would really
    make my day. In any case, keep hacking ferociously
    on whatever you're working on. I have monomorphizer
    to implement. See ya. -Your favorite compiler girl <3
`.replaceAll("    ", "").trim();