aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crowbar_reference_compiler/__init__.py8
-rw-r--r--crowbar_reference_compiler/ast.py3
-rw-r--r--crowbar_reference_compiler/ssagen.py237
-rw-r--r--tests/test_hello_world.py13
4 files changed, 125 insertions, 136 deletions
diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py
index 53d942f..1410bf7 100644
--- a/crowbar_reference_compiler/__init__.py
+++ b/crowbar_reference_compiler/__init__.py
@@ -1,10 +1,7 @@
-import dataclasses
-from pprint import pprint
-
from .ast import build_ast
from .parser import parse_header, parse_implementation
from .scanner import scan
-from .ssagen import compile_to_ssa
+from .ssagen import build_ssa
def main():
@@ -37,9 +34,8 @@ def main():
return
full_ast = build_ast(parse_tree, args.include_dir)
- pprint(dataclasses.asdict(full_ast))
- ssa = compile_to_ssa(parse_tree)
+ ssa = build_ssa(full_ast)
if args.stop_at_qbe_ssa:
if args.out is None:
args.out = args.input.replace('.cro', '.ssa')
diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py
index 2a6963c..bb5a4e1 100644
--- a/crowbar_reference_compiler/ast.py
+++ b/crowbar_reference_compiler/ast.py
@@ -530,7 +530,8 @@ class ASTBuilder(NodeVisitor):
return ConstantExpression(body.data)
if body.type in ['true', 'false']:
return ConstantExpression(body.type)
- raise NotImplementedError()
+ if body.type == 'string_literal':
+ return ConstantExpression(body.data)
def visit_StructPointerElementSuffix(self, node, visited_children):
separator, element = visited_children
diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py
index 8d4b6be..3c12c85 100644
--- a/crowbar_reference_compiler/ssagen.py
+++ b/crowbar_reference_compiler/ssagen.py
@@ -1,124 +1,113 @@
-from parsimonious import NodeVisitor # type: ignore
-from parsimonious.nodes import Node # type: ignore
-
-
-class SsaGenVisitor(NodeVisitor):
- def __init__(self):
- self.data = []
-
- def visit_ImplementationFile(self, node, visited_children):
- data = '\n'.join(self.data)
- functions = '\n'.join(visited_children)
- return data + '\n' + functions
-
- def visit_IncludeStatement(self, node, visited_children):
- include, included_header, semicolon = visited_children
- assert include.text[0].type == 'include'
- assert included_header.type == 'string_literal'
- included_header = included_header.data
- assert semicolon.text[0].type == ';'
- return ''
-
- def visit_FunctionDefinition(self, node, visited_children):
- signature, body = visited_children
- return_type, name, args = signature
- body = '\n'.join(' ' + instr for instr in body)
- return f"export function w ${name}() {{\n@start\n{body}\n}}"
-
- def visit_FunctionSignature(self, node, visited_children):
- return_type, name, lparen, args, rparen = visited_children
- assert name.type == 'identifier'
- name = name.data
- assert lparen.text[0].type == '('
- assert rparen.text[0].type == ')'
- return return_type, name, args
-
- def visit_Block(self, node, visited_children):
- lbrace, statements, rbrace = visited_children
- return statements
-
- def visit_Statement(self, node, visited_children):
- return visited_children[0]
-
- def visit_ExpressionStatement(self, node, visited_children):
- expression, semicolon = visited_children
- assert semicolon.text[0].type == ';'
- return expression
-
- def visit_Expression(self, node, visited_children):
- # TODO handle logical and/or
- return visited_children[0]
-
- def visit_ComparisonExpression(self, node, visited_children):
- # TODO handle comparisons
- return visited_children[0]
-
- def visit_BitwiseOpExpression(self, node, visited_children):
- # TODO handle bitwise operations
- return visited_children[0]
-
- def visit_ArithmeticExpression(self, node, visited_children):
- # TODO handle addition/subtraction
- return visited_children[0]
-
- def visit_TermExpression(self, node, visited_children):
- # TODO handle multiplication/division/modulus
- return visited_children[0]
-
- def visit_FactorExpression(self, node, visited_children):
- # TODO handle casts/address-of/pointer-dereference/unary ops/sizeof
- return visited_children[0]
-
- def visit_ObjectExpression(self, node, visited_children):
- # TODO handle array literals
- # TODO handle struct literals
- base, suffices = visited_children[0]
- if isinstance(suffices, Node):
- suffices = suffices.children
- if len(suffices) == 0:
- return base
- if base.type == 'identifier' and suffices[0].text[0].type == '(':
- arguments = suffices[1]
- if arguments[0].type == 'string_literal':
- data = arguments[0].data
- name = f"$data{len(self.data)}"
- # TODO handle non-variadic functions
- arguments = [f"l {name}", '...']
- self.data.append(f"data {name} = {{ b {data}, b 0 }}")
- return f"call ${base.data}({', '.join(arguments)})"
- print(base)
- print(suffices[0])
-
- def visit_AtomicExpression(self, node, visited_children):
- # TODO handle parenthesized subexpressions
- return visited_children[0]
-
- def visit_FlowControlStatement(self, node, visited_children):
- # TODO handle break/continue
- ret, arg, semicolon = visited_children[0]
- assert ret.text[0].type == 'return'
- assert semicolon.text[0].type == ';'
- if arg.type == 'constant':
- return f"ret {arg.data}"
-
- def visit_constant(self, node, visited_children):
- return node.text[0]
-
- def visit_string_literal(self, node, visited_children):
- return node.text[0]
-
- def visit_identifier(self, node, visited_children):
- return node.text[0]
-
- def generic_visit(self, node, visited_children):
- """ The generic visit method. """
- if not visited_children:
- return node
- if len(visited_children) == 1:
- return visited_children[0]
- return visited_children
-
-
-def compile_to_ssa(parse_tree):
- ssa_gen = SsaGenVisitor()
- return ssa_gen.visit(parse_tree)
+from dataclasses import dataclass
+from functools import singledispatch
+from typing import Dict, List
+
+from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \
+ VariableExpression, ConstantExpression, ReturnStatement, BasicType
+
+
+@dataclass
+class SsaResult:
+ data: List[str]
+ code: List[str]
+
+
+@dataclass
+class CompileContext:
+ next_data: int = 0
+ next_temp: int = 0
+
+
+def build_ssa(file: ImplementationFile) -> str:
+ result = compile_to_ssa(file, CompileContext())
+ data = '\n'.join(result.data)
+ code = '\n'.join(result.code)
+ return data + '\n\n' + code
+
+
+@singledispatch
+def compile_to_ssa(target, context: CompileContext) -> SsaResult:
+ raise NotImplementedError('unannotated compile on ' + str(type(target)))
+
+
+@compile_to_ssa.register
+def _(target: ImplementationFile, context: CompileContext):
+ data = []
+ code = []
+ for target in target.contents:
+ result = compile_to_ssa(target, context)
+ data += result.data
+ code += result.code
+ return SsaResult(data, code)
+
+
+@compile_to_ssa.register
+def _(target: FunctionDefinition, context: CompileContext) -> SsaResult:
+ data = []
+ code = []
+ for statement in target.body:
+ result = compile_to_ssa(statement, context)
+ data += result.data
+ code += result.code
+ code = [' ' + instr for instr in code]
+ assert len(target.args) == 0
+ assert target.return_type == BasicType('int32')
+ code = [f"export function w ${target.name}() {{", "@start", *code, "}"]
+ return SsaResult(data, code)
+
+
+@compile_to_ssa.register
+def _(target: ExpressionStatement, context: CompileContext) -> SsaResult:
+ return compile_to_ssa(target.body, context)
+
+
+@compile_to_ssa.register
+def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult:
+ assert isinstance(target.function, VariableExpression)
+ data = []
+ code = []
+ args = []
+ for i, expr in enumerate(target.arguments):
+ arg_dest = context.next_temp
+ result = compile_to_ssa(expr, context)
+ data += result.data
+ code += result.code
+ args += [f"l %t{arg_dest}"]
+ code += [f"call ${target.function.name}({','.join(args)}, ...)"]
+ return SsaResult(data, code)
+
+
+@compile_to_ssa.register
+def _(target: ConstantExpression, context: CompileContext) -> SsaResult:
+ if target.value.startswith('"'):
+ data_dest = context.next_data
+ context.next_data += 1
+ data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"]
+ temp = context.next_temp
+ context.next_temp += 1
+ code = [f"%t{temp} =l copy $data{data_dest}"]
+ else:
+ assert not target.value.startswith('0b')
+ assert not target.value.startswith('0B')
+ assert not target.value.startswith('0o')
+ assert not target.value.startswith('0x')
+ assert not target.value.startswith('0X')
+ assert not target.value.startswith('0f')
+ assert not target.value.startswith('0F')
+ assert '.' not in target.value
+ assert not target.value.startswith("'")
+ data = []
+ temp = context.next_temp
+ context.next_temp += 1
+ code = [f"%t{temp} =w copy {target.value}"]
+ return SsaResult(data, code)
+
+
+@compile_to_ssa.register
+def _(target: ReturnStatement, context: CompileContext) -> SsaResult:
+ if target.body is None:
+ return SsaResult([], ['ret'])
+ ret_val_dest = context.next_temp
+ result = compile_to_ssa(target.body, context)
+ result.code.append(f"ret %t{ret_val_dest}")
+ return result
diff --git a/tests/test_hello_world.py b/tests/test_hello_world.py
index 594ab4f..0377392 100644
--- a/tests/test_hello_world.py
+++ b/tests/test_hello_world.py
@@ -1,12 +1,12 @@
import unittest
-from crowbar_reference_compiler import compile_to_ssa, parse_header, parse_implementation, scan
+from crowbar_reference_compiler import build_ast, build_ssa, parse_header, parse_implementation, scan
class TestHelloWorld(unittest.TestCase):
def test_ssa(self):
code = r"""
-include "stdio.hro";
+//include "stdio.hro";
int32 main() {
printf("Hello, world!\n");
@@ -15,14 +15,17 @@ int32 main() {
"""
tokens = scan(code)
parse_tree = parse_implementation(tokens)
- actual_ssa = compile_to_ssa(parse_tree)
+ ast = build_ast(parse_tree, [])
+ actual_ssa = build_ssa(ast)
expected_ssa = r"""
data $data0 = { b "Hello, world!\n", b 0 }
export function w $main() {
@start
- call $printf(l $data0, ...)
- ret 0
+ %t0 =l copy $data0
+ call $printf(l %t0, ...)
+ %t1 =w copy 0
+ ret %t1
}
""".strip()
self.assertEqual(expected_ssa, actual_ssa)