aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler/declarations.py
diff options
context:
space:
mode:
Diffstat (limited to 'crowbar_reference_compiler/declarations.py')
-rw-r--r--crowbar_reference_compiler/declarations.py171
1 files changed, 159 insertions, 12 deletions
diff --git a/crowbar_reference_compiler/declarations.py b/crowbar_reference_compiler/declarations.py
index 9f1a6b2..e9a7e4b 100644
--- a/crowbar_reference_compiler/declarations.py
+++ b/crowbar_reference_compiler/declarations.py
@@ -1,4 +1,6 @@
+from dataclasses import dataclass
from pathlib import Path
+from typing import List, Optional, Tuple, Union
from parsimonious import NodeVisitor # type: ignore
@@ -6,6 +8,71 @@ from .scanner import scan
from .parser import parse_header
+@dataclass
+class Type:
+ pass
+
+
+@dataclass
+class BasicType(Type):
+ name: str
+
+
+@dataclass
+class PointerType(Type):
+ target: Type
+
+
+@dataclass
+class ArrayType(Type):
+ contents: Type
+ size: int
+
+
+@dataclass
+class Declaration:
+ name: str
+
+
+@dataclass
+class VariableDeclaration(Declaration):
+ """Represents the declaration of a variable."""
+ type: Type
+ value: Optional[str]
+
+
+@dataclass
+class Declarations:
+ included: List[Declaration]
+
+
+@dataclass
+class StructDeclaration(Declaration):
+ """Represents the declaration of a struct type."""
+ fields: Optional[List[VariableDeclaration]]
+
+
+@dataclass
+class EnumDeclaration(Declaration):
+ """Represents the declaration of an enum type."""
+ values: List[Tuple[str, Optional[int]]]
+
+
+@dataclass
+class UnionDeclaration(Declaration):
+ """Represents the declaration of a union type."""
+ tag: Optional[VariableDeclaration]
+ cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]]
+
+
+@dataclass
+class FunctionDeclaration(Declaration):
+ """Represents the declaration of a function."""
+ return_type: Type
+ args: List[VariableDeclaration]
+
+
+# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class DeclarationVisitor(NodeVisitor):
def __init__(self, include_folders):
self.data = []
@@ -16,13 +83,15 @@ class DeclarationVisitor(NodeVisitor):
return elements
def visit_ImplementationFile(self, node, visited_children):
- return [x for x in visited_children if x is not None]
+ includes, elements = visited_children
+ includes = [y for x in includes for y in x]
+ return [x for x in includes + elements if x is not None]
def visit_IncludeStatement(self, node, visited_children):
include, included_header, semicolon = visited_children
assert include.text[0].type == 'include'
assert included_header.type == 'string_literal'
- included_header = included_header.data
+ included_header = included_header.data.strip('"')
assert semicolon.text[0].type == ';'
for include_folder in self.include_folders:
header = Path(include_folder) / included_header
@@ -39,7 +108,9 @@ class DeclarationVisitor(NodeVisitor):
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
- return f"struct {name}"
+ if not isinstance(fields, list):
+ fields = [fields]
+ return StructDeclaration(name, fields)
def visit_OpaqueStructDefinition(self, node, visited_children):
opaque, struct, name, semi = visited_children
@@ -47,7 +118,7 @@ class DeclarationVisitor(NodeVisitor):
assert struct.text[0].type == 'struct'
assert semi.text[0].type == ';'
name = name.data
- return f"struct {name}"
+ return StructDeclaration(name, None)
def visit_EnumDefinition(self, node, visited_children):
enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
@@ -55,7 +126,18 @@ class DeclarationVisitor(NodeVisitor):
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
- return f"enum {name}"
+ values = [first_member]
+ for _, v in extra_members:
+ values.append(v)
+ return EnumDeclaration(name, values)
+
+ def visit_EnumMember(self, node, visited_children):
+ name, equals_value = visited_children
+ name = name.data
+ if not isinstance(equals_value, list):
+ return name, None
+ _, value = equals_value
+ return name, value
def visit_RobustUnionDefinition(self, node, visited_children):
union, name, lbrace, tag, body, rbrace = visited_children
@@ -63,7 +145,40 @@ class DeclarationVisitor(NodeVisitor):
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
- return f"union {name}"
+ expected_tagname, body = body
+ if tag.name != expected_tagname:
+ raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
+ if not isinstance(body, list):
+ body = [body]
+ return UnionDeclaration(name, tag, body)
+
+ def visit_UnionBody(self, node, visited_children):
+ switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
+ assert switch.text[0].type == 'switch'
+ assert lparen.text[0].type == '('
+ assert rparen.text[0].type == ')'
+ assert lbrace.text[0].type == '{'
+ assert rbrace.text[0].type == '}'
+ return tag.data, body
+
+ def visit_UnionBodySet(self, node, visited_children):
+ cases, var = visited_children
+ if isinstance(cases, list):
+ cases = cases[0]
+ if isinstance(var, VariableDeclaration):
+ return cases, var
+ else:
+ return cases, None
+
+ def visit_CaseSpecifier(self, node, visited_children):
+ while isinstance(visited_children, list) and len(visited_children) == 1:
+ visited_children = visited_children[0]
+ # TODO don't explode on 'default:'
+ case, expr, colon = visited_children
+ while isinstance(expr, list):
+ expr = expr[0]
+ # TODO don't explode on nontrivial expression
+ return expr.data
def visit_FragileUnionDefinition(self, node, visited_children):
fragile, union, name, lbrace, body, rbrace = visited_children
@@ -72,7 +187,7 @@ class DeclarationVisitor(NodeVisitor):
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
- return f"union {name}"
+ return UnionDeclaration(name, None, body)
def visit_FunctionDeclaration(self, node, visited_children):
signature, semi = visited_children
@@ -84,13 +199,13 @@ class DeclarationVisitor(NodeVisitor):
assert eq.text[0].type == '='
assert semi.text[0].type == ';'
name = name.data
- return name
+ return VariableDeclaration(name, type, value)
def visit_VariableDeclaration(self, node, visited_children):
type, name, semi = visited_children
assert semi.text[0].type == ';'
name = name.data
- return name
+ return VariableDeclaration(name, type, None)
def visit_FunctionDefinition(self, node, visited_children):
signature, body = visited_children
@@ -102,7 +217,39 @@ class DeclarationVisitor(NodeVisitor):
name = name.data
assert lparen.text[0].type == '('
assert rparen.text[0].type == ')'
- return return_type, name, args
+ return FunctionDeclaration(name, return_type, args)
+
+ def visit_BasicType(self, node, visited_children):
+ while isinstance(visited_children, list) and len(visited_children) == 1:
+ visited_children = visited_children[0]
+ if isinstance(visited_children, list):
+ if len(visited_children) == 3:
+ # parenthesized!
+ lparen, ty, rparen = visited_children
+ assert lparen.text[0].type == '('
+ assert rparen.text[0].type == ')'
+ return ty
+ else:
+ category, name = visited_children
+ category = category.text[0].type
+ name = name.data
+ return BasicType(f"{category} {name}")
+ return BasicType(visited_children.text[0].type)
+
+ def visit_ArrayType(self, node, visited_children):
+ contents, lbracket, size, rbracket = visited_children
+ assert lbracket.text[0].type == '['
+ assert rbracket.text[0].type == ']'
+ # TODO don't explode on nontrivial expression
+ while isinstance(size, list):
+ size = size[0]
+ size = int(size.data)
+ return ArrayType(contents, size)
+
+ def visit_PointerType(self, node, visited_children):
+ contents, splat = visited_children
+ assert splat.text[0].type == '*'
+ return PointerType(contents)
def visit_constant(self, node, visited_children):
return node.text[0]
@@ -122,6 +269,6 @@ class DeclarationVisitor(NodeVisitor):
return visited_children
-def load_declarations(parse_tree):
- declarations = DeclarationVisitor([])
+def load_declarations(parse_tree, include_dirs):
+ declarations = DeclarationVisitor(include_dirs)
return declarations.visit(parse_tree)