parrot/parser.py
2023-04-08 19:21:37 +02:00

450 lines
18 KiB
Python

from tokens import Token, TokenType, TokenIterator
from position import Position, Span, Node
from parsed import Assign, AssignType, Binary, BinaryType, Block, Break, Call, Continue, Match, MatchArm, PatternError, Expr, For, Id, If, Index, Int, Char, Loop, Pattern, Return, String, ExprError, StructMember, Tuple, TupleMember, Unary, UnaryType, Unit, While
from typing import List, Optional
from utils import Result, Ok, Err
class Parser:
def __init__(self, text: str, tokens: TokenIterator) -> None:
self.text = text
self.tokens = tokens
self.current_token = tokens.next()
def parse(self) -> List[Node[Expr]]:
statements: List[Node[Expr]] = []
while not self.done():
statements.append(self.parse_statement())
return statements
def parse_statement(self) -> Node[Expr]:
if self.current_is(TokenType.KwIf):
return self.parse_if()
else:
return self.parse_assign()
def parse_assign(self) -> Node[Expr]:
subject = self.parse_expr()
if self.current_is(TokenType.Equal):
self.step()
value = self.parse_expr()
return Node(Assign(AssignType.Assign, subject, value), subject.span.to(value.span))
else:
return subject
def parse_expr(self) -> Node[Expr]:
return self.parse_or()
def parse_or(self) -> Node[Expr]:
left = self.parse_and()
while self.current_is(TokenType.KwOr):
self.step()
right = self.parse_and()
left = Node(Binary(BinaryType.Or, left, right), left.span.to(right.span))
return left
def parse_and(self) -> Node[Expr]:
left = self.parse_equal()
while self.current_is(TokenType.KwOr):
self.step()
right = self.parse_equal()
left = Node(Binary(BinaryType.And, left, right), left.span.to(right.span))
return left
def parse_equal(self) -> Node[Expr]:
left = self.parse_compare()
while not self.done():
if self.current_is(TokenType.EqualEqual):
self.step()
right = self.parse_compare()
left = Node(Binary(BinaryType.Equal, left, right), left.span.to(right.span))
elif self.current_is(TokenType.ExclamationEqual):
self.step()
right = self.parse_compare()
left = Node(Binary(BinaryType.Inequal, left, right), left.span.to(right.span))
else:
break
return left
def parse_compare(self) -> Node[Expr]:
left = self.parse_add_subtract()
while not self.done():
if self.current_is(TokenType.LT):
self.step()
right = self.parse_add_subtract()
left = Node(Binary(BinaryType.LT, left, right), left.span.to(right.span))
elif self.current_is(TokenType.GT):
self.step()
right = self.parse_add_subtract()
left = Node(Binary(BinaryType.LT, left, right), left.span.to(right.span))
elif self.current_is(TokenType.LTEqual):
self.step()
right = self.parse_add_subtract()
left = Node(Binary(BinaryType.LTEqual, left, right), left.span.to(right.span))
elif self.current_is(TokenType.GTEqual):
self.step()
right = self.parse_add_subtract()
left = Node(Binary(BinaryType.GTEqual, left, right), left.span.to(right.span))
else:
break
return left
def parse_add_subtract(self) -> Node[Expr]:
left = self.parse_multiply_divide_modulo()
while not self.done():
if self.current_is(TokenType.Plus):
self.step()
right = self.parse_multiply_divide_modulo()
left = Node(Binary(BinaryType.Add, left, right), left.span.to(right.span))
elif self.current_is(TokenType.Minus):
self.step()
right = self.parse_multiply_divide_modulo()
left = Node(Binary(BinaryType.Subtract, left, right), left.span.to(right.span))
else:
break
return left
def parse_multiply_divide_modulo(self) -> Node[Expr]:
left = self.parse_negate()
while not self.done():
if self.current_is(TokenType.Asterisk):
self.step()
right = self.parse_negate()
left = Node(Binary(BinaryType.Multiply, left, right), left.span.to(right.span))
elif self.current_is(TokenType.Slash):
self.step()
right = self.parse_negate()
left = Node(Binary(BinaryType.Divide, left, right), left.span.to(right.span))
elif self.current_is(TokenType.Percent):
self.step()
right = self.parse_negate()
left = Node(Binary(BinaryType.Modulo, left, right), left.span.to(right.span))
else:
break
return left
def parse_negate(self) -> Node[Expr]:
if self.current_is(TokenType.Minus):
token_span = self.current().span
self.step()
subject = self.parse_exponent()
return Node(Unary(UnaryType.Negate, subject), token_span.to(subject.span))
else:
return self.parse_exponent()
def parse_exponent(self) -> Node[Expr]:
left = self.parse_unary()
if self.current_is(TokenType.AsteriskAsterisk):
self.step()
right = self.parse_exponent()
return Node(Binary(BinaryType.Exponent, left, right), left.span.to(right.span))
else:
return left
def parse_unary(self) -> Node[Expr]:
if self.current_is(TokenType.KwNot):
token_span = self.current().span
self.step()
subject = self.parse_unary()
return Node(Unary(UnaryType.Not, subject), token_span.to(subject.span))
elif self.current_is(TokenType.Asterisk):
token_span = self.current().span
self.step()
subject = self.parse_unary()
return Node(Unary(UnaryType.Dereference, subject), token_span.to(subject.span))
elif self.current_is(TokenType.Ampersand):
token_span = self.current().span
self.step()
if self.current_is(TokenType.KwMut):
self.step()
subject = self.parse_unary()
return Node(Unary(UnaryType.ReferenceMut, subject), token_span.to(subject.span))
else:
subject = self.parse_unary()
return Node(Unary(UnaryType.ReferenceMut, subject), token_span.to(subject.span))
else:
return self.parse_member_index_call()
def parse_member_index_call(self) -> Node[Expr]:
subject: Node[Expr] = self.parse_operand()
while not self.done():
if self.current_is(TokenType.Dot):
self.step()
if self.current_is(TokenType.Id):
id_token = self.current()
self.step()
text = id_token.text_slice(self.text)
subject = Node(StructMember(subject, text), subject.span.to(id_token.span))
elif self.current_is(TokenType.Int):
int_token = self.current()
self.step()
value = int(int_token.text_slice(self.text))
subject = Node(TupleMember(subject, value), subject.span.to(int_token.span))
else:
return Node(ExprError("expected Int or Id"), subject.span)
elif self.current_is(TokenType.LBracket):
self.step()
value = self.parse_expr()
if not self.current_is(TokenType.RBracket):
return Node(ExprError("expected ']'"), subject.span.to(value.span))
rbracket_token_span = self.current().span
self.step()
subject = Node(Index(subject, value), subject.span.to(rbracket_token_span))
elif self.current_is(TokenType.LParen):
self.step()
arguments: List[Node[Expr]] = []
if not self.done() and self.current() != TokenType.RParen:
arguments.append(self.parse_expr())
while not self.done() and self.current() == TokenType.Comma:
self.step()
if self.done() or self.current() == TokenType.RParen:
break
arguments.append(self.parse_expr())
if not self.current_is(TokenType.RParen):
if len(arguments) > 0:
end = arguments[-1].span
else:
end = subject.span
return Node(ExprError("expected ')'"), subject.span.to(end))
end = self.current().span
self.step()
subject = Node(Call(subject, arguments), subject.span.to(end))
else:
break
return subject
def parse_operand(self) -> Node[Expr]:
if self.current_is(TokenType.Id):
token = self.current()
value = token.text_slice(self.text)
self.step()
return Node(Id(value), token.span)
elif self.current_is(TokenType.Int):
token = self.current()
value = int(token.text_slice(self.text))
self.step()
return Node(Int(value), token.span)
elif self.current_is(TokenType.Char):
token = self.current()
value = token.text_slice(self.text)
self.step()
return Node(Char(value), token.span)
elif self.current_is(TokenType.String):
token = self.current()
value = token.text_slice(self.text)
self.step()
return Node(String(value), token.span)
elif self.current_is(TokenType.LParen):
return self.parse_unit_group_tuple()
elif self.current_is(TokenType.LBrace):
return self.parse_block()
elif self.current_is(TokenType.KwIf):
return self.parse_if()
elif self.current_is(TokenType.KwLoop):
return self.parse_loop()
elif self.current_is(TokenType.KwWhile):
return self.parse_while()
elif self.current_is(TokenType.KwFor):
return self.parse_for()
else:
token = self.current()
self.step()
return Node(ExprError("expected value"), token.span)
def parse_unit_group_tuple(self) -> Node[Expr]:
begin = self.current().span
self.step()
if self.current_is(TokenType.RParen):
end = self.current().span
self.step()
return Node(Unit(), begin.to(end))
else:
first_expr = self.parse_expr()
if self.current_is(TokenType.RParen):
end = self.current().span
self.step()
return Node(first_expr.value, begin.to(end))
elif self.current_is(TokenType.Comma):
values = [first_expr]
end = self.current().span
while self.current_is(TokenType.Comma):
end = self.current().span
self.step()
if self.done() or self.current().token_type == TokenType.RParen:
break
value = self.parse_expr()
end = value.span
values.append(value)
if not self.current_is(TokenType.RParen):
return Node(ExprError("expected ')'"), begin.to(end))
end = self.current().span
self.step()
return Node(Tuple(values), begin.to(end))
else:
return Node(ExprError("expected ')' or ','"), begin.to(first_expr.span))
def parse_block(self) -> Node[Expr]:
begin = self.current().span
self.step()
statements: List[Node[Expr]] = []
value: Optional[Node[Expr]] = None
while not self.done() and self.current().token_type != TokenType.RBrace:
if value is not None:
statements.append(value)
value = self.parse_statement()
if not self.current_is(TokenType.RBrace):
if value is not None:
end = value.span
else:
end = begin
return Node(ExprError("expected '}'"), begin.to(end))
end = self.current().span
self.step()
return Node(Block(statements, value), begin.to(end))
def parse_if(self) -> Node[Expr]:
begin = self.current().span
self.step()
condition = self.parse_expr()
if not self.current_is(TokenType.LBrace):
return Node(ExprError("expected '{'"), begin.to(condition.span))
truthy = self.parse_block()
if self.current_is(TokenType.KwElse):
else_token_span = self.current().span
self.step()
if not self.current_is(TokenType.LBrace):
return Node(ExprError("expected '{'"), begin.to(else_token_span))
falsy = self.parse_block()
return Node(If(condition, truthy, falsy), begin.to(falsy.span))
else:
return Node(If(condition, truthy, None), begin.to(truthy.span))
def parse_match(self) -> Node[Expr]:
begin = self.current().span
self.step()
value = self.parse_expr()
if not self.current_is(TokenType.LBrace):
return Node(ExprError("expected '{'"), begin)
lbrace_span = self.current().span
self.step()
arms: List[Node[MatchArm]] = []
if not self.done() and self.current() != TokenType.RBrace:
arm = self.parse_match_arm()
if not arm.ok():
return Node(ExprError(arm.error().value), begin.to(arm.error().span))
arms.append(arm.value())
while not self.done() and self.current() != TokenType.RBrace:
arm = self.parse_match_arm()
if not arm.ok():
return Node(ExprError(arm.error().value), begin.to(arm.error().span))
arms.append(arm.value())
if not self.current_is(TokenType.RBrace):
if len(arms) > 0:
end = arms[-1].span
else:
end = lbrace_span
return Node(ExprError("expected '}'"), begin.to(end))
rbrace_span = self.current().span
self.step()
return Node(Match(value, arms), begin.to(rbrace_span))
def parse_match_arm(self) -> Result[Node[MatchArm], Node[str]]:
pattern = self.parse_pattern()
if not self.current_is(TokenType.EqualLT):
return Err(Node("expected '=>'", pattern.span.to(pattern.span)))
self.step()
if self.current_is(TokenType.LParen):
expr = self.parse_block()
if self.current_is(TokenType.Comma):
self.step()
else:
expr = self.parse_match_arm_expr()
if not self.current_is(TokenType.Comma):
return Err(Node("expected ','", pattern.span.to(expr.span)))
self.step()
return Ok(Node(MatchArm(pattern, expr), pattern.span.to(expr.span)))
def parse_match_arm_expr(self) -> Node[Expr]:
if self.current_is(TokenType.KwReturn):
return self.parse_return()
elif self.current_is(TokenType.KwBreak):
return self.parse_break()
elif self.current_is(TokenType.KwContinue):
return self.parse_continue()
else:
return self.parse_expr()
def parse_loop(self) -> Node[Expr]:
begin = self.current().span
self.step()
if not self.current_is(TokenType.LBrace):
return Node(ExprError("expected '{'"), begin)
body = self.parse_block()
return Node(Loop(body), begin.to(body.span))
def parse_while(self) -> Node[Expr]:
begin = self.current().span
self.step()
condition = self.parse_expr()
if not self.current_is(TokenType.LBrace):
return Node(ExprError("expected '{'"), begin.to(condition.span))
self.step()
body = self.parse_block()
return Node(While(condition, body), begin.to(body.span))
def parse_for(self) -> Node[Expr]:
begin = self.current().span
self.step()
subject = self.parse_pattern()
if not self.current_is(TokenType.KwIn):
return Node(ExprError("expected 'in'"), begin.to(subject.span))
self.step()
value = self.parse_expr()
if not self.current_is(TokenType.LBrace):
return Node(ExprError("expected '{'"), begin.to(value.span))
self.step()
body = self.parse_block()
return Node(For(subject, value, body), begin.to(body.span))
def parse_return(self) -> Node[Expr]:
begin = self.current().span
self.step()
value: Optional[Node[Expr]] = None
if not self.done() and self.current() not in [TokenType.Comma, TokenType.Semicolon]:
value = self.parse_expr()
return Node(Return(value), begin.to(value.span if value is not None else begin))
def parse_break(self) -> Node[Expr]:
begin = self.current().span
self.step()
value: Optional[Node[Expr]] = None
if not self.done() and self.current() not in [TokenType.Comma, TokenType.Semicolon]:
value = self.parse_expr()
return Node(Break(value), begin.to(value.span if value is not None else begin))
def parse_continue(self) -> Node[Expr]:
begin = self.current().span
self.step()
value: Optional[Node[Expr]] = None
if not self.done() and self.current() not in [TokenType.Comma, TokenType.Semicolon]:
value = self.parse_expr()
return Node(Continue(value), begin.to(value.span if value is not None else begin))
def parse_pattern(self) -> Node[Pattern]:
return Node(PatternError("not implemented"), self.current().span)
def step(self) -> None:
self.current_token = self.tokens.next()
def current_is(self, token_type: TokenType) -> bool:
return not self.done() and self.current().token_type == token_type
def done(self) -> bool:
return self.current_token.token_type == TokenType.Eof
def current(self) -> Token:
return self.current_token