From c7de22f575607a7966b6b592dbf81bd3f867a2e4 Mon Sep 17 00:00:00 2001 From: Melody Horn Date: Wed, 23 Dec 2020 00:46:52 -0700 Subject: implement a bunch more stuff --- crowbar_reference_compiler/ast.py | 90 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 88 insertions(+), 2 deletions(-) (limited to 'crowbar_reference_compiler/ast.py') diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py index 86e64fe..37ce0da 100644 --- a/crowbar_reference_compiler/ast.py +++ b/crowbar_reference_compiler/ast.py @@ -12,23 +12,48 @@ from .parser import parse_header @dataclass class Type: - pass + def size_bytes(self, declarations: List['Declaration']) -> int: + raise NotImplementedError('type.size_bytes() on ' + str(type(self)) + ' not implemented') @dataclass class Expression: - pass + def type(self, declarations: List['Declaration']) -> Type: + raise NotImplementedError('expression.type() on ' + str(type(self)) + ' not implemented') @dataclass class ConstantExpression(Expression): value: str + def type(self, _: List['Declaration']) -> Type: + if self.value.startswith('"'): + return PointerType(ConstType(BasicType('char'))) + elif self.value.startswith("'"): + return BasicType('char') + elif self.value in ['true', 'false']: + return BasicType('bool') + elif '.' in self.value: + return BasicType('float?') # TODO infer size + else: + return BasicType('int?') # TODO infer size and signedness + @dataclass class VariableExpression(Expression): name: str + def type(self, declarations: List['Declaration']) -> Type: + for decl in declarations: + if decl.name == self.name: + if isinstance(decl, VariableDeclaration): + return decl.type + elif isinstance(decl, VariableDefinition): + return decl.type + elif isinstance(decl, FunctionDeclaration) or isinstance(decl, FunctionDefinition): + return FunctionType(decl.return_type, [arg.type for arg in decl.args]) + raise KeyError('unknown variable ' + self.name) + @dataclass class AddExpression(Expression): @@ -36,6 +61,12 @@ class AddExpression(Expression): term2: Expression +@dataclass +class SubtractExpression(Expression): + term1: Expression + term2: Expression + + @dataclass class MultiplyExpression(Expression): factor1: Expression @@ -47,6 +78,22 @@ class StructPointerElementExpression(Expression): base: Expression element: str + def type(self, declarations: List['Declaration']) -> Type: + base_type = self.base.type(declarations) + assert isinstance(base_type, PointerType) + assert isinstance(base_type.target, BasicType) + hopefully_struct, struct_name = base_type.target.name.split(' ') + assert hopefully_struct == 'struct' + for decl in declarations: + if isinstance(decl, StructDeclaration) and decl.name == struct_name: + if decl.fields is None: + raise KeyError('struct ' + struct_name + ' is opaque') + for elem in decl.fields: + if elem.name == self.element: + return elem.type + raise KeyError('element ' + self.element + ' not found in struct ' + struct_name) + raise KeyError('struct ' + struct_name + ' not found') + @dataclass class ArrayIndexExpression(Expression): @@ -91,6 +138,20 @@ class ComparisonExpression(Expression): class BasicType(Type): name: str + def size_bytes(self, declarations: List['Declaration']) -> int: + if self.name == 'uint8': + return 1 + elif self.name == 'uintsize': + return 8 + elif self.name.startswith('struct'): + _, struct_name = self.name.split(' ') + for decl in declarations: + if isinstance(decl, StructDeclaration) and decl.name == struct_name: + if decl.fields is None: + raise KeyError('struct ' + struct_name + ' is opaque') + return sum(field.type.size_bytes(declarations) for field in decl.fields) + raise NotImplementedError('size of ' + str(self) + ' not yet found') + @dataclass class ConstType(Type): @@ -101,6 +162,9 @@ class ConstType(Type): class PointerType(Type): target: Type + def size_bytes(self, declarations: List['Declaration']) -> int: + return 8 # TODO figure out 32 bit vs 64 bit + @dataclass class ArrayType(Type): @@ -226,6 +290,14 @@ class UpdateAssignment(AssignmentStatement): operation: str value: Expression + def deconstruct(self) -> DirectAssignment: + if self.operation == '+=': + return DirectAssignment(self.destination, AddExpression(self.destination, self.value)) + elif self.operation == '*=': + return DirectAssignment(self.destination, MultiplyExpression(self.destination, self.value)) + else: + raise NotImplementedError('UpdateAssignment deconstruct with ' + self.operation) + @dataclass class CrementAssignment(AssignmentStatement): @@ -273,12 +345,24 @@ class HeaderFile: includes: List['HeaderFile'] contents: List[HeaderFileElement] + def get_declarations(self) -> List[Declaration]: + included_declarations = [x.get_declarations() for x in self.includes] + own_declarations = [x for x in self.contents if isinstance(x, Declaration)] + all_declarations = included_declarations + [own_declarations] + return [x for l in all_declarations for x in l] + @dataclass class ImplementationFile: includes: List[HeaderFile] contents: List[ImplementationFileElement] + def get_declarations(self) -> List[Declaration]: + included_declarations = [x.get_declarations() for x in self.includes] + own_declarations = [x for x in self.contents if isinstance(x, Declaration)] + all_declarations = included_declarations + [own_declarations] + return [x for l in all_declarations for x in l] + # noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal class ASTBuilder(NodeVisitor): @@ -605,6 +689,8 @@ class ASTBuilder(NodeVisitor): for op, term in suffix: if op.type == '+': base = AddExpression(base, term) + elif op.type == '-': + base = SubtractExpression(base, term) else: raise NotImplementedError('arithmetic suffix ' + op) return base -- cgit v1.2.3