from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple, Union from parsimonious import NodeVisitor # type: ignore from .scanner import scan from .parser import parse_header @dataclass class Type: pass @dataclass class BasicType(Type): name: str @dataclass class PointerType(Type): target: Type @dataclass class ArrayType(Type): contents: Type size: int @dataclass class Declaration: name: str @dataclass class VariableDeclaration(Declaration): """Represents the declaration of a variable.""" type: Type value: Optional[str] @dataclass class Declarations: included: List[Declaration] @dataclass class StructDeclaration(Declaration): """Represents the declaration of a struct type.""" fields: Optional[List[VariableDeclaration]] @dataclass class EnumDeclaration(Declaration): """Represents the declaration of an enum type.""" values: List[Tuple[str, Optional[int]]] @dataclass class UnionDeclaration(Declaration): """Represents the declaration of a union type.""" tag: Optional[VariableDeclaration] cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]] @dataclass class FunctionDeclaration(Declaration): """Represents the declaration of a function.""" return_type: Type args: List[VariableDeclaration] # noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal class DeclarationVisitor(NodeVisitor): def __init__(self, include_folders): self.data = [] self.include_folders = include_folders def visit_HeaderFile(self, node, visited_children): includes, elements = visited_children return elements def visit_ImplementationFile(self, node, visited_children): includes, elements = visited_children includes = [y for x in includes for y in x] return [x for x in includes + elements if x is not None] def visit_IncludeStatement(self, node, visited_children): include, included_header, semicolon = visited_children assert include.text[0].type == 'include' assert included_header.type == 'string_literal' included_header = included_header.data.strip('"') assert semicolon.text[0].type == ';' for include_folder in self.include_folders: header = Path(include_folder) / included_header if header.exists(): with open(header, 'r', encoding='utf-8') as header_file: header_text = header_file.read() header_parse_tree = parse_header(scan(header_text)) return self.visit(header_parse_tree) raise FileNotFoundError(included_header) def visit_NormalStructDefinition(self, node, visited_children): struct, name, lbrace, fields, rbrace = visited_children assert struct.text[0].type == 'struct' assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' name = name.data if not isinstance(fields, list): fields = [fields] return StructDeclaration(name, fields) def visit_OpaqueStructDefinition(self, node, visited_children): opaque, struct, name, semi = visited_children assert opaque.text[0].type == 'opaque' assert struct.text[0].type == 'struct' assert semi.text[0].type == ';' name = name.data return StructDeclaration(name, None) def visit_EnumDefinition(self, node, visited_children): enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children assert enum.text[0].type == 'enum' assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' name = name.data values = [first_member] for _, v in extra_members: values.append(v) return EnumDeclaration(name, values) def visit_EnumMember(self, node, visited_children): name, equals_value = visited_children name = name.data if not isinstance(equals_value, list): return name, None _, value = equals_value return name, value def visit_RobustUnionDefinition(self, node, visited_children): union, name, lbrace, tag, body, rbrace = visited_children assert union.text[0].type == 'union' assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' name = name.data expected_tagname, body = body if tag.name != expected_tagname: raise NameError(f"tag {tag} does not match switch argument {expected_tagname}") if not isinstance(body, list): body = [body] return UnionDeclaration(name, tag, body) def visit_UnionBody(self, node, visited_children): switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children assert switch.text[0].type == 'switch' assert lparen.text[0].type == '(' assert rparen.text[0].type == ')' assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' return tag.data, body def visit_UnionBodySet(self, node, visited_children): cases, var = visited_children if isinstance(cases, list): cases = cases[0] if isinstance(var, VariableDeclaration): return cases, var else: return cases, None def visit_CaseSpecifier(self, node, visited_children): while isinstance(visited_children, list) and len(visited_children) == 1: visited_children = visited_children[0] # TODO don't explode on 'default:' case, expr, colon = visited_children while isinstance(expr, list): expr = expr[0] # TODO don't explode on nontrivial expression return expr.data def visit_FragileUnionDefinition(self, node, visited_children): fragile, union, name, lbrace, body, rbrace = visited_children assert fragile.text[0].type == 'fragile' assert union.text[0].type == 'union' assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' name = name.data return UnionDeclaration(name, None, body) def visit_FunctionDeclaration(self, node, visited_children): signature, semi = visited_children assert semi.text[0].type == ';' return signature def visit_VariableDefinition(self, node, visited_children): type, name, eq, value, semi = visited_children assert eq.text[0].type == '=' assert semi.text[0].type == ';' name = name.data return VariableDeclaration(name, type, value) def visit_VariableDeclaration(self, node, visited_children): type, name, semi = visited_children assert semi.text[0].type == ';' name = name.data return VariableDeclaration(name, type, None) def visit_FunctionDefinition(self, node, visited_children): signature, body = visited_children return signature def visit_FunctionSignature(self, node, visited_children): return_type, name, lparen, args, rparen = visited_children assert name.type == 'identifier' name = name.data assert lparen.text[0].type == '(' assert rparen.text[0].type == ')' return FunctionDeclaration(name, return_type, args) def visit_BasicType(self, node, visited_children): while isinstance(visited_children, list) and len(visited_children) == 1: visited_children = visited_children[0] if isinstance(visited_children, list): if len(visited_children) == 3: # parenthesized! lparen, ty, rparen = visited_children assert lparen.text[0].type == '(' assert rparen.text[0].type == ')' return ty else: category, name = visited_children category = category.text[0].type name = name.data return BasicType(f"{category} {name}") return BasicType(visited_children.text[0].type) def visit_ArrayType(self, node, visited_children): contents, lbracket, size, rbracket = visited_children assert lbracket.text[0].type == '[' assert rbracket.text[0].type == ']' # TODO don't explode on nontrivial expression while isinstance(size, list): size = size[0] size = int(size.data) return ArrayType(contents, size) def visit_PointerType(self, node, visited_children): contents, splat = visited_children assert splat.text[0].type == '*' return PointerType(contents) def visit_constant(self, node, visited_children): return node.text[0] def visit_string_literal(self, node, visited_children): return node.text[0] def visit_identifier(self, node, visited_children): return node.text[0] def generic_visit(self, node, visited_children): """ The generic visit method. """ if not visited_children: return node if len(visited_children) == 1: return visited_children[0] return visited_children def load_declarations(parse_tree, include_dirs): declarations = DeclarationVisitor(include_dirs) return declarations.visit(parse_tree)