from dataclasses import dataclass from functools import singledispatch from typing import List from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \ VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression @dataclass class SsaResult: data: List[str] code: List[str] def __add__(self, other: 'SsaResult') -> 'SsaResult': if not isinstance(other, SsaResult): return NotImplemented return SsaResult(self.data + other.data, self.code + other.code) def __radd__(self, other: 'SsaResult'): if not isinstance(other, SsaResult): return NotImplemented self.data += other.data self.code += other.code @dataclass class CompileContext: next_data: int = 0 next_temp: int = 0 next_label: int = 0 def build_ssa(file: ImplementationFile) -> str: result = compile_to_ssa(file, CompileContext()) data = '\n'.join(result.data) code = '\n'.join(result.code) return data + '\n\n' + code @singledispatch def compile_to_ssa(target, context: CompileContext) -> SsaResult: raise NotImplementedError('unannotated compile on ' + str(type(target))) @compile_to_ssa.register def _(target: ImplementationFile, context: CompileContext): result = SsaResult([], []) for target in target.contents: result += compile_to_ssa(target, context) return result @compile_to_ssa.register def _(target: FunctionDefinition, context: CompileContext) -> SsaResult: result = SsaResult([], []) for statement in target.body: result += compile_to_ssa(statement, context) code = [' ' + instr for instr in result.code] assert len(target.args) == 0 assert target.return_type == BasicType('int32') code = [f"export function w ${target.name}() {{", "@start", *code, "}"] return SsaResult(result.data, code) @compile_to_ssa.register def _(target: ExpressionStatement, context: CompileContext) -> SsaResult: return compile_to_ssa(target.body, context) @compile_to_ssa.register def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult: assert isinstance(target.function, VariableExpression) result = SsaResult([], []) args = [] for i, expr in enumerate(target.arguments): result += compile_to_ssa(expr, context) arg_dest = context.next_temp - 1 args += [f"l %t{arg_dest}"] result.code.append(f"call ${target.function.name}({','.join(args)}, ...)") return result @compile_to_ssa.register def _(target: ConstantExpression, context: CompileContext) -> SsaResult: if target.value.startswith('"'): data_dest = context.next_data context.next_data += 1 data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"] temp = context.next_temp context.next_temp += 1 code = [f"%t{temp} =l copy $data{data_dest}"] else: assert not target.value.startswith('0b') assert not target.value.startswith('0B') assert not target.value.startswith('0o') assert not target.value.startswith('0x') assert not target.value.startswith('0X') assert not target.value.startswith('0f') assert not target.value.startswith('0F') assert '.' not in target.value assert not target.value.startswith("'") data = [] temp = context.next_temp context.next_temp += 1 code = [f"%t{temp} =w copy {target.value}"] return SsaResult(data, code) @compile_to_ssa.register def _(target: ReturnStatement, context: CompileContext) -> SsaResult: if target.body is None: return SsaResult([], ['ret']) result = compile_to_ssa(target.body, context) ret_val_dest = context.next_temp - 1 result.code.append(f"ret %t{ret_val_dest}") return result @compile_to_ssa.register def _(target: IfStatement, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.condition, context) condition_dest = context.next_temp - 1 true_label = context.next_label context.next_label += 1 false_label = context.next_label context.next_label += 1 after_label = context.next_label context.next_label += 1 result.code.append(f"jnz %t{condition_dest}, @l{true_label}, @l{false_label}") result.code.append(f"@l{true_label}") for statement in target.then: result += compile_to_ssa(statement, context) result.code.append(f"jmp @l{after_label}") result.code.append(f"@l{false_label}") for statement in target.els: result += compile_to_ssa(statement, context) result.code.append(f"jmp @l{after_label}") result.code.append(f"@l{after_label}") return result @compile_to_ssa.register def _(target: ComparisonExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.value1, context) value1_dest = context.next_temp - 1 result += compile_to_ssa(target.value2, context) value2_dest = context.next_temp - 1 result_dest = context.next_temp context.next_temp += 1 if target.op == '==': result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}") else: raise NotImplementedError('comparison ' + target.op) return result @compile_to_ssa.register def _(target: AddExpression, context: CompileContext) -> SsaResult: result = compile_to_ssa(target.term1, context) value1_dest = context.next_temp - 1 result += compile_to_ssa(target.term2, context) value2_dest = context.next_temp - 1 result_reg = context.next_temp context.next_temp += 1 # TODO make sure the types are correct result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}") return result @compile_to_ssa.register def _(target: VariableExpression, context: CompileContext) -> SsaResult: # TODO make sure any of this is reasonable result = context.next_temp context.next_temp += 1 return SsaResult([], [f"%t{result} =w copy %{target.name}"])