aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler/ssagen.py
diff options
context:
space:
mode:
authorMelody Horn <melody@boringcactus.com>2020-12-23 00:46:52 -0700
committerMelody Horn <melody@boringcactus.com>2020-12-23 00:46:52 -0700
commitc7de22f575607a7966b6b592dbf81bd3f867a2e4 (patch)
tree2b554674bcbdfd808681ce2e5b4395148c78a865 /crowbar_reference_compiler/ssagen.py
parentbd933eeef043e0bc5e3ddd6557c9884593b59d3b (diff)
downloadreference-compiler-c7de22f575607a7966b6b592dbf81bd3f867a2e4.tar.gz
reference-compiler-c7de22f575607a7966b6b592dbf81bd3f867a2e4.zip
implement a bunch more stuff
Diffstat (limited to 'crowbar_reference_compiler/ssagen.py')
-rw-r--r--crowbar_reference_compiler/ssagen.py221
1 files changed, 206 insertions, 15 deletions
diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py
index 508025e..b326239 100644
--- a/crowbar_reference_compiler/ssagen.py
+++ b/crowbar_reference_compiler/ssagen.py
@@ -1,9 +1,13 @@
+import dataclasses
from dataclasses import dataclass
from functools import singledispatch
from typing import List
from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \
- VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression
+ VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, \
+ AddExpression, StructPointerElementExpression, Declaration, PointerType, StructDeclaration, VariableDefinition, \
+ MultiplyExpression, LogicalNotExpression, DirectAssignment, UpdateAssignment, SizeofExpression, Expression, \
+ ConstType, ArrayIndexExpression, ArrayType, NegativeExpression, SubtractExpression, AddressOfExpression
@dataclass
@@ -25,13 +29,14 @@ class SsaResult:
@dataclass
class CompileContext:
+ declarations: List[Declaration]
next_data: int = 0
next_temp: int = 0
next_label: int = 0
def build_ssa(file: ImplementationFile) -> str:
- result = compile_to_ssa(file, CompileContext())
+ result = compile_to_ssa(file, CompileContext(file.get_declarations()))
data = '\n'.join(result.data)
code = '\n'.join(result.code)
return data + '\n\n' + code
@@ -53,12 +58,20 @@ def _(target: ImplementationFile, context: CompileContext):
@compile_to_ssa.register
def _(target: FunctionDefinition, context: CompileContext) -> SsaResult:
result = SsaResult([], [])
+ context = dataclasses.replace(context, declarations=target.args + context.declarations)
for statement in target.body:
result += compile_to_ssa(statement, context)
+ if isinstance(statement, Declaration):
+ context = dataclasses.replace(context, declarations=[statement]+context.declarations)
+ if not result.code[-1].startswith('ret'):
+ result.code.append('ret')
code = [' ' + instr for instr in result.code]
- assert len(target.args) == 0
- assert target.return_type == BasicType('int32')
- code = [f"export function w ${target.name}() {{", "@start", *code, "}"]
+ # TODO types
+ args = ','.join(f"l %{x.name}" for x in target.args)
+ ret_type = ''
+ if target.return_type != BasicType('void'):
+ ret_type = 'l'
+ code = [f"export function {ret_type} ${target.name}({args}) {{", "@start", *code, "}"]
return SsaResult(result.data, code)
@@ -82,14 +95,28 @@ def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:
@compile_to_ssa.register
def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
- if target.value.startswith('"'):
+ if target.type(context.declarations) == PointerType(ConstType(BasicType('char'))):
data_dest = context.next_data
context.next_data += 1
data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"]
temp = context.next_temp
context.next_temp += 1
code = [f"%t{temp} =l copy $data{data_dest}"]
- else:
+ elif target.type(context.declarations) == BasicType('char'):
+ data = []
+ temp = context.next_temp
+ context.next_temp += 1
+ code = [f"%t{temp} =l copy {ord(target.value[1])}"] # TODO handle escape sequences
+ elif target.type(context.declarations) == BasicType('bool'):
+ data = []
+ temp = context.next_temp
+ context.next_temp += 1
+ if target.value == 'true':
+ value = 1
+ else:
+ value = 0
+ code = [f"%t{temp} =l copy {value}"]
+ elif target.type(context.declarations) == BasicType('int?'):
assert not target.value.startswith('0b')
assert not target.value.startswith('0B')
assert not target.value.startswith('0o')
@@ -102,7 +129,9 @@ def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
data = []
temp = context.next_temp
context.next_temp += 1
- code = [f"%t{temp} =w copy {target.value}"]
+ code = [f"%t{temp} =l copy {target.value}"]
+ else:
+ raise NotImplementedError('compiling ' + str(target))
return SsaResult(data, code)
@@ -130,11 +159,14 @@ def _(target: IfStatement, context: CompileContext) -> SsaResult:
result.code.append(f"@l{true_label}")
for statement in target.then:
result += compile_to_ssa(statement, context)
- result.code.append(f"jmp @l{after_label}")
+ if not result.code[-1].startswith('ret'):
+ result.code.append(f"jmp @l{after_label}")
result.code.append(f"@l{false_label}")
- for statement in target.els:
- result += compile_to_ssa(statement, context)
- result.code.append(f"jmp @l{after_label}")
+ if target.els is not None:
+ for statement in target.els:
+ result += compile_to_ssa(statement, context)
+ if not result.code[-1].startswith('ret'):
+ result.code.append(f"jmp @l{after_label}")
result.code.append(f"@l{after_label}")
return result
@@ -147,10 +179,16 @@ def _(target: ComparisonExpression, context: CompileContext) -> SsaResult:
value2_dest = context.next_temp - 1
result_dest = context.next_temp
context.next_temp += 1
+ # TODO types, and signedness
if target.op == '==':
- result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}")
+ op = "ceqw"
+ elif target.op == '>=':
+ op = "cugew"
+ elif target.op == '<=':
+ op = "culew"
else:
raise NotImplementedError('comparison ' + target.op)
+ result.code.append(f"%t{result_dest} =l {op} %t{value1_dest}, %t{value2_dest}")
return result
@@ -163,7 +201,33 @@ def _(target: AddExpression, context: CompileContext) -> SsaResult:
result_reg = context.next_temp
context.next_temp += 1
# TODO make sure the types are correct
- result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}")
+ result.code.append(f"%t{result_reg} =l add %t{value1_dest}, %t{value2_dest}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: SubtractExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.term1, context)
+ value1_dest = context.next_temp - 1
+ result += compile_to_ssa(target.term2, context)
+ value2_dest = context.next_temp - 1
+ result_reg = context.next_temp
+ context.next_temp += 1
+ # TODO make sure the types are correct
+ result.code.append(f"%t{result_reg} =l sub %t{value1_dest}, %t{value2_dest}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: MultiplyExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.factor1, context)
+ value1_dest = context.next_temp - 1
+ result += compile_to_ssa(target.factor2, context)
+ value2_dest = context.next_temp - 1
+ result_reg = context.next_temp
+ context.next_temp += 1
+ # TODO make sure the types are correct
+ result.code.append(f"%t{result_reg} =l mul %t{value1_dest}, %t{value2_dest}")
return result
@@ -172,4 +236,131 @@ def _(target: VariableExpression, context: CompileContext) -> SsaResult:
# TODO make sure any of this is reasonable
result = context.next_temp
context.next_temp += 1
- return SsaResult([], [f"%t{result} =w copy %{target.name}"])
+ return SsaResult([], [f"%t{result} =l copy %{target.name}"])
+
+
+@compile_to_ssa.register
+def _(target: VariableDefinition, context: CompileContext) -> SsaResult:
+ # TODO figure some shit out
+ result = compile_to_ssa(target.value, context)
+ result_dest = context.next_temp - 1
+ result.code.append(f"%{target.name} =l copy %t{result_dest}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: LogicalNotExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.body, context)
+ inner_result_dest = context.next_temp - 1
+ result_dest = context.next_temp
+ context.next_temp += 1
+ result.code.append(f"%t{result_dest} =l ceqw %t{inner_result_dest}, 0")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: NegativeExpression, context: CompileContext) -> SsaResult:
+ return compile_to_ssa(SubtractExpression(ConstantExpression('0'), target.body), context)
+
+
+@compile_to_ssa.register
+def _(target: ArrayIndexExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.array, context)
+ base = context.next_temp - 1
+ result += compile_to_ssa(target.index, context)
+ index = context.next_temp - 1
+ array_type = target.array.type(context.declarations)
+ if isinstance(array_type, PointerType):
+ array_type = array_type.target
+ assert isinstance(array_type, ArrayType)
+ content_type = array_type.contents
+ scale = content_type.size_bytes(context.declarations)
+ offset = context.next_temp
+ context.next_temp += 1
+ address = context.next_temp
+ context.next_temp += 1
+ dest = context.next_temp
+ context.next_temp += 1
+ # TODO types
+ result.code.append(f"%t{offset} =l mul %t{index}, {scale}")
+ result.code.append(f"%t{address} =l add %t{base}, %t{offset}")
+ result.code.append(f"%t{dest} =l loadsw %t{address}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: StructPointerElementExpression, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.base, context)
+ base_dest = context.next_temp - 1
+ # hoooo boy.
+ base_type = target.base.type(context.declarations)
+ assert isinstance(base_type, PointerType)
+ assert isinstance(base_type.target, BasicType)
+ hopefully_struct, struct_name = base_type.target.name.split(' ')
+ assert hopefully_struct == 'struct'
+ target_struct = None
+ for decl in context.declarations:
+ if isinstance(decl, StructDeclaration) and decl.name == struct_name:
+ if decl.fields is None:
+ raise KeyError('struct ' + struct_name + ' is opaque')
+ target_struct = decl
+ break
+ if target_struct is None:
+ raise KeyError('struct ' + struct_name + ' not found')
+ offset = 0
+ for field in target_struct.fields:
+ if field.name == target.element:
+ break
+ else:
+ offset += field.type.size_bytes(context.declarations)
+ temp = context.next_temp
+ context.next_temp += 1
+ result_dest = context.next_temp
+ context.next_temp += 1
+ # TODO types
+ result.code.append(f"%t{temp} =l add %t{base_dest}, {offset}")
+ result.code.append(f"%t{result_dest} =l loadsw %t{temp}")
+ return result
+
+
+@compile_to_ssa.register
+def _(target: AddressOfExpression, context: CompileContext) -> SsaResult:
+ if isinstance(target.body, StructPointerElementExpression) or isinstance(target.body, ArrayIndexExpression):
+ result = compile_to_ssa(target.body, context)
+ result.code.pop()
+ context.next_temp -= 1
+ else:
+ raise NotImplementedError('address of ' + str(type(target.body)))
+ return result
+
+
+@compile_to_ssa.register
+def _(target: DirectAssignment, context: CompileContext) -> SsaResult:
+ result = compile_to_ssa(target.value, context)
+ result_dest = context.next_temp - 1
+ if isinstance(target.destination, VariableExpression):
+ raise NotImplementedError('assign directly to variable')
+ elif isinstance(target.destination, StructPointerElementExpression) or isinstance(target.destination, ArrayIndexExpression):
+ sub_result = compile_to_ssa(target.destination, context)
+ last_instr = sub_result.code.pop()
+ _, _, _, location = last_instr.split(' ')
+ # TODO type
+ sub_result.code.append(f"storew %t{result_dest}, {location}")
+ result += sub_result
+ else:
+ raise NotImplementedError('assign to ' + str(type(target.destination)))
+ return result
+
+
+@compile_to_ssa.register
+def _(target: UpdateAssignment, context: CompileContext) -> SsaResult:
+ return compile_to_ssa(target.deconstruct(), context)
+
+
+@compile_to_ssa.register
+def _(target: SizeofExpression, context: CompileContext) -> SsaResult:
+ target = target.body
+ if isinstance(target, Expression):
+ target = target.type(context.declarations)
+ size = target.size_bytes(context.declarations)
+ return compile_to_ssa(ConstantExpression(str(size)), context)