aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler
diff options
context:
space:
mode:
Diffstat (limited to 'crowbar_reference_compiler')
-rw-r--r--crowbar_reference_compiler/__init__.py9
-rw-r--r--crowbar_reference_compiler/ast.py432
-rw-r--r--crowbar_reference_compiler/declarations.py274
3 files changed, 438 insertions, 277 deletions
diff --git a/crowbar_reference_compiler/__init__.py b/crowbar_reference_compiler/__init__.py
index c7baeea..53d942f 100644
--- a/crowbar_reference_compiler/__init__.py
+++ b/crowbar_reference_compiler/__init__.py
@@ -1,4 +1,7 @@
-from .declarations import load_declarations
+import dataclasses
+from pprint import pprint
+
+from .ast import build_ast
from .parser import parse_header, parse_implementation
from .scanner import scan
from .ssagen import compile_to_ssa
@@ -33,8 +36,8 @@ def main():
output_file.write(str(parse_tree))
return
- decls = load_declarations(parse_tree, args.include_dir)
- print(decls)
+ full_ast = build_ast(parse_tree, args.include_dir)
+ pprint(dataclasses.asdict(full_ast))
ssa = compile_to_ssa(parse_tree)
if args.stop_at_qbe_ssa:
diff --git a/crowbar_reference_compiler/ast.py b/crowbar_reference_compiler/ast.py
new file mode 100644
index 0000000..4b9c21d
--- /dev/null
+++ b/crowbar_reference_compiler/ast.py
@@ -0,0 +1,432 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+from parsimonious import NodeVisitor # type: ignore
+
+from .scanner import scan
+from .parser import parse_header
+
+
+@dataclass
+class Type:
+ pass
+
+
+@dataclass
+class Expression:
+ pass
+
+
+@dataclass
+class ConstantExpression(Expression):
+ value: str
+
+
+@dataclass
+class VariableExpression(Expression):
+ name: str
+
+
+@dataclass
+class BasicType(Type):
+ name: str
+
+
+@dataclass
+class ConstType(Type):
+ target: Type
+
+
+@dataclass
+class PointerType(Type):
+ target: Type
+
+
+@dataclass
+class ArrayType(Type):
+ contents: Type
+ size: Expression
+
+
+@dataclass
+class FunctionType(Type):
+ return_type: Type
+ args: List[Type]
+
+
+@dataclass
+class HeaderFileElement:
+ pass
+
+
+@dataclass
+class ImplementationFileElement:
+ pass
+
+
+@dataclass
+class Statement:
+ pass
+
+
+@dataclass
+class EmptyStatement(Statement):
+ pass
+
+
+@dataclass
+class FragileStatement(Statement):
+ body: Statement
+
+
+@dataclass
+class ExpressionStatement(Statement):
+ body: Expression
+
+
+@dataclass
+class IfStatement(Statement):
+ condition: Expression
+ then: List[Statement]
+ els: Optional[List[Statement]]
+
+
+@dataclass
+class SwitchStatement(Statement):
+ expression: Expression
+ body: List[Union[Optional[Expression], Statement]]
+
+
+@dataclass
+class WhileStatement(Statement):
+ condition: Expression
+ body: List[Statement]
+
+
+@dataclass
+class DoWhileStatement(Statement):
+ condition: Expression
+ body: List[Statement]
+
+
+@dataclass
+class Declaration:
+ name: str
+
+
+@dataclass
+class VariableDeclaration(Declaration, HeaderFileElement):
+ """Represents the declaration of a variable."""
+ type: Type
+
+
+@dataclass
+class VariableDefinition(Declaration, ImplementationFileElement, Statement):
+ """Represents the definition of a variable."""
+ type: Type
+ value: Expression
+
+
+@dataclass
+class AssignmentStatement(Statement):
+ pass
+
+
+@dataclass
+class ForStatement(Statement):
+ init: List[VariableDefinition]
+ condition: Expression
+ update: List[AssignmentStatement]
+
+
+@dataclass
+class ContinueStatement(Statement):
+ pass
+
+
+@dataclass
+class BreakStatement(Statement):
+ pass
+
+
+@dataclass
+class ReturnStatement(Statement):
+ body: Optional[Expression]
+
+
+@dataclass
+class DirectAssignment(AssignmentStatement):
+ destination: Expression
+ value: Expression
+
+
+@dataclass
+class UpdateAssignment(AssignmentStatement):
+ destination: Expression
+ operation: str
+ value: Expression
+
+
+@dataclass
+class CrementAssignment(AssignmentStatement):
+ destination: Expression
+ operation: str
+
+
+@dataclass
+class StructDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
+ """Represents the declaration of a struct type."""
+ fields: Optional[List[VariableDeclaration]]
+
+
+@dataclass
+class EnumDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
+ """Represents the declaration of an enum type."""
+ values: List[Tuple[str, Optional[int]]]
+
+
+@dataclass
+class UnionDeclaration(Declaration, HeaderFileElement, ImplementationFileElement):
+ """Represents the declaration of a union type."""
+ tag: Optional[VariableDeclaration]
+ cases: Union[List[VariableDeclaration], List[Tuple[Expression, Optional[VariableDeclaration]]]]
+
+
+@dataclass
+class FunctionDeclaration(Declaration, HeaderFileElement):
+ """Represents the declaration of a function."""
+ return_type: Type
+ args: List[VariableDeclaration]
+
+
+@dataclass
+class FunctionDefinition(Declaration, HeaderFileElement, ImplementationFileElement):
+ """Represents the definition of a function."""
+ return_type: Type
+ args: List[VariableDeclaration]
+ body: List[Statement]
+
+
+@dataclass
+class HeaderFile:
+ includes: List['HeaderFile']
+ contents: List[HeaderFileElement]
+
+
+@dataclass
+class ImplementationFile:
+ includes: List[HeaderFile]
+ contents: List[ImplementationFileElement]
+
+
+# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
+class ASTBuilder(NodeVisitor):
+ def __init__(self, include_folders):
+ self.include_folders = include_folders
+
+ def visit_HeaderFile(self, node, visited_children) -> HeaderFile:
+ includes, elements = visited_children
+ if not isinstance(includes, list):
+ includes = []
+ return HeaderFile(includes, elements)
+
+ def visit_ImplementationFile(self, node, visited_children) -> ImplementationFile:
+ includes, elements = visited_children
+ return ImplementationFile(includes, elements)
+
+ def visit_IncludeStatement(self, node, visited_children) -> HeaderFile:
+ include, included_header, semicolon = visited_children
+ assert include.type == 'include'
+ assert included_header.type == 'string_literal'
+ included_header = included_header.data.strip('"')
+ assert semicolon.type == ';'
+ for include_folder in self.include_folders:
+ header = Path(include_folder) / included_header
+ if header.exists():
+ with open(header, 'r', encoding='utf-8') as header_file:
+ header_text = header_file.read()
+ header_parse_tree = parse_header(scan(header_text))
+ return self.visit(header_parse_tree)
+ raise FileNotFoundError(included_header)
+
+ def visit_NormalStructDefinition(self, node, visited_children) -> StructDeclaration:
+ struct, name, lbrace, fields, rbrace = visited_children
+ assert struct.type == 'struct'
+ assert lbrace.type == '{'
+ assert rbrace.type == '}'
+ name = name.data
+ if not isinstance(fields, list):
+ fields = [fields]
+ return StructDeclaration(name, fields)
+
+ def visit_OpaqueStructDefinition(self, node, visited_children) -> StructDeclaration:
+ opaque, struct, name, semi = visited_children
+ assert opaque.type == 'opaque'
+ assert struct.type == 'struct'
+ assert semi.type == ';'
+ name = name.data
+ return StructDeclaration(name, None)
+
+ def visit_EnumDefinition(self, node, visited_children) -> EnumDeclaration:
+ enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
+ assert enum.type == 'enum'
+ assert lbrace.type == '{'
+ assert rbrace.type == '}'
+ name = name.data
+ values = [first_member]
+ for _, v in extra_members:
+ values.append(v)
+ return EnumDeclaration(name, values)
+
+ def visit_EnumMember(self, node, visited_children) -> Tuple[str, Optional[Expression]]:
+ name, equals_value = visited_children
+ name = name.data
+ if len(equals_value) == 0:
+ return name, None
+ _, value = equals_value
+ return name, value
+
+ def visit_RobustUnionDefinition(self, node, visited_children) -> UnionDeclaration:
+ union, name, lbrace, tag, body, rbrace = visited_children
+ assert union.type == 'union'
+ assert lbrace.type == '{'
+ assert rbrace.type == '}'
+ name = name.data
+ expected_tagname, body = body
+ if tag.name != expected_tagname:
+ raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
+ if not isinstance(body, list):
+ body = [body]
+ return UnionDeclaration(name, tag, body)
+
+ def visit_UnionBody(self, node, visited_children) -> Tuple[str, List[Tuple[Expression, Optional[VariableDeclaration]]]]:
+ switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
+ assert switch.type == 'switch'
+ assert lparen.type == '('
+ assert rparen.type == ')'
+ assert lbrace.type == '{'
+ assert rbrace.type == '}'
+ return tag.data, body
+
+ def visit_UnionBodySet(self, node, visited_children) -> Tuple[Expression, Optional[VariableDeclaration]]:
+ cases, var = visited_children
+ if isinstance(cases, list):
+ cases = cases[0]
+ if isinstance(var, VariableDeclaration):
+ return cases, var
+ else:
+ return cases, None
+
+ def visit_CaseSpecifier(self, node, visited_children) -> Expression:
+ while isinstance(visited_children, list) and len(visited_children) == 1:
+ visited_children = visited_children[0]
+ # TODO don't explode on 'default:'
+ case, expr, colon = visited_children
+ while isinstance(expr, list):
+ expr = expr[0]
+ return expr
+
+ def visit_FragileUnionDefinition(self, node, visited_children) -> UnionDeclaration:
+ fragile, union, name, lbrace, body, rbrace = visited_children
+ assert fragile.type == 'fragile'
+ assert union.type == 'union'
+ assert lbrace.type == '{'
+ assert rbrace.type == '}'
+ name = name.data
+ return UnionDeclaration(name, None, body)
+
+ def visit_FunctionDeclaration(self, node, visited_children) -> FunctionDeclaration:
+ signature, semi = visited_children
+ assert semi.type == ';'
+ return signature
+
+ def visit_VariableDefinition(self, node, visited_children) -> VariableDefinition:
+ type, name, eq, value, semi = visited_children
+ assert eq.type == '='
+ assert semi.type == ';'
+ name = name.data
+ return VariableDefinition(name, type, value)
+
+ def visit_VariableDeclaration(self, node, visited_children) -> VariableDeclaration:
+ type, name, semi = visited_children
+ assert semi.type == ';'
+ name = name.data
+ return VariableDeclaration(name, type)
+
+ def visit_FunctionDefinition(self, node, visited_children) -> FunctionDefinition:
+ signature, body = visited_children
+ return FunctionDefinition(signature.name, signature.return_type, signature.args, body)
+
+ def visit_FunctionSignature(self, node, visited_children) -> FunctionDeclaration:
+ return_type, name, lparen, args, rparen = visited_children
+ assert name.type == 'identifier'
+ name = name.data
+ assert lparen.type == '('
+ assert rparen.type == ')'
+ return FunctionDeclaration(name, return_type, args)
+
+ def visit_BasicType(self, node, visited_children) -> Type:
+ while isinstance(visited_children, list) and len(visited_children) == 1:
+ visited_children = visited_children[0]
+ if isinstance(visited_children, list):
+ if len(visited_children) == 3:
+ # parenthesized!
+ lparen, ty, rparen = visited_children
+ assert lparen.type == '('
+ assert rparen.type == ')'
+ return ty
+ else:
+ category, name = visited_children
+ category = category.type
+ name = name.data
+ return BasicType(f"{category} {name}")
+ return BasicType(visited_children.type)
+
+ def visit_ArrayType(self, node, visited_children) -> ArrayType:
+ contents, lbracket, size, rbracket = visited_children
+ assert lbracket.type == '['
+ assert rbracket.type == ']'
+ # TODO don't explode on nontrivial expression
+ while isinstance(size, list):
+ size = size[0]
+ return ArrayType(contents, size)
+
+ def visit_PointerType(self, node, visited_children) -> PointerType:
+ contents, splat = visited_children
+ assert splat.type == '*'
+ return PointerType(contents)
+
+ def visit_AtomicExpression(self, node, visited_children) -> Expression:
+ if isinstance(visited_children, list) and len(visited_children) == 3:
+ lparen, body, rparen = visited_children
+ assert lparen.type == '('
+ assert rparen.type == ')'
+ return body
+ body = visited_children
+ while isinstance(body, list):
+ body = body[0]
+ if body.type == 'identifier':
+ return VariableExpression(body.data)
+ if body.type == 'constant':
+ return ConstantExpression(body.data)
+ if body.type in ['true', 'false']:
+ return ConstantExpression(body.type)
+ raise NotImplementedError()
+
+ def generic_visit(self, node, visited_children):
+ """ The generic visit method. """
+ if not visited_children:
+ if len(node.text) == 0:
+ return []
+ if len(node.text) == 1:
+ return node.text[0]
+ raise ValueError('just a node: ' + str(node))
+ if len(visited_children) == 1:
+ return visited_children[0]
+ return visited_children
+
+
+def build_ast(parse_tree, include_dirs):
+ builder = ASTBuilder(include_dirs)
+ return builder.visit(parse_tree)
diff --git a/crowbar_reference_compiler/declarations.py b/crowbar_reference_compiler/declarations.py
deleted file mode 100644
index e9a7e4b..0000000
--- a/crowbar_reference_compiler/declarations.py
+++ /dev/null
@@ -1,274 +0,0 @@
-from dataclasses import dataclass
-from pathlib import Path
-from typing import List, Optional, Tuple, Union
-
-from parsimonious import NodeVisitor # type: ignore
-
-from .scanner import scan
-from .parser import parse_header
-
-
-@dataclass
-class Type:
- pass
-
-
-@dataclass
-class BasicType(Type):
- name: str
-
-
-@dataclass
-class PointerType(Type):
- target: Type
-
-
-@dataclass
-class ArrayType(Type):
- contents: Type
- size: int
-
-
-@dataclass
-class Declaration:
- name: str
-
-
-@dataclass
-class VariableDeclaration(Declaration):
- """Represents the declaration of a variable."""
- type: Type
- value: Optional[str]
-
-
-@dataclass
-class Declarations:
- included: List[Declaration]
-
-
-@dataclass
-class StructDeclaration(Declaration):
- """Represents the declaration of a struct type."""
- fields: Optional[List[VariableDeclaration]]
-
-
-@dataclass
-class EnumDeclaration(Declaration):
- """Represents the declaration of an enum type."""
- values: List[Tuple[str, Optional[int]]]
-
-
-@dataclass
-class UnionDeclaration(Declaration):
- """Represents the declaration of a union type."""
- tag: Optional[VariableDeclaration]
- cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]]
-
-
-@dataclass
-class FunctionDeclaration(Declaration):
- """Represents the declaration of a function."""
- return_type: Type
- args: List[VariableDeclaration]
-
-
-# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
-class DeclarationVisitor(NodeVisitor):
- def __init__(self, include_folders):
- self.data = []
- self.include_folders = include_folders
-
- def visit_HeaderFile(self, node, visited_children):
- includes, elements = visited_children
- return elements
-
- def visit_ImplementationFile(self, node, visited_children):
- includes, elements = visited_children
- includes = [y for x in includes for y in x]
- return [x for x in includes + elements if x is not None]
-
- def visit_IncludeStatement(self, node, visited_children):
- include, included_header, semicolon = visited_children
- assert include.text[0].type == 'include'
- assert included_header.type == 'string_literal'
- included_header = included_header.data.strip('"')
- assert semicolon.text[0].type == ';'
- for include_folder in self.include_folders:
- header = Path(include_folder) / included_header
- if header.exists():
- with open(header, 'r', encoding='utf-8') as header_file:
- header_text = header_file.read()
- header_parse_tree = parse_header(scan(header_text))
- return self.visit(header_parse_tree)
- raise FileNotFoundError(included_header)
-
- def visit_NormalStructDefinition(self, node, visited_children):
- struct, name, lbrace, fields, rbrace = visited_children
- assert struct.text[0].type == 'struct'
- assert lbrace.text[0].type == '{'
- assert rbrace.text[0].type == '}'
- name = name.data
- if not isinstance(fields, list):
- fields = [fields]
- return StructDeclaration(name, fields)
-
- def visit_OpaqueStructDefinition(self, node, visited_children):
- opaque, struct, name, semi = visited_children
- assert opaque.text[0].type == 'opaque'
- assert struct.text[0].type == 'struct'
- assert semi.text[0].type == ';'
- name = name.data
- return StructDeclaration(name, None)
-
- def visit_EnumDefinition(self, node, visited_children):
- enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
- assert enum.text[0].type == 'enum'
- assert lbrace.text[0].type == '{'
- assert rbrace.text[0].type == '}'
- name = name.data
- values = [first_member]
- for _, v in extra_members:
- values.append(v)
- return EnumDeclaration(name, values)
-
- def visit_EnumMember(self, node, visited_children):
- name, equals_value = visited_children
- name = name.data
- if not isinstance(equals_value, list):
- return name, None
- _, value = equals_value
- return name, value
-
- def visit_RobustUnionDefinition(self, node, visited_children):
- union, name, lbrace, tag, body, rbrace = visited_children
- assert union.text[0].type == 'union'
- assert lbrace.text[0].type == '{'
- assert rbrace.text[0].type == '}'
- name = name.data
- expected_tagname, body = body
- if tag.name != expected_tagname:
- raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
- if not isinstance(body, list):
- body = [body]
- return UnionDeclaration(name, tag, body)
-
- def visit_UnionBody(self, node, visited_children):
- switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
- assert switch.text[0].type == 'switch'
- assert lparen.text[0].type == '('
- assert rparen.text[0].type == ')'
- assert lbrace.text[0].type == '{'
- assert rbrace.text[0].type == '}'
- return tag.data, body
-
- def visit_UnionBodySet(self, node, visited_children):
- cases, var = visited_children
- if isinstance(cases, list):
- cases = cases[0]
- if isinstance(var, VariableDeclaration):
- return cases, var
- else:
- return cases, None
-
- def visit_CaseSpecifier(self, node, visited_children):
- while isinstance(visited_children, list) and len(visited_children) == 1:
- visited_children = visited_children[0]
- # TODO don't explode on 'default:'
- case, expr, colon = visited_children
- while isinstance(expr, list):
- expr = expr[0]
- # TODO don't explode on nontrivial expression
- return expr.data
-
- def visit_FragileUnionDefinition(self, node, visited_children):
- fragile, union, name, lbrace, body, rbrace = visited_children
- assert fragile.text[0].type == 'fragile'
- assert union.text[0].type == 'union'
- assert lbrace.text[0].type == '{'
- assert rbrace.text[0].type == '}'
- name = name.data
- return UnionDeclaration(name, None, body)
-
- def visit_FunctionDeclaration(self, node, visited_children):
- signature, semi = visited_children
- assert semi.text[0].type == ';'
- return signature
-
- def visit_VariableDefinition(self, node, visited_children):
- type, name, eq, value, semi = visited_children
- assert eq.text[0].type == '='
- assert semi.text[0].type == ';'
- name = name.data
- return VariableDeclaration(name, type, value)
-
- def visit_VariableDeclaration(self, node, visited_children):
- type, name, semi = visited_children
- assert semi.text[0].type == ';'
- name = name.data
- return VariableDeclaration(name, type, None)
-
- def visit_FunctionDefinition(self, node, visited_children):
- signature, body = visited_children
- return signature
-
- def visit_FunctionSignature(self, node, visited_children):
- return_type, name, lparen, args, rparen = visited_children
- assert name.type == 'identifier'
- name = name.data
- assert lparen.text[0].type == '('
- assert rparen.text[0].type == ')'
- return FunctionDeclaration(name, return_type, args)
-
- def visit_BasicType(self, node, visited_children):
- while isinstance(visited_children, list) and len(visited_children) == 1:
- visited_children = visited_children[0]
- if isinstance(visited_children, list):
- if len(visited_children) == 3:
- # parenthesized!
- lparen, ty, rparen = visited_children
- assert lparen.text[0].type == '('
- assert rparen.text[0].type == ')'
- return ty
- else:
- category, name = visited_children
- category = category.text[0].type
- name = name.data
- return BasicType(f"{category} {name}")
- return BasicType(visited_children.text[0].type)
-
- def visit_ArrayType(self, node, visited_children):
- contents, lbracket, size, rbracket = visited_children
- assert lbracket.text[0].type == '['
- assert rbracket.text[0].type == ']'
- # TODO don't explode on nontrivial expression
- while isinstance(size, list):
- size = size[0]
- size = int(size.data)
- return ArrayType(contents, size)
-
- def visit_PointerType(self, node, visited_children):
- contents, splat = visited_children
- assert splat.text[0].type == '*'
- return PointerType(contents)
-
- def visit_constant(self, node, visited_children):
- return node.text[0]
-
- def visit_string_literal(self, node, visited_children):
- return node.text[0]
-
- def visit_identifier(self, node, visited_children):
- return node.text[0]
-
- def generic_visit(self, node, visited_children):
- """ The generic visit method. """
- if not visited_children:
- return node
- if len(visited_children) == 1:
- return visited_children[0]
- return visited_children
-
-
-def load_declarations(parse_tree, include_dirs):
- declarations = DeclarationVisitor(include_dirs)
- return declarations.visit(parse_tree)