import { EType, Expr, Field, Param, Stmt } from "./ast.ts";

export type VisitRes = "stop" | void;

export interface AstVisitor<Args extends unknown[] = []> {
    visitStmts?(stmts: Stmt[], ...args: Args): VisitRes;
    visitStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitErrorStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitModFileStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitModBlockStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitModStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitBreakStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitReturnStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitFnStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitLetStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitTypeAliasStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitAssignStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitExprStmt?(stmt: Stmt, ...args: Args): VisitRes;
    visitExpr?(expr: Expr, ...args: Args): VisitRes;
    visitErrorExpr?(expr: Expr, ...args: Args): VisitRes;
    visitIntExpr?(expr: Expr, ...args: Args): VisitRes;
    visitStringExpr?(expr: Expr, ...args: Args): VisitRes;
    visitIdentExpr?(expr: Expr, ...args: Args): VisitRes;
    visitGroupExpr?(expr: Expr, ...args: Args): VisitRes;
    visitRefExpr?(expr: Expr, ...args: Args): VisitRes;
    visitRefMutExpr?(expr: Expr, ...args: Args): VisitRes;
    visitDerefExpr?(expr: Expr, ...args: Args): VisitRes;
    visitArrayExpr?(expr: Expr, ...args: Args): VisitRes;
    visitStructExpr?(expr: Expr, ...args: Args): VisitRes;
    visitFieldExpr?(expr: Expr, ...args: Args): VisitRes;
    visitIndexExpr?(expr: Expr, ...args: Args): VisitRes;
    visitCallExpr?(expr: Expr, ...args: Args): VisitRes;
    visitPathExpr?(expr: Expr, ...args: Args): VisitRes;
    visitETypeArgsExpr?(expr: Expr, ...args: Args): VisitRes;
    visitUnaryExpr?(expr: Expr, ...args: Args): VisitRes;
    visitBinaryExpr?(expr: Expr, ...args: Args): VisitRes;
    visitIfExpr?(expr: Expr, ...args: Args): VisitRes;
    visitBoolExpr?(expr: Expr, ...args: Args): VisitRes;
    visitNullExpr?(expr: Expr, ...args: Args): VisitRes;
    visitLoopExpr?(expr: Expr, ...args: Args): VisitRes;
    visitWhileExpr?(expr: Expr, ...args: Args): VisitRes;
    visitForInExpr?(expr: Expr, ...args: Args): VisitRes;
    visitForExpr?(expr: Expr, ...args: Args): VisitRes;
    visitBlockExpr?(expr: Expr, ...args: Args): VisitRes;
    visitSymExpr?(expr: Expr, ...args: Args): VisitRes;
    visitParam?(param: Param, ...args: Args): VisitRes;
    visitField?(field: Field, ...args: Args): VisitRes;
    visitEType?(etype: EType, ...args: Args): VisitRes;
    visitErrorEType?(etype: EType, ...args: Args): VisitRes;
    visitNullEType?(etype: EType, ...args: Args): VisitRes;
    visitIntEType?(etype: EType, ...args: Args): VisitRes;
    visitBoolEType?(etype: EType, ...args: Args): VisitRes;
    visitStringEType?(etype: EType, ...args: Args): VisitRes;
    visitIdentEType?(etype: EType, ...args: Args): VisitRes;
    visitSymEType?(etype: EType, ...args: Args): VisitRes;
    visitRefEType?(etype: EType, ...args: Args): VisitRes;
    visitRefMutEType?(etype: EType, ...args: Args): VisitRes;
    visitPtrEType?(etype: EType, ...args: Args): VisitRes;
    visitPtrMutEType?(etype: EType, ...args: Args): VisitRes;
    visitArrayEType?(etype: EType, ...args: Args): VisitRes;
    visitStructEType?(etype: EType, ...args: Args): VisitRes;
    visitTypeOfEType?(etype: EType, ...args: Args): VisitRes;
    visitAnno?(etype: EType, ...args: Args): VisitRes;
}

export function visitStmts<Args extends unknown[] = []>(
    stmts: Stmt[],
    v: AstVisitor<Args>,
    ...args: Args
) {
    if (v.visitStmts?.(stmts, ...args) === "stop") return;
    stmts.map((stmt) => visitStmt(stmt, v, ...args));
}

