diff options
-rw-r--r-- | crowbar_reference_compiler/__init__.py | 7 | ||||
-rw-r--r-- | crowbar_reference_compiler/declarations.py | 171 | ||||
-rw-r--r-- | crowbar_reference_compiler/scanner.py | 2 | ||||
-rw-r--r-- | tests/test_declarations.py | 19 |
4 files changed, 181 insertions, 18 deletions
diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py index 7a46e01..c7baeea 100644 --- a/crowbar_reference_compiler/__init__.py +++ b/crowbar_reference_compiler/__init__.py @@ -16,8 +16,8 @@ def main(): args.add_argument('--stop-at-qbe-ssa', action='store_true') args.add_argument('-S', '--stop-at-assembly', action='store_true') args.add_argument('-c', '--stop-at-object', action='store_true') - args.add_argument('-D', '--define-constant', help='define a constant with some literal value') - args.add_argument('-I', '--include-dir', help='folder to look for included headers within') + args.add_argument('-D', '--define-constant', action='append', help='define a constant with some literal value') + args.add_argument('-I', '--include-dir', action='append', help='folder to look for included headers within') args.add_argument('-o', '--out', help='output file') args.add_argument('input', help='input file') @@ -33,7 +33,8 @@ def main(): output_file.write(str(parse_tree)) return - decls = load_declarations(parse_tree) + decls = load_declarations(parse_tree, args.include_dir) + print(decls) ssa = compile_to_ssa(parse_tree) if args.stop_at_qbe_ssa: diff --git a/crowbar_reference_compiler/declarations.py b/crowbar_reference_compiler/declarations.py index 9f1a6b2..e9a7e4b 100644 --- a/crowbar_reference_compiler/declarations.py +++ b/crowbar_reference_compiler/declarations.py @@ -1,4 +1,6 @@ +from dataclasses import dataclass from pathlib import Path +from typing import List, Optional, Tuple, Union from parsimonious import NodeVisitor # type: ignore @@ -6,6 +8,71 @@ from .scanner import scan from .parser import parse_header +@dataclass +class Type: + pass + + +@dataclass +class BasicType(Type): + name: str + + +@dataclass +class PointerType(Type): + target: Type + + +@dataclass +class ArrayType(Type): + contents: Type + size: int + + +@dataclass +class Declaration: + name: str + + +@dataclass +class VariableDeclaration(Declaration): + """Represents the declaration of a variable.""" + type: Type + value: Optional[str] + + +@dataclass +class Declarations: + included: List[Declaration] + + +@dataclass +class StructDeclaration(Declaration): + """Represents the declaration of a struct type.""" + fields: Optional[List[VariableDeclaration]] + + +@dataclass +class EnumDeclaration(Declaration): + """Represents the declaration of an enum type.""" + values: List[Tuple[str, Optional[int]]] + + +@dataclass +class UnionDeclaration(Declaration): + """Represents the declaration of a union type.""" + tag: Optional[VariableDeclaration] + cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]] + + +@dataclass +class FunctionDeclaration(Declaration): + """Represents the declaration of a function.""" + return_type: Type + args: List[VariableDeclaration] + + +# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal class DeclarationVisitor(NodeVisitor): def __init__(self, include_folders): self.data = [] @@ -16,13 +83,15 @@ class DeclarationVisitor(NodeVisitor): return elements def visit_ImplementationFile(self, node, visited_children): - return [x for x in visited_children if x is not None] + includes, elements = visited_children + includes = [y for x in includes for y in x] + return [x for x in includes + elements if x is not None] def visit_IncludeStatement(self, node, visited_children): include, included_header, semicolon = visited_children assert include.text[0].type == 'include' assert included_header.type == 'string_literal' - included_header = included_header.data + included_header = included_header.data.strip('"') assert semicolon.text[0].type == ';' for include_folder in self.include_folders: header = Path(include_folder) / included_header @@ -39,7 +108,9 @@ class DeclarationVisitor(NodeVisitor): assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' name = name.data - return f"struct {name}" + if not isinstance(fields, list): + fields = [fields] + return StructDeclaration(name, fields) def visit_OpaqueStructDefinition(self, node, visited_children): opaque, struct, name, semi = visited_children @@ -47,7 +118,7 @@ class DeclarationVisitor(NodeVisitor): assert struct.text[0].type == 'struct' assert semi.text[0].type == ';' name = name.data - return f"struct {name}" + return StructDeclaration(name, None) def visit_EnumDefinition(self, node, visited_children): enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children @@ -55,7 +126,18 @@ class DeclarationVisitor(NodeVisitor): assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' name = name.data - return f"enum {name}" + values = [first_member] + for _, v in extra_members: + values.append(v) + return EnumDeclaration(name, values) + + def visit_EnumMember(self, node, visited_children): + name, equals_value = visited_children + name = name.data + if not isinstance(equals_value, list): + return name, None + _, value = equals_value + return name, value def visit_RobustUnionDefinition(self, node, visited_children): union, name, lbrace, tag, body, rbrace = visited_children @@ -63,7 +145,40 @@ class DeclarationVisitor(NodeVisitor): assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' name = name.data - return f"union {name}" + expected_tagname, body = body + if tag.name != expected_tagname: + raise NameError(f"tag {tag} does not match switch argument {expected_tagname}") + if not isinstance(body, list): + body = [body] + return UnionDeclaration(name, tag, body) + + def visit_UnionBody(self, node, visited_children): + switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children + assert switch.text[0].type == 'switch' + assert lparen.text[0].type == '(' + assert rparen.text[0].type == ')' + assert lbrace.text[0].type == '{' + assert rbrace.text[0].type == '}' + return tag.data, body + + def visit_UnionBodySet(self, node, visited_children): + cases, var = visited_children + if isinstance(cases, list): + cases = cases[0] + if isinstance(var, VariableDeclaration): + return cases, var + else: + return cases, None + + def visit_CaseSpecifier(self, node, visited_children): + while isinstance(visited_children, list) and len(visited_children) == 1: + visited_children = visited_children[0] + # TODO don't explode on 'default:' + case, expr, colon = visited_children + while isinstance(expr, list): + expr = expr[0] + # TODO don't explode on nontrivial expression + return expr.data def visit_FragileUnionDefinition(self, node, visited_children): fragile, union, name, lbrace, body, rbrace = visited_children @@ -72,7 +187,7 @@ class DeclarationVisitor(NodeVisitor): assert lbrace.text[0].type == '{' assert rbrace.text[0].type == '}' name = name.data - return f"union {name}" + return UnionDeclaration(name, None, body) def visit_FunctionDeclaration(self, node, visited_children): signature, semi = visited_children @@ -84,13 +199,13 @@ class DeclarationVisitor(NodeVisitor): assert eq.text[0].type == '=' assert semi.text[0].type == ';' name = name.data - return name + return VariableDeclaration(name, type, value) def visit_VariableDeclaration(self, node, visited_children): type, name, semi = visited_children assert semi.text[0].type == ';' name = name.data - return name + return VariableDeclaration(name, type, None) def visit_FunctionDefinition(self, node, visited_children): signature, body = visited_children @@ -102,7 +217,39 @@ class DeclarationVisitor(NodeVisitor): name = name.data assert lparen.text[0].type == '(' assert rparen.text[0].type == ')' - return return_type, name, args + return FunctionDeclaration(name, return_type, args) + + def visit_BasicType(self, node, visited_children): + while isinstance(visited_children, list) and len(visited_children) == 1: + visited_children = visited_children[0] + if isinstance(visited_children, list): + if len(visited_children) == 3: + # parenthesized! + lparen, ty, rparen = visited_children + assert lparen.text[0].type == '(' + assert rparen.text[0].type == ')' + return ty + else: + category, name = visited_children + category = category.text[0].type + name = name.data + return BasicType(f"{category} {name}") + return BasicType(visited_children.text[0].type) + + def visit_ArrayType(self, node, visited_children): + contents, lbracket, size, rbracket = visited_children + assert lbracket.text[0].type == '[' + assert rbracket.text[0].type == ']' + # TODO don't explode on nontrivial expression + while isinstance(size, list): + size = size[0] + size = int(size.data) + return ArrayType(contents, size) + + def visit_PointerType(self, node, visited_children): + contents, splat = visited_children + assert splat.text[0].type == '*' + return PointerType(contents) def visit_constant(self, node, visited_children): return node.text[0] @@ -122,6 +269,6 @@ class DeclarationVisitor(NodeVisitor): return visited_children -def load_declarations(parse_tree): - declarations = DeclarationVisitor([]) +def load_declarations(parse_tree, include_dirs): + declarations = DeclarationVisitor(include_dirs) return declarations.visit(parse_tree) diff --git a/crowbar_reference_compiler/scanner.py b/crowbar_reference_compiler/scanner.py index abf274b..7cf5743 100644 --- a/crowbar_reference_compiler/scanner.py +++ b/crowbar_reference_compiler/scanner.py @@ -78,7 +78,7 @@ def scan(code): remaining = remaining[id_match.end():] continue was_constant = False - for constant in [DECIMAL_CONSTANT, BINARY_CONSTANT, OCTAL_CONSTANT, HEX_CONSTANT, FLOAT_CONSTANT, HEX_FLOAT_CONSTANT, CHAR_CONSTANT]: + for constant in [HEX_CONSTANT, BINARY_CONSTANT, OCTAL_CONSTANT, HEX_FLOAT_CONSTANT, FLOAT_CONSTANT, DECIMAL_CONSTANT, CHAR_CONSTANT]: match = constant.match(remaining) if match: result.append(Token('constant', match.group())) diff --git a/tests/test_declarations.py b/tests/test_declarations.py index 3fd6683..9eaf488 100644 --- a/tests/test_declarations.py +++ b/tests/test_declarations.py @@ -1,6 +1,8 @@ import unittest from crowbar_reference_compiler import compile_to_ssa, load_declarations, parse_header, parse_implementation, scan +from crowbar_reference_compiler.declarations import ArrayType, BasicType, EnumDeclaration, PointerType, \ + StructDeclaration, UnionDeclaration, VariableDeclaration class TestDeclarationLoading(unittest.TestCase): @@ -8,6 +10,7 @@ class TestDeclarationLoading(unittest.TestCase): code = r""" struct normal { bool fake; + (uint8[3])* data; } opaque struct ope; @@ -31,8 +34,20 @@ fragile union not_robust { """ tokens = scan(code) parse_tree = parse_header(tokens) - decls = load_declarations(parse_tree) - self.assertListEqual(decls, ['struct normal', 'struct ope', 'enum sample', 'union robust', 'union not_robust']) + decls = load_declarations(parse_tree, []) + normal = StructDeclaration('normal', [ + VariableDeclaration('fake', BasicType('bool'), None), + VariableDeclaration('data', PointerType(ArrayType(BasicType('uint8'), 3)), None), + ]) + ope = StructDeclaration('ope', None) + sample = EnumDeclaration('sample', [('Testing', None)]) + robust = UnionDeclaration('robust', VariableDeclaration('tag', BasicType('enum sample'), None), + [('Testing', VariableDeclaration('testPassed', BasicType('bool'), None))]) + not_robust = UnionDeclaration('not_robust', None, + [VariableDeclaration('sample', BasicType('int8'), None), + VariableDeclaration('nope', BasicType('bool'), None)]) + self.assertListEqual(decls, [normal, ope, sample, robust, not_robust]) + if __name__ == '__main__': unittest.main() |