1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
|
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union
from parsimonious import NodeVisitor # type: ignore
from .scanner import scan
from .parser import parse_header
@dataclass
class Type:
pass
@dataclass
class BasicType(Type):
name: str
@dataclass
class PointerType(Type):
target: Type
@dataclass
class ArrayType(Type):
contents: Type
size: int
@dataclass
class Declaration:
name: str
@dataclass
class VariableDeclaration(Declaration):
"""Represents the declaration of a variable."""
type: Type
value: Optional[str]
@dataclass
class Declarations:
included: List[Declaration]
@dataclass
class StructDeclaration(Declaration):
"""Represents the declaration of a struct type."""
fields: Optional[List[VariableDeclaration]]
@dataclass
class EnumDeclaration(Declaration):
"""Represents the declaration of an enum type."""
values: List[Tuple[str, Optional[int]]]
@dataclass
class UnionDeclaration(Declaration):
"""Represents the declaration of a union type."""
tag: Optional[VariableDeclaration]
cases: Union[List[VariableDeclaration], List[Tuple[str, Optional[VariableDeclaration]]]]
@dataclass
class FunctionDeclaration(Declaration):
"""Represents the declaration of a function."""
return_type: Type
args: List[VariableDeclaration]
# noinspection PyPep8Naming,PyMethodMayBeStatic,PyUnusedLocal
class DeclarationVisitor(NodeVisitor):
def __init__(self, include_folders):
self.data = []
self.include_folders = include_folders
def visit_HeaderFile(self, node, visited_children):
includes, elements = visited_children
return elements
def visit_ImplementationFile(self, node, visited_children):
includes, elements = visited_children
includes = [y for x in includes for y in x]
return [x for x in includes + elements if x is not None]
def visit_IncludeStatement(self, node, visited_children):
include, included_header, semicolon = visited_children
assert include.text[0].type == 'include'
assert included_header.type == 'string_literal'
included_header = included_header.data.strip('"')
assert semicolon.text[0].type == ';'
for include_folder in self.include_folders:
header = Path(include_folder) / included_header
if header.exists():
with open(header, 'r', encoding='utf-8') as header_file:
header_text = header_file.read()
header_parse_tree = parse_header(scan(header_text))
return self.visit(header_parse_tree)
raise FileNotFoundError(included_header)
def visit_NormalStructDefinition(self, node, visited_children):
struct, name, lbrace, fields, rbrace = visited_children
assert struct.text[0].type == 'struct'
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
if not isinstance(fields, list):
fields = [fields]
return StructDeclaration(name, fields)
def visit_OpaqueStructDefinition(self, node, visited_children):
opaque, struct, name, semi = visited_children
assert opaque.text[0].type == 'opaque'
assert struct.text[0].type == 'struct'
assert semi.text[0].type == ';'
name = name.data
return StructDeclaration(name, None)
def visit_EnumDefinition(self, node, visited_children):
enum, name, lbrace, first_member, extra_members, trailing_comma, rbrace = visited_children
assert enum.text[0].type == 'enum'
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
values = [first_member]
for _, v in extra_members:
values.append(v)
return EnumDeclaration(name, values)
def visit_EnumMember(self, node, visited_children):
name, equals_value = visited_children
name = name.data
if not isinstance(equals_value, list):
return name, None
_, value = equals_value
return name, value
def visit_RobustUnionDefinition(self, node, visited_children):
union, name, lbrace, tag, body, rbrace = visited_children
assert union.text[0].type == 'union'
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
expected_tagname, body = body
if tag.name != expected_tagname:
raise NameError(f"tag {tag} does not match switch argument {expected_tagname}")
if not isinstance(body, list):
body = [body]
return UnionDeclaration(name, tag, body)
def visit_UnionBody(self, node, visited_children):
switch, lparen, tag, rparen, lbrace, body, rbrace = visited_children
assert switch.text[0].type == 'switch'
assert lparen.text[0].type == '('
assert rparen.text[0].type == ')'
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
return tag.data, body
def visit_UnionBodySet(self, node, visited_children):
cases, var = visited_children
if isinstance(cases, list):
cases = cases[0]
if isinstance(var, VariableDeclaration):
return cases, var
else:
return cases, None
def visit_CaseSpecifier(self, node, visited_children):
while isinstance(visited_children, list) and len(visited_children) == 1:
visited_children = visited_children[0]
# TODO don't explode on 'default:'
case, expr, colon = visited_children
while isinstance(expr, list):
expr = expr[0]
# TODO don't explode on nontrivial expression
return expr.data
def visit_FragileUnionDefinition(self, node, visited_children):
fragile, union, name, lbrace, body, rbrace = visited_children
assert fragile.text[0].type == 'fragile'
assert union.text[0].type == 'union'
assert lbrace.text[0].type == '{'
assert rbrace.text[0].type == '}'
name = name.data
return UnionDeclaration(name, None, body)
def visit_FunctionDeclaration(self, node, visited_children):
signature, semi = visited_children
assert semi.text[0].type == ';'
return signature
def visit_VariableDefinition(self, node, visited_children):
type, name, eq, value, semi = visited_children
assert eq.text[0].type == '='
assert semi.text[0].type == ';'
name = name.data
return VariableDeclaration(name, type, value)
def visit_VariableDeclaration(self, node, visited_children):
type, name, semi = visited_children
assert semi.text[0].type == ';'
name = name.data
return VariableDeclaration(name, type, None)
def visit_FunctionDefinition(self, node, visited_children):
signature, body = visited_children
return signature
def visit_FunctionSignature(self, node, visited_children):
return_type, name, lparen, args, rparen = visited_children
assert name.type == 'identifier'
name = name.data
assert lparen.text[0].type == '('
assert rparen.text[0].type == ')'
return FunctionDeclaration(name, return_type, args)
def visit_BasicType(self, node, visited_children):
while isinstance(visited_children, list) and len(visited_children) == 1:
visited_children = visited_children[0]
if isinstance(visited_children, list):
if len(visited_children) == 3:
# parenthesized!
lparen, ty, rparen = visited_children
assert lparen.text[0].type == '('
assert rparen.text[0].type == ')'
return ty
else:
category, name = visited_children
category = category.text[0].type
name = name.data
return BasicType(f"{category} {name}")
return BasicType(visited_children.text[0].type)
def visit_ArrayType(self, node, visited_children):
contents, lbracket, size, rbracket = visited_children
assert lbracket.text[0].type == '['
assert rbracket.text[0].type == ']'
# TODO don't explode on nontrivial expression
while isinstance(size, list):
size = size[0]
size = int(size.data)
return ArrayType(contents, size)
def visit_PointerType(self, node, visited_children):
contents, splat = visited_children
assert splat.text[0].type == '*'
return PointerType(contents)
def visit_constant(self, node, visited_children):
return node.text[0]
def visit_string_literal(self, node, visited_children):
return node.text[0]
def visit_identifier(self, node, visited_children):
return node.text[0]
def generic_visit(self, node, visited_children):
""" The generic visit method. """
if not visited_children:
return node
if len(visited_children) == 1:
return visited_children[0]
return visited_children
def load_declarations(parse_tree, include_dirs):
declarations = DeclarationVisitor(include_dirs)
return declarations.visit(parse_tree)
|