export function visitStmt<Args extends unknown[] = []>(
    stmt: Stmt,
    v: AstVisitor<Args>,
    ...args: Args
) {
    if (v.visitStmt?.(stmt, ...args) == "stop") return;
    switch (stmt.kind.type) {
        case "error":
            if (v.visitErrorStmt?.(stmt, ...args) == "stop") return;
            break;
        case "mod_file":
            if (v.visitModFileStmt?.(stmt, ...args) == "stop") return;
            break;
        case "mod_block":
            if (v.visitModBlockStmt?.(stmt, ...args) == "stop") return;
            visitStmts(stmt.kind.stmts, v, ...args);
            break;
        case "mod":
            if (v.visitModStmt?.(stmt, ...args) == "stop") return;
            visitStmts(stmt.kind.mod.ast, v, ...args);
            break;
        case "break":
            if (v.visitBreakStmt?.(stmt, ...args) == "stop") return;
            if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args);
            break;
        case "return":
            if (v.visitReturnStmt?.(stmt, ...args) == "stop") return;
            if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args);
            break;
        case "fn":
            if (v.visitFnStmt?.(stmt, ...args) == "stop") return;
            stmt.kind.params.map((param) => visitParam(param, v, ...args));
            if (stmt.kind.returnType) {
                visitEType(stmt.kind.returnType, v, ...args);
            }
            visitExpr(stmt.kind.body, v, ...args);
            break;
        case "let":
            if (v.visitLetStmt?.(stmt, ...args) == "stop") return;
            visitParam(stmt.kind.param, v, ...args);
            visitExpr(stmt.kind.value, v, ...args);
            break;
        case "type_alias":
            if (v.visitTypeAliasStmt?.(stmt, ...args) == "stop") return;
            visitParam(stmt.kind.param, v, ...args);
            break;
        case "assign":
            if (v.visitAssignStmt?.(stmt, ...args) == "stop") return;
            visitExpr(stmt.kind.subject, v, ...args);
            visitExpr(stmt.kind.value, v, ...args);
            break;
        case "expr":
            if (v.visitExprStmt?.(stmt, ...args) == "stop") return;
            visitExpr(stmt.kind.expr, v, ...args);
            break;
        default:
            throw new Error(
                `statement '${
                    (stmt.kind as { type: string }).type
                }' not implemented`,
            );
    }
}

