import { EType, Expr, Stmt } from "./ast.ts"; import { printStackTrace, Reporter } from "./info.ts"; import { Pos } from "./token.ts"; import { extractGenericType, GenericArgsMap, VType, VTypeGenericParam, VTypeParam, vtypesEqual, vtypeToString, } from "./vtype.ts"; export class Checker { private fnReturnStack: VType[] = []; private loopBreakStack: VType[][] = []; private globalIdToGenericParamMap = new Map(); public constructor(private reporter: Reporter) {} public check(stmts: Stmt[]) { this.checkFnHeaders(stmts); for (const stmt of stmts) { this.checkStmt(stmt); } } private checkFnHeaders(stmts: Stmt[]) { for (const stmt of stmts) { if (stmt.kind.type !== "fn") { continue; } let genericParams: VTypeGenericParam[] | undefined; if (stmt.kind.genericParams !== undefined) { genericParams = []; for (const etypeParam of stmt.kind.genericParams) { const id = genericParams.length; const globalId = etypeParam.id; const param = { id, ident: etypeParam.ident }; genericParams.push(param); this.globalIdToGenericParamMap.set(globalId, param); } } const params: VTypeParam[] = []; for (const param of stmt.kind.params) { if (param.etype === undefined) { this.report("parameter types must be defined", param.pos); stmt.kind.vtype = { type: "error" }; } const vtype = this.checkEType(param.etype!); param.vtype = vtype; params.push({ ident: param.ident, vtype }); } const returnType: VType = stmt.kind.returnType ? this.checkEType(stmt.kind.returnType) : { type: "null" }; stmt.kind.vtype = { type: "fn", genericParams, params, returnType, stmtId: stmt.id, }; } } public checkStmt(stmt: Stmt) { switch (stmt.kind.type) { case "error": return { type: "error" }; case "break": return this.checkBreakStmt(stmt); case "return": return this.checkReturnStmt(stmt); case "fn": return this.checkFnStmt(stmt); case "let": return this.checkLetStmt(stmt); case "assign": return this.checkAssignStmt(stmt); case "expr": return this.checkExpr(stmt.kind.expr); } } public checkBreakStmt(stmt: Stmt) { if (stmt.kind.type !== "break") { throw new Error(); } const pos = stmt.pos; if (this.loopBreakStack.length === 0) { this.report("cannot break outside loop context", pos); return; } const exprType: VType = stmt.kind.expr ? this.checkExpr(stmt.kind.expr) : { type: "null" }; const breakTypes = this.loopBreakStack.at(-1)!; if (breakTypes.length === 0) { breakTypes.push(exprType); return; } const prevBreakType = breakTypes.at(-1)!; if (!vtypesEqual(prevBreakType, exprType)) { this.report( `incompatible types for break` + `, got ${exprType}` + ` incompatible with ${prevBreakType}`, pos, ); return; } breakTypes.push(exprType); } public checkReturnStmt(stmt: Stmt) { if (stmt.kind.type !== "return") { throw new Error(); } const pos = stmt.pos; if (this.fnReturnStack.length === 0) { this.report("cannot return outside fn context", pos); return; } const exprType: VType = stmt.kind.expr ? this.checkExpr(stmt.kind.expr) : { type: "null" }; const returnType = this.fnReturnStack.at(-1)!; if (!vtypesEqual(exprType, returnType)) { this.report( `incompatible return type` + `, got ${exprType}` + `, expected ${returnType}`, pos, ); } } public checkFnStmt(stmt: Stmt) { if (stmt.kind.type !== "fn") { throw new Error(); } const pos = stmt.pos; if (stmt.kind.vtype!.type !== "fn") { throw new Error(); } if ( stmt.kind.anno?.ident === "remainder" || stmt.kind.anno?.ident === "builtin" ) { return; } const { returnType } = stmt.kind.vtype!; if (returnType.type === "error") return returnType; this.fnReturnStack.push(returnType); const body = this.checkExpr(stmt.kind.body); if (body.type === "error") return body; this.fnReturnStack.pop(); if (!vtypesEqual(returnType, body)) { this.report( `incompatible return type` + `, expected '${vtypeToString(returnType)}'` + `, got '${vtypeToString(body)}'`, pos, ); } } public checkLetStmt(stmt: Stmt) { if (stmt.kind.type !== "let") { throw new Error(); } const pos = stmt.pos; const value = this.checkExpr(stmt.kind.value); if (value.type === "error") { return stmt.kind.param.vtype = value; } if (stmt.kind.param.etype) { const paramVType = this.checkEType(stmt.kind.param.etype); if (paramVType.type === "error") return paramVType; if (!vtypesEqual(value, paramVType)) { this.report( `incompatible value type` + `, got '${vtypeToString(value)}'` + `, expected '${vtypeToString(paramVType)}'`, pos, ); return; } } stmt.kind.param.vtype = value; } public checkAssignStmt(stmt: Stmt) { if (stmt.kind.type !== "assign") { throw new Error(); } const pos = stmt.pos; if (stmt.kind.assignType !== "=") { throw new Error("invalid ast: compound assign should be desugered"); } const value = this.checkExpr(stmt.kind.value); switch (stmt.kind.subject.kind.type) { case "field": { const subject = this.checkExpr(stmt.kind.subject.kind.subject); if (subject.type !== "struct") { this.report("cannot use field on non-struct", pos); return { type: "error" }; } const fieldValue = stmt.kind.subject.kind.ident; const found = subject.fields.find((param) => param.ident === fieldValue ); if (!found) { this.report( `no field named '${stmt.kind.subject.kind.ident}' on struct`, pos, ); return { type: "error" }; } if (!vtypesEqual(found.vtype, value)) { this.report( `cannot assign incompatible type to field '${found.ident}'` + `, got '${vtypeToString(value)}'` + `, expected '${vtypeToString(found.vtype)}'`, pos, ); return; } return; } case "index": { const subject = this.checkExpr(stmt.kind.subject.kind.subject); if (subject.type !== "array" && subject.type !== "string") { this.report( `cannot index on non-array, got: ${subject.type}`, pos, ); return { type: "error" }; } const indexValue = this.checkExpr(stmt.kind.subject.kind.value); if (indexValue.type !== "int") { this.report("cannot index on array with non-int", pos); return { type: "error" }; } if ( subject.type == "array" && !vtypesEqual(subject.inner, value) ) { this.report( `cannot assign incompatible type to array ` + `'${vtypeToString(subject)}'` + `, got '${vtypeToString(value)}'`, pos, ); return; } return; } case "sym": { if (stmt.kind.subject.kind.sym.type !== "let") { this.report("cannot only assign to let-symbol", pos); return { type: "error" }; } if ( !vtypesEqual(stmt.kind.subject.kind.sym.param.vtype!, value) ) { this.report( `cannot assign to incompatible type` + `, got '${vtypeToString(value)}'` + `, expected '${ vtypeToString( stmt.kind.subject.kind.sym.param.vtype!, ) }'`, pos, ); return; } return; } default: this.report("unassignable expression", pos); return; } } public checkExpr(expr: Expr): VType { const vtype = ((): VType => { switch (expr.kind.type) { case "error": throw new Error("error in AST"); case "ident": if (this.reporter.errorOccured()) { return { type: "error" }; } throw new Error("ident expr in AST"); case "sym": return this.checkSymExpr(expr); case "null": return { type: "null" }; case "int": return { type: "int" }; case "bool": return { type: "bool" }; case "string": return { type: "string" }; case "group": return this.checkExpr(expr.kind.expr); case "field": return this.checkFieldExpr(expr); case "index": return this.checkIndexExpr(expr); case "call": return this.checkCallExpr(expr); case "path": return this.checkPathExpr(expr); case "etype_args": return this.checkETypeArgsExpr(expr); case "unary": return this.checkUnaryExpr(expr); case "binary": return this.checkBinaryExpr(expr); case "if": return this.checkIfExpr(expr); case "loop": return this.checkLoopExpr(expr); case "while": case "for_in": case "for": throw new Error( "invalid ast: special loops should be desugered", ); case "block": return this.checkBlockExpr(expr); } // throw new Error(`unhandled type ${expr.kind.type}`); })(); return expr.vtype = vtype; } public checkSymExpr(expr: Expr): VType { if (expr.kind.type !== "sym") { throw new Error(); } switch (expr.kind.sym.type) { case "let": return expr.kind.sym.param.vtype!; case "fn": { const fnStmt = expr.kind.sym.stmt!; if (fnStmt.kind.type !== "fn") { throw new Error(); } const vtype = fnStmt.kind.vtype!; if (vtype.type !== "fn") { throw new Error(); } return vtype; } case "fn_param": return expr.kind.sym.param.vtype!; case "let_static": case "closure": case "generic": throw new Error( `not implemented, sym type '${expr.kind.sym.type}'`, ); } } public checkFieldExpr(expr: Expr): VType { if (expr.kind.type !== "field") { throw new Error(); } const pos = expr.pos; const subject = this.checkExpr(expr.kind.subject); if (subject.type !== "struct") { this.report("cannot use field on non-struct", pos); return { type: "error" }; } const value = expr.kind.ident; const found = subject.fields.find((param) => param.ident === value); if (!found) { this.report( `no field named '${expr.kind.ident}' on struct`, pos, ); return { type: "error" }; } return found.vtype; } public checkIndexExpr(expr: Expr): VType { if (expr.kind.type !== "index") { throw new Error(); } const pos = expr.pos; const subject = this.checkExpr(expr.kind.subject); if (subject.type !== "array" && subject.type !== "string") { this.report(`cannot index on non-array, got: ${subject.type}`, pos); return { type: "error" }; } const value = this.checkExpr(expr.kind.value); if (value.type !== "int") { this.report("cannot index on array with non-int", pos); return { type: "error" }; } if (subject.type === "array") { return subject.inner; } return { type: "int" }; } public checkCallExpr(expr: Expr): VType { if (expr.kind.type !== "call") { throw new Error(); } const pos = expr.pos; const subject = this.checkExpr(expr.kind.subject); if (subject.type === "error") return subject; if (subject.type === "fn") { if (expr.kind.args.length !== subject.params.length) { this.report( `incorrect number of arguments` + `, expected ${subject.params.length}`, pos, ); } const args = expr.kind.args.map((arg) => this.checkExpr(arg)); if (subject.genericParams === undefined) { return this.checkCallExprNoGenericsTail( expr, subject, args, pos, ); } return this.checkCallExprInferredGenericsTail( expr, subject, args, pos, ); } if (subject.type === "generic_spec" && subject.subject.type === "fn") { return this.checkCallExprExplicitGenericsTail(expr, subject); } this.report("cannot call non-fn", pos); return { type: "error" }; } private checkCallExprNoGenericsTail( expr: Expr, subject: VType, args: VType[], pos: Pos, ): VType { if ( expr.kind.type !== "call" || subject.type !== "fn" ) { throw new Error(); } for (let i = 0; i < args.length; ++i) { if (this.vtypeContainsGeneric(args[i])) { this.report( `amfibious generic parameter for argument ${i}, please specify generic types explicitly`, pos, ); return { type: "error" }; } } for (let i = 0; i < args.length; ++i) { if (!vtypesEqual(args[i], subject.params[i].vtype)) { this.report( `incorrect argument ${i} '${subject.params[i].ident}'` + `, expected ${vtypeToString(subject.params[i].vtype)}` + `, got ${vtypeToString(args[i])}`, pos, ); break; } } return subject.returnType; } private checkCallExprInferredGenericsTail( expr: Expr, subject: VType, args: VType[], pos: Pos, ): VType { if ( expr.kind.type !== "call" || subject.type !== "fn" || subject.genericParams === undefined ) { throw new Error(); } const genericArgsRes = this.inferGenericArgs( subject.genericParams, subject.params, args, pos, ); if (!genericArgsRes.ok) { return { type: "error" }; } const genericArgs = genericArgsRes.value; for (let i = 0; i < args.length; ++i) { const vtypeCompatible = vtypesEqual( args[i], subject.params[i].vtype, genericArgs, ); if (!vtypeCompatible) { this.report( `incorrect argument ${i} '${subject.params[i].ident}'` + `, expected ${ vtypeToString( extractGenericType( subject.params[i].vtype, genericArgs, ), ) }` + `, got ${vtypeToString(args[i])}`, pos, ); break; } } expr.kind.genericArgs = genericArgs; return this.concretizeVType(subject.returnType, genericArgs); } private inferGenericArgs( genericParams: VTypeGenericParam[], params: VTypeParam[], args: VType[], pos: Pos, ): { ok: true; value: GenericArgsMap } | { ok: false } { const genericArgs: GenericArgsMap = {}; for (let i = 0; i < params.length; ++i) { if (!this.vtypeContainsGeneric(params[i].vtype)) { continue; } const { a: generic, b: concrete, } = this.reduceToSignificant(params[i].vtype, args[i]); if (generic.type !== "generic") { throw new Error(); } const paramId = generic.param.id; if ( paramId in genericArgs && !vtypesEqual(genericArgs[paramId], concrete) ) { this.report( `according to inferrence, argument ${i} has a conflicting type`, pos, ); return { ok: false }; } genericArgs[paramId] = concrete; } for (const param of genericParams) { if (!(param.id in genericArgs)) { this.report(`could not infer generic type ${param.ident}`, pos); return { ok: false }; } } return { ok: true, value: genericArgs }; } private reduceToSignificant(a: VType, b: VType): { a: VType; b: VType } { if (a.type !== b.type) { return { a, b }; } if (a.type === "array" && b.type === "array") { return this.reduceToSignificant(a.inner, b.inner); } if (a.type === "generic" && b.type === "generic") { return { a, b }; } throw new Error("idk what to do here"); } private vtypeContainsGeneric(vtype: VType): boolean { switch (vtype.type) { case "error": case "string": case "unknown": case "null": case "int": case "bool": return false; case "array": return this.vtypeContainsGeneric(vtype.inner); case "struct": return vtype.fields.some((field) => this.vtypeContainsGeneric(field.vtype) ); case "fn": throw new Error("not implemented"); case "generic": return true; case "generic_spec": throw new Error("listen kid, grrrrrrrr"); } } private checkCallExprExplicitGenericsTail( expr: Expr, subject: VType, ): VType { if ( expr.kind.type !== "call" || subject.type !== "generic_spec" || subject.subject.type !== "fn" ) { throw new Error(); } const pos = expr.pos; const inner = subject.subject; const params = inner.params; const args = expr.kind.args.map((arg) => this.checkExpr(arg)); if (args.length !== params.length) { this.report( `incorrect number of arguments` + `, expected ${params.length}`, pos, ); } for (let i = 0; i < args.length; ++i) { const vtypeCompatible = vtypesEqual( args[i], params[i].vtype, subject.genericArgs, ); if (!vtypeCompatible) { this.report( `incorrect argument ${i} '${inner.params[i].ident}'` + `, expected ${ vtypeToString( extractGenericType( params[i].vtype, subject.genericArgs, ), ) }` + `, got ${vtypeToString(args[i])}`, pos, ); break; } } expr.kind.genericArgs = subject.genericArgs; return this.concretizeVType( subject.subject.returnType, subject.genericArgs, ); } private concretizeVType( vtype: VType, generics: GenericArgsMap, ): VType { switch (vtype.type) { case "error": case "unknown": case "string": case "null": case "int": case "bool": return vtype; case "array": return { type: "array", inner: this.concretizeVType(vtype.inner, generics), }; case "struct": return { type: "struct", fields: vtype.fields.map((field) => ({ ...field, vtype: this.concretizeVType(field.vtype, generics), })), }; case "fn": throw new Error("not implemented"); case "generic": return generics[vtype.param.id]; case "generic_spec": throw new Error("not implemented"); } } public checkPathExpr(expr: Expr): VType { if (expr.kind.type !== "path") { throw new Error(); } throw new Error("not implemented"); } public checkETypeArgsExpr(expr: Expr): VType { if (expr.kind.type !== "etype_args") { throw new Error(); } const pos = expr.pos; const subject = this.checkExpr(expr.kind.subject); if (subject.type !== "fn" || subject.genericParams === undefined) { this.report( "etype arguments must only be applied to generic functions", expr.pos, ); return { type: "error" }; } const args = expr.kind.etypeArgs; if (args.length !== subject.genericParams.length) { this.report( `incorrect number of arguments` + `, expected ${subject.params.length}`, pos, ); } const genericArgs: GenericArgsMap = {}; for (let i = 0; i < args.length; ++i) { const etype = this.checkEType(args[i]); genericArgs[subject.genericParams[i].id] = etype; } return { type: "generic_spec", subject, genericArgs, }; } public checkUnaryExpr(expr: Expr): VType { if (expr.kind.type !== "unary") { throw new Error(); } const pos = expr.pos; const subject = this.checkExpr(expr.kind.subject); if (subject.type === "error") return subject; for (const operation of simpleUnaryOperations) { if (operation.unaryType !== expr.kind.unaryType) { continue; } if (!vtypesEqual(operation.operand, subject)) { continue; } return operation.result ?? operation.operand; } this.report( `cannot apply unary operation '${expr.kind.unaryType}' ` + `on type '${vtypeToString(subject)}'`, pos, ); return { type: "error" }; } public checkBinaryExpr(expr: Expr): VType { if (expr.kind.type !== "binary") { throw new Error(); } const pos = expr.pos; const left = this.checkExpr(expr.kind.left); if (left.type === "error") return left; const right = this.checkExpr(expr.kind.right); if (right.type === "error") return right; for (const operation of simpleBinaryOperations) { if (operation.binaryType !== expr.kind.binaryType) { continue; } if (!vtypesEqual(operation.operand, left)) { continue; } if (!vtypesEqual(left, right)) { continue; } return operation.result ?? operation.operand; } this.report( `cannot apply binary operation '${expr.kind.binaryType}' ` + `on types '${vtypeToString(left)}' and '${ vtypeToString(right) }'`, pos, ); return { type: "error" }; } public checkIfExpr(expr: Expr): VType { if (expr.kind.type !== "if") { throw new Error(); } const pos = expr.pos; const cond = this.checkExpr(expr.kind.cond); if (cond.type === "error") return cond; const truthy = this.checkExpr(expr.kind.truthy); if (truthy.type === "error") return truthy; const falsy = expr.kind.falsy ? this.checkExpr(expr.kind.falsy) : undefined; if (falsy?.type === "error") return falsy; if (cond.type !== "bool") { this.report( `if condition should be 'bool', got '${vtypeToString(cond)}'`, pos, ); return { type: "error" }; } if (falsy === undefined && truthy.type !== "null") { this.report( `if expressions without false-case must result in type 'null'` + `, got '${vtypeToString(truthy)}'`, pos, ); return { type: "error" }; } if (falsy !== undefined && !vtypesEqual(truthy, falsy)) { this.report( `if cases must be compatible, got incompatible types` + ` '${vtypeToString(truthy)}'` + ` and '${vtypeToString(falsy)}'`, pos, ); return { type: "error" }; } return truthy; } public checkLoopExpr(expr: Expr): VType { if (expr.kind.type !== "loop") { throw new Error(); } const pos = expr.pos; this.loopBreakStack.push([]); const body = this.checkExpr(expr.kind.body); if (body.type !== "null") { this.report( `loop body must result in type 'null'` + `, got '${vtypeToString(body)}'`, pos, ); return { type: "error" }; } const loopBreakTypes = this.loopBreakStack.pop()!; if (loopBreakTypes.length === 0) { return { type: "null" }; } const breakType = loopBreakTypes.reduce<[VType, boolean, VType]>( (acc, curr) => { const [resulting, isIncompatible, outlier] = acc; if (isIncompatible) { return acc; } if (!vtypesEqual(resulting, curr)) { return [resulting, true, curr]; } return [resulting, false, outlier]; }, [{ type: "null" }, false, { type: "null" }], ); if (breakType[1]) { this.report( `incompatible types in break statements` + `, got '${vtypeToString(breakType[2])}'` + ` incompatible with ${vtypeToString(breakType[0])}`, pos, ); return { type: "error" }; } return breakType[0]; } public checkBlockExpr(expr: Expr): VType { if (expr.kind.type !== "block") { throw new Error(); } this.checkFnHeaders(expr.kind.stmts); for (const stmt of expr.kind.stmts) { this.checkStmt(stmt); } return expr.kind.expr ? this.checkExpr(expr.kind.expr) : { type: "null" }; } public checkEType(etype: EType): VType { const pos = etype.pos; switch (etype.kind.type) { case "null": return { type: "null" }; case "int": return { type: "int" }; case "bool": return { type: "bool" }; case "string": return { type: "string" }; } if (etype.kind.type === "ident") { this.report(`undefined type '${etype.kind.ident}'`, pos); return { type: "error" }; } if (etype.kind.type === "sym") { if (etype.kind.sym.type === "generic") { const { id: globalId, ident } = etype.kind.sym.genericParam; if (!this.globalIdToGenericParamMap.has(globalId)) { throw new Error(); } const { id } = this.globalIdToGenericParamMap.get(globalId)!; return { type: "generic", param: { id, ident } }; } this.report(`sym type '${etype.kind.sym.type}' used as type`, pos); return { type: "error" }; } if (etype.kind.type === "array") { const inner = this.checkEType(etype.kind.inner); return { type: "array", inner }; } if (etype.kind.type === "struct") { const noTypeTest = etype.kind.fields.reduce( (acc, param) => [acc[0] || !param.etype, param.ident], [false, ""], ); if (noTypeTest[0]) { this.report( `field '${noTypeTest[1]}' declared without type`, pos, ); return { type: "error" }; } const declaredTwiceTest = etype.kind.fields.reduce< [boolean, string[], string] >( (acc, curr) => { if (acc[0]) { return acc; } if (acc[1].includes(curr.ident)) { return [true, acc[1], curr.ident]; } return [false, [...acc[1], curr.ident], ""]; }, [false, [], ""], ); if ( declaredTwiceTest[0] ) { this.report(`field ${declaredTwiceTest[2]} defined twice`, pos); return { type: "error" }; } const fields = etype.kind.fields.map((param): VTypeParam => ({ ident: param.ident, vtype: this.checkEType(param.etype!), })); return { type: "struct", fields }; } throw new Error(`unknown explicit type ${etype.kind.type}`); } private report(msg: string, pos: Pos) { this.reporter.reportError({ reporter: "Checker", msg, pos }); printStackTrace(); } } const simpleUnaryOperations: { unaryType: string; operand: VType; result?: VType; }[] = [ { unaryType: "not", operand: { type: "bool" } }, { unaryType: "-", operand: { type: "int" } }, ]; const simpleBinaryOperations: { binaryType: string; operand: VType; result?: VType; }[] = [ // arithmetic { binaryType: "+", operand: { type: "int" } }, { binaryType: "+", operand: { type: "string" } }, { binaryType: "-", operand: { type: "int" } }, { binaryType: "*", operand: { type: "int" } }, { binaryType: "/", operand: { type: "int" } }, // logical { binaryType: "and", operand: { type: "bool" } }, { binaryType: "or", operand: { type: "bool" } }, // equality { binaryType: "==", operand: { type: "null" }, result: { type: "bool" } }, { binaryType: "==", operand: { type: "int" }, result: { type: "bool" } }, { binaryType: "==", operand: { type: "string" }, result: { type: "bool" } }, { binaryType: "==", operand: { type: "bool" }, result: { type: "bool" } }, { binaryType: "!=", operand: { type: "null" }, result: { type: "bool" } }, { binaryType: "!=", operand: { type: "int" }, result: { type: "bool" } }, { binaryType: "!=", operand: { type: "string" }, result: { type: "bool" } }, { binaryType: "!=", operand: { type: "bool" }, result: { type: "bool" } }, // comparison { binaryType: "<", operand: { type: "int" }, result: { type: "bool" } }, { binaryType: ">", operand: { type: "int" }, result: { type: "bool" } }, { binaryType: "<=", operand: { type: "int" }, result: { type: "bool" } }, { binaryType: ">=", operand: { type: "int" }, result: { type: "bool" } }, ];