diff options
Diffstat (limited to 'crowbar_reference_compiler')
| -rw-r--r-- | crowbar_reference_compiler/ssagen.py | 110 | 
1 files changed, 86 insertions, 24 deletions
diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py index 3c12c85..508025e 100644 --- a/crowbar_reference_compiler/ssagen.py +++ b/crowbar_reference_compiler/ssagen.py @@ -1,9 +1,9 @@  from dataclasses import dataclass  from functools import singledispatch -from typing import Dict, List +from typing import List  from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \ -    VariableExpression, ConstantExpression, ReturnStatement, BasicType +    VariableExpression, ConstantExpression, ReturnStatement, BasicType, IfStatement, ComparisonExpression, AddExpression  @dataclass @@ -11,11 +11,23 @@ class SsaResult:      data: List[str]      code: List[str] +    def __add__(self, other: 'SsaResult') -> 'SsaResult': +        if not isinstance(other, SsaResult): +            return NotImplemented +        return SsaResult(self.data + other.data, self.code + other.code) + +    def __radd__(self, other: 'SsaResult'): +        if not isinstance(other, SsaResult): +            return NotImplemented +        self.data += other.data +        self.code += other.code +  @dataclass  class CompileContext:      next_data: int = 0      next_temp: int = 0 +    next_label: int = 0  def build_ssa(file: ImplementationFile) -> str: @@ -32,28 +44,22 @@ def compile_to_ssa(target, context: CompileContext) -> SsaResult:  @compile_to_ssa.register  def _(target: ImplementationFile, context: CompileContext): -    data = [] -    code = [] +    result = SsaResult([], [])      for target in target.contents: -        result = compile_to_ssa(target, context) -        data += result.data -        code += result.code -    return SsaResult(data, code) +        result += compile_to_ssa(target, context) +    return result  @compile_to_ssa.register  def _(target: FunctionDefinition, context: CompileContext) -> SsaResult: -    data = [] -    code = [] +    result = SsaResult([], [])      for statement in target.body: -        result = compile_to_ssa(statement, context) -        data += result.data -        code += result.code -    code = ['    ' + instr for instr in code] +        result += compile_to_ssa(statement, context) +    code = ['    ' + instr for instr in result.code]      assert len(target.args) == 0      assert target.return_type == BasicType('int32')      code = [f"export function w ${target.name}() {{", "@start", *code, "}"] -    return SsaResult(data, code) +    return SsaResult(result.data, code)  @compile_to_ssa.register @@ -64,17 +70,14 @@ def _(target: ExpressionStatement, context: CompileContext) -> SsaResult:  @compile_to_ssa.register  def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:      assert isinstance(target.function, VariableExpression) -    data = [] -    code = [] +    result = SsaResult([], [])      args = []      for i, expr in enumerate(target.arguments): -        arg_dest = context.next_temp -        result = compile_to_ssa(expr, context) -        data += result.data -        code += result.code +        result += compile_to_ssa(expr, context) +        arg_dest = context.next_temp - 1          args += [f"l %t{arg_dest}"] -    code += [f"call ${target.function.name}({','.join(args)}, ...)"] -    return SsaResult(data, code) +    result.code.append(f"call ${target.function.name}({','.join(args)}, ...)") +    return result  @compile_to_ssa.register @@ -107,7 +110,66 @@ def _(target: ConstantExpression, context: CompileContext) -> SsaResult:  def _(target: ReturnStatement, context: CompileContext) -> SsaResult:      if target.body is None:          return SsaResult([], ['ret']) -    ret_val_dest = context.next_temp      result = compile_to_ssa(target.body, context) +    ret_val_dest = context.next_temp - 1      result.code.append(f"ret %t{ret_val_dest}")      return result + + +@compile_to_ssa.register +def _(target: IfStatement, context: CompileContext) -> SsaResult: +    result = compile_to_ssa(target.condition, context) +    condition_dest = context.next_temp - 1 +    true_label = context.next_label +    context.next_label += 1 +    false_label = context.next_label +    context.next_label += 1 +    after_label = context.next_label +    context.next_label += 1 +    result.code.append(f"jnz %t{condition_dest}, @l{true_label}, @l{false_label}") +    result.code.append(f"@l{true_label}") +    for statement in target.then: +        result += compile_to_ssa(statement, context) +    result.code.append(f"jmp @l{after_label}") +    result.code.append(f"@l{false_label}") +    for statement in target.els: +        result += compile_to_ssa(statement, context) +    result.code.append(f"jmp @l{after_label}") +    result.code.append(f"@l{after_label}") +    return result + + +@compile_to_ssa.register +def _(target: ComparisonExpression, context: CompileContext) -> SsaResult: +    result = compile_to_ssa(target.value1, context) +    value1_dest = context.next_temp - 1 +    result += compile_to_ssa(target.value2, context) +    value2_dest = context.next_temp - 1 +    result_dest = context.next_temp +    context.next_temp += 1 +    if target.op == '==': +        result.code.append(f"%t{result_dest} =w ceq %t{value1_dest}, %t{value2_dest}") +    else: +        raise NotImplementedError('comparison ' + target.op) +    return result + + +@compile_to_ssa.register +def _(target: AddExpression, context: CompileContext) -> SsaResult: +    result = compile_to_ssa(target.term1, context) +    value1_dest = context.next_temp - 1 +    result += compile_to_ssa(target.term2, context) +    value2_dest = context.next_temp - 1 +    result_reg = context.next_temp +    context.next_temp += 1 +    # TODO make sure the types are correct +    result.code.append(f"%t{result_reg} =w add %t{value1_dest}, %t{value2_dest}") +    return result + + +@compile_to_ssa.register +def _(target: VariableExpression, context: CompileContext) -> SsaResult: +    # TODO make sure any of this is reasonable +    result = context.next_temp +    context.next_temp += 1 +    return SsaResult([], [f"%t{result} =w copy %{target.name}"])  |