from tokens import Token, TokenType, TokenIterator from position import Position, Span, Node from parsed import Assign, AssignType, Binary, BinaryType, Block, Break, Call, Continue, Match, MatchArm, PatternError, Expr, For, Id, If, Index, Int, Char, Loop, Pattern, Return, String, ExprError, StructMember, Tuple, TupleMember, Unary, UnaryType, Unit, While from typing import List, Optional from utils import Result, Ok, Err, unescape_string class Parser: def __init__(self, text: str, tokens: TokenIterator) -> None: self.text = text self.tokens = tokens self.current_token = tokens.next() def parse(self) -> List[Node[Expr]]: statements: List[Node[Expr]] = [] while not self.done(): statements.append(self.parse_statement()) return statements def parse_statement(self) -> Node[Expr]: if self.current_is(TokenType.KwIf): return self.parse_if() else: return self.parse_assign() def parse_assign(self) -> Node[Expr]: subject = self.parse_expr() if self.current_is(TokenType.Equal): self.step() value = self.parse_expr() return Node(Assign(AssignType.Assign, subject, value), subject.span.to(value.span)) else: return subject def parse_expr(self) -> Node[Expr]: return self.parse_or() def parse_or(self) -> Node[Expr]: left = self.parse_and() while self.current_is(TokenType.KwOr): self.step() right = self.parse_and() left = Node(Binary(BinaryType.Or, left, right), left.span.to(right.span)) return left def parse_and(self) -> Node[Expr]: left = self.parse_equal() while self.current_is(TokenType.KwOr): self.step() right = self.parse_equal() left = Node(Binary(BinaryType.And, left, right), left.span.to(right.span)) return left def parse_equal(self) -> Node[Expr]: left = self.parse_compare() while not self.done(): if self.current_is(TokenType.EqualEqual): self.step() right = self.parse_compare() left = Node(Binary(BinaryType.Equal, left, right), left.span.to(right.span)) elif self.current_is(TokenType.ExclamationEqual): self.step() right = self.parse_compare() left = Node(Binary(BinaryType.Inequal, left, right), left.span.to(right.span)) else: break return left def parse_compare(self) -> Node[Expr]: left = self.parse_add_subtract() while not self.done(): if self.current_is(TokenType.LT): self.step() right = self.parse_add_subtract() left = Node(Binary(BinaryType.LT, left, right), left.span.to(right.span)) elif self.current_is(TokenType.GT): self.step() right = self.parse_add_subtract() left = Node(Binary(BinaryType.LT, left, right), left.span.to(right.span)) elif self.current_is(TokenType.LTEqual): self.step() right = self.parse_add_subtract() left = Node(Binary(BinaryType.LTEqual, left, right), left.span.to(right.span)) elif self.current_is(TokenType.GTEqual): self.step() right = self.parse_add_subtract() left = Node(Binary(BinaryType.GTEqual, left, right), left.span.to(right.span)) else: break return left def parse_add_subtract(self) -> Node[Expr]: left = self.parse_multiply_divide_modulo() while not self.done(): if self.current_is(TokenType.Plus): self.step() right = self.parse_multiply_divide_modulo() left = Node(Binary(BinaryType.Add, left, right), left.span.to(right.span)) elif self.current_is(TokenType.Minus): self.step() right = self.parse_multiply_divide_modulo() left = Node(Binary(BinaryType.Subtract, left, right), left.span.to(right.span)) else: break return left def parse_multiply_divide_modulo(self) -> Node[Expr]: left = self.parse_negate() while not self.done(): if self.current_is(TokenType.Asterisk): self.step() right = self.parse_negate() left = Node(Binary(BinaryType.Multiply, left, right), left.span.to(right.span)) elif self.current_is(TokenType.Slash): self.step() right = self.parse_negate() left = Node(Binary(BinaryType.Divide, left, right), left.span.to(right.span)) elif self.current_is(TokenType.Percent): self.step() right = self.parse_negate() left = Node(Binary(BinaryType.Modulo, left, right), left.span.to(right.span)) else: break return left def parse_negate(self) -> Node[Expr]: if self.current_is(TokenType.Minus): token_span = self.current().span self.step() subject = self.parse_exponent() return Node(Unary(UnaryType.Negate, subject), token_span.to(subject.span)) else: return self.parse_exponent() def parse_exponent(self) -> Node[Expr]: left = self.parse_unary() if self.current_is(TokenType.AsteriskAsterisk): self.step() right = self.parse_exponent() return Node(Binary(BinaryType.Exponent, left, right), left.span.to(right.span)) else: return left def parse_unary(self) -> Node[Expr]: if self.current_is(TokenType.KwNot): token_span = self.current().span self.step() subject = self.parse_unary() return Node(Unary(UnaryType.Not, subject), token_span.to(subject.span)) elif self.current_is(TokenType.Asterisk): token_span = self.current().span self.step() subject = self.parse_unary() return Node(Unary(UnaryType.Dereference, subject), token_span.to(subject.span)) elif self.current_is(TokenType.Ampersand): token_span = self.current().span self.step() if self.current_is(TokenType.KwMut): self.step() subject = self.parse_unary() return Node(Unary(UnaryType.ReferenceMut, subject), token_span.to(subject.span)) else: subject = self.parse_unary() return Node(Unary(UnaryType.ReferenceMut, subject), token_span.to(subject.span)) else: return self.parse_member_index_call() def parse_member_index_call(self) -> Node[Expr]: subject: Node[Expr] = self.parse_operand() while not self.done(): if self.current_is(TokenType.Dot): self.step() if self.current_is(TokenType.Id): id_token = self.current() self.step() text = id_token.text_slice(self.text) subject = Node(StructMember(subject, text), subject.span.to(id_token.span)) elif self.current_is(TokenType.Int): int_token = self.current() self.step() value = int(int_token.text_slice(self.text)) subject = Node(TupleMember(subject, value), subject.span.to(int_token.span)) else: return Node(ExprError("expected Int or Id"), subject.span) elif self.current_is(TokenType.LBracket): self.step() value = self.parse_expr() if not self.current_is(TokenType.RBracket): return Node(ExprError("expected ']'"), subject.span.to(value.span)) rbracket_token_span = self.current().span self.step() subject = Node(Index(subject, value), subject.span.to(rbracket_token_span)) elif self.current_is(TokenType.LParen): self.step() arguments: List[Node[Expr]] = [] if not self.done() and self.current() != TokenType.RParen: arguments.append(self.parse_expr()) while not self.done() and self.current() == TokenType.Comma: self.step() if self.done() or self.current() == TokenType.RParen: break arguments.append(self.parse_expr()) if not self.current_is(TokenType.RParen): if len(arguments) > 0: end = arguments[-1].span else: end = subject.span return Node(ExprError("expected ')'"), subject.span.to(end)) end = self.current().span self.step() subject = Node(Call(subject, arguments), subject.span.to(end)) else: break return subject def parse_operand(self) -> Node[Expr]: if self.current_is(TokenType.Id): token = self.current() self.step() value = token.text_slice(self.text) return Node(Id(value), token.span) elif self.current_is(TokenType.Int): token = self.current() self.step() value = int(token.text_slice(self.text)) return Node(Int(value), token.span) elif self.current_is(TokenType.Char): token = self.current() self.step() value = unescape_string(token.text_slice(self.text)[1:-1]) if not value.ok(): return Node(ExprError(value.error()), token.span) return Node(Char(value.value()), token.span) elif self.current_is(TokenType.String): token = self.current() self.step() value = unescape_string(token.text_slice(self.text)[1:-1]) if not value.ok(): return Node(ExprError(value.error()), token.span) return Node(String(value.value()), token.span) elif self.current_is(TokenType.LParen): return self.parse_unit_group_tuple() elif self.current_is(TokenType.LBrace): return self.parse_block() elif self.current_is(TokenType.KwIf): return self.parse_if() elif self.current_is(TokenType.KwLoop): return self.parse_loop() elif self.current_is(TokenType.KwWhile): return self.parse_while() elif self.current_is(TokenType.KwFor): return self.parse_for() else: token = self.current() self.step() return Node(ExprError("expected value"), token.span) def parse_unit_group_tuple(self) -> Node[Expr]: begin = self.current().span self.step() if self.current_is(TokenType.RParen): end = self.current().span self.step() return Node(Unit(), begin.to(end)) else: first_expr = self.parse_expr() if self.current_is(TokenType.RParen): end = self.current().span self.step() return Node(first_expr.value, begin.to(end)) elif self.current_is(TokenType.Comma): values = [first_expr] end = self.current().span while self.current_is(TokenType.Comma): end = self.current().span self.step() if self.done() or self.current().token_type == TokenType.RParen: break value = self.parse_expr() end = value.span values.append(value) if not self.current_is(TokenType.RParen): return Node(ExprError("expected ')'"), begin.to(end)) end = self.current().span self.step() return Node(Tuple(values), begin.to(end)) else: return Node(ExprError("expected ')' or ','"), begin.to(first_expr.span)) def parse_block(self) -> Node[Expr]: begin = self.current().span self.step() statements: List[Node[Expr]] = [] value: Optional[Node[Expr]] = None while not self.done() and self.current().token_type != TokenType.RBrace: if value is not None: statements.append(value) value = self.parse_statement() if not self.current_is(TokenType.RBrace): if value is not None: end = value.span else: end = begin return Node(ExprError("expected '}'"), begin.to(end)) end = self.current().span self.step() return Node(Block(statements, value), begin.to(end)) def parse_if(self) -> Node[Expr]: begin = self.current().span self.step() condition = self.parse_expr() if not self.current_is(TokenType.LBrace): return Node(ExprError("expected '{'"), begin.to(condition.span)) truthy = self.parse_block() if self.current_is(TokenType.KwElse): else_token_span = self.current().span self.step() if not self.current_is(TokenType.LBrace): return Node(ExprError("expected '{'"), begin.to(else_token_span)) falsy = self.parse_block() return Node(If(condition, truthy, falsy), begin.to(falsy.span)) else: return Node(If(condition, truthy, None), begin.to(truthy.span)) def parse_match(self) -> Node[Expr]: begin = self.current().span self.step() value = self.parse_expr() if not self.current_is(TokenType.LBrace): return Node(ExprError("expected '{'"), begin) lbrace_span = self.current().span self.step() arms: List[Node[MatchArm]] = [] if not self.done() and self.current() != TokenType.RBrace: arm = self.parse_match_arm() if not arm.ok(): return Node(ExprError(arm.error().value), begin.to(arm.error().span)) arms.append(arm.value()) while not self.done() and self.current() != TokenType.RBrace: arm = self.parse_match_arm() if not arm.ok(): return Node(ExprError(arm.error().value), begin.to(arm.error().span)) arms.append(arm.value()) if not self.current_is(TokenType.RBrace): if len(arms) > 0: end = arms[-1].span else: end = lbrace_span return Node(ExprError("expected '}'"), begin.to(end)) rbrace_span = self.current().span self.step() return Node(Match(value, arms), begin.to(rbrace_span)) def parse_match_arm(self) -> Result[Node[MatchArm], Node[str]]: pattern = self.parse_pattern() if not self.current_is(TokenType.EqualLT): return Err(Node("expected '=>'", pattern.span.to(pattern.span))) self.step() if self.current_is(TokenType.LParen): expr = self.parse_block() if self.current_is(TokenType.Comma): self.step() else: expr = self.parse_match_arm_expr() if not self.current_is(TokenType.Comma): return Err(Node("expected ','", pattern.span.to(expr.span))) self.step() return Ok(Node(MatchArm(pattern, expr), pattern.span.to(expr.span))) def parse_match_arm_expr(self) -> Node[Expr]: if self.current_is(TokenType.KwReturn): return self.parse_return() elif self.current_is(TokenType.KwBreak): return self.parse_break() elif self.current_is(TokenType.KwContinue): return self.parse_continue() else: return self.parse_expr() def parse_loop(self) -> Node[Expr]: begin = self.current().span self.step() if not self.current_is(TokenType.LBrace): return Node(ExprError("expected '{'"), begin) body = self.parse_block() return Node(Loop(body), begin.to(body.span)) def parse_while(self) -> Node[Expr]: begin = self.current().span self.step() condition = self.parse_expr() if not self.current_is(TokenType.LBrace): return Node(ExprError("expected '{'"), begin.to(condition.span)) self.step() body = self.parse_block() return Node(While(condition, body), begin.to(body.span)) def parse_for(self) -> Node[Expr]: begin = self.current().span self.step() subject = self.parse_pattern() if not self.current_is(TokenType.KwIn): return Node(ExprError("expected 'in'"), begin.to(subject.span)) self.step() value = self.parse_expr() if not self.current_is(TokenType.LBrace): return Node(ExprError("expected '{'"), begin.to(value.span)) self.step() body = self.parse_block() return Node(For(subject, value, body), begin.to(body.span)) def parse_return(self) -> Node[Expr]: begin = self.current().span self.step() value: Optional[Node[Expr]] = None if not self.done() and self.current() not in [TokenType.Comma, TokenType.Semicolon]: value = self.parse_expr() return Node(Return(value), begin.to(value.span if value is not None else begin)) def parse_break(self) -> Node[Expr]: begin = self.current().span self.step() value: Optional[Node[Expr]] = None if not self.done() and self.current() not in [TokenType.Comma, TokenType.Semicolon]: value = self.parse_expr() return Node(Break(value), begin.to(value.span if value is not None else begin)) def parse_continue(self) -> Node[Expr]: begin = self.current().span self.step() value: Optional[Node[Expr]] = None if not self.done() and self.current() not in [TokenType.Comma, TokenType.Semicolon]: value = self.parse_expr() return Node(Continue(value), begin.to(value.span if value is not None else begin)) def parse_pattern(self) -> Node[Pattern]: return Node(PatternError("not implemented"), self.current().span) def step(self) -> None: self.current_token = self.tokens.next() def current_is(self, token_type: TokenType) -> bool: return not self.done() and self.current().token_type == token_type def done(self) -> bool: return self.current_token.token_type == TokenType.Eof def current(self) -> Token: return self.current_token