aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler
diff options
context:
space:
mode:
Diffstat (limited to 'crowbar_reference_compiler')
-rw-r--r--crowbar_reference_compiler/ast.py283
1 files changed, 245 insertions, 38 deletions
diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py
index 4b9c21d..2a6963c 100644
--- a/crowbar_reference_compiler/ast.py
+++ b/crowbar_reference_compiler/ast.py
@@ -1,8 +1,10 @@
from dataclasses import dataclass
from pathlib import Path
-from typing import List, Optional, Tuple, Union
+import typing
+from typing import ClassVar, List, Tuple, Union
from parsimonious import NodeVisitor # type: ignore
+from parsimonious.expressions import Compound, OneOf, Optional, Sequence, TokenMatcher, ZeroOrMore # type: ignore
from .scanner import scan
from .parser import parse_header
@@ -29,6 +31,63 @@ class VariableExpression(Expression):
@dataclass
+class AddExpression(Expression):
+ term1: Expression
+ term2: Expression
+
+
+@dataclass
+class MultiplyExpression(Expression):
+ factor1: Expression
+ factor2: Expression
+
+
+@dataclass
+class StructPointerElementExpression(Expression):
+ base: Expression
+ element: str
+
+
+@dataclass
+class ArrayIndexExpression(Expression):
+ array: Expression
+ index: Expression
+
+
+@dataclass
+class FunctionCallExpression(Expression):
+ function: Expression
+ arguments: List[Expression]
+
+
+@dataclass
+class LogicalNotExpression(Expression):
+ body: Expression
+
+
+@dataclass
+class NegativeExpression(Expression):
+ body: Expression
+
+
+@dataclass
+class AddressOfExpression(Expression):
+ body: Expression
+
+
+@dataclass
+class SizeofExpression(Expression):
+ body: Union[Type, Expression]
+
+
+@dataclass
+class ComparisonExpression(Expression):
+ value1: Expression
+ op: str
+ value2: Expression
+
+
+@dataclass
class BasicType(Type):
name: str
@@ -89,13 +148,13 @@ class ExpressionStatement(Statement):
class IfStatement(Statement):
condition: Expression
then: List[Statement]
- els: Optional[List[Statement]]
+ els: typing.Optional[List[Statement]]
@dataclass
class SwitchStatement(Statement):
expression: Expression
- body: List[Union[Optional[Expression], Statement]]
+ body: List[Union[typing.Optional[Expression], Statement]]
@dataclass
@@ -122,7 +181,7 @@ class VariableDeclaration(Declaration, HeaderFileElement):
@dataclass
-class VariableDefinition(Declaration, ImplementationFileElement, Statement):
+class VariableDefinition(Declaration, HeaderFileElement, ImplementationFileElement, Statement):
"""Represents the definition of a variable."""
type: Type
value: Expression
@@ -152,7 +211,7 @@ class BreakStatement(Statement):
@dataclass
class ReturnStatement(Statement):
- body: Optional[Expression]
+ body: typing.Optional[Expression]
@dataclass
@@ -177,20 +236,20 @@ class CrementAssignment(AssignmentStatement):
@dataclass
class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of a struct type."""
- fields: Optional[List[VariableDeclaration]]
+ fields: typing.Optional[List[VariableDeclaration]]
@dataclass
class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of an enum type."""
- values: List[Tuple[str, Optional[int]]]
+ values: List[Tuple[str, typing.Optional[int]]]
@dataclass
class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
"""Represents the declaration of a union type."""
- tag: Optional[VariableDeclaration]
- cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]]
+ tag: typing.Optional[VariableDeclaration]
+ cases: Union[List[VariableDeclaration], List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]
@dataclass
@@ -210,6 +269,7 @@ class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileEleme
@dataclass
class HeaderFile:
+ grammar: ClassVar[str] = "HeaderFile <- IncludeStatement* HeaderFileElement+"
includes: List['HeaderFile']
contents: List[HeaderFileElement]
@@ -227,8 +287,6 @@ class ASTBuilder(NodeVisitor):
def visit_HeaderFile(self, node, visited_children) -> HeaderFile:
includes, elements = visited_children
- if not isinstance(includes, list):
- includes = []
return HeaderFile(includes, elements)
def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile:
@@ -253,36 +311,38 @@ class ASTBuilder(NodeVisitor):
def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration:
struct, name, lbrace, fields, rbrace = visited_children
assert struct.type == 'struct'
+ assert name.type == 'identifier'
+ name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
- name = name.data
- if not isinstance(fields, list):
- fields = [fields]
return StructDeclaration(name, fields)
def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration:
opaque, struct, name, semi = visited_children
assert opaque.type == 'opaque'
assert struct.type == 'struct'
- assert semi.type == ';'
+ assert name.type == 'identifier'
name = name.data
+ assert semi.type == ';'
return StructDeclaration(name, None)
def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration:
enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
assert enum.type == 'enum'
+ assert name.type == 'identifier'
+ name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
- name = name.data
values = [first_member]
for _, v in extra_members:
values.append(v)
return EnumDeclaration(name, values)
- def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]:
+ def visit_EnumMember(self, node, visited_children) -> Tuple[str, typing.Optional[Expression]]:
name, equals_value = visited_children
+ assert name.type == 'identifier'
name = name.data
- if len(equals_value) == 0:
+ if equals_value is None:
return name, None
_, value = equals_value
return name, value
@@ -290,9 +350,10 @@ class ASTBuilder(NodeVisitor):
def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration:
union, name, lbrace, tag, body, rbrace = visited_children
assert union.type == 'union'
+ assert name.type == 'identifier'
+ name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
- name = name.data
expected_tagname, body = body
if tag.name != expected_tagname:
raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
@@ -300,7 +361,7 @@ class ASTBuilder(NodeVisitor):
body = [body]
return UnionDeclaration(name, tag, body)
- def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]:
+ def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, typing.Optional[VariableDeclaration]]]]:
switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
assert switch.type == 'switch'
assert lparen.type == '('
@@ -309,7 +370,7 @@ class ASTBuilder(NodeVisitor):
assert rbrace.type == '}'
return tag.data, body
- def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]:
+ def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, typing.Optional[VariableDeclaration]]:
cases, var = visited_children
if isinstance(cases, list):
cases = cases[0]
@@ -323,17 +384,16 @@ class ASTBuilder(NodeVisitor):
visited_children = visited_children[0]
# TODO don't explode on 'default:'
case, expr, colon = visited_children
- while isinstance(expr, list):
- expr = expr[0]
return expr
def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration:
fragile, union, name, lbrace, body, rbrace = visited_children
assert fragile.type == 'fragile'
assert union.type == 'union'
+ assert name.type == 'identifier'
+ name = name.data
assert lbrace.type == '{'
assert rbrace.type == '}'
- name = name.data
return UnionDeclaration(name, None, body)
def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration:
@@ -343,15 +403,17 @@ class ASTBuilder(NodeVisitor):
def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition:
type, name, eq, value, semi = visited_children
+ assert name.type == 'identifier'
+ name = name.data
assert eq.type == '='
assert semi.type == ';'
- name = name.data
return VariableDefinition(name, type, value)
def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration:
type, name, semi = visited_children
- assert semi.type == ';'
+ assert name.type == 'identifier'
name = name.data
+ assert semi.type == ';'
return VariableDeclaration(name, type)
def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition:
@@ -363,9 +425,53 @@ class ASTBuilder(NodeVisitor):
assert name.type == 'identifier'
name = name.data
assert lparen.type == '('
+ if args is None:
+ args = []
assert rparen.type == ')'
return FunctionDeclaration(name, return_type, args)
+ def visit_SignatureArguments(self, node, visited_children) -> List[VariableDeclaration]:
+ first_type, first_name, rest, comma = visited_children
+ result = [VariableDeclaration(first_name.data, first_type)]
+ for comma, ty, name in rest:
+ result.append(VariableDeclaration(name.data, ty))
+ return result
+
+ def visit_IfStatement(self, node, visited_children):
+ kwd, lparen, condition, rparen, then, els = visited_children
+ assert kwd.type == 'if'
+ assert lparen.type == '('
+ assert rparen.type == ')'
+ if els is not None:
+ kwd, els = els
+ assert kwd.type == 'else'
+ return IfStatement(condition, then, els)
+
+ def visit_ReturnStatement(self, node, visited_children):
+ ret, body, semi = visited_children
+ assert ret.type == 'return'
+ assert semi.type == ';'
+ return ReturnStatement(body)
+
+ def visit_DirectAssignmentBody(self, node, visited_children):
+ dest, eq, value = visited_children
+ assert eq.type == '='
+ return DirectAssignment(dest, value)
+
+ def visit_UpdateAssignmentBody(self, node, visited_children):
+ dest, op, value = visited_children
+ return UpdateAssignment(dest, op.type, value)
+
+ def visit_AssignmentStatement(self, node, visited_children):
+ assignment, semi = visited_children
+ assert semi.type == ';'
+ return assignment
+
+ def visit_ExpressionStatement(self, node, visited_children):
+ expression, semi = visited_children
+ assert semi.type == ';'
+ return ExpressionStatement(expression)
+
def visit_BasicType(self, node, visited_children) -> Type:
while isinstance(visited_children, list) and len(visited_children) == 1:
visited_children = visited_children[0]
@@ -379,17 +485,23 @@ class ASTBuilder(NodeVisitor):
else:
category, name = visited_children
category = category.type
+ assert name.type == 'identifier'
name = name.data
return BasicType(f"{category} {name}")
return BasicType(visited_children.type)
+ def visit_ConstType(self, node, visited_children) -> ConstType:
+ const, contents = visited_children
+ assert const.type == 'const'
+ return ConstType(contents)
+
+ def visit_FunctionType(self, node, visited_children):
+ raise NotImplementedError('function types')
+
def visit_ArrayType(self, node, visited_children) -> ArrayType:
contents, lbracket, size, rbracket = visited_children
assert lbracket.type == '['
assert rbracket.type == ']'
- # TODO don't explode on nontrivial expression
- while isinstance(size, list):
- size = size[0]
return ArrayType(contents, size)
def visit_PointerType(self, node, visited_children) -> PointerType:
@@ -397,6 +509,12 @@ class ASTBuilder(NodeVisitor):
assert splat.type == '*'
return PointerType(contents)
+ def visit_Block(self, node, visited_children) -> List[Expression]:
+ lbrace, body, rbrace = visited_children
+ assert lbrace.type == '{'
+ assert rbrace.type == '}'
+ return body
+
def visit_AtomicExpression(self, node, visited_children) -> Expression:
if isinstance(visited_children, list) and len(visited_children) == 3:
lparen, body, rparen = visited_children
@@ -414,17 +532,106 @@ class ASTBuilder(NodeVisitor):
return ConstantExpression(body.type)
raise NotImplementedError()
+ def visit_StructPointerElementSuffix(self, node, visited_children):
+ separator, element = visited_children
+ assert separator.type == '->'
+ return lambda base: StructPointerElementExpression(base, element.data)
+
+ def visit_CommasExpressionList(self, node, visited_children):
+ first, rest, comma = visited_children
+ result = [first]
+ for comma, next in rest:
+ result.append(next)
+ return result
+
+ def visit_FunctionCallSuffix(self, node, visited_children):
+ lparen, args, rparen = visited_children
+ assert lparen.type == '('
+ assert rparen.type == ')'
+ if args is None:
+ args = []
+ return lambda base: FunctionCallExpression(base, args)
+
+ def visit_ArrayIndexSuffix(self, node, visited_children):
+ lbracket, index, rbracket = visited_children
+ assert lbracket.type == '['
+ assert rbracket.type == ']'
+ return lambda base: ArrayIndexExpression(base, index)
+
+ def visit_ObjectExpression(self, node, visited_children) -> Expression:
+ if isinstance(visited_children, list):
+ base, suffix = visited_children[0]
+ if len(suffix) > 0:
+ for suffix in suffix:
+ base = suffix(base)
+ return base
+ raise NotImplementedError('array/struct literals')
+
+ def visit_NegativeExpression(self, node, visited_children):
+ minus, body = visited_children
+ assert minus.type == '-'
+ return NegativeExpression(body)
+
+ def visit_AddressOfExpression(self, node, visited_children):
+ ampersand, body = visited_children
+ assert ampersand.type == '&'
+ return AddressOfExpression(body)
+
+ def visit_LogicalNotExpression(self, node, visited_children):
+ bang, body = visited_children
+ assert bang.type == '!'
+ return LogicalNotExpression(body)
+
+ def visit_SizeofExpression(self, node, visited_children):
+ sizeof, argument = visited_children[0]
+ assert sizeof.type == 'sizeof'
+ return SizeofExpression(argument)
+
+ def visit_TermExpression(self, node, visited_children) -> Expression:
+ base, suffix = visited_children
+ if suffix is not None:
+ for op, factor in suffix:
+ if op.type == '*':
+ base = MultiplyExpression(base, factor)
+ else:
+ raise NotImplementedError('term suffix ' + op)
+ return base
+
+ def visit_ArithmeticExpression(self, node, visited_children) -> Expression:
+ base, suffix = visited_children
+ if suffix is not None:
+ for op, term in suffix:
+ if op.type == '+':
+ base = AddExpression(base, term)
+ else:
+ raise NotImplementedError('arithmetic suffix ' + op)
+ return base
+
+ def visit_GreaterEqExpression(self, node, visited_children):
+ value1, op, value2 = visited_children
+ assert op.type == '>='
+ return ComparisonExpression(value1, '>=', value2)
+
+ def visit_LessEqExpression(self, node, visited_children):
+ value1, op, value2 = visited_children
+ assert op.type == '<='
+ return ComparisonExpression(value1, '<=', value2)
+
def generic_visit(self, node, visited_children):
- """ The generic visit method. """
- if not visited_children:
- if len(node.text) == 0:
- return []
- if len(node.text) == 1:
- return node.text[0]
- raise ValueError('just a node: ' + str(node))
- if len(visited_children) == 1:
+ if isinstance(node.expr, TokenMatcher):
+ return node.text[0]
+ if isinstance(node.expr, OneOf):
+ return visited_children[0]
+ if isinstance(node.expr, Optional):
+ if len(visited_children) == 0:
+ return None
return visited_children[0]
- return visited_children
+ if isinstance(node.expr, Sequence) and node.expr.name != '':
+ raise NotImplementedError('visit for sequence ' + str(node.expr))
+ if isinstance(node.expr, Compound):
+ return visited_children
+ print(node.expr)
+ return super(ASTBuilder, self).generic_visit(node, visited_children)
def build_ast(parse_tree, include_dirs):