aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler
diff options
context:
space:
mode:
Diffstat (limited to 'crowbar_reference_compiler')
-rw-r--r--crowbar_reference_compiler/ssagen.py110
1 files changed, 86 insertions, 24 deletions
diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py
index 3c12c85..508025e 100644
--- a/crowbar_reference_compiler/ssagen.py
+++ b/crowbar_reference_compiler/ssagen.py
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from functools import singledispatch
-from typing import Dict, List
+from typing import List
from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \
- VariableExpression, ConstantExpression, ReturnStatement, BasicType
+ VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression
@dataclass
@@ -11,11 +11,23 @@ class SsaResult:
data: List[str]
code: List[str]
+ def __add__(self, other: 'SsaResult') -> 'SsaResult':
+ if not isinstance(other, SsaResult):
+ return NotImplemented
+ return SsaResult(self.data + other.data, self.code + other.code)
+
+ def __radd__(self, other: 'SsaResult'):
+ if not isinstance(other, SsaResult):
+ return NotImplemented
+ self.data += other.data
+ self.code += other.code
+
@dataclass
class CompileContext:
next_data: int = 0
next_temp: int = 0
+ next_label: int = 0
def build_ssa(file: ImplementationFile) -> str:
@@ -32,28 +44,22 @@ def compile_to_ssa(target, context: CompileContext) -> SsaResult:
@compile_to_ssa.register
def _(target: ImplementationFile, context: CompileContext):
- data = []
- code = []
+ result = SsaResult([], [])
for target in target.contents:
- result = compile_to_ssa(target, context)
- data += result.data
- code += result.code
- return SsaResult(data, code)
+ result += compile_to_ssa(target, context)
+ return result
@compile_to_ssa.register
def _(target: FunctionDefinition, context: CompileContext) -> SsaResult:
- data = []
- code = []
+ result = SsaResult([], [])
for statement in target.body:
- result = compile_to_ssa(statement, context)
- data += result.data
- code += result.code
- code = [' ' + instr for instr in code]
+ result += compile_to_ssa(statement, context)
+ code = [' ' + instr for instr in result.code]
assert len(target.args) == 0
assert target.return_type == BasicType('int32')
code = [f"export function w ${target.name}() {{", "@start", *code, "}"]
- return SsaResult(data, code)
+ return SsaResult(result.data, code)
@compile_to_ssa.register
@@ -64,17 +70,14 @@ def _(target: ExpressionStatement, context: CompileContext) -> SsaResult:
@compile_to_ssa.register
def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:
assert isinstance(target.function, VariableExpression)
- data = []
- code = []
+ result = SsaResult([], [])
args = []
for i, expr in enumerate(target.arguments):
- arg_dest = context.next_temp
- result = compile_to_ssa(expr, context)
- data += result.data
- code += result.code
+ result += compile_to_ssa(expr, context)
+ arg_dest = context.next_temp - 1
args += [f"l %t{arg_dest}"]
- code += [f"call ${target.function.name}({','.join(args)}, ...)"]
- return SsaResult(data, code)
+ result.code.append(f"call ${target.function.name}({','.join(args)}, ...)")
+ return result
@compile_to_ssa.register
@@ -107,7 +110,66 @@ def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
def _(target: ReturnStatement, context: CompileContext) -> SsaResult:
if target.body is None:
return SsaResult([], ['ret'])
- ret_val_dest = context.next_temp
result = compile_to_ssa(target.body, context)
+ ret_val_dest = context.next_temp - 1
result.code.append(f"ret %t{ret_val_dest}")
return result
+
+
+@compile_to_ssa.register
+def _(target: IfStatement, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.condition, context)
+ condition_dest = context.next_temp - 1
+ true_label = context.next_label
+ context.next_label += 1
+ false_label = context.next_label
+ context.next_label += 1
+ after_label = context.next_label
+ context.next_label += 1
+ result.code.append(f"jnz %t{condition_dest}, @l{true_label}, @l{false_label}")
+ result.code.append(f"@l{true_label}")
+ for statement in target.then:
+ result += compile_to_ssa(statement, context)
+ result.code.append(f"jmp @l{after_label}")
+ result.code.append(f"@l{false_label}")
+ for statement in target.els:
+ result += compile_to_ssa(statement, context)
+ result.code.append(f"jmp @l{after_label}")
+ result.code.append(f"@l{after_label}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: ComparisonExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.value1, context)
+ value1_dest = context.next_temp - 1
+ result += compile_to_ssa(target.value2, context)
+ value2_dest = context.next_temp - 1
+ result_dest = context.next_temp
+ context.next_temp += 1
+ if target.op == '==':
+ result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}")
+ else:
+ raise NotImplementedError('comparison ' + target.op)
+ return result
+
+
+@compile_to_ssa.register
+def _(target: AddExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.term1, context)
+ value1_dest = context.next_temp - 1
+ result += compile_to_ssa(target.term2, context)
+ value2_dest = context.next_temp - 1
+ result_reg = context.next_temp
+ context.next_temp += 1
+ # TODO make sure the types are correct
+ result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: VariableExpression, context: CompileContext) -> SsaResult:
+ # TODO make sure any of this is reasonable
+ result = context.next_temp
+ context.next_temp += 1
+ return SsaResult([], [f"%t{result} =w copy %{target.name}"])