diff options
Diffstat (limited to 'crowbar_reference_compiler/declarations.py')
-rw-r--r-- | crowbar_reference_compiler/declarations.py | 127 |
1 files changed, 127 insertions, 0 deletions
diff --git a/crowbar_reference_compiler/declarations.py b/crowbar_reference_compiler/declarations.py new file mode 100644 index 0000000..9f1a6b2 --- /dev/null +++ b/crowbar_reference_compiler/declarations.py @@ -0,0 +1,127 @@ +from pathlib import Path + +from parsimonious import NodeVisitor # type: ignore + +from .scanner import scan +from .parser import parse_header + + +class DeclarationVisitor(NodeVisitor): + def __init__(self, include_folders): + self.data = [] + self.include_folders = include_folders + + def visit_HeaderFile(self, node, visited_children): + includes, elements = visited_children + return elements + + def visit_ImplementationFile(self, node, visited_children): + return [x for x in visited_children if x is not None] + + def visit_IncludeStatement(self, node, visited_children): + include, included_header, semicolon = visited_children + assert include.text[0].type == 'include' + assert included_header.type == 'string_literal' + included_header = included_header.data + assert semicolon.text[0].type == ';' + for include_folder in self.include_folders: + header = Path(include_folder) / included_header + if header.exists(): + with open(header, 'r', encoding='utf-8') as header_file: + header_text = header_file.read() + header_parse_tree = parse_header(scan(header_text)) + return self.visit(header_parse_tree) + raise FileNotFoundError(included_header) + + def visit_NormalStructDefinition(self, node, visited_children): + struct, name, lbrace, fields, rbrace = visited_children + assert struct.text[0].type == 'struct' + assert lbrace.text[0].type == '{' + assert rbrace.text[0].type == '}' + name = name.data + return f"struct {name}" + + def visit_OpaqueStructDefinition(self, node, visited_children): + opaque, struct, name, semi = visited_children + assert opaque.text[0].type == 'opaque' + assert struct.text[0].type == 'struct' + assert semi.text[0].type == ';' + name = name.data + return f"struct {name}" + + def visit_EnumDefinition(self, node, visited_children): + enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children + assert enum.text[0].type == 'enum' + assert lbrace.text[0].type == '{' + assert rbrace.text[0].type == '}' + name = name.data + return f"enum {name}" + + def visit_RobustUnionDefinition(self, node, visited_children): + union, name, lbrace, tag, body, rbrace = visited_children + assert union.text[0].type == 'union' + assert lbrace.text[0].type == '{' + assert rbrace.text[0].type == '}' + name = name.data + return f"union {name}" + + def visit_FragileUnionDefinition(self, node, visited_children): + fragile, union, name, lbrace, body, rbrace = visited_children + assert fragile.text[0].type == 'fragile' + assert union.text[0].type == 'union' + assert lbrace.text[0].type == '{' + assert rbrace.text[0].type == '}' + name = name.data + return f"union {name}" + + def visit_FunctionDeclaration(self, node, visited_children): + signature, semi = visited_children + assert semi.text[0].type == ';' + return signature + + def visit_VariableDefinition(self, node, visited_children): + type, name, eq, value, semi = visited_children + assert eq.text[0].type == '=' + assert semi.text[0].type == ';' + name = name.data + return name + + def visit_VariableDeclaration(self, node, visited_children): + type, name, semi = visited_children + assert semi.text[0].type == ';' + name = name.data + return name + + def visit_FunctionDefinition(self, node, visited_children): + signature, body = visited_children + return signature + + def visit_FunctionSignature(self, node, visited_children): + return_type, name, lparen, args, rparen = visited_children + assert name.type == 'identifier' + name = name.data + assert lparen.text[0].type == '(' + assert rparen.text[0].type == ')' + return return_type, name, args + + def visit_constant(self, node, visited_children): + return node.text[0] + + def visit_string_literal(self, node, visited_children): + return node.text[0] + + def visit_identifier(self, node, visited_children): + return node.text[0] + + def generic_visit(self, node, visited_children): + """ The generic visit method. """ + if not visited_children: + return node + if len(visited_children) == 1: + return visited_children[0] + return visited_children + + +def load_declarations(parse_tree): + declarations = DeclarationVisitor([]) + return declarations.visit(parse_tree) |