vex-ast 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vex_ast/__init__.py +65 -0
- vex_ast/ast/__init__.py +75 -0
- vex_ast/ast/core.py +71 -0
- vex_ast/ast/expressions.py +233 -0
- vex_ast/ast/interfaces.py +192 -0
- vex_ast/ast/literals.py +80 -0
- vex_ast/ast/navigator.py +213 -0
- vex_ast/ast/operators.py +136 -0
- vex_ast/ast/statements.py +351 -0
- vex_ast/ast/validators.py +114 -0
- vex_ast/ast/vex_nodes.py +241 -0
- vex_ast/parser/__init__.py +0 -0
- vex_ast/parser/factory.py +179 -0
- vex_ast/parser/interfaces.py +35 -0
- vex_ast/parser/python_parser.py +725 -0
- vex_ast/parser/strategies.py +0 -0
- vex_ast/registry/__init__.py +51 -0
- vex_ast/registry/api.py +155 -0
- vex_ast/registry/categories.py +136 -0
- vex_ast/registry/language_map.py +78 -0
- vex_ast/registry/registry.py +153 -0
- vex_ast/registry/signature.py +143 -0
- vex_ast/registry/simulation_behavior.py +9 -0
- vex_ast/registry/validation.py +44 -0
- vex_ast/serialization/__init__.py +37 -0
- vex_ast/serialization/json_deserializer.py +264 -0
- vex_ast/serialization/json_serializer.py +148 -0
- vex_ast/serialization/schema.py +471 -0
- vex_ast/utils/__init__.py +0 -0
- vex_ast/utils/errors.py +112 -0
- vex_ast/utils/source_location.py +39 -0
- vex_ast/utils/type_definitions.py +0 -0
- vex_ast/visitors/__init__.py +0 -0
- vex_ast/visitors/analyzer.py +103 -0
- vex_ast/visitors/base.py +130 -0
- vex_ast/visitors/printer.py +145 -0
- vex_ast/visitors/transformer.py +0 -0
- vex_ast-0.1.0.dist-info/METADATA +176 -0
- vex_ast-0.1.0.dist-info/RECORD +41 -0
- vex_ast-0.1.0.dist-info/WHEEL +5 -0
- vex_ast-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,351 @@
|
|
1
|
+
"""Statement nodes for the AST."""
|
2
|
+
|
3
|
+
from typing import Dict, List, Optional, Union, cast, Any
|
4
|
+
|
5
|
+
from .interfaces import IAstNode, IExpression, IStatement, IVisitor, T_VisitorResult, IAssignment
|
6
|
+
from .core import Statement, Expression
|
7
|
+
from .expressions import Identifier
|
8
|
+
from ..utils.source_location import SourceLocation
|
9
|
+
|
10
|
+
class ExpressionStatement(Statement):
|
11
|
+
"""An expression used as a statement."""
|
12
|
+
|
13
|
+
_fields = ('expression',)
|
14
|
+
|
15
|
+
def __init__(self, expression: IExpression, location: Optional[SourceLocation] = None):
|
16
|
+
super().__init__(location)
|
17
|
+
self.expression = expression
|
18
|
+
if isinstance(expression, Expression):
|
19
|
+
expression.set_parent(self)
|
20
|
+
|
21
|
+
def get_children(self) -> List[IAstNode]:
|
22
|
+
return [cast(IAstNode, self.expression)]
|
23
|
+
|
24
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
25
|
+
return visitor.visit_expressionstatement(self)
|
26
|
+
|
27
|
+
def get_expression(self) -> IExpression:
|
28
|
+
"""Get the expression."""
|
29
|
+
return self.expression
|
30
|
+
|
31
|
+
class Assignment(Statement, IAssignment):
|
32
|
+
"""An assignment statement (target = value)."""
|
33
|
+
|
34
|
+
_fields = ('target', 'value')
|
35
|
+
|
36
|
+
def __init__(self, target: IExpression, value: IExpression,
|
37
|
+
location: Optional[SourceLocation] = None):
|
38
|
+
super().__init__(location)
|
39
|
+
self.target = target
|
40
|
+
self.value = value
|
41
|
+
if isinstance(target, Expression):
|
42
|
+
target.set_parent(self)
|
43
|
+
if isinstance(value, Expression):
|
44
|
+
value.set_parent(self)
|
45
|
+
|
46
|
+
def get_children(self) -> List[IAstNode]:
|
47
|
+
return [cast(IAstNode, self.target), cast(IAstNode, self.value)]
|
48
|
+
|
49
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
50
|
+
return visitor.visit_assignment(self)
|
51
|
+
|
52
|
+
def get_target(self) -> IExpression:
|
53
|
+
"""Get the assignment target."""
|
54
|
+
return self.target
|
55
|
+
|
56
|
+
def get_value(self) -> IExpression:
|
57
|
+
"""Get the assigned value."""
|
58
|
+
return self.value
|
59
|
+
|
60
|
+
class Argument(Statement):
|
61
|
+
"""A function argument in a definition."""
|
62
|
+
|
63
|
+
_fields = ('name', 'annotation', 'default')
|
64
|
+
|
65
|
+
def __init__(self, name: str, annotation: Optional[IExpression] = None,
|
66
|
+
default: Optional[IExpression] = None,
|
67
|
+
location: Optional[SourceLocation] = None):
|
68
|
+
super().__init__(location)
|
69
|
+
self.name = name
|
70
|
+
self.annotation = annotation
|
71
|
+
self.default = default
|
72
|
+
|
73
|
+
if isinstance(annotation, Expression):
|
74
|
+
annotation.set_parent(self)
|
75
|
+
if isinstance(default, Expression):
|
76
|
+
default.set_parent(self)
|
77
|
+
|
78
|
+
def get_children(self) -> List[IAstNode]:
|
79
|
+
result: List[IAstNode] = []
|
80
|
+
if self.annotation:
|
81
|
+
result.append(cast(IAstNode, self.annotation))
|
82
|
+
if self.default:
|
83
|
+
result.append(cast(IAstNode, self.default))
|
84
|
+
return result
|
85
|
+
|
86
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
87
|
+
return visitor.visit_argument(self)
|
88
|
+
|
89
|
+
def get_name(self) -> str:
|
90
|
+
"""Get the argument name."""
|
91
|
+
return self.name
|
92
|
+
|
93
|
+
def get_annotation(self) -> Optional[IExpression]:
|
94
|
+
"""Get the type annotation, if any."""
|
95
|
+
return self.annotation
|
96
|
+
|
97
|
+
def get_default(self) -> Optional[IExpression]:
|
98
|
+
"""Get the default value, if any."""
|
99
|
+
return self.default
|
100
|
+
|
101
|
+
class FunctionDefinition(Statement):
|
102
|
+
"""A function definition."""
|
103
|
+
|
104
|
+
_fields = ('name', 'args', 'body', 'return_annotation')
|
105
|
+
|
106
|
+
def __init__(self, name: str, args: List[Argument], body: List[IStatement],
|
107
|
+
return_annotation: Optional[IExpression] = None,
|
108
|
+
location: Optional[SourceLocation] = None):
|
109
|
+
super().__init__(location)
|
110
|
+
self.name = name
|
111
|
+
self.args = args
|
112
|
+
self.body = body
|
113
|
+
self.return_annotation = return_annotation
|
114
|
+
|
115
|
+
# Set parent references
|
116
|
+
for arg in self.args:
|
117
|
+
if isinstance(arg, Statement):
|
118
|
+
arg.set_parent(self)
|
119
|
+
|
120
|
+
for stmt in self.body:
|
121
|
+
if isinstance(stmt, Statement):
|
122
|
+
stmt.set_parent(self)
|
123
|
+
|
124
|
+
if isinstance(return_annotation, Expression):
|
125
|
+
return_annotation.set_parent(self)
|
126
|
+
|
127
|
+
def get_children(self) -> List[IAstNode]:
|
128
|
+
result: List[IAstNode] = cast(List[IAstNode], self.args)
|
129
|
+
result.extend(cast(List[IAstNode], self.body))
|
130
|
+
if self.return_annotation:
|
131
|
+
result.append(cast(IAstNode, self.return_annotation))
|
132
|
+
return result
|
133
|
+
|
134
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
135
|
+
return visitor.visit_functiondefinition(self)
|
136
|
+
|
137
|
+
def get_name(self) -> str:
|
138
|
+
"""Get the function name."""
|
139
|
+
return self.name
|
140
|
+
|
141
|
+
def get_arguments(self) -> List[Argument]:
|
142
|
+
"""Get the function arguments."""
|
143
|
+
return self.args
|
144
|
+
|
145
|
+
def get_body(self) -> List[IStatement]:
|
146
|
+
"""Get the function body."""
|
147
|
+
return self.body
|
148
|
+
|
149
|
+
def get_return_annotation(self) -> Optional[IExpression]:
|
150
|
+
"""Get the return type annotation, if any."""
|
151
|
+
return self.return_annotation
|
152
|
+
|
153
|
+
def add_statement(self, statement: IStatement) -> None:
|
154
|
+
"""Add a statement to the function body."""
|
155
|
+
self.body.append(statement)
|
156
|
+
if isinstance(statement, Statement):
|
157
|
+
statement.set_parent(self)
|
158
|
+
|
159
|
+
class IfStatement(Statement):
|
160
|
+
"""An if statement with optional else branch."""
|
161
|
+
|
162
|
+
_fields = ('test', 'body', 'orelse')
|
163
|
+
|
164
|
+
def __init__(self, test: IExpression, body: List[IStatement],
|
165
|
+
orelse: Optional[Union[List[IStatement], 'IfStatement']] = None,
|
166
|
+
location: Optional[SourceLocation] = None):
|
167
|
+
super().__init__(location)
|
168
|
+
self.test = test
|
169
|
+
self.body = body
|
170
|
+
self.orelse = orelse
|
171
|
+
|
172
|
+
# Set parent references
|
173
|
+
if isinstance(test, Expression):
|
174
|
+
test.set_parent(self)
|
175
|
+
|
176
|
+
for stmt in self.body:
|
177
|
+
if isinstance(stmt, Statement):
|
178
|
+
stmt.set_parent(self)
|
179
|
+
|
180
|
+
if isinstance(self.orelse, list):
|
181
|
+
for stmt in self.orelse:
|
182
|
+
if isinstance(stmt, Statement):
|
183
|
+
stmt.set_parent(self)
|
184
|
+
elif isinstance(self.orelse, Statement):
|
185
|
+
self.orelse.set_parent(self)
|
186
|
+
|
187
|
+
def get_children(self) -> List[IAstNode]:
|
188
|
+
result: List[IAstNode] = [cast(IAstNode, self.test)]
|
189
|
+
result.extend(cast(List[IAstNode], self.body))
|
190
|
+
if isinstance(self.orelse, list):
|
191
|
+
result.extend(cast(List[IAstNode], self.orelse))
|
192
|
+
elif self.orelse:
|
193
|
+
result.append(cast(IAstNode, self.orelse))
|
194
|
+
return result
|
195
|
+
|
196
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
197
|
+
return visitor.visit_ifstatement(self)
|
198
|
+
|
199
|
+
def get_test(self) -> IExpression:
|
200
|
+
"""Get the test condition."""
|
201
|
+
return self.test
|
202
|
+
|
203
|
+
def get_body(self) -> List[IStatement]:
|
204
|
+
"""Get the if body."""
|
205
|
+
return self.body
|
206
|
+
|
207
|
+
def get_else(self) -> Optional[Union[List[IStatement], 'IfStatement']]:
|
208
|
+
"""Get the else branch, if any."""
|
209
|
+
return self.orelse
|
210
|
+
|
211
|
+
def add_statement(self, statement: IStatement) -> None:
|
212
|
+
"""Add a statement to the if body."""
|
213
|
+
self.body.append(statement)
|
214
|
+
if isinstance(statement, Statement):
|
215
|
+
statement.set_parent(self)
|
216
|
+
|
217
|
+
class WhileLoop(Statement):
|
218
|
+
"""A while loop."""
|
219
|
+
|
220
|
+
_fields = ('test', 'body')
|
221
|
+
|
222
|
+
def __init__(self, test: IExpression, body: List[IStatement],
|
223
|
+
location: Optional[SourceLocation] = None):
|
224
|
+
super().__init__(location)
|
225
|
+
self.test = test
|
226
|
+
self.body = body
|
227
|
+
|
228
|
+
# Set parent references
|
229
|
+
if isinstance(test, Expression):
|
230
|
+
test.set_parent(self)
|
231
|
+
|
232
|
+
for stmt in self.body:
|
233
|
+
if isinstance(stmt, Statement):
|
234
|
+
stmt.set_parent(self)
|
235
|
+
|
236
|
+
def get_children(self) -> List[IAstNode]:
|
237
|
+
result: List[IAstNode] = [cast(IAstNode, self.test)]
|
238
|
+
result.extend(cast(List[IAstNode], self.body))
|
239
|
+
return result
|
240
|
+
|
241
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
242
|
+
return visitor.visit_whileloop(self)
|
243
|
+
|
244
|
+
def get_test(self) -> IExpression:
|
245
|
+
"""Get the test condition."""
|
246
|
+
return self.test
|
247
|
+
|
248
|
+
def get_body(self) -> List[IStatement]:
|
249
|
+
"""Get the loop body."""
|
250
|
+
return self.body
|
251
|
+
|
252
|
+
def add_statement(self, statement: IStatement) -> None:
|
253
|
+
"""Add a statement to the loop body."""
|
254
|
+
self.body.append(statement)
|
255
|
+
if isinstance(statement, Statement):
|
256
|
+
statement.set_parent(self)
|
257
|
+
|
258
|
+
class ForLoop(Statement):
|
259
|
+
"""A for loop (for target in iterable)."""
|
260
|
+
|
261
|
+
_fields = ('target', 'iterable', 'body')
|
262
|
+
|
263
|
+
def __init__(self, target: IExpression, iterable: IExpression,
|
264
|
+
body: List[IStatement], location: Optional[SourceLocation] = None):
|
265
|
+
super().__init__(location)
|
266
|
+
self.target = target
|
267
|
+
self.iterable = iterable
|
268
|
+
self.body = body
|
269
|
+
|
270
|
+
# Set parent references
|
271
|
+
if isinstance(target, Expression):
|
272
|
+
target.set_parent(self)
|
273
|
+
if isinstance(iterable, Expression):
|
274
|
+
iterable.set_parent(self)
|
275
|
+
|
276
|
+
for stmt in self.body:
|
277
|
+
if isinstance(stmt, Statement):
|
278
|
+
stmt.set_parent(self)
|
279
|
+
|
280
|
+
def get_children(self) -> List[IAstNode]:
|
281
|
+
result: List[IAstNode] = [
|
282
|
+
cast(IAstNode, self.target),
|
283
|
+
cast(IAstNode, self.iterable)
|
284
|
+
]
|
285
|
+
result.extend(cast(List[IAstNode], self.body))
|
286
|
+
return result
|
287
|
+
|
288
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
289
|
+
return visitor.visit_forloop(self)
|
290
|
+
|
291
|
+
def get_target(self) -> IExpression:
|
292
|
+
"""Get the loop target."""
|
293
|
+
return self.target
|
294
|
+
|
295
|
+
def get_iterable(self) -> IExpression:
|
296
|
+
"""Get the iterable expression."""
|
297
|
+
return self.iterable
|
298
|
+
|
299
|
+
def get_body(self) -> List[IStatement]:
|
300
|
+
"""Get the loop body."""
|
301
|
+
return self.body
|
302
|
+
|
303
|
+
def add_statement(self, statement: IStatement) -> None:
|
304
|
+
"""Add a statement to the loop body."""
|
305
|
+
self.body.append(statement)
|
306
|
+
if isinstance(statement, Statement):
|
307
|
+
statement.set_parent(self)
|
308
|
+
|
309
|
+
class ReturnStatement(Statement):
|
310
|
+
"""A return statement."""
|
311
|
+
|
312
|
+
_fields = ('value',)
|
313
|
+
|
314
|
+
def __init__(self, value: Optional[IExpression] = None,
|
315
|
+
location: Optional[SourceLocation] = None):
|
316
|
+
super().__init__(location)
|
317
|
+
self.value = value
|
318
|
+
if isinstance(value, Expression):
|
319
|
+
value.set_parent(self)
|
320
|
+
|
321
|
+
def get_children(self) -> List[IAstNode]:
|
322
|
+
return [cast(IAstNode, self.value)] if self.value else []
|
323
|
+
|
324
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
325
|
+
return visitor.visit_returnstatement(self)
|
326
|
+
|
327
|
+
def get_value(self) -> Optional[IExpression]:
|
328
|
+
"""Get the return value, if any."""
|
329
|
+
return self.value
|
330
|
+
|
331
|
+
class BreakStatement(Statement):
|
332
|
+
"""A break statement."""
|
333
|
+
|
334
|
+
_fields = ()
|
335
|
+
|
336
|
+
def get_children(self) -> List[IAstNode]:
|
337
|
+
return []
|
338
|
+
|
339
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
340
|
+
return visitor.visit_breakstatement(self)
|
341
|
+
|
342
|
+
class ContinueStatement(Statement):
|
343
|
+
"""A continue statement."""
|
344
|
+
|
345
|
+
_fields = ()
|
346
|
+
|
347
|
+
def get_children(self) -> List[IAstNode]:
|
348
|
+
return []
|
349
|
+
|
350
|
+
def accept(self, visitor: IVisitor[T_VisitorResult]) -> T_VisitorResult:
|
351
|
+
return visitor.visit_continuestatement(self)
|
@@ -0,0 +1,114 @@
|
|
1
|
+
"""AST validators that use the function registry."""
|
2
|
+
|
3
|
+
from typing import List, Dict, Set, Optional, Tuple
|
4
|
+
from .core import Program
|
5
|
+
from .expressions import FunctionCall, AttributeAccess, Identifier
|
6
|
+
from .vex_nodes import VexAPICall
|
7
|
+
from ..visitors.base import AstVisitor
|
8
|
+
from ..registry.api import registry_api
|
9
|
+
|
10
|
+
class VexFunctionValidator(AstVisitor[List[Tuple[VexAPICall, str]]]):
|
11
|
+
"""Validates VEX function calls in the AST"""
|
12
|
+
|
13
|
+
def __init__(self):
|
14
|
+
self.errors: List[Tuple[VexAPICall, str]] = []
|
15
|
+
|
16
|
+
def generic_visit(self, node):
|
17
|
+
"""Visit children of non-VEX-specific nodes"""
|
18
|
+
for child in node.get_children():
|
19
|
+
self.visit(child)
|
20
|
+
return self.errors
|
21
|
+
|
22
|
+
def visit_vexapicall(self, node: VexAPICall):
|
23
|
+
"""Validate a VEX API call"""
|
24
|
+
valid, error = node.validate()
|
25
|
+
if not valid and error:
|
26
|
+
self.errors.append((node, error))
|
27
|
+
|
28
|
+
# Still visit children for nested calls
|
29
|
+
for child in node.get_children():
|
30
|
+
self.visit(child)
|
31
|
+
|
32
|
+
return self.errors
|
33
|
+
|
34
|
+
visit_program = generic_visit
|
35
|
+
visit_expression = generic_visit
|
36
|
+
visit_statement = generic_visit
|
37
|
+
visit_identifier = generic_visit
|
38
|
+
visit_variablereference = generic_visit
|
39
|
+
visit_attributeaccess = generic_visit
|
40
|
+
visit_binaryoperation = generic_visit
|
41
|
+
visit_unaryoperation = generic_visit
|
42
|
+
visit_keywordargument = generic_visit
|
43
|
+
visit_numberliteral = generic_visit
|
44
|
+
visit_stringliteral = generic_visit
|
45
|
+
visit_booleanliteral = generic_visit
|
46
|
+
visit_noneliteral = generic_visit
|
47
|
+
visit_expressionstatement = generic_visit
|
48
|
+
visit_assignment = generic_visit
|
49
|
+
visit_ifstatement = generic_visit
|
50
|
+
visit_whileloop = generic_visit
|
51
|
+
visit_forloop = generic_visit
|
52
|
+
visit_functiondefinition = generic_visit
|
53
|
+
visit_argument = generic_visit
|
54
|
+
visit_returnstatement = generic_visit
|
55
|
+
visit_breakstatement = generic_visit
|
56
|
+
visit_continuestatement = generic_visit
|
57
|
+
visit_motorcontrol = visit_vexapicall
|
58
|
+
visit_sensorreading = visit_vexapicall
|
59
|
+
visit_timingcontrol = visit_vexapicall
|
60
|
+
visit_displayoutput = visit_vexapicall
|
61
|
+
|
62
|
+
def visit_functioncall(self, node: FunctionCall):
|
63
|
+
"""Check if a regular function call is actually a VEX API call"""
|
64
|
+
# Try to determine if this is a VEX function call
|
65
|
+
function_name = None
|
66
|
+
|
67
|
+
# Direct function name
|
68
|
+
if isinstance(node.function, Identifier):
|
69
|
+
function_name = node.function.name
|
70
|
+
|
71
|
+
# Method call like motor.spin()
|
72
|
+
elif isinstance(node.function, AttributeAccess):
|
73
|
+
obj = node.function.object
|
74
|
+
if isinstance(obj, Identifier):
|
75
|
+
function_name = f"{obj.name}.{node.function.attribute}"
|
76
|
+
|
77
|
+
# Check if this is a known VEX function
|
78
|
+
if function_name:
|
79
|
+
is_vex_function = False
|
80
|
+
|
81
|
+
# Check if this is a method call on a known object type
|
82
|
+
if '.' in function_name:
|
83
|
+
obj_name, method_name = function_name.split('.', 1)
|
84
|
+
# For method calls, we need to check if the method exists for any object type
|
85
|
+
# since we don't know the actual type of the object at validation time
|
86
|
+
# Just check if the method name exists in the registry API
|
87
|
+
if registry_api.get_function(method_name):
|
88
|
+
is_vex_function = True
|
89
|
+
# Or check if it's a direct function
|
90
|
+
elif registry_api.get_function(function_name):
|
91
|
+
is_vex_function = True
|
92
|
+
|
93
|
+
if is_vex_function:
|
94
|
+
# Convert to VexAPICall and validate
|
95
|
+
vex_call = VexAPICall(
|
96
|
+
node.function,
|
97
|
+
node.args,
|
98
|
+
node.keywords,
|
99
|
+
node.location
|
100
|
+
)
|
101
|
+
valid, error = vex_call.validate()
|
102
|
+
if not valid and error:
|
103
|
+
self.errors.append((vex_call, error))
|
104
|
+
|
105
|
+
# Still visit children
|
106
|
+
for child in node.get_children():
|
107
|
+
self.visit(child)
|
108
|
+
|
109
|
+
return self.errors
|
110
|
+
|
111
|
+
def validate_vex_functions(ast: Program) -> List[Tuple[VexAPICall, str]]:
|
112
|
+
"""Validate all VEX function calls in the AST"""
|
113
|
+
validator = VexFunctionValidator()
|
114
|
+
return validator.visit(ast)
|