aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler/declarations.py
blob: 9f1a6b2382fe16b13d98aebf5cc130d393b2973a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from pathlib import Path

from parsimonious import NodeVisitor  # type: ignore

from .scanner import scan
from .parser import parse_header


class DeclarationVisitor(NodeVisitor):
    def __init__(self, include_folders):
        self.data = []
        self.include_folders = include_folders

    def visit_HeaderFile(self, node, visited_children):
        includes, elements = visited_children
        return elements

    def visit_ImplementationFile(self, node, visited_children):
        return [x for x in visited_children if x is not None]

    def visit_IncludeStatement(self, node, visited_children):
        include, included_header, semicolon = visited_children
        assert include.text[0].type == 'include'
        assert included_header.type == 'string_literal'
        included_header = included_header.data
        assert semicolon.text[0].type == ';'
        for include_folder in self.include_folders:
            header = Path(include_folder) / included_header
            if header.exists():
                with open(header, 'r', encoding='utf-8') as header_file:
                    header_text = header_file.read()
                header_parse_tree = parse_header(scan(header_text))
                return self.visit(header_parse_tree)
        raise FileNotFoundError(included_header)

    def visit_NormalStructDefinition(self, node, visited_children):
        struct, name, lbrace, fields, rbrace = visited_children
        assert struct.text[0].type == 'struct'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"struct {name}"

    def visit_OpaqueStructDefinition(self, node, visited_children):
        opaque, struct, name, semi = visited_children
        assert opaque.text[0].type == 'opaque'
        assert struct.text[0].type == 'struct'
        assert semi.text[0].type == ';'
        name = name.data
        return f"struct {name}"

    def visit_EnumDefinition(self, node, visited_children):
        enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
        assert enum.text[0].type == 'enum'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"enum {name}"

    def visit_RobustUnionDefinition(self, node, visited_children):
        union, name, lbrace, tag, body, rbrace = visited_children
        assert union.text[0].type == 'union'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"union {name}"

    def visit_FragileUnionDefinition(self, node, visited_children):
        fragile, union, name, lbrace, body, rbrace = visited_children
        assert fragile.text[0].type == 'fragile'
        assert union.text[0].type == 'union'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return f"union {name}"

    def visit_FunctionDeclaration(self, node, visited_children):
        signature, semi = visited_children
        assert semi.text[0].type == ';'
        return signature

    def visit_VariableDefinition(self, node, visited_children):
        type, name, eq, value, semi = visited_children
        assert eq.text[0].type == '='
        assert semi.text[0].type == ';'
        name = name.data
        return name

    def visit_VariableDeclaration(self, node, visited_children):
        type, name, semi = visited_children
        assert semi.text[0].type == ';'
        name = name.data
        return name

    def visit_FunctionDefinition(self, node, visited_children):
        signature, body = visited_children
        return signature

    def visit_FunctionSignature(self, node, visited_children):
        return_type, name, lparen, args, rparen = visited_children
        assert name.type == 'identifier'
        name = name.data
        assert lparen.text[0].type == '('
        assert rparen.text[0].type == ')'
        return return_type, name, args

    def visit_constant(self, node, visited_children):
        return node.text[0]

    def visit_string_literal(self, node, visited_children):
        return node.text[0]

    def visit_identifier(self, node, visited_children):
        return node.text[0]

    def generic_visit(self, node, visited_children):
        """ The generic visit method. """
        if not visited_children:
            return node
        if len(visited_children) == 1:
            return visited_children[0]
        return visited_children


def load_declarations(parse_tree):
    declarations = DeclarationVisitor([])
    return declarations.visit(parse_tree)