diff --git a/src/ir.rs b/src/ir.rs index aed2e30..053428a 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -1,5 +1,11 @@ #[derive(Clone, PartialEq, Debug)] -pub enum Data { +pub struct Data { + pub kind: DataKind, + pub id: u64, +} + +#[derive(Clone, PartialEq, Debug)] +pub enum DataKind { U8(Vec), } @@ -26,6 +32,7 @@ pub enum Op { LoadLocal(i32), Jump(u64), JumpIfTrue(u64), - Call(u64, i32), + CallFn(u64, i32), + CallPtr(i32), Return, } diff --git a/src/ir_compiler.rs b/src/ir_compiler.rs index b4824fb..c0b2416 100644 --- a/src/ir_compiler.rs +++ b/src/ir_compiler.rs @@ -4,104 +4,255 @@ use std::collections::HashMap; use crate::{ checked::{self, Node, NodeKind}, - ir::{Block, Fn, Op}, + ir::{Block, Data, DataKind, Fn, Op}, sym::Syms, }; -pub struct IrProgram { +pub struct Program { pub fns: Vec, + pub data: Vec, pub entry: u64, pub sym: Syms, } struct Compiler { syms: Syms, - fns: Vec, } impl Compiler { pub fn new(syms: Syms) -> Self { - Self { - syms, - fns: Vec::new(), - } + Self { syms } } - pub fn compile(mut self, ast: &[Node]) -> Result { - let (blocks, local_count) = self.compile_fn_body(ast)?; - let entry_fn = Fn { - blocks, - id: 0, - arg_count: 0, - local_count, - }; - self.fns.push(entry_fn); - Ok(IrProgram { - fns: self.fns, + pub fn compile(self, ast: &[Node]) -> Result { + let (fns, data) = FnCompiler::new(&self.syms).compile_entry(ast)?; + Ok(Program { + fns, + data, sym: self.syms, entry: 0, }) } - fn compile_fn(&mut self, node: &Node) -> Result { - let Node { kind: NodeKind::Fn { subject: _, params, return_typ: _, body, id }, .. } = node else { + fn error>(&mut self, msg: S) { + println!("ir compiler: {}", msg.into()); + } +} + +struct FnCompiler<'syms> { + syms: &'syms Syms, + ops: Vec, + local_count: i32, + local_map: HashMap, + fns: Vec, + data: Vec, +} + +impl<'syms> FnCompiler<'syms> { + pub fn new(syms: &'syms Syms) -> Self { + Self { + syms, + ops: Vec::new(), + local_count: 0, + local_map: HashMap::new(), + fns: Vec::new(), + data: Vec::new(), + } + } + + pub fn compile_entry(mut self, ast: &[Node]) -> Result<(Vec, Vec), ()> { + let mut blocks = Vec::new(); + for stmt in ast { + blocks.append(&mut self.compile_stmt(stmt)?); + if stmt.typ != checked::Type::Unit { + self.ops.push(Op::Pop); + } + } + self.fns.push(Fn { + blocks, + id: 0, + arg_count: 0, + local_count: self.local_count, + }); + Ok((self.fns, self.data)) + } + + pub fn compile(mut self, node: &Node) -> Result { + let Node { + kind: + NodeKind::Fn { + subject: _, + params, + return_typ: _, + body, + id, + }, + .. + } = node + else { unreachable!() }; - let NodeKind::Block(body) = &body.kind else { unreachable!() }; - let (blocks, local_count) = self.compile_fn_body(body)?; + let NodeKind::Block(body) = &body.kind else { + unreachable!() + }; + let mut blocks = Vec::new(); + for stmt in body { + blocks.append(&mut self.compile_stmt(stmt)?); + if stmt.typ != checked::Type::Unit { + self.ops.push(Op::Pop); + } + } + self.push_ops(&mut blocks); Ok(Fn { blocks, id: *id, arg_count: params.len() as i32, - local_count, + local_count: self.local_count, }) } - fn compile_fn_body(&mut self, body: &[Node]) -> Result<(Vec, i32), ()> { + pub fn compile_stmt(&mut self, stmt: &Node) -> Result, ()> { + let syms = self.syms.view(stmt.table_id); let mut blocks = Vec::new(); - let mut ops = Vec::::new(); - let mut local_count = 0; - let mut local_map = HashMap::::new(); - for node in body { - let syms = self.syms.view(node.table_id); - match &node.kind { - NodeKind::Error => return Err(()), - NodeKind::Id(id) => { - let sym = syms.get(*id).unwrap(); - let local_id = local_map.get(&sym.uid).unwrap(); - ops.push(Op::LoadLocal(*local_id)); - } - NodeKind::Int(value) => match node.typ { - checked::Type::I32 => ops.push(Op::PushI32(*value as i32)), - checked::Type::U32 => ops.push(Op::PushU32(*value as u32)), - _ => unreachable!(), - }, - NodeKind::String(_) => todo!(), - NodeKind::Group(_) => todo!(), - NodeKind::Block(_) => todo!(), - NodeKind::Call { subject, args } => todo!(), - NodeKind::If { - cond, - truthy, - falsy, - } => todo!(), - NodeKind::Loop { body } => todo!(), - NodeKind::Break => todo!(), - NodeKind::Assign { subject, value } => todo!(), - NodeKind::Let { subject, value } => todo!(), - NodeKind::Fn { - subject, - params, - return_typ, - body, - id, - } => todo!(), - NodeKind::Return { value } => todo!(), - NodeKind::Param { subject, typ } => todo!(), + match &stmt.kind { + NodeKind::Error => return Err(()), + NodeKind::Break => todo!(), + NodeKind::Assign { subject, value } => todo!(), + NodeKind::Let { subject, value } => todo!(), + NodeKind::Fn { + subject, + params, + return_typ, + body, + id, + } => todo!(), + NodeKind::Return { value } => { + self.ops.push(Op::Return); + } + _ => { + blocks.append(&mut self.compile_expr(stmt)?); } } + if stmt.typ != checked::Type::Unit { + self.ops.push(Op::Pop); + } + Ok(blocks) + } + + fn compile_expr(&mut self, expr: &Node) -> Result, ()> { + let syms = self.syms.view(expr.table_id); + let mut blocks = Vec::new(); + match &expr.kind { + NodeKind::Error => return Err(()), + NodeKind::Id(id) => { + let sym = syms.get(*id).unwrap(); + let local_id = self.local_map.get(&sym.uid).unwrap(); + self.ops.push(Op::LoadLocal(*local_id)); + } + NodeKind::Int(value) => match expr.typ { + checked::Type::I32 => self.ops.push(Op::PushI32(*value as i32)), + checked::Type::U32 => self.ops.push(Op::PushU32(*value as u32)), + _ => unreachable!(), + }, + NodeKind::String(value) => { + let id = self.data.len() as u64; + let bytes = value.bytes().collect(); + self.data.push(Data { + kind: DataKind::U8(bytes), + id, + }); + self.ops.push(Op::PushStaticPtr(id)); + } + NodeKind::Group(expr) => { + blocks.append(&mut self.compile_expr(expr)?); + } + NodeKind::Block(stmts) => { + let mut last_typ = None; + for stmt in stmts { + if last_typ.filter(|typ| *typ != checked::Type::Unit).is_some() { + self.ops.push(Op::Pop); + } + last_typ = Some(stmt.typ.clone()); + self.compile_stmt(stmt)?; + } + } + NodeKind::Call { subject, args } => { + self.compile_expr(subject)?; + for arg in args { + self.compile_expr(arg)?; + } + match subject.typ { + checked::Type::Fn { + id, + params: _, + return_typ: _, + } => { + self.ops.push(Op::CallFn(id, args.len() as i32)); + } + _ => { + self.ops.push(Op::CallPtr(args.len() as i32)); + } + } + } + NodeKind::If { + cond, + truthy, + falsy, + } => { + self.compile_expr(cond)?; + let cond_idx = blocks.len(); + self.push_ops(&mut blocks); + + match falsy { + Some(falsy) => { + let truthy_first_idx = blocks.len() as u64; + let mut truthy = self.compile_expr(truthy)?; + let truthy_last_idx = blocks.len(); + self.push_ops(&mut truthy); + blocks.append(&mut truthy); + + let falsy_first_idx = blocks.len() as u64; + let mut falsy = self.compile_expr(falsy)?; + let falsy_last_idx = blocks.len(); + self.push_ops(&mut falsy); + blocks.append(&mut falsy); + + let after_idx = blocks.len() as u64; + + blocks[cond_idx].ops.push(Op::JumpIfTrue(truthy_first_idx)); + blocks[cond_idx].ops.push(Op::Jump(falsy_first_idx)); + blocks[truthy_last_idx].ops.push(Op::Jump(after_idx)); + blocks[falsy_last_idx].ops.push(Op::Jump(after_idx)); + } + None => { + let truthy_first_idx = blocks.len() as u64; + let mut truthy = self.compile_expr(truthy)?; + let truthy_last_idx = blocks.len(); + self.push_ops(&mut truthy); + blocks.append(&mut truthy); + + let after_idx = blocks.len() as u64; + + blocks[cond_idx].ops.push(Op::JumpIfTrue(truthy_first_idx)); + blocks[cond_idx].ops.push(Op::Jump(after_idx)); + blocks[truthy_last_idx].ops.push(Op::Jump(after_idx)); + } + } + } + NodeKind::Loop { body } => { + let body = self.compile_expr(body)?; + let body_idx = blocks.len(); + } + _ => unreachable!(), + } + self.push_ops(&mut blocks); + Ok(blocks) + } + + fn push_ops(&mut self, blocks: &mut Vec) { + let mut ops = Vec::new(); + std::mem::swap(&mut self.ops, &mut ops); blocks.push(Block { ops }); - Ok((blocks, local_count)) } fn error>(&mut self, msg: S) { @@ -110,7 +261,7 @@ impl Compiler { } #[test] -fn test_checker() { +fn test_compiler() { use crate::checker::{Checker, IdGen}; use crate::parser::Parser; use pretty_assertions::assert_eq; @@ -134,11 +285,11 @@ fn test_checker() { let checked = checker.check(&Parser::new(text).parse()); let syms = checker.finish(); let compiled = Compiler::new(syms).compile(&checked); - compiled.map(|program| program.fns) + compiled }; assert_eq!( - compile("123;"), + compile("123;").map(|program| program.fns), Ok(vec![Fn { blocks: vec![Block { ops: vec![PushI32(123)]