413 lines
12 KiB
Rust
413 lines
12 KiB
Rust
#![allow(unused_variables)]
|
|
|
|
use std::collections::HashMap;
|
|
|
|
use crate::{
|
|
checked::{self, Node, NodeKind},
|
|
ir::{Data, DataKind, Fn, Op},
|
|
sym::Syms,
|
|
};
|
|
|
|
pub struct Program {
|
|
pub fns: Vec<Fn>,
|
|
pub entry: u64,
|
|
pub sym: Syms,
|
|
}
|
|
|
|
struct Compiler {
|
|
syms: Syms,
|
|
}
|
|
|
|
impl Compiler {
|
|
pub fn new(syms: Syms) -> Self {
|
|
Self { syms }
|
|
}
|
|
|
|
pub fn compile(self, ast: &[Node]) -> Result<Program, ()> {
|
|
let fns = FnCompiler::new(&self.syms).compile_entry(ast)?;
|
|
Ok(Program {
|
|
fns,
|
|
sym: self.syms,
|
|
entry: 0,
|
|
})
|
|
}
|
|
|
|
fn error<S: Into<String>>(&mut self, msg: S) {
|
|
println!("ir compiler: {}", msg.into());
|
|
}
|
|
}
|
|
|
|
struct FnCompiler<'syms> {
|
|
syms: &'syms Syms,
|
|
ops: Vec<Op>,
|
|
data: Vec<Data>,
|
|
fns: Vec<Fn>,
|
|
local_count: i32,
|
|
local_map: HashMap<u64, i32>,
|
|
break_stack: Vec<usize>,
|
|
}
|
|
|
|
impl<'syms> FnCompiler<'syms> {
|
|
pub fn new(syms: &'syms Syms) -> Self {
|
|
Self {
|
|
syms,
|
|
ops: Vec::new(),
|
|
fns: Vec::new(),
|
|
data: Vec::new(),
|
|
local_count: 0,
|
|
local_map: HashMap::new(),
|
|
break_stack: Vec::new(),
|
|
}
|
|
}
|
|
|
|
pub fn compile_entry(mut self, ast: &[Node]) -> Result<Vec<Fn>, ()> {
|
|
for stmt in ast {
|
|
self.compile_stmt(stmt)?;
|
|
if stmt.typ != checked::Type::Unit {
|
|
self.ops.push(Op::Pop);
|
|
}
|
|
}
|
|
self.fns.push(Fn {
|
|
ops: self.ops,
|
|
data: self.data,
|
|
id: 0,
|
|
arg_count: 0,
|
|
local_count: self.local_count,
|
|
});
|
|
Ok(self.fns)
|
|
}
|
|
|
|
pub fn compile(mut self, node: &Node) -> Result<Vec<Fn>, ()> {
|
|
let Node {
|
|
kind:
|
|
NodeKind::Fn {
|
|
subject,
|
|
params,
|
|
return_typ: _,
|
|
body,
|
|
id,
|
|
},
|
|
..
|
|
} = node
|
|
else {
|
|
unreachable!()
|
|
};
|
|
for param in params {
|
|
let NodeKind::Param { subject, typ } = ¶m.kind else { unreachable!() };
|
|
let NodeKind::Id(param_id) = subject.kind else { unreachable!() };
|
|
let sym_table = self.syms.view(body.table_id);
|
|
let sym = sym_table.get(param_id).unwrap();
|
|
let local_id = self.local_count;
|
|
self.local_count += 1;
|
|
self.local_map.insert(sym.uid, local_id);
|
|
}
|
|
let NodeKind::Block(body) = &body.kind else {
|
|
unreachable!()
|
|
};
|
|
for stmt in body {
|
|
self.compile_stmt(stmt)?;
|
|
if stmt.typ != checked::Type::Unit {
|
|
self.push(Op::Pop);
|
|
}
|
|
}
|
|
self.push(Op::Return);
|
|
self.fns.push(Fn {
|
|
ops: self.ops,
|
|
data: self.data,
|
|
id: *id,
|
|
arg_count: params.len() as i32,
|
|
local_count: self.local_count,
|
|
});
|
|
Ok(self.fns)
|
|
}
|
|
|
|
fn compile_stmt(&mut self, stmt: &Node) -> Result<(), ()> {
|
|
let syms = self.syms.view(stmt.table_id);
|
|
match &stmt.kind {
|
|
NodeKind::Error => return Err(()),
|
|
NodeKind::Break => {
|
|
let addr = self.ops.len();
|
|
self.push(Op::Jump(0));
|
|
self.break_stack.push(addr);
|
|
}
|
|
NodeKind::Assign { subject, value } => {
|
|
self.compile_expr(value)?;
|
|
match subject.kind {
|
|
NodeKind::Error => {
|
|
return Err(());
|
|
}
|
|
NodeKind::Id(id) => {
|
|
let sym_table = self.syms.view(stmt.table_id);
|
|
let sym = sym_table.get(id).unwrap();
|
|
let local_id = self.local_map.get(&sym.uid).unwrap();
|
|
self.push(Op::StoreLocal(*local_id));
|
|
}
|
|
_ => unimplemented!(),
|
|
}
|
|
}
|
|
NodeKind::Let { subject, value } => {
|
|
let NodeKind::Param { subject, typ: _ } = &subject.kind else { unreachable!() };
|
|
let NodeKind::Id(subject_id) = subject.kind else { unreachable!()};
|
|
let sym_table = self.syms.view(stmt.table_id);
|
|
let sym = sym_table.get(subject_id).unwrap();
|
|
let local_id = self.local_count;
|
|
self.local_map.insert(sym.uid, local_id);
|
|
self.local_count += 1;
|
|
self.compile_expr(value)?;
|
|
self.push(Op::StoreLocal(local_id));
|
|
}
|
|
NodeKind::Fn { subject, .. } => {
|
|
let mut compiled_fn = FnCompiler::new(self.syms).compile(stmt)?;
|
|
self.fns.append(&mut compiled_fn);
|
|
let NodeKind::Id(subject_id) = subject.kind else { unreachable!()};
|
|
let sym_table = self.syms.view(stmt.table_id);
|
|
let sym = sym_table.get(subject_id).unwrap();
|
|
self.local_map.insert(sym.uid, self.local_count);
|
|
self.local_count += 1;
|
|
}
|
|
NodeKind::Return { value } => {
|
|
self.ops.push(Op::Return);
|
|
}
|
|
_ => {
|
|
self.compile_expr(stmt)?;
|
|
}
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn compile_expr(&mut self, expr: &Node) -> Result<(), ()> {
|
|
let syms = self.syms.view(expr.table_id);
|
|
match &expr.kind {
|
|
NodeKind::Error => return Err(()),
|
|
NodeKind::Id(id) => {
|
|
let sym = syms.get(*id).unwrap();
|
|
let local_id = self.local_map.get(&sym.uid).unwrap();
|
|
self.push(Op::LoadLocal(*local_id));
|
|
}
|
|
NodeKind::Int(value) => match expr.typ {
|
|
checked::Type::I32 => self.push(Op::PushI32(*value as i32)),
|
|
checked::Type::U32 => self.push(Op::PushU32(*value as u32)),
|
|
_ => unreachable!(),
|
|
},
|
|
NodeKind::Str(value) => {
|
|
let id = self.data.len() as u64;
|
|
let bytes = value.bytes().collect();
|
|
self.data.push(Data {
|
|
kind: DataKind::U8(bytes),
|
|
id,
|
|
});
|
|
self.ops.push(Op::PushStaticPtr(id));
|
|
}
|
|
NodeKind::Group(expr) => {
|
|
self.compile_expr(expr)?;
|
|
}
|
|
NodeKind::Block(stmts) => {
|
|
let mut last_typ = None;
|
|
for stmt in stmts {
|
|
if last_typ.filter(|typ| *typ != checked::Type::Unit).is_some() {
|
|
self.push(Op::Pop);
|
|
}
|
|
last_typ = Some(stmt.typ.clone());
|
|
self.compile_stmt(stmt)?;
|
|
}
|
|
}
|
|
NodeKind::Call { subject, args } => {
|
|
self.compile_expr(subject)?;
|
|
for arg in args {
|
|
self.compile_expr(arg)?;
|
|
}
|
|
match subject.typ {
|
|
checked::Type::Fn {
|
|
id,
|
|
params: _,
|
|
return_typ: _,
|
|
} => {
|
|
self.push(Op::CallFn(id, args.len() as i32));
|
|
}
|
|
_ => {
|
|
self.push(Op::CallPtr(args.len() as i32));
|
|
}
|
|
}
|
|
}
|
|
NodeKind::If {
|
|
cond,
|
|
truthy,
|
|
falsy,
|
|
} => {
|
|
self.compile_expr(cond)?;
|
|
let l0 = self.ops.len();
|
|
self.push(Op::JumpIfFalse(0));
|
|
self.compile_expr(truthy)?;
|
|
let l1 = self.ops.len();
|
|
if let Op::JumpIfFalse(ref mut addr) = self.ops[l0] {
|
|
*addr = l1 as u64;
|
|
}
|
|
if let Some(falsy) = falsy {
|
|
self.push(Op::Jump(0));
|
|
self.compile_expr(falsy)?;
|
|
let l2 = self.ops.len();
|
|
if let Op::Jump(ref mut addr) = self.ops[l1] {
|
|
*addr = l2 as u64;
|
|
}
|
|
}
|
|
}
|
|
NodeKind::Loop { body } => {
|
|
let l0 = self.ops.len() as u64;
|
|
self.compile_expr(body)?;
|
|
let l1 = self.ops.len() as u64;
|
|
for op in self.break_stack.drain(..) {
|
|
let Op::Jump(ref mut addr) = self.ops[op] else { unreachable!() };
|
|
*addr = l1;
|
|
}
|
|
self.push(Op::Jump(l0));
|
|
}
|
|
_ => unreachable!(),
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
fn push(&mut self, op: Op) {
|
|
self.ops.push(op);
|
|
}
|
|
|
|
fn error<S: Into<String>>(&mut self, msg: S) {
|
|
println!("ir compiler: {}", msg.into());
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_compiler() {
|
|
use crate::checker::{Checker, IdGen};
|
|
use crate::parser::Parser;
|
|
use Op::*;
|
|
|
|
use pretty_assertions::assert_eq;
|
|
|
|
struct SeqIdGen(u64);
|
|
impl IdGen for SeqIdGen {
|
|
fn new() -> Self {
|
|
Self(0)
|
|
}
|
|
|
|
fn gen(&mut self) -> u64 {
|
|
let v = self.0;
|
|
self.0 += 1;
|
|
v
|
|
}
|
|
}
|
|
|
|
let compile = |text| {
|
|
let mut checker = Checker::<SeqIdGen>::new_with_fn_id_gen();
|
|
let checked = checker.check(&Parser::new(text).parse());
|
|
let syms = checker.finish();
|
|
let compiled = Compiler::new(syms).compile(&checked);
|
|
compiled.map(|program| program.fns)
|
|
};
|
|
|
|
assert_eq!(
|
|
compile("fn test(a: i32) -> i32 { a; } test(123);"),
|
|
Ok(vec![
|
|
Fn {
|
|
ops: vec![LoadLocal(0), Pop],
|
|
data: vec![],
|
|
id: 0,
|
|
arg_count: 1,
|
|
local_count: 1
|
|
},
|
|
Fn {
|
|
ops: vec![LoadLocal(0), PushI32(123), CallFn(0, 1), Pop],
|
|
data: vec![],
|
|
id: 0,
|
|
arg_count: 0,
|
|
local_count: 1
|
|
}
|
|
])
|
|
);
|
|
|
|
assert_eq!(
|
|
compile("let a = 5; let b = 3; a; b; a = b;"),
|
|
Ok(vec![Fn {
|
|
ops: vec![
|
|
PushI32(5),
|
|
StoreLocal(0),
|
|
PushI32(3),
|
|
StoreLocal(1),
|
|
LoadLocal(0),
|
|
Pop,
|
|
LoadLocal(1),
|
|
Pop,
|
|
LoadLocal(1),
|
|
StoreLocal(0),
|
|
],
|
|
data: vec![],
|
|
id: 0,
|
|
arg_count: 0,
|
|
local_count: 2,
|
|
}])
|
|
);
|
|
|
|
assert_eq!(
|
|
compile("let a = \"hello\";"),
|
|
Ok(vec![Fn {
|
|
ops: vec![PushStaticPtr(0), StoreLocal(0)],
|
|
data: vec![Data {
|
|
kind: DataKind::U8(vec![104, 101, 108, 108, 111]),
|
|
id: 0
|
|
}],
|
|
id: 0,
|
|
arg_count: 0,
|
|
local_count: 1
|
|
}])
|
|
);
|
|
|
|
assert_eq!(
|
|
compile("if 1 { 2; } if 1 { 2; } else { 3; }"),
|
|
Ok(vec![Fn {
|
|
ops: vec![
|
|
PushI32(1),
|
|
JumpIfFalse(3),
|
|
PushI32(2),
|
|
PushI32(1),
|
|
JumpIfFalse(6),
|
|
PushI32(2),
|
|
Jump(8),
|
|
PushI32(3),
|
|
Pop,
|
|
],
|
|
data: vec![],
|
|
id: 0,
|
|
arg_count: 0,
|
|
local_count: 0
|
|
}])
|
|
);
|
|
|
|
assert_eq!(
|
|
compile("let a = if 1 { 2; } else { 3; };"),
|
|
Ok(vec![Fn {
|
|
ops: vec![
|
|
PushI32(1),
|
|
JumpIfFalse(3),
|
|
PushI32(2),
|
|
Jump(5),
|
|
PushI32(3),
|
|
StoreLocal(0),
|
|
],
|
|
data: vec![],
|
|
id: 0,
|
|
arg_count: 0,
|
|
local_count: 1
|
|
}])
|
|
);
|
|
|
|
assert_eq!(
|
|
compile("loop { break; }"),
|
|
Ok(vec![Fn {
|
|
ops: vec![Jump(1), Jump(0)],
|
|
data: vec![],
|
|
id: 0,
|
|
arg_count: 0,
|
|
local_count: 0,
|
|
},])
|
|
);
|
|
}
|