From bd933eeef043e0bc5e3ddd6557c9884593b59d3b Mon Sep 17 00:00:00 2001 From: Melody Horn Date: Thu, 5 Nov 2020 21:01:38 -0700 Subject: add some more instructions --- crowbar_reference_compiler/ssagen.py | 110 +++++++++++++++++++++++++++-------- 1 file changed, 86 insertions(+), 24 deletions(-) diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py index 3c12c85..508025e 100644 --- a/crowbar_reference_compiler/ssagen.py +++ b/crowbar_reference_compiler/ssagen.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from functools import singledispatch -from typing import Dict, List +from typing import List from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \ - VariableExpression, ConstantExpression, ReturnStatement, BasicType + VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression @dataclass @@ -11,11 +11,23 @@ class SsaResult: data: List[str] code: List[str] + def __add__(self, other: 'SsaResult') -> 'SsaResult': + if not isinstance(other, SsaResult): + return NotImplemented + return SsaResult(self.data + other.data, self.code + other.code) + + def __radd__(self, other: 'SsaResult'): + if not isinstance(other, SsaResult): + return NotImplemented + self.data += other.data + self.code += other.code + @dataclass class CompileContext: next_data: int = 0 next_temp: int = 0 + next_label: int = 0 def build_ssa(file: ImplementationFile) -> str: @@ -32,28 +44,22 @@ def compile_to_ssa(target, context: CompileContext) -> SsaResult: @compile_to_ssa.register def _(target: ImplementationFile, context: CompileContext): - data = [] - code = [] + result = SsaResult([], []) for target in target.contents: - result = compile_to_ssa(target, context) - data += result.data - code += result.code - return SsaResult(data, code) + result += compile_to_ssa(target, context) + return result @compile_to_ssa.register def _(target: FunctionDefinition, context: CompileContext) -> SsaResult: - data = [] - code = [] + result = SsaResult([], []) for statement in target.body: - result = compile_to_ssa(statement, context) - data += result.data - code += result.code - code = [' ' + instr for instr in code] + result += compile_to_ssa(statement, context) + code = [' ' + instr for instr in result.code] assert len(target.args) == 0 assert target.return_type == BasicType('int32') code = [f"export function w ${target.name}() {{", "@start", *code, "}"] - return SsaResult(data, code) + return SsaResult(result.data, code) @compile_to_ssa.register @@ -64,17 +70,14 @@ def _(target: ExpressionStatement, context: CompileContext) -> SsaResult: @compile_to_ssa.register def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult: assert isinstance(target.function, VariableExpression) - data = [] - code = [] + result = SsaResult([], []) args = [] for i, expr in enumerate(target.arguments): - arg_dest = context.next_temp - result = compile_to_ssa(expr, context) - data += result.data - code += result.code + result += compile_to_ssa(expr, context) + arg_dest = context.next_temp - 1 args += [f"l %t{arg_dest}"] - code += [f"call ${target.function.name}({','.join(args)}, ...)"] - return SsaResult(data, code) + result.code.append(f"call ${target.function.name}({','.join(args)}, ...)") + return result @compile_to_ssa.register @@ -107,7 +110,66 @@ def _(target: ConstantExpression, context: CompileContext) -> SsaResult: def _(target: ReturnStatement, context: CompileContext) -> SsaResult: if target.body is None: return SsaResult([], ['ret']) - ret_val_dest = context.next_temp result = compile_to_ssa(target.body, context) + ret_val_dest = context.next_temp - 1 result.code.append(f"ret %t{ret_val_dest}") return result + + +@compile_to_ssa.register +def _(target: IfStatement, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.condition, context) + condition_dest = context.next_temp - 1 + true_label = context.next_label + context.next_label += 1 + false_label = context.next_label + context.next_label += 1 + after_label = context.next_label + context.next_label += 1 + result.code.append(f"jnz %t{condition_dest}, @l{true_label}, @l{false_label}") + result.code.append(f"@l{true_label}") + for statement in target.then: + result += compile_to_ssa(statement, context) + result.code.append(f"jmp @l{after_label}") + result.code.append(f"@l{false_label}") + for statement in target.els: + result += compile_to_ssa(statement, context) + result.code.append(f"jmp @l{after_label}") + result.code.append(f"@l{after_label}") + return result + + +@compile_to_ssa.register +def _(target: ComparisonExpression, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.value1, context) + value1_dest = context.next_temp - 1 + result += compile_to_ssa(target.value2, context) + value2_dest = context.next_temp - 1 + result_dest = context.next_temp + context.next_temp += 1 + if target.op == '==': + result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}") + else: + raise NotImplementedError('comparison ' + target.op) + return result + + +@compile_to_ssa.register +def _(target: AddExpression, context: CompileContext) -> SsaResult: + result = compile_to_ssa(target.term1, context) + value1_dest = context.next_temp - 1 + result += compile_to_ssa(target.term2, context) + value2_dest = context.next_temp - 1 + result_reg = context.next_temp + context.next_temp += 1 + # TODO make sure the types are correct + result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}") + return result + + +@compile_to_ssa.register +def _(target: VariableExpression, context: CompileContext) -> SsaResult: + # TODO make sure any of this is reasonable + result = context.next_temp + context.next_temp += 1 + return SsaResult([], [f"%t{result} =w copy %{target.name}"]) -- cgit v1.2.3