diff --git a/compiler/checker.ts b/compiler/checker.ts index d300d9e..e3673e1 100644 --- a/compiler/checker.ts +++ b/compiler/checker.ts @@ -12,6 +12,7 @@ import { export class Checker { private fnReturnStack: VType[] = []; private loopBreakStack: VType[][] = []; + private structIdCounter = 0; public constructor(private reporter: Reporter) {} @@ -47,7 +48,13 @@ export class Checker { param.vtype = vtype; params.push({ ident: param.ident, vtype }); } - stmt.kind.vtype = { type: "fn", genericParams, params, returnType }; + stmt.kind.vtype = { + type: "fn", + genericParams, + params, + returnType, + fnStmtId: stmt.id, + }; } } @@ -332,12 +339,7 @@ export class Checker { if (fnStmt.kind.type !== "fn") { throw new Error(); } - const vtype = fnStmt.kind.vtype!; - if (vtype.type !== "fn") { - throw new Error(); - } - const { params, returnType } = vtype; - return { type: "fn", params, returnType }; + return fnStmt.kind.vtype!; } case "fn_param": return expr.kind.sym.param.vtype!; @@ -445,10 +447,10 @@ export class Checker { ); return { type: "error" }; } - const genericParams = expr.kind.etypeArgs.map((arg) => + const genericArgs = expr.kind.etypeArgs.map((arg) => this.checkEType(arg) ); - if (genericParams.length !== subject.params.length) { + if (genericArgs.length !== subject.params.length) { this.report( `incorrect number of arguments` + `, expected ${subject.params.length}`, @@ -456,9 +458,9 @@ export class Checker { ); } return { - type: "generic_spec", + type: "generic_args", subject, - genericParams, + genericArgs, }; } @@ -671,7 +673,9 @@ export class Checker { ident: param.ident, vtype: this.checkEType(param.etype!), })); - return { type: "struct", fields }; + const structId = this.structIdCounter; + this.structIdCounter += 1; + return { type: "struct", structId, fields }; } throw new Error(`unknown explicit type ${etype.kind.type}`); } diff --git a/compiler/compiler.ts b/compiler/compiler.ts index 2b44225..fdd2bac 100644 --- a/compiler/compiler.ts +++ b/compiler/compiler.ts @@ -5,6 +5,7 @@ import { SpecialLoopDesugarer } from "./desugar/special_loop.ts"; import { Reporter } from "./info.ts"; import { Lexer } from "./lexer.ts"; import { FnNamesMap, Lowerer } from "./lowerer.ts"; +import { monomorphizeFunctionGraphs } from "./mfg.ts"; import { Parser } from "./parser.ts"; import { Resolver } from "./resolver.ts"; @@ -45,9 +46,11 @@ export class Compiler { Deno.exit(1); } + const monomorphizedFns = monomorphizeFunctionGraphs(ast); + const lowerer = new Lowerer(lexer.currentPos()); - lowerer.lower(ast); - // lowerer.printProgram(); + lowerer.lower(monomorphizedFns); + lowerer.printProgram(); const { program, fnNames } = lowerer.finish(); return { program, fnNames }; diff --git a/compiler/lowerer.ts b/compiler/lowerer.ts index 64beadc..d3c50dc 100644 --- a/compiler/lowerer.ts +++ b/compiler/lowerer.ts @@ -2,31 +2,42 @@ import { Builtins, Ops } from "./arch.ts"; import { Expr, Stmt } from "./ast.ts"; import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts"; import { Assembler, Label } from "./assembler.ts"; -import { vtypeToString } from "./vtype.ts"; +import { VType, vtypeToString } from "./vtype.ts"; import { Pos } from "./token.ts"; +import { fnCallMid, fnStmtMid, MonomorphizedFn } from "./mfg.ts"; export type FnNamesMap = { [pc: number]: string }; export class Lowerer { private program = Assembler.newRoot(); - private locals: Locals = new LocalsFnRoot(); - private fnStmtIdLabelMap: { [stmtId: number]: string } = {}; private fnLabelNameMap: { [name: string]: string } = {}; - private returnStack: Label[] = []; - private breakStack: Label[] = []; public constructor(private lastPos: Pos) {} - public lower(stmts: Stmt[]) { + public lower(fns: MonomorphizedFn[]) { this.addClearingSourceMap(); this.program.add(Ops.PushPtr, { label: "main" }); this.program.add(Ops.Call, 0); this.program.add(Ops.PushPtr, { label: "_exit" }); this.program.add(Ops.Jump); - this.scoutFnHeaders(stmts); - for (const stmt of stmts) { - this.lowerStaticStmt(stmt); + + const fnMidLabelMap: { [mid: string]: string } = {}; + for (const fn of fns) { + const mid = fnStmtMid(fn.stmt, fn.genericArgs); + fnMidLabelMap[mid] = mid; } + + for (const fn of fns) { + const fnProgram = new FnLowerer( + this.program.fork(), + fnMidLabelMap, + fn.stmt, + fn.genericArgs, + ) + .lowerFn(); + this.program.join(fnProgram); + } + this.program.setLabel({ label: "_exit" }); this.addSourceMap(this.lastPos); this.program.add(Ops.Pop); @@ -43,6 +54,10 @@ export class Lowerer { return { program, fnNames }; } + public printProgram() { + this.program.printProgram(); + } + private addSourceMap({ index, line, col }: Pos) { this.program.add(Ops.SourceMap, index, line, col); } @@ -50,31 +65,76 @@ export class Lowerer { private addClearingSourceMap() { this.program.add(Ops.SourceMap, 0, 1, 1); } +} - private scoutFnHeaders(stmts: Stmt[]) { - for (const stmt of stmts) { - if (stmt.kind.type !== "fn") { - continue; - } - const label = stmt.kind.ident === "main" - ? "main" - : `${stmt.kind.ident}_${stmt.id}`; - this.fnStmtIdLabelMap[stmt.id] = label; - } +class FnLowerer { + private locals: Locals = new LocalsFnRoot(); + private fnLabelNameMap: { [name: string]: string } = {}; + private returnStack: Label[] = []; + private breakStack: Label[] = []; + + public constructor( + private program: Assembler, + private fnMidLabelMap: { [mid: string]: string }, + private fnStmt: Stmt, + private genericArgs?: VType[], + ) {} + + public lowerFn(): Assembler { + this.lowerFnStmt(this.fnStmt); + return this.program; } - private lowerStaticStmt(stmt: Stmt) { - switch (stmt.kind.type) { - case "fn": - return this.lowerFnStmt(stmt); - case "error": - case "break": - case "return": - case "let": - case "assign": - case "expr": + private lowerFnStmt(stmt: Stmt) { + if (stmt.kind.type !== "fn") { + throw new Error(); } - throw new Error(`unhandled static statement '${stmt.kind.type}'`); + const label = fnStmtMid(stmt, this.genericArgs); + this.program.setLabel({ label }); + this.fnLabelNameMap[label] = stmt.kind.ident; + this.addSourceMap(stmt.pos); + + const outerLocals = this.locals; + const fnRoot = new LocalsFnRoot(outerLocals); + const outerProgram = this.program; + + const returnLabel = this.program.makeLabel(); + this.returnStack.push(returnLabel); + + this.program = outerProgram.fork(); + this.locals = fnRoot; + for (const { ident } of stmt.kind.params) { + this.locals.allocSym(ident); + } + if (stmt.kind.anno?.ident === "builtin") { + this.lowerFnBuiltinBody(stmt.kind.anno.values); + } else if (stmt.kind.anno?.ident === "remainder") { + this.program.add(Ops.Remainder); + } else { + this.lowerExpr(stmt.kind.body); + } + this.locals = outerLocals; + + const localAmount = fnRoot.stackReserved() - + stmt.kind.params.length; + for (let i = 0; i < localAmount; ++i) { + outerProgram.add(Ops.PushNull); + } + + this.returnStack.pop(); + this.program.setLabel(returnLabel); + this.program.add(Ops.Return); + + outerProgram.join(this.program); + this.program = outerProgram; + } + + private addSourceMap({ index, line, col }: Pos) { + this.program.add(Ops.SourceMap, index, line, col); + } + + private addClearingSourceMap() { + this.program.add(Ops.SourceMap, 0, 1, 1); } private lowerStmt(stmt: Stmt) { @@ -86,7 +146,7 @@ export class Lowerer { case "return": return this.lowerReturnStmt(stmt); case "fn": - return this.lowerFnStmt(stmt); + break; case "let": return this.lowerLetStmt(stmt); case "assign": @@ -153,52 +213,6 @@ export class Lowerer { this.program.add(Ops.Jump); } - private lowerFnStmt(stmt: Stmt) { - if (stmt.kind.type !== "fn") { - throw new Error(); - } - const label = stmt.kind.ident === "main" - ? "main" - : `${stmt.kind.ident}_${stmt.id}`; - this.program.setLabel({ label }); - this.fnLabelNameMap[label] = stmt.kind.ident; - this.addSourceMap(stmt.pos); - - const outerLocals = this.locals; - const fnRoot = new LocalsFnRoot(outerLocals); - const outerProgram = this.program; - - const returnLabel = this.program.makeLabel(); - this.returnStack.push(returnLabel); - - this.program = outerProgram.fork(); - this.locals = fnRoot; - for (const { ident } of stmt.kind.params) { - this.locals.allocSym(ident); - } - if (stmt.kind.anno?.ident === "builtin") { - this.lowerFnBuiltinBody(stmt.kind.anno.values); - } else if (stmt.kind.anno?.ident === "remainder") { - this.program.add(Ops.Remainder); - } else { - this.lowerExpr(stmt.kind.body); - } - this.locals = outerLocals; - - const localAmount = fnRoot.stackReserved() - - stmt.kind.params.length; - for (let i = 0; i < localAmount; ++i) { - outerProgram.add(Ops.PushNull); - } - - this.returnStack.pop(); - this.program.setLabel(returnLabel); - this.program.add(Ops.Return); - - outerProgram.join(this.program); - this.program = outerProgram; - } - private lowerFnBuiltinBody(annoArgs: Expr[]) { if (annoArgs.length !== 1) { throw new Error("invalid # of arguments to builtin annotation"); @@ -306,8 +320,7 @@ export class Lowerer { return; } if (expr.kind.sym.type === "fn") { - const label = this.fnStmtIdLabelMap[expr.kind.sym.stmt.id]; - this.program.add(Ops.PushPtr, { label }); + this.program.add(Ops.PushPtr, 0); return; } throw new Error(`unhandled sym type '${expr.kind.sym.type}'`); @@ -528,7 +541,6 @@ export class Lowerer { } const outerLocals = this.locals; this.locals = new LocalLeaf(this.locals); - this.scoutFnHeaders(expr.kind.stmts); for (const stmt of expr.kind.stmts) { this.addSourceMap(stmt.pos); this.lowerStmt(stmt); @@ -541,8 +553,4 @@ export class Lowerer { } this.locals = outerLocals; } - - public printProgram() { - this.program.printProgram(); - } } diff --git a/compiler/mfg.ts b/compiler/mfg.ts new file mode 100644 index 0000000..ba785fe --- /dev/null +++ b/compiler/mfg.ts @@ -0,0 +1,244 @@ +// monomorphized function (ast-)graphs + +import { Expr, Stmt } from "./ast.ts"; +import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts"; +import { VType } from "./vtype.ts"; + +export type MonomorphizedFn = { + mid: string; + stmt: Stmt; + genericArgs?: VType[]; +}; + +export function monomorphizeFunctionGraphs(ast: Stmt[]): MonomorphizedFn[] { + const allFns = new AllFnsCollector().collect(ast); + const mainFn = findMain(allFns); + return [ + ...new Monomorphizer(allFns) + .monomorphize(mainFn) + .values(), + ]; +} + +function findMain(fns: Map): Stmt { + const mainId = fns.values().find((stmt) => + stmt.kind.type === "fn" && stmt.kind.ident === "main" + ); + if (mainId === undefined) { + console.error("error: cannot find function 'main'"); + console.error( + ` + Hear me out. Monomorphization, meaning the process + inwich generic functions are stamped out into seperate + specialized functions is actually really hard, and I + have a really hard time right now, figuring out, how + to do it in a smart way. To really explain it, let's + imagine you have a function, you defined as a(). + For each call with seperate generics arguments given, + such as a::() and a::(), a specialized + function has to be 'stamped out', ie. created and put + into the compilation with the rest of the program. Now + to the reason as to why 'main' is needed. To do the + monomorphization, we have to do it recursively. To + explain this, imagine you have a generic function a + and inside the body of a, you call another generic + function such as b with the same generic type. This + means that the monomorphization process of b depends + on the monomorphization of a. What this essentially + means, is that the monomorphization process works on + the program as a call graph, meaning a graph or tree + structure where each represents a function call to + either another function or a recursive call to the + function itself. But a problem arises from doing it + this way, which is that a call graph will need an + entrypoint. The language, as it is currently, does + not really require a 'main'-function. Or maybe it + does, but that's beside the point. The point is that + we need a main function, to be the entry point for + the call graph. The monomorphization process then + runs through the program from that entry point. This + means that each function we call, will itself be + monomorphized and added to the compilation. It also + means that functions that are not called, will also + not be added to the compilation. This essentially + eliminates uncalled/dead functions. Is this + particularly smart to do in such a high level part + of the compilation process? I don't know. It's + obvious that we can't just use every function as + an entry point in the call graph, because we're + actively added new functions. Additionally, with + generic functions, we don't know, if they're the + entry point, what generic arguments, they should + be monomorphized with. We could do monomorphization + the same way C++ does it, where all non-generic + functions before monomorphization are treated as + entry points in the call graph. But this has the + drawback that generic and non-generic functions + are treated differently, which has many underlying + drawbacks, especially pertaining to the amount of + work needed to handle both in all proceeding steps + of the compiler. Anyways, I just wanted to yap and + complain about the way generics and monomorphization + has made the compiler 100x more complicated, and + that I find it really hard to implement in a way, + that is not either too simplistic or so complicated + and advanced I'm too dumb to implement it. So if + you would be so kind as to make it clear to the + compiler, what function it should designate as + the entry point to the call graph, it will use + for monomorphization, that would be very kind of + you. The way you do this, is by added or selecting + one of your current functions and giving it the + name of 'main'. This is spelled m-a-i-n. The word + is synonemous with the words primary and principle. + The name is meant to designate the entry point into + the program, which is why the monomorphization + process uses this specific function as the entry + point into the call graph, it generates. So if you + would be so kind as to do that, that would really + make my day. In any case, keep hacking ferociously + on whatever you're working on. I have monomorphizer + to implement. See ya. -Your favorite compiler girl <3 + `.replaceAll(" ", "").trim(), + ); + throw new Error("cannot find function 'main'"); + } + return mainId; +} + +class AllFnsCollector implements AstVisitor { + private allFns = new Map(); + + public collect(ast: Stmt[]): Map { + visitStmts(ast, this); + return this.allFns; + } + + visitFnStmt(stmt: Stmt): VisitRes { + if (stmt.kind.type !== "fn") { + throw new Error(); + } + this.allFns.set(stmt.id, stmt); + } +} + +class Monomorphizer { + private monomorphizedFns = new Map(); + + public constructor(private allFns: Map) {} + + public monomorphize(mainFn: Stmt): Map { + this.monomorphizeFn(mainFn); + return this.monomorphizedFns; + } + + private monomorphizeFn(stmt: Stmt, genericArgs?: VType[]) { + const calls = new FnBodyCallCollector().collect(stmt); + for (const expr of calls) { + if (expr.kind.type !== "call") { + throw new Error(); + } + const vtype = expr.kind.subject.vtype!; + if (vtype.type === "fn") { + const stmt = this.allFns.get(vtype.fnStmtId)!; + if (stmt.kind.type !== "fn") { + throw new Error(); + } + const mid = fnCallMid(expr, stmt); + if (!this.monomorphizedFns.has(mid)) { + this.monomorphizedFns.set(mid, { mid, stmt }); + this.monomorphizeFn(stmt); + } + return; + } else if (vtype.type === "generic_args") { + if (vtype.subject.type !== "fn") { + throw new Error(); + } + const stmt = this.allFns.get(vtype.subject.fnStmtId)!; + if (stmt.kind.type !== "fn") { + throw new Error(); + } + const mid = fnCallMid(expr, stmt); + if (!this.monomorphizedFns.has(mid)) { + this.monomorphizedFns.set(mid, { mid, stmt, genericArgs }); + this.monomorphizeFn(stmt, vtype.genericArgs); + } + return; + } + throw new Error(); + } + } +} + +class FnBodyCallCollector implements AstVisitor { + private calls: Expr[] = []; + + public collect(stmt: Stmt): Expr[] { + if (stmt.kind.type !== "fn") { + throw new Error(); + } + visitExpr(stmt.kind.body, this); + return this.calls; + } + + visitCallExpr(expr: Expr): VisitRes { + if (expr.kind.type !== "call") { + throw new Error(); + } + this.calls.push(expr); + } +} + +export function fnCallMid(expr: Expr, stmt: Stmt) { + console.log(expr); + if (expr.kind.type !== "call") { + throw new Error(); + } + const vtype = expr.kind.subject.vtype!; + if (vtype.type === "fn") { + return fnStmtMid(stmt); + } else if (vtype.type === "generic_args") { + if (vtype.subject.type !== "fn") { + throw new Error(); + } + return fnStmtMid(stmt, vtype.genericArgs); + } + throw new Error(); +} + +export function fnStmtMid(stmt: Stmt, genericArgs?: VType[]) { + if (stmt.kind.type !== "fn") { + throw new Error(); + } + const { kind: { ident }, id } = stmt; + if (genericArgs !== undefined) { + const genericArgsStr = genericArgs + .map((arg) => vtypeMidPart(arg)) + .join("_"); + return `${ident}_${id}_${genericArgsStr}`; + } else { + return ident === "main" ? "main" : `${ident}_${id}`; + } +} + +export function vtypeMidPart(vtype: VType): string { + switch (vtype.type) { + case "string": + case "int": + case "bool": + case "null": + case "unknown": + return vtype.type; + case "array": + return `array(${vtypeMidPart(vtype.inner)})`; + case "struct": + return `struct(${vtype.structId})`; + case "fn": + return `fn(${vtype.fnStmtId})`; + case "error": + throw new Error("error in type"); + case "generic": + case "generic_args": + throw new Error("cannot be monomorphized"); + } +} diff --git a/compiler/vtype.ts b/compiler/vtype.ts index 7363697..5533794 100644 --- a/compiler/vtype.ts +++ b/compiler/vtype.ts @@ -6,18 +6,19 @@ export type VType = | { type: "string" } | { type: "bool" } | { type: "array"; inner: VType } - | { type: "struct"; fields: VTypeParam[] } + | { type: "struct"; structId: number; fields: VTypeParam[] } | { type: "fn"; genericParams?: VTypeGenericParam[]; params: VTypeParam[]; returnType: VType; + fnStmtId: number; } | { type: "generic" } | { - type: "generic_spec"; + type: "generic_args"; subject: VType; - genericParams: VType[]; + genericArgs: VType[]; }; export type VTypeParam = {