#![allow(unused_variables)] use std::collections::HashMap; use crate::{ checked::{self, Node, NodeKind}, ir::{Data, DataKind, Fn, Op}, sym::Syms, }; pub struct Program { pub fns: Vec, pub entry: u64, pub sym: Syms, } struct Compiler { syms: Syms, } impl Compiler { pub fn new(syms: Syms) -> Self { Self { syms } } pub fn compile(self, ast: &[Node]) -> Result { let fns = FnCompiler::new(&self.syms).compile_entry(ast)?; Ok(Program { fns, sym: self.syms, entry: 0, }) } fn error>(&mut self, msg: S) { println!("ir compiler: {}", msg.into()); } } struct FnCompiler<'syms> { syms: &'syms Syms, ops: Vec, data: Vec, fns: Vec, local_count: i32, local_map: HashMap, break_stack: Vec, } impl<'syms> FnCompiler<'syms> { pub fn new(syms: &'syms Syms) -> Self { Self { syms, ops: Vec::new(), fns: Vec::new(), data: Vec::new(), local_count: 0, local_map: HashMap::new(), break_stack: Vec::new(), } } pub fn compile_entry(mut self, ast: &[Node]) -> Result, ()> { for stmt in ast { self.compile_stmt(stmt)?; if stmt.typ != checked::Type::Unit { self.ops.push(Op::Pop); } } self.fns.push(Fn { ops: self.ops, data: self.data, id: 0, arg_count: 0, local_count: self.local_count, }); Ok(self.fns) } pub fn compile(mut self, node: &Node) -> Result, ()> { let Node { kind: NodeKind::Fn { subject, params, return_typ: _, body, id, }, .. } = node else { unreachable!() }; for param in params { let NodeKind::Param { subject, typ } = ¶m.kind else { unreachable!() }; let NodeKind::Id(param_id) = subject.kind else { unreachable!() }; let sym_table = self.syms.view(body.table_id); let sym = sym_table.get(param_id).unwrap(); let local_id = self.local_count; self.local_count += 1; self.local_map.insert(sym.uid, local_id); } let NodeKind::Block(body) = &body.kind else { unreachable!() }; for stmt in body { self.compile_stmt(stmt)?; if stmt.typ != checked::Type::Unit { self.push(Op::Pop); } } self.fns.push(Fn { ops: self.ops, data: self.data, id: *id, arg_count: params.len() as i32, local_count: self.local_count, }); Ok(self.fns) } fn compile_stmt(&mut self, stmt: &Node) -> Result<(), ()> { let syms = self.syms.view(stmt.table_id); match &stmt.kind { NodeKind::Error => return Err(()), NodeKind::Break => { let addr = self.ops.len(); self.push(Op::Jump(0)); self.break_stack.push(addr); } NodeKind::Assign { subject, value } => { self.compile_expr(value)?; match subject.kind { NodeKind::Error => { return Err(()); } NodeKind::Id(id) => { let sym_table = self.syms.view(stmt.table_id); let sym = sym_table.get(id).unwrap(); let local_id = self.local_map.get(&sym.uid).unwrap(); self.push(Op::StoreLocal(*local_id)); } _ => unimplemented!(), } } NodeKind::Let { subject, value } => { let NodeKind::Param { subject, typ: _ } = &subject.kind else { unreachable!() }; let NodeKind::Id(subject_id) = subject.kind else { unreachable!()}; let sym_table = self.syms.view(stmt.table_id); let sym = sym_table.get(subject_id).unwrap(); let local_id = self.local_count; self.local_map.insert(sym.uid, local_id); self.local_count += 1; self.compile_expr(value)?; self.push(Op::StoreLocal(local_id)); } NodeKind::Fn { subject, .. } => { let mut compiled_fn = FnCompiler::new(self.syms).compile(stmt)?; self.fns.append(&mut compiled_fn); let NodeKind::Id(subject_id) = subject.kind else { unreachable!()}; let sym_table = self.syms.view(stmt.table_id); let sym = sym_table.get(subject_id).unwrap(); self.local_map.insert(sym.uid, self.local_count); self.local_count += 1; } NodeKind::Return { value } => { self.ops.push(Op::Return); } _ => { self.compile_expr(stmt)?; } } Ok(()) } fn compile_expr(&mut self, expr: &Node) -> Result<(), ()> { let syms = self.syms.view(expr.table_id); match &expr.kind { NodeKind::Error => return Err(()), NodeKind::Id(id) => { let sym = syms.get(*id).unwrap(); let local_id = self.local_map.get(&sym.uid).unwrap(); self.push(Op::LoadLocal(*local_id)); } NodeKind::Int(value) => match expr.typ { checked::Type::I32 => self.push(Op::PushI32(*value as i32)), checked::Type::U32 => self.push(Op::PushU32(*value as u32)), _ => unreachable!(), }, NodeKind::Str(value) => { let id = self.data.len() as u64; let bytes = value.bytes().collect(); self.data.push(Data { kind: DataKind::U8(bytes), id, }); self.ops.push(Op::PushStaticPtr(id)); } NodeKind::Group(expr) => { self.compile_expr(expr)?; } NodeKind::Block(stmts) => { let mut last_typ = None; for stmt in stmts { if last_typ.filter(|typ| *typ != checked::Type::Unit).is_some() { self.push(Op::Pop); } last_typ = Some(stmt.typ.clone()); self.compile_stmt(stmt)?; } } NodeKind::Call { subject, args } => { self.compile_expr(subject)?; for arg in args { self.compile_expr(arg)?; } match subject.typ { checked::Type::Fn { id, params: _, return_typ: _, } => { self.push(Op::CallFn(id, args.len() as i32)); } _ => { self.push(Op::CallPtr(args.len() as i32)); } } } NodeKind::If { cond, truthy, falsy, } => { self.compile_expr(cond)?; let l0 = self.ops.len(); self.push(Op::JumpIfFalse(0)); self.compile_expr(truthy)?; let l1 = self.ops.len(); if let Op::JumpIfFalse(ref mut addr) = self.ops[l0] { *addr = l1 as u64; } if let Some(falsy) = falsy { self.push(Op::Jump(0)); self.compile_expr(falsy)?; let l2 = self.ops.len(); if let Op::Jump(ref mut addr) = self.ops[l1] { *addr = l2 as u64; } } } NodeKind::Loop { body } => { let l0 = self.ops.len() as u64; self.compile_expr(body)?; let l1 = self.ops.len() as u64; for op in self.break_stack.drain(..) { let Op::Jump(ref mut addr) = self.ops[op] else { unreachable!() }; *addr = l1; } self.push(Op::Jump(l0)); } _ => unreachable!(), } Ok(()) } fn push(&mut self, op: Op) { self.ops.push(op); } fn error>(&mut self, msg: S) { println!("ir compiler: {}", msg.into()); } } #[test] fn test_compiler() { use crate::checker::{Checker, IdGen}; use crate::parser::Parser; use Op::*; use pretty_assertions::assert_eq; struct SeqIdGen(u64); impl IdGen for SeqIdGen { fn new() -> Self { Self(0) } fn gen(&mut self) -> u64 { let v = self.0; self.0 += 1; v } } let compile = |text| { let mut checker = Checker::::new_with_fn_id_gen(); let checked = checker.check(&Parser::new(text).parse()); let syms = checker.finish(); let compiled = Compiler::new(syms).compile(&checked); compiled.map(|program| program.fns) }; assert_eq!( compile("fn test(a: i32) -> i32 { a; } test(123);"), Ok(vec![ Fn { ops: vec![LoadLocal(0), Pop], data: vec![], id: 0, arg_count: 1, local_count: 1 }, Fn { ops: vec![LoadLocal(0), PushI32(123), CallFn(0, 1), Pop], data: vec![], id: 0, arg_count: 0, local_count: 1 } ]) ); assert_eq!( compile("let a = 5; let b = 3; a; b; a = b;"), Ok(vec![Fn { ops: vec![ PushI32(5), StoreLocal(0), PushI32(3), StoreLocal(1), LoadLocal(0), Pop, LoadLocal(1), Pop, LoadLocal(1), StoreLocal(0), ], data: vec![], id: 0, arg_count: 0, local_count: 2, }]) ); assert_eq!( compile("let a = \"hello\";"), Ok(vec![Fn { ops: vec![PushStaticPtr(0), StoreLocal(0)], data: vec![Data { kind: DataKind::U8(vec![104, 101, 108, 108, 111]), id: 0 }], id: 0, arg_count: 0, local_count: 1 }]) ); assert_eq!( compile("if 1 { 2; } if 1 { 2; } else { 3; }"), Ok(vec![Fn { ops: vec![ PushI32(1), JumpIfFalse(3), PushI32(2), PushI32(1), JumpIfFalse(6), PushI32(2), Jump(8), PushI32(3), Pop, ], data: vec![], id: 0, arg_count: 0, local_count: 0 }]) ); assert_eq!( compile("let a = if 1 { 2; } else { 3; };"), Ok(vec![Fn { ops: vec![ PushI32(1), JumpIfFalse(3), PushI32(2), Jump(5), PushI32(3), StoreLocal(0), ], data: vec![], id: 0, arg_count: 0, local_count: 1 }]) ); assert_eq!( compile("loop { break; }"), Ok(vec![Fn { ops: vec![Jump(1), Jump(0)], data: vec![], id: 0, arg_count: 0, local_count: 0, },]) ); }