do more in the likes of generics

This commit is contained in:
sfja 2024-12-25 05:19:32 +01:00
parent cd923450f5
commit 7b5fee745d
5 changed files with 361 additions and 101 deletions

View File

@ -12,6 +12,7 @@ import {
export class Checker { export class Checker {
private fnReturnStack: VType[] = []; private fnReturnStack: VType[] = [];
private loopBreakStack: VType[][] = []; private loopBreakStack: VType[][] = [];
private structIdCounter = 0;
public constructor(private reporter: Reporter) {} public constructor(private reporter: Reporter) {}
@ -47,7 +48,13 @@ export class Checker {
param.vtype = vtype; param.vtype = vtype;
params.push({ ident: param.ident, vtype }); params.push({ ident: param.ident, vtype });
} }
stmt.kind.vtype = { type: "fn", genericParams, params, returnType }; stmt.kind.vtype = {
type: "fn",
genericParams,
params,
returnType,
fnStmtId: stmt.id,
};
} }
} }
@ -332,12 +339,7 @@ export class Checker {
if (fnStmt.kind.type !== "fn") { if (fnStmt.kind.type !== "fn") {
throw new Error(); throw new Error();
} }
const vtype = fnStmt.kind.vtype!; return fnStmt.kind.vtype!;
if (vtype.type !== "fn") {
throw new Error();
}
const { params, returnType } = vtype;
return { type: "fn", params, returnType };
} }
case "fn_param": case "fn_param":
return expr.kind.sym.param.vtype!; return expr.kind.sym.param.vtype!;
@ -445,10 +447,10 @@ export class Checker {
); );
return { type: "error" }; return { type: "error" };
} }
const genericParams = expr.kind.etypeArgs.map((arg) => const genericArgs = expr.kind.etypeArgs.map((arg) =>
this.checkEType(arg) this.checkEType(arg)
); );
if (genericParams.length !== subject.params.length) { if (genericArgs.length !== subject.params.length) {
this.report( this.report(
`incorrect number of arguments` + `incorrect number of arguments` +
`, expected ${subject.params.length}`, `, expected ${subject.params.length}`,
@ -456,9 +458,9 @@ export class Checker {
); );
} }
return { return {
type: "generic_spec", type: "generic_args",
subject, subject,
genericParams, genericArgs,
}; };
} }
@ -671,7 +673,9 @@ export class Checker {
ident: param.ident, ident: param.ident,
vtype: this.checkEType(param.etype!), vtype: this.checkEType(param.etype!),
})); }));
return { type: "struct", fields }; const structId = this.structIdCounter;
this.structIdCounter += 1;
return { type: "struct", structId, fields };
} }
throw new Error(`unknown explicit type ${etype.kind.type}`); throw new Error(`unknown explicit type ${etype.kind.type}`);
} }

View File

@ -5,6 +5,7 @@ import { SpecialLoopDesugarer } from "./desugar/special_loop.ts";
import { Reporter } from "./info.ts"; import { Reporter } from "./info.ts";
import { Lexer } from "./lexer.ts"; import { Lexer } from "./lexer.ts";
import { FnNamesMap, Lowerer } from "./lowerer.ts"; import { FnNamesMap, Lowerer } from "./lowerer.ts";
import { monomorphizeFunctionGraphs } from "./mfg.ts";
import { Parser } from "./parser.ts"; import { Parser } from "./parser.ts";
import { Resolver } from "./resolver.ts"; import { Resolver } from "./resolver.ts";
@ -45,9 +46,11 @@ export class Compiler {
Deno.exit(1); Deno.exit(1);
} }
const monomorphizedFns = monomorphizeFunctionGraphs(ast);
const lowerer = new Lowerer(lexer.currentPos()); const lowerer = new Lowerer(lexer.currentPos());
lowerer.lower(ast); lowerer.lower(monomorphizedFns);
// lowerer.printProgram(); lowerer.printProgram();
const { program, fnNames } = lowerer.finish(); const { program, fnNames } = lowerer.finish();
return { program, fnNames }; return { program, fnNames };

View File

