diff options
Diffstat (limited to 'crowbar_reference_compiler')
-rw-r--r-- | crowbar_reference_compiler/ast.py | 283 |
1 files changed, 245 insertions, 38 deletions
diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py index 4b9c21d..2a6963c 100644 --- a/crowbar_reference_compiler/ast.py +++ b/crowbar_reference_compiler/ast.py @@ -1,8 +1,10 @@ from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple, Union +import typing +from typing import ClassVar, List, Tuple, Union from parsimonious import NodeVisitor # type: ignore +from parsimonious.expressions import Compound, OneOf, Optional, Sequence, TokenMatcher, ZeroOrMore # type: ignore from .scanner import scan from .parser import parse_header @@ -29,6 +31,63 @@ class VariableExpression(Expression): @dataclass +class AddExpression(Expression): + term1: Expression + term2: Expression + + +@dataclass +class MultiplyExpression(Expression): + factor1: Expression + factor2: Expression + + +@dataclass +class StructPointerElementExpression(Expression): + base: Expression + element: str + + +@dataclass +class ArrayIndexExpression(Expression): + array: Expression + index: Expression + + +@dataclass +class FunctionCallExpression(Expression): + function: Expression + arguments: List[Expression] + + +@dataclass +class LogicalNotExpression(Expression): + body: Expression + + +@dataclass +class NegativeExpression(Expression): + body: Expression + + +@dataclass +class AddressOfExpression(Expression): + body: Expression + + +@dataclass +class SizeofExpression(Expression): + body: Union[Type, Expression] + + +@dataclass +class ComparisonExpression(Expression): + value1: Expression + op: str + value2: Expression + + +@dataclass class BasicType(Type): name: str @@ -89,13 +148,13 @@ class ExpressionStatement(Statement): class IfStatement(Statement): condition: Expression then: List[Statement] - els: Optional[List[Statement]] + els: typing.Optional[List[Statement]] @dataclass class SwitchStatement(Statement): expression: Expression - body: List[Union[Optional[Expression], Statement]] + body: List[Union[typing.Optional[Expression], Statement]] @dataclass @@ -122,7 +181,7 @@ class VariableDeclaration(Declaration, HeaderFileElement): @dataclass -class VariableDefinition(Declaration, ImplementationFileElement, Statement): +class VariableDefinition(Declaration, HeaderFileElement, ImplementationFileElement, Statement): """Represents the definition of a variable.""" type: Type value: Expression @@ -152,7 +211,7 @@ class BreakStatement(Statement): @dataclass class ReturnStatement(Statement): - body: Optional[Expression] + body: typing.Optional[Expression] @dataclass @@ -177,20 +236,20 @@ class CrementAssignment(AssignmentStatement): @dataclass class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of a struct type.""" - fields: Optional[List[VariableDeclaration]] + fields: typing.Optional[List[VariableDeclaration]] @dataclass class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of an enum type.""" - values: List[Tuple[str, Optional[int]]] + values: List[Tuple[str, typing.Optional[int]]] @dataclass class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of a union type.""" - tag: Optional[VariableDeclaration] - cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]] + tag: typing.Optional[VariableDeclaration] + cases: Union[List[VariableDeclaration], List[Tuple[Expression, typing.Optional[VariableDeclaration]]]] @dataclass @@ -210,6 +269,7 @@ class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileEleme @dataclass class HeaderFile: + grammar: ClassVar[str] = "HeaderFile <- IncludeStatement* HeaderFileElement+" includes: List['HeaderFile'] contents: List[HeaderFileElement] @@ -227,8 +287,6 @@ class ASTBuilder(NodeVisitor): def visit_HeaderFile(self, node, visited_children) -> HeaderFile: includes, elements = visited_children - if not isinstance(includes, list): - includes = [] return HeaderFile(includes, elements) def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile: @@ -253,36 +311,38 @@ class ASTBuilder(NodeVisitor): def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration: struct, name, lbrace, fields, rbrace = visited_children assert struct.type == 'struct' + assert name.type == 'identifier' + name = name.data assert lbrace.type == '{' assert rbrace.type == '}' - name = name.data - if not isinstance(fields, list): - fields = [fields] return StructDeclaration(name, fields) def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration: opaque, struct, name, semi = visited_children assert opaque.type == 'opaque' assert struct.type == 'struct' - assert semi.type == ';' + assert name.type == 'identifier' name = name.data + assert semi.type == ';' return StructDeclaration(name, None) def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration: enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children assert enum.type == 'enum' + assert name.type == 'identifier' + name = name.data assert lbrace.type == '{' assert rbrace.type == '}' - name = name.data values = [first_member] for _, v in extra_members: values.append(v) return EnumDeclaration(name, values) - def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]: + def visit_EnumMember(self, node, visited_children) -> Tuple[str, typing.Optional[Expression]]: name, equals_value = visited_children + assert name.type == 'identifier' name = name.data - if len(equals_value) == 0: + if equals_value is None: return name, None _, value = equals_value return name, value @@ -290,9 +350,10 @@ class ASTBuilder(NodeVisitor): def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration: union, name, lbrace, tag, body, rbrace = visited_children assert union.type == 'union' + assert name.type == 'identifier' + name = name.data assert lbrace.type == '{' assert rbrace.type == '}' - name = name.data expected_tagname, body = body if tag.name != expected_tagname: raise NameError(f"tag {tag} does not match switch argument {expected_tagname}") @@ -300,7 +361,7 @@ class ASTBuilder(NodeVisitor): body = [body] return UnionDeclaration(name, tag, body) - def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]: + def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]: switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children assert switch.type == 'switch' assert lparen.type == '(' @@ -309,7 +370,7 @@ class ASTBuilder(NodeVisitor): assert rbrace.type == '}' return tag.data, body - def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]: + def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, typing.Optional[VariableDeclaration]]: cases, var = visited_children if isinstance(cases, list): cases = cases[0] @@ -323,17 +384,16 @@ class ASTBuilder(NodeVisitor): visited_children = visited_children[0] # TODO don't explode on 'default:' case, expr, colon = visited_children - while isinstance(expr, list): - expr = expr[0] return expr def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration: fragile, union, name, lbrace, body, rbrace = visited_children assert fragile.type == 'fragile' assert union.type == 'union' + assert name.type == 'identifier' + name = name.data assert lbrace.type == '{' assert rbrace.type == '}' - name = name.data return UnionDeclaration(name, None, body) def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration: @@ -343,15 +403,17 @@ class ASTBuilder(NodeVisitor): def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition: type, name, eq, value, semi = visited_children + assert name.type == 'identifier' + name = name.data assert eq.type == '=' assert semi.type == ';' - name = name.data return VariableDefinition(name, type, value) def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration: type, name, semi = visited_children - assert semi.type == ';' + assert name.type == 'identifier' name = name.data + assert semi.type == ';' return VariableDeclaration(name, type) def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition: @@ -363,9 +425,53 @@ class ASTBuilder(NodeVisitor): assert name.type == 'identifier' name = name.data assert lparen.type == '(' + if args is None: + args = [] assert rparen.type == ')' return FunctionDeclaration(name, return_type, args) + def visit_SignatureArguments(self, node, visited_children) -> List[VariableDeclaration]: + first_type, first_name, rest, comma = visited_children + result = [VariableDeclaration(first_name.data, first_type)] + for comma, ty, name in rest: + result.append(VariableDeclaration(name.data, ty)) + return result + + def visit_IfStatement(self, node, visited_children): + kwd, lparen, condition, rparen, then, els = visited_children + assert kwd.type == 'if' + assert lparen.type == '(' + assert rparen.type == ')' + if els is not None: + kwd, els = els + assert kwd.type == 'else' + return IfStatement(condition, then, els) + + def visit_ReturnStatement(self, node, visited_children): + ret, body, semi = visited_children + assert ret.type == 'return' + assert semi.type == ';' + return ReturnStatement(body) + + def visit_DirectAssignmentBody(self, node, visited_children): + dest, eq, value = visited_children + assert eq.type == '=' + return DirectAssignment(dest, value) + + def visit_UpdateAssignmentBody(self, node, visited_children): + dest, op, value = visited_children + return UpdateAssignment(dest, op.type, value) + + def visit_AssignmentStatement(self, node, visited_children): + assignment, semi = visited_children + assert semi.type == ';' + return assignment + + def visit_ExpressionStatement(self, node, visited_children): + expression, semi = visited_children + assert semi.type == ';' + return ExpressionStatement(expression) + def visit_BasicType(self, node, visited_children) -> Type: while isinstance(visited_children, list) and len(visited_children) == 1: visited_children = visited_children[0] @@ -379,17 +485,23 @@ class ASTBuilder(NodeVisitor): else: category, name = visited_children category = category.type + assert name.type == 'identifier' name = name.data return BasicType(f"{category} {name}") return BasicType(visited_children.type) + def visit_ConstType(self, node, visited_children) -> ConstType: + const, contents = visited_children + assert const.type == 'const' + return ConstType(contents) + + def visit_FunctionType(self, node, visited_children): + raise NotImplementedError('function types') + def visit_ArrayType(self, node, visited_children) -> ArrayType: contents, lbracket, size, rbracket = visited_children assert lbracket.type == '[' assert rbracket.type == ']' - # TODO don't explode on nontrivial expression - while isinstance(size, list): - size = size[0] return ArrayType(contents, size) def visit_PointerType(self, node, visited_children) -> PointerType: @@ -397,6 +509,12 @@ class ASTBuilder(NodeVisitor): assert splat.type == '*' return PointerType(contents) + def visit_Block(self, node, visited_children) -> List[Expression]: + lbrace, body, rbrace = visited_children + assert lbrace.type == '{' + assert rbrace.type == '}' + return body + def visit_AtomicExpression(self, node, visited_children) -> Expression: if isinstance(visited_children, list) and len(visited_children) == 3: lparen, body, rparen = visited_children @@ -414,17 +532,106 @@ class ASTBuilder(NodeVisitor): return ConstantExpression(body.type) raise NotImplementedError() + def visit_StructPointerElementSuffix(self, node, visited_children): + separator, element = visited_children + assert separator.type == '->' + return lambda base: StructPointerElementExpression(base, element.data) + + def visit_CommasExpressionList(self, node, visited_children): + first, rest, comma = visited_children + result = [first] + for comma, next in rest: + result.append(next) + return result + + def visit_FunctionCallSuffix(self, node, visited_children): + lparen, args, rparen = visited_children + assert lparen.type == '(' + assert rparen.type == ')' + if args is None: + args = [] + return lambda base: FunctionCallExpression(base, args) + + def visit_ArrayIndexSuffix(self, node, visited_children): + lbracket, index, rbracket = visited_children + assert lbracket.type == '[' + assert rbracket.type == ']' + return lambda base: ArrayIndexExpression(base, index) + + def visit_ObjectExpression(self, node, visited_children) -> Expression: + if isinstance(visited_children, list): + base, suffix = visited_children[0] + if len(suffix) > 0: + for suffix in suffix: + base = suffix(base) + return base + raise NotImplementedError('array/struct literals') + + def visit_NegativeExpression(self, node, visited_children): + minus, body = visited_children + assert minus.type == '-' + return NegativeExpression(body) + + def visit_AddressOfExpression(self, node, visited_children): + ampersand, body = visited_children + assert ampersand.type == '&' + return AddressOfExpression(body) + + def visit_LogicalNotExpression(self, node, visited_children): + bang, body = visited_children + assert bang.type == '!' + return LogicalNotExpression(body) + + def visit_SizeofExpression(self, node, visited_children): + sizeof, argument = visited_children[0] + assert sizeof.type == 'sizeof' + return SizeofExpression(argument) + + def visit_TermExpression(self, node, visited_children) -> Expression: + base, suffix = visited_children + if suffix is not None: + for op, factor in suffix: + if op.type == '*': + base = MultiplyExpression(base, factor) + else: + raise NotImplementedError('term suffix ' + op) + return base + + def visit_ArithmeticExpression(self, node, visited_children) -> Expression: + base, suffix = visited_children + if suffix is not None: + for op, term in suffix: + if op.type == '+': + base = AddExpression(base, term) + else: + raise NotImplementedError('arithmetic suffix ' + op) + return base + + def visit_GreaterEqExpression(self, node, visited_children): + value1, op, value2 = visited_children + assert op.type == '>=' + return ComparisonExpression(value1, '>=', value2) + + def visit_LessEqExpression(self, node, visited_children): + value1, op, value2 = visited_children + assert op.type == '<=' + return ComparisonExpression(value1, '<=', value2) + def generic_visit(self, node, visited_children): - """ The generic visit method. """ - if not visited_children: - if len(node.text) == 0: - return [] - if len(node.text) == 1: - return node.text[0] - raise ValueError('just a node: ' + str(node)) - if len(visited_children) == 1: + if isinstance(node.expr, TokenMatcher): + return node.text[0] + if isinstance(node.expr, OneOf): + return visited_children[0] + if isinstance(node.expr, Optional): + if len(visited_children) == 0: + return None return visited_children[0] - return visited_children + if isinstance(node.expr, Sequence) and node.expr.name != '': + raise NotImplementedError('visit for sequence ' + str(node.expr)) + if isinstance(node.expr, Compound): + return visited_children + print(node.expr) + return super(ASTBuilder, self).generic_visit(node, visited_children) def build_ast(parse_tree, include_dirs): |