diff --git a/src/checked.rs b/src/checked.rs index c982e2b..6795737 100644 --- a/src/checked.rs +++ b/src/checked.rs @@ -10,7 +10,7 @@ pub enum NodeKind { Error, Id(u64), Int(i64), - String(String), + Str(String), Group(Box), Block(Vec), Call { @@ -56,10 +56,10 @@ pub enum Type { Unit, I32, U32, - String, + Str, Fn { id: u64, - params: Vec, + params: Vec, return_typ: Box, }, } diff --git a/src/checker.rs b/src/checker.rs index abf1931..c055cbe 100644 --- a/src/checker.rs +++ b/src/checker.rs @@ -27,6 +27,7 @@ pub struct Checker { syms: Syms, fn_id_gen: FnIdGen, } + impl Checker { pub fn new() -> Self { Self { @@ -71,7 +72,13 @@ impl Checker { continue; } - let params = params.into_iter().map(|(_, param)| param).collect(); + let params = params + .into_iter() + .map(|(_, param)| { + let NodeKind::Param { subject: _, typ } = param.kind else {unreachable!()}; + typ.as_ref().cloned().unwrap() + }) + .collect::>(); let parsed::Node::Id(id) = subject.as_ref() else { unreachable!() }; @@ -132,7 +139,7 @@ impl Checker { Type::I32 }, ), - parsed::Node::String(value) => self.node(NodeKind::String(value.clone()), Type::String), + parsed::Node::Str(value) => self.node(NodeKind::Str(value.clone()), Type::Str), parsed::Node::Group(expr) => { let expr = self.check_expr(expr); let typ = expr.typ.clone(); @@ -171,8 +178,8 @@ impl Checker { } if args .iter() - .zip(params) - .map(|(arg, param)| self.compatible(&arg.typ, ¶m.typ)) + .zip(params.clone()) + .map(|(arg, param)| self.compatible(&arg.typ, ¶m)) .any(|is_compatible| !is_compatible) { self.error("incorrect args"); @@ -228,13 +235,24 @@ impl Checker { let subject = Box::new(self.check_expr(subject)); let value = Box::new(self.check_expr(value)); - let typ = if !self.compatible(&subject.typ, &value.typ) { + match subject.kind { + NodeKind::Error => { + return *subject; + } + NodeKind::Id(_) => {} + _ => { + self.error("cannot assign to expr"); + return self.node(NodeKind::Error, Type::Error); + } + } + + let _typ = if !self.compatible(&subject.typ, &value.typ) { self.error("incompatible types #3"); Type::Error } else { subject.typ.clone() }; - self.node(NodeKind::Assign { subject, value }, typ) + self.node(NodeKind::Assign { subject, value }, Type::Unit) } parsed::Node::Let { subject, value } => { let (subject, subject_typ) = match subject.as_ref() { @@ -274,11 +292,11 @@ impl Checker { _ => unreachable!(), }; - self.node(NodeKind::Let { subject, value }, typ) + self.node(NodeKind::Let { subject, value }, Type::Unit) } parsed::Node::Fn { subject, - params: _, + params, return_typ: _, body, } => { @@ -289,20 +307,33 @@ impl Checker { return self.node(NodeKind::Error,Type::Error); }; - let Type::Fn { id: fn_id, params, return_typ } = sym.typ else { + let Type::Fn { id: fn_id, params: param_typs, return_typ } = sym.typ else { self.error("redefintion"); return self.node(NodeKind::Error,Type::Error); }; self.syms.enter_scope(); + let params = params + .iter() + .zip(param_typs) + .map(|(param, typ)| { + let parsed::Node::Param { subject, .. } = param else { unreachable!() }; + let parsed::Node::Id(id) = subject.as_ref() else { unreachable!() }; + self.node( + NodeKind::Param { + subject: Box::new(self.node(NodeKind::Id(*id), Type::Unit)), + typ: Some(typ), + }, + Type::Unit, + ) + }) + .collect::>(); + for param in ¶ms { - let NodeKind::Param { - ref subject, - typ: Some(ref typ), - } = param.kind else { unreachable!() }; + let NodeKind::Param { ref subject, ref typ } = param.kind else { unreachable!() }; let NodeKind::Id(id) = subject.kind else { unreachable!() }; - self.syms.define(id, typ.clone()); + self.syms.define(id, typ.as_ref().cloned().unwrap()); } let body = Box::new(self.check_expr(body)); @@ -415,7 +446,7 @@ fn test_checker() { table_id: 0, }) }, - typ: I32, + typ: Unit, table_id: 0, }, Node { @@ -449,7 +480,7 @@ fn test_checker() { table_id: 0, }) }, - typ: I32, + typ: Unit, table_id: 0, }, Node { @@ -464,6 +495,7 @@ fn test_checker() { ] ); + println!("intentionally undefined"); assert_eq!( check("let a = 5; a; { a; let b = 5; b; } a; b;"), vec![ @@ -487,7 +519,7 @@ fn test_checker() { table_id: 0, }) }, - typ: I32, + typ: Unit, table_id: 0, }, Node { @@ -522,7 +554,7 @@ fn test_checker() { table_id: 1, }) }, - typ: I32, + typ: Unit, table_id: 1, }, Node { diff --git a/src/ir.rs b/src/ir.rs index 053428a..1ba88fb 100644 --- a/src/ir.rs +++ b/src/ir.rs @@ -1,27 +1,12 @@ -#[derive(Clone, PartialEq, Debug)] -pub struct Data { - pub kind: DataKind, - pub id: u64, -} - -#[derive(Clone, PartialEq, Debug)] -pub enum DataKind { - U8(Vec), -} - #[derive(Clone, PartialEq, Debug)] pub struct Fn { - pub blocks: Vec, + pub ops: Vec, + pub data: Vec, pub id: u64, pub arg_count: i32, pub local_count: i32, } -#[derive(Clone, PartialEq, Debug)] -pub struct Block { - pub ops: Vec, -} - #[derive(Clone, PartialEq, Debug)] pub enum Op { Pop, @@ -31,8 +16,20 @@ pub enum Op { StoreLocal(i32), LoadLocal(i32), Jump(u64), + JumpIfFalse(u64), JumpIfTrue(u64), CallFn(u64, i32), CallPtr(i32), Return, } + +#[derive(Clone, PartialEq, Debug)] +pub struct Data { + pub kind: DataKind, + pub id: u64, +} + +#[derive(Clone, PartialEq, Debug)] +pub enum DataKind { + U8(Vec), +} diff --git a/src/ir_compiler.rs b/src/ir_compiler.rs index c0b2416..1a06013 100644 --- a/src/ir_compiler.rs +++ b/src/ir_compiler.rs @@ -4,13 +4,12 @@ use std::collections::HashMap; use crate::{ checked::{self, Node, NodeKind}, - ir::{Block, Data, DataKind, Fn, Op}, + ir::{Data, DataKind, Fn, Op}, sym::Syms, }; pub struct Program { pub fns: Vec, - pub data: Vec, pub entry: u64, pub sym: Syms, } @@ -25,10 +24,9 @@ impl Compiler { } pub fn compile(self, ast: &[Node]) -> Result { - let (fns, data) = FnCompiler::new(&self.syms).compile_entry(ast)?; + let fns = FnCompiler::new(&self.syms).compile_entry(ast)?; Ok(Program { fns, - data, sym: self.syms, entry: 0, }) @@ -42,10 +40,11 @@ impl Compiler { struct FnCompiler<'syms> { syms: &'syms Syms, ops: Vec, + data: Vec, + fns: Vec, local_count: i32, local_map: HashMap, - fns: Vec, - data: Vec, + break_stack: Vec, } impl<'syms> FnCompiler<'syms> { @@ -53,35 +52,36 @@ impl<'syms> FnCompiler<'syms> { Self { syms, ops: Vec::new(), - local_count: 0, - local_map: HashMap::new(), fns: Vec::new(), data: Vec::new(), + local_count: 0, + local_map: HashMap::new(), + break_stack: Vec::new(), } } - pub fn compile_entry(mut self, ast: &[Node]) -> Result<(Vec, Vec), ()> { - let mut blocks = Vec::new(); + pub fn compile_entry(mut self, ast: &[Node]) -> Result, ()> { for stmt in ast { - blocks.append(&mut self.compile_stmt(stmt)?); + self.compile_stmt(stmt)?; if stmt.typ != checked::Type::Unit { self.ops.push(Op::Pop); } } self.fns.push(Fn { - blocks, + ops: self.ops, + data: self.data, id: 0, arg_count: 0, local_count: self.local_count, }); - Ok((self.fns, self.data)) + Ok(self.fns) } - pub fn compile(mut self, node: &Node) -> Result { + pub fn compile(mut self, node: &Node) -> Result, ()> { let Node { kind: NodeKind::Fn { - subject: _, + subject, params, return_typ: _, body, @@ -92,69 +92,103 @@ impl<'syms> FnCompiler<'syms> { else { unreachable!() }; + for param in params { + let NodeKind::Param { subject, typ } = ¶m.kind else { unreachable!() }; + let NodeKind::Id(param_id) = subject.kind else { unreachable!() }; + let sym_table = self.syms.view(body.table_id); + let sym = sym_table.get(param_id).unwrap(); + let local_id = self.local_count; + self.local_count += 1; + self.local_map.insert(sym.uid, local_id); + } let NodeKind::Block(body) = &body.kind else { unreachable!() }; - let mut blocks = Vec::new(); for stmt in body { - blocks.append(&mut self.compile_stmt(stmt)?); + self.compile_stmt(stmt)?; if stmt.typ != checked::Type::Unit { - self.ops.push(Op::Pop); + self.push(Op::Pop); } } - self.push_ops(&mut blocks); - Ok(Fn { - blocks, + self.fns.push(Fn { + ops: self.ops, + data: self.data, id: *id, arg_count: params.len() as i32, local_count: self.local_count, - }) + }); + Ok(self.fns) } - pub fn compile_stmt(&mut self, stmt: &Node) -> Result, ()> { + fn compile_stmt(&mut self, stmt: &Node) -> Result<(), ()> { let syms = self.syms.view(stmt.table_id); - let mut blocks = Vec::new(); match &stmt.kind { NodeKind::Error => return Err(()), - NodeKind::Break => todo!(), - NodeKind::Assign { subject, value } => todo!(), - NodeKind::Let { subject, value } => todo!(), - NodeKind::Fn { - subject, - params, - return_typ, - body, - id, - } => todo!(), + NodeKind::Break => { + let addr = self.ops.len(); + self.push(Op::Jump(0)); + self.break_stack.push(addr); + } + NodeKind::Assign { subject, value } => { + self.compile_expr(value)?; + match subject.kind { + NodeKind::Error => { + return Err(()); + } + NodeKind::Id(id) => { + let sym_table = self.syms.view(stmt.table_id); + let sym = sym_table.get(id).unwrap(); + let local_id = self.local_map.get(&sym.uid).unwrap(); + self.push(Op::StoreLocal(*local_id)); + } + _ => unimplemented!(), + } + } + NodeKind::Let { subject, value } => { + let NodeKind::Param { subject, typ: _ } = &subject.kind else { unreachable!() }; + let NodeKind::Id(subject_id) = subject.kind else { unreachable!()}; + let sym_table = self.syms.view(stmt.table_id); + let sym = sym_table.get(subject_id).unwrap(); + let local_id = self.local_count; + self.local_map.insert(sym.uid, local_id); + self.local_count += 1; + self.compile_expr(value)?; + self.push(Op::StoreLocal(local_id)); + } + NodeKind::Fn { subject, .. } => { + let mut compiled_fn = FnCompiler::new(self.syms).compile(stmt)?; + self.fns.append(&mut compiled_fn); + let NodeKind::Id(subject_id) = subject.kind else { unreachable!()}; + let sym_table = self.syms.view(stmt.table_id); + let sym = sym_table.get(subject_id).unwrap(); + self.local_map.insert(sym.uid, self.local_count); + self.local_count += 1; + } NodeKind::Return { value } => { self.ops.push(Op::Return); } _ => { - blocks.append(&mut self.compile_expr(stmt)?); + self.compile_expr(stmt)?; } } - if stmt.typ != checked::Type::Unit { - self.ops.push(Op::Pop); - } - Ok(blocks) + Ok(()) } - fn compile_expr(&mut self, expr: &Node) -> Result, ()> { + fn compile_expr(&mut self, expr: &Node) -> Result<(), ()> { let syms = self.syms.view(expr.table_id); - let mut blocks = Vec::new(); match &expr.kind { NodeKind::Error => return Err(()), NodeKind::Id(id) => { let sym = syms.get(*id).unwrap(); let local_id = self.local_map.get(&sym.uid).unwrap(); - self.ops.push(Op::LoadLocal(*local_id)); + self.push(Op::LoadLocal(*local_id)); } NodeKind::Int(value) => match expr.typ { - checked::Type::I32 => self.ops.push(Op::PushI32(*value as i32)), - checked::Type::U32 => self.ops.push(Op::PushU32(*value as u32)), + checked::Type::I32 => self.push(Op::PushI32(*value as i32)), + checked::Type::U32 => self.push(Op::PushU32(*value as u32)), _ => unreachable!(), }, - NodeKind::String(value) => { + NodeKind::Str(value) => { let id = self.data.len() as u64; let bytes = value.bytes().collect(); self.data.push(Data { @@ -164,13 +198,13 @@ impl<'syms> FnCompiler<'syms> { self.ops.push(Op::PushStaticPtr(id)); } NodeKind::Group(expr) => { - blocks.append(&mut self.compile_expr(expr)?); + self.compile_expr(expr)?; } NodeKind::Block(stmts) => { let mut last_typ = None; for stmt in stmts { if last_typ.filter(|typ| *typ != checked::Type::Unit).is_some() { - self.ops.push(Op::Pop); + self.push(Op::Pop); } last_typ = Some(stmt.typ.clone()); self.compile_stmt(stmt)?; @@ -187,10 +221,10 @@ impl<'syms> FnCompiler<'syms> { params: _, return_typ: _, } => { - self.ops.push(Op::CallFn(id, args.len() as i32)); + self.push(Op::CallFn(id, args.len() as i32)); } _ => { - self.ops.push(Op::CallPtr(args.len() as i32)); + self.push(Op::CallPtr(args.len() as i32)); } } } @@ -200,59 +234,39 @@ impl<'syms> FnCompiler<'syms> { falsy, } => { self.compile_expr(cond)?; - let cond_idx = blocks.len(); - self.push_ops(&mut blocks); - - match falsy { - Some(falsy) => { - let truthy_first_idx = blocks.len() as u64; - let mut truthy = self.compile_expr(truthy)?; - let truthy_last_idx = blocks.len(); - self.push_ops(&mut truthy); - blocks.append(&mut truthy); - - let falsy_first_idx = blocks.len() as u64; - let mut falsy = self.compile_expr(falsy)?; - let falsy_last_idx = blocks.len(); - self.push_ops(&mut falsy); - blocks.append(&mut falsy); - - let after_idx = blocks.len() as u64; - - blocks[cond_idx].ops.push(Op::JumpIfTrue(truthy_first_idx)); - blocks[cond_idx].ops.push(Op::Jump(falsy_first_idx)); - blocks[truthy_last_idx].ops.push(Op::Jump(after_idx)); - blocks[falsy_last_idx].ops.push(Op::Jump(after_idx)); - } - None => { - let truthy_first_idx = blocks.len() as u64; - let mut truthy = self.compile_expr(truthy)?; - let truthy_last_idx = blocks.len(); - self.push_ops(&mut truthy); - blocks.append(&mut truthy); - - let after_idx = blocks.len() as u64; - - blocks[cond_idx].ops.push(Op::JumpIfTrue(truthy_first_idx)); - blocks[cond_idx].ops.push(Op::Jump(after_idx)); - blocks[truthy_last_idx].ops.push(Op::Jump(after_idx)); + let l0 = self.ops.len(); + self.push(Op::JumpIfFalse(0)); + self.compile_expr(truthy)?; + let l1 = self.ops.len(); + if let Op::JumpIfFalse(ref mut addr) = self.ops[l0] { + *addr = l1 as u64; + } + if let Some(falsy) = falsy { + self.push(Op::Jump(0)); + self.compile_expr(falsy)?; + let l2 = self.ops.len(); + if let Op::Jump(ref mut addr) = self.ops[l1] { + *addr = l2 as u64; } } } NodeKind::Loop { body } => { - let body = self.compile_expr(body)?; - let body_idx = blocks.len(); + let l0 = self.ops.len() as u64; + self.compile_expr(body)?; + let l1 = self.ops.len() as u64; + for op in self.break_stack.drain(..) { + let Op::Jump(ref mut addr) = self.ops[op] else { unreachable!() }; + *addr = l1; + } + self.push(Op::Jump(l0)); } _ => unreachable!(), } - self.push_ops(&mut blocks); - Ok(blocks) + Ok(()) } - fn push_ops(&mut self, blocks: &mut Vec) { - let mut ops = Vec::new(); - std::mem::swap(&mut self.ops, &mut ops); - blocks.push(Block { ops }); + fn push(&mut self, op: Op) { + self.ops.push(op); } fn error>(&mut self, msg: S) { @@ -264,9 +278,10 @@ impl<'syms> FnCompiler<'syms> { fn test_compiler() { use crate::checker::{Checker, IdGen}; use crate::parser::Parser; - use pretty_assertions::assert_eq; use Op::*; + use pretty_assertions::assert_eq; + struct SeqIdGen(u64); impl IdGen for SeqIdGen { fn new() -> Self { @@ -285,18 +300,112 @@ fn test_compiler() { let checked = checker.check(&Parser::new(text).parse()); let syms = checker.finish(); let compiled = Compiler::new(syms).compile(&checked); - compiled + compiled.map(|program| program.fns) }; assert_eq!( - compile("123;").map(|program| program.fns), + compile("fn test(a: i32) -> i32 { a; } test(123);"), + Ok(vec![ + Fn { + ops: vec![LoadLocal(0), Pop], + data: vec![], + id: 0, + arg_count: 1, + local_count: 1 + }, + Fn { + ops: vec![LoadLocal(0), PushI32(123), CallFn(0, 1), Pop], + data: vec![], + id: 0, + arg_count: 0, + local_count: 1 + } + ]) + ); + + assert_eq!( + compile("let a = 5; let b = 3; a; b; a = b;"), Ok(vec![Fn { - blocks: vec![Block { - ops: vec![PushI32(123)] + ops: vec![ + PushI32(5), + StoreLocal(0), + PushI32(3), + StoreLocal(1), + LoadLocal(0), + Pop, + LoadLocal(1), + Pop, + LoadLocal(1), + StoreLocal(0), + ], + data: vec![], + id: 0, + arg_count: 0, + local_count: 2, + }]) + ); + + assert_eq!( + compile("let a = \"hello\";"), + Ok(vec![Fn { + ops: vec![PushStaticPtr(0), StoreLocal(0)], + data: vec![Data { + kind: DataKind::U8(vec![104, 101, 108, 108, 111]), + id: 0 }], id: 0, arg_count: 0, + local_count: 1 + }]) + ); + + assert_eq!( + compile("if 1 { 2; } if 1 { 2; } else { 3; }"), + Ok(vec![Fn { + ops: vec![ + PushI32(1), + JumpIfFalse(3), + PushI32(2), + PushI32(1), + JumpIfFalse(6), + PushI32(2), + Jump(8), + PushI32(3), + Pop, + ], + data: vec![], + id: 0, + arg_count: 0, local_count: 0 }]) ); + + assert_eq!( + compile("let a = if 1 { 2; } else { 3; };"), + Ok(vec![Fn { + ops: vec![ + PushI32(1), + JumpIfFalse(3), + PushI32(2), + Jump(5), + PushI32(3), + StoreLocal(0), + ], + data: vec![], + id: 0, + arg_count: 0, + local_count: 1 + }]) + ); + + assert_eq!( + compile("loop { break; }"), + Ok(vec![Fn { + ops: vec![Jump(1), Jump(0)], + data: vec![], + id: 0, + arg_count: 0, + local_count: 0, + },]) + ); } diff --git a/src/lexer.rs b/src/lexer.rs index cbce116..bef64de 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -114,7 +114,7 @@ impl<'a> Lexer<'a> { Some('"') => { self.step(); break self - .token_with_value(TokenKind::String, TokenValue::String(value)); + .token_with_value(TokenKind::Str, TokenValue::Str(value)); } Some(ch) => { value.push(ch); @@ -261,18 +261,18 @@ fn test_lexer() { assert_eq!(lex("abc"), vec![(TK::Id, TV::Id(hash("abc")))]); assert_eq!(lex("123"), vec![(TK::Int, TV::Int(123))]); - assert_eq!(lex("\"\""), vec![(TK::String, TV::String("".to_string()))]); + assert_eq!(lex("\"\""), vec![(TK::Str, TV::Str("".to_string()))]); assert_eq!( lex("\"hello\""), - vec![(TK::String, TV::String("hello".to_string()))] + vec![(TK::Str, TV::Str("hello".to_string()))] ); assert_eq!( lex("\"new\\nline\""), - vec![(TK::String, TV::String("new\nline".to_string()))] + vec![(TK::Str, TV::Str("new\nline".to_string()))] ); assert_eq!( lex("\"backslash\\\\\""), - vec![(TK::String, TV::String("backslash\\".to_string()))] + vec![(TK::Str, TV::Str("backslash\\".to_string()))] ); assert_eq!(lex("->"), vec![(TK::MinusLt, TV::None)]); assert_eq!(lex("let"), vec![(TK::Let, TV::None)]); diff --git a/src/parsed.rs b/src/parsed.rs index 7074005..23ee9b3 100644 --- a/src/parsed.rs +++ b/src/parsed.rs @@ -3,7 +3,7 @@ pub enum Node { Error, Id(u64), Int(i64), - String(String), + Str(String), Group(Box), Block(Vec), Call { diff --git a/src/parser.rs b/src/parser.rs index 5b11b25..1bd41cd 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -111,10 +111,6 @@ impl<'a> Parser<'a> { self.step(); let mut params = Vec::new(); if !self.curr_is(TokenKind::RParen) { - if !self.curr_is(TokenKind::RParen) { - self.error("expected ')'"); - return Err(Node::Error); - } if !self.curr_is(TokenKind::Id) { self.error("expected id"); return Err(Node::Error); @@ -158,7 +154,7 @@ impl<'a> Parser<'a> { fn parse_param(&mut self) -> Node { let subject = Box::new(self.parse_id()); - let typ = if let Some(TokenKind::Comma) = self.curr_kind() { + let typ = if let Some(TokenKind::Colon) = self.curr_kind() { self.step(); Some(Box::new(self.parse_typ())) } else { @@ -239,7 +235,7 @@ impl<'a> Parser<'a> { match self.curr_kind() { Some(TokenKind::Id) => self.parse_id(), Some(TokenKind::Int) => self.parse_int(), - Some(TokenKind::String) => self.parse_string(), + Some(TokenKind::Str) => self.parse_string(), Some(TokenKind::LParen) => self.parse_group(), Some(TokenKind::LBrace) => self.parse_block(), Some(TokenKind::If) => self.parse_if(), @@ -280,15 +276,15 @@ impl<'a> Parser<'a> { fn parse_string(&mut self) -> Node { let Some(Token { - kind: TokenKind::String, - value: TokenValue::String(value), + kind: TokenKind::Str, + value: TokenValue::Str(value), .. }) = self.current.clone() else { unreachable!() }; self.step(); - Node::String(value.clone()) + Node::Str(value.clone()) } fn parse_group(&mut self) -> Node { @@ -307,6 +303,10 @@ impl<'a> Parser<'a> { let mut stmts = Vec::new(); loop { match self.curr_kind() { + None => { + self.error("expected ')'"); + break Node::Error; + } Some(TokenKind::RBrace) => { self.step(); break Node::Block(stmts); @@ -392,7 +392,7 @@ fn test_parser() { assert_eq!(Parser::new("123;").parse(), vec![Int(123)]); assert_eq!( Parser::new("\"hello\";").parse(), - vec![String("hello".to_string())] + vec![Str("hello".to_string())] ); assert_eq!(Parser::new("0;").parse(), vec![Int(0)]); assert_eq!(Parser::new("0;abc;").parse(), vec![Int(0), Id(hash("abc"))]); diff --git a/src/token.rs b/src/token.rs index a84f7c2..17f0f75 100644 --- a/src/token.rs +++ b/src/token.rs @@ -12,7 +12,7 @@ pub enum TokenKind { Error, Id, Int, - String, + Str, If, Else, Loop, @@ -36,5 +36,5 @@ pub enum TokenValue { None, Id(u64), Int(i64), - String(String), + Str(String), }