From 35979953a534bcdb2185de0a934e7937f319d687 Mon Sep 17 00:00:00 2001 From: Melody Horn Date: Wed, 4 Nov 2020 19:24:09 -0700 Subject: switch from specific declarations to generic AST --- crowbar_reference_compiler/__init__.py | 9 +- crowbar_reference_compiler/ast.py | 432 +++++++++++++++++++++++++++++ crowbar_reference_compiler/declarations.py | 274 ------------------ tests/test_ast.py | 54 ++++ tests/test_declarations.py | 53 ---- 5 files changed, 492 insertions(+), 330 deletions(-) create mode 100644 crowbar_reference_compiler/ast.py delete mode 100644 crowbar_reference_compiler/declarations.py create mode 100644 tests/test_ast.py delete mode 100644 tests/test_declarations.py diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py index c7baeea..53d942f 100644 --- a/crowbar_reference_compiler/__init__.py +++ b/crowbar_reference_compiler/__init__.py @@ -1,4 +1,7 @@ -from .declarations import load_declarations +import dataclasses +from pprint import pprint + +from .ast import build_ast from .parser import parse_header, parse_implementation from .scanner import scan from .ssagen import compile_to_ssa @@ -33,8 +36,8 @@ def main(): output_file.write(str(parse_tree)) return - decls = load_declarations(parse_tree, args.include_dir) - print(decls) + full_ast = build_ast(parse_tree, args.include_dir) + pprint(dataclasses.asdict(full_ast)) ssa = compile_to_ssa(parse_tree) if args.stop_at_qbe_ssa: diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py new file mode 100644 index 0000000..4b9c21d --- /dev/null +++ b/crowbar_reference_compiler/ast.py @@ -0,0 +1,432 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple, Union + +from parsimonious import NodeVisitor # type: ignore + +from .scanner import scan +from .parser import parse_header + + +@dataclass +class Type: + pass + + +@dataclass +class Expression: + pass + + +@dataclass +class ConstantExpression(Expression): + value: str + + +@dataclass +class VariableExpression(Expression): + name: str + + +@dataclass +class BasicType(Type): + name: str + + +@dataclass +class ConstType(Type): + target: Type + + +@dataclass +class PointerType(Type): + target: Type + + +@dataclass +class ArrayType(Type): + contents: Type + size: Expression + + +@dataclass +class FunctionType(Type): + return_type: Type + args: List[Type] + + +@dataclass +class HeaderFileElement: + pass + + +@dataclass +class ImplementationFileElement: + pass + + +@dataclass +class Statement: + pass + + +@dataclass +class EmptyStatement(Statement): + pass + + +@dataclass +class FragileStatement(Statement): + body: Statement + + +@dataclass +class ExpressionStatement(Statement): + body: Expression + + +@dataclass +class IfStatement(Statement): + condition: Expression + then: List[Statement] + els: Optional[List[Statement]] + + +@dataclass +class SwitchStatement(Statement): + expression: Expression + body: List[Union[Optional[Expression], Statement]] + + +@dataclass +class WhileStatement(Statement): + condition: Expression + body: List[Statement] + + +@dataclass +class DoWhileStatement(Statement): + condition: Expression + body: List[Statement] + + +@dataclass +class Declaration: + name: str + + +@dataclass +class VariableDeclaration(Declaration, HeaderFileElement): + """Represents the declaration of a variable.""" + type: Type + + +@dataclass +class VariableDefinition(Declaration, ImplementationFileElement, Statement): + """Represents the definition of a variable.""" + type: Type + value: Expression + + +@dataclass +class AssignmentStatement(Statement): + pass + + +@dataclass +class ForStatement(Statement): + init: List[VariableDefinition] + condition: Expression + update: List[AssignmentStatement] + + +@dataclass +class ContinueStatement(Statement): + pass + + +@dataclass +class BreakStatement(Statement): + pass + + +@dataclass +class ReturnStatement(Statement): + body: Optional[Expression] + + +@dataclass +class DirectAssignment(AssignmentStatement): + destination: Expression + value: Expression + + +@dataclass +class UpdateAssignment(AssignmentStatement): + destination: Expression + operation: str + value: Expression + + +@dataclass +class CrementAssignment(AssignmentStatement): + destination: Expression + operation: str + + +@dataclass +class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): + """Represents the declaration of a struct type.""" + fields: Optional[List[VariableDeclaration]] + + +@dataclass +class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): + """Represents the declaration of an enum type.""" + values: List[Tuple[str, Optional[int]]] + + +@dataclass +class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): + """Represents the declaration of a union type.""" + tag: Optional[VariableDeclaration] + cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]] + + +@dataclass +class FunctionDeclaration(Declaration, HeaderFileElement): + """Represents the declaration of a function.""" + return_type: Type + args: List[VariableDeclaration] + + +@dataclass +class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileElement): + """Represents the definition of a function.""" + return_type: Type + args: List[VariableDeclaration] + body: List[Statement] + + +@dataclass +class HeaderFile: + includes: List['HeaderFile'] + contents: List[HeaderFileElement] + + +@dataclass +class ImplementationFile: + includes: List[HeaderFile] + contents: List[ImplementationFileElement] + + +# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal +class ASTBuilder(NodeVisitor): + def __init__(self, include_folders): + self.include_folders = include_folders + + def visit_HeaderFile(self, node, visited_children) -> HeaderFile: + includes, elements = visited_children + if not isinstance(includes, list): + includes = [] + return HeaderFile(includes, elements) + + def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile: + includes, elements = visited_children + return ImplementationFile(includes, elements) + + def visit_IncludeStatement(self, node, visited_children) -> HeaderFile: + include, included_header, semicolon = visited_children + assert include.type == 'include' + assert included_header.type == 'string_literal' + included_header = included_header.data.strip('"') + assert semicolon.type == ';' + for include_folder in self.include_folders: + header = Path(include_folder) / included_header + if header.exists(): + with open(header, 'r', encoding='utf-8') as header_file: + header_text = header_file.read() + header_parse_tree = parse_header(scan(header_text)) + return self.visit(header_parse_tree) + raise FileNotFoundError(included_header) + + def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration: + struct, name, lbrace, fields, rbrace = visited_children + assert struct.type == 'struct' + assert lbrace.type == '{' + assert rbrace.type == '}' + name = name.data + if not isinstance(fields, list): + fields = [fields] + return StructDeclaration(name, fields) + + def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration: + opaque, struct, name, semi = visited_children + assert opaque.type == 'opaque' + assert struct.type == 'struct' + assert semi.type == ';' + name = name.data + return StructDeclaration(name, None) + + def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration: + enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children + assert enum.type == 'enum' + assert lbrace.type == '{' + assert rbrace.type == '}' + name = name.data + values = [first_member] + for _, v in extra_members: + values.append(v) + return EnumDeclaration(name, values) + + def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]: + name, equals_value = visited_children + name = name.data + if len(equals_value) == 0: + return name, None + _, value = equals_value + return name, value + + def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration: + union, name, lbrace, tag, body, rbrace = visited_children + assert union.type == 'union' + assert lbrace.type == '{' + assert rbrace.type == '}' + name = name.data + expected_tagname, body = body + if tag.name != expected_tagname: + raise NameError(f"tag {tag} does not match switch argument {expected_tagname}") + if not isinstance(body, list): + body = [body] + return UnionDeclaration(name, tag, body) + + def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]: + switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children + assert switch.type == 'switch' + assert lparen.type == '(' + assert rparen.type == ')' + assert lbrace.type == '{' + assert rbrace.type == '}' + return tag.data, body + + def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]: + cases, var = visited_children + if isinstance(cases, list): + cases = cases[0] + if isinstance(var, VariableDeclaration): + return cases, var + else: + return cases, None + + def visit_CaseSpecifier(self, node, visited_children) -> Expression: + while isinstance(visited_children, list) and len(visited_children) == 1: + visited_children = visited_children[0] + # TODO don't explode on 'default:' + case, expr, colon = visited_children + while isinstance(expr, list): + expr = expr[0] + return expr + + def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration: + fragile, union, name, lbrace, body, rbrace = visited_children + assert fragile.type == 'fragile' + assert union.type == 'union' + assert lbrace.type == '{' + assert rbrace.type == '}' + name = name.data + return UnionDeclaration(name, None, body) + + def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration: + signature, semi = visited_children + assert semi.type == ';' + return signature + + def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition: + type, name, eq, value, semi = visited_children + assert eq.type == '=' + assert semi.type == ';' + name = name.data + return VariableDefinition(name, type, value) + + def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration: + type, name, semi = visited_children + assert semi.type == ';' + name = name.data + return VariableDeclaration(name, type) + + def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition: + signature, body = visited_children + return FunctionDefinition(signature.name, signature.return_type, signature.args, body) + + def visit_FunctionSignature(self, node, visited_children) -> FunctionDeclaration: + return_type, name, lparen, args, rparen = visited_children + assert name.type == 'identifier' + name = name.data + assert lparen.type == '(' + assert rparen.type == ')' + return FunctionDeclaration(name, return_type, args) + + def visit_BasicType(self, node, visited_children) -> Type: + while isinstance(visited_children, list) and len(visited_children) == 1: + visited_children = visited_children[0] + if isinstance(visited_children, list): + if len(visited_children) == 3: + # parenthesized! + lparen, ty, rparen = visited_children + assert lparen.type == '(' + assert rparen.type == ')' + return ty + else: + category, name = visited_children + category = category.type + name = name.data + return BasicType(f"{category} {name}") + return BasicType(visited_children.type) + + def visit_ArrayType(self, node, visited_children) -> ArrayType: + contents, lbracket, size, rbracket = visited_children + assert lbracket.type == '[' + assert rbracket.type == ']' + # TODO don't explode on nontrivial expression + while isinstance(size, list): + size = size[0] + return ArrayType(contents, size) + + def visit_PointerType(self, node, visited_children) -> PointerType: + contents, splat = visited_children + assert splat.type == '*' + return PointerType(contents) + + def visit_AtomicExpression(self, node, visited_children) -> Expression: + if isinstance(visited_children, list) and len(visited_children) == 3: + lparen, body, rparen = visited_children + assert lparen.type == '(' + assert rparen.type == ')' + return body + body = visited_children + while isinstance(body, list): + body = body[0] + if body.type == 'identifier': + return VariableExpression(body.data) + if body.type == 'constant': + return ConstantExpression(body.data) + if body.type in ['true', 'false']: + return ConstantExpression(body.type) + raise NotImplementedError() + + def generic_visit(self, node, visited_children): + """ The generic visit method. """ + if not visited_children: + if len(node.text) == 0: + return [] + if len(node.text) == 1: + return node.text[0] + raise ValueError('just a node: ' + str(node)) + if len(visited_children) == 1: + return visited_children[0] + return visited_children + + +def build_ast(parse_tree, include_dirs): + builder = ASTBuilder(include_dirs) + return builder.visit(parse_tree) diff --git a/crowbar_reference_compiler/declarations.py b/crowbar_reference_compiler/declarations.py deleted file mode 100644 index e9a7e4b..0000000 --- a/crowbar_reference_compiler/declarations.py +++ /dev/null @@ -1,274 +0,0 @@ -from dataclasses import dataclass -from pathlib import Path -from typing import List, Optional, Tuple, Union - -from parsimonious import NodeVisitor # type: ignore - -from .scanner import scan -from .parser import parse_header - - -@dataclass -class Type: - pass - - -@dataclass -class BasicType(Type): - name: str - - -@dataclass -class PointerType(Type): - target: Type - - -@dataclass -class ArrayType(Type): - contents: Type - size: int - - -@dataclass -class Declaration: - name: str - - -@dataclass -class VariableDeclaration(Declaration): - """Represents the declaration of a variable.""" - type: Type - value: Optional[str] - - -@dataclass -class Declarations: - included: List[Declaration] - - -@dataclass -class StructDeclaration(Declaration): - """Represents the declaration of a struct type.""" - fields: Optional[List[VariableDeclaration]] - - -@dataclass -class EnumDeclaration(Declaration): - """Represents the declaration of an enum type.""" - values: List[Tuple[str, Optional[int]]] - - -@dataclass -class UnionDeclaration(Declaration): - """Represents the declaration of a union type.""" - tag: Optional[VariableDeclaration] - cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]] - - -@dataclass -class FunctionDeclaration(Declaration): - """Represents the declaration of a function.""" - return_type: Type - args: List[VariableDeclaration] - - -# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal -class DeclarationVisitor(NodeVisitor): - def __init__(self, include_folders): - self.data = [] - self.include_folders = include_folders - - def visit_HeaderFile(self, node, visited_children): - includes, elements = visited_children - return elements - - def visit_ImplementationFile(self, node, visited_children): - includes, elements = visited_children - includes = [y for x in includes for y in x] - return [x for x in includes + elements if x is not None] - - def visit_IncludeStatement(self, node, visited_children): - include, included_header, semicolon = visited_children - assert include.text[0].type == 'include' - assert included_header.type == 'string_literal' - included_header = included_header.data.strip('"') - assert semicolon.text[0].type == ';' - for include_folder in self.include_folders: - header = Path(include_folder) / included_header - if header.exists(): - with open(header, 'r', encoding='utf-8') as header_file: - header_text = header_file.read() - header_parse_tree = parse_header(scan(header_text)) - return self.visit(header_parse_tree) - raise FileNotFoundError(included_header) - - def visit_NormalStructDefinition(self, node, visited_children): - struct, name, lbrace, fields, rbrace = visited_children - assert struct.text[0].type == 'struct' - assert lbrace.text[0].type == '{' - assert rbrace.text[0].type == '}' - name = name.data - if not isinstance(fields, list): - fields = [fields] - return StructDeclaration(name, fields) - - def visit_OpaqueStructDefinition(self, node, visited_children): - opaque, struct, name, semi = visited_children - assert opaque.text[0].type == 'opaque' - assert struct.text[0].type == 'struct' - assert semi.text[0].type == ';' - name = name.data - return StructDeclaration(name, None) - - def visit_EnumDefinition(self, node, visited_children): - enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children - assert enum.text[0].type == 'enum' - assert lbrace.text[0].type == '{' - assert rbrace.text[0].type == '}' - name = name.data - values = [first_member] - for _, v in extra_members: - values.append(v) - return EnumDeclaration(name, values) - - def visit_EnumMember(self, node, visited_children): - name, equals_value = visited_children - name = name.data - if not isinstance(equals_value, list): - return name, None - _, value = equals_value - return name, value - - def visit_RobustUnionDefinition(self, node, visited_children): - union, name, lbrace, tag, body, rbrace = visited_children - assert union.text[0].type == 'union' - assert lbrace.text[0].type == '{' - assert rbrace.text[0].type == '}' - name = name.data - expected_tagname, body = body - if tag.name != expected_tagname: - raise NameError(f"tag {tag} does not match switch argument {expected_tagname}") - if not isinstance(body, list): - body = [body] - return UnionDeclaration(name, tag, body) - - def visit_UnionBody(self, node, visited_children): - switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children - assert switch.text[0].type == 'switch' - assert lparen.text[0].type == '(' - assert rparen.text[0].type == ')' - assert lbrace.text[0].type == '{' - assert rbrace.text[0].type == '}' - return tag.data, body - - def visit_UnionBodySet(self, node, visited_children): - cases, var = visited_children - if isinstance(cases, list): - cases = cases[0] - if isinstance(var, VariableDeclaration): - return cases, var - else: - return cases, None - - def visit_CaseSpecifier(self, node, visited_children): - while isinstance(visited_children, list) and len(visited_children) == 1: - visited_children = visited_children[0] - # TODO don't explode on 'default:' - case, expr, colon = visited_children - while isinstance(expr, list): - expr = expr[0] - # TODO don't explode on nontrivial expression - return expr.data - - def visit_FragileUnionDefinition(self, node, visited_children): - fragile, union, name, lbrace, body, rbrace = visited_children - assert fragile.text[0].type == 'fragile' - assert union.text[0].type == 'union' - assert lbrace.text[0].type == '{' - assert rbrace.text[0].type == '}' - name = name.data - return UnionDeclaration(name, None, body) - - def visit_FunctionDeclaration(self, node, visited_children): - signature, semi = visited_children - assert semi.text[0].type == ';' - return signature - - def visit_VariableDefinition(self, node, visited_children): - type, name, eq, value, semi = visited_children - assert eq.text[0].type == '=' - assert semi.text[0].type == ';' - name = name.data - return VariableDeclaration(name, type, value) - - def visit_VariableDeclaration(self, node, visited_children): - type, name, semi = visited_children - assert semi.text[0].type == ';' - name = name.data - return VariableDeclaration(name, type, None) - - def visit_FunctionDefinition(self, node, visited_children): - signature, body = visited_children - return signature - - def visit_FunctionSignature(self, node, visited_children): - return_type, name, lparen, args, rparen = visited_children - assert name.type == 'identifier' - name = name.data - assert lparen.text[0].type == '(' - assert rparen.text[0].type == ')' - return FunctionDeclaration(name, return_type, args) - - def visit_BasicType(self, node, visited_children): - while isinstance(visited_children, list) and len(visited_children) == 1: - visited_children = visited_children[0] - if isinstance(visited_children, list): - if len(visited_children) == 3: - # parenthesized! - lparen, ty, rparen = visited_children - assert lparen.text[0].type == '(' - assert rparen.text[0].type == ')' - return ty - else: - category, name = visited_children - category = category.text[0].type - name = name.data - return BasicType(f"{category} {name}") - return BasicType(visited_children.text[0].type) - - def visit_ArrayType(self, node, visited_children): - contents, lbracket, size, rbracket = visited_children - assert lbracket.text[0].type == '[' - assert rbracket.text[0].type == ']' - # TODO don't explode on nontrivial expression - while isinstance(size, list): - size = size[0] - size = int(size.data) - return ArrayType(contents, size) - - def visit_PointerType(self, node, visited_children): - contents, splat = visited_children - assert splat.text[0].type == '*' - return PointerType(contents) - - def visit_constant(self, node, visited_children): - return node.text[0] - - def visit_string_literal(self, node, visited_children): - return node.text[0] - - def visit_identifier(self, node, visited_children): - return node.text[0] - - def generic_visit(self, node, visited_children): - """ The generic visit method. """ - if not visited_children: - return node - if len(visited_children) == 1: - return visited_children[0] - return visited_children - - -def load_declarations(parse_tree, include_dirs): - declarations = DeclarationVisitor(include_dirs) - return declarations.visit(parse_tree) diff --git a/tests/test_ast.py b/tests/test_ast.py new file mode 100644 index 0000000..b2c592d --- /dev/null +++ b/tests/test_ast.py @@ -0,0 +1,54 @@ +import unittest + +from crowbar_reference_compiler import build_ast, parse_header, scan +from crowbar_reference_compiler.ast import ArrayType, BasicType, ConstantExpression, EnumDeclaration, HeaderFile, \ + PointerType, StructDeclaration, UnionDeclaration, VariableDeclaration, VariableExpression + + +class TestAST(unittest.TestCase): + def test_kitchen_sink(self): + code = r""" +struct normal { + bool fake; + (uint8[3])* data; +} + +opaque struct ope; + +enum sample { + Testing, +} + +union robust { + enum sample tag; + + switch (tag) { + case Testing: bool testPassed; + } +} + +fragile union not_robust { + int8 sample; + bool nope; +} +""" + tokens = scan(code) + parse_tree = parse_header(tokens) + decls = build_ast(parse_tree, []) + normal = StructDeclaration('normal', [ + VariableDeclaration('fake', BasicType('bool')), + VariableDeclaration('data', PointerType(ArrayType(BasicType('uint8'), ConstantExpression('3')))), + ]) + ope = StructDeclaration('ope', None) + sample = EnumDeclaration('sample', [('Testing', None)]) + robust = UnionDeclaration('robust', VariableDeclaration('tag', BasicType('enum sample')), + [(VariableExpression('Testing'), VariableDeclaration('testPassed', BasicType('bool')))]) + not_robust = UnionDeclaration('not_robust', None, [ + VariableDeclaration('sample', BasicType('int8')), + VariableDeclaration('nope', BasicType('bool')), + ]) + self.assertEqual(decls, HeaderFile([], [normal, ope, sample, robust, not_robust])) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_declarations.py b/tests/test_declarations.py deleted file mode 100644 index 9eaf488..0000000 --- a/tests/test_declarations.py +++ /dev/null @@ -1,53 +0,0 @@ -import unittest - -from crowbar_reference_compiler import compile_to_ssa, load_declarations, parse_header, parse_implementation, scan -from crowbar_reference_compiler.declarations import ArrayType, BasicType, EnumDeclaration, PointerType, \ - StructDeclaration, UnionDeclaration, VariableDeclaration - - -class TestDeclarationLoading(unittest.TestCase): - def test_kitchen_sink(self): - code = r""" -struct normal { - bool fake; - (uint8[3])* data; -} - -opaque struct ope; - -enum sample { - Testing, -} - -union robust { - enum sample tag; - - switch (tag) { - case Testing: bool testPassed; - } -} - -fragile union not_robust { - int8 sample; - bool nope; -} -""" - tokens = scan(code) - parse_tree = parse_header(tokens) - decls = load_declarations(parse_tree, []) - normal = StructDeclaration('normal', [ - VariableDeclaration('fake', BasicType('bool'), None), - VariableDeclaration('data', PointerType(ArrayType(BasicType('uint8'), 3)), None), - ]) - ope = StructDeclaration('ope', None) - sample = EnumDeclaration('sample', [('Testing', None)]) - robust = UnionDeclaration('robust', VariableDeclaration('tag', BasicType('enum sample'), None), - [('Testing', VariableDeclaration('testPassed', BasicType('bool'), None))]) - not_robust = UnionDeclaration('not_robust', None, - [VariableDeclaration('sample', BasicType('int8'), None), - VariableDeclaration('nope', BasicType('bool'), None)]) - self.assertListEqual(decls, [normal, ope, sample, robust, not_robust]) - - -if __name__ == '__main__': - unittest.main() -- cgit v1.2.3