From c7de22f575607a7966b6b592dbf81bd3f867a2e4 Mon Sep 17 00:00:00 2001 From: Melody Horn Date: Wed, 23 Dec 2020 00:46:52 -0700 Subject: implement a bunch more stuff --- crowbar_reference_compiler/__init__.py | 2 +- crowbar_reference_compiler/ast.py | 90 +++++++++++++- crowbar_reference_compiler/ssagen.py | 221 ++++++++++++++++++++++++++++++--- 3 files changed, 295 insertions(+), 18 deletions(-) diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py index 1410bf7..46d0115 100644 --- a/crowbar_reference_compiler/__init__.py +++ b/crowbar_reference_compiler/__init__.py @@ -62,7 +62,7 @@ def main(): if args.out is None: args.out = args.input.replace('.cro', '.o') extra_gcc_flags.append('-c') - gcc_result = subprocess.run(['gcc', '-x', 'assembler', '-o', args.out, '-'], input=asm, text=True) + gcc_result = subprocess.run(['gcc', '-x', 'assembler', '-o', args.out, *extra_gcc_flags, '-'], input=asm, text=True) sys.exit(gcc_result.returncode) diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py index 86e64fe..37ce0da 100644 --- a/crowbar_reference_compiler/ast.py +++ b/crowbar_reference_compiler/ast.py @@ -12,23 +12,48 @@ from .parser import parse_header @dataclass class Type: - pass + def size_bytes(self, declarations: List['Declaration']) -> int: + raise NotImplementedError('type.size_bytes() on ' + str(type(self)) + ' not implemented') @dataclass class Expression: - pass + def type(self, declarations: List['Declaration']) -> Type: + raise NotImplementedError('expression.type() on ' + str(type(self)) + ' not implemented') @dataclass class ConstantExpression(Expression): value: str + def type(self, _: List['Declaration']) -> Type: + if self.value.startswith('"'): + return PointerType(ConstType(BasicType('char'))) + elif self.value.startswith("'"): + return BasicType('char') + elif self.value in ['true', 'false']: + return BasicType('bool') + elif '.' in self.value: + return BasicType('float?') # TODO infer size + else: + return BasicType('int?') # TODO infer size and signedness + @dataclass class VariableExpression(Expression): name: str + def type(self, declarations: List['Declaration']) -> Type: + for decl in declarations: + if decl.name == self.name: + if isinstance(decl, VariableDeclaration): + return decl.type + elif isinstance(decl, VariableDefinition): + return decl.type + elif isinstance(decl, FunctionDeclaration) or isinstance(decl, FunctionDefinition): + return FunctionType(decl.return_type, [arg.type for arg in decl.args]) + raise KeyError('unknown variable ' + self.name) + @dataclass class AddExpression(Expression): @@ -36,6 +61,12 @@ class AddExpression(Expression): term2: Expression +@dataclass +class SubtractExpression(Expression): + term1: Expression + term2: Expression + + @dataclass class MultiplyExpression(Expression): factor1: Expression @@ -47,6 +78,22 @@ class StructPointerElementExpression(Expression): base: Expression element: str + def type(self, declarations: List['Declaration']) -> Type: + base_type = self.base.type(declarations) + assert isinstance(base_type, PointerType) + assert isinstance(base_type.target, BasicType) + hopefully_struct, struct_name = base_type.target.name.split(' ') + assert hopefully_struct == 'struct' + for decl in declarations: + if isinstance(decl, StructDeclaration) and decl.name == struct_name: + if decl.fields is None: + raise KeyError('struct ' + struct_name + ' is opaque') + for elem in decl.fields: + if elem.name == self.element: + return elem.type + raise KeyError('element ' + self.element + ' not found in struct ' + struct_name) + raise KeyError('struct ' + struct_name + ' not found') + @dataclass class ArrayIndexExpression(Expression): @@ -91,6 +138,20 @@ class ComparisonExpression(Expression): class BasicType(Type): name: str + def size_bytes(self, declarations: List['Declaration']) -> int: + if self.name == 'uint8': + return 1 + elif self.name == 'uintsize': + return 8 + elif self.name.startswith('struct'): + _, struct_name = self.name.split(' ') + for decl in declarations: + if isinstance(decl, StructDeclaration) and decl.name == struct_name: + if decl.fields is None: + raise KeyError('struct ' + struct_name + ' is opaque') + return sum(field.type.size_bytes(declarations) for field in decl.fields) + raise NotImplementedError('size of ' + str(self) + ' not yet found') + @dataclass class ConstType(Type): @@ -101,6 +162,9 @@ class ConstType(Type): class PointerType(Type): target: Type + def size_bytes(self, declarations: List['Declaration']) -> int: + return 8 # TODO figure out 32 bit vs 64 bit + @dataclass class ArrayType(Type): @@ -226,6 +290,14 @@ class UpdateAssignment(AssignmentStatement): operation: str value: Expression + def deconstruct(self) -> DirectAssignment: + if self.operation == '+=': + return DirectAssignment(self.destination, AddExpression(self.destination, self.value)) + elif self.operation == '*=': + return DirectAssignment(self.destination, MultiplyExpression(self.destination, self.value)) + else: + raise NotImplementedError('UpdateAssignment deconstruct with ' + self.operation) + @dataclass class CrementAssignment(AssignmentStatement): @@ -273,12 +345,24 @@ class HeaderFile: includes: List['HeaderFile'] contents: List[HeaderFileElement] + def get_declarations(self) -> List[Declaration]: + included_declarations = [x.get_declarations() for x in self.includes] + own_declarations = [x for x in self.contents if isinstance(x, Declaration)] + all_declarations = included_declarations + [own_declarations] + return [x for l in all_declarations for x in l] + @dataclass class ImplementationFile: includes: List[HeaderFile] contents: List[ImplementationFileElement] + def get_declarations(self) -> List[Declaration]: + included_declarations = [x.get_declarations() for x in self.includes] + own_declarations = [x for x in self.contents if isinstance(x, Declaration)] + all_declarations = included_declarations + [own_declarations] + return [x for l in all_declarations for x in l] + # noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal class ASTBuilder(NodeVisitor): @@ -605,6 +689,8 @@ class ASTBuilder(NodeVisitor): for op, term in suffix: if op.type == '+': base = AddExpression(base, term) + elif op.type == '-': + base = SubtractExpression(base, term) else: raise NotImplementedError('arithmetic suffix ' + op) return base diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py index 508025e..b326239 100644 --- a/crowbar_reference_compiler/ssagen.py +++ b/crowbar_reference_compiler/ssagen.py @@ -1,9 +1,13 @@ +import dataclasses from dataclasses import dataclass from functools import singledispatch from typing import List from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \ - VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression + VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, \ + AddExpression, StructPointerElementExpression, Declaration, PointerType, StructDeclaration, VariableDefinition, \ + MultiplyExpression, LogicalNotExpression, DirectAssignment, UpdateAssignment, SizeofExpression, Expression, \ + ConstType, ArrayIndexExpression, ArrayType, NegativeExpression, SubtractExpression, AddressOfExpression @dataclass @@ -25,13 +29,14 @@ class SsaResult: @dataclass class CompileContext: + declarations: List[Declaration] next_data: int = 0 next_temp: int = 0 next_label: int = 0 def build_ssa(file: ImplementationFile) -> str: - result = compile_to_ssa(file, CompileContext()) + result = compile_to_ssa(file, CompileContext(file.get_declarations())) data = '\n'.join(result.data) code = '\n'.join(result.code) return data + '\n\n' + code @@ -53,12 +58,20 @@ def _(target: ImplementationFile, context: CompileContext): @compile_to_ssa.register def _(target: FunctionDefinition, context: CompileContext) -> SsaResult: result = SsaResult([], []) + context = dataclasses.replace(context, declarations=target.args + context.declarations) for statement in target.body: result += compile_to_ssa(statement, context) + if isinstance(statement, Declaration): + context = dataclasses.replace(context, declarations=[statement]+context.declarations) + if not result.code[-1].startswith('ret'): + result.code.append('ret') code = [' ' + instr for instr in result.code] - assert len(target.args) == 0 - assert target.return_type == BasicType('int32') - code = [f"export function w ${target.name}() {{", "@start", *code, "}"] + # TODO types + args = ','.join(f"l %{x.name}" for x in target.args) + ret_type = '' + if target.return_type != BasicType('void'): + ret_type = 'l' + code = [f"export function {ret_type} ${target.name}({args}) {{", "@start", *code, "}"] return SsaResult(result.data, code) @@ -82,14 +95,28 @@ def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult: @compile_to_ssa.register def _(target: ConstantExpression, context: CompileContext) -> SsaResult: - if target.value.startswith('"'): + if target.type(context.declarations) == PointerType(ConstType(BasicType('char'))): data_dest = context.next_data context.next_data += 1 data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"] temp = context.next_temp context.next_temp += 1 code = [f"%t{temp} =l copy $data{data_dest}"] - else: + elif target.type(context.declarations) == BasicType('char'): + data = [] + temp = context.next_temp + context.next_temp += 1 + code = [f"%t{temp} =l copy {ord(target.value[1])}"] # TODO handle escape sequences + elif target.type(context.declarations) == BasicType('bool'): + data = [] + temp = context.next_temp + context.next_temp += 1 + if target.value == 'true': + value = 1 + else: + value = 0 + code = [f"%t{temp} =l copy {value}"] + elif target.type(context.declarations) == BasicType('int?'): assert not target.value.startswith('0b') assert not target.value.startswith('0B') assert not target.value.startswith('0o') @@ -102,7 +129,9 @@ def _(target: ConstantExpression, context: CompileContext) -> SsaResult: data = [] temp = context.next_temp context.next_temp += 1 - code = [f"%t{temp} =w copy {target.value}"] + code = [f"%t{temp} =l copy {target.value}"] + else: + raise NotImplementedError('compiling ' + str(target)) return SsaResult(data, code) @@ -130,11 +159,14 @@ def _(target: IfStatement, context: CompileContext) -> SsaResult: result.code.append(f"@l{true_label}") for statement in target.then: result += compile_to_ssa(statement, context) - result.code.append(f"jmp @l{after_label}") + if not result.code[-1].startswith('ret'): + result.code.append(f"jmp @l{after_label}") result.code.append(f"@l{false_label}") - for statement in target.els: - result += compile_to_ssa(statement, context) - result.code.append(f"jmp @l{after_label}") + if target.els is not None: + for statement in target.els: + result += compile_to_ssa(statement, context) + if not result.code[-1].startswith('ret'): + result.code.append(f"jmp @l{after_label}") result.code.append(f"@l{after_label}") return result @@ -147,10 +179,16 @@ def _(target: ComparisonExpression, context: CompileContext) -> SsaResult: value2_dest = context.next_temp - 1 result_dest = context.next_temp context.next_temp += 1 + # TODO types, and signedness if target.op == '==': - result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}") + op = "ceqw" + elif target.op == '>=': + op = "cugew" + elif target.op == '<=': + op = "culew" else: raise NotImplementedError('comparison ' + target.op) + result.code.append(f"%t{result_dest} =l {op} %t{value1_dest}, %t{value2_dest}") return result @@ -163,7 +201,33 @@ def _(target: AddExpression, context: CompileContext) -> SsaResult: result_reg = context.next_temp context.next_temp += 1 # TODO make sure the types are correct - result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}") + result.code.append(f"%t{result_reg} =l add %t{value1_dest}, %t{value2_dest}") + return result + + +@compile_to_ssa.register +def _(target: SubtractExpression, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.term1, context) + value1_dest = context.next_temp - 1 + result += compile_to_ssa(target.term2, context) + value2_dest = context.next_temp - 1 + result_reg = context.next_temp + context.next_temp += 1 + # TODO make sure the types are correct + result.code.append(f"%t{result_reg} =l sub %t{value1_dest}, %t{value2_dest}") + return result + + +@compile_to_ssa.register +def _(target: MultiplyExpression, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.factor1, context) + value1_dest = context.next_temp - 1 + result += compile_to_ssa(target.factor2, context) + value2_dest = context.next_temp - 1 + result_reg = context.next_temp + context.next_temp += 1 + # TODO make sure the types are correct + result.code.append(f"%t{result_reg} =l mul %t{value1_dest}, %t{value2_dest}") return result @@ -172,4 +236,131 @@ def _(target: VariableExpression, context: CompileContext) -> SsaResult: # TODO make sure any of this is reasonable result = context.next_temp context.next_temp += 1 - return SsaResult([], [f"%t{result} =w copy %{target.name}"]) + return SsaResult([], [f"%t{result} =l copy %{target.name}"]) + + +@compile_to_ssa.register +def _(target: VariableDefinition, context: CompileContext) -> SsaResult: + # TODO figure some shit out + result = compile_to_ssa(target.value, context) + result_dest = context.next_temp - 1 + result.code.append(f"%{target.name} =l copy %t{result_dest}") + return result + + +@compile_to_ssa.register +def _(target: LogicalNotExpression, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.body, context) + inner_result_dest = context.next_temp - 1 + result_dest = context.next_temp + context.next_temp += 1 + result.code.append(f"%t{result_dest} =l ceqw %t{inner_result_dest}, 0") + return result + + +@compile_to_ssa.register +def _(target: NegativeExpression, context: CompileContext) -> SsaResult: + return compile_to_ssa(SubtractExpression(ConstantExpression('0'), target.body), context) + + +@compile_to_ssa.register +def _(target: ArrayIndexExpression, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.array, context) + base = context.next_temp - 1 + result += compile_to_ssa(target.index, context) + index = context.next_temp - 1 + array_type = target.array.type(context.declarations) + if isinstance(array_type, PointerType): + array_type = array_type.target + assert isinstance(array_type, ArrayType) + content_type = array_type.contents + scale = content_type.size_bytes(context.declarations) + offset = context.next_temp + context.next_temp += 1 + address = context.next_temp + context.next_temp += 1 + dest = context.next_temp + context.next_temp += 1 + # TODO types + result.code.append(f"%t{offset} =l mul %t{index}, {scale}") + result.code.append(f"%t{address} =l add %t{base}, %t{offset}") + result.code.append(f"%t{dest} =l loadsw %t{address}") + return result + + +@compile_to_ssa.register +def _(target: StructPointerElementExpression, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.base, context) + base_dest = context.next_temp - 1 + # hoooo boy. + base_type = target.base.type(context.declarations) + assert isinstance(base_type, PointerType) + assert isinstance(base_type.target, BasicType) + hopefully_struct, struct_name = base_type.target.name.split(' ') + assert hopefully_struct == 'struct' + target_struct = None + for decl in context.declarations: + if isinstance(decl, StructDeclaration) and decl.name == struct_name: + if decl.fields is None: + raise KeyError('struct ' + struct_name + ' is opaque') + target_struct = decl + break + if target_struct is None: + raise KeyError('struct ' + struct_name + ' not found') + offset = 0 + for field in target_struct.fields: + if field.name == target.element: + break + else: + offset += field.type.size_bytes(context.declarations) + temp = context.next_temp + context.next_temp += 1 + result_dest = context.next_temp + context.next_temp += 1 + # TODO types + result.code.append(f"%t{temp} =l add %t{base_dest}, {offset}") + result.code.append(f"%t{result_dest} =l loadsw %t{temp}") + return result + + +@compile_to_ssa.register +def _(target: AddressOfExpression, context: CompileContext) -> SsaResult: + if isinstance(target.body, StructPointerElementExpression) or isinstance(target.body, ArrayIndexExpression): + result = compile_to_ssa(target.body, context) + result.code.pop() + context.next_temp -= 1 + else: + raise NotImplementedError('address of ' + str(type(target.body))) + return result + + +@compile_to_ssa.register +def _(target: DirectAssignment, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.value, context) + result_dest = context.next_temp - 1 + if isinstance(target.destination, VariableExpression): + raise NotImplementedError('assign directly to variable') + elif isinstance(target.destination, StructPointerElementExpression) or isinstance(target.destination, ArrayIndexExpression): + sub_result = compile_to_ssa(target.destination, context) + last_instr = sub_result.code.pop() + _, _, _, location = last_instr.split(' ') + # TODO type + sub_result.code.append(f"storew %t{result_dest}, {location}") + result += sub_result + else: + raise NotImplementedError('assign to ' + str(type(target.destination))) + return result + + +@compile_to_ssa.register +def _(target: UpdateAssignment, context: CompileContext) -> SsaResult: + return compile_to_ssa(target.deconstruct(), context) + + +@compile_to_ssa.register +def _(target: SizeofExpression, context: CompileContext) -> SsaResult: + target = target.body + if isinstance(target, Expression): + target = target.type(context.declarations) + size = target.size_bytes(context.declarations) + return compile_to_ssa(ConstantExpression(str(size)), context) -- cgit v1.2.3