slige/compiler/ast_visitor.ts
2025-01-17 11:50:14 +01:00

342 lines
13 KiB
TypeScript

import { EType, Expr, Field, Param, Stmt } from "./ast.ts";
export type VisitRes = "stop" | void;
export interface AstVisitor<Args extends unknown[] = []> {
visitStmts?(stmts: Stmt[], ...args: Args): VisitRes;
visitStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitErrorStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitModFileStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitModBlockStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitModStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitBreakStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitReturnStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitFnStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitLetStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitTypeAliasStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitAssignStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitExprStmt?(stmt: Stmt, ...args: Args): VisitRes;
visitExpr?(expr: Expr, ...args: Args): VisitRes;
visitErrorExpr?(expr: Expr, ...args: Args): VisitRes;
visitIntExpr?(expr: Expr, ...args: Args): VisitRes;
visitStringExpr?(expr: Expr, ...args: Args): VisitRes;
visitIdentExpr?(expr: Expr, ...args: Args): VisitRes;
visitGroupExpr?(expr: Expr, ...args: Args): VisitRes;
visitArrayExpr?(expr: Expr, ...args: Args): VisitRes;
visitStructExpr?(expr: Expr, ...args: Args): VisitRes;
visitFieldExpr?(expr: Expr, ...args: Args): VisitRes;
visitIndexExpr?(expr: Expr, ...args: Args): VisitRes;
visitCallExpr?(expr: Expr, ...args: Args): VisitRes;
visitPathExpr?(expr: Expr, ...args: Args): VisitRes;
visitETypeArgsExpr?(expr: Expr, ...args: Args): VisitRes;
visitUnaryExpr?(expr: Expr, ...args: Args): VisitRes;
visitBinaryExpr?(expr: Expr, ...args: Args): VisitRes;
visitIfExpr?(expr: Expr, ...args: Args): VisitRes;
visitBoolExpr?(expr: Expr, ...args: Args): VisitRes;
visitNullExpr?(expr: Expr, ...args: Args): VisitRes;
visitLoopExpr?(expr: Expr, ...args: Args): VisitRes;
visitWhileExpr?(expr: Expr, ...args: Args): VisitRes;
visitForInExpr?(expr: Expr, ...args: Args): VisitRes;
visitForExpr?(expr: Expr, ...args: Args): VisitRes;
visitBlockExpr?(expr: Expr, ...args: Args): VisitRes;
visitSymExpr?(expr: Expr, ...args: Args): VisitRes;
visitParam?(param: Param, ...args: Args): VisitRes;
visitField?(field: Field, ...args: Args): VisitRes;
visitEType?(etype: EType, ...args: Args): VisitRes;
visitErrorEType?(etype: EType, ...args: Args): VisitRes;
visitNullEType?(etype: EType, ...args: Args): VisitRes;
visitIntEType?(etype: EType, ...args: Args): VisitRes;
visitBoolEType?(etype: EType, ...args: Args): VisitRes;
visitStringEType?(etype: EType, ...args: Args): VisitRes;
visitIdentEType?(etype: EType, ...args: Args): VisitRes;
visitSymEType?(etype: EType, ...args: Args): VisitRes;
visitArrayEType?(etype: EType, ...args: Args): VisitRes;
visitStructEType?(etype: EType, ...args: Args): VisitRes;
visitTypeOfEType?(etype: EType, ...args: Args): VisitRes;
visitAnno?(etype: EType, ...args: Args): VisitRes;
}
export function visitStmts<Args extends unknown[] = []>(
stmts: Stmt[],
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitStmts?.(stmts, ...args) === "stop") return;
stmts.map((stmt) => visitStmt(stmt, v, ...args));
}
export function visitStmt<Args extends unknown[] = []>(
stmt: Stmt,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitStmt?.(stmt, ...args) == "stop") return;
switch (stmt.kind.type) {
case "error":
if (v.visitErrorStmt?.(stmt, ...args) == "stop") return;
break;
case "mod_file":
if (v.visitModFileStmt?.(stmt, ...args) == "stop") return;
break;
case "mod_block":
if (v.visitModBlockStmt?.(stmt, ...args) == "stop") return;
visitStmts(stmt.kind.stmts, v, ...args);
break;
case "mod":
if (v.visitModStmt?.(stmt, ...args) == "stop") return;
visitStmts(stmt.kind.mod.ast, v, ...args);
break;
case "break":
if (v.visitBreakStmt?.(stmt, ...args) == "stop") return;
if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args);
break;
case "return":
if (v.visitReturnStmt?.(stmt, ...args) == "stop") return;
if (stmt.kind.expr) visitExpr(stmt.kind.expr, v, ...args);
break;
case "fn":
if (v.visitFnStmt?.(stmt, ...args) == "stop") return;
stmt.kind.params.map((param) => visitParam(param, v, ...args));
if (stmt.kind.returnType) {
visitEType(stmt.kind.returnType, v, ...args);
}
visitExpr(stmt.kind.body, v, ...args);
break;
case "let":
if (v.visitLetStmt?.(stmt, ...args) == "stop") return;
visitParam(stmt.kind.param, v, ...args);
visitExpr(stmt.kind.value, v, ...args);
break;
case "type_alias":
if (v.visitTypeAliasStmt?.(stmt, ...args) == "stop") return;
visitParam(stmt.kind.param, v, ...args);
break;
case "assign":
if (v.visitAssignStmt?.(stmt, ...args) == "stop") return;
visitExpr(stmt.kind.subject, v, ...args);
visitExpr(stmt.kind.value, v, ...args);
break;
case "expr":
if (v.visitExprStmt?.(stmt, ...args) == "stop") return;
visitExpr(stmt.kind.expr, v, ...args);
break;
default:
throw new Error(
`statement '${
(stmt.kind as { type: string }).type
}' not implemented`,
);
}
}
export function visitExpr<Args extends unknown[] = []>(
expr: Expr,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitExpr?.(expr, ...args) == "stop") return;
switch (expr.kind.type) {
case "error":
if (v.visitErrorExpr?.(expr, ...args) == "stop") return;
break;
case "string":
if (v.visitStringExpr?.(expr, ...args) == "stop") return;
break;
case "int":
if (v.visitIntExpr?.(expr, ...args) == "stop") return;
break;
case "ident":
if (v.visitIdentExpr?.(expr, ...args) == "stop") return;
break;
case "group":
if (v.visitGroupExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.expr, v, ...args);
break;
case "field":
if (v.visitFieldExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
break;
case "index":
if (v.visitIndexExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
visitExpr(expr.kind.value, v, ...args);
break;
case "call":
if (v.visitCallExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
expr.kind.args.map((arg) => visitExpr(arg, v, ...args));
break;
case "path":
if (v.visitPathExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
break;
case "etype_args":
if (v.visitETypeArgsExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
expr.kind.etypeArgs.map((arg) => visitEType(arg, v, ...args));
break;
case "unary":
if (v.visitUnaryExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.subject, v, ...args);
break;
case "binary":
if (v.visitBinaryExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.left, v, ...args);
visitExpr(expr.kind.right, v, ...args);
break;
case "array":
if (v.visitArrayExpr?.(expr, ...args) == "stop") return;
expr.kind.exprs.map((expr) => visitExpr(expr, v, ...args));
break;
case "struct":
if (v.visitStructExpr?.(expr, ...args) == "stop") return;
expr.kind.fields.map((field) => visitField(field, v, ...args));
break;
case "if":
if (v.visitIfExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.cond, v, ...args);
visitExpr(expr.kind.truthy, v, ...args);
if (expr.kind.falsy) visitExpr(expr.kind.falsy, v, ...args);
break;
case "bool":
if (v.visitBoolExpr?.(expr, ...args) == "stop") return;
break;
case "null":
if (v.visitNullExpr?.(expr, ...args) == "stop") return;
break;
case "loop":
if (v.visitLoopExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.body, v, ...args);
break;
case "while":
if (v.visitWhileExpr?.(expr, ...args) == "stop") return;
visitExpr(expr.kind.cond, v, ...args);
visitExpr(expr.kind.body, v, ...args);
break;
case "for_in":
if (v.visitForInExpr?.(expr, ...args) == "stop") return;
visitParam(expr.kind.param, v, ...args);
visitExpr(expr.kind.value, v, ...args);
visitExpr(expr.kind.body, v, ...args);
break;
case "for":
if (v.visitForExpr?.(expr, ...args) == "stop") return;
if (expr.kind.decl) visitStmt(expr.kind.decl, v, ...args);
if (expr.kind.cond) visitExpr(expr.kind.cond, v, ...args);
if (expr.kind.incr) visitStmt(expr.kind.incr, v, ...args);
visitExpr(expr.kind.body, v, ...args);
break;
case "block":
if (v.visitBlockExpr?.(expr, ...args) == "stop") return;
expr.kind.stmts.map((stmt) => visitStmt(stmt, v, ...args));
if (expr.kind.expr) visitExpr(expr.kind.expr, v, ...args);
break;
case "sym":
if (v.visitSymExpr?.(expr, ...args) == "stop") return;
break;
default:
throw new Error(
`expression '${
(expr.kind as { type: string }).type
}' not implemented`,
);
}
}
export function visitParam<Args extends unknown[] = []>(
param: Param,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitParam?.(param, ...args) == "stop") return;
if (param.etype) visitEType(param.etype, v, ...args);
}
export function visitField<Args extends unknown[] = []>(
field: Field,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitField?.(field, ...args) == "stop") return;
visitExpr(field.expr, v, ...args);
}
export function visitEType<Args extends unknown[] = []>(
etype: EType,
v: AstVisitor<Args>,
...args: Args
) {
if (v.visitEType?.(etype, ...args) == "stop") return;
switch (etype.kind.type) {
case "error":
if (v.visitErrorEType?.(etype, ...args) == "stop") return;
break;
case "string":
if (v.visitStringEType?.(etype, ...args) == "stop") return;
break;
case "null":
if (v.visitNullEType?.(etype, ...args) == "stop") return;
break;
case "int":
if (v.visitIntEType?.(etype, ...args) == "stop") return;
break;
case "bool":
if (v.visitBoolEType?.(etype, ...args) == "stop") return;
break;
case "ident":
if (v.visitIdentEType?.(etype, ...args) == "stop") return;
break;
case "sym":
if (v.visitSymEType?.(etype, ...args) == "stop") return;
break;
case "array":
if (v.visitArrayEType?.(etype, ...args) == "stop") return;
if (etype.kind.inner) visitEType(etype.kind.inner, v, ...args);
break;
case "struct":
if (v.visitStructEType?.(etype, ...args) == "stop") return;
etype.kind.fields.map((field) => visitParam(field, v, ...args));
break;
case "type_of":
if (v.visitTypeOfEType?.(etype, ...args) == "stop") return;
visitExpr(etype.kind.expr, v, ...args);
break;
default:
throw new Error(
`etype '${
(etype.kind as { type: string }).type
}' not implemented`,
);
}
}
export function stmtToString(stmt: Stmt): string {
const body = (() => {
switch (stmt.kind.type) {
case "assign":
return `{ subject: ${exprToString(stmt.kind.subject)}, value: ${
exprToString(stmt.kind.value)
} }`;
}
return "(<not implemented>)";
})();
const { line } = stmt.pos;
return `${stmt.kind.type}:${line}${body}`;
}
export function exprToString(expr: Expr): string {
const body = (() => {
switch (expr.kind.type) {
case "binary":
return `(${
exprToString(expr.kind.left)
} ${expr.kind.binaryType} ${exprToString(expr.kind.right)})`;
case "sym":
return `(${expr.kind.ident})`;
}
return "(<not implemented>)";
})();
const { line } = expr.pos;
return `${expr.kind.type}:${line}${body}`;
}