aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--crowbar_reference_compiler/__init__.py7
-rw-r--r--crowbar_reference_compiler/declarations.py171
-rw-r--r--crowbar_reference_compiler/scanner.py2
-rw-r--r--tests/test_declarations.py19
4 files changed, 181 insertions, 18 deletions
diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py
index 7a46e01..c7baeea 100644
--- a/crowbar_reference_compiler/__init__.py
+++ b/crowbar_reference_compiler/__init__.py
@@ -16,8 +16,8 @@ def main():
args.add_argument('--stop-at-qbe-ssa', action='store_true')
args.add_argument('-S', '--stop-at-assembly', action='store_true')
args.add_argument('-c', '--stop-at-object', action='store_true')
- args.add_argument('-D', '--define-constant', help='define a constant with some literal value')
- args.add_argument('-I', '--include-dir', help='folder to look for included headers within')
+ args.add_argument('-D', '--define-constant', action='append', help='define a constant with some literal value')
+ args.add_argument('-I', '--include-dir', action='append', help='folder to look for included headers within')
args.add_argument('-o', '--out', help='output file')
args.add_argument('input', help='input file')
@@ -33,7 +33,8 @@ def main():
output_file.write(str(parse_tree))
return
- decls = load_declarations(parse_tree)
+ decls = load_declarations(parse_tree, args.include_dir)
+ print(decls)
ssa = compile_to_ssa(parse_tree)
if args.stop_at_qbe_ssa:
diff --git a/crowbar_reference_compiler/declarations.py b/crowbar_reference_compiler/declarations.py
index 9f1a6b2..e9a7e4b 100644
--- a/crowbar_reference_compiler/declarations.py
+++ b/crowbar_reference_compiler/declarations.py
@@ -1,4 +1,6 @@
+from dataclasses import dataclass
from pathlib import Path
+from typing import List, Optional, Tuple, Union
from parsimonious import NodeVisitor # type: ignore
@@ -6,6 +8,71 @@ from .scanner import scan
from .parser import parse_header
+@dataclass
+class Type:
+ pass
+
+
+@dataclass
+class BasicType(Type):
+ name: str
+
+
+@dataclass
+class PointerType(Type):
+ target: Type
+
+
+@dataclass
+class ArrayType(Type):
+ contents: Type
+ size: int
+
+
+@dataclass
+class Declaration:
+ name: str
+
+
+@dataclass
+class VariableDeclaration(Declaration):
+ """Represents the declaration of a variable."""
+ type: Type
+ value: Optional[str]
+
+
+@dataclass
+class Declarations:
+ included: List[Declaration]
+
+
+@dataclass
+class StructDeclaration(Declaration):
+ """Represents the declaration of a struct type."""
+ fields: Optional[List[VariableDeclaration]]
+
+
+@dataclass
+class EnumDeclaration(Declaration):
+ """Represents the declaration of an enum type."""
+ values: List[Tuple[str, Optional[int]]]
+
+
+@dataclass
+class UnionDeclaration(Declaration):
+ """Represents the declaration of a union type."""
+ tag: Optional[VariableDeclaration]
+ cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]]
+
+
+@dataclass
+class FunctionDeclaration(Declaration):
+ """Represents the declaration of a function."""
+ return_type: Type
+ args: List[VariableDeclaration]
+
+
+# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class DeclarationVisitor(NodeVisitor):
def __init__(self, include_folders):
self.data = []
@@ -16,13 +83,15 @@ class DeclarationVisitor(NodeVisitor):
return elements
def visit_ImplementationFile(self, node, visited_children):
- return [x for x in visited_children if x is not None]
+ includes, elements = visited_children
+ includes = [y for x in includes for y in x]
+ return [x for x in includes + elements if x is not None]
def visit_IncludeStatement(self, node, visited_children):
include, included_header, semicolon = visited_children
assert include.text[0].type == 'include'
assert included_header.type == 'string_literal'
- included_header = included_header.data
+ included_header = included_header.data.strip('"')
assert semicolon.text[0].type == ';'
for include_folder in self.include_folders:
header = Path(include_folder) / included_header
@@ -39,7 +108,9 @@ class DeclarationVisitor(NodeVisitor):
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
- return f"struct {name}"
+ if not isinstance(fields, list):
+ fields = [fields]
+ return StructDeclaration(name, fields)
def visit_OpaqueStructDefinition(self, node, visited_children):
opaque, struct, name, semi = visited_children
@@ -47,7 +118,7 @@ class DeclarationVisitor(NodeVisitor):
assert struct.text[0].type == 'struct'
assert semi.text[0].type == ';'
name = name.data
- return f"struct {name}"
+ return StructDeclaration(name, None)
def visit_EnumDefinition(self, node, visited_children):
enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
@@ -55,7 +126,18 @@ class DeclarationVisitor(NodeVisitor):
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
- return f"enum {name}"
+ values = [first_member]
+ for _, v in extra_members:
+ values.append(v)
+ return EnumDeclaration(name, values)
+
+ def visit_EnumMember(self, node, visited_children):
+ name, equals_value = visited_children
+ name = name.data
+ if not isinstance(equals_value, list):
+ return name, None
+ _, value = equals_value
+ return name, value
def visit_RobustUnionDefinition(self, node, visited_children):
union, name, lbrace, tag, body, rbrace = visited_children
@@ -63,7 +145,40 @@ class DeclarationVisitor(NodeVisitor):
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
- return f"union {name}"
+ expected_tagname, body = body
+ if tag.name != expected_tagname:
+ raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
+ if not isinstance(body, list):
+ body = [body]
+ return UnionDeclaration(name, tag, body)
+
+ def visit_UnionBody(self, node, visited_children):
+ switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
+ assert switch.text[0].type == 'switch'
+ assert lparen.text[0].type == '('
+ assert rparen.text[0].type == ')'
+ assert lbrace.text[0].type == '{'
+ assert rbrace.text[0].type == '}'
+ return tag.data, body
+
+ def visit_UnionBodySet(self, node, visited_children):
+ cases, var = visited_children
+ if isinstance(cases, list):
+ cases = cases[0]
+ if isinstance(var, VariableDeclaration):
+ return cases, var
+ else:
+ return cases, None
+
+ def visit_CaseSpecifier(self, node, visited_children):
+ while isinstance(visited_children, list) and len(visited_children) == 1:
+ visited_children = visited_children[0]
+ # TODO don't explode on 'default:'
+ case, expr, colon = visited_children
+ while isinstance(expr, list):
+ expr = expr[0]
+ # TODO don't explode on nontrivial expression
+ return expr.data
def visit_FragileUnionDefinition(self, node, visited_children):
fragile, union, name, lbrace, body, rbrace = visited_children
@@ -72,7 +187,7 @@ class DeclarationVisitor(NodeVisitor):
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
- return f"union {name}"
+ return UnionDeclaration(name, None, body)
def visit_FunctionDeclaration(self, node, visited_children):
signature, semi = visited_children
@@ -84,13 +199,13 @@ class DeclarationVisitor(NodeVisitor):
assert eq.text[0].type == '='
assert semi.text[0].type == ';'
name = name.data
- return name
+ return VariableDeclaration(name, type, value)
def visit_VariableDeclaration(self, node, visited_children):
type, name, semi = visited_children
assert semi.text[0].type == ';'
name = name.data
- return name
+ return VariableDeclaration(name, type, None)
def visit_FunctionDefinition(self, node, visited_children):
signature, body = visited_children
@@ -102,7 +217,39 @@ class DeclarationVisitor(NodeVisitor):
name = name.data
assert lparen.text[0].type == '('
assert rparen.text[0].type == ')'
- return return_type, name, args
+ return FunctionDeclaration(name, return_type, args)
+
+ def visit_BasicType(self, node, visited_children):
+ while isinstance(visited_children, list) and len(visited_children) == 1:
+ visited_children = visited_children[0]
+ if isinstance(visited_children, list):
+ if len(visited_children) == 3:
+ # parenthesized!
+ lparen, ty, rparen = visited_children
+ assert lparen.text[0].type == '('
+ assert rparen.text[0].type == ')'
+ return ty
+ else:
+ category, name = visited_children
+ category = category.text[0].type
+ name = name.data
+ return BasicType(f"{category} {name}")
+ return BasicType(visited_children.text[0].type)
+
+ def visit_ArrayType(self, node, visited_children):
+ contents, lbracket, size, rbracket = visited_children
+ assert lbracket.text[0].type == '['
+ assert rbracket.text[0].type == ']'
+ # TODO don't explode on nontrivial expression
+ while isinstance(size, list):
+ size = size[0]
+ size = int(size.data)
+ return ArrayType(contents, size)
+
+ def visit_PointerType(self, node, visited_children):
+ contents, splat = visited_children
+ assert splat.text[0].type == '*'
+ return PointerType(contents)
def visit_constant(self, node, visited_children):
return node.text[0]
@@ -122,6 +269,6 @@ class DeclarationVisitor(NodeVisitor):
return visited_children
-def load_declarations(parse_tree):
- declarations = DeclarationVisitor([])
+def load_declarations(parse_tree, include_dirs):
+ declarations = DeclarationVisitor(include_dirs)
return declarations.visit(parse_tree)
diff --git a/crowbar_reference_compiler/scanner.py b/crowbar_reference_compiler/scanner.py
index abf274b..7cf5743 100644
--- a/crowbar_reference_compiler/scanner.py
+++ b/crowbar_reference_compiler/scanner.py
@@ -78,7 +78,7 @@ def scan(code):
remaining = remaining[id_match.end():]
continue
was_constant = False
- for constant in [DECIMAL_CONSTANT, BINARY_CONSTANT, OCTAL_CONSTANT, HEX_CONSTANT, FLOAT_CONSTANT, HEX_FLOAT_CONSTANT, CHAR_CONSTANT]:
+ for constant in [HEX_CONSTANT, BINARY_CONSTANT, OCTAL_CONSTANT, HEX_FLOAT_CONSTANT, FLOAT_CONSTANT, DECIMAL_CONSTANT, CHAR_CONSTANT]:
match = constant.match(remaining)
if match:
result.append(Token('constant', match.group()))
diff --git a/tests/test_declarations.py b/tests/test_declarations.py
index 3fd6683..9eaf488 100644
--- a/tests/test_declarations.py
+++ b/tests/test_declarations.py
@@ -1,6 +1,8 @@
import unittest
from crowbar_reference_compiler import compile_to_ssa, load_declarations, parse_header, parse_implementation, scan
+from crowbar_reference_compiler.declarations import ArrayType, BasicType, EnumDeclaration, PointerType, \
+ StructDeclaration, UnionDeclaration, VariableDeclaration
class TestDeclarationLoading(unittest.TestCase):
@@ -8,6 +10,7 @@ class TestDeclarationLoading(unittest.TestCase):
code = r"""
struct normal {
bool fake;
+ (uint8[3])* data;
}
opaque struct ope;
@@ -31,8 +34,20 @@ fragile union not_robust {
"""
tokens = scan(code)
parse_tree = parse_header(tokens)
- decls = load_declarations(parse_tree)
- self.assertListEqual(decls, ['struct normal', 'struct ope', 'enum sample', 'union robust', 'union not_robust'])
+ decls = load_declarations(parse_tree, [])
+ normal = StructDeclaration('normal', [
+ VariableDeclaration('fake', BasicType('bool'), None),
+ VariableDeclaration('data', PointerType(ArrayType(BasicType('uint8'), 3)), None),
+ ])
+ ope = StructDeclaration('ope', None)
+ sample = EnumDeclaration('sample', [('Testing', None)])
+ robust = UnionDeclaration('robust', VariableDeclaration('tag', BasicType('enum sample'), None),
+ [('Testing', VariableDeclaration('testPassed', BasicType('bool'), None))])
+ not_robust = UnionDeclaration('not_robust', None,
+ [VariableDeclaration('sample', BasicType('int8'), None),
+ VariableDeclaration('nope', BasicType('bool'), None)])
+ self.assertListEqual(decls, [normal, ope, sample, robust, not_robust])
+
if __name__ == '__main__':
unittest.main()