From a4c1b60a61250a45c42aeecdb94998890f5b1cf8 Mon Sep 17 00:00:00 2001 From: sfja Date: Thu, 26 Dec 2024 03:56:59 +0100 Subject: [PATCH] karlkode generic type inferrence --- compiler/ast.ts | 9 +- compiler/checker.ts | 317 ++++++++++++++++++++++++++++++------- compiler/compiler.ts | 2 +- compiler/mono.ts | 57 ++++++- examples/generic_array.slg | 10 +- slige-run.sh | 2 +- 6 files changed, 329 insertions(+), 68 deletions(-) diff --git a/compiler/ast.ts b/compiler/ast.ts index 8a342e7..975f762 100644 --- a/compiler/ast.ts +++ b/compiler/ast.ts @@ -1,5 +1,5 @@ import { Pos } from "./token.ts"; -import { VType } from "./vtype.ts"; +import { GenericArgsMap, VType } from "./vtype.ts"; export type Stmt = { kind: StmtKind; @@ -48,7 +48,12 @@ export type ExprKind = | { type: "group"; expr: Expr } | { type: "field"; subject: Expr; ident: string } | { type: "index"; subject: Expr; value: Expr } - | { type: "call"; subject: Expr; args: Expr[] } + | { + type: "call"; + subject: Expr; + args: Expr[]; + genericArgs?: GenericArgsMap; + } | { type: "path"; subject: Expr; ident: string } | { type: "etype_args"; subject: Expr; etypeArgs: EType[] } | { type: "unary"; unaryType: UnaryType; subject: Expr } diff --git a/compiler/checker.ts b/compiler/checker.ts index 07fe346..1da414f 100644 --- a/compiler/checker.ts +++ b/compiler/checker.ts @@ -414,75 +414,286 @@ export class Checker { const subject = this.checkExpr(expr.kind.subject); if (subject.type === "error") return subject; if (subject.type === "fn") { - if (subject.genericParams !== undefined) { - throw new Error("😭😭😭"); - } - const args = expr.kind.args.map((arg) => this.checkExpr(arg)); - if (args.length !== subject.params.length) { + if (expr.kind.args.length !== subject.params.length) { this.report( `incorrect number of arguments` + `, expected ${subject.params.length}`, pos, ); } - for (let i = 0; i < args.length; ++i) { - if (!vtypesEqual(args[i], subject.params[i].vtype)) { - this.report( - `incorrect argument ${i} '${subject.params[i].ident}'` + - `, expected ${ - vtypeToString(subject.params[i].vtype) - }` + - `, got ${vtypeToString(args[i])}`, - pos, - ); - break; - } - } - return subject.returnType; - } - if (subject.type === "generic_spec" && subject.subject.type === "fn") { - const inner = subject.subject; - const params = inner.params; const args = expr.kind.args.map((arg) => this.checkExpr(arg)); - if (args.length !== params.length) { - this.report( - `incorrect number of arguments` + - `, expected ${params.length}`, + if (subject.genericParams === undefined) { + return this.checkCallExprNoGenericsTail( + expr, + subject, + args, pos, ); } - for (let i = 0; i < args.length; ++i) { - const vtypeCompatible = vtypesEqual( - args[i], - params[i].vtype, - subject.genericArgs, - ); - if (!vtypeCompatible) { - this.report( - `incorrect argument ${i} '${inner.params[i].ident}'` + - `, expected ${ - vtypeToString( - extractGenericType( - params[i].vtype, - subject.genericArgs, - ), - ) - }` + - `, got ${vtypeToString(args[i])}`, - pos, - ); - break; - } - } - return extractGenericType( - subject.subject.returnType, - subject.genericArgs, + return this.checkCallExprInferredGenericsTail( + expr, + subject, + args, + pos, ); } + if (subject.type === "generic_spec" && subject.subject.type === "fn") { + return this.checkCallExprExplicitGenericsTail(expr, subject); + } this.report("cannot call non-fn", pos); return { type: "error" }; } + private checkCallExprNoGenericsTail( + expr: Expr, + subject: VType, + args: VType[], + pos: Pos, + ): VType { + if ( + expr.kind.type !== "call" || subject.type !== "fn" + ) { + throw new Error(); + } + for (let i = 0; i < args.length; ++i) { + if (this.vtypeContainsGeneric(args[i])) { + this.report( + `amfibious generic parameter for argument ${i}, please specify generic types explicitly`, + pos, + ); + return { type: "error" }; + } + } + + for (let i = 0; i < args.length; ++i) { + if (!vtypesEqual(args[i], subject.params[i].vtype)) { + this.report( + `incorrect argument ${i} '${subject.params[i].ident}'` + + `, expected ${vtypeToString(subject.params[i].vtype)}` + + `, got ${vtypeToString(args[i])}`, + pos, + ); + break; + } + } + + return subject.returnType; + } + + private checkCallExprInferredGenericsTail( + expr: Expr, + subject: VType, + args: VType[], + pos: Pos, + ): VType { + if ( + expr.kind.type !== "call" || subject.type !== "fn" || + subject.genericParams === undefined + ) { + throw new Error(); + } + const genericArgsRes = this.inferGenericArgs( + subject.genericParams, + subject.params, + args, + pos, + ); + if (!genericArgsRes.ok) { + return { type: "error" }; + } + const genericArgs = genericArgsRes.value; + + for (let i = 0; i < args.length; ++i) { + const vtypeCompatible = vtypesEqual( + args[i], + subject.params[i].vtype, + genericArgs, + ); + if (!vtypeCompatible) { + this.report( + `incorrect argument ${i} '${subject.params[i].ident}'` + + `, expected ${ + vtypeToString( + extractGenericType( + subject.params[i].vtype, + genericArgs, + ), + ) + }` + + `, got ${vtypeToString(args[i])}`, + pos, + ); + break; + } + } + + expr.kind.genericArgs = genericArgs; + return this.concretizeVType(subject.returnType, genericArgs); + } + + private inferGenericArgs( + genericParams: VTypeGenericParam[], + params: VTypeParam[], + args: VType[], + pos: Pos, + ): { ok: true; value: GenericArgsMap } | { ok: false } { + const genericArgs: GenericArgsMap = {}; + + for (let i = 0; i < params.length; ++i) { + if (!this.vtypeContainsGeneric(params[i].vtype)) { + continue; + } + const { + a: generic, + b: concrete, + } = this.reduceToSignificant(params[i].vtype, args[i]); + if (generic.type !== "generic") { + throw new Error(); + } + const paramId = generic.param.id; + if ( + paramId in genericArgs && + !vtypesEqual(genericArgs[paramId], concrete) + ) { + this.report( + `according to inferrence, argument ${i} has a conflicting type`, + pos, + ); + return { ok: false }; + } + genericArgs[paramId] = concrete; + } + + for (const param of genericParams) { + if (!(param.id in genericArgs)) { + this.report(`could not infer generic type ${param.ident}`, pos); + return { ok: false }; + } + } + + return { ok: true, value: genericArgs }; + } + + private reduceToSignificant(a: VType, b: VType): { a: VType; b: VType } { + if (a.type !== b.type) { + return { a, b }; + } + if (a.type === "array" && b.type === "array") { + return this.reduceToSignificant(a.inner, b.inner); + } + throw new Error("idk what to do here"); + } + + private vtypeContainsGeneric(vtype: VType): boolean { + switch (vtype.type) { + case "error": + case "string": + case "unknown": + case "null": + case "int": + case "bool": + return false; + case "array": + return this.vtypeContainsGeneric(vtype.inner); + case "struct": + return vtype.fields.some((field) => + this.vtypeContainsGeneric(field.vtype) + ); + case "fn": + throw new Error("not implemented"); + case "generic": + return true; + case "generic_spec": + throw new Error("listen kid, grrrrrrrr"); + } + } + + private checkCallExprExplicitGenericsTail( + expr: Expr, + subject: VType, + ): VType { + if ( + expr.kind.type !== "call" || subject.type !== "generic_spec" || + subject.subject.type !== "fn" + ) { + throw new Error(); + } + const pos = expr.pos; + const inner = subject.subject; + const params = inner.params; + const args = expr.kind.args.map((arg) => this.checkExpr(arg)); + if (args.length !== params.length) { + this.report( + `incorrect number of arguments` + + `, expected ${params.length}`, + pos, + ); + } + for (let i = 0; i < args.length; ++i) { + const vtypeCompatible = vtypesEqual( + args[i], + params[i].vtype, + subject.genericArgs, + ); + if (!vtypeCompatible) { + this.report( + `incorrect argument ${i} '${inner.params[i].ident}'` + + `, expected ${ + vtypeToString( + extractGenericType( + params[i].vtype, + subject.genericArgs, + ), + ) + }` + + `, got ${vtypeToString(args[i])}`, + pos, + ); + break; + } + } + + expr.kind.genericArgs = subject.genericArgs; + return this.concretizeVType( + subject.subject.returnType, + subject.genericArgs, + ); + } + + private concretizeVType( + vtype: VType, + generics: GenericArgsMap, + ): VType { + switch (vtype.type) { + case "error": + case "unknown": + case "string": + case "null": + case "int": + case "bool": + return vtype; + case "array": + return { + type: "array", + inner: this.concretizeVType(vtype.inner, generics), + }; + case "struct": + return { + type: "struct", + fields: vtype.fields.map((field) => ({ + ...field, + vtype: this.concretizeVType(field.vtype, generics), + })), + }; + case "fn": + throw new Error("not implemented"); + case "generic": + return generics[vtype.param.id]; + case "generic_spec": + throw new Error("not implemented"); + } + } + public checkPathExpr(expr: Expr): VType { if (expr.kind.type !== "path") { throw new Error(); diff --git a/compiler/compiler.ts b/compiler/compiler.ts index 026a173..aafc601 100644 --- a/compiler/compiler.ts +++ b/compiler/compiler.ts @@ -50,7 +50,7 @@ export class Compiler { const lowerer = new Lowerer(monoFns, callMap, lexer.currentPos()); const { program, fnNames } = lowerer.lower(); - //lowerer.printProgram(); + lowerer.printProgram(); return { program, fnNames }; } diff --git a/compiler/mono.ts b/compiler/mono.ts index fa6d00a..9ccc45b 100644 --- a/compiler/mono.ts +++ b/compiler/mono.ts @@ -3,6 +3,7 @@ import { AstVisitor, visitExpr, VisitRes, visitStmts } from "./ast_visitor.ts"; import { GenericArgsMap, VType } from "./vtype.ts"; export class Monomorphizer { + private fnIdCounter = 0; private fns: MonoFnsMap = {}; private callMap: MonoCallNameGenMap = {}; private allFns: Map; @@ -22,11 +23,13 @@ export class Monomorphizer { stmt: Stmt, genericArgs?: GenericArgsMap, ): MonoFn { - const nameGen = monoFnNameGen(stmt, genericArgs); + const id = this.fnIdCounter; + this.fnIdCounter += 1; + const nameGen = monoFnNameGen(id, stmt, genericArgs); if (nameGen in this.fns) { return this.fns[nameGen]; } - const monoFn = { nameGen, stmt, genericArgs }; + const monoFn = { id, nameGen, stmt, genericArgs }; this.fns[nameGen] = monoFn; const calls = new CallCollector().collect(stmt); for (const call of calls) { @@ -34,7 +37,10 @@ export class Monomorphizer { if (call.kind.type !== "call") { throw new Error(); } - if (call.kind.subject.vtype?.type === "fn") { + if ( + call.kind.subject.vtype?.type === "fn" && + call.kind.subject.vtype.genericParams === undefined + ) { const fn = this.allFns.get(call.kind.subject.vtype.stmtId); if (fn === undefined) { throw new Error(); @@ -43,6 +49,40 @@ export class Monomorphizer { this.callMap[call.id] = monoFn.nameGen; continue; } + if ( + call.kind.subject.vtype?.type === "fn" && + call.kind.subject.vtype.genericParams !== undefined + ) { + if (call.kind.genericArgs === undefined) { + throw new Error(); + } + const genericArgs = call.kind.genericArgs; + + const monoArgs: GenericArgsMap = {}; + for (const key in call.kind.genericArgs) { + const vtype = genericArgs[key]; + if (vtype.type === "generic") { + if (genericArgs === undefined) { + throw new Error(); + } + monoArgs[key] = genericArgs[vtype.param.id]; + } else { + monoArgs[key] = vtype; + } + } + + const fnType = call.kind.subject.vtype!; + if (fnType.type !== "fn") { + throw new Error(); + } + const fn = this.allFns.get(fnType.stmtId); + if (fn === undefined) { + throw new Error(); + } + const monoFn = this.monomorphizeFn(fn, monoArgs); + this.callMap[call.id] = monoFn.nameGen; + continue; + } if (call.kind.subject.vtype?.type === "generic_spec") { const genericSpecType = call.kind.subject.vtype!; if (genericSpecType.subject.type !== "fn") { @@ -85,6 +125,7 @@ export type MonoResult = { export type MonoFnsMap = { [nameGen: string]: MonoFn }; export type MonoFn = { + id: number; nameGen: string; stmt: Stmt; genericArgs?: GenericArgsMap; @@ -92,7 +133,11 @@ export type MonoFn = { export type MonoCallNameGenMap = { [exprId: number]: string }; -function monoFnNameGen(stmt: Stmt, genericArgs?: GenericArgsMap): string { +function monoFnNameGen( + id: number, + stmt: Stmt, + genericArgs?: GenericArgsMap, +): string { if (stmt.kind.type !== "fn") { throw new Error(); } @@ -100,12 +145,12 @@ function monoFnNameGen(stmt: Stmt, genericArgs?: GenericArgsMap): string { return "main"; } if (genericArgs === undefined) { - return `${stmt.kind.ident}_${stmt.id}`; + return `${stmt.kind.ident}_${id}`; } const args = Object.values(genericArgs) .map((arg) => vtypeNameGenPart(arg)) .join("_"); - return `${stmt.kind.ident}_${stmt.id}_${args}`; + return `${stmt.kind.ident}_${id}_${args}`; } function vtypeNameGenPart(vtype: VType): string { diff --git a/examples/generic_array.slg b/examples/generic_array.slg index 11cb0bf..82cf74e 100644 --- a/examples/generic_array.slg +++ b/examples/generic_array.slg @@ -1,4 +1,4 @@ - +// fn array_new() -> [T] #[builtin(ArrayNew)] {} fn array_push(array: [T], value: T) #[builtin(ArrayPush)] {} @@ -8,12 +8,12 @@ fn array_at(array: [T], index: int) -> string #[builtin(ArrayAt)] {} fn main() { let strings = array_new::(); - array_push::(strings, "hello"); - array_push::(strings, "world"); + array_push(strings, "hello"); + array_push(strings, "world"); let ints = array_new::(); - array_push::(ints, 1); - array_push::(ints, 2); + array_push(ints, 1); + array_push(ints, 2); } diff --git a/slige-run.sh b/slige-run.sh index 610eabd..056163a 100755 --- a/slige-run.sh +++ b/slige-run.sh @@ -13,7 +13,7 @@ fi echo Compiling $1... -deno run --allow-read --allow-write compiler/main.ts $1 +deno run --allow-read --allow-write --check compiler/main.ts $1 echo Running out.slgbc...