import { EType, Expr, Field, Param, Stmt } from "./ast.ts"; export type VisitRes = "stop" | void; export interface AstVisitor { visitStmts?(stmts: Stmt[], ...args: Args): VisitRes; visitStmt?(stmt: Stmt, ...args: Args): VisitRes; visitErrorStmt?(stmt: Stmt, ...args: Args): VisitRes; visitModFileStmt?(stmt: Stmt, ...args: Args): VisitRes; visitModBlockStmt?(stmt: Stmt, ...args: Args): VisitRes; visitModStmt?(stmt: Stmt, ...args: Args): VisitRes; visitBreakStmt?(stmt: Stmt, ...args: Args): VisitRes; visitReturnStmt?(stmt: Stmt, ...args: Args): VisitRes; visitFnStmt?(stmt: Stmt, ...args: Args): VisitRes; visitLetStmt?(stmt: Stmt, ...args: Args): VisitRes; visitAssignStmt?(stmt: Stmt, ...args: Args): VisitRes; visitExprStmt?(stmt: Stmt, ...args: Args): VisitRes; visitExpr?(expr: Expr, ...args: Args): VisitRes; visitErrorExpr?(expr: Expr, ...args: Args): VisitRes; visitIntExpr?(expr: Expr, ...args: Args): VisitRes; visitStringExpr?(expr: Expr, ...args: Args): VisitRes; visitIdentExpr?(expr: Expr, ...args: Args): VisitRes; visitGroupExpr?(expr: Expr, ...args: Args): VisitRes; visitArrayExpr?(expr: Expr, ...args: Args): VisitRes; visitStructExpr?(expr: Expr, ...args: Args): VisitRes; visitFieldExpr?(expr: Expr, ...args: Args): VisitRes; visitIndexExpr?(expr: Expr, ...args: Args): VisitRes; visitCallExpr?(expr: Expr, ...args: Args): VisitRes; visitPathExpr?(expr: Expr, ...args: Args): VisitRes; visitETypeArgsExpr?(expr: Expr, ...args: Args): VisitRes; visitUnaryExpr?(expr: Expr, ...args: Args): VisitRes; visitBinaryExpr?(expr: Expr, ...args: Args): VisitRes; visitIfExpr?(expr: Expr, ...args: Args): VisitRes; visitBoolExpr?(expr: Expr, ...args: Args): VisitRes; visitNullExpr?(expr: Expr, ...args: Args): VisitRes; visitLoopExpr?(expr: Expr, ...args: Args): VisitRes; visitWhileExpr?(expr: Expr, ...args: Args): VisitRes; visitForInExpr?(expr: Expr, ...args: Args): VisitRes; visitForExpr?(expr: Expr, ...args: Args): VisitRes; visitBlockExpr?(expr: Expr, ...args: Args): VisitRes; visitSymExpr?(expr: Expr, ...args: Args): VisitRes; visitParam?(param: Param, ...args: Args): VisitRes; visitField?(field: Field, ...args: Args): VisitRes; visitEType?(etype: EType, ...args: Args): VisitRes; visitErrorEType?(etype: EType, ...args: Args): VisitRes; visitNullEType?(etype: EType, ...args: Args): VisitRes; visitIntEType?(etype: EType, ...args: Args): VisitRes; visitBoolEType?(etype: EType, ...args: Args): VisitRes; visitStringEType?(etype: EType, ...args: Args): VisitRes; visitIdentEType?(etype: EType, ...args: Args): VisitRes; visitSymEType?(etype: EType, ...args: Args): VisitRes; visitArrayEType?(etype: EType, ...args: Args): VisitRes; visitStructEType?(etype: EType, ...args: Args): VisitRes; visitTypeOfEType?(etype: EType, ...args: Args): VisitRes; visitAnno?(etype: EType, ...args: Args): VisitRes; } export function visitStmts( stmts: Stmt[], v: AstVisitor, ...args: Args ) { if (v.visitStmts?.(stmts, ...args) === "stop") return; stmts.map((stmt) => visitStmt(stmt, v, ...args)); } export function visitStmt( stmt: Stmt, v: AstVisitor, ...args: Args ) { if (v.visitStmt?.(stmt, ...args) == "stop") return; switch (stmt.kind.type) { case "error": if (v.visitErrorStmt?.(stmt, ...args) == "stop") return; break; case "mod_file": if (v.visitModFileStmt?.(stmt, ...args) == "stop") return; break; case "mod_block": if (v.visitModBlockStmt?.(stmt, ...args) == "stop") return; visitStmts(stmt.kind.stmts, v, ...args); break; case "mod": if (v.visitModStmt?.(stmt, ...args) == "stop") return; visitStmts(stmt.kind.mod.ast, v, ...args); break; case "break": if (v.visitBreakStmt?.(stmt, ...args) == "stop") return; if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args); break; case "return": if (v.visitReturnStmt?.(stmt, ...args) == "stop") return; if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args); break; case "fn": if (v.visitFnStmt?.(stmt, ...args) == "stop") return; stmt.kind.params.map((param) => visitParam(param, v, ...args)); if (stmt.kind.returnType) { visitEType(stmt.kind.returnType, v, ...args); } visitExpr(stmt.kind.body, v, ...args); break; case "let": if (v.visitLetStmt?.(stmt, ...args) == "stop") return; visitParam(stmt.kind.param, v, ...args); visitExpr(stmt.kind.value, v, ...args); break; case "assign": if (v.visitAssignStmt?.(stmt, ...args) == "stop") return; visitExpr(stmt.kind.subject, v, ...args); visitExpr(stmt.kind.value, v, ...args); break; case "expr": if (v.visitExprStmt?.(stmt, ...args) == "stop") return; visitExpr(stmt.kind.expr, v, ...args); break; default: throw new Error( `statement '${ (stmt.kind as { type: string }).type }' not implemented`, ); } } export function visitExpr( expr: Expr, v: AstVisitor, ...args: Args ) { if (v.visitExpr?.(expr, ...args) == "stop") return; switch (expr.kind.type) { case "error": if (v.visitErrorExpr?.(expr, ...args) == "stop") return; break; case "string": if (v.visitStringExpr?.(expr, ...args) == "stop") return; break; case "int": if (v.visitIntExpr?.(expr, ...args) == "stop") return; break; case "ident": if (v.visitIdentExpr?.(expr, ...args) == "stop") return; break; case "group": if (v.visitGroupExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.expr, v, ...args); break; case "field": if (v.visitFieldExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.subject, v, ...args); break; case "index": if (v.visitIndexExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.subject, v, ...args); visitExpr(expr.kind.value, v, ...args); break; case "call": if (v.visitCallExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.subject, v, ...args); expr.kind.args.map((arg) => visitExpr(arg, v, ...args)); break; case "path": if (v.visitPathExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.subject, v, ...args); break; case "etype_args": if (v.visitETypeArgsExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.subject, v, ...args); expr.kind.etypeArgs.map((arg) => visitEType(arg, v, ...args)); break; case "unary": if (v.visitUnaryExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.subject, v, ...args); break; case "binary": if (v.visitBinaryExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.left, v, ...args); visitExpr(expr.kind.right, v, ...args); break; case "array": if (v.visitArrayExpr?.(expr, ...args) == "stop") return; expr.kind.exprs.map((expr) => visitExpr(expr, v, ...args)); break; case "struct": if (v.visitStructExpr?.(expr, ...args) == "stop") return; expr.kind.fields.map((field) => visitField(field, v, ...args)); break; case "if": if (v.visitIfExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.cond, v, ...args); visitExpr(expr.kind.truthy, v, ...args); if (expr.kind.falsy) visitExpr(expr.kind.falsy, v, ...args); break; case "bool": if (v.visitBoolExpr?.(expr, ...args) == "stop") return; break; case "null": if (v.visitNullExpr?.(expr, ...args) == "stop") return; break; case "loop": if (v.visitLoopExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.body, v, ...args); break; case "while": if (v.visitWhileExpr?.(expr, ...args) == "stop") return; visitExpr(expr.kind.cond, v, ...args); visitExpr(expr.kind.body, v, ...args); break; case "for_in": if (v.visitForInExpr?.(expr, ...args) == "stop") return; visitParam(expr.kind.param, v, ...args); visitExpr(expr.kind.value, v, ...args); visitExpr(expr.kind.body, v, ...args); break; case "for": if (v.visitForExpr?.(expr, ...args) == "stop") return; if (expr.kind.decl) visitStmt(expr.kind.decl, v, ...args); if (expr.kind.cond) visitExpr(expr.kind.cond, v, ...args); if (expr.kind.incr) visitStmt(expr.kind.incr, v, ...args); visitExpr(expr.kind.body, v, ...args); break; case "block": if (v.visitBlockExpr?.(expr, ...args) == "stop") return; expr.kind.stmts.map((stmt) => visitStmt(stmt, v, ...args)); if (expr.kind.expr) visitExpr(expr.kind.expr, v, ...args); break; case "sym": if (v.visitSymExpr?.(expr, ...args) == "stop") return; break; default: throw new Error( `expression '${ (expr.kind as { type: string }).type }' not implemented`, ); } } export function visitParam( param: Param, v: AstVisitor, ...args: Args ) { if (v.visitParam?.(param, ...args) == "stop") return; if (param.etype) visitEType(param.etype, v, ...args); } export function visitField( field: Field, v: AstVisitor, ...args: Args ) { if (v.visitField?.(field, ...args) == "stop") return; visitExpr(field.expr, v, ...args); } export function visitEType( etype: EType, v: AstVisitor, ...args: Args ) { if (v.visitEType?.(etype, ...args) == "stop") return; switch (etype.kind.type) { case "error": if (v.visitErrorEType?.(etype, ...args) == "stop") return; break; case "string": if (v.visitStringEType?.(etype, ...args) == "stop") return; break; case "null": if (v.visitNullEType?.(etype, ...args) == "stop") return; break; case "int": if (v.visitIntEType?.(etype, ...args) == "stop") return; break; case "bool": if (v.visitBoolEType?.(etype, ...args) == "stop") return; break; case "ident": if (v.visitIdentEType?.(etype, ...args) == "stop") return; break; case "sym": if (v.visitSymEType?.(etype, ...args) == "stop") return; break; case "array": if (v.visitArrayEType?.(etype, ...args) == "stop") return; if (etype.kind.inner) visitEType(etype.kind.inner, v, ...args); break; case "struct": if (v.visitStructEType?.(etype, ...args) == "stop") return; etype.kind.fields.map((field) => visitParam(field, v, ...args)); break; case "type_of": if (v.visitTypeOfEType?.(etype, ...args) == "stop") return; visitExpr(etype.kind.expr, v, ...args); break; default: throw new Error( `etype '${ (etype.kind as { type: string }).type }' not implemented`, ); } } export function stmtToString(stmt: Stmt): string { const body = (() => { switch (stmt.kind.type) { case "assign": return `{ subject: ${exprToString(stmt.kind.subject)}, value: ${ exprToString(stmt.kind.value) } }`; } return "()"; })(); const { line } = stmt.pos; return `${stmt.kind.type}:${line}${body}`; } export function exprToString(expr: Expr): string { const body = (() => { switch (expr.kind.type) { case "binary": return `(${ exprToString(expr.kind.left) } ${expr.kind.binaryType} ${exprToString(expr.kind.right)})`; case "sym": return `(${expr.kind.ident})`; } return "()"; })(); const { line } = expr.pos; return `${expr.kind.type}:${line}${body}`; }