from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple, Union from parsimonious import NodeVisitor # type: ignore from .scanner import scan from .parser import parse_header @dataclass class Type: pass @dataclass class Expression: pass @dataclass class ConstantExpression(Expression): value: str @dataclass class VariableExpression(Expression): name: str @dataclass class BasicType(Type): name: str @dataclass class ConstType(Type): target: Type @dataclass class PointerType(Type): target: Type @dataclass class ArrayType(Type): contents: Type size: Expression @dataclass class FunctionType(Type): return_type: Type args: List[Type] @dataclass class HeaderFileElement: pass @dataclass class ImplementationFileElement: pass @dataclass class Statement: pass @dataclass class EmptyStatement(Statement): pass @dataclass class FragileStatement(Statement): body: Statement @dataclass class ExpressionStatement(Statement): body: Expression @dataclass class IfStatement(Statement): condition: Expression then: List[Statement] els: Optional[List[Statement]] @dataclass class SwitchStatement(Statement): expression: Expression body: List[Union[Optional[Expression], Statement]] @dataclass class WhileStatement(Statement): condition: Expression body: List[Statement] @dataclass class DoWhileStatement(Statement): condition: Expression body: List[Statement] @dataclass class Declaration: name: str @dataclass class VariableDeclaration(Declaration, HeaderFileElement): """Represents the declaration of a variable.""" type: Type @dataclass class VariableDefinition(Declaration, ImplementationFileElement, Statement): """Represents the definition of a variable.""" type: Type value: Expression @dataclass class AssignmentStatement(Statement): pass @dataclass class ForStatement(Statement): init: List[VariableDefinition] condition: Expression update: List[AssignmentStatement] @dataclass class ContinueStatement(Statement): pass @dataclass class BreakStatement(Statement): pass @dataclass class ReturnStatement(Statement): body: Optional[Expression] @dataclass class DirectAssignment(AssignmentStatement): destination: Expression value: Expression @dataclass class UpdateAssignment(AssignmentStatement): destination: Expression operation: str value: Expression @dataclass class CrementAssignment(AssignmentStatement): destination: Expression operation: str @dataclass class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of a struct type.""" fields: Optional[List[VariableDeclaration]] @dataclass class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of an enum type.""" values: List[Tuple[str, Optional[int]]] @dataclass class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the declaration of a union type.""" tag: Optional[VariableDeclaration] cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]] @dataclass class FunctionDeclaration(Declaration, HeaderFileElement): """Represents the declaration of a function.""" return_type: Type args: List[VariableDeclaration] @dataclass class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileElement): """Represents the definition of a function.""" return_type: Type args: List[VariableDeclaration] body: List[Statement] @dataclass class HeaderFile: includes: List['HeaderFile'] contents: List[HeaderFileElement] @dataclass class ImplementationFile: includes: List[HeaderFile] contents: List[ImplementationFileElement] # noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal class ASTBuilder(NodeVisitor): def __init__(self, include_folders): self.include_folders = include_folders def visit_HeaderFile(self, node, visited_children) -> HeaderFile: includes, elements = visited_children if not isinstance(includes, list): includes = [] return HeaderFile(includes, elements) def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile: includes, elements = visited_children return ImplementationFile(includes, elements) def visit_IncludeStatement(self, node, visited_children) -> HeaderFile: include, included_header, semicolon = visited_children assert include.type == 'include' assert included_header.type == 'string_literal' included_header = included_header.data.strip('"') assert semicolon.type == ';' for include_folder in self.include_folders: header = Path(include_folder) / included_header if header.exists(): with open(header, 'r', encoding='utf-8') as header_file: header_text = header_file.read() header_parse_tree = parse_header(scan(header_text)) return self.visit(header_parse_tree) raise FileNotFoundError(included_header) def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration: struct, name, lbrace, fields, rbrace = visited_children assert struct.type == 'struct' assert lbrace.type == '{' assert rbrace.type == '}' name = name.data if not isinstance(fields, list): fields = [fields] return StructDeclaration(name, fields) def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration: opaque, struct, name, semi = visited_children assert opaque.type == 'opaque' assert struct.type == 'struct' assert semi.type == ';' name = name.data return StructDeclaration(name, None) def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration: enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children assert enum.type == 'enum' assert lbrace.type == '{' assert rbrace.type == '}' name = name.data values = [first_member] for _, v in extra_members: values.append(v) return EnumDeclaration(name, values) def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]: name, equals_value = visited_children name = name.data if len(equals_value) == 0: return name, None _, value = equals_value return name, value def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration: union, name, lbrace, tag, body, rbrace = visited_children assert union.type == 'union' assert lbrace.type == '{' assert rbrace.type == '}' name = name.data expected_tagname, body = body if tag.name != expected_tagname: raise NameError(f"tag {tag} does not match switch argument {expected_tagname}") if not isinstance(body, list): body = [body] return UnionDeclaration(name, tag, body) def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]: switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children assert switch.type == 'switch' assert lparen.type == '(' assert rparen.type == ')' assert lbrace.type == '{' assert rbrace.type == '}' return tag.data, body def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]: cases, var = visited_children if isinstance(cases, list): cases = cases[0] if isinstance(var, VariableDeclaration): return cases, var else: return cases, None def visit_CaseSpecifier(self, node, visited_children) -> Expression: while isinstance(visited_children, list) and len(visited_children) == 1: visited_children = visited_children[0] # TODO don't explode on 'default:' case, expr, colon = visited_children while isinstance(expr, list): expr = expr[0] return expr def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration: fragile, union, name, lbrace, body, rbrace = visited_children assert fragile.type == 'fragile' assert union.type == 'union' assert lbrace.type == '{' assert rbrace.type == '}' name = name.data return UnionDeclaration(name, None, body) def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration: signature, semi = visited_children assert semi.type == ';' return signature def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition: type, name, eq, value, semi = visited_children assert eq.type == '=' assert semi.type == ';' name = name.data return VariableDefinition(name, type, value) def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration: type, name, semi = visited_children assert semi.type == ';' name = name.data return VariableDeclaration(name, type) def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition: signature, body = visited_children return FunctionDefinition(signature.name, signature.return_type, signature.args, body) def visit_FunctionSignature(self, node, visited_children) -> FunctionDeclaration: return_type, name, lparen, args, rparen = visited_children assert name.type == 'identifier' name = name.data assert lparen.type == '(' assert rparen.type == ')' return FunctionDeclaration(name, return_type, args) def visit_BasicType(self, node, visited_children) -> Type: while isinstance(visited_children, list) and len(visited_children) == 1: visited_children = visited_children[0] if isinstance(visited_children, list): if len(visited_children) == 3: # parenthesized! lparen, ty, rparen = visited_children assert lparen.type == '(' assert rparen.type == ')' return ty else: category, name = visited_children category = category.type name = name.data return BasicType(f"{category} {name}") return BasicType(visited_children.type) def visit_ArrayType(self, node, visited_children) -> ArrayType: contents, lbracket, size, rbracket = visited_children assert lbracket.type == '[' assert rbracket.type == ']' # TODO don't explode on nontrivial expression while isinstance(size, list): size = size[0] return ArrayType(contents, size) def visit_PointerType(self, node, visited_children) -> PointerType: contents, splat = visited_children assert splat.type == '*' return PointerType(contents) def visit_AtomicExpression(self, node, visited_children) -> Expression: if isinstance(visited_children, list) and len(visited_children) == 3: lparen, body, rparen = visited_children assert lparen.type == '(' assert rparen.type == ')' return body body = visited_children while isinstance(body, list): body = body[0] if body.type == 'identifier': return VariableExpression(body.data) if body.type == 'constant': return ConstantExpression(body.data) if body.type in ['true', 'false']: return ConstantExpression(body.type) raise NotImplementedError() def generic_visit(self, node, visited_children): """ The generic visit method. """ if not visited_children: if len(node.text) == 0: return [] if len(node.text) == 1: return node.text[0] raise ValueError('just a node: ' + str(node)) if len(visited_children) == 1: return visited_children[0] return visited_children def build_ast(parse_tree, include_dirs): builder = ASTBuilder(include_dirs) return builder.visit(parse_tree)