import { exhausted } from "../util.ts";
import {
    ArrayTy,
    AssignStmt,
    BindPat,
    BreakStmt,
    EnumItem,
    Expr,
    ExprStmt,
    File,
    FnItem,
    Ident,
    Item,
    ItemStmt,
    LetStmt,
    ModBlockItem,
    ModFileItem,
    Pat,
    PtrTy,
    RefTy,
    ReturnStmt,
    SliceTy,
    StaticItem,
    Stmt,
    StructItem,
    Ty,
    TypeAliasItem,
    UseItem,
} from "./ast.ts";

export type VisitRes = "stop" | void;

type R = VisitRes;
type PM = unknown[];

export interface Visitor<
    P extends PM = [],
> {
    visitFile?(file: File, ...p: P): R;
    visitStmt?(stmt: Stmt, ...p: P): R;
    visitErrorStmt?(stmt: Stmt, ...p: P): R;
    visitItemStmt?(stmt: Stmt, kind: ItemStmt, ...p: P): R;
    visitLetStmt?(stmt: Stmt, kind: LetStmt, ...p: P): R;
    visitReturnStmt?(stmt: Stmt, kind: ReturnStmt, ...p: P): R;
    visitBreakStmt?(stmt: Stmt, kind: BreakStmt, ...p: P): R;
    visitAssignStmt?(stmt: Stmt, kind: AssignStmt, ...p: P): R;
    visitExprStmt?(stmt: Stmt, kind: ExprStmt, ...p: P): R;
    visitItem?(item: Item, ...p: P): R;
    visitErrorItem?(item: Item, ...p: P): R;
    visitModBlockItem?(item: Item, kind: ModBlockItem, ...p: P): R;
    visitModFileItem?(item: Item, kind: ModFileItem, ...p: P): R;
    visitEnumItem?(item: Item, kind: EnumItem, ...p: P): R;
    visitStructItem?(item: Item, kind: StructItem, ...p: P): R;
    visitFnItem?(item: Item, kind: FnItem, ...p: P): R;
    visitUseItem?(item: Item, kind: UseItem, ...p: P): R;
    visitStaticItem?(item: Item, kind: StaticItem, ...p: P): R;
    visitTypeAliasItem?(item: Item, kind: TypeAliasItem, ...p: P): R;
    visitExpr?(expr: Expr, ...p: P): R;
    visitErrorExpr?(expr: Expr, ...p: P): R;
    visitIdentExpr?(expr: Expr, kind: Ident, ...p: P): R;
    visitPat?(pat: Pat, ...p: P): R;
    visitErrorPat?(pat: Pat, ...p: P): R;
    visitBindPat?(pat: Pat, kind: BindPat, ...p: P): R;
    visitTy?(ty: Ty, ...p: P): R;
    visitErrorTy?(ty: Ty, ...p: P): R;
    visitIdentTy?(ty: Ty, kind: Ident, ...p: P): R;
    visitRefTy?(ty: Ty, kind: RefTy, ...p: P): R;
    visitPtrTy?(ty: Ty, kind: PtrTy, ...p: P): R;
    visitSliceTy?(ty: Ty, kind: SliceTy, ...p: P): R;
    visitArrayTy?(ty: Ty, kind: ArrayTy, ...p: P): R;
    visitIdent?(ident: Ident, ...p: P): R;
}

export function visitFile<
    P extends PM = [],
>(
    v: Visitor<P>,
    file: File,
    ...p: P
) {
    if (v.visitFile?.(file, ...p) === "stop") return;
    visitStmts(v, file.stmts, ...p);
}

export function visitStmts<
    P extends PM = [],
>(
    v: Visitor<P>,
    stmts: Stmt[],
    ...p: P
) {
    for (const stmt of stmts) {
        visitStmt(v, stmt, ...p);
    }
}

export function visitStmt<
    P extends PM = [],
>(
    v: Visitor<P>,
    stmt: Stmt,
    ...p: P
) {
    const kind = stmt.kind;
    switch (kind.tag) {
        case "error":
            if (v.visitErrorStmt?.(stmt, ...p) === "stop") return;
            return;
        case "item":
            if (v.visitItemStmt?.(stmt, kind, ...p) === "stop") return;
            return;
        case "let":
            if (v.visitLetStmt?.(stmt, kind, ...p) === "stop") return;
            return;
        case "return":
            if (v.visitReturnStmt?.(stmt, kind, ...p) === "stop") return;
            return;
        case "break":
            if (v.visitBreakStmt?.(stmt, kind, ...p) === "stop") return;
            return;
        case "assign":
            if (v.visitAssignStmt?.(stmt, kind, ...p) === "stop") return;
            return;
        case "expr":
            if (v.visitExprStmt?.(stmt, kind, ...p) === "stop") return;
            return;
    }
    exhausted(kind);
}

export function visitItem<
    P extends PM = [],
>(
    v: Visitor<P>,
    item: Item,
    ...p: P
) {
    const kind = item.kind;
    switch (kind.tag) {
        case "error":
            if (v.visitErrorItem?.(item, ...p) === "stop") return;
            return;
        case "mod_block":
            if (v.visitModBlockItem?.(item, kind, ...p) === "stop") return;
            return;
        case "mod_file":
            if (v.visitModFileItem?.(item, kind, ...p) === "stop") return;
            return;
        case "enum":
            if (v.visitEnumItem?.(item, kind, ...p) === "stop") return;
            return;
        case "struct":
            if (v.visitStructItem?.(item, kind, ...p) === "stop") return;
            return;
        case "fn":
            if (v.visitFnItem?.(item, kind, ...p) === "stop") return;
            return;
        case "use":
            if (v.visitUseItem?.(item, kind, ...p) === "stop") return;
            return;
        case "static":
            if (v.visitStaticItem?.(item, kind, ...p) === "stop") return;
            return;
        case "type_alias":
            if (v.visitTypeAliasItem?.(item, kind, ...p) === "stop") return;
            return;
    }
    exhausted(kind);
}

export function visitExpr<
    P extends PM = [],
>(
    v: Visitor<P>,
    expr: Expr,
    ...p: P
) {
    const kind = expr.kind;
    switch (kind.tag) {
        case "error":
            if (v.visitErrorExpr?.(expr, ...p) === "stop") return;
            return;
        case "ident":
            if (v.visitIdentExpr?.(expr, kind, ...p) === "stop") return;
            return;
    }
    exhausted(kind);
}

export function visitPat<
    P extends PM = [],
>(
    v: Visitor<P>,
    pat: Pat,
    ...p: P
) {
    const kind = pat.kind;
    switch (kind.tag) {
        case "error":
            if (v.visitErrorPat?.(pat, ...p) === "stop") return;
            return;
        case "bind":
            if (v.visitBindPat?.(pat, kind, ...p) === "stop") return;
            return;
    }
    exhausted(kind);
}

export function visitTy<
    P extends PM = [],
>(
    v: Visitor<P>,
    ty: Ty,
    ...p: P
) {
    const kind = ty.kind;
    switch (kind.tag) {
        case "error":
            if (v.visitErrorTy?.(ty, ...p) === "stop") return;
            return;
        case "ident":
            if (v.visitIdentTy?.(ty, kind, ...p) === "stop") return;
            return;
    }
    exhausted(kind);
}

export function visitIdent<
    P extends PM = [],
>(
    v: Visitor<P>,
    ident: Ident,
    ...p: P
) {
    v.visitIdent?.(ident, ...p);
}