diff --git a/compiler/checker.ts b/compiler/checker.ts index e3673e1..d300d9e 100644 --- a/compiler/checker.ts +++ b/compiler/checker.ts @@ -12,7 +12,6 @@ import { export class Checker { private fnReturnStack: VType[] = []; private loopBreakStack: VType[][] = []; - private structIdCounter = 0; public constructor(private reporter: Reporter) {} @@ -48,13 +47,7 @@ export class Checker { param.vtype = vtype; params.push({ ident: param.ident, vtype }); } - stmt.kind.vtype = { - type: "fn", - genericParams, - params, - returnType, - fnStmtId: stmt.id, - }; + stmt.kind.vtype = { type: "fn", genericParams, params, returnType }; } } @@ -339,7 +332,12 @@ export class Checker { if (fnStmt.kind.type !== "fn") { throw new Error(); } - return fnStmt.kind.vtype!; + const vtype = fnStmt.kind.vtype!; + if (vtype.type !== "fn") { + throw new Error(); + } + const { params, returnType } = vtype; + return { type: "fn", params, returnType }; } case "fn_param": return expr.kind.sym.param.vtype!; @@ -447,10 +445,10 @@ export class Checker { ); return { type: "error" }; } - const genericArgs = expr.kind.etypeArgs.map((arg) => + const genericParams = expr.kind.etypeArgs.map((arg) => this.checkEType(arg) ); - if (genericArgs.length !== subject.params.length) { + if (genericParams.length !== subject.params.length) { this.report( `incorrect number of arguments` + `, expected ${subject.params.length}`, @@ -458,9 +456,9 @@ export class Checker { ); } return { - type: "generic_args", + type: "generic_spec", subject, - genericArgs, + genericParams, }; } @@ -673,9 +671,7 @@ export class Checker { ident: param.ident, vtype: this.checkEType(param.etype!), })); - const structId = this.structIdCounter; - this.structIdCounter += 1; - return { type: "struct", structId, fields }; + return { type: "struct", fields }; } throw new Error(`unknown explicit type ${etype.kind.type}`); } diff --git a/compiler/compiler.ts b/compiler/compiler.ts index fdd2bac..2b44225 100644 --- a/compiler/compiler.ts +++ b/compiler/compiler.ts @@ -5,7 +5,6 @@ import { SpecialLoopDesugarer } from "./desugar/special_loop.ts"; import { Reporter } from "./info.ts"; import { Lexer } from "./lexer.ts"; import { FnNamesMap, Lowerer } from "./lowerer.ts"; -import { monomorphizeFunctionGraphs } from "./mfg.ts"; import { Parser } from "./parser.ts"; import { Resolver } from "./resolver.ts"; @@ -46,11 +45,9 @@ export class Compiler { Deno.exit(1); } - const monomorphizedFns = monomorphizeFunctionGraphs(ast); - const lowerer = new Lowerer(lexer.currentPos()); - lowerer.lower(monomorphizedFns); - lowerer.printProgram(); + lowerer.lower(ast); + // lowerer.printProgram(); const { program, fnNames } = lowerer.finish(); return { program, fnNames }; diff --git a/compiler/lowerer.ts b/compiler/lowerer.ts index d3c50dc..64beadc 100644 --- a/compiler/lowerer.ts +++ b/compiler/lowerer.ts @@ -2,42 +2,31 @@ import { Builtins, Ops } from "./arch.ts"; import { Expr, Stmt } from "./ast.ts"; import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts"; import { Assembler, Label } from "./assembler.ts"; -import { VType, vtypeToString } from "./vtype.ts"; +import { vtypeToString } from "./vtype.ts"; import { Pos } from "./token.ts"; -import { fnCallMid, fnStmtMid, MonomorphizedFn } from "./mfg.ts"; export type FnNamesMap = { [pc: number]: string }; export class Lowerer { private program = Assembler.newRoot(); + private locals: Locals = new LocalsFnRoot(); + private fnStmtIdLabelMap: { [stmtId: number]: string } = {}; private fnLabelNameMap: { [name: string]: string } = {}; + private returnStack: Label[] = []; + private breakStack: Label[] = []; public constructor(private lastPos: Pos) {} - public lower(fns: MonomorphizedFn[]) { + public lower(stmts: Stmt[]) { this.addClearingSourceMap(); this.program.add(Ops.PushPtr, { label: "main" }); this.program.add(Ops.Call, 0); this.program.add(Ops.PushPtr, { label: "_exit" }); this.program.add(Ops.Jump); - - const fnMidLabelMap: { [mid: string]: string } = {}; - for (const fn of fns) { - const mid = fnStmtMid(fn.stmt, fn.genericArgs); - fnMidLabelMap[mid] = mid; + this.scoutFnHeaders(stmts); + for (const stmt of stmts) { + this.lowerStaticStmt(stmt); } - - for (const fn of fns) { - const fnProgram = new FnLowerer( - this.program.fork(), - fnMidLabelMap, - fn.stmt, - fn.genericArgs, - ) - .lowerFn(); - this.program.join(fnProgram); - } - this.program.setLabel({ label: "_exit" }); this.addSourceMap(this.lastPos); this.program.add(Ops.Pop); @@ -54,10 +43,6 @@ export class Lowerer { return { program, fnNames }; } - public printProgram() { - this.program.printProgram(); - } - private addSourceMap({ index, line, col }: Pos) { this.program.add(Ops.SourceMap, index, line, col); } @@ -65,76 +50,31 @@ export class Lowerer { private addClearingSourceMap() { this.program.add(Ops.SourceMap, 0, 1, 1); } -} -class FnLowerer { - private locals: Locals = new LocalsFnRoot(); - private fnLabelNameMap: { [name: string]: string } = {}; - private returnStack: Label[] = []; - private breakStack: Label[] = []; - - public constructor( - private program: Assembler, - private fnMidLabelMap: { [mid: string]: string }, - private fnStmt: Stmt, - private genericArgs?: VType[], - ) {} - - public lowerFn(): Assembler { - this.lowerFnStmt(this.fnStmt); - return this.program; + private scoutFnHeaders(stmts: Stmt[]) { + for (const stmt of stmts) { + if (stmt.kind.type !== "fn") { + continue; + } + const label = stmt.kind.ident === "main" + ? "main" + : `${stmt.kind.ident}_${stmt.id}`; + this.fnStmtIdLabelMap[stmt.id] = label; + } } - private lowerFnStmt(stmt: Stmt) { - if (stmt.kind.type !== "fn") { - throw new Error(); + private lowerStaticStmt(stmt: Stmt) { + switch (stmt.kind.type) { + case "fn": + return this.lowerFnStmt(stmt); + case "error": + case "break": + case "return": + case "let": + case "assign": + case "expr": } - const label = fnStmtMid(stmt, this.genericArgs); - this.program.setLabel({ label }); - this.fnLabelNameMap[label] = stmt.kind.ident; - this.addSourceMap(stmt.pos); - - const outerLocals = this.locals; - const fnRoot = new LocalsFnRoot(outerLocals); - const outerProgram = this.program; - - const returnLabel = this.program.makeLabel(); - this.returnStack.push(returnLabel); - - this.program = outerProgram.fork(); - this.locals = fnRoot; - for (const { ident } of stmt.kind.params) { - this.locals.allocSym(ident); - } - if (stmt.kind.anno?.ident === "builtin") { - this.lowerFnBuiltinBody(stmt.kind.anno.values); - } else if (stmt.kind.anno?.ident === "remainder") { - this.program.add(Ops.Remainder); - } else { - this.lowerExpr(stmt.kind.body); - } - this.locals = outerLocals; - - const localAmount = fnRoot.stackReserved() - - stmt.kind.params.length; - for (let i = 0; i < localAmount; ++i) { - outerProgram.add(Ops.PushNull); - } - - this.returnStack.pop(); - this.program.setLabel(returnLabel); - this.program.add(Ops.Return); - - outerProgram.join(this.program); - this.program = outerProgram; - } - - private addSourceMap({ index, line, col }: Pos) { - this.program.add(Ops.SourceMap, index, line, col); - } - - private addClearingSourceMap() { - this.program.add(Ops.SourceMap, 0, 1, 1); + throw new Error(`unhandled static statement '${stmt.kind.type}'`); } private lowerStmt(stmt: Stmt) { @@ -146,7 +86,7 @@ class FnLowerer { case "return": return this.lowerReturnStmt(stmt); case "fn": - break; + return this.lowerFnStmt(stmt); case "let": return this.lowerLetStmt(stmt); case "assign": @@ -213,6 +153,52 @@ class FnLowerer { this.program.add(Ops.Jump); } + private lowerFnStmt(stmt: Stmt) { + if (stmt.kind.type !== "fn") { + throw new Error(); + } + const label = stmt.kind.ident === "main" + ? "main" + : `${stmt.kind.ident}_${stmt.id}`; + this.program.setLabel({ label }); + this.fnLabelNameMap[label] = stmt.kind.ident; + this.addSourceMap(stmt.pos); + + const outerLocals = this.locals; + const fnRoot = new LocalsFnRoot(outerLocals); + const outerProgram = this.program; + + const returnLabel = this.program.makeLabel(); + this.returnStack.push(returnLabel); + + this.program = outerProgram.fork(); + this.locals = fnRoot; + for (const { ident } of stmt.kind.params) { + this.locals.allocSym(ident); + } + if (stmt.kind.anno?.ident === "builtin") { + this.lowerFnBuiltinBody(stmt.kind.anno.values); + } else if (stmt.kind.anno?.ident === "remainder") { + this.program.add(Ops.Remainder); + } else { + this.lowerExpr(stmt.kind.body); + } + this.locals = outerLocals; + + const localAmount = fnRoot.stackReserved() - + stmt.kind.params.length; + for (let i = 0; i < localAmount; ++i) { + outerProgram.add(Ops.PushNull); + } + + this.returnStack.pop(); + this.program.setLabel(returnLabel); + this.program.add(Ops.Return); + + outerProgram.join(this.program); + this.program = outerProgram; + } + private lowerFnBuiltinBody(annoArgs: Expr[]) { if (annoArgs.length !== 1) { throw new Error("invalid # of arguments to builtin annotation"); @@ -320,7 +306,8 @@ class FnLowerer { return; } if (expr.kind.sym.type === "fn") { - this.program.add(Ops.PushPtr, 0); + const label = this.fnStmtIdLabelMap[expr.kind.sym.stmt.id]; + this.program.add(Ops.PushPtr, { label }); return; } throw new Error(`unhandled sym type '${expr.kind.sym.type}'`); @@ -541,6 +528,7 @@ class FnLowerer { } const outerLocals = this.locals; this.locals = new LocalLeaf(this.locals); + this.scoutFnHeaders(expr.kind.stmts); for (const stmt of expr.kind.stmts) { this.addSourceMap(stmt.pos); this.lowerStmt(stmt); @@ -553,4 +541,8 @@ class FnLowerer { } this.locals = outerLocals; } + + public printProgram() { + this.program.printProgram(); + } } diff --git a/compiler/mfg.ts b/compiler/mfg.ts deleted file mode 100644 index ba785fe..0000000 --- a/compiler/mfg.ts +++ /dev/null @@ -1,244 +0,0 @@ -// monomorphized function (ast-)graphs - -import { Expr, Stmt } from "./ast.ts"; -import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts"; -import { VType } from "./vtype.ts"; - -export type MonomorphizedFn = { - mid: string; - stmt: Stmt; - genericArgs?: VType[]; -}; - -export function monomorphizeFunctionGraphs(ast: Stmt[]): MonomorphizedFn[] { - const allFns = new AllFnsCollector().collect(ast); - const mainFn = findMain(allFns); - return [ - ...new Monomorphizer(allFns) - .monomorphize(mainFn) - .values(), - ]; -} - -function findMain(fns: Map): Stmt { - const mainId = fns.values().find((stmt) => - stmt.kind.type === "fn" && stmt.kind.ident === "main" - ); - if (mainId === undefined) { - console.error("error: cannot find function 'main'"); - console.error( - ` - Hear me out. Monomorphization, meaning the process - inwich generic functions are stamped out into seperate - specialized functions is actually really hard, and I - have a really hard time right now, figuring out, how - to do it in a smart way. To really explain it, let's - imagine you have a function, you defined as a(). - For each call with seperate generics arguments given, - such as a::() and a::(), a specialized - function has to be 'stamped out', ie. created and put - into the compilation with the rest of the program. Now - to the reason as to why 'main' is needed. To do the - monomorphization, we have to do it recursively. To - explain this, imagine you have a generic function a - and inside the body of a, you call another generic - function such as b with the same generic type. This - means that the monomorphization process of b depends - on the monomorphization of a. What this essentially - means, is that the monomorphization process works on - the program as a call graph, meaning a graph or tree - structure where each represents a function call to - either another function or a recursive call to the - function itself. But a problem arises from doing it - this way, which is that a call graph will need an - entrypoint. The language, as it is currently, does - not really require a 'main'-function. Or maybe it - does, but that's beside the point. The point is that - we need a main function, to be the entry point for - the call graph. The monomorphization process then - runs through the program from that entry point. This - means that each function we call, will itself be - monomorphized and added to the compilation. It also - means that functions that are not called, will also - not be added to the compilation. This essentially - eliminates uncalled/dead functions. Is this - particularly smart to do in such a high level part - of the compilation process? I don't know. It's - obvious that we can't just use every function as - an entry point in the call graph, because we're - actively added new functions. Additionally, with - generic functions, we don't know, if they're the - entry point, what generic arguments, they should - be monomorphized with. We could do monomorphization - the same way C++ does it, where all non-generic - functions before monomorphization are treated as - entry points in the call graph. But this has the - drawback that generic and non-generic functions - are treated differently, which has many underlying - drawbacks, especially pertaining to the amount of - work needed to handle both in all proceeding steps - of the compiler. Anyways, I just wanted to yap and - complain about the way generics and monomorphization - has made the compiler 100x more complicated, and - that I find it really hard to implement in a way, - that is not either too simplistic or so complicated - and advanced I'm too dumb to implement it. So if - you would be so kind as to make it clear to the - compiler, what function it should designate as - the entry point to the call graph, it will use - for monomorphization, that would be very kind of - you. The way you do this, is by added or selecting - one of your current functions and giving it the - name of 'main'. This is spelled m-a-i-n. The word - is synonemous with the words primary and principle. - The name is meant to designate the entry point into - the program, which is why the monomorphization - process uses this specific function as the entry - point into the call graph, it generates. So if you - would be so kind as to do that, that would really - make my day. In any case, keep hacking ferociously - on whatever you're working on. I have monomorphizer - to implement. See ya. -Your favorite compiler girl <3 - `.replaceAll(" ", "").trim(), - ); - throw new Error("cannot find function 'main'"); - } - return mainId; -} - -class AllFnsCollector implements AstVisitor { - private allFns = new Map(); - - public collect(ast: Stmt[]): Map { - visitStmts(ast, this); - return this.allFns; - } - - visitFnStmt(stmt: Stmt): VisitRes { - if (stmt.kind.type !== "fn") { - throw new Error(); - } - this.allFns.set(stmt.id, stmt); - } -} - -class Monomorphizer { - private monomorphizedFns = new Map(); - - public constructor(private allFns: Map) {} - - public monomorphize(mainFn: Stmt): Map { - this.monomorphizeFn(mainFn); - return this.monomorphizedFns; - } - - private monomorphizeFn(stmt: Stmt, genericArgs?: VType[]) { - const calls = new FnBodyCallCollector().collect(stmt); - for (const expr of calls) { - if (expr.kind.type !== "call") { - throw new Error(); - } - const vtype = expr.kind.subject.vtype!; - if (vtype.type === "fn") { - const stmt = this.allFns.get(vtype.fnStmtId)!; - if (stmt.kind.type !== "fn") { - throw new Error(); - } - const mid = fnCallMid(expr, stmt); - if (!this.monomorphizedFns.has(mid)) { - this.monomorphizedFns.set(mid, { mid, stmt }); - this.monomorphizeFn(stmt); - } - return; - } else if (vtype.type === "generic_args") { - if (vtype.subject.type !== "fn") { - throw new Error(); - } - const stmt = this.allFns.get(vtype.subject.fnStmtId)!; - if (stmt.kind.type !== "fn") { - throw new Error(); - } - const mid = fnCallMid(expr, stmt); - if (!this.monomorphizedFns.has(mid)) { - this.monomorphizedFns.set(mid, { mid, stmt, genericArgs }); - this.monomorphizeFn(stmt, vtype.genericArgs); - } - return; - } - throw new Error(); - } - } -} - -class FnBodyCallCollector implements AstVisitor { - private calls: Expr[] = []; - - public collect(stmt: Stmt): Expr[] { - if (stmt.kind.type !== "fn") { - throw new Error(); - } - visitExpr(stmt.kind.body, this); - return this.calls; - } - - visitCallExpr(expr: Expr): VisitRes { - if (expr.kind.type !== "call") { - throw new Error(); - } - this.calls.push(expr); - } -} - -export function fnCallMid(expr: Expr, stmt: Stmt) { - console.log(expr); - if (expr.kind.type !== "call") { - throw new Error(); - } - const vtype = expr.kind.subject.vtype!; - if (vtype.type === "fn") { - return fnStmtMid(stmt); - } else if (vtype.type === "generic_args") { - if (vtype.subject.type !== "fn") { - throw new Error(); - } - return fnStmtMid(stmt, vtype.genericArgs); - } - throw new Error(); -} - -export function fnStmtMid(stmt: Stmt, genericArgs?: VType[]) { - if (stmt.kind.type !== "fn") { - throw new Error(); - } - const { kind: { ident }, id } = stmt; - if (genericArgs !== undefined) { - const genericArgsStr = genericArgs - .map((arg) => vtypeMidPart(arg)) - .join("_"); - return `${ident}_${id}_${genericArgsStr}`; - } else { - return ident === "main" ? "main" : `${ident}_${id}`; - } -} - -export function vtypeMidPart(vtype: VType): string { - switch (vtype.type) { - case "string": - case "int": - case "bool": - case "null": - case "unknown": - return vtype.type; - case "array": - return `array(${vtypeMidPart(vtype.inner)})`; - case "struct": - return `struct(${vtype.structId})`; - case "fn": - return `fn(${vtype.fnStmtId})`; - case "error": - throw new Error("error in type"); - case "generic": - case "generic_args": - throw new Error("cannot be monomorphized"); - } -} diff --git a/compiler/vtype.ts b/compiler/vtype.ts index 5533794..7363697 100644 --- a/compiler/vtype.ts +++ b/compiler/vtype.ts @@ -6,19 +6,18 @@ export type VType = | { type: "string" } | { type: "bool" } | { type: "array"; inner: VType } - | { type: "struct"; structId: number; fields: VTypeParam[] } + | { type: "struct"; fields: VTypeParam[] } | { type: "fn"; genericParams?: VTypeGenericParam[]; params: VTypeParam[]; returnType: VType; - fnStmtId: number; } | { type: "generic" } | { - type: "generic_args"; + type: "generic_spec"; subject: VType; - genericArgs: VType[]; + genericParams: VType[]; }; export type VTypeParam = {