From b6258d36b6534d521e9cdf1307665c38e5ae409d Mon Sep 17 00:00:00 2001 From: Melody Horn Date: Thu, 5 Nov 2020 01:27:58 -0700 Subject: compile based on the fancy new AST --- crowbar_reference_compiler/__init__.py | 8 +- crowbar_reference_compiler/ast.py | 3 +- crowbar_reference_compiler/ssagen.py | 237 ++++++++++++++++----------------- tests/test_hello_world.py | 13 +- 4 files changed, 125 insertions(+), 136 deletions(-) diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py index 53d942f..1410bf7 100644 --- a/crowbar_reference_compiler/__init__.py +++ b/crowbar_reference_compiler/__init__.py @@ -1,10 +1,7 @@ -import dataclasses -from pprint import pprint - from .ast import build_ast from .parser import parse_header, parse_implementation from .scanner import scan -from .ssagen import compile_to_ssa +from .ssagen import build_ssa def main(): @@ -37,9 +34,8 @@ def main(): return full_ast = build_ast(parse_tree, args.include_dir) - pprint(dataclasses.asdict(full_ast)) - ssa = compile_to_ssa(parse_tree) + ssa = build_ssa(full_ast) if args.stop_at_qbe_ssa: if args.out is None: args.out = args.input.replace('.cro', '.ssa') diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py index 2a6963c..bb5a4e1 100644 --- a/crowbar_reference_compiler/ast.py +++ b/crowbar_reference_compiler/ast.py @@ -530,7 +530,8 @@ class ASTBuilder(NodeVisitor): return ConstantExpression(body.data) if body.type in ['true', 'false']: return ConstantExpression(body.type) - raise NotImplementedError() + if body.type == 'string_literal': + return ConstantExpression(body.data) def visit_StructPointerElementSuffix(self, node, visited_children): separator, element = visited_children diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py index 8d4b6be..3c12c85 100644 --- a/crowbar_reference_compiler/ssagen.py +++ b/crowbar_reference_compiler/ssagen.py @@ -1,124 +1,113 @@ -from parsimonious import NodeVisitor # type: ignore -from parsimonious.nodes import Node # type: ignore - - -class SsaGenVisitor(NodeVisitor): - def __init__(self): - self.data = [] - - def visit_ImplementationFile(self, node, visited_children): - data = '\n'.join(self.data) - functions = '\n'.join(visited_children) - return data + '\n' + functions - - def visit_IncludeStatement(self, node, visited_children): - include, included_header, semicolon = visited_children - assert include.text[0].type == 'include' - assert included_header.type == 'string_literal' - included_header = included_header.data - assert semicolon.text[0].type == ';' - return '' - - def visit_FunctionDefinition(self, node, visited_children): - signature, body = visited_children - return_type, name, args = signature - body = '\n'.join(' ' + instr for instr in body) - return f"export function w ${name}() {{\n@start\n{body}\n}}" - - def visit_FunctionSignature(self, node, visited_children): - return_type, name, lparen, args, rparen = visited_children - assert name.type == 'identifier' - name = name.data - assert lparen.text[0].type == '(' - assert rparen.text[0].type == ')' - return return_type, name, args - - def visit_Block(self, node, visited_children): - lbrace, statements, rbrace = visited_children - return statements - - def visit_Statement(self, node, visited_children): - return visited_children[0] - - def visit_ExpressionStatement(self, node, visited_children): - expression, semicolon = visited_children - assert semicolon.text[0].type == ';' - return expression - - def visit_Expression(self, node, visited_children): - # TODO handle logical and/or - return visited_children[0] - - def visit_ComparisonExpression(self, node, visited_children): - # TODO handle comparisons - return visited_children[0] - - def visit_BitwiseOpExpression(self, node, visited_children): - # TODO handle bitwise operations - return visited_children[0] - - def visit_ArithmeticExpression(self, node, visited_children): - # TODO handle addition/subtraction - return visited_children[0] - - def visit_TermExpression(self, node, visited_children): - # TODO handle multiplication/division/modulus - return visited_children[0] - - def visit_FactorExpression(self, node, visited_children): - # TODO handle casts/address-of/pointer-dereference/unary ops/sizeof - return visited_children[0] - - def visit_ObjectExpression(self, node, visited_children): - # TODO handle array literals - # TODO handle struct literals - base, suffices = visited_children[0] - if isinstance(suffices, Node): - suffices = suffices.children - if len(suffices) == 0: - return base - if base.type == 'identifier' and suffices[0].text[0].type == '(': - arguments = suffices[1] - if arguments[0].type == 'string_literal': - data = arguments[0].data - name = f"$data{len(self.data)}" - # TODO handle non-variadic functions - arguments = [f"l {name}", '...'] - self.data.append(f"data {name} = {{ b {data}, b 0 }}") - return f"call ${base.data}({', '.join(arguments)})" - print(base) - print(suffices[0]) - - def visit_AtomicExpression(self, node, visited_children): - # TODO handle parenthesized subexpressions - return visited_children[0] - - def visit_FlowControlStatement(self, node, visited_children): - # TODO handle break/continue - ret, arg, semicolon = visited_children[0] - assert ret.text[0].type == 'return' - assert semicolon.text[0].type == ';' - if arg.type == 'constant': - return f"ret {arg.data}" - - def visit_constant(self, node, visited_children): - return node.text[0] - - def visit_string_literal(self, node, visited_children): - return node.text[0] - - def visit_identifier(self, node, visited_children): - return node.text[0] - - def generic_visit(self, node, visited_children): - """ The generic visit method. """ - if not visited_children: - return node - if len(visited_children) == 1: - return visited_children[0] - return visited_children - - -def compile_to_ssa(parse_tree): - ssa_gen = SsaGenVisitor() - return ssa_gen.visit(parse_tree) +from dataclasses import dataclass +from functools import singledispatch +from typing import Dict, List + +from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \ + VariableExpression, ConstantExpression, ReturnStatement, BasicType + + +@dataclass +class SsaResult: + data: List[str] + code: List[str] + + +@dataclass +class CompileContext: + next_data: int = 0 + next_temp: int = 0 + + +def build_ssa(file: ImplementationFile) -> str: + result = compile_to_ssa(file, CompileContext()) + data = '\n'.join(result.data) + code = '\n'.join(result.code) + return data + '\n\n' + code + + +@singledispatch +def compile_to_ssa(target, context: CompileContext) -> SsaResult: + raise NotImplementedError('unannotated compile on ' + str(type(target))) + + +@compile_to_ssa.register +def _(target: ImplementationFile, context: CompileContext): + data = [] + code = [] + for target in target.contents: + result = compile_to_ssa(target, context) + data += result.data + code += result.code + return SsaResult(data, code) + + +@compile_to_ssa.register +def _(target: FunctionDefinition, context: CompileContext) -> SsaResult: + data = [] + code = [] + for statement in target.body: + result = compile_to_ssa(statement, context) + data += result.data + code += result.code + code = [' ' + instr for instr in code] + assert len(target.args) == 0 + assert target.return_type == BasicType('int32') + code = [f"export function w ${target.name}() {{", "@start", *code, "}"] + return SsaResult(data, code) + + +@compile_to_ssa.register +def _(target: ExpressionStatement, context: CompileContext) -> SsaResult: + return compile_to_ssa(target.body, context) + + +@compile_to_ssa.register +def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult: + assert isinstance(target.function, VariableExpression) + data = [] + code = [] + args = [] + for i, expr in enumerate(target.arguments): + arg_dest = context.next_temp + result = compile_to_ssa(expr, context) + data += result.data + code += result.code + args += [f"l %t{arg_dest}"] + code += [f"call ${target.function.name}({','.join(args)}, ...)"] + return SsaResult(data, code) + + +@compile_to_ssa.register +def _(target: ConstantExpression, context: CompileContext) -> SsaResult: + if target.value.startswith('"'): + data_dest = context.next_data + context.next_data += 1 + data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"] + temp = context.next_temp + context.next_temp += 1 + code = [f"%t{temp} =l copy $data{data_dest}"] + else: + assert not target.value.startswith('0b') + assert not target.value.startswith('0B') + assert not target.value.startswith('0o') + assert not target.value.startswith('0x') + assert not target.value.startswith('0X') + assert not target.value.startswith('0f') + assert not target.value.startswith('0F') + assert '.' not in target.value + assert not target.value.startswith("'") + data = [] + temp = context.next_temp + context.next_temp += 1 + code = [f"%t{temp} =w copy {target.value}"] + return SsaResult(data, code) + + +@compile_to_ssa.register +def _(target: ReturnStatement, context: CompileContext) -> SsaResult: + if target.body is None: + return SsaResult([], ['ret']) + ret_val_dest = context.next_temp + result = compile_to_ssa(target.body, context) + result.code.append(f"ret %t{ret_val_dest}") + return result diff --git a/tests/test_hello_world.py b/tests/test_hello_world.py index 594ab4f..0377392 100644 --- a/tests/test_hello_world.py +++ b/tests/test_hello_world.py @@ -1,12 +1,12 @@ import unittest -from crowbar_reference_compiler import compile_to_ssa, parse_header, parse_implementation, scan +from crowbar_reference_compiler import build_ast, build_ssa, parse_header, parse_implementation, scan class TestHelloWorld(unittest.TestCase): def test_ssa(self): code = r""" -include "stdio.hro"; +//include "stdio.hro"; int32 main() { printf("Hello, world!\n"); @@ -15,14 +15,17 @@ int32 main() { """ tokens = scan(code) parse_tree = parse_implementation(tokens) - actual_ssa = compile_to_ssa(parse_tree) + ast = build_ast(parse_tree, []) + actual_ssa = build_ssa(ast) expected_ssa = r""" data $data0 = { b "Hello, world!\n", b 0 } export function w $main() { @start - call $printf(l $data0, ...) - ret 0 + %t0 =l copy $data0 + call $printf(l %t0, ...) + %t1 =w copy 0 + ret %t1 } """.strip() self.assertEqual(expected_ssa, actual_ssa) -- cgit v1.2.3