export function visitExpr<Args extends unknown[] = []>(
    expr: Expr,
    v: AstVisitor<Args>,
    ...args: Args
) {
    if (v.visitExpr?.(expr, ...args) == "stop") return;
    switch (expr.kind.type) {
        case "error":
            if (v.visitErrorExpr?.(expr, ...args) == "stop") return;
            break;
        case "string":
            if (v.visitStringExpr?.(expr, ...args) == "stop") return;
            break;
        case "int":
            if (v.visitIntExpr?.(expr, ...args) == "stop") return;
            break;
        case "ident":
            if (v.visitIdentExpr?.(expr, ...args) == "stop") return;
            break;
        case "group":
            if (v.visitGroupExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.expr, v, ...args);
            break;
        case "ref":
            if (v.visitRefExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            break;
        case "ref_mut":
            if (v.visitRefMutExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            break;
        case "deref":
            if (v.visitDerefExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            break;
        case "field":
            if (v.visitFieldExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            break;
        case "index":
            if (v.visitIndexExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            visitExpr(expr.kind.value, v, ...args);
            break;
        case "call":
            if (v.visitCallExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            expr.kind.args.map((arg) => visitExpr(arg, v, ...args));
            break;
        case "path":
            if (v.visitPathExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            break;
        case "etype_args":
            if (v.visitETypeArgsExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            expr.kind.etypeArgs.map((arg) => visitEType(arg, v, ...args));
            break;
        case "unary":
            if (v.visitUnaryExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.subject, v, ...args);
            break;
        case "binary":
            if (v.visitBinaryExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.left, v, ...args);
            visitExpr(expr.kind.right, v, ...args);
            break;
        case "array":
            if (v.visitArrayExpr?.(expr, ...args) == "stop") return;
            expr.kind.exprs.map((expr) => visitExpr(expr, v, ...args));
            break;
        case "struct":
            if (v.visitStructExpr?.(expr, ...args) == "stop") return;
            expr.kind.fields.map((field) => visitField(field, v, ...args));
            break;
        case "if":
            if (v.visitIfExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.cond, v, ...args);
            visitExpr(expr.kind.truthy, v, ...args);
            if (expr.kind.falsy) visitExpr(expr.kind.falsy, v, ...args);
            break;
        case "bool":
            if (v.visitBoolExpr?.(expr, ...args) == "stop") return;
            break;
        case "null":
            if (v.visitNullExpr?.(expr, ...args) == "stop") return;
            break;
        case "loop":
            if (v.visitLoopExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.body, v, ...args);
            break;
        case "while":
            if (v.visitWhileExpr?.(expr, ...args) == "stop") return;
            visitExpr(expr.kind.cond, v, ...args);
            visitExpr(expr.kind.body, v, ...args);
            break;
        case "for_in":
            if (v.visitForInExpr?.(expr, ...args) == "stop") return;
            visitParam(expr.kind.param, v, ...args);
            visitExpr(expr.kind.value, v, ...args);
            visitExpr(expr.kind.body, v, ...args);
            break;
        case "for":
            if (v.visitForExpr?.(expr, ...args) == "stop") return;
            if (expr.kind.decl) visitStmt(expr.kind.decl, v, ...args);
            if (expr.kind.cond) visitExpr(expr.kind.cond, v, ...args);
            if (expr.kind.incr) visitStmt(expr.kind.incr, v, ...args);
            visitExpr(expr.kind.body, v, ...args);
            break;
        case "block":
            if (v.visitBlockExpr?.(expr, ...args) == "stop") return;
            expr.kind.stmts.map((stmt) => visitStmt(stmt, v, ...args));
            if (expr.kind.expr) visitExpr(expr.kind.expr, v, ...args);
            break;
        case "sym":
            if (v.visitSymExpr?.(expr, ...args) == "stop") return;
            break;
        default:
            throw new Error(
                `expression '${
                    (expr.kind as { type: string }).type
                }' not implemented`,
            );
    }
}

export function visitParam<Args extends unknown[] = []>(
    param: Param,
    v: AstVisitor<Args>,
    ...args: Args
) {
    if (v.visitParam?.(param, ...args) == "stop") return;
    if (param.etype) visitEType(param.etype, v, ...args);
}

export function visitField<Args extends unknown[] = []>(
    field: Field,
    v: AstVisitor<Args>,
    ...args: Args
) {
    if (v.visitField?.(field, ...args) == "stop") return;
    visitExpr(field.expr, v, ...args);
}

export function visitEType<Args extends unknown[] = []>(
    etype: EType,
    v: AstVisitor<Args>,
    ...args: Args
) {
    if (v.visitEType?.(etype, ...args) == "stop") return;
    switch (etype.kind.type) {
        case "error":
            if (v.visitErrorEType?.(etype, ...args) == "stop") return;
            break;
        case "string":
            if (v.visitStringEType?.(etype, ...args) == "stop") return;
            break;
        case "null":
            if (v.visitNullEType?.(etype, ...args) == "stop") return;
            break;
        case "int":
            if (v.visitIntEType?.(etype, ...args) == "stop") return;
            break;
        case "bool":
            if (v.visitBoolEType?.(etype, ...args) == "stop") return;
            break;
        case "ident":
            if (v.visitIdentEType?.(etype, ...args) == "stop") return;
            break;
        case "sym":
            if (v.visitSymEType?.(etype, ...args) == "stop") return;
            break;
        case "ref":
            if (v.visitRefEType?.(etype, ...args) == "stop") return;
            visitEType(etype.kind.subject, v, ...args);
            break;
        case "ref_mut":
            if (v.visitRefMutEType?.(etype, ...args) == "stop") return;
            visitEType(etype.kind.subject, v, ...args);
            break;
        case "ptr":
            if (v.visitPtrEType?.(etype, ...args) == "stop") return;
            visitEType(etype.kind.subject, v, ...args);
            break;
        case "ptr_mut":
            if (v.visitPtrMutEType?.(etype, ...args) == "stop") return;
            visitEType(etype.kind.subject, v, ...args);
            break;
        case "array":
            if (v.visitArrayEType?.(etype, ...args) == "stop") return;
            visitEType(etype.kind.subject, v, ...args);
            break;
        case "struct":
            if (v.visitStructEType?.(etype, ...args) == "stop") return;
            etype.kind.fields.map((field) => visitParam(field, v, ...args));
            break;
        case "type_of":
            if (v.visitTypeOfEType?.(etype, ...args) == "stop") return;
            visitExpr(etype.kind.expr, v, ...args);
            break;
        default:
            throw new Error(
                `etype '${
                    (etype.kind as { type: string }).type
                }' not implemented`,
            );
    }
}

export function stmtToString(stmt: Stmt): string {
    const body = (() => {
        switch (stmt.kind.type) {
            case "assign":
                return `{ subject: ${exprToString(stmt.kind.subject)}, value: ${
                    exprToString(stmt.kind.value)
                } }`;
        }
        return "(<not implemented>)";
    })();
    const { line } = stmt.pos;
    return `${stmt.kind.type}:${line}${body}`;
}

export function exprToString(expr: Expr): string {
    const body = (() => {
        switch (expr.kind.type) {
            case "binary":
                return `(${
                    exprToString(expr.kind.left)
                } ${expr.kind.binaryType} ${exprToString(expr.kind.right)})`;
            case "sym":
                return `(${expr.kind.ident})`;
        }
        return "(<not implemented>)";
    })();
    const { line } = expr.pos;
    return `${expr.kind.type}:${line}${body}`;
}