From 35979953a534bcdb2185de0a934e7937f319d687 Mon Sep 17 00:00:00 2001 From: Melody Horn Date: Wed, 4 Nov 2020 19:24:09 -0700 Subject: switch from specific declarations to generic AST --- crowbar_reference_compiler/ast.py | 432 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 432 insertions(+) create mode 100644 crowbar_reference_compiler/ast.py (limited to 'crowbar_reference_compiler/ast.py') diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py new file mode 100644 index 0000000..4b9c21d --- /dev/null +++ b/crowbar_reference_compiler/ast.py @@ -0,0 +1,432 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple, Union + +from parsimonious import NodeVisitor # type: ignore + +from .scanner import scan +from .parser import parse_header + + +@dataclass +class Type: + pass + + +@dataclass +class Expression: + pass + + +@dataclass +class ConstantExpression(Expression): + value: str + + +@dataclass +class VariableExpression(Expression): + name: str + + +@dataclass +class BasicType(Type): + name: str + + +@dataclass +class ConstType(Type): + target: Type + + +@dataclass +class PointerType(Type): + target: Type + + +@dataclass +class ArrayType(Type): + contents: Type + size: Expression + + +@dataclass +class FunctionType(Type): + return_type: Type + args: List[Type] + + +@dataclass +class HeaderFileElement: + pass + + +@dataclass +class ImplementationFileElement: + pass + + +@dataclass +class Statement: + pass + + +@dataclass +class EmptyStatement(Statement): + pass + + +@dataclass +class FragileStatement(Statement): + body: Statement + + +@dataclass +class ExpressionStatement(Statement): + body: Expression + + +@dataclass +class IfStatement(Statement): + condition: Expression + then: List[Statement] + els: Optional[List[Statement]] + + +@dataclass +class SwitchStatement(Statement): + expression: Expression + body: List[Union[Optional[Expression], Statement]] + + +@dataclass +class WhileStatement(Statement): + condition: Expression + body: List[Statement] + + +@dataclass +class DoWhileStatement(Statement): + condition: Expression + body: List[Statement] + + +@dataclass +class Declaration: + name: str + + +@dataclass +class VariableDeclaration(Declaration, HeaderFileElement): + """Represents the declaration of a variable.""" + type: Type + + +@dataclass +class VariableDefinition(Declaration, ImplementationFileElement, Statement): + """Represents the definition of a variable.""" + type: Type + value: Expression + + +@dataclass +class AssignmentStatement(Statement): + pass + + +@dataclass +class ForStatement(Statement): + init: List[VariableDefinition] + condition: Expression + update: List[AssignmentStatement] + + +@dataclass +class ContinueStatement(Statement): + pass + + +@dataclass +class BreakStatement(Statement): + pass + + +@dataclass +class ReturnStatement(Statement): + body: Optional[Expression] + + +@dataclass +class DirectAssignment(AssignmentStatement): + destination: Expression + value: Expression + + +@dataclass +class UpdateAssignment(AssignmentStatement): + destination: Expression + operation: str + value: Expression + + +@dataclass +class CrementAssignment(AssignmentStatement): + destination: Expression + operation: str + + +@dataclass +class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): + """Represents the declaration of a struct type.""" + fields: Optional[List[VariableDeclaration]] + + +@dataclass +class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): + """Represents the declaration of an enum type.""" + values: List[Tuple[str, Optional[int]]] + + +@dataclass +class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): + """Represents the declaration of a union type.""" + tag: Optional[VariableDeclaration] + cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]] + + +@dataclass +class FunctionDeclaration(Declaration, HeaderFileElement): + """Represents the declaration of a function.""" + return_type: Type + args: List[VariableDeclaration] + + +@dataclass +class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileElement): + """Represents the definition of a function.""" + return_type: Type + args: List[VariableDeclaration] + body: List[Statement] + + +@dataclass +class HeaderFile: + includes: List['HeaderFile'] + contents: List[HeaderFileElement] + + +@dataclass +class ImplementationFile: + includes: List[HeaderFile] + contents: List[ImplementationFileElement] + + +# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal +class ASTBuilder(NodeVisitor): + def __init__(self, include_folders): + self.include_folders = include_folders + + def visit_HeaderFile(self, node, visited_children) -> HeaderFile: + includes, elements = visited_children + if not isinstance(includes, list): + includes = [] + return HeaderFile(includes, elements) + + def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile: + includes, elements = visited_children + return ImplementationFile(includes, elements) + + def visit_IncludeStatement(self, node, visited_children) -> HeaderFile: + include, included_header, semicolon = visited_children + assert include.type == 'include' + assert included_header.type == 'string_literal' + included_header = included_header.data.strip('"') + assert semicolon.type == ';' + for include_folder in self.include_folders: + header = Path(include_folder) / included_header + if header.exists(): + with open(header, 'r', encoding='utf-8') as header_file: + header_text = header_file.read() + header_parse_tree = parse_header(scan(header_text)) + return self.visit(header_parse_tree) + raise FileNotFoundError(included_header) + + def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration: + struct, name, lbrace, fields, rbrace = visited_children + assert struct.type == 'struct' + assert lbrace.type == '{' + assert rbrace.type == '}' + name = name.data + if not isinstance(fields, list): + fields = [fields] + return StructDeclaration(name, fields) + + def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration: + opaque, struct, name, semi = visited_children + assert opaque.type == 'opaque' + assert struct.type == 'struct' + assert semi.type == ';' + name = name.data + return StructDeclaration(name, None) + + def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration: + enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children + assert enum.type == 'enum' + assert lbrace.type == '{' + assert rbrace.type == '}' + name = name.data + values = [first_member] + for _, v in extra_members: + values.append(v) + return EnumDeclaration(name, values) + + def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]: + name, equals_value = visited_children + name = name.data + if len(equals_value) == 0: + return name, None + _, value = equals_value + return name, value + + def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration: + union, name, lbrace, tag, body, rbrace = visited_children + assert union.type == 'union' + assert lbrace.type == '{' + assert rbrace.type == '}' + name = name.data + expected_tagname, body = body + if tag.name != expected_tagname: + raise NameError(f"tag {tag} does not match switch argument {expected_tagname}") + if not isinstance(body, list): + body = [body] + return UnionDeclaration(name, tag, body) + + def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]: + switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children + assert switch.type == 'switch' + assert lparen.type == '(' + assert rparen.type == ')' + assert lbrace.type == '{' + assert rbrace.type == '}' + return tag.data, body + + def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]: + cases, var = visited_children + if isinstance(cases, list): + cases = cases[0] + if isinstance(var, VariableDeclaration): + return cases, var + else: + return cases, None + + def visit_CaseSpecifier(self, node, visited_children) -> Expression: + while isinstance(visited_children, list) and len(visited_children) == 1: + visited_children = visited_children[0] + # TODO don't explode on 'default:' + case, expr, colon = visited_children + while isinstance(expr, list): + expr = expr[0] + return expr + + def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration: + fragile, union, name, lbrace, body, rbrace = visited_children + assert fragile.type == 'fragile' + assert union.type == 'union' + assert lbrace.type == '{' + assert rbrace.type == '}' + name = name.data + return UnionDeclaration(name, None, body) + + def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration: + signature, semi = visited_children + assert semi.type == ';' + return signature + + def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition: + type, name, eq, value, semi = visited_children + assert eq.type == '=' + assert semi.type == ';' + name = name.data + return VariableDefinition(name, type, value) + + def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration: + type, name, semi = visited_children + assert semi.type == ';' + name = name.data + return VariableDeclaration(name, type) + + def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition: + signature, body = visited_children + return FunctionDefinition(signature.name, signature.return_type, signature.args, body) + + def visit_FunctionSignature(self, node, visited_children) -> FunctionDeclaration: + return_type, name, lparen, args, rparen = visited_children + assert name.type == 'identifier' + name = name.data + assert lparen.type == '(' + assert rparen.type == ')' + return FunctionDeclaration(name, return_type, args) + + def visit_BasicType(self, node, visited_children) -> Type: + while isinstance(visited_children, list) and len(visited_children) == 1: + visited_children = visited_children[0] + if isinstance(visited_children, list): + if len(visited_children) == 3: + # parenthesized! + lparen, ty, rparen = visited_children + assert lparen.type == '(' + assert rparen.type == ')' + return ty + else: + category, name = visited_children + category = category.type + name = name.data + return BasicType(f"{category} {name}") + return BasicType(visited_children.type) + + def visit_ArrayType(self, node, visited_children) -> ArrayType: + contents, lbracket, size, rbracket = visited_children + assert lbracket.type == '[' + assert rbracket.type == ']' + # TODO don't explode on nontrivial expression + while isinstance(size, list): + size = size[0] + return ArrayType(contents, size) + + def visit_PointerType(self, node, visited_children) -> PointerType: + contents, splat = visited_children + assert splat.type == '*' + return PointerType(contents) + + def visit_AtomicExpression(self, node, visited_children) -> Expression: + if isinstance(visited_children, list) and len(visited_children) == 3: + lparen, body, rparen = visited_children + assert lparen.type == '(' + assert rparen.type == ')' + return body + body = visited_children + while isinstance(body, list): + body = body[0] + if body.type == 'identifier': + return VariableExpression(body.data) + if body.type == 'constant': + return ConstantExpression(body.data) + if body.type in ['true', 'false']: + return ConstantExpression(body.type) + raise NotImplementedError() + + def generic_visit(self, node, visited_children): + """ The generic visit method. """ + if not visited_children: + if len(node.text) == 0: + return [] + if len(node.text) == 1: + return node.text[0] + raise ValueError('just a node: ' + str(node)) + if len(visited_children) == 1: + return visited_children[0] + return visited_children + + +def build_ast(parse_tree, include_dirs): + builder = ASTBuilder(include_dirs) + return builder.visit(parse_tree) -- cgit v1.2.3