aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler/ssagen.py
diff options
context:
space:
mode:
authorMelody Horn <melody@boringcactus.com>2020-11-05 01:27:58 -0700
committerMelody Horn <melody@boringcactus.com>2020-11-05 01:27:58 -0700
commitb6258d36b6534d521e9cdf1307665c38e5ae409d (patch)
tree7e98020964173869f2799533553bcdcd1f64ea1a /crowbar_reference_compiler/ssagen.py
parent9dfc552c0703c5e14ea472eb5431719b2e0d6400 (diff)
downloadreference-compiler-b6258d36b6534d521e9cdf1307665c38e5ae409d.tar.gz
reference-compiler-b6258d36b6534d521e9cdf1307665c38e5ae409d.zip
compile based on the fancy new AST
Diffstat (limited to 'crowbar_reference_compiler/ssagen.py')
-rw-r--r--crowbar_reference_compiler/ssagen.py237
1 files changed, 113 insertions, 124 deletions
diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py
index 8d4b6be..3c12c85 100644
--- a/crowbar_reference_compiler/ssagen.py
+++ b/crowbar_reference_compiler/ssagen.py
@@ -1,124 +1,113 @@
-from parsimonious import NodeVisitor # type: ignore
-from parsimonious.nodes import Node # type: ignore
-
-
-class SsaGenVisitor(NodeVisitor):
- def __init__(self):
- self.data = []
-
- def visit_ImplementationFile(self, node, visited_children):
- data = '\n'.join(self.data)
- functions = '\n'.join(visited_children)
- return data + '\n' + functions
-
- def visit_IncludeStatement(self, node, visited_children):
- include, included_header, semicolon = visited_children
- assert include.text[0].type == 'include'
- assert included_header.type == 'string_literal'
- included_header = included_header.data
- assert semicolon.text[0].type == ';'
- return ''
-
- def visit_FunctionDefinition(self, node, visited_children):
- signature, body = visited_children
- return_type, name, args = signature
- body = '\n'.join(' ' + instr for instr in body)
- return f"export function w ${name}() {{\n@start\n{body}\n}}"
-
- def visit_FunctionSignature(self, node, visited_children):
- return_type, name, lparen, args, rparen = visited_children
- assert name.type == 'identifier'
- name = name.data
- assert lparen.text[0].type == '('
- assert rparen.text[0].type == ')'
- return return_type, name, args
-
- def visit_Block(self, node, visited_children):
- lbrace, statements, rbrace = visited_children
- return statements
-
- def visit_Statement(self, node, visited_children):
- return visited_children[0]
-
- def visit_ExpressionStatement(self, node, visited_children):
- expression, semicolon = visited_children
- assert semicolon.text[0].type == ';'
- return expression
-
- def visit_Expression(self, node, visited_children):
- # TODO handle logical and/or
- return visited_children[0]
-
- def visit_ComparisonExpression(self, node, visited_children):
- # TODO handle comparisons
- return visited_children[0]
-
- def visit_BitwiseOpExpression(self, node, visited_children):
- # TODO handle bitwise operations
- return visited_children[0]
-
- def visit_ArithmeticExpression(self, node, visited_children):
- # TODO handle addition/subtraction
- return visited_children[0]
-
- def visit_TermExpression(self, node, visited_children):
- # TODO handle multiplication/division/modulus
- return visited_children[0]
-
- def visit_FactorExpression(self, node, visited_children):
- # TODO handle casts/address-of/pointer-dereference/unary ops/sizeof
- return visited_children[0]
-
- def visit_ObjectExpression(self, node, visited_children):
- # TODO handle array literals
- # TODO handle struct literals
- base, suffices = visited_children[0]
- if isinstance(suffices, Node):
- suffices = suffices.children
- if len(suffices) == 0:
- return base
- if base.type == 'identifier' and suffices[0].text[0].type == '(':
- arguments = suffices[1]
- if arguments[0].type == 'string_literal':
- data = arguments[0].data
- name = f"$data{len(self.data)}"
- # TODO handle non-variadic functions
- arguments = [f"l {name}", '...']
- self.data.append(f"data {name} = {{ b {data}, b 0 }}")
- return f"call ${base.data}({', '.join(arguments)})"
- print(base)
- print(suffices[0])
-
- def visit_AtomicExpression(self, node, visited_children):
- # TODO handle parenthesized subexpressions
- return visited_children[0]
-
- def visit_FlowControlStatement(self, node, visited_children):
- # TODO handle break/continue
- ret, arg, semicolon = visited_children[0]
- assert ret.text[0].type == 'return'
- assert semicolon.text[0].type == ';'
- if arg.type == 'constant':
- return f"ret {arg.data}"
-
- def visit_constant(self, node, visited_children):
- return node.text[0]
-
- def visit_string_literal(self, node, visited_children):
- return node.text[0]
-
- def visit_identifier(self, node, visited_children):
- return node.text[0]
-
- def generic_visit(self, node, visited_children):
- """ The generic visit method. """
- if not visited_children:
- return node
- if len(visited_children) == 1:
- return visited_children[0]
- return visited_children
-
-
-def compile_to_ssa(parse_tree):
- ssa_gen = SsaGenVisitor()
- return ssa_gen.visit(parse_tree)
+from dataclasses import dataclass
+from functools import singledispatch
+from typing import Dict, List
+
+from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \
+ VariableExpression, ConstantExpression, ReturnStatement, BasicType
+
+
+@dataclass
+class SsaResult:
+ data: List[str]
+ code: List[str]
+
+
+@dataclass
+class CompileContext:
+ next_data: int = 0
+ next_temp: int = 0
+
+
+def build_ssa(file: ImplementationFile) -> str:
+ result = compile_to_ssa(file, CompileContext())
+ data = '\n'.join(result.data)
+ code = '\n'.join(result.code)
+ return data + '\n\n' + code
+
+
+@singledispatch
+def compile_to_ssa(target, context: CompileContext) -> SsaResult:
+ raise NotImplementedError('unannotated compile on ' + str(type(target)))
+
+
+@compile_to_ssa.register
+def _(target: ImplementationFile, context: CompileContext):
+ data = []
+ code = []
+ for target in target.contents:
+ result = compile_to_ssa(target, context)
+ data += result.data
+ code += result.code
+ return SsaResult(data, code)
+
+
+@compile_to_ssa.register
+def _(target: FunctionDefinition, context: CompileContext) -> SsaResult:
+ data = []
+ code = []
+ for statement in target.body:
+ result = compile_to_ssa(statement, context)
+ data += result.data
+ code += result.code
+ code = [' ' + instr for instr in code]
+ assert len(target.args) == 0
+ assert target.return_type == BasicType('int32')
+ code = [f"export function w ${target.name}() {{", "@start", *code, "}"]
+ return SsaResult(data, code)
+
+
+@compile_to_ssa.register
+def _(target: ExpressionStatement, context: CompileContext) -> SsaResult:
+ return compile_to_ssa(target.body, context)
+
+
+@compile_to_ssa.register
+def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:
+ assert isinstance(target.function, VariableExpression)
+ data = []
+ code = []
+ args = []
+ for i, expr in enumerate(target.arguments):
+ arg_dest = context.next_temp
+ result = compile_to_ssa(expr, context)
+ data += result.data
+ code += result.code
+ args += [f"l %t{arg_dest}"]
+ code += [f"call ${target.function.name}({','.join(args)}, ...)"]
+ return SsaResult(data, code)
+
+
+@compile_to_ssa.register
+def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
+ if target.value.startswith('"'):
+ data_dest = context.next_data
+ context.next_data += 1
+ data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"]
+ temp = context.next_temp
+ context.next_temp += 1
+ code = [f"%t{temp} =l copy $data{data_dest}"]
+ else:
+ assert not target.value.startswith('0b')
+ assert not target.value.startswith('0B')
+ assert not target.value.startswith('0o')
+ assert not target.value.startswith('0x')
+ assert not target.value.startswith('0X')
+ assert not target.value.startswith('0f')
+ assert not target.value.startswith('0F')
+ assert '.' not in target.value
+ assert not target.value.startswith("'")
+ data = []
+ temp = context.next_temp
+ context.next_temp += 1
+ code = [f"%t{temp} =w copy {target.value}"]
+ return SsaResult(data, code)
+
+
+@compile_to_ssa.register
+def _(target: ReturnStatement, context: CompileContext) -> SsaResult:
+ if target.body is None:
+ return SsaResult([], ['ret'])
+ ret_val_dest = context.next_temp
+ result = compile_to_ssa(target.body, context)
+ result.code.append(f"ret %t{ret_val_dest}")
+ return result