import dataclasses from dataclasses import dataclass from functools import singledispatch from typing import List from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \ VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, \ AddExpression, StructPointerElementExpression, Declaration, PointerType, StructDeclaration, VariableDefinition, \ MultiplyExpression, LogicalNotExpression, DirectAssignment, UpdateAssignment, SizeofExpression, Expression, \ ConstType, ArrayIndexExpression, ArrayType, NegativeExpression, SubtractExpression, AddressOfExpression, \ WhileStatement @dataclass class SsaResult: data: List[str] code: List[str] def __add__(self, other: 'SsaResult') -> 'SsaResult': if not isinstance(other, SsaResult): return NotImplemented return SsaResult(self.data + other.data, self.code + other.code) def __radd__(self, other: 'SsaResult'): if not isinstance(other, SsaResult): return NotImplemented self.data += other.data self.code += other.code @dataclass class CompileContext: declarations: List[Declaration] next_data: int = 0 next_temp: int = 0 next_label: int = 0 def build_ssa(file: ImplementationFile) -> str: result = compile_to_ssa(file, CompileContext(file.get_declarations())) data = '\n'.join(result.data) code = '\n'.join(result.code) return data + '\n\n' + code @singledispatch def compile_to_ssa(target, context: CompileContext) -> SsaResult: raise NotImplementedError('unannotated compile on ' + str(type(target))) @compile_to_ssa.register def _(target: ImplementationFile, context: CompileContext): result = SsaResult([], []) for target in target.contents: result += compile_to_ssa(target, context) return result @compile_to_ssa.register def _(target: FunctionDefinition, context: CompileContext) -> SsaResult: result = SsaResult([], []) context = dataclasses.replace(context, declarations=target.args + context.declarations) for statement in target.body: result += compile_to_ssa(statement, context) if isinstance(statement, Declaration): context = dataclasses.replace(context, declarations=[statement]+context.declarations) if not result.code[-1].startswith('ret'): result.code.append('ret') code = [' ' + instr for instr in result.code] # TODO types args = ','.join(f"l %{x.name}" for x in target.args) ret_type = '' if target.return_type != BasicType('void'): ret_type = 'l' code = [f"export function {ret_type} ${target.name}({args}) {{", "@start", *code, "}"] return SsaResult(result.data, code) @compile_to_ssa.register def _(target: ExpressionStatement, context: CompileContext) -> SsaResult: return compile_to_ssa(target.body, context) @compile_to_ssa.register def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult: assert isinstance(target.function, VariableExpression) result = SsaResult([], []) args = [] for i, expr in enumerate(target.arguments): result += compile_to_ssa(expr, context) arg_dest = context.next_temp - 1 args += [f"l %t{arg_dest}"] result_dest = context.next_temp context.next_temp += 1 # TODO size result.code.append(f"%t{result_dest} =l call ${target.function.name}({','.join(args)})") return result @compile_to_ssa.register def _(target: ConstantExpression, context: CompileContext) -> SsaResult: if target.type(context.declarations) == PointerType(ConstType(BasicType('char'))): data_dest = context.next_data context.next_data += 1 data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"] temp = context.next_temp context.next_temp += 1 code = [f"%t{temp} =l copy $data{data_dest}"] elif target.type(context.declarations) == BasicType('char'): value = target.value.strip("'") if len(value) == 1: value = ord(value) elif value == r'\0': value = 0 else: raise NotImplementedError('escape sequence ' + value) data = [] temp = context.next_temp context.next_temp += 1 code = [f"%t{temp} =l copy {value}"] elif target.type(context.declarations) == BasicType('bool'): data = [] temp = context.next_temp context.next_temp += 1 if target.value == 'true': value = 1 else: value = 0 code = [f"%t{temp} =l copy {value}"] elif target.type(context.declarations) == BasicType('int?'): assert not target.value.startswith('0b') assert not target.value.startswith('0B') assert not target.value.startswith('0o') assert not target.value.startswith('0x') assert not target.value.startswith('0X') assert not target.value.startswith('0f') assert not target.value.startswith('0F') assert '.' not in target.value assert not target.value.startswith("'") data = [] temp = context.next_temp context.next_temp += 1 code = [f"%t{temp} =l copy {target.value}"] else: raise NotImplementedError('compiling ' + str(target)) return SsaResult(data, code) @compile_to_ssa.register def _(target: ReturnStatement, context: CompileContext) -> SsaResult: if target.body is None: return SsaResult([], ['ret']) result = compile_to_ssa(target.body, context) ret_val_dest = context.next_temp - 1 result.code.append(f"ret %t{ret_val_dest}") return result @compile_to_ssa.register def _(target: IfStatement, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.condition, context) condition_dest = context.next_temp - 1 true_label = context.next_label context.next_label += 1 false_label = context.next_label context.next_label += 1 after_label = context.next_label context.next_label += 1 result.code.append(f"jnz %t{condition_dest}, @l{true_label}, @l{false_label}") result.code.append(f"@l{true_label}") for statement in target.then: result += compile_to_ssa(statement, context) if not result.code[-1].startswith('ret'): result.code.append(f"jmp @l{after_label}") result.code.append(f"@l{false_label}") if target.els is not None: for statement in target.els: result += compile_to_ssa(statement, context) if not result.code[-1].startswith('ret'): result.code.append(f"jmp @l{after_label}") result.code.append(f"@l{after_label}") return result @compile_to_ssa.register def _(target: WhileStatement, context: CompileContext) -> SsaResult: start_label = context.next_label context.next_label += 1 result = SsaResult([], [f"@l{start_label}"]) body_label = context.next_label context.next_label += 1 end_label = context.next_label context.next_label += 1 result += compile_to_ssa(target.condition, context) condition_dest = context.next_temp - 1 result.code.append(f"jnz %t{condition_dest}, @l{body_label}, @l{end_label}") result.code.append(f"@l{body_label}") for statement in target.body: result += compile_to_ssa(statement, context) result.code.append(f"jmp @l{start_label}") result.code.append(f"@l{end_label}") return result @compile_to_ssa.register def _(target: ComparisonExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.value1, context) value1_dest = context.next_temp - 1 result += compile_to_ssa(target.value2, context) value2_dest = context.next_temp - 1 result_dest = context.next_temp context.next_temp += 1 # TODO types, and signedness if target.op == '==': op = "ceqw" elif target.op == '>=': op = "cugew" elif target.op == '<=': op = "culew" else: raise NotImplementedError('comparison ' + target.op) result.code.append(f"%t{result_dest} =l {op} %t{value1_dest}, %t{value2_dest}") return result @compile_to_ssa.register def _(target: AddExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.term1, context) value1_dest = context.next_temp - 1 result += compile_to_ssa(target.term2, context) value2_dest = context.next_temp - 1 result_reg = context.next_temp context.next_temp += 1 # TODO make sure the types are correct result.code.append(f"%t{result_reg} =l add %t{value1_dest}, %t{value2_dest}") return result @compile_to_ssa.register def _(target: SubtractExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.term1, context) value1_dest = context.next_temp - 1 result += compile_to_ssa(target.term2, context) value2_dest = context.next_temp - 1 result_reg = context.next_temp context.next_temp += 1 # TODO make sure the types are correct result.code.append(f"%t{result_reg} =l sub %t{value1_dest}, %t{value2_dest}") return result @compile_to_ssa.register def _(target: MultiplyExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.factor1, context) value1_dest = context.next_temp - 1 result += compile_to_ssa(target.factor2, context) value2_dest = context.next_temp - 1 result_reg = context.next_temp context.next_temp += 1 # TODO make sure the types are correct result.code.append(f"%t{result_reg} =l mul %t{value1_dest}, %t{value2_dest}") return result @compile_to_ssa.register def _(target: VariableExpression, context: CompileContext) -> SsaResult: # TODO make sure any of this is reasonable result = context.next_temp context.next_temp += 1 return SsaResult([], [f"%t{result} =l copy %{target.name}"]) @compile_to_ssa.register def _(target: VariableDefinition, context: CompileContext) -> SsaResult: # TODO figure some shit out result = compile_to_ssa(target.value, context) result_dest = context.next_temp - 1 result.code.append(f"%{target.name} =l copy %t{result_dest}") return result @compile_to_ssa.register def _(target: LogicalNotExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.body, context) inner_result_dest = context.next_temp - 1 result_dest = context.next_temp context.next_temp += 1 result.code.append(f"%t{result_dest} =l ceqw %t{inner_result_dest}, 0") return result @compile_to_ssa.register def _(target: NegativeExpression, context: CompileContext) -> SsaResult: return compile_to_ssa(SubtractExpression(ConstantExpression('0'), target.body), context) @compile_to_ssa.register def _(target: ArrayIndexExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.array, context) base = context.next_temp - 1 result += compile_to_ssa(target.index, context) index = context.next_temp - 1 array_type = target.array.type(context.declarations) if isinstance(array_type, PointerType): array_type = array_type.target assert isinstance(array_type, ArrayType) content_type = array_type.contents scale = content_type.size_bytes(context.declarations) offset = context.next_temp context.next_temp += 1 address = context.next_temp context.next_temp += 1 dest = context.next_temp context.next_temp += 1 # TODO types result.code.append(f"%t{offset} =l mul %t{index}, {scale}") result.code.append(f"%t{address} =l add %t{base}, %t{offset}") result.code.append(f"%t{dest} =l loadl %t{address}") return result @compile_to_ssa.register def _(target: StructPointerElementExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.base, context) base_dest = context.next_temp - 1 # hoooo boy. base_type = target.base.type(context.declarations) assert isinstance(base_type, PointerType) assert isinstance(base_type.target, BasicType) hopefully_struct, struct_name = base_type.target.name.split(' ') assert hopefully_struct == 'struct' target_struct = None for decl in context.declarations: if isinstance(decl, StructDeclaration) and decl.name == struct_name: if decl.fields is None: raise KeyError('struct ' + struct_name + ' is opaque') target_struct = decl break if target_struct is None: raise KeyError('struct ' + struct_name + ' not found') offset = 0 for field in target_struct.fields: if field.name == target.element: break else: offset += field.type.size_bytes(context.declarations) temp = context.next_temp context.next_temp += 1 result_dest = context.next_temp context.next_temp += 1 # TODO types result.code.append(f"%t{temp} =l add %t{base_dest}, {offset}") result.code.append(f"%t{result_dest} =l loadl %t{temp}") return result @compile_to_ssa.register def _(target: AddressOfExpression, context: CompileContext) -> SsaResult: if isinstance(target.body, StructPointerElementExpression) or isinstance(target.body, ArrayIndexExpression): result = compile_to_ssa(target.body, context) result.code.pop() context.next_temp -= 1 else: raise NotImplementedError('address of ' + str(type(target.body))) return result @compile_to_ssa.register def _(target: DirectAssignment, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.value, context) result_dest = context.next_temp - 1 if isinstance(target.destination, VariableExpression): raise NotImplementedError('assign directly to variable') elif isinstance(target.destination, StructPointerElementExpression) or isinstance(target.destination, ArrayIndexExpression): sub_result = compile_to_ssa(target.destination, context) last_instr = sub_result.code.pop() _, _, _, location = last_instr.split(' ') # TODO type sub_result.code.append(f"storel %t{result_dest}, {location}") result += sub_result else: raise NotImplementedError('assign to ' + str(type(target.destination))) return result @compile_to_ssa.register def _(target: UpdateAssignment, context: CompileContext) -> SsaResult: return compile_to_ssa(target.deconstruct(), context) @compile_to_ssa.register def _(target: SizeofExpression, context: CompileContext) -> SsaResult: target = target.body if isinstance(target, Expression): target = target.type(context.declarations) size = target.size_bytes(context.declarations) return compile_to_ssa(ConstantExpression(str(size)), context)