aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler/declarations.py
diff options
context:
space:
mode:
Diffstat (limited to 'crowbar_reference_compiler/declarations.py')
-rw-r--r--crowbar_reference_compiler/declarations.py127
1 files changed, 127 insertions, 0 deletions
diff --git a/crowbar_reference_compiler/declarations.py b/crowbar_reference_compiler/declarations.py
new file mode 100644
index 0000000..9f1a6b2
--- /dev/null
+++ b/crowbar_reference_compiler/declarations.py
@@ -0,0 +1,127 @@
+from pathlib import Path
+
+from parsimonious import NodeVisitor # type: ignore
+
+from .scanner import scan
+from .parser import parse_header
+
+
+class DeclarationVisitor(NodeVisitor):
+ def __init__(self, include_folders):
+ self.data = []
+ self.include_folders = include_folders
+
+ def visit_HeaderFile(self, node, visited_children):
+ includes, elements = visited_children
+ return elements
+
+ def visit_ImplementationFile(self, node, visited_children):
+ return [x for x in visited_children if x is not None]
+
+ def visit_IncludeStatement(self, node, visited_children):
+ include, included_header, semicolon = visited_children
+ assert include.text[0].type == 'include'
+ assert included_header.type == 'string_literal'
+ included_header = included_header.data
+ assert semicolon.text[0].type == ';'
+ for include_folder in self.include_folders:
+ header = Path(include_folder) / included_header
+ if header.exists():
+ with open(header, 'r', encoding='utf-8') as header_file:
+ header_text = header_file.read()
+ header_parse_tree = parse_header(scan(header_text))
+ return self.visit(header_parse_tree)
+ raise FileNotFoundError(included_header)
+
+ def visit_NormalStructDefinition(self, node, visited_children):
+ struct, name, lbrace, fields, rbrace = visited_children
+ assert struct.text[0].type == 'struct'
+ assert lbrace.text[0].type == '{'
+ assert rbrace.text[0].type == '}'
+ name = name.data
+ return f"struct {name}"
+
+ def visit_OpaqueStructDefinition(self, node, visited_children):
+ opaque, struct, name, semi = visited_children
+ assert opaque.text[0].type == 'opaque'
+ assert struct.text[0].type == 'struct'
+ assert semi.text[0].type == ';'
+ name = name.data
+ return f"struct {name}"
+
+ def visit_EnumDefinition(self, node, visited_children):
+ enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
+ assert enum.text[0].type == 'enum'
+ assert lbrace.text[0].type == '{'
+ assert rbrace.text[0].type == '}'
+ name = name.data
+ return f"enum {name}"
+
+ def visit_RobustUnionDefinition(self, node, visited_children):
+ union, name, lbrace, tag, body, rbrace = visited_children
+ assert union.text[0].type == 'union'
+ assert lbrace.text[0].type == '{'
+ assert rbrace.text[0].type == '}'
+ name = name.data
+ return f"union {name}"
+
+ def visit_FragileUnionDefinition(self, node, visited_children):
+ fragile, union, name, lbrace, body, rbrace = visited_children
+ assert fragile.text[0].type == 'fragile'
+ assert union.text[0].type == 'union'
+ assert lbrace.text[0].type == '{'
+ assert rbrace.text[0].type == '}'
+ name = name.data
+ return f"union {name}"
+
+ def visit_FunctionDeclaration(self, node, visited_children):
+ signature, semi = visited_children
+ assert semi.text[0].type == ';'
+ return signature
+
+ def visit_VariableDefinition(self, node, visited_children):
+ type, name, eq, value, semi = visited_children
+ assert eq.text[0].type == '='
+ assert semi.text[0].type == ';'
+ name = name.data
+ return name
+
+ def visit_VariableDeclaration(self, node, visited_children):
+ type, name, semi = visited_children
+ assert semi.text[0].type == ';'
+ name = name.data
+ return name
+
+ def visit_FunctionDefinition(self, node, visited_children):
+ signature, body = visited_children
+ return signature
+
+ def visit_FunctionSignature(self, node, visited_children):
+ return_type, name, lparen, args, rparen = visited_children
+ assert name.type == 'identifier'
+ name = name.data
+ assert lparen.text[0].type == '('
+ assert rparen.text[0].type == ')'
+ return return_type, name, args
+
+ def visit_constant(self, node, visited_children):
+ return node.text[0]
+
+ def visit_string_literal(self, node, visited_children):
+ return node.text[0]
+
+ def visit_identifier(self, node, visited_children):
+ return node.text[0]
+
+ def generic_visit(self, node, visited_children):
+ """ The generic visit method. """
+ if not visited_children:
+ return node
+ if len(visited_children) == 1:
+ return visited_children[0]
+ return visited_children
+
+
+def load_declarations(parse_tree):
+ declarations = DeclarationVisitor([])
+ return declarations.visit(parse_tree)