aboutsummaryrefslogtreecommitdiff
path: root/crowbar_reference_compiler/declarations.py
blob: e9a7e4b08fc9626894705eeaaa3d86b62a26cab6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union

from parsimonious import NodeVisitor  # type: ignore

from .scanner import scan
from .parser import parse_header


@dataclass
class Type:
    pass


@dataclass
class BasicType(Type):
    name: str


@dataclass
class PointerType(Type):
    target: Type


@dataclass
class ArrayType(Type):
    contents: Type
    size: int


@dataclass
class Declaration:
    name: str


@dataclass
class VariableDeclaration(Declaration):
    """Represents the declaration of a variable."""
    type: Type
    value: Optional[str]


@dataclass
class Declarations:
    included: List[Declaration]


@dataclass
class StructDeclaration(Declaration):
    """Represents the declaration of a struct type."""
    fields: Optional[List[VariableDeclaration]]


@dataclass
class EnumDeclaration(Declaration):
    """Represents the declaration of an enum type."""
    values: List[Tuple[str, Optional[int]]]


@dataclass
class UnionDeclaration(Declaration):
    """Represents the declaration of a union type."""
    tag: Optional[VariableDeclaration]
    cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]]


@dataclass
class FunctionDeclaration(Declaration):
    """Represents the declaration of a function."""
    return_type: Type
    args: List[VariableDeclaration]


# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class DeclarationVisitor(NodeVisitor):
    def __init__(self, include_folders):
        self.data = []
        self.include_folders = include_folders

    def visit_HeaderFile(self, node, visited_children):
        includes, elements = visited_children
        return elements

    def visit_ImplementationFile(self, node, visited_children):
        includes, elements = visited_children
        includes = [y for x in includes for y in x]
        return [x for x in includes + elements if x is not None]

    def visit_IncludeStatement(self, node, visited_children):
        include, included_header, semicolon = visited_children
        assert include.text[0].type == 'include'
        assert included_header.type == 'string_literal'
        included_header = included_header.data.strip('"')
        assert semicolon.text[0].type == ';'
        for include_folder in self.include_folders:
            header = Path(include_folder) / included_header
            if header.exists():
                with open(header, 'r', encoding='utf-8') as header_file:
                    header_text = header_file.read()
                header_parse_tree = parse_header(scan(header_text))
                return self.visit(header_parse_tree)
        raise FileNotFoundError(included_header)

    def visit_NormalStructDefinition(self, node, visited_children):
        struct, name, lbrace, fields, rbrace = visited_children
        assert struct.text[0].type == 'struct'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        if not isinstance(fields, list):
            fields = [fields]
        return StructDeclaration(name, fields)

    def visit_OpaqueStructDefinition(self, node, visited_children):
        opaque, struct, name, semi = visited_children
        assert opaque.text[0].type == 'opaque'
        assert struct.text[0].type == 'struct'
        assert semi.text[0].type == ';'
        name = name.data
        return StructDeclaration(name, None)

    def visit_EnumDefinition(self, node, visited_children):
        enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
        assert enum.text[0].type == 'enum'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        values = [first_member]
        for _, v in extra_members:
            values.append(v)
        return EnumDeclaration(name, values)

    def visit_EnumMember(self, node, visited_children):
        name, equals_value = visited_children
        name = name.data
        if not isinstance(equals_value, list):
            return name, None
        _, value = equals_value
        return name, value

    def visit_RobustUnionDefinition(self, node, visited_children):
        union, name, lbrace, tag, body, rbrace = visited_children
        assert union.text[0].type == 'union'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        expected_tagname, body = body
        if tag.name != expected_tagname:
            raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
        if not isinstance(body, list):
            body = [body]
        return UnionDeclaration(name, tag, body)

    def visit_UnionBody(self, node, visited_children):
        switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
        assert switch.text[0].type == 'switch'
        assert lparen.text[0].type == '('
        assert rparen.text[0].type == ')'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        return tag.data, body

    def visit_UnionBodySet(self, node, visited_children):
        cases, var = visited_children
        if isinstance(cases, list):
            cases = cases[0]
        if isinstance(var, VariableDeclaration):
            return cases, var
        else:
            return cases, None

    def visit_CaseSpecifier(self, node, visited_children):
        while isinstance(visited_children, list) and len(visited_children) == 1:
            visited_children = visited_children[0]
        # TODO don't explode on 'default:'
        case, expr, colon = visited_children
        while isinstance(expr, list):
            expr = expr[0]
        # TODO don't explode on nontrivial expression
        return expr.data

    def visit_FragileUnionDefinition(self, node, visited_children):
        fragile, union, name, lbrace, body, rbrace = visited_children
        assert fragile.text[0].type == 'fragile'
        assert union.text[0].type == 'union'
        assert lbrace.text[0].type == '{'
        assert rbrace.text[0].type == '}'
        name = name.data
        return UnionDeclaration(name, None, body)

    def visit_FunctionDeclaration(self, node, visited_children):
        signature, semi = visited_children
        assert semi.text[0].type == ';'
        return signature

    def visit_VariableDefinition(self, node, visited_children):
        type, name, eq, value, semi = visited_children
        assert eq.text[0].type == '='
        assert semi.text[0].type == ';'
        name = name.data
        return VariableDeclaration(name, type, value)

    def visit_VariableDeclaration(self, node, visited_children):
        type, name, semi = visited_children
        assert semi.text[0].type == ';'
        name = name.data
        return VariableDeclaration(name, type, None)

    def visit_FunctionDefinition(self, node, visited_children):
        signature, body = visited_children
        return signature

    def visit_FunctionSignature(self, node, visited_children):
        return_type, name, lparen, args, rparen = visited_children
        assert name.type == 'identifier'
        name = name.data
        assert lparen.text[0].type == '('
        assert rparen.text[0].type == ')'
        return FunctionDeclaration(name, return_type, args)

    def visit_BasicType(self, node, visited_children):
        while isinstance(visited_children, list) and len(visited_children) == 1:
            visited_children = visited_children[0]
        if isinstance(visited_children, list):
            if len(visited_children) == 3:
                # parenthesized!
                lparen, ty, rparen = visited_children
                assert lparen.text[0].type == '('
                assert rparen.text[0].type == ')'
                return ty
            else:
                category, name = visited_children
                category = category.text[0].type
                name = name.data
                return BasicType(f"{category} {name}")
        return BasicType(visited_children.text[0].type)

    def visit_ArrayType(self, node, visited_children):
        contents, lbracket, size, rbracket = visited_children
        assert lbracket.text[0].type == '['
        assert rbracket.text[0].type == ']'
        # TODO don't explode on nontrivial expression
        while isinstance(size, list):
            size = size[0]
        size = int(size.data)
        return ArrayType(contents, size)

    def visit_PointerType(self, node, visited_children):
        contents, splat = visited_children
        assert splat.text[0].type == '*'
        return PointerType(contents)

    def visit_constant(self, node, visited_children):
        return node.text[0]

    def visit_string_literal(self, node, visited_children):
        return node.text[0]

    def visit_identifier(self, node, visited_children):
        return node.text[0]

    def generic_visit(self, node, visited_children):
        """ The generic visit method. """
        if not visited_children:
            return node
        if len(visited_children) == 1:
            return visited_children[0]
        return visited_children


def load_declarations(parse_tree, include_dirs):
    declarations = DeclarationVisitor(include_dirs)
    return declarations.visit(parse_tree)