From cab2c9baa3fc27f6e0754c52e53e976c7d5ee253 Mon Sep 17 00:00:00 2001 From: sfja Date: Thu, 26 Dec 2024 01:51:05 +0100 Subject: [PATCH] generics work --- compiler/ast.ts | 1 + compiler/checker.ts | 120 ++++++-- compiler/compiler.ts | 16 +- compiler/info.ts | 2 +- compiler/lowerer.ts | 9 + compiler/mono.ts | 262 +++++++++++++++++ compiler/mono_lower.ts | 618 +++++++++++++++++++++++++++++++++++++++++ compiler/parser.ts | 6 +- compiler/vtype.ts | 53 +++- stdlib.slg | 2 + tests/generics.slg | 6 + 11 files changed, 1051 insertions(+), 44 deletions(-) create mode 100644 compiler/mono.ts create mode 100644 compiler/mono_lower.ts diff --git a/compiler/ast.ts b/compiler/ast.ts index 8a9c7b0..1bfdbd3 100644 --- a/compiler/ast.ts +++ b/compiler/ast.ts @@ -125,6 +125,7 @@ export type ETypeKind = | { type: "struct"; fields: Param[] }; export type GenericParam = { + id: number; ident: string; pos: Pos; vtype?: VType; diff --git a/compiler/checker.ts b/compiler/checker.ts index d300d9e..b86312f 100644 --- a/compiler/checker.ts +++ b/compiler/checker.ts @@ -2,6 +2,8 @@ import { EType, Expr, Stmt } from "./ast.ts"; import { printStackTrace, Reporter } from "./info.ts"; import { Pos } from "./token.ts"; import { + extractGenericType, + GenericArgsMap, VType, VTypeGenericParam, VTypeParam, @@ -34,7 +36,8 @@ export class Checker { if (stmt.kind.genericParams !== undefined) { genericParams = []; for (const etypeParam of stmt.kind.genericParams) { - genericParams.push({ ident: etypeParam.ident }); + const id = genericParams.length; + genericParams.push({ id, ident: etypeParam.ident }); } } const params: VTypeParam[] = []; @@ -47,7 +50,13 @@ export class Checker { param.vtype = vtype; params.push({ ident: param.ident, vtype }); } - stmt.kind.vtype = { type: "fn", genericParams, params, returnType }; + stmt.kind.vtype = { + type: "fn", + genericParams, + params, + returnType, + stmtId: stmt.id, + }; } } @@ -336,8 +345,7 @@ export class Checker { if (vtype.type !== "fn") { throw new Error(); } - const { params, returnType } = vtype; - return { type: "fn", params, returnType }; + return vtype; } case "fn_param": return expr.kind.sym.param.vtype!; @@ -399,30 +407,75 @@ export class Checker { } const pos = expr.pos; const subject = this.checkExpr(expr.kind.subject); - if (subject.type !== "fn") { - this.report("cannot call non-fn", pos); - return { type: "error" }; - } - const args = expr.kind.args.map((arg) => this.checkExpr(arg)); - if (args.length !== subject.params.length) { - this.report( - `incorrect number of arguments` + - `, expected ${subject.params.length}`, - pos, - ); - } - for (let i = 0; i < args.length; ++i) { - if (!vtypesEqual(args[i], subject.params[i].vtype)) { + if (subject.type === "error") return subject; + if (subject.type === "fn") { + if (subject.genericParams !== undefined) { + throw new Error("😭😭😭"); + } + const args = expr.kind.args.map((arg) => this.checkExpr(arg)); + if (args.length !== subject.params.length) { this.report( - `incorrect argument ${i} '${subject.params[i].ident}'` + - `, expected ${vtypeToString(subject.params[i].vtype)}` + - `, got ${vtypeToString(args[i])}`, + `incorrect number of arguments` + + `, expected ${subject.params.length}`, pos, ); - break; } + for (let i = 0; i < args.length; ++i) { + if (!vtypesEqual(args[i], subject.params[i].vtype)) { + this.report( + `incorrect argument ${i} '${subject.params[i].ident}'` + + `, expected ${ + vtypeToString(subject.params[i].vtype) + }` + + `, got ${vtypeToString(args[i])}`, + pos, + ); + break; + } + } + return subject.returnType; } - return subject.returnType; + if (subject.type === "generic_spec" && subject.subject.type === "fn") { + const inner = subject.subject; + const params = inner.params; + const args = expr.kind.args.map((arg) => this.checkExpr(arg)); + if (args.length !== params.length) { + this.report( + `incorrect number of arguments` + + `, expected ${params.length}`, + pos, + ); + } + for (let i = 0; i < args.length; ++i) { + const vtypeCompatible = vtypesEqual( + args[i], + params[i].vtype, + subject.genericArgs, + ); + if (!vtypeCompatible) { + this.report( + `incorrect argument ${i} '${inner.params[i].ident}'` + + `, expected ${ + vtypeToString( + extractGenericType( + params[i].vtype, + subject.genericArgs, + ), + ) + }` + + `, got ${vtypeToString(args[i])}`, + pos, + ); + break; + } + } + return extractGenericType( + subject.subject.returnType, + subject.genericArgs, + ); + } + this.report("cannot call non-fn", pos); + return { type: "error" }; } public checkPathExpr(expr: Expr): VType { @@ -445,20 +498,23 @@ export class Checker { ); return { type: "error" }; } - const genericParams = expr.kind.etypeArgs.map((arg) => - this.checkEType(arg) - ); - if (genericParams.length !== subject.params.length) { + const args = expr.kind.etypeArgs; + if (args.length !== subject.params.length) { this.report( `incorrect number of arguments` + `, expected ${subject.params.length}`, pos, ); } + const genericArgs: GenericArgsMap = {}; + for (let i = 0; i < args.length; ++i) { + const etype = this.checkEType(args[i]); + genericArgs[subject.genericParams[i].id] = etype; + } return { type: "generic_spec", subject, - genericParams, + genericArgs, }; } @@ -491,7 +547,9 @@ export class Checker { } const pos = expr.pos; const left = this.checkExpr(expr.kind.left); + if (left.type === "error") return left; const right = this.checkExpr(expr.kind.right); + if (right.type === "error") return right; for (const operation of simpleBinaryOperations) { if (operation.binaryType !== expr.kind.binaryType) { continue; @@ -520,10 +578,13 @@ export class Checker { } const pos = expr.pos; const cond = this.checkExpr(expr.kind.cond); + if (cond.type === "error") return cond; const truthy = this.checkExpr(expr.kind.truthy); + if (truthy.type === "error") return truthy; const falsy = expr.kind.falsy ? this.checkExpr(expr.kind.falsy) : undefined; + if (falsy?.type === "error") return falsy; if (cond.type !== "bool") { this.report( `if condition should be 'bool', got '${vtypeToString(cond)}'`, @@ -626,7 +687,8 @@ export class Checker { } if (etype.kind.type === "sym") { if (etype.kind.sym.type === "generic") { - return { type: "generic" }; + const { id, ident } = etype.kind.sym.genericParam; + return { type: "generic", param: { id, ident } }; } this.report(`sym type '${etype.kind.sym.type}' used as type`, pos); return { type: "error" }; diff --git a/compiler/compiler.ts b/compiler/compiler.ts index 2b44225..827d227 100644 --- a/compiler/compiler.ts +++ b/compiler/compiler.ts @@ -5,6 +5,8 @@ import { SpecialLoopDesugarer } from "./desugar/special_loop.ts"; import { Reporter } from "./info.ts"; import { Lexer } from "./lexer.ts"; import { FnNamesMap, Lowerer } from "./lowerer.ts"; +import { Monomorphizer } from "./mono.ts"; +import { MonoLowerer } from "./mono_lower.ts"; import { Parser } from "./parser.ts"; import { Resolver } from "./resolver.ts"; @@ -45,10 +47,16 @@ export class Compiler { Deno.exit(1); } - const lowerer = new Lowerer(lexer.currentPos()); - lowerer.lower(ast); - // lowerer.printProgram(); - const { program, fnNames } = lowerer.finish(); + const { monoFns, callMap } = new Monomorphizer(ast).monomorphize(); + + //const lowerer = new Lowerer(lexer.currentPos()); + //lowerer.lower(ast); + //// lowerer.printProgram(); + //const { program, fnNames } = lowerer.finish(); + + const lowerer = new MonoLowerer(monoFns, callMap, lexer.currentPos()); + const { program, fnNames } = lowerer.lower(); + lowerer.printProgram(); return { program, fnNames }; } diff --git a/compiler/info.ts b/compiler/info.ts index f51f991..55a3dad 100644 --- a/compiler/info.ts +++ b/compiler/info.ts @@ -42,7 +42,7 @@ export function printStackTrace() { } } try { - //throw new ReportNotAnError(); + throw new ReportNotAnError(); } catch (error) { if (!(error instanceof ReportNotAnError)) { throw error; diff --git a/compiler/lowerer.ts b/compiler/lowerer.ts index 64beadc..0b8b536 100644 --- a/compiler/lowerer.ts +++ b/compiler/lowerer.ts @@ -257,6 +257,8 @@ export class Lowerer { return this.lowerIndexExpr(expr); case "call": return this.lowerCallExpr(expr); + case "etype_args": + return this.lowerETypeArgsExpr(expr); case "unary": return this.lowerUnaryExpr(expr); case "binary": @@ -465,6 +467,13 @@ export class Lowerer { this.program.add(Ops.Call, expr.kind.args.length); } + private lowerETypeArgsExpr(expr: Expr) { + if (expr.kind.type !== "etype_args") { + throw new Error(); + } + throw new Error("not implemented"); + } + private lowerIfExpr(expr: Expr) { if (expr.kind.type !== "if") { throw new Error(); diff --git a/compiler/mono.ts b/compiler/mono.ts new file mode 100644 index 0000000..fa6d00a --- /dev/null +++ b/compiler/mono.ts @@ -0,0 +1,262 @@ +import { Expr, Stmt } from "./ast.ts"; +import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts"; +import { GenericArgsMap, VType } from "./vtype.ts"; + +export class Monomorphizer { + private fns: MonoFnsMap = {}; + private callMap: MonoCallNameGenMap = {}; + private allFns: Map; + private entryFn: Stmt; + + constructor(private ast: Stmt[]) { + this.allFns = new AllFnsCollector().collect(this.ast); + this.entryFn = findMain(this.allFns); + } + + public monomorphize(): MonoResult { + this.monomorphizeFn(this.entryFn); + return { monoFns: this.fns, callMap: this.callMap }; + } + + private monomorphizeFn( + stmt: Stmt, + genericArgs?: GenericArgsMap, + ): MonoFn { + const nameGen = monoFnNameGen(stmt, genericArgs); + if (nameGen in this.fns) { + return this.fns[nameGen]; + } + const monoFn = { nameGen, stmt, genericArgs }; + this.fns[nameGen] = monoFn; + const calls = new CallCollector().collect(stmt); + for (const call of calls) { + this.callMap[call.id] = nameGen; + if (call.kind.type !== "call") { + throw new Error(); + } + if (call.kind.subject.vtype?.type === "fn") { + const fn = this.allFns.get(call.kind.subject.vtype.stmtId); + if (fn === undefined) { + throw new Error(); + } + const monoFn = this.monomorphizeFn(fn); + this.callMap[call.id] = monoFn.nameGen; + continue; + } + if (call.kind.subject.vtype?.type === "generic_spec") { + const genericSpecType = call.kind.subject.vtype!; + if (genericSpecType.subject.type !== "fn") { + throw new Error(); + } + const fnType = genericSpecType.subject; + + const monoArgs: GenericArgsMap = {}; + for (const key in genericSpecType.genericArgs) { + const vtype = genericSpecType.genericArgs[key]; + if (vtype.type === "generic") { + if (genericArgs === undefined) { + throw new Error(); + } + monoArgs[key] = genericArgs[vtype.param.id]; + } else { + monoArgs[key] = vtype; + } + } + + const fn = this.allFns.get(fnType.stmtId); + if (fn === undefined) { + throw new Error(); + } + const monoFn = this.monomorphizeFn(fn, monoArgs); + this.callMap[call.id] = monoFn.nameGen; + continue; + } + throw new Error(); + } + return monoFn; + } +} + +export type MonoResult = { + monoFns: MonoFnsMap; + callMap: MonoCallNameGenMap; +}; + +export type MonoFnsMap = { [nameGen: string]: MonoFn }; + +export type MonoFn = { + nameGen: string; + stmt: Stmt; + genericArgs?: GenericArgsMap; +}; + +export type MonoCallNameGenMap = { [exprId: number]: string }; + +function monoFnNameGen(stmt: Stmt, genericArgs?: GenericArgsMap): string { + if (stmt.kind.type !== "fn") { + throw new Error(); + } + if (stmt.kind.ident === "main") { + return "main"; + } + if (genericArgs === undefined) { + return `${stmt.kind.ident}_${stmt.id}`; + } + const args = Object.values(genericArgs) + .map((arg) => vtypeNameGenPart(arg)) + .join("_"); + return `${stmt.kind.ident}_${stmt.id}_${args}`; +} + +function vtypeNameGenPart(vtype: VType): string { + switch (vtype.type) { + case "error": + throw new Error("error in type"); + case "string": + case "int": + case "bool": + case "null": + case "unknown": + return vtype.type; + case "array": + return `[${vtypeNameGenPart(vtype.inner)}]`; + case "struct": { + const fields = vtype.fields + .map((field) => + `${field.ident}, ${vtypeNameGenPart(field.vtype)}` + ) + .join(", "); + return `struct { ${fields} }`; + } + case "fn": + return `fn(${vtype.stmtId})`; + case "generic": + case "generic_spec": + throw new Error("cannot be monomorphized"); + } +} + +class AllFnsCollector implements AstVisitor { + private allFns = new Map(); + + public collect(ast: Stmt[]): Map { + visitStmts(ast, this); + return this.allFns; + } + + visitFnStmt(stmt: Stmt): VisitRes { + if (stmt.kind.type !== "fn") { + throw new Error(); + } + this.allFns.set(stmt.id, stmt); + } +} + +function findMain(fns: Map): Stmt { + const mainId = fns.values().find((stmt) => + stmt.kind.type === "fn" && stmt.kind.ident === "main" + ); + if (mainId === undefined) { + console.error("error: cannot find function 'main'"); + console.error(apology); + throw new Error("cannot find function 'main'"); + } + return mainId; +} + +class CallCollector implements AstVisitor { + private calls: Expr[] = []; + + public collect(fn: Stmt): Expr[] { + if (fn.kind.type !== "fn") { + throw new Error(); + } + visitExpr(fn.kind.body, this); + return this.calls; + } + + visitFnStmt(_stmt: Stmt): VisitRes { + return "stop"; + } + + visitCallExpr(expr: Expr): VisitRes { + if (expr.kind.type !== "call") { + throw new Error(); + } + this.calls.push(expr); + } +} + +const apology = ` + Hear me out. Monomorphization, meaning the process + inwich generic functions are stamped out into seperate + specialized functions is actually really hard, and I + have a really hard time right now, figuring out, how + to do it in a smart way. To really explain it, let's + imagine you have a function, you defined as a(). + For each call with seperate generics arguments given, + such as a::() and a::(), a specialized + function has to be 'stamped out', ie. created and put + into the compilation with the rest of the program. Now + to the reason as to why 'main' is needed. To do the + monomorphization, we have to do it recursively. To + explain this, imagine you have a generic function a + and inside the body of a, you call another generic + function such as b with the same generic type. This + means that the monomorphization process of b depends + on the monomorphization of a. What this essentially + means, is that the monomorphization process works on + the program as a call graph, meaning a graph or tree + structure where each represents a function call to + either another function or a recursive call to the + function itself. But a problem arises from doing it + this way, which is that a call graph will need an + entrypoint. The language, as it is currently, does + not really require a 'main'-function. Or maybe it + does, but that's beside the point. The point is that + we need a main function, to be the entry point for + the call graph. The monomorphization process then + runs through the program from that entry point. This + means that each function we call, will itself be + monomorphized and added to the compilation. It also + means that functions that are not called, will also + not be added to the compilation. This essentially + eliminates uncalled/dead functions. Is this + particularly smart to do in such a high level part + of the compilation process? I don't know. It's + obvious that we can't just use every function as + an entry point in the call graph, because we're + actively added new functions. Additionally, with + generic functions, we don't know, if they're the + entry point, what generic arguments, they should + be monomorphized with. We could do monomorphization + the same way C++ does it, where all non-generic + functions before monomorphization are treated as + entry points in the call graph. But this has the + drawback that generic and non-generic functions + are treated differently, which has many underlying + drawbacks, especially pertaining to the amount of + work needed to handle both in all proceeding steps + of the compiler. Anyways, I just wanted to yap and + complain about the way generics and monomorphization + has made the compiler 100x more complicated, and + that I find it really hard to implement in a way, + that is not either too simplistic or so complicated + and advanced I'm too dumb to implement it. So if + you would be so kind as to make it clear to the + compiler, what function it should designate as + the entry point to the call graph, it will use + for monomorphization, that would be very kind of + you. The way you do this, is by added or selecting + one of your current functions and giving it the + name of 'main'. This is spelled m-a-i-n. The word + is synonemous with the words primary and principle. + The name is meant to designate the entry point into + the program, which is why the monomorphization + process uses this specific function as the entry + point into the call graph, it generates. So if you + would be so kind as to do that, that would really + make my day. In any case, keep hacking ferociously + on whatever you're working on. I have monomorphizer + to implement. See ya. -Your favorite compiler girl <3 +`.replaceAll(" ", "").trim(); diff --git a/compiler/mono_lower.ts b/compiler/mono_lower.ts new file mode 100644 index 0000000..40dca45 --- /dev/null +++ b/compiler/mono_lower.ts @@ -0,0 +1,618 @@ +import { Builtins, Ops } from "./arch.ts"; +import { Assembler, Label } from "./assembler.ts"; +import { Expr, Stmt } from "./ast.ts"; +import { FnNamesMap } from "./lowerer.ts"; +import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts"; +import { MonoCallNameGenMap, MonoFn, MonoFnsMap } from "./mono.ts"; +import { Pos } from "./token.ts"; +import { vtypeToString } from "./vtype.ts"; + +export class MonoLowerer { + private program = Assembler.newRoot(); + + public constructor( + private monoFns: MonoFnsMap, + private callMap: MonoCallNameGenMap, + private lastPos: Pos, + ) {} + + public lower(): { program: number[]; fnNames: FnNamesMap } { + const fnLabelNameMap: FnLabelMap = {}; + for (const nameGen in this.monoFns) { + fnLabelNameMap[nameGen] = nameGen; + } + + this.addPrelimiary(); + + for (const fn of Object.values(this.monoFns)) { + const fnProgram = new MonoFnLowerer( + fn, + this.program.fork(), + this.callMap, + ).lower(); + this.program.join(fnProgram); + } + + this.addConcluding(); + + const { program, locs } = this.program.assemble(); + const fnNames: FnNamesMap = {}; + for (const label in locs) { + if (label in fnLabelNameMap) { + fnNames[locs[label]] = fnLabelNameMap[label]; + } + } + return { program, fnNames }; + } + + private addPrelimiary() { + this.addClearingSourceMap(); + this.program.add(Ops.PushPtr, { label: "main" }); + this.program.add(Ops.Call, 0); + this.program.add(Ops.PushPtr, { label: "_exit" }); + this.program.add(Ops.Jump); + } + + private addConcluding() { + this.program.setLabel({ label: "_exit" }); + this.addSourceMap(this.lastPos); + this.program.add(Ops.Pop); + } + + private addSourceMap({ index, line, col }: Pos) { + this.program.add(Ops.SourceMap, index, line, col); + } + + private addClearingSourceMap() { + this.program.add(Ops.SourceMap, 0, 1, 1); + } + + public printProgram() { + this.program.printProgram(); + } +} + +type FnLabelMap = { [nameGen: string]: string }; + +class MonoFnLowerer { + private locals: Locals = new LocalsFnRoot(); + private returnStack: Label[] = []; + private breakStack: Label[] = []; + + public constructor( + private fn: MonoFn, + private program: Assembler, + private callMap: MonoCallNameGenMap, + ) {} + + public lower(): Assembler { + this.lowerFnStmt(this.fn.stmt); + return this.program; + } + + private lowerFnStmt(stmt: Stmt) { + if (stmt.kind.type !== "fn") { + throw new Error(); + } + const label = this.fn.nameGen; + this.program.setLabel({ label }); + this.addSourceMap(stmt.pos); + + const outerLocals = this.locals; + const fnRoot = new LocalsFnRoot(outerLocals); + const outerProgram = this.program; + + const returnLabel = this.program.makeLabel(); + this.returnStack.push(returnLabel); + + this.program = outerProgram.fork(); + this.locals = fnRoot; + for (const { ident } of stmt.kind.params) { + this.locals.allocSym(ident); + } + if (stmt.kind.anno?.ident === "builtin") { + this.lowerFnBuiltinBody(stmt.kind.anno.values); + } else if (stmt.kind.anno?.ident === "remainder") { + this.program.add(Ops.Remainder); + } else { + this.lowerExpr(stmt.kind.body); + } + this.locals = outerLocals; + + const localAmount = fnRoot.stackReserved() - + stmt.kind.params.length; + for (let i = 0; i < localAmount; ++i) { + outerProgram.add(Ops.PushNull); + } + + this.returnStack.pop(); + this.program.setLabel(returnLabel); + this.program.add(Ops.Return); + + outerProgram.join(this.program); + this.program = outerProgram; + } + + private addSourceMap({ index, line, col }: Pos) { + this.program.add(Ops.SourceMap, index, line, col); + } + + private addClearingSourceMap() { + this.program.add(Ops.SourceMap, 0, 1, 1); + } + + private lowerStmt(stmt: Stmt) { + switch (stmt.kind.type) { + case "error": + break; + case "break": + return this.lowerBreakStmt(stmt); + case "return": + return this.lowerReturnStmt(stmt); + case "fn": + return this.lowerFnStmt(stmt); + case "let": + return this.lowerLetStmt(stmt); + case "assign": + return this.lowerAssignStmt(stmt); + case "expr": + this.lowerExpr(stmt.kind.expr); + this.program.add(Ops.Pop); + return; + } + throw new Error(`unhandled stmt '${stmt.kind.type}'`); + } + + private lowerAssignStmt(stmt: Stmt) { + if (stmt.kind.type !== "assign") { + throw new Error(); + } + this.lowerExpr(stmt.kind.value); + switch (stmt.kind.subject.kind.type) { + case "field": { + this.lowerExpr(stmt.kind.subject.kind.subject); + this.program.add(Ops.PushString, stmt.kind.subject.kind.ident); + this.program.add(Ops.Builtin, Builtins.StructSet); + return; + } + case "index": { + this.lowerExpr(stmt.kind.subject.kind.subject); + this.lowerExpr(stmt.kind.subject.kind.value); + this.program.add(Ops.Builtin, Builtins.ArraySet); + return; + } + case "sym": { + this.program.add( + Ops.StoreLocal, + this.locals.symId(stmt.kind.subject.kind.sym.ident), + ); + return; + } + default: + throw new Error(); + } + } + + private lowerReturnStmt(stmt: Stmt) { + if (stmt.kind.type !== "return") { + throw new Error(); + } + if (stmt.kind.expr) { + this.lowerExpr(stmt.kind.expr); + } + this.addClearingSourceMap(); + this.program.add(Ops.PushPtr, this.returnStack.at(-1)!); + this.program.add(Ops.Jump); + } + + private lowerBreakStmt(stmt: Stmt) { + if (stmt.kind.type !== "break") { + throw new Error(); + } + if (stmt.kind.expr) { + this.lowerExpr(stmt.kind.expr); + } + this.addClearingSourceMap(); + this.program.add(Ops.PushPtr, this.breakStack.at(-1)!); + this.program.add(Ops.Jump); + } + + private lowerFnBuiltinBody(annoArgs: Expr[]) { + if (annoArgs.length !== 1) { + throw new Error("invalid # of arguments to builtin annotation"); + } + const anno = annoArgs[0]; + if (anno.kind.type !== "ident") { + throw new Error( + `unexpected argument type '${anno.kind.type}' expected 'ident'`, + ); + } + const value = anno.kind.ident; + const builtin = Object.entries(Builtins).find((entry) => + entry[0] === value + )?.[1]; + if (builtin === undefined) { + throw new Error( + `unrecognized builtin '${value}'`, + ); + } + this.program.add(Ops.Builtin, builtin); + } + + private lowerLetStmt(stmt: Stmt) { + if (stmt.kind.type !== "let") { + throw new Error(); + } + this.lowerExpr(stmt.kind.value); + this.locals.allocSym(stmt.kind.param.ident); + this.program.add( + Ops.StoreLocal, + this.locals.symId(stmt.kind.param.ident), + ); + } + + private lowerExpr(expr: Expr) { + switch (expr.kind.type) { + case "error": + break; + case "sym": + return this.lowerSymExpr(expr); + case "null": + break; + case "int": + return this.lowerIntExpr(expr); + case "bool": + return this.lowerBoolExpr(expr); + case "string": + return this.lowerStringExpr(expr); + case "ident": + break; + case "group": + return void this.lowerExpr(expr.kind.expr); + case "field": + break; + case "index": + return this.lowerIndexExpr(expr); + case "call": + return this.lowerCallExpr(expr); + case "etype_args": + return this.lowerETypeArgsExpr(expr); + case "unary": + return this.lowerUnaryExpr(expr); + case "binary": + return this.lowerBinaryExpr(expr); + case "if": + return this.lowerIfExpr(expr); + case "loop": + return this.lowerLoopExpr(expr); + case "block": + return this.lowerBlockExpr(expr); + } + throw new Error(`unhandled expr '${expr.kind.type}'`); + } + + private lowerIndexExpr(expr: Expr) { + if (expr.kind.type !== "index") { + throw new Error(); + } + this.lowerExpr(expr.kind.subject); + this.lowerExpr(expr.kind.value); + + if (expr.kind.subject.vtype?.type == "array") { + this.program.add(Ops.Builtin, Builtins.ArrayAt); + return; + } + if (expr.kind.subject.vtype?.type == "string") { + this.program.add(Ops.Builtin, Builtins.StringCharAt); + return; + } + throw new Error(`unhandled index subject type '${expr.kind.subject}'`); + } + + private lowerSymExpr(expr: Expr) { + if (expr.kind.type !== "sym") { + throw new Error(); + } + if (expr.kind.sym.type === "let") { + const symId = this.locals.symId(expr.kind.ident); + this.program.add(Ops.LoadLocal, symId); + return; + } + if (expr.kind.sym.type === "fn_param") { + this.program.add( + Ops.LoadLocal, + this.locals.symId(expr.kind.ident), + ); + return; + } + if (expr.kind.sym.type === "fn") { + // Is this smart? Well, my presumption is + // that it isn't. The underlying problem, which + // this solutions raison d'être is to solve, is + // that the compiler, as it d'être's currently + // doesn't support checking and infering generic + // fn args all the way down to the sym. Therefore, + // when a sym is checked in a call expr, we can't + // really do anything useful. Instead the actual + // function pointer pointing to the actual + // monomorphized function is emplaced when + // lowering the call expression itself. But what + // should we do then, if the user decides to + // assign a function to a local? You might ask. + // You see, that's where the problem lies. + // My current, very thought out solution, as + // you can read below, is to push a null pointer, + // for it to then be replaced later. This will + // probably cause many hastles in the future + // for myself in particular, when trying to + // decipher the lowerer's output. So if you're + // the unlucky girl, who has tried for ages to + // decipher why a zero value is pushed and then + // later replaced, and then you finally + // stumbled upon this here implementation, + // let me first say, I'm so sorry. At the time + // of writing, I really haven't thought out + // very well, how the generic call system should + // work, and it's therefore a bit flaky, and the + // implementation kinda looks like it was + // implementated by a girl who didn't really + // understand very well what they were + // implementing at the time that they were + // implementing it. Anyway, I just wanted to + // apologize. Happy coding. + // -Your favorite compiler girl. + this.program.add(Ops.PushPtr, 0); + return; + } + throw new Error(`unhandled sym type '${expr.kind.sym.type}'`); + } + + private lowerIntExpr(expr: Expr) { + if (expr.kind.type !== "int") { + throw new Error(); + } + this.program.add(Ops.PushInt, expr.kind.value); + } + + private lowerBoolExpr(expr: Expr) { + if (expr.kind.type !== "bool") { + throw new Error(); + } + this.program.add(Ops.PushBool, expr.kind.value); + } + + private lowerStringExpr(expr: Expr) { + if (expr.kind.type !== "string") { + throw new Error(); + } + this.program.add(Ops.PushString, expr.kind.value); + } + + private lowerUnaryExpr(expr: Expr) { + if (expr.kind.type !== "unary") { + throw new Error(); + } + this.lowerExpr(expr.kind.subject); + const vtype = expr.kind.subject.vtype!; + if (vtype.type === "bool") { + switch (expr.kind.unaryType) { + case "not": + this.program.add(Ops.Not); + return; + default: + } + } + if (vtype.type === "int") { + switch (expr.kind.unaryType) { + case "-": { + this.program.add(Ops.PushInt, 0); + this.program.add(Ops.Swap); + this.program.add(Ops.Subtract); + return; + } + default: + } + } + throw new Error( + `unhandled unary` + + ` '${vtypeToString(expr.vtype!)}' aka. ` + + ` ${expr.kind.unaryType}` + + ` '${vtypeToString(expr.kind.subject.vtype!)}'`, + ); + } + + private lowerBinaryExpr(expr: Expr) { + if (expr.kind.type !== "binary") { + throw new Error(); + } + const vtype = expr.kind.left.vtype!; + if (vtype.type === "bool") { + if (["or", "and"].includes(expr.kind.binaryType)) { + const shortCircuitLabel = this.program.makeLabel(); + this.lowerExpr(expr.kind.left); + this.program.add(Ops.Duplicate); + if (expr.kind.binaryType === "and") { + this.program.add(Ops.Not); + } + this.program.add(Ops.PushPtr, shortCircuitLabel); + this.program.add(Ops.JumpIfTrue); + this.program.add(Ops.Pop); + this.lowerExpr(expr.kind.right); + this.program.setLabel(shortCircuitLabel); + return; + } + } + this.lowerExpr(expr.kind.left); + this.lowerExpr(expr.kind.right); + if (vtype.type === "int") { + switch (expr.kind.binaryType) { + case "+": + this.program.add(Ops.Add); + return; + case "-": + this.program.add(Ops.Subtract); + return; + case "*": + this.program.add(Ops.Multiply); + return; + case "/": + this.program.add(Ops.Multiply); + return; + case "==": + this.program.add(Ops.Equal); + return; + case "!=": + this.program.add(Ops.Equal); + this.program.add(Ops.Not); + return; + case "<": + this.program.add(Ops.LessThan); + return; + case ">": + this.program.add(Ops.Swap); + this.program.add(Ops.LessThan); + return; + case "<=": + this.program.add(Ops.Swap); + this.program.add(Ops.LessThan); + this.program.add(Ops.Not); + return; + case ">=": + this.program.add(Ops.LessThan); + this.program.add(Ops.Not); + return; + default: + } + } + if (vtype.type === "bool") { + switch (expr.kind.binaryType) { + case "==": + this.program.add(Ops.And); + return; + case "!=": + this.program.add(Ops.And); + this.program.add(Ops.Not); + return; + default: + } + } + if (vtype.type === "string") { + if (expr.kind.binaryType === "+") { + this.program.add(Ops.Builtin, Builtins.StringConcat); + return; + } + if (expr.kind.binaryType === "==") { + this.program.add(Ops.Builtin, Builtins.StringEqual); + return; + } + if (expr.kind.binaryType === "!=") { + this.program.add(Ops.Builtin, Builtins.StringEqual); + this.program.add(Ops.Not); + return; + } + } + throw new Error( + `unhandled binaryType` + + ` '${vtypeToString(expr.vtype!)}' aka. ` + + ` '${vtypeToString(expr.kind.left.vtype!)}'` + + ` ${expr.kind.binaryType}` + + ` '${vtypeToString(expr.kind.left.vtype!)}'`, + ); + } + + private lowerCallExpr(expr: Expr) { + if (expr.kind.type !== "call") { + throw new Error(); + } + for (const arg of expr.kind.args) { + this.lowerExpr(arg); + } + this.lowerExpr(expr.kind.subject); + this.program.add(Ops.Pop); + this.program.add(Ops.PushPtr, { label: this.callMap[expr.id] }); + this.program.add(Ops.Call, expr.kind.args.length); + } + + private lowerETypeArgsExpr(expr: Expr) { + if (expr.kind.type !== "etype_args") { + throw new Error(); + } + this.lowerExpr(expr.kind.subject); + } + + private lowerIfExpr(expr: Expr) { + if (expr.kind.type !== "if") { + throw new Error(); + } + + const falseLabel = this.program.makeLabel(); + const doneLabel = this.program.makeLabel(); + + this.lowerExpr(expr.kind.cond); + + this.program.add(Ops.Not); + this.addClearingSourceMap(); + this.program.add(Ops.PushPtr, falseLabel); + this.program.add(Ops.JumpIfTrue); + + this.addSourceMap(expr.kind.truthy.pos); + this.lowerExpr(expr.kind.truthy); + + this.addClearingSourceMap(); + this.program.add(Ops.PushPtr, doneLabel); + this.program.add(Ops.Jump); + + this.program.setLabel(falseLabel); + + if (expr.kind.falsy) { + this.addSourceMap(expr.kind.elsePos!); + this.lowerExpr(expr.kind.falsy); + } else { + this.program.add(Ops.PushNull); + } + + this.program.setLabel(doneLabel); + } + + private lowerLoopExpr(expr: Expr) { + if (expr.kind.type !== "loop") { + throw new Error(); + } + const continueLabel = this.program.makeLabel(); + const breakLabel = this.program.makeLabel(); + + this.breakStack.push(breakLabel); + + this.program.setLabel(continueLabel); + this.addSourceMap(expr.kind.body.pos); + this.lowerExpr(expr.kind.body); + this.program.add(Ops.Pop); + this.addClearingSourceMap(); + this.program.add(Ops.PushPtr, continueLabel); + this.program.add(Ops.Jump); + this.program.setLabel(breakLabel); + if (expr.vtype!.type === "null") { + this.program.add(Ops.PushNull); + } + this.breakStack.pop(); + } + + private lowerBlockExpr(expr: Expr) { + if (expr.kind.type !== "block") { + throw new Error(); + } + const outerLocals = this.locals; + this.locals = new LocalLeaf(this.locals); + for (const stmt of expr.kind.stmts) { + this.addSourceMap(stmt.pos); + this.lowerStmt(stmt); + } + if (expr.kind.expr) { + this.addSourceMap(expr.kind.expr.pos); + this.lowerExpr(expr.kind.expr); + } else { + this.program.add(Ops.PushNull); + } + this.locals = outerLocals; + } +} diff --git a/compiler/parser.ts b/compiler/parser.ts index c0c7ebe..9eb594c 100644 --- a/compiler/parser.ts +++ b/compiler/parser.ts @@ -269,12 +269,16 @@ export class Parser { return this.parseDelimitedList(this.parseETypeParam, ">", ","); } + private veryTemporaryETypeParamIdCounter = 0; + private parseETypeParam(): Res { const pos = this.pos(); if (this.test("ident")) { const ident = this.current().identValue!; this.step(); - return { ok: true, value: { ident, pos } }; + const id = this.veryTemporaryETypeParamIdCounter; + this.veryTemporaryETypeParamIdCounter += 1; + return { ok: true, value: { id, ident, pos } }; } this.report("expected generic parameter"); return { ok: false }; diff --git a/compiler/vtype.ts b/compiler/vtype.ts index 7363697..24bf59d 100644 --- a/compiler/vtype.ts +++ b/compiler/vtype.ts @@ -12,12 +12,13 @@ export type VType = genericParams?: VTypeGenericParam[]; params: VTypeParam[]; returnType: VType; + stmtId: number; } - | { type: "generic" } + | { type: "generic"; param: VTypeGenericParam } | { type: "generic_spec"; subject: VType; - genericParams: VType[]; + genericArgs: GenericArgsMap; }; export type VTypeParam = { @@ -26,21 +27,25 @@ export type VTypeParam = { }; export type VTypeGenericParam = { + id: number; ident: string; }; -export function vtypesEqual(a: VType, b: VType): boolean { - if (a.type !== b.type) { - return false; - } +export type GenericArgsMap = { [id: number]: VType }; + +export function vtypesEqual( + a: VType, + b: VType, + generics?: GenericArgsMap, +): boolean { if ( ["error", "unknown", "null", "int", "string", "bool"] - .includes(a.type) + .includes(a.type) && a.type === b.type ) { return true; } if (a.type === "array" && b.type === "array") { - return vtypesEqual(a.inner, b.inner); + return vtypesEqual(a.inner, b.inner, generics); } if (a.type === "fn" && b.type === "fn") { if (a.params.length !== b.params.length) { @@ -51,11 +56,41 @@ export function vtypesEqual(a: VType, b: VType): boolean { return false; } } - return vtypesEqual(a.returnType, b.returnType); + return vtypesEqual(a.returnType, b.returnType, generics); + } + if (a.type === "generic" && b.type === "generic") { + return a.param.id === b.param.id; + } + if ( + (a.type === "generic" || b.type === "generic") && + generics !== undefined + ) { + if (generics === undefined) { + throw new Error(); + } + + const generic = a.type === "generic" ? a : b; + const concrete = a.type === "generic" ? b : a; + + const genericType = extractGenericType(generic, generics); + return vtypesEqual(genericType, concrete, generics); } return false; } +export function extractGenericType( + generic: VType, + generics: GenericArgsMap, +): VType { + if (generic.type !== "generic") { + return generic; + } + if (!(generic.param.id in generics)) { + throw new Error("generic not found (not supposed to happen)"); + } + return generics[generic.param.id]; +} + export function vtypeToString(vtype: VType): string { if ( ["error", "unknown", "null", "int", "string", "bool"] diff --git a/stdlib.slg b/stdlib.slg index 566f169..1dba465 100644 --- a/stdlib.slg +++ b/stdlib.slg @@ -1,4 +1,6 @@ +fn exit(status_code: int) #[builtin(Exit)] {} + fn print(msg: string) #[builtin(Print)] {} fn println(msg: string) { print(msg + "\n") } diff --git a/tests/generics.slg b/tests/generics.slg index f396eed..4228a24 100644 --- a/tests/generics.slg +++ b/tests/generics.slg @@ -1,17 +1,23 @@ fn exit(status_code: int) #[builtin(Exit)] {} +fn print(msg: string) #[builtin(Print)] {} +fn println(msg: string) { print(msg + "\n") } + fn id(v: T) -> T { v } fn main() { + println("calling with int"); if id::(123) != 123 { exit(1); } + println("calling with bool"); if id::(true) != true { exit(1); } + println("all tests ran successfully"); exit(0); }