@ -2,31 +2,42 @@ import { Builtins, Ops } from "./arch.ts";
import { Expr, Stmt } from "./ast.ts"; import { Expr, Stmt } from "./ast.ts";
import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts"; import { LocalLeaf, Locals, LocalsFnRoot } from "./lowerer_locals.ts";
import { Assembler, Label } from "./assembler.ts"; import { Assembler, Label } from "./assembler.ts";
import { vtypeToString } from "./vtype.ts"; import { VType, vtypeToString } from "./vtype.ts";
import { Pos } from "./token.ts"; import { Pos } from "./token.ts";
import { fnCallMid, fnStmtMid, MonomorphizedFn } from "./mfg.ts";
export type FnNamesMap = { [pc: number]: string }; export type FnNamesMap = { [pc: number]: string };
export class Lowerer { export class Lowerer {
private program = Assembler.newRoot(); private program = Assembler.newRoot();
private locals: Locals = new LocalsFnRoot();
private fnStmtIdLabelMap: { [stmtId: number]: string } = {};
private fnLabelNameMap: { [name: string]: string } = {}; private fnLabelNameMap: { [name: string]: string } = {};
private returnStack: Label[] = [];
private breakStack: Label[] = [];
public constructor(private lastPos: Pos) {} public constructor(private lastPos: Pos) {}
public lower(stmts: Stmt[]) { public lower(fns: MonomorphizedFn[]) {
this.addClearingSourceMap(); this.addClearingSourceMap();
this.program.add(Ops.PushPtr, { label: "main" }); this.program.add(Ops.PushPtr, { label: "main" });
this.program.add(Ops.Call, 0); this.program.add(Ops.Call, 0);
this.program.add(Ops.PushPtr, { label: "_exit" }); this.program.add(Ops.PushPtr, { label: "_exit" });
this.program.add(Ops.Jump); this.program.add(Ops.Jump);
this.scoutFnHeaders(stmts);
for (const stmt of stmts) { const fnMidLabelMap: { [mid: string]: string } = {};
this.lowerStaticStmt(stmt); for (const fn of fns) {
const mid = fnStmtMid(fn.stmt, fn.genericArgs);
fnMidLabelMap[mid] = mid;
} }
for (const fn of fns) {
const fnProgram = new FnLowerer(
this.program.fork(),
fnMidLabelMap,
fn.stmt,
fn.genericArgs,
)
.lowerFn();
this.program.join(fnProgram);
}
this.program.setLabel({ label: "_exit" }); this.program.setLabel({ label: "_exit" });
this.addSourceMap(this.lastPos); this.addSourceMap(this.lastPos);
this.program.add(Ops.Pop); this.program.add(Ops.Pop);
@ -43,6 +54,10 @@ export class Lowerer {
return { program, fnNames }; return { program, fnNames };
} }
public printProgram() {
this.program.printProgram();
}
private addSourceMap({ index, line, col }: Pos) { private addSourceMap({ index, line, col }: Pos) {
this.program.add(Ops.SourceMap, index, line, col); this.program.add(Ops.SourceMap, index, line, col);
} }
@ -50,31 +65,76 @@ export class Lowerer {
private addClearingSourceMap() { private addClearingSourceMap() {
this.program.add(Ops.SourceMap, 0, 1, 1); this.program.add(Ops.SourceMap, 0, 1, 1);
} }
}
private scoutFnHeaders(stmts: Stmt[]) { class FnLowerer {
for (const stmt of stmts) { private locals: Locals = new LocalsFnRoot();
private fnLabelNameMap: { [name: string]: string } = {};
private returnStack: Label[] = [];
private breakStack: Label[] = [];
public constructor(
private program: Assembler,
private fnMidLabelMap: { [mid: string]: string },
private fnStmt: Stmt,
private genericArgs?: VType[],
) {}
public lowerFn(): Assembler {
this.lowerFnStmt(this.fnStmt);
return this.program;
}
private lowerFnStmt(stmt: Stmt) {
if (stmt.kind.type !== "fn") { if (stmt.kind.type !== "fn") {
continue; throw new Error();
} }
const label = stmt.kind.ident === "main" const label = fnStmtMid(stmt, this.genericArgs);
? "main" this.program.setLabel({ label });
: `${stmt.kind.ident}_${stmt.id}`; this.fnLabelNameMap[label] = stmt.kind.ident;
this.fnStmtIdLabelMap[stmt.id] = label; this.addSourceMap(stmt.pos);
const outerLocals = this.locals;
const fnRoot = new LocalsFnRoot(outerLocals);
const outerProgram = this.program;
const returnLabel = this.program.makeLabel();
this.returnStack.push(returnLabel);
this.program = outerProgram.fork();
this.locals = fnRoot;
for (const { ident } of stmt.kind.params) {
this.locals.allocSym(ident);
} }
if (stmt.kind.anno?.ident === "builtin") {
this.lowerFnBuiltinBody(stmt.kind.anno.values);
} else if (stmt.kind.anno?.ident === "remainder") {
this.program.add(Ops.Remainder);
} else {
this.lowerExpr(stmt.kind.body);
}
this.locals = outerLocals;
const localAmount = fnRoot.stackReserved() -
stmt.kind.params.length;
for (let i = 0; i < localAmount; ++i) {
outerProgram.add(Ops.PushNull);
} }
private lowerStaticStmt(stmt: Stmt) { this.returnStack.pop();
switch (stmt.kind.type) { this.program.setLabel(returnLabel);
case "fn": this.program.add(Ops.Return);
return this.lowerFnStmt(stmt);
case "error": outerProgram.join(this.program);
case "break": this.program = outerProgram;
case "return":
case "let":
case "assign":
case "expr":
} }
throw new Error(`unhandled static statement '${stmt.kind.type}'`);
private addSourceMap({ index, line, col }: Pos) {
this.program.add(Ops.SourceMap, index, line, col);
}
private addClearingSourceMap() {
this.program.add(Ops.SourceMap, 0, 1, 1);
} }
private lowerStmt(stmt: Stmt) { private lowerStmt(stmt: Stmt) {
@ -86,7 +146,7 @@ export class Lowerer {
case "return": case "return":
return this.lowerReturnStmt(stmt); return this.lowerReturnStmt(stmt);
case "fn": case "fn":
return this.lowerFnStmt(stmt); break;
case "let": case "let":
return this.lowerLetStmt(stmt); return this.lowerLetStmt(stmt);
case "assign": case "assign":
@ -153,52 +213,6 @@ export class Lowerer {
this.program.add(Ops.Jump); this.program.add(Ops.Jump);
} }
private lowerFnStmt(stmt: Stmt) {
if (stmt.kind.type !== "fn") {
throw new Error();
}
const label = stmt.kind.ident === "main"
? "main"
: `${stmt.kind.ident}_${stmt.id}`;
this.program.setLabel({ label });
this.fnLabelNameMap[label] = stmt.kind.ident;
this.addSourceMap(stmt.pos);
const outerLocals = this.locals;
const fnRoot = new LocalsFnRoot(outerLocals);
const outerProgram = this.program;
const returnLabel = this.program.makeLabel();
this.returnStack.push(returnLabel);
this.program = outerProgram.fork();
this.locals = fnRoot;
for (const { ident } of stmt.kind.params) {
this.locals.allocSym(ident);
}
if (stmt.kind.anno?.ident === "builtin") {
this.lowerFnBuiltinBody(stmt.kind.anno.values);
} else if (stmt.kind.anno?.ident === "remainder") {
this.program.add(Ops.Remainder);
} else {
this.lowerExpr(stmt.kind.body);
}
this.locals = outerLocals;
const localAmount = fnRoot.stackReserved() -
stmt.kind.params.length;
for (let i = 0; i < localAmount; ++i) {
outerProgram.add(Ops.PushNull);
}
this.returnStack.pop();
this.program.setLabel(returnLabel);
this.program.add(Ops.Return);
outerProgram.join(this.program);
this.program = outerProgram;
}
private lowerFnBuiltinBody(annoArgs: Expr[]) { private lowerFnBuiltinBody(annoArgs: Expr[]) {
if (annoArgs.length !== 1) { if (annoArgs.length !== 1) {
throw new Error("invalid # of arguments to builtin annotation"); throw new Error("invalid # of arguments to builtin annotation");
@ -306,8 +320,7 @@ export class Lowerer {
return; return;
} }
if (expr.kind.sym.type === "fn") { if (expr.kind.sym.type === "fn") {
const label = this.fnStmtIdLabelMap[expr.kind.sym.stmt.id]; this.program.add(Ops.PushPtr, 0);
this.program.add(Ops.PushPtr, { label });
return; return;
} }
throw new Error(`unhandled sym type '${expr.kind.sym.type}'`); throw new Error(`unhandled sym type '${expr.kind.sym.type}'`);
@ -528,7 +541,6 @@ export class Lowerer {
} }
const outerLocals = this.locals; const outerLocals = this.locals;
this.locals = new LocalLeaf(this.locals); this.locals = new LocalLeaf(this.locals);
this.scoutFnHeaders(expr.kind.stmts);
for (const stmt of expr.kind.stmts) { for (const stmt of expr.kind.stmts) {
this.addSourceMap(stmt.pos); this.addSourceMap(stmt.pos);
this.lowerStmt(stmt); this.lowerStmt(stmt);
@ -541,8 +553,4 @@ export class Lowerer {
} }
this.locals = outerLocals; this.locals = outerLocals;
} }
public printProgram() {
this.program.printProgram();
}
} }

244
compiler/mfg.ts Normal file
View File

@ -0,0 +1,244 @@
// monomorphized function (ast-)graphs
import { Expr, Stmt } from "./ast.ts";
import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts";
import { VType } from "./vtype.ts";
export type MonomorphizedFn = {
mid: string;
stmt: Stmt;
genericArgs?: VType[];
};
export function monomorphizeFunctionGraphs(ast: Stmt[]): MonomorphizedFn[] {
const allFns = new AllFnsCollector().collect(ast);
const mainFn = findMain(allFns);
return [
...new Monomorphizer(allFns)
.monomorphize(mainFn)
.values(),
];
}
function findMain(fns: Map<number, Stmt>): Stmt {
const mainId = fns.values().find((stmt) =>
stmt.kind.type === "fn" && stmt.kind.ident === "main"
);
if (mainId === undefined) {
console.error("error: cannot find function 'main'");
console.error(
`
Hear me out. Monomorphization, meaning the process
inwich generic functions are stamped out into seperate
specialized functions is actually really hard, and I
have a really hard time right now, figuring out, how
to do it in a smart way. To really explain it, let's
imagine you have a function, you defined as a<T>().
For each call with seperate generics arguments given,
such as a::<int>() and a::<string>(), a specialized
function has to be 'stamped out', ie. created and put
into the compilation with the rest of the program. Now
to the reason as to why 'main' is needed. To do the
monomorphization, we have to do it recursively. To
explain this, imagine you have a generic function a<T>
and inside the body of a<T>, you call another generic
function such as b<T> with the same generic type. This
means that the monomorphization process of b<T> depends
on the monomorphization of a<T>. What this essentially
means, is that the monomorphization process works on
the program as a call graph, meaning a graph or tree
structure where each represents a function call to
either another function or a recursive call to the
function itself. But a problem arises from doing it
this way, which is that a call graph will need an
entrypoint. The language, as it is currently, does
not really require a 'main'-function. Or maybe it
does, but that's beside the point. The point is that
we need a main function, to be the entry point for
the call graph. The monomorphization process then
runs through the program from that entry point. This
means that each function we call, will itself be
monomorphized and added to the compilation. It also
means that functions that are not called, will also
not be added to the compilation. This essentially
eliminates uncalled/dead functions. Is this
particularly smart to do in such a high level part
of the compilation process? I don't know. It's
obvious that we can't just use every function as
an entry point in the call graph, because we're
actively added new functions. Additionally, with
generic functions, we don't know, if they're the
entry point, what generic arguments, they should
be monomorphized with. We could do monomorphization
the same way C++ does it, where all non-generic
functions before monomorphization are treated as
entry points in the call graph. But this has the
drawback that generic and non-generic functions
are treated differently, which has many underlying
drawbacks, especially pertaining to the amount of
work needed to handle both in all proceeding steps
of the compiler. Anyways, I just wanted to yap and
complain about the way generics and monomorphization
has made the compiler 100x more complicated, and
that I find it really hard to implement in a way,
that is not either too simplistic or so complicated
and advanced I'm too dumb to implement it. So if
you would be so kind as to make it clear to the
compiler, what function it should designate as
the entry point to the call graph, it will use
for monomorphization, that would be very kind of
you. The way you do this, is by added or selecting
one of your current functions and giving it the
name of 'main'. This is spelled m-a-i-n. The word
is synonemous with the words primary and principle.
The name is meant to designate the entry point into
the program, which is why the monomorphization
process uses this specific function as the entry
point into the call graph, it generates. So if you
would be so kind as to do that, that would really
make my day. In any case, keep hacking ferociously
on whatever you're working on. I have monomorphizer
to implement. See ya. -Your favorite compiler girl <3
`.replaceAll(" ", "").trim(),
);
throw new Error("cannot find function 'main'");
}
return mainId;
}
class AllFnsCollector implements AstVisitor {
private allFns = new Map<number, Stmt>();
public collect(ast: Stmt[]): Map<number, Stmt> {
visitStmts(ast, this);
return this.allFns;
}
visitFnStmt(stmt: Stmt): VisitRes {
if (stmt.kind.type !== "fn") {
throw new Error();
}
this.allFns.set(stmt.id, stmt);
}
}
class Monomorphizer {
private monomorphizedFns = new Map<string, MonomorphizedFn>();
public constructor(private allFns: Map<number, Stmt>) {}
public monomorphize(mainFn: Stmt): Map<string, MonomorphizedFn> {
this.monomorphizeFn(mainFn);
return this.monomorphizedFns;
}
private monomorphizeFn(stmt: Stmt, genericArgs?: VType[]) {
const calls = new FnBodyCallCollector().collect(stmt);
for (const expr of calls) {
if (expr.kind.type !== "call") {
throw new Error();
}
const vtype = expr.kind.subject.vtype!;
if (vtype.type === "fn") {
const stmt = this.allFns.get(vtype.fnStmtId)!;
if (stmt.kind.type !== "fn") {
throw new Error();
}
const mid = fnCallMid(expr, stmt);
if (!this.monomorphizedFns.has(mid)) {
this.monomorphizedFns.set(mid, { mid, stmt });
this.monomorphizeFn(stmt);
}
return;
} else if (vtype.type === "generic_args") {
if (vtype.subject.type !== "fn") {
throw new Error();
}
const stmt = this.allFns.get(vtype.subject.fnStmtId)!;
if (stmt.kind.type !== "fn") {
throw new Error();
}
const mid = fnCallMid(expr, stmt);
if (!this.monomorphizedFns.has(mid)) {
this.monomorphizedFns.set(mid, { mid, stmt, genericArgs });
this.monomorphizeFn(stmt, vtype.genericArgs);
}
return;
}
throw new Error();
}
}
}
class FnBodyCallCollector implements AstVisitor {
private calls: Expr[] = [];
public collect(stmt: Stmt): Expr[] {
if (stmt.kind.type !== "fn") {
throw new Error();
}
visitExpr(stmt.kind.body, this);
return this.calls;
}
visitCallExpr(expr: Expr): VisitRes {
if (expr.kind.type !== "call") {
throw new Error();
}
this.calls.push(expr);
}
}
export function fnCallMid(expr: Expr, stmt: Stmt) {
console.log(expr);
if (expr.kind.type !== "call") {
throw new Error();
}
const vtype = expr.kind.subject.vtype!;
if (vtype.type === "fn") {
return fnStmtMid(stmt);
} else if (vtype.type === "generic_args") {
if (vtype.subject.type !== "fn") {
throw new Error();
}
return fnStmtMid(stmt, vtype.genericArgs);
}
throw new Error();
}
export function fnStmtMid(stmt: Stmt, genericArgs?: VType[]) {
if (stmt.kind.type !== "fn") {
throw new Error();
}
const { kind: { ident }, id } = stmt;
if (genericArgs !== undefined) {
const genericArgsStr = genericArgs
.map((arg) => vtypeMidPart(arg))
.join("_");
return `${ident}_${id}_${genericArgsStr}`;
} else {
return ident === "main" ? "main" : `${ident}_${id}`;
}
}
export function vtypeMidPart(vtype: VType): string {
switch (vtype.type) {
case "string":
case "int":
case "bool":
case "null":
case "unknown":
return vtype.type;
case "array":
return `array(${vtypeMidPart(vtype.inner)})`;
case "struct":
return `struct(${vtype.structId})`;
case "fn":
return `fn(${vtype.fnStmtId})`;
case "error":
throw new Error("error in type");
case "generic":
case "generic_args":
throw new Error("cannot be monomorphized");
}
}

View File

@ -6,18 +6,19 @@ export type VType =
| { type: "string" } | { type: "string" }
| { type: "bool" } | { type: "bool" }
| { type: "array"; inner: VType } | { type: "array"; inner: VType }
| { type: "struct"; fields: VTypeParam[] } | { type: "struct"; structId: number; fields: VTypeParam[] }
| { | {
type: "fn"; type: "fn";
genericParams?: VTypeGenericParam[]; genericParams?: VTypeGenericParam[];
params: VTypeParam[]; params: VTypeParam[];
returnType: VType; returnType: VType;
fnStmtId: number;
} }
| { type: "generic" } | { type: "generic" }
| { | {
type: "generic_spec"; type: "generic_args";
subject: VType; subject: VType;
genericParams: VType[]; genericArgs: VType[];
}; };
export type VTypeParam = { export type VTypeParam = {