diff options
Diffstat (limited to 'crowbar_reference_compiler')
| -rw-r--r-- | crowbar_reference_compiler/__init__.py | 8 | ||||
| -rw-r--r-- | crowbar_reference_compiler/ast.py | 3 | ||||
| -rw-r--r-- | crowbar_reference_compiler/ssagen.py | 237 | 
3 files changed, 117 insertions, 131 deletions
| diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py index 53d942f..1410bf7 100644 --- a/crowbar_reference_compiler/__init__.py +++ b/crowbar_reference_compiler/__init__.py @@ -1,10 +1,7 @@ -import dataclasses -from pprint import pprint -  from .ast import build_ast  from .parser import parse_header, parse_implementation  from .scanner import scan -from .ssagen import compile_to_ssa +from .ssagen import build_ssa  def main(): @@ -37,9 +34,8 @@ def main():          return      full_ast = build_ast(parse_tree, args.include_dir) -    pprint(dataclasses.asdict(full_ast)) -    ssa = compile_to_ssa(parse_tree) +    ssa = build_ssa(full_ast)      if args.stop_at_qbe_ssa:          if args.out is None:              args.out = args.input.replace('.cro', '.ssa') diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py index 2a6963c..bb5a4e1 100644 --- a/crowbar_reference_compiler/ast.py +++ b/crowbar_reference_compiler/ast.py @@ -530,7 +530,8 @@ class ASTBuilder(NodeVisitor):              return ConstantExpression(body.data)          if body.type in ['true', 'false']:              return ConstantExpression(body.type) -        raise NotImplementedError() +        if body.type == 'string_literal': +            return ConstantExpression(body.data)      def visit_StructPointerElementSuffix(self, node, visited_children):          separator, element = visited_children diff --git a/crowbar_reference_compiler/ssagen.py b/crowbar_reference_compiler/ssagen.py index 8d4b6be..3c12c85 100644 --- a/crowbar_reference_compiler/ssagen.py +++ b/crowbar_reference_compiler/ssagen.py @@ -1,124 +1,113 @@ -from parsimonious import NodeVisitor  # type: ignore -from parsimonious.nodes import Node  # type: ignore - - -class SsaGenVisitor(NodeVisitor): -    def __init__(self): -        self.data = [] - -    def visit_ImplementationFile(self, node, visited_children): -        data = '\n'.join(self.data) -        functions = '\n'.join(visited_children) -        return data + '\n' + functions - -    def visit_IncludeStatement(self, node, visited_children): -        include, included_header, semicolon = visited_children -        assert include.text[0].type == 'include' -        assert included_header.type == 'string_literal' -        included_header = included_header.data -        assert semicolon.text[0].type == ';' -        return '' - -    def visit_FunctionDefinition(self, node, visited_children): -        signature, body = visited_children -        return_type, name, args = signature -        body = '\n'.join('    ' + instr for instr in body) -        return f"export function w ${name}() {{\n@start\n{body}\n}}" - -    def visit_FunctionSignature(self, node, visited_children): -        return_type, name, lparen, args, rparen = visited_children -        assert name.type == 'identifier' -        name = name.data -        assert lparen.text[0].type == '(' -        assert rparen.text[0].type == ')' -        return return_type, name, args - -    def visit_Block(self, node, visited_children): -        lbrace, statements, rbrace = visited_children -        return statements - -    def visit_Statement(self, node, visited_children): -        return visited_children[0] - -    def visit_ExpressionStatement(self, node, visited_children): -        expression, semicolon = visited_children -        assert semicolon.text[0].type == ';' -        return expression - -    def visit_Expression(self, node, visited_children): -        # TODO handle logical and/or -        return visited_children[0] - -    def visit_ComparisonExpression(self, node, visited_children): -        # TODO handle comparisons -        return visited_children[0] - -    def visit_BitwiseOpExpression(self, node, visited_children): -        # TODO handle bitwise operations -        return visited_children[0] - -    def visit_ArithmeticExpression(self, node, visited_children): -        # TODO handle addition/subtraction -        return visited_children[0] - -    def visit_TermExpression(self, node, visited_children): -        # TODO handle multiplication/division/modulus -        return visited_children[0] - -    def visit_FactorExpression(self, node, visited_children): -        # TODO handle casts/address-of/pointer-dereference/unary ops/sizeof -        return visited_children[0] - -    def visit_ObjectExpression(self, node, visited_children): -        # TODO handle array literals -        # TODO handle struct literals -        base, suffices = visited_children[0] -        if isinstance(suffices, Node): -            suffices = suffices.children -        if len(suffices) == 0: -            return base -        if base.type == 'identifier' and suffices[0].text[0].type == '(': -            arguments = suffices[1] -            if arguments[0].type == 'string_literal': -                data = arguments[0].data -                name = f"$data{len(self.data)}" -                # TODO handle non-variadic functions -                arguments = [f"l {name}", '...'] -                self.data.append(f"data {name} = {{ b {data}, b 0 }}") -            return f"call ${base.data}({', '.join(arguments)})" -        print(base) -        print(suffices[0]) - -    def visit_AtomicExpression(self, node, visited_children): -        # TODO handle parenthesized subexpressions -        return visited_children[0] - -    def visit_FlowControlStatement(self, node, visited_children): -        # TODO handle break/continue -        ret, arg, semicolon = visited_children[0] -        assert ret.text[0].type == 'return' -        assert semicolon.text[0].type == ';' -        if arg.type == 'constant': -            return f"ret {arg.data}" - -    def visit_constant(self, node, visited_children): -        return node.text[0] - -    def visit_string_literal(self, node, visited_children): -        return node.text[0] - -    def visit_identifier(self, node, visited_children): -        return node.text[0] - -    def generic_visit(self, node, visited_children): -        """ The generic visit method. """ -        if not visited_children: -            return node -        if len(visited_children) == 1: -            return visited_children[0] -        return visited_children - - -def compile_to_ssa(parse_tree): -    ssa_gen = SsaGenVisitor() -    return ssa_gen.visit(parse_tree) +from dataclasses import dataclass +from functools import singledispatch +from typing import Dict, List + +from .ast import ImplementationFile, FunctionDefinition, ExpressionStatement, FunctionCallExpression, \ +    VariableExpression, ConstantExpression, ReturnStatement, BasicType + + +@dataclass +class SsaResult: +    data: List[str] +    code: List[str] + + +@dataclass +class CompileContext: +    next_data: int = 0 +    next_temp: int = 0 + + +def build_ssa(file: ImplementationFile) -> str: +    result = compile_to_ssa(file, CompileContext()) +    data = '\n'.join(result.data) +    code = '\n'.join(result.code) +    return data + '\n\n' + code + + +@singledispatch +def compile_to_ssa(target, context: CompileContext) -> SsaResult: +    raise NotImplementedError('unannotated compile on ' + str(type(target))) + + +@compile_to_ssa.register +def _(target: ImplementationFile, context: CompileContext): +    data = [] +    code = [] +    for target in target.contents: +        result = compile_to_ssa(target, context) +        data += result.data +        code += result.code +    return SsaResult(data, code) + + +@compile_to_ssa.register +def _(target: FunctionDefinition, context: CompileContext) -> SsaResult: +    data = [] +    code = [] +    for statement in target.body: +        result = compile_to_ssa(statement, context) +        data += result.data +        code += result.code +    code = ['    ' + instr for instr in code] +    assert len(target.args) == 0 +    assert target.return_type == BasicType('int32') +    code = [f"export function w ${target.name}() {{", "@start", *code, "}"] +    return SsaResult(data, code) + + +@compile_to_ssa.register +def _(target: ExpressionStatement, context: CompileContext) -> SsaResult: +    return compile_to_ssa(target.body, context) + + +@compile_to_ssa.register +def _(target: FunctionCallExpression, context: CompileContext) -> SsaResult: +    assert isinstance(target.function, VariableExpression) +    data = [] +    code = [] +    args = [] +    for i, expr in enumerate(target.arguments): +        arg_dest = context.next_temp +        result = compile_to_ssa(expr, context) +        data += result.data +        code += result.code +        args += [f"l %t{arg_dest}"] +    code += [f"call ${target.function.name}({','.join(args)}, ...)"] +    return SsaResult(data, code) + + +@compile_to_ssa.register +def _(target: ConstantExpression, context: CompileContext) -> SsaResult: +    if target.value.startswith('"'): +        data_dest = context.next_data +        context.next_data += 1 +        data = [f"data $data{data_dest} = {{ b {target.value}, b 0 }}"] +        temp = context.next_temp +        context.next_temp += 1 +        code = [f"%t{temp} =l copy $data{data_dest}"] +    else: +        assert not target.value.startswith('0b') +        assert not target.value.startswith('0B') +        assert not target.value.startswith('0o') +        assert not target.value.startswith('0x') +        assert not target.value.startswith('0X') +        assert not target.value.startswith('0f') +        assert not target.value.startswith('0F') +        assert '.' not in target.value +        assert not target.value.startswith("'") +        data = [] +        temp = context.next_temp +        context.next_temp += 1 +        code = [f"%t{temp} =w copy {target.value}"] +    return SsaResult(data, code) + + +@compile_to_ssa.register +def _(target: ReturnStatement, context: CompileContext) -> SsaResult: +    if target.body is None: +        return SsaResult([], ['ret']) +    ret_val_dest = context.next_temp +    result = compile_to_ssa(target.body, context) +    result.code.append(f"ret %t{ret_val_dest}") +    return result |