aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler/ast.py
diff options
context:
space:
mode:
Diffstat (limited to 'crowbar_reference_compiler/ast.py')
-rw-r--r--crowbar_reference_compiler/ast.py90
1 files changed, 88 insertions, 2 deletions
diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py
index 86e64fe..37ce0da 100644
--- a/crowbar_reference_compiler/ast.py
+++ b/crowbar_reference_compiler/ast.py
@@ -12,23 +12,48 @@ from .parser import parse_header
@dataclass
class Type:
- pass
+ def size_bytes(self, declarations: List['Declaration']) -> int:
+ raise NotImplementedError('type.size_bytes() on ' + str(type(self)) + ' not implemented')
@dataclass
class Expression:
- pass
+ def type(self, declarations: List['Declaration']) -> Type:
+ raise NotImplementedError('expression.type() on ' + str(type(self)) + ' not implemented')
@dataclass
class ConstantExpression(Expression):
value: str
+ def type(self, _: List['Declaration']) -> Type:
+ if self.value.startswith('"'):
+ return PointerType(ConstType(BasicType('char')))
+ elif self.value.startswith("'"):
+ return BasicType('char')
+ elif self.value in ['true', 'false']:
+ return BasicType('bool')
+ elif '.' in self.value:
+ return BasicType('float?') # TODO infer size
+ else:
+ return BasicType('int?') # TODO infer size and signedness
+
@dataclass
class VariableExpression(Expression):
name: str
+ def type(self, declarations: List['Declaration']) -> Type:
+ for decl in declarations:
+ if decl.name == self.name:
+ if isinstance(decl, VariableDeclaration):
+ return decl.type
+ elif isinstance(decl, VariableDefinition):
+ return decl.type
+ elif isinstance(decl, FunctionDeclaration) or isinstance(decl, FunctionDefinition):
+ return FunctionType(decl.return_type, [arg.type for arg in decl.args])
+ raise KeyError('unknown variable ' + self.name)
+
@dataclass
class AddExpression(Expression):
@@ -37,6 +62,12 @@ class AddExpression(Expression):
@dataclass
+class SubtractExpression(Expression):
+ term1: Expression
+ term2: Expression
+
+
+@dataclass
class MultiplyExpression(Expression):
factor1: Expression
factor2: Expression
@@ -47,6 +78,22 @@ class StructPointerElementExpression(Expression):
base: Expression
element: str
+ def type(self, declarations: List['Declaration']) -> Type:
+ base_type = self.base.type(declarations)
+ assert isinstance(base_type, PointerType)
+ assert isinstance(base_type.target, BasicType)
+ hopefully_struct, struct_name = base_type.target.name.split(' ')
+ assert hopefully_struct == 'struct'
+ for decl in declarations:
+ if isinstance(decl, StructDeclaration) and decl.name == struct_name:
+ if decl.fields is None:
+ raise KeyError('struct ' + struct_name + ' is opaque')
+ for elem in decl.fields:
+ if elem.name == self.element:
+ return elem.type
+ raise KeyError('element ' + self.element + ' not found in struct ' + struct_name)
+ raise KeyError('struct ' + struct_name + ' not found')
+
@dataclass
class ArrayIndexExpression(Expression):
@@ -91,6 +138,20 @@ class ComparisonExpression(Expression):
class BasicType(Type):
name: str
+ def size_bytes(self, declarations: List['Declaration']) -> int:
+ if self.name == 'uint8':
+ return 1
+ elif self.name == 'uintsize':
+ return 8
+ elif self.name.startswith('struct'):
+ _, struct_name = self.name.split(' ')
+ for decl in declarations:
+ if isinstance(decl, StructDeclaration) and decl.name == struct_name:
+ if decl.fields is None:
+ raise KeyError('struct ' + struct_name + ' is opaque')
+ return sum(field.type.size_bytes(declarations) for field in decl.fields)
+ raise NotImplementedError('size of ' + str(self) + ' not yet found')
+
@dataclass
class ConstType(Type):
@@ -101,6 +162,9 @@ class ConstType(Type):
class PointerType(Type):
target: Type
+ def size_bytes(self, declarations: List['Declaration']) -> int:
+ return 8 # TODO figure out 32 bit vs 64 bit
+
@dataclass
class ArrayType(Type):
@@ -226,6 +290,14 @@ class UpdateAssignment(AssignmentStatement):
operation: str
value: Expression
+ def deconstruct(self) -> DirectAssignment:
+ if self.operation == '+=':
+ return DirectAssignment(self.destination, AddExpression(self.destination, self.value))
+ elif self.operation == '*=':
+ return DirectAssignment(self.destination, MultiplyExpression(self.destination, self.value))
+ else:
+ raise NotImplementedError('UpdateAssignment deconstruct with ' + self.operation)
+
@dataclass
class CrementAssignment(AssignmentStatement):
@@ -273,12 +345,24 @@ class HeaderFile:
includes: List['HeaderFile']
contents: List[HeaderFileElement]
+ def get_declarations(self) -> List[Declaration]:
+ included_declarations = [x.get_declarations() for x in self.includes]
+ own_declarations = [x for x in self.contents if isinstance(x, Declaration)]
+ all_declarations = included_declarations + [own_declarations]
+ return [x for l in all_declarations for x in l]
+
@dataclass
class ImplementationFile:
includes: List[HeaderFile]
contents: List[ImplementationFileElement]
+ def get_declarations(self) -> List[Declaration]:
+ included_declarations = [x.get_declarations() for x in self.includes]
+ own_declarations = [x for x in self.contents if isinstance(x, Declaration)]
+ all_declarations = included_declarations + [own_declarations]
+ return [x for l in all_declarations for x in l]
+
# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class ASTBuilder(NodeVisitor):
@@ -605,6 +689,8 @@ class ASTBuilder(NodeVisitor):
for op, term in suffix:
if op.type == '+':
base = AddExpression(base, term)
+ elif op.type == '-':
+ base = SubtractExpression(base, term)
else:
raise NotImplementedError('arithmetic suffix ' + op)
return base