fluent-codegen 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fluent_codegen/__init__.py +1 -0
- fluent_codegen/ast_compat.py +93 -0
- fluent_codegen/codegen.py +1653 -0
- fluent_codegen/py.typed +0 -0
- fluent_codegen/utils.py +27 -0
- fluent_codegen-0.1.0.dist-info/METADATA +124 -0
- fluent_codegen-0.1.0.dist-info/RECORD +10 -0
- fluent_codegen-0.1.0.dist-info/WHEEL +5 -0
- fluent_codegen-0.1.0.dist-info/licenses/LICENSE +13 -0
- fluent_codegen-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1653 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for doing Python code generation
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import builtins
|
|
8
|
+
import enum
|
|
9
|
+
import keyword
|
|
10
|
+
import re
|
|
11
|
+
from abc import ABC, abstractmethod
|
|
12
|
+
from collections.abc import Callable, Sequence
|
|
13
|
+
from dataclasses import dataclass
|
|
14
|
+
from typing import ClassVar, Protocol, assert_never, overload, runtime_checkable
|
|
15
|
+
|
|
16
|
+
from . import ast_compat as py_ast
|
|
17
|
+
from .ast_compat import (
|
|
18
|
+
DEFAULT_AST_ARGS,
|
|
19
|
+
DEFAULT_AST_ARGS_ADD,
|
|
20
|
+
DEFAULT_AST_ARGS_ARGUMENTS,
|
|
21
|
+
DEFAULT_AST_ARGS_MODULE,
|
|
22
|
+
)
|
|
23
|
+
from .utils import allowable_keyword_arg_name, allowable_name
|
|
24
|
+
|
|
25
|
+
# This module provides simple utilities for building up Python source code.
|
|
26
|
+
# The design originally came from fluent-compiler, so had the following aims
|
|
27
|
+
# and constraints:
|
|
28
|
+
#
|
|
29
|
+
# 1. Performance.
|
|
30
|
+
#
|
|
31
|
+
# The resulting Python code should do as little as possible, especially for
|
|
32
|
+
# simple cases.
|
|
33
|
+
#
|
|
34
|
+
# 2. Correctness (obviously)
|
|
35
|
+
#
|
|
36
|
+
# In particular, we should try to make it hard to generate code that is
|
|
37
|
+
# syntactically correct and therefore compiles but doesn't work. We try to
|
|
38
|
+
# make it hard to generate accidental name clashes, or use variables that are
|
|
39
|
+
# not defined.
|
|
40
|
+
#
|
|
41
|
+
# Correctness also has a security implication, since the result of this code
|
|
42
|
+
# might be 'exec'ed. To that end:
|
|
43
|
+
# * We build up AST, rather than strings. This eliminates many
|
|
44
|
+
# potential bugs caused by wrong escaping/interpolation.
|
|
45
|
+
# * the `as_ast()` methods are paranoid about input, and do many asserts.
|
|
46
|
+
# We do this even though other layers will usually have checked the
|
|
47
|
+
# input, to allow us to reason locally when checking these methods. These
|
|
48
|
+
# asserts must also have 100% code coverage.
|
|
49
|
+
#
|
|
50
|
+
# 3. Simplicity
|
|
51
|
+
#
|
|
52
|
+
# The resulting Python code should be easy to read and understand.
|
|
53
|
+
#
|
|
54
|
+
# 4. Predictability
|
|
55
|
+
#
|
|
56
|
+
# Since we want to test the resulting source code, we have made some design
|
|
57
|
+
# decisions that aim to ensure things like function argument names are
|
|
58
|
+
# consistent and so can be predicted easily.
|
|
59
|
+
|
|
60
|
+
# Outside of fluent-compiler, this code will likely be useful for situations
|
|
61
|
+
# which have similar aims.
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
PROPERTY_TYPE = "PROPERTY_TYPE"
|
|
65
|
+
PROPERTY_RETURN_TYPE = "PROPERTY_RETURN_TYPE"
|
|
66
|
+
# UNKNOWN_TYPE is just an alias for `object` for clarity.
|
|
67
|
+
UNKNOWN_TYPE: type = object
|
|
68
|
+
# It is important for our usage of it that UNKNOWN_TYPE is a `type`,
|
|
69
|
+
# and the most general `type`.
|
|
70
|
+
assert isinstance(UNKNOWN_TYPE, type)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
SENSITIVE_FUNCTIONS = {
|
|
74
|
+
# builtin functions that we should never be calling from our code
|
|
75
|
+
# generation. This is a defense-in-depth mechansim to stop our code
|
|
76
|
+
# generation becoming a code execution vulnerability. There should also be
|
|
77
|
+
# higher level code that ensures we are not generating calls to arbitrary
|
|
78
|
+
# Python functions. This is not a comprehensive list of functions we are not
|
|
79
|
+
# using, but functions we definitely don't need and are most likely to be
|
|
80
|
+
# used to execute remote code or to get around safety mechanisms.
|
|
81
|
+
"__import__",
|
|
82
|
+
"__build_class__",
|
|
83
|
+
"apply",
|
|
84
|
+
"compile",
|
|
85
|
+
"eval",
|
|
86
|
+
"exec",
|
|
87
|
+
"execfile",
|
|
88
|
+
"exit",
|
|
89
|
+
"file",
|
|
90
|
+
"globals",
|
|
91
|
+
"locals",
|
|
92
|
+
"open",
|
|
93
|
+
"object",
|
|
94
|
+
"reload",
|
|
95
|
+
"type",
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class CodeGenAst(ABC):
|
|
100
|
+
"""
|
|
101
|
+
Base class representing a simplified Python AST (not the real one).
|
|
102
|
+
Generates real `ast.*` nodes via `as_ast()` method.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
@abstractmethod
|
|
106
|
+
def as_ast(self) -> py_ast.AST: ...
|
|
107
|
+
|
|
108
|
+
child_elements: ClassVar[list[str]]
|
|
109
|
+
|
|
110
|
+
def as_python_source(self) -> str:
|
|
111
|
+
"""Return the Python source code for this AST node."""
|
|
112
|
+
node = self.as_ast()
|
|
113
|
+
py_ast.fix_missing_locations(node)
|
|
114
|
+
return py_ast.unparse(node)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class CodeGenAstList(ABC):
|
|
118
|
+
"""
|
|
119
|
+
Alternative base class to CodeGenAst when we have code that wants to return a
|
|
120
|
+
list of AST objects. These must also be `stmt` objects.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def as_ast_list(self, allow_empty: bool = True) -> list[py_ast.stmt]: ...
|
|
125
|
+
|
|
126
|
+
child_elements: ClassVar[list[str]]
|
|
127
|
+
|
|
128
|
+
def as_python_source(self) -> str:
|
|
129
|
+
"""Return the Python source code for this AST list."""
|
|
130
|
+
mod = py_ast.Module(body=self.as_ast_list(), type_ignores=[], **DEFAULT_AST_ARGS_MODULE)
|
|
131
|
+
py_ast.fix_missing_locations(mod)
|
|
132
|
+
return py_ast.unparse(mod)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
CodeGenAstType = CodeGenAst | CodeGenAstList
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class Scope:
|
|
139
|
+
def __init__(self, parent_scope: Scope | None = None):
|
|
140
|
+
self.parent_scope = parent_scope
|
|
141
|
+
self.names: set[str] = set()
|
|
142
|
+
self._function_arg_reserved_names: set[str] = set()
|
|
143
|
+
self._assignments: set[str] = set()
|
|
144
|
+
|
|
145
|
+
def is_name_in_use(self, name: str) -> bool:
|
|
146
|
+
if name in self.names:
|
|
147
|
+
return True
|
|
148
|
+
|
|
149
|
+
if self.parent_scope is None:
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
return self.parent_scope.is_name_in_use(name)
|
|
153
|
+
|
|
154
|
+
def is_name_reserved_function_arg(self, name: str) -> bool:
|
|
155
|
+
if name in self._function_arg_reserved_names:
|
|
156
|
+
return True
|
|
157
|
+
|
|
158
|
+
if self.parent_scope is None:
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
return self.parent_scope.is_name_reserved_function_arg(name)
|
|
162
|
+
|
|
163
|
+
def is_name_reserved(self, name: str) -> bool:
|
|
164
|
+
return self.is_name_in_use(name) or self.is_name_reserved_function_arg(name)
|
|
165
|
+
|
|
166
|
+
def reserve_name(
|
|
167
|
+
self,
|
|
168
|
+
requested: str,
|
|
169
|
+
function_arg: bool = False,
|
|
170
|
+
is_builtin: bool = False,
|
|
171
|
+
):
|
|
172
|
+
"""
|
|
173
|
+
Reserve a name as being in use in a scope.
|
|
174
|
+
|
|
175
|
+
Pass function_arg=True if this is a function argument.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
def _add(final: str):
|
|
179
|
+
self.names.add(final)
|
|
180
|
+
return final
|
|
181
|
+
|
|
182
|
+
if function_arg:
|
|
183
|
+
if self.is_name_reserved_function_arg(requested):
|
|
184
|
+
assert not self.is_name_in_use(requested)
|
|
185
|
+
return _add(requested)
|
|
186
|
+
if self.is_name_reserved(requested):
|
|
187
|
+
raise AssertionError(f"Cannot use '{requested}' as argument name as it is already in use")
|
|
188
|
+
|
|
189
|
+
cleaned = cleanup_name(requested)
|
|
190
|
+
|
|
191
|
+
attempt = cleaned
|
|
192
|
+
count = 2 # instance without suffix is regarded as 1
|
|
193
|
+
# To avoid shadowing of global names in local scope, we
|
|
194
|
+
# take into account parent scope when assigning names.
|
|
195
|
+
|
|
196
|
+
def _is_name_allowed(name: str) -> bool:
|
|
197
|
+
# We need to also protect against using keywords ('class', 'def' etc.)
|
|
198
|
+
# i.e. count all keywords as 'used'.
|
|
199
|
+
# However, some builtins are also keywords (e.g. 'None'), and so
|
|
200
|
+
# if a builtin is being reserved, don't check against the keyword list
|
|
201
|
+
if (not is_builtin) and keyword.iskeyword(name):
|
|
202
|
+
return False
|
|
203
|
+
|
|
204
|
+
return not self.is_name_reserved(name)
|
|
205
|
+
|
|
206
|
+
while not _is_name_allowed(attempt):
|
|
207
|
+
attempt = cleaned + str(count)
|
|
208
|
+
count += 1
|
|
209
|
+
|
|
210
|
+
return _add(attempt)
|
|
211
|
+
|
|
212
|
+
def reserve_function_arg_name(self, name: str):
|
|
213
|
+
"""
|
|
214
|
+
Reserve a name for *later* use as a function argument. This does not result
|
|
215
|
+
in that name being considered 'in use' in the current scope, but will
|
|
216
|
+
avoid the name being assigned for any use other than as a function argument.
|
|
217
|
+
"""
|
|
218
|
+
# To keep things simple, and the generated code predictable, we reserve
|
|
219
|
+
# names for all function arguments in a separate scope, and insist on
|
|
220
|
+
# the exact names
|
|
221
|
+
if self.is_name_reserved(name):
|
|
222
|
+
raise AssertionError(f"Can't reserve '{name}' as function arg name as it is already reserved")
|
|
223
|
+
self._function_arg_reserved_names.add(name)
|
|
224
|
+
|
|
225
|
+
def has_assignment(self, name: str) -> bool:
|
|
226
|
+
return name in self._assignments
|
|
227
|
+
|
|
228
|
+
def register_assignment(self, name: str) -> None:
|
|
229
|
+
self._assignments.add(name)
|
|
230
|
+
|
|
231
|
+
def create_name(self, name: str) -> Name:
|
|
232
|
+
reserved = self.reserve_name(name)
|
|
233
|
+
return Name(reserved, self)
|
|
234
|
+
|
|
235
|
+
def name(self, name: str) -> Name:
|
|
236
|
+
# Convenience utility for returning a Name
|
|
237
|
+
return Name(name, self)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
_IDENTIFIER_SANITIZER_RE = re.compile("[^a-zA-Z0-9_]")
|
|
241
|
+
_IDENTIFIER_START_RE = re.compile("^[a-zA-Z_]")
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
def cleanup_name(name: str) -> str:
|
|
245
|
+
"""
|
|
246
|
+
Convert name to a allowable identifier
|
|
247
|
+
"""
|
|
248
|
+
# See https://docs.python.org/2/reference/lexical_analysis.html#grammar-token-identifier
|
|
249
|
+
name = _IDENTIFIER_SANITIZER_RE.sub("", name)
|
|
250
|
+
if not _IDENTIFIER_START_RE.match(name):
|
|
251
|
+
name = "n" + name
|
|
252
|
+
return name
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class Statement(CodeGenAst):
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
@runtime_checkable
|
|
260
|
+
class SupportsNameAssignment(Protocol):
|
|
261
|
+
def has_assignment_for_name(self, name: str) -> bool: ...
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class _Annotation(Statement):
|
|
265
|
+
"""A bare type annotation without a value, e.g. ``x: int``."""
|
|
266
|
+
|
|
267
|
+
child_elements = []
|
|
268
|
+
|
|
269
|
+
def __init__(self, name: str, annotation: Expression):
|
|
270
|
+
self.name = name
|
|
271
|
+
self.annotation = annotation
|
|
272
|
+
|
|
273
|
+
def as_ast(self):
|
|
274
|
+
if not allowable_name(self.name):
|
|
275
|
+
raise AssertionError(f"Expected {self.name} to be a valid Python identifier")
|
|
276
|
+
return py_ast.AnnAssign(
|
|
277
|
+
target=py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS),
|
|
278
|
+
annotation=self.annotation.as_ast(),
|
|
279
|
+
simple=1,
|
|
280
|
+
value=None,
|
|
281
|
+
**DEFAULT_AST_ARGS,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
def has_assignment_for_name(self, name: str) -> bool:
|
|
285
|
+
return self.name == name
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class _Assignment(Statement):
|
|
289
|
+
child_elements = ["value"]
|
|
290
|
+
|
|
291
|
+
def __init__(self, name: str, value: Expression, /, *, type_hint: Expression | None = None):
|
|
292
|
+
self.name = name
|
|
293
|
+
self.value = value
|
|
294
|
+
self.type_hint = type_hint
|
|
295
|
+
|
|
296
|
+
def as_ast(self):
|
|
297
|
+
if not allowable_name(self.name):
|
|
298
|
+
raise AssertionError(f"Expected {self.name} to be a valid Python identifier")
|
|
299
|
+
if self.type_hint is None:
|
|
300
|
+
return py_ast.Assign(
|
|
301
|
+
targets=[py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS)],
|
|
302
|
+
value=self.value.as_ast(),
|
|
303
|
+
**DEFAULT_AST_ARGS,
|
|
304
|
+
)
|
|
305
|
+
else:
|
|
306
|
+
return py_ast.AnnAssign(
|
|
307
|
+
target=py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS),
|
|
308
|
+
annotation=self.type_hint.as_ast(),
|
|
309
|
+
simple=1, # not sure what this does...
|
|
310
|
+
value=self.value.as_ast(),
|
|
311
|
+
**DEFAULT_AST_ARGS,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
def has_assignment_for_name(self, name: str) -> bool:
|
|
315
|
+
return self.name == name
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
class Block(CodeGenAstList):
|
|
319
|
+
child_elements = ["statements"]
|
|
320
|
+
|
|
321
|
+
def __init__(self, scope: Scope, parent_block: Block | None = None):
|
|
322
|
+
self.scope = scope
|
|
323
|
+
# We all `Expression` here for things like MethodCall which
|
|
324
|
+
# are bare expressions that are still useful for side effects
|
|
325
|
+
self.statements: list[Block | Statement | Expression] = []
|
|
326
|
+
self.parent_block = parent_block
|
|
327
|
+
|
|
328
|
+
def as_ast_list(self, allow_empty: bool = True) -> list[py_ast.stmt]:
|
|
329
|
+
retval: list[py_ast.stmt] = []
|
|
330
|
+
for s in self.statements:
|
|
331
|
+
if isinstance(s, CodeGenAstList):
|
|
332
|
+
retval.extend(s.as_ast_list(allow_empty=True))
|
|
333
|
+
else:
|
|
334
|
+
if isinstance(s, Statement):
|
|
335
|
+
ast_obj = s.as_ast()
|
|
336
|
+
assert isinstance(ast_obj, py_ast.stmt), (
|
|
337
|
+
"Statement object return {ast_obj} which is not a subclass of py_ast.stmt"
|
|
338
|
+
)
|
|
339
|
+
retval.append(ast_obj)
|
|
340
|
+
else:
|
|
341
|
+
# Things like bare function/method calls need to be wrapped
|
|
342
|
+
# in `Expr` to match the way Python parses.
|
|
343
|
+
retval.append(py_ast.Expr(s.as_ast(), **DEFAULT_AST_ARGS))
|
|
344
|
+
|
|
345
|
+
if len(retval) == 0 and not allow_empty:
|
|
346
|
+
return [py_ast.Pass(**DEFAULT_AST_ARGS)]
|
|
347
|
+
return retval
|
|
348
|
+
|
|
349
|
+
def add_statement(self, statement: Statement | Block | Expression) -> None:
|
|
350
|
+
self.statements.append(statement)
|
|
351
|
+
if isinstance(statement, Block):
|
|
352
|
+
if statement.parent_block is None:
|
|
353
|
+
statement.parent_block = self
|
|
354
|
+
else:
|
|
355
|
+
if statement.parent_block != self:
|
|
356
|
+
raise AssertionError(
|
|
357
|
+
f"Block {statement} is already child of {statement.parent_block}, can't reassign to {self}"
|
|
358
|
+
)
|
|
359
|
+
|
|
360
|
+
def create_import(self, module: str, as_: str | None = None) -> tuple[Import, Name]:
|
|
361
|
+
|
|
362
|
+
return_name_object: Name
|
|
363
|
+
if as_ is not None:
|
|
364
|
+
# "import foo as bar" results in `bar` name being assigned.
|
|
365
|
+
if not allowable_name(as_):
|
|
366
|
+
raise AssertionError(f"{as_!r} is not an allowable 'as' name")
|
|
367
|
+
if self.scope.is_name_in_use(as_):
|
|
368
|
+
raise AssertionError(f"{as_!r} is already assigned in the scope")
|
|
369
|
+
as_name_object = self.scope.create_name(as_)
|
|
370
|
+
return_name_object = as_name_object
|
|
371
|
+
else:
|
|
372
|
+
as_name_object = None
|
|
373
|
+
# "import foo" results in `foo` name being assigned
|
|
374
|
+
# "import foo.bar" also results in `foo` being reserved.
|
|
375
|
+
dotted_parts = module.split(".")
|
|
376
|
+
for part in dotted_parts:
|
|
377
|
+
if not allowable_name(part):
|
|
378
|
+
raise AssertionError(f"{module!r} not an allowable 'import' name")
|
|
379
|
+
name_to_assign = dotted_parts[0]
|
|
380
|
+
# We can't rename, so don't use `reserve_name` or `create_name`.
|
|
381
|
+
# We also need to allow for multiple imports, like `import foo.bar` then `import foo.baz`
|
|
382
|
+
|
|
383
|
+
if not self.scope.is_name_in_use(name_to_assign):
|
|
384
|
+
self.scope.reserve_name(name_to_assign)
|
|
385
|
+
return_name_object = self.scope.name(name_to_assign)
|
|
386
|
+
|
|
387
|
+
import_statement = Import(module=module, as_=as_name_object)
|
|
388
|
+
self.add_statement(import_statement)
|
|
389
|
+
return import_statement, return_name_object
|
|
390
|
+
|
|
391
|
+
def create_import_from(self, *, from_: str, import_: str, as_: str | None = None) -> tuple[ImportFrom, Name]:
|
|
392
|
+
|
|
393
|
+
return_name_object: Name
|
|
394
|
+
if as_ is not None:
|
|
395
|
+
# "from foo import bar as baz" results in `baz` name being assigned.
|
|
396
|
+
if not allowable_name(as_):
|
|
397
|
+
raise AssertionError(f"{as_!r} is not an allowable 'as' name")
|
|
398
|
+
if self.scope.is_name_in_use(as_):
|
|
399
|
+
raise AssertionError(f"{as_!r} is already assigned in the scope")
|
|
400
|
+
as_name_object = self.scope.create_name(as_)
|
|
401
|
+
return_name_object = as_name_object
|
|
402
|
+
else:
|
|
403
|
+
as_name_object = None
|
|
404
|
+
# Check the dotted bit.
|
|
405
|
+
dotted_parts = from_.split(".")
|
|
406
|
+
for part in dotted_parts:
|
|
407
|
+
if not allowable_name(part):
|
|
408
|
+
raise AssertionError(f"{from_!r} not an allowable 'import' name")
|
|
409
|
+
|
|
410
|
+
# Check the `import_` for clashes.
|
|
411
|
+
name_to_assign = import_
|
|
412
|
+
|
|
413
|
+
if self.scope.is_name_in_use(name_to_assign):
|
|
414
|
+
raise AssertionError(f"{name_to_assign!r} is already assigned in the scope")
|
|
415
|
+
return_name_object = self.scope.create_name(name_to_assign)
|
|
416
|
+
|
|
417
|
+
import_statement = ImportFrom(from_module=from_, import_=import_, as_=as_name_object)
|
|
418
|
+
self.add_statement(import_statement)
|
|
419
|
+
return import_statement, return_name_object
|
|
420
|
+
|
|
421
|
+
# Safe alternatives to Block.statements being manipulated directly:
|
|
422
|
+
def create_assignment(
|
|
423
|
+
self, name: str | Name, value: Expression, *, type_hint: Expression | None = None, allow_multiple: bool = False
|
|
424
|
+
):
|
|
425
|
+
"""
|
|
426
|
+
Adds an assigment of the form:
|
|
427
|
+
|
|
428
|
+
x = value
|
|
429
|
+
"""
|
|
430
|
+
if isinstance(name, Name):
|
|
431
|
+
name = name.name
|
|
432
|
+
if not self.scope.is_name_in_use(name):
|
|
433
|
+
raise AssertionError(f"Cannot assign to unreserved name '{name}'")
|
|
434
|
+
|
|
435
|
+
if self.scope.has_assignment(name):
|
|
436
|
+
if not allow_multiple:
|
|
437
|
+
raise AssertionError(f"Have already assigned to '{name}' in this scope")
|
|
438
|
+
else:
|
|
439
|
+
self.scope.register_assignment(name)
|
|
440
|
+
|
|
441
|
+
self.add_statement(_Assignment(name, value, type_hint=type_hint))
|
|
442
|
+
|
|
443
|
+
def create_annotation(self, name: str, annotation: Expression) -> Name:
|
|
444
|
+
"""
|
|
445
|
+
Adds a bare type annotation of the form::
|
|
446
|
+
|
|
447
|
+
x: int
|
|
448
|
+
|
|
449
|
+
Reserves the name and adds the annotation statement to the block.
|
|
450
|
+
"""
|
|
451
|
+
name_obj = self.scope.create_name(name)
|
|
452
|
+
self.scope.register_assignment(name_obj.name)
|
|
453
|
+
self.add_statement(_Annotation(name_obj.name, annotation))
|
|
454
|
+
return name_obj
|
|
455
|
+
|
|
456
|
+
def create_field(self, name: str, annotation: Expression, *, default: Expression | None = None) -> Name:
|
|
457
|
+
"""
|
|
458
|
+
Create a typed field, typically used in dataclass bodies.
|
|
459
|
+
|
|
460
|
+
If *default* is provided, creates an annotated assignment::
|
|
461
|
+
|
|
462
|
+
x: int = 0
|
|
463
|
+
|
|
464
|
+
Otherwise, creates a bare annotation::
|
|
465
|
+
|
|
466
|
+
x: int
|
|
467
|
+
"""
|
|
468
|
+
if default is not None:
|
|
469
|
+
name_obj = self.scope.create_name(name)
|
|
470
|
+
self.scope.register_assignment(name_obj.name)
|
|
471
|
+
self.add_statement(_Assignment(name_obj.name, default, type_hint=annotation))
|
|
472
|
+
return name_obj
|
|
473
|
+
else:
|
|
474
|
+
return self.create_annotation(name, annotation)
|
|
475
|
+
|
|
476
|
+
def create_function(
|
|
477
|
+
self,
|
|
478
|
+
name: str,
|
|
479
|
+
args: Sequence[str | FunctionArg],
|
|
480
|
+
decorators: Sequence[Expression] | None = None,
|
|
481
|
+
return_type: Expression | None = None,
|
|
482
|
+
) -> tuple[Function, Name]:
|
|
483
|
+
"""
|
|
484
|
+
Reserve a name for a function, create the Function and add the function statement
|
|
485
|
+
to the block.
|
|
486
|
+
"""
|
|
487
|
+
name_obj = self.scope.create_name(name)
|
|
488
|
+
func = Function(
|
|
489
|
+
name_obj.name, args=args, parent_scope=self.scope, decorators=decorators, return_type=return_type
|
|
490
|
+
)
|
|
491
|
+
self.add_statement(func)
|
|
492
|
+
return func, name_obj
|
|
493
|
+
|
|
494
|
+
def create_class(
|
|
495
|
+
self,
|
|
496
|
+
name: str,
|
|
497
|
+
bases: Sequence[Expression] | None = None,
|
|
498
|
+
decorators: Sequence[Expression] | None = None,
|
|
499
|
+
) -> tuple[Class, Name]:
|
|
500
|
+
"""
|
|
501
|
+
Reserve a name for a class, create the Class and add the class statement
|
|
502
|
+
to the block.
|
|
503
|
+
"""
|
|
504
|
+
name_obj = self.scope.create_name(name)
|
|
505
|
+
cls = Class(name_obj.name, parent_scope=self.scope, bases=bases, decorators=decorators)
|
|
506
|
+
self.add_statement(cls)
|
|
507
|
+
return cls, name_obj
|
|
508
|
+
|
|
509
|
+
def create_return(self, value: Expression) -> None:
|
|
510
|
+
self.add_statement(Return(value))
|
|
511
|
+
|
|
512
|
+
def create_if(self) -> If:
|
|
513
|
+
"""
|
|
514
|
+
Create an If statement, add it to this block, and return it.
|
|
515
|
+
|
|
516
|
+
Usage::
|
|
517
|
+
|
|
518
|
+
if_stmt = block.create_if()
|
|
519
|
+
if_block = if_stmt.add_if(condition)
|
|
520
|
+
if_block.create_return(value)
|
|
521
|
+
"""
|
|
522
|
+
if_statement = If(self.scope, parent_block=self)
|
|
523
|
+
self.add_statement(if_statement)
|
|
524
|
+
return if_statement
|
|
525
|
+
|
|
526
|
+
def create_with(self, context_expr: Expression, target: Name | None = None) -> With:
|
|
527
|
+
"""
|
|
528
|
+
Create a With statement, add it to this block, and return it
|
|
529
|
+
|
|
530
|
+
Usage::
|
|
531
|
+
|
|
532
|
+
with_stmt = block.create_with(expr, "f")
|
|
533
|
+
with_stmt.body.create_return(value)
|
|
534
|
+
"""
|
|
535
|
+
with_statement = With(context_expr, target=target, parent_scope=self.scope, parent_block=self)
|
|
536
|
+
self.add_statement(with_statement)
|
|
537
|
+
return with_statement
|
|
538
|
+
|
|
539
|
+
def has_assignment_for_name(self, name: str) -> bool:
|
|
540
|
+
for s in self.statements:
|
|
541
|
+
if isinstance(s, SupportsNameAssignment) and s.has_assignment_for_name(name):
|
|
542
|
+
return True
|
|
543
|
+
if self.parent_block is not None:
|
|
544
|
+
return self.parent_block.has_assignment_for_name(name)
|
|
545
|
+
return False
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
class Module(Block, CodeGenAst):
|
|
549
|
+
def __init__(self, reserve_builtins: bool = True):
|
|
550
|
+
scope = Scope(parent_scope=None)
|
|
551
|
+
if reserve_builtins:
|
|
552
|
+
for name in dir(builtins):
|
|
553
|
+
scope.reserve_name(name, is_builtin=True)
|
|
554
|
+
Block.__init__(self, scope)
|
|
555
|
+
self.file_comments: list[str] = []
|
|
556
|
+
|
|
557
|
+
def as_ast(self) -> py_ast.Module:
|
|
558
|
+
return py_ast.Module(body=self.as_ast_list(), type_ignores=[], **DEFAULT_AST_ARGS_MODULE)
|
|
559
|
+
|
|
560
|
+
def as_python_source(self) -> str:
|
|
561
|
+
main = super().as_python_source()
|
|
562
|
+
file_comments = "".join(f"# {comment}\n" for comment in self.file_comments)
|
|
563
|
+
return file_comments + main
|
|
564
|
+
|
|
565
|
+
def add_file_comment(self, comment: str) -> None:
|
|
566
|
+
self.file_comments.append(comment)
|
|
567
|
+
|
|
568
|
+
|
|
569
|
+
class ArgKind(enum.Enum):
|
|
570
|
+
"""The kind of a function argument."""
|
|
571
|
+
|
|
572
|
+
POSITIONAL_ONLY = "positional_only"
|
|
573
|
+
POSITIONAL_OR_KEYWORD = "positional_or_keyword"
|
|
574
|
+
KEYWORD_ONLY = "keyword_only"
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
@dataclass(frozen=True)
|
|
578
|
+
class FunctionArg:
|
|
579
|
+
"""A function argument with a name, kind, and optional default value."""
|
|
580
|
+
|
|
581
|
+
name: str
|
|
582
|
+
kind: ArgKind = ArgKind.POSITIONAL_OR_KEYWORD
|
|
583
|
+
default: Expression | None = None
|
|
584
|
+
annotation: Expression | None = None
|
|
585
|
+
|
|
586
|
+
@classmethod
|
|
587
|
+
def positional(
|
|
588
|
+
cls, name: str, *, default: Expression | None = None, annotation: Expression | None = None
|
|
589
|
+
) -> FunctionArg:
|
|
590
|
+
"""Create a positional-only argument."""
|
|
591
|
+
return cls(name=name, kind=ArgKind.POSITIONAL_ONLY, default=default, annotation=annotation)
|
|
592
|
+
|
|
593
|
+
@classmethod
|
|
594
|
+
def keyword(
|
|
595
|
+
cls, name: str, *, default: Expression | None = None, annotation: Expression | None = None
|
|
596
|
+
) -> FunctionArg:
|
|
597
|
+
"""Create a keyword-only argument."""
|
|
598
|
+
return cls(name=name, kind=ArgKind.KEYWORD_ONLY, default=default, annotation=annotation)
|
|
599
|
+
|
|
600
|
+
@classmethod
|
|
601
|
+
def standard(
|
|
602
|
+
cls, name: str, *, default: Expression | None = None, annotation: Expression | None = None
|
|
603
|
+
) -> FunctionArg:
|
|
604
|
+
"""Create a positional-or-keyword argument (the Python default)."""
|
|
605
|
+
return cls(name=name, kind=ArgKind.POSITIONAL_OR_KEYWORD, default=default, annotation=annotation)
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def _normalize_args(args: Sequence[str | FunctionArg]) -> list[FunctionArg]:
|
|
609
|
+
"""Normalize a mixed list of str and FunctionArg into a list of FunctionArg."""
|
|
610
|
+
return [FunctionArg(name=a) if isinstance(a, str) else a for a in args]
|
|
611
|
+
|
|
612
|
+
|
|
613
|
+
def _validate_arg_order(args: list[FunctionArg]) -> None:
|
|
614
|
+
"""Validate that args are in the correct order:
|
|
615
|
+
positional-only, then positional-or-keyword, then keyword-only.
|
|
616
|
+
Within each group, defaults must come after non-defaults.
|
|
617
|
+
"""
|
|
618
|
+
# Check kind ordering
|
|
619
|
+
KIND_ORDER = {
|
|
620
|
+
ArgKind.POSITIONAL_ONLY: 0,
|
|
621
|
+
ArgKind.POSITIONAL_OR_KEYWORD: 1,
|
|
622
|
+
ArgKind.KEYWORD_ONLY: 2,
|
|
623
|
+
}
|
|
624
|
+
prev_order = -1
|
|
625
|
+
for arg in args:
|
|
626
|
+
order = KIND_ORDER[arg.kind]
|
|
627
|
+
if order < prev_order:
|
|
628
|
+
raise ValueError(
|
|
629
|
+
f"Argument '{arg.name}' of kind {arg.kind.value} "
|
|
630
|
+
f"is out of order: positional-only args must come first, "
|
|
631
|
+
f"then positional-or-keyword, then keyword-only"
|
|
632
|
+
)
|
|
633
|
+
prev_order = order
|
|
634
|
+
|
|
635
|
+
# Check default ordering within positional groups
|
|
636
|
+
# (positional-only and positional-or-keyword share defaults list,
|
|
637
|
+
# so non-default can't follow default across these groups)
|
|
638
|
+
seen_default_in_positional = False
|
|
639
|
+
for arg in args:
|
|
640
|
+
if arg.kind in (ArgKind.POSITIONAL_ONLY, ArgKind.POSITIONAL_OR_KEYWORD):
|
|
641
|
+
if arg.default is not None:
|
|
642
|
+
seen_default_in_positional = True
|
|
643
|
+
elif seen_default_in_positional:
|
|
644
|
+
raise ValueError(f"Non-default argument '{arg.name}' follows default argument in positional arguments")
|
|
645
|
+
|
|
646
|
+
# keyword-only args can have defaults in any order (Python allows it)
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
class Function(Scope, Statement):
|
|
650
|
+
child_elements = ["body"]
|
|
651
|
+
|
|
652
|
+
def __init__(
|
|
653
|
+
self,
|
|
654
|
+
name: str,
|
|
655
|
+
args: Sequence[str | FunctionArg] | None = None,
|
|
656
|
+
parent_scope: Scope | None = None,
|
|
657
|
+
decorators: Sequence[Expression] | None = None,
|
|
658
|
+
return_type: Expression | None = None,
|
|
659
|
+
):
|
|
660
|
+
super().__init__(parent_scope=parent_scope)
|
|
661
|
+
self.body = Block(self)
|
|
662
|
+
self.func_name = name
|
|
663
|
+
self.decorators: list[Expression] = list(decorators) if decorators else []
|
|
664
|
+
self.return_type: Expression | None = return_type
|
|
665
|
+
self._args: list[FunctionArg] = []
|
|
666
|
+
if args is not None:
|
|
667
|
+
self.add_args(args)
|
|
668
|
+
|
|
669
|
+
def add_args(self, args: Sequence[str | FunctionArg]) -> None:
|
|
670
|
+
"""Add arguments to the function, with the same validation as in __init__."""
|
|
671
|
+
normalized = _normalize_args(args)
|
|
672
|
+
combined = self._args + normalized
|
|
673
|
+
_validate_arg_order(combined)
|
|
674
|
+
for arg in normalized:
|
|
675
|
+
if self.is_name_in_use(arg.name):
|
|
676
|
+
raise AssertionError(f"Can't use '{arg.name}' as function argument name because it shadows other names")
|
|
677
|
+
self.reserve_name(arg.name, function_arg=True)
|
|
678
|
+
self._args = combined
|
|
679
|
+
|
|
680
|
+
def as_ast(self) -> py_ast.stmt:
|
|
681
|
+
if not allowable_name(self.func_name):
|
|
682
|
+
raise AssertionError(f"Expected '{self.func_name}' to be a valid Python identifier")
|
|
683
|
+
for arg in self._args:
|
|
684
|
+
if not allowable_name(arg.name):
|
|
685
|
+
raise AssertionError(f"Expected '{arg.name}' to be a valid Python identifier")
|
|
686
|
+
|
|
687
|
+
def _make_arg(a: FunctionArg) -> py_ast.arg:
|
|
688
|
+
return py_ast.arg(
|
|
689
|
+
arg=a.name,
|
|
690
|
+
annotation=a.annotation.as_ast() if a.annotation is not None else None,
|
|
691
|
+
**DEFAULT_AST_ARGS,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
posonlyargs = [_make_arg(a) for a in self._args if a.kind == ArgKind.POSITIONAL_ONLY]
|
|
695
|
+
regular_args = [_make_arg(a) for a in self._args if a.kind == ArgKind.POSITIONAL_OR_KEYWORD]
|
|
696
|
+
kwonlyargs = [_make_arg(a) for a in self._args if a.kind == ArgKind.KEYWORD_ONLY]
|
|
697
|
+
|
|
698
|
+
# defaults: right-aligned to posonlyargs + regular_args
|
|
699
|
+
positional_all = [a for a in self._args if a.kind in (ArgKind.POSITIONAL_ONLY, ArgKind.POSITIONAL_OR_KEYWORD)]
|
|
700
|
+
defaults = [a.default.as_ast() for a in positional_all if a.default is not None]
|
|
701
|
+
|
|
702
|
+
# kw_defaults: one entry per kwonlyarg, None if no default
|
|
703
|
+
kw_defaults: list[py_ast.expr | None] = [
|
|
704
|
+
a.default.as_ast() if a.default is not None else None for a in self._args if a.kind == ArgKind.KEYWORD_ONLY
|
|
705
|
+
]
|
|
706
|
+
|
|
707
|
+
return py_ast.FunctionDef(
|
|
708
|
+
name=self.func_name,
|
|
709
|
+
args=py_ast.arguments(
|
|
710
|
+
posonlyargs=posonlyargs,
|
|
711
|
+
args=regular_args,
|
|
712
|
+
vararg=None,
|
|
713
|
+
kwonlyargs=kwonlyargs,
|
|
714
|
+
kw_defaults=kw_defaults,
|
|
715
|
+
kwarg=None,
|
|
716
|
+
defaults=defaults,
|
|
717
|
+
**DEFAULT_AST_ARGS_ARGUMENTS,
|
|
718
|
+
),
|
|
719
|
+
body=self.body.as_ast_list(allow_empty=False),
|
|
720
|
+
decorator_list=[d.as_ast() for d in self.decorators],
|
|
721
|
+
type_params=[],
|
|
722
|
+
returns=self.return_type.as_ast() if self.return_type is not None else None,
|
|
723
|
+
**DEFAULT_AST_ARGS,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
def create_return(self, value: Expression):
|
|
727
|
+
self.body.create_return(value)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
class Class(Scope, Statement):
|
|
731
|
+
child_elements = ["body"]
|
|
732
|
+
|
|
733
|
+
def __init__(
|
|
734
|
+
self,
|
|
735
|
+
name: str,
|
|
736
|
+
parent_scope: Scope | None = None,
|
|
737
|
+
bases: Sequence[Expression] | None = None,
|
|
738
|
+
decorators: Sequence[Expression] | None = None,
|
|
739
|
+
):
|
|
740
|
+
super().__init__(parent_scope=parent_scope)
|
|
741
|
+
self.body = Block(self)
|
|
742
|
+
self.class_name = name
|
|
743
|
+
self.bases: list[Expression] = list(bases) if bases else []
|
|
744
|
+
self.decorators: list[Expression] = list(decorators) if decorators else []
|
|
745
|
+
|
|
746
|
+
def as_ast(self) -> py_ast.stmt:
|
|
747
|
+
if not allowable_name(self.class_name):
|
|
748
|
+
raise AssertionError(f"Expected '{self.class_name}' to be a valid Python identifier")
|
|
749
|
+
return py_ast.ClassDef(
|
|
750
|
+
name=self.class_name,
|
|
751
|
+
bases=[b.as_ast() for b in self.bases],
|
|
752
|
+
keywords=[],
|
|
753
|
+
body=self.body.as_ast_list(allow_empty=False),
|
|
754
|
+
decorator_list=[d.as_ast() for d in self.decorators],
|
|
755
|
+
type_params=[],
|
|
756
|
+
**DEFAULT_AST_ARGS,
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
|
|
760
|
+
class Return(Statement):
|
|
761
|
+
child_elements = ["value"]
|
|
762
|
+
|
|
763
|
+
def __init__(self, value: Expression):
|
|
764
|
+
self.value = value
|
|
765
|
+
|
|
766
|
+
def as_ast(self):
|
|
767
|
+
return py_ast.Return(self.value.as_ast(), **DEFAULT_AST_ARGS)
|
|
768
|
+
|
|
769
|
+
def __repr__(self):
|
|
770
|
+
return f"Return({repr(self.value)}"
|
|
771
|
+
|
|
772
|
+
|
|
773
|
+
class If(Statement):
|
|
774
|
+
child_elements = ["if_blocks", "conditions", "else_block"]
|
|
775
|
+
|
|
776
|
+
def __init__(self, parent_scope: Scope, parent_block: Block | None = None):
|
|
777
|
+
# We model a "compound if statement" as a list of if blocks
|
|
778
|
+
# (if/elif/elif etc), each with their own condition, with a final else
|
|
779
|
+
# block. Note this is quite different from Python's AST for the same
|
|
780
|
+
# thing, so conversion to AST is more complex because of this.
|
|
781
|
+
self.if_blocks: list[Block] = []
|
|
782
|
+
self.conditions: list[Expression] = []
|
|
783
|
+
self._parent_block = parent_block
|
|
784
|
+
self.else_block = Block(parent_scope, parent_block=self._parent_block)
|
|
785
|
+
self._parent_scope = parent_scope
|
|
786
|
+
|
|
787
|
+
def create_if_branch(self, condition: Expression) -> Block:
|
|
788
|
+
"""
|
|
789
|
+
Create new if branch with a condition.
|
|
790
|
+
"""
|
|
791
|
+
new_if = Block(self._parent_scope, parent_block=self._parent_block)
|
|
792
|
+
self.if_blocks.append(new_if)
|
|
793
|
+
self.conditions.append(condition)
|
|
794
|
+
return new_if
|
|
795
|
+
|
|
796
|
+
def finalize(self) -> Block | Statement:
|
|
797
|
+
if not self.if_blocks:
|
|
798
|
+
# Unusual case of no conditions, only default case, but it
|
|
799
|
+
# simplifies other code to be able to handle this uniformly. We can
|
|
800
|
+
# replace this if statement with a single unconditional block.
|
|
801
|
+
return self.else_block
|
|
802
|
+
return self
|
|
803
|
+
|
|
804
|
+
def as_ast(self) -> py_ast.If:
|
|
805
|
+
if len(self.if_blocks) == 0:
|
|
806
|
+
raise AssertionError("Should have called `finalize` on If")
|
|
807
|
+
if_ast = empty_If()
|
|
808
|
+
current_if = if_ast
|
|
809
|
+
previous_if = None
|
|
810
|
+
for condition, if_block in zip(self.conditions, self.if_blocks):
|
|
811
|
+
current_if.test = condition.as_ast()
|
|
812
|
+
current_if.body = if_block.as_ast_list()
|
|
813
|
+
if previous_if is not None:
|
|
814
|
+
previous_if.orelse.append(current_if)
|
|
815
|
+
|
|
816
|
+
previous_if = current_if
|
|
817
|
+
current_if = empty_If()
|
|
818
|
+
|
|
819
|
+
if self.else_block.statements:
|
|
820
|
+
assert previous_if is not None
|
|
821
|
+
previous_if.orelse = self.else_block.as_ast_list()
|
|
822
|
+
|
|
823
|
+
return if_ast
|
|
824
|
+
|
|
825
|
+
|
|
826
|
+
class With(Statement):
|
|
827
|
+
child_elements = ["context_expr", "body"]
|
|
828
|
+
|
|
829
|
+
def __init__(
|
|
830
|
+
self,
|
|
831
|
+
context_expr: Expression,
|
|
832
|
+
target: Name | None = None,
|
|
833
|
+
*,
|
|
834
|
+
parent_scope: Scope,
|
|
835
|
+
parent_block: Block | None = None,
|
|
836
|
+
):
|
|
837
|
+
self.context_expr = context_expr
|
|
838
|
+
self.target = target
|
|
839
|
+
self._parent_scope = parent_scope
|
|
840
|
+
self._parent_block = parent_block
|
|
841
|
+
self.body = Block(parent_scope, parent_block=parent_block)
|
|
842
|
+
|
|
843
|
+
def as_ast(self) -> py_ast.With:
|
|
844
|
+
optional_vars = None
|
|
845
|
+
if self.target is not None:
|
|
846
|
+
optional_vars = py_ast.Name(id=self.target.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS)
|
|
847
|
+
|
|
848
|
+
return py_ast.With(
|
|
849
|
+
items=[
|
|
850
|
+
py_ast.withitem(
|
|
851
|
+
context_expr=self.context_expr.as_ast(),
|
|
852
|
+
optional_vars=optional_vars,
|
|
853
|
+
)
|
|
854
|
+
],
|
|
855
|
+
body=self.body.as_ast_list(allow_empty=False),
|
|
856
|
+
**DEFAULT_AST_ARGS,
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
|
|
860
|
+
class Try(Statement):
|
|
861
|
+
child_elements = ["catch_exceptions", "try_block", "except_block", "else_block"]
|
|
862
|
+
|
|
863
|
+
def __init__(self, catch_exceptions: Sequence[Expression], parent_scope: Scope):
|
|
864
|
+
self.catch_exceptions = catch_exceptions
|
|
865
|
+
self.try_block = Block(parent_scope)
|
|
866
|
+
self.except_block = Block(parent_scope)
|
|
867
|
+
self.else_block = Block(parent_scope)
|
|
868
|
+
|
|
869
|
+
def as_ast(self) -> py_ast.Try:
|
|
870
|
+
return py_ast.Try(
|
|
871
|
+
body=self.try_block.as_ast_list(allow_empty=False),
|
|
872
|
+
handlers=[
|
|
873
|
+
py_ast.ExceptHandler(
|
|
874
|
+
type=(
|
|
875
|
+
self.catch_exceptions[0].as_ast()
|
|
876
|
+
if len(self.catch_exceptions) == 1
|
|
877
|
+
else py_ast.Tuple(
|
|
878
|
+
elts=[e.as_ast() for e in self.catch_exceptions],
|
|
879
|
+
ctx=py_ast.Load(),
|
|
880
|
+
**DEFAULT_AST_ARGS,
|
|
881
|
+
)
|
|
882
|
+
),
|
|
883
|
+
name=None,
|
|
884
|
+
body=self.except_block.as_ast_list(allow_empty=False),
|
|
885
|
+
**DEFAULT_AST_ARGS,
|
|
886
|
+
)
|
|
887
|
+
],
|
|
888
|
+
orelse=self.else_block.as_ast_list(allow_empty=True),
|
|
889
|
+
finalbody=[],
|
|
890
|
+
**DEFAULT_AST_ARGS,
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
def has_assignment_for_name(self, name: str) -> bool:
|
|
894
|
+
if (
|
|
895
|
+
self.try_block.has_assignment_for_name(name) or self.else_block.has_assignment_for_name(name)
|
|
896
|
+
) and self.except_block.has_assignment_for_name(name):
|
|
897
|
+
return True
|
|
898
|
+
return False
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
class Import(Statement):
|
|
902
|
+
"""
|
|
903
|
+
Simple import statements, supporting:
|
|
904
|
+
- import foo
|
|
905
|
+
- import foo as bar
|
|
906
|
+
- import foo.bar
|
|
907
|
+
- import foo.bar as baz
|
|
908
|
+
|
|
909
|
+
Use via `Block.create_import`
|
|
910
|
+
|
|
911
|
+
We deliberately don't support multiple imports - these should
|
|
912
|
+
be cleaned up later using a linter on the generated code, if desired.
|
|
913
|
+
"""
|
|
914
|
+
|
|
915
|
+
def __init__(self, module: str, as_: Name | None) -> None:
|
|
916
|
+
self.module = module
|
|
917
|
+
self.as_ = as_
|
|
918
|
+
|
|
919
|
+
def as_ast(self) -> py_ast.AST:
|
|
920
|
+
if self.as_ is None:
|
|
921
|
+
# No alias needed:
|
|
922
|
+
return py_ast.Import(names=[py_ast.alias(name=self.module)], **DEFAULT_AST_ARGS)
|
|
923
|
+
else:
|
|
924
|
+
return py_ast.Import(names=[py_ast.alias(name=self.module, asname=self.as_.name)], **DEFAULT_AST_ARGS)
|
|
925
|
+
|
|
926
|
+
|
|
927
|
+
class ImportFrom(Statement):
|
|
928
|
+
"""
|
|
929
|
+
`from import' statement, supporting:
|
|
930
|
+
- from foo import bar
|
|
931
|
+
- from foo import bar as baz
|
|
932
|
+
|
|
933
|
+
Use via `Block.create_import_from`
|
|
934
|
+
|
|
935
|
+
We deliberately don't support multiple imports - these should
|
|
936
|
+
be cleaned up later using a linter on the generated code, if desired.
|
|
937
|
+
"""
|
|
938
|
+
|
|
939
|
+
def __init__(self, from_module: str, import_: str, as_: Name | None) -> None:
|
|
940
|
+
self.from_module = from_module
|
|
941
|
+
self.import_ = import_
|
|
942
|
+
self.as_ = as_
|
|
943
|
+
|
|
944
|
+
def as_ast(self) -> py_ast.AST:
|
|
945
|
+
if self.as_ is None:
|
|
946
|
+
# No alias needed:
|
|
947
|
+
return py_ast.ImportFrom(
|
|
948
|
+
module=self.from_module,
|
|
949
|
+
names=[py_ast.alias(name=self.import_)],
|
|
950
|
+
level=0,
|
|
951
|
+
**DEFAULT_AST_ARGS,
|
|
952
|
+
)
|
|
953
|
+
else:
|
|
954
|
+
return py_ast.ImportFrom(
|
|
955
|
+
module=self.from_module,
|
|
956
|
+
names=[py_ast.alias(name=self.import_, asname=self.as_.name)],
|
|
957
|
+
level=0,
|
|
958
|
+
**DEFAULT_AST_ARGS,
|
|
959
|
+
)
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
class Expression(CodeGenAst):
|
|
963
|
+
@abstractmethod
|
|
964
|
+
def as_ast(self) -> py_ast.expr: ...
|
|
965
|
+
|
|
966
|
+
# Some utilities for easy chaining:
|
|
967
|
+
|
|
968
|
+
def attr(self, attribute: str, /) -> Attr:
|
|
969
|
+
return Attr(self, attribute)
|
|
970
|
+
|
|
971
|
+
def call(
|
|
972
|
+
self,
|
|
973
|
+
args: Sequence[Expression] | None = None,
|
|
974
|
+
kwargs: dict[str, Expression] | None = None,
|
|
975
|
+
) -> Call:
|
|
976
|
+
return Call(self, args or [], kwargs or {})
|
|
977
|
+
|
|
978
|
+
def method_call(
|
|
979
|
+
self,
|
|
980
|
+
attribute: str,
|
|
981
|
+
args: Sequence[Expression] | None = None,
|
|
982
|
+
kwargs: dict[str, Expression] | None = None,
|
|
983
|
+
) -> Call:
|
|
984
|
+
return self.attr(attribute).call(args, kwargs)
|
|
985
|
+
|
|
986
|
+
# Arithmetic operators
|
|
987
|
+
|
|
988
|
+
def add(self, other: Expression, /) -> Add:
|
|
989
|
+
return Add(self, other)
|
|
990
|
+
|
|
991
|
+
def sub(self, other: Expression, /) -> Sub:
|
|
992
|
+
return Sub(self, other)
|
|
993
|
+
|
|
994
|
+
def mul(self, other: Expression, /) -> Mul:
|
|
995
|
+
return Mul(self, other)
|
|
996
|
+
|
|
997
|
+
def div(self, other: Expression, /) -> Div:
|
|
998
|
+
return Div(self, other)
|
|
999
|
+
|
|
1000
|
+
def floordiv(self, other: Expression, /) -> FloorDiv:
|
|
1001
|
+
return FloorDiv(self, other)
|
|
1002
|
+
|
|
1003
|
+
def mod(self, other: Expression, /) -> Mod:
|
|
1004
|
+
return Mod(self, other)
|
|
1005
|
+
|
|
1006
|
+
def pow(self, other: Expression, /) -> Pow:
|
|
1007
|
+
return Pow(self, other)
|
|
1008
|
+
|
|
1009
|
+
def matmul(self, other: Expression, /) -> MatMul:
|
|
1010
|
+
return MatMul(self, other)
|
|
1011
|
+
|
|
1012
|
+
# Comparison operators
|
|
1013
|
+
|
|
1014
|
+
def eq(self, other: Expression, /) -> Equals:
|
|
1015
|
+
return Equals(self, other)
|
|
1016
|
+
|
|
1017
|
+
def ne(self, other: Expression, /) -> NotEquals:
|
|
1018
|
+
return NotEquals(self, other)
|
|
1019
|
+
|
|
1020
|
+
def lt(self, other: Expression, /) -> Lt:
|
|
1021
|
+
return Lt(self, other)
|
|
1022
|
+
|
|
1023
|
+
def gt(self, other: Expression, /) -> Gt:
|
|
1024
|
+
return Gt(self, other)
|
|
1025
|
+
|
|
1026
|
+
def le(self, other: Expression, /) -> LtE:
|
|
1027
|
+
return LtE(self, other)
|
|
1028
|
+
|
|
1029
|
+
def ge(self, other: Expression, /) -> GtE:
|
|
1030
|
+
return GtE(self, other)
|
|
1031
|
+
|
|
1032
|
+
# Boolean operators
|
|
1033
|
+
|
|
1034
|
+
def and_(self, other: Expression, /) -> And:
|
|
1035
|
+
return And(self, other)
|
|
1036
|
+
|
|
1037
|
+
def or_(self, other: Expression, /) -> Or:
|
|
1038
|
+
return Or(self, other)
|
|
1039
|
+
|
|
1040
|
+
# Membership operators
|
|
1041
|
+
|
|
1042
|
+
def in_(self, other: Expression, /) -> In:
|
|
1043
|
+
return In(self, other)
|
|
1044
|
+
|
|
1045
|
+
def not_in(self, other: Expression, /) -> NotIn:
|
|
1046
|
+
return NotIn(self, other)
|
|
1047
|
+
|
|
1048
|
+
# Unpacking
|
|
1049
|
+
|
|
1050
|
+
def starred(self) -> Starred:
|
|
1051
|
+
return Starred(self)
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
class String(Expression):
|
|
1055
|
+
child_elements = []
|
|
1056
|
+
|
|
1057
|
+
type = str
|
|
1058
|
+
|
|
1059
|
+
def __init__(self, string_value: str):
|
|
1060
|
+
self.string_value = string_value
|
|
1061
|
+
|
|
1062
|
+
def as_ast(self) -> py_ast.expr:
|
|
1063
|
+
return py_ast.Constant(
|
|
1064
|
+
self.string_value,
|
|
1065
|
+
kind=None, # 3.8, indicates no prefix, needed only for tests
|
|
1066
|
+
**DEFAULT_AST_ARGS,
|
|
1067
|
+
)
|
|
1068
|
+
|
|
1069
|
+
def __repr__(self):
|
|
1070
|
+
return f"String({repr(self.string_value)})"
|
|
1071
|
+
|
|
1072
|
+
def __eq__(self, other: object):
|
|
1073
|
+
return isinstance(other, String) and other.string_value == self.string_value
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
class Bool(Expression):
|
|
1077
|
+
child_elements = []
|
|
1078
|
+
|
|
1079
|
+
type = bool
|
|
1080
|
+
|
|
1081
|
+
def __init__(self, value: bool):
|
|
1082
|
+
self.value = value
|
|
1083
|
+
|
|
1084
|
+
def as_ast(self) -> py_ast.expr:
|
|
1085
|
+
return py_ast.Constant(self.value, **DEFAULT_AST_ARGS)
|
|
1086
|
+
|
|
1087
|
+
def __repr__(self):
|
|
1088
|
+
return f"Bool({self.value!r})"
|
|
1089
|
+
|
|
1090
|
+
|
|
1091
|
+
class Bytes(Expression):
|
|
1092
|
+
child_elements = []
|
|
1093
|
+
|
|
1094
|
+
type = bytes
|
|
1095
|
+
|
|
1096
|
+
def __init__(self, value: bytes):
|
|
1097
|
+
self.value = value
|
|
1098
|
+
|
|
1099
|
+
def as_ast(self) -> py_ast.expr:
|
|
1100
|
+
return py_ast.Constant(self.value, **DEFAULT_AST_ARGS)
|
|
1101
|
+
|
|
1102
|
+
def __repr__(self):
|
|
1103
|
+
return f"Bytes({self.value!r})"
|
|
1104
|
+
|
|
1105
|
+
|
|
1106
|
+
class Number(Expression):
|
|
1107
|
+
child_elements = []
|
|
1108
|
+
|
|
1109
|
+
def __init__(self, number: int | float):
|
|
1110
|
+
self.number = number
|
|
1111
|
+
|
|
1112
|
+
def as_ast(self) -> py_ast.expr:
|
|
1113
|
+
return py_ast.Constant(self.number, **DEFAULT_AST_ARGS)
|
|
1114
|
+
|
|
1115
|
+
def __repr__(self):
|
|
1116
|
+
return f"Number({repr(self.number)})"
|
|
1117
|
+
|
|
1118
|
+
|
|
1119
|
+
class List(Expression):
|
|
1120
|
+
child_elements = ["items"]
|
|
1121
|
+
|
|
1122
|
+
def __init__(self, items: list[Expression]):
|
|
1123
|
+
self.items = items
|
|
1124
|
+
|
|
1125
|
+
def as_ast(self) -> py_ast.expr:
|
|
1126
|
+
return py_ast.List(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
class Tuple(Expression):
|
|
1130
|
+
child_elements = ["items"]
|
|
1131
|
+
|
|
1132
|
+
def __init__(self, items: Sequence[Expression]):
|
|
1133
|
+
self.items = items
|
|
1134
|
+
|
|
1135
|
+
def as_ast(self) -> py_ast.expr:
|
|
1136
|
+
return py_ast.Tuple(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
|
|
1137
|
+
|
|
1138
|
+
|
|
1139
|
+
class Set(Expression):
|
|
1140
|
+
child_elements = ["items"]
|
|
1141
|
+
|
|
1142
|
+
def __init__(self, items: Sequence[Expression]):
|
|
1143
|
+
self.items = items
|
|
1144
|
+
|
|
1145
|
+
def as_ast(self) -> py_ast.expr:
|
|
1146
|
+
if len(self.items) == 0:
|
|
1147
|
+
# {} is a dict literal in Python, so empty sets must use set([])
|
|
1148
|
+
return py_ast.Call(
|
|
1149
|
+
func=py_ast.Name(id="set", ctx=py_ast.Load(), **DEFAULT_AST_ARGS),
|
|
1150
|
+
args=[py_ast.List(elts=[], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)],
|
|
1151
|
+
keywords=[],
|
|
1152
|
+
**DEFAULT_AST_ARGS,
|
|
1153
|
+
)
|
|
1154
|
+
return py_ast.Set(elts=[i.as_ast() for i in self.items], **DEFAULT_AST_ARGS)
|
|
1155
|
+
|
|
1156
|
+
|
|
1157
|
+
class Dict(Expression):
|
|
1158
|
+
child_elements = ["pairs"]
|
|
1159
|
+
|
|
1160
|
+
def __init__(self, pairs: Sequence[tuple[Expression, Expression]]):
|
|
1161
|
+
self.pairs = pairs
|
|
1162
|
+
|
|
1163
|
+
def as_ast(self) -> py_ast.expr:
|
|
1164
|
+
return py_ast.Dict(
|
|
1165
|
+
keys=[k.as_ast() for k, _ in self.pairs],
|
|
1166
|
+
values=[v.as_ast() for _, v in self.pairs],
|
|
1167
|
+
**DEFAULT_AST_ARGS,
|
|
1168
|
+
)
|
|
1169
|
+
|
|
1170
|
+
|
|
1171
|
+
class StringJoinBase(Expression):
|
|
1172
|
+
child_elements = ["parts"]
|
|
1173
|
+
|
|
1174
|
+
type = str
|
|
1175
|
+
|
|
1176
|
+
def __init__(self, parts: Sequence[Expression]):
|
|
1177
|
+
self.parts = parts
|
|
1178
|
+
|
|
1179
|
+
def __repr__(self):
|
|
1180
|
+
return f"{self.__class__.__name__}([{', '.join(repr(p) for p in self.parts)}])"
|
|
1181
|
+
|
|
1182
|
+
@classmethod
|
|
1183
|
+
def build(cls: type[StringJoinBase], parts: Sequence[Expression]) -> StringJoinBase | Expression:
|
|
1184
|
+
"""
|
|
1185
|
+
Build a string join operation, but return a simpler expression if possible.
|
|
1186
|
+
"""
|
|
1187
|
+
# Merge adjacent String objects.
|
|
1188
|
+
new_parts: list[Expression] = []
|
|
1189
|
+
for part in parts:
|
|
1190
|
+
if len(new_parts) > 0 and isinstance(new_parts[-1], String) and isinstance(part, String):
|
|
1191
|
+
new_parts[-1] = String(new_parts[-1].string_value + part.string_value)
|
|
1192
|
+
else:
|
|
1193
|
+
new_parts.append(part)
|
|
1194
|
+
parts = new_parts
|
|
1195
|
+
|
|
1196
|
+
# See if we can eliminate the StringJoin altogether
|
|
1197
|
+
if len(parts) == 0:
|
|
1198
|
+
return String("")
|
|
1199
|
+
if len(parts) == 1:
|
|
1200
|
+
return parts[0]
|
|
1201
|
+
return cls(parts)
|
|
1202
|
+
|
|
1203
|
+
|
|
1204
|
+
class FStringJoin(StringJoinBase):
|
|
1205
|
+
def as_ast(self) -> py_ast.expr:
|
|
1206
|
+
# f-strings
|
|
1207
|
+
values: list[py_ast.expr] = []
|
|
1208
|
+
for part in self.parts:
|
|
1209
|
+
if isinstance(part, String):
|
|
1210
|
+
values.append(part.as_ast())
|
|
1211
|
+
else:
|
|
1212
|
+
values.append(
|
|
1213
|
+
py_ast.FormattedValue(
|
|
1214
|
+
value=part.as_ast(),
|
|
1215
|
+
conversion=-1,
|
|
1216
|
+
format_spec=None,
|
|
1217
|
+
**DEFAULT_AST_ARGS,
|
|
1218
|
+
)
|
|
1219
|
+
)
|
|
1220
|
+
return py_ast.JoinedStr(values=values, **DEFAULT_AST_ARGS)
|
|
1221
|
+
|
|
1222
|
+
|
|
1223
|
+
class ConcatJoin(StringJoinBase):
|
|
1224
|
+
def as_ast(self) -> py_ast.expr:
|
|
1225
|
+
# Concatenate with +
|
|
1226
|
+
left = self.parts[0].as_ast()
|
|
1227
|
+
for part in self.parts[1:]:
|
|
1228
|
+
right = part.as_ast()
|
|
1229
|
+
left = py_ast.BinOp(
|
|
1230
|
+
left=left,
|
|
1231
|
+
op=py_ast.Add(**DEFAULT_AST_ARGS_ADD),
|
|
1232
|
+
right=right,
|
|
1233
|
+
**DEFAULT_AST_ARGS,
|
|
1234
|
+
)
|
|
1235
|
+
return left
|
|
1236
|
+
|
|
1237
|
+
|
|
1238
|
+
# For CPython, f-strings give a measurable improvement over concatenation,
|
|
1239
|
+
# so make that default
|
|
1240
|
+
|
|
1241
|
+
StringJoin = FStringJoin
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
class Name(Expression):
|
|
1245
|
+
child_elements = []
|
|
1246
|
+
|
|
1247
|
+
def __init__(self, name: str, scope: Scope):
|
|
1248
|
+
if not scope.is_name_in_use(name):
|
|
1249
|
+
raise AssertionError(f"Cannot refer to undefined name '{name}'")
|
|
1250
|
+
self.name = name
|
|
1251
|
+
|
|
1252
|
+
def as_ast(self) -> py_ast.expr:
|
|
1253
|
+
if not allowable_name(self.name, allow_builtin=True):
|
|
1254
|
+
raise AssertionError(f"Expected {self.name} to be a valid Python identifier")
|
|
1255
|
+
return py_ast.Name(id=self.name, ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
|
|
1256
|
+
|
|
1257
|
+
def __eq__(self, other: object):
|
|
1258
|
+
return type(other) is type(self) and other.name == self.name
|
|
1259
|
+
|
|
1260
|
+
def __repr__(self):
|
|
1261
|
+
return f"Name({repr(self.name)})"
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
class Attr(Expression):
|
|
1265
|
+
child_elements = ["value"]
|
|
1266
|
+
|
|
1267
|
+
def __init__(self, value: Expression, attribute: str) -> None:
|
|
1268
|
+
self.value = value
|
|
1269
|
+
if not allowable_name(attribute, allow_builtin=True):
|
|
1270
|
+
raise AssertionError(f"Expected {attribute} to be a valid Python identifier")
|
|
1271
|
+
self.attribute = attribute
|
|
1272
|
+
|
|
1273
|
+
def as_ast(self) -> py_ast.expr:
|
|
1274
|
+
return py_ast.Attribute(value=self.value.as_ast(), attr=self.attribute, **DEFAULT_AST_ARGS)
|
|
1275
|
+
|
|
1276
|
+
|
|
1277
|
+
class Starred(Expression):
|
|
1278
|
+
child_elements = ["value"]
|
|
1279
|
+
|
|
1280
|
+
def __init__(self, value: Expression):
|
|
1281
|
+
self.value = value
|
|
1282
|
+
|
|
1283
|
+
def as_ast(self) -> py_ast.expr:
|
|
1284
|
+
return py_ast.Starred(value=self.value.as_ast(), ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
|
|
1285
|
+
|
|
1286
|
+
def __repr__(self):
|
|
1287
|
+
return f"Starred({self.value!r})"
|
|
1288
|
+
|
|
1289
|
+
|
|
1290
|
+
def function_call(
|
|
1291
|
+
function_name: str,
|
|
1292
|
+
args: Sequence[Expression],
|
|
1293
|
+
kwargs: dict[str, Expression],
|
|
1294
|
+
scope: Scope,
|
|
1295
|
+
) -> Expression:
|
|
1296
|
+
if not scope.is_name_in_use(function_name):
|
|
1297
|
+
raise AssertionError(f"Cannot call unknown function '{function_name}'")
|
|
1298
|
+
if function_name in SENSITIVE_FUNCTIONS:
|
|
1299
|
+
raise AssertionError(f"Disallowing call to '{function_name}'")
|
|
1300
|
+
|
|
1301
|
+
return Name(name=function_name, scope=scope).call(args, kwargs)
|
|
1302
|
+
|
|
1303
|
+
|
|
1304
|
+
class Call(Expression):
|
|
1305
|
+
child_elements = ["value", "args", "kwargs"]
|
|
1306
|
+
|
|
1307
|
+
def __init__(
|
|
1308
|
+
self,
|
|
1309
|
+
value: Expression,
|
|
1310
|
+
args: Sequence[Expression],
|
|
1311
|
+
kwargs: dict[str, Expression],
|
|
1312
|
+
):
|
|
1313
|
+
self.value = value
|
|
1314
|
+
self.args = list(args)
|
|
1315
|
+
self.kwargs = kwargs
|
|
1316
|
+
|
|
1317
|
+
def as_ast(self) -> py_ast.expr:
|
|
1318
|
+
|
|
1319
|
+
for name in self.kwargs.keys():
|
|
1320
|
+
if not allowable_keyword_arg_name(name):
|
|
1321
|
+
raise AssertionError(f"Expected {name} to be a valid Fluent NamedArgument name")
|
|
1322
|
+
|
|
1323
|
+
if any(not allowable_name(name) for name in self.kwargs.keys()):
|
|
1324
|
+
# This branch covers function arg names like 'foo-bar', which are
|
|
1325
|
+
# allowable in languages like Fluent, but not normally in Python. We work around
|
|
1326
|
+
# this using `my_function(**{'foo-bar': baz})` syntax.
|
|
1327
|
+
|
|
1328
|
+
# (If we only wanted to exec the resulting AST, this branch is technically not
|
|
1329
|
+
# necessary, since it is the Python parser that disallows `foo-bar` as an identifier,
|
|
1330
|
+
# and we are by-passing that by creating AST directly. However, to produce something
|
|
1331
|
+
# that can be decompiled to valid Python, we solve the general case).
|
|
1332
|
+
|
|
1333
|
+
kwarg_pairs = list(sorted(self.kwargs.items()))
|
|
1334
|
+
kwarg_names, kwarg_values = [k for k, _ in kwarg_pairs], [v for _, v in kwarg_pairs]
|
|
1335
|
+
return py_ast.Call(
|
|
1336
|
+
func=self.value.as_ast(),
|
|
1337
|
+
args=[arg.as_ast() for arg in self.args],
|
|
1338
|
+
keywords=[
|
|
1339
|
+
py_ast.keyword(
|
|
1340
|
+
arg=None,
|
|
1341
|
+
value=py_ast.Dict(
|
|
1342
|
+
keys=[py_ast.Constant(k, kind=None, **DEFAULT_AST_ARGS) for k in kwarg_names],
|
|
1343
|
+
values=[v.as_ast() for v in kwarg_values],
|
|
1344
|
+
**DEFAULT_AST_ARGS,
|
|
1345
|
+
),
|
|
1346
|
+
**DEFAULT_AST_ARGS,
|
|
1347
|
+
)
|
|
1348
|
+
],
|
|
1349
|
+
**DEFAULT_AST_ARGS,
|
|
1350
|
+
)
|
|
1351
|
+
|
|
1352
|
+
# Normal `my_function(foo=bar)` syntax
|
|
1353
|
+
return py_ast.Call(
|
|
1354
|
+
func=self.value.as_ast(),
|
|
1355
|
+
args=[arg.as_ast() for arg in self.args],
|
|
1356
|
+
keywords=[
|
|
1357
|
+
py_ast.keyword(arg=name, value=value.as_ast(), **DEFAULT_AST_ARGS)
|
|
1358
|
+
for name, value in self.kwargs.items()
|
|
1359
|
+
],
|
|
1360
|
+
**DEFAULT_AST_ARGS,
|
|
1361
|
+
)
|
|
1362
|
+
|
|
1363
|
+
def __repr__(self):
|
|
1364
|
+
return f"Call({self.value!r}, {self.args}, {self.kwargs})"
|
|
1365
|
+
|
|
1366
|
+
|
|
1367
|
+
def method_call(
|
|
1368
|
+
obj: Expression,
|
|
1369
|
+
method_name: str,
|
|
1370
|
+
args: Sequence[Expression],
|
|
1371
|
+
kwargs: dict[str, Expression],
|
|
1372
|
+
):
|
|
1373
|
+
return obj.attr(method_name).call(args=args, kwargs=kwargs)
|
|
1374
|
+
|
|
1375
|
+
|
|
1376
|
+
class DictLookup(Expression):
|
|
1377
|
+
child_elements = ["lookup_obj", "lookup_arg"]
|
|
1378
|
+
|
|
1379
|
+
def __init__(self, lookup_obj: Expression, lookup_arg: Expression):
|
|
1380
|
+
self.lookup_obj = lookup_obj
|
|
1381
|
+
self.lookup_arg = lookup_arg
|
|
1382
|
+
|
|
1383
|
+
def as_ast(self) -> py_ast.expr:
|
|
1384
|
+
return py_ast.Subscript(
|
|
1385
|
+
value=self.lookup_obj.as_ast(),
|
|
1386
|
+
slice=py_ast.subscript_slice_object(self.lookup_arg.as_ast()),
|
|
1387
|
+
ctx=py_ast.Load(),
|
|
1388
|
+
**DEFAULT_AST_ARGS,
|
|
1389
|
+
)
|
|
1390
|
+
|
|
1391
|
+
|
|
1392
|
+
create_class_instance = function_call
|
|
1393
|
+
|
|
1394
|
+
|
|
1395
|
+
class NoneExpr(Expression):
|
|
1396
|
+
type = type(None)
|
|
1397
|
+
|
|
1398
|
+
def as_ast(self) -> py_ast.expr:
|
|
1399
|
+
return py_ast.Constant(value=None, **DEFAULT_AST_ARGS)
|
|
1400
|
+
|
|
1401
|
+
|
|
1402
|
+
class BinaryOperator(Expression):
|
|
1403
|
+
child_elements = ["left", "right"]
|
|
1404
|
+
|
|
1405
|
+
def __init__(self, left: Expression, right: Expression):
|
|
1406
|
+
self.left = left
|
|
1407
|
+
self.right = right
|
|
1408
|
+
|
|
1409
|
+
|
|
1410
|
+
class ArithOp(BinaryOperator, ABC):
|
|
1411
|
+
"""Arithmetic binary operator (ast.BinOp)."""
|
|
1412
|
+
|
|
1413
|
+
op: ClassVar[type[py_ast.operator]]
|
|
1414
|
+
|
|
1415
|
+
def as_ast(self) -> py_ast.expr:
|
|
1416
|
+
return py_ast.BinOp(
|
|
1417
|
+
left=self.left.as_ast(),
|
|
1418
|
+
op=self.op(**DEFAULT_AST_ARGS_ADD),
|
|
1419
|
+
right=self.right.as_ast(),
|
|
1420
|
+
**DEFAULT_AST_ARGS,
|
|
1421
|
+
)
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
class Add(ArithOp):
|
|
1425
|
+
op = py_ast.Add
|
|
1426
|
+
|
|
1427
|
+
|
|
1428
|
+
class Sub(ArithOp):
|
|
1429
|
+
op = py_ast.Sub
|
|
1430
|
+
|
|
1431
|
+
|
|
1432
|
+
class Mul(ArithOp):
|
|
1433
|
+
op = py_ast.Mult
|
|
1434
|
+
|
|
1435
|
+
|
|
1436
|
+
class Div(ArithOp):
|
|
1437
|
+
op = py_ast.Div
|
|
1438
|
+
|
|
1439
|
+
|
|
1440
|
+
class FloorDiv(ArithOp):
|
|
1441
|
+
op = py_ast.FloorDiv
|
|
1442
|
+
|
|
1443
|
+
|
|
1444
|
+
class Mod(ArithOp):
|
|
1445
|
+
op = py_ast.Mod
|
|
1446
|
+
|
|
1447
|
+
|
|
1448
|
+
class Pow(ArithOp):
|
|
1449
|
+
op = py_ast.Pow
|
|
1450
|
+
|
|
1451
|
+
|
|
1452
|
+
class MatMul(ArithOp):
|
|
1453
|
+
op = py_ast.MatMult
|
|
1454
|
+
|
|
1455
|
+
|
|
1456
|
+
class CompareOp(BinaryOperator, ABC):
|
|
1457
|
+
"""Comparison operator (ast.Compare)."""
|
|
1458
|
+
|
|
1459
|
+
type = bool
|
|
1460
|
+
op: ClassVar[type[py_ast.cmpop]]
|
|
1461
|
+
|
|
1462
|
+
def as_ast(self) -> py_ast.expr:
|
|
1463
|
+
return py_ast.Compare(
|
|
1464
|
+
left=self.left.as_ast(),
|
|
1465
|
+
comparators=[self.right.as_ast()],
|
|
1466
|
+
ops=[self.op()],
|
|
1467
|
+
**DEFAULT_AST_ARGS,
|
|
1468
|
+
)
|
|
1469
|
+
|
|
1470
|
+
|
|
1471
|
+
class Equals(CompareOp):
|
|
1472
|
+
op = py_ast.Eq
|
|
1473
|
+
|
|
1474
|
+
|
|
1475
|
+
class NotEquals(CompareOp):
|
|
1476
|
+
op = py_ast.NotEq
|
|
1477
|
+
|
|
1478
|
+
|
|
1479
|
+
class Lt(CompareOp):
|
|
1480
|
+
op = py_ast.Lt
|
|
1481
|
+
|
|
1482
|
+
|
|
1483
|
+
class Gt(CompareOp):
|
|
1484
|
+
op = py_ast.Gt
|
|
1485
|
+
|
|
1486
|
+
|
|
1487
|
+
class LtE(CompareOp):
|
|
1488
|
+
op = py_ast.LtE
|
|
1489
|
+
|
|
1490
|
+
|
|
1491
|
+
class GtE(CompareOp):
|
|
1492
|
+
op = py_ast.GtE
|
|
1493
|
+
|
|
1494
|
+
|
|
1495
|
+
class In(CompareOp):
|
|
1496
|
+
op = py_ast.In
|
|
1497
|
+
|
|
1498
|
+
|
|
1499
|
+
class NotIn(CompareOp):
|
|
1500
|
+
op = py_ast.NotIn
|
|
1501
|
+
|
|
1502
|
+
|
|
1503
|
+
class BoolOp(BinaryOperator, ABC):
|
|
1504
|
+
type = bool
|
|
1505
|
+
op: ClassVar[type[py_ast.boolop]]
|
|
1506
|
+
|
|
1507
|
+
def as_ast(self) -> py_ast.expr:
|
|
1508
|
+
return py_ast.BoolOp(
|
|
1509
|
+
op=self.op(),
|
|
1510
|
+
values=[self.left.as_ast(), self.right.as_ast()],
|
|
1511
|
+
**DEFAULT_AST_ARGS,
|
|
1512
|
+
)
|
|
1513
|
+
|
|
1514
|
+
|
|
1515
|
+
class And(BoolOp):
|
|
1516
|
+
op = py_ast.And
|
|
1517
|
+
|
|
1518
|
+
|
|
1519
|
+
class Or(BoolOp):
|
|
1520
|
+
op = py_ast.Or
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
def traverse(ast_node: py_ast.AST, func: Callable[[py_ast.AST], None]):
|
|
1524
|
+
"""
|
|
1525
|
+
Apply 'func' to ast_node (which is `ast.*` object)
|
|
1526
|
+
"""
|
|
1527
|
+
for node in py_ast.walk(ast_node):
|
|
1528
|
+
func(node)
|
|
1529
|
+
|
|
1530
|
+
|
|
1531
|
+
def simplify(codegen_ast: CodeGenAstType, simplifier: Callable[[CodeGenAstType, list[bool]], CodeGenAst]):
|
|
1532
|
+
changes = [True]
|
|
1533
|
+
|
|
1534
|
+
# Wrap `simplifier` (which takes additional `changes` arg)
|
|
1535
|
+
# into function that take just `node`, as required by rewriting_traverse
|
|
1536
|
+
def rewriter(node: CodeGenAstType) -> CodeGenAstType:
|
|
1537
|
+
return simplifier(node, changes)
|
|
1538
|
+
|
|
1539
|
+
while any(changes):
|
|
1540
|
+
changes[:] = []
|
|
1541
|
+
rewriting_traverse(codegen_ast, rewriter)
|
|
1542
|
+
return codegen_ast
|
|
1543
|
+
|
|
1544
|
+
|
|
1545
|
+
def rewriting_traverse(
|
|
1546
|
+
node: CodeGenAstType | Sequence[CodeGenAstType],
|
|
1547
|
+
func: Callable[[CodeGenAstType], CodeGenAstType],
|
|
1548
|
+
):
|
|
1549
|
+
"""
|
|
1550
|
+
Apply 'func' to node and all sub CodeGenAst nodes
|
|
1551
|
+
"""
|
|
1552
|
+
if isinstance(node, (CodeGenAst, CodeGenAstList)):
|
|
1553
|
+
new_node = func(node)
|
|
1554
|
+
if new_node is not node:
|
|
1555
|
+
morph_into(node, new_node)
|
|
1556
|
+
for k in node.child_elements:
|
|
1557
|
+
rewriting_traverse(getattr(node, k), func)
|
|
1558
|
+
elif isinstance(node, (list, tuple)):
|
|
1559
|
+
for i in node:
|
|
1560
|
+
rewriting_traverse(i, func)
|
|
1561
|
+
|
|
1562
|
+
|
|
1563
|
+
def morph_into(item: object, new_item: object) -> None:
|
|
1564
|
+
# This naughty little function allows us to make `item` behave like
|
|
1565
|
+
# `new_item` in every way, except it maintains the identity of `item`, so
|
|
1566
|
+
# that we don't have to rewrite a tree of objects with new objects.
|
|
1567
|
+
item.__class__ = new_item.__class__
|
|
1568
|
+
item.__dict__ = new_item.__dict__
|
|
1569
|
+
|
|
1570
|
+
|
|
1571
|
+
def empty_If() -> py_ast.If:
|
|
1572
|
+
"""
|
|
1573
|
+
Create an empty If ast node. The `test` attribute
|
|
1574
|
+
must be added later.
|
|
1575
|
+
"""
|
|
1576
|
+
return py_ast.If(test=None, orelse=[], **DEFAULT_AST_ARGS) # type: ignore[reportArgumentType]
|
|
1577
|
+
|
|
1578
|
+
|
|
1579
|
+
type PythonObj = (
|
|
1580
|
+
bool
|
|
1581
|
+
| str
|
|
1582
|
+
| bytes
|
|
1583
|
+
| int
|
|
1584
|
+
| float
|
|
1585
|
+
| None
|
|
1586
|
+
| list[PythonObj]
|
|
1587
|
+
| tuple[PythonObj, ...]
|
|
1588
|
+
| set[PythonObj]
|
|
1589
|
+
| frozenset[PythonObj]
|
|
1590
|
+
| dict[PythonObj, PythonObj]
|
|
1591
|
+
)
|
|
1592
|
+
|
|
1593
|
+
|
|
1594
|
+
@overload
|
|
1595
|
+
def auto(value: bool) -> Bool: ... # type: ignore[overload-overlap] # bool before int/float is intentional
|
|
1596
|
+
@overload
|
|
1597
|
+
def auto(value: str) -> String: ...
|
|
1598
|
+
@overload
|
|
1599
|
+
def auto(value: bytes) -> Bytes: ...
|
|
1600
|
+
@overload
|
|
1601
|
+
def auto(value: int) -> Number: ...
|
|
1602
|
+
@overload
|
|
1603
|
+
def auto(value: float) -> Number: ...
|
|
1604
|
+
@overload
|
|
1605
|
+
def auto(value: None) -> NoneExpr: ...
|
|
1606
|
+
@overload
|
|
1607
|
+
def auto(value: list[PythonObj]) -> List: ...
|
|
1608
|
+
@overload
|
|
1609
|
+
def auto(value: tuple[PythonObj, ...]) -> Tuple: ...
|
|
1610
|
+
@overload
|
|
1611
|
+
def auto(value: set[PythonObj]) -> Set: ...
|
|
1612
|
+
@overload
|
|
1613
|
+
def auto(value: frozenset[PythonObj]) -> Set: ...
|
|
1614
|
+
@overload
|
|
1615
|
+
def auto(value: dict[PythonObj, PythonObj]) -> Dict: ...
|
|
1616
|
+
|
|
1617
|
+
|
|
1618
|
+
def auto(value: PythonObj) -> Expression:
|
|
1619
|
+
"""
|
|
1620
|
+
Create a codegen Expression from a plain Python object.
|
|
1621
|
+
|
|
1622
|
+
Supports bool, str, bytes, int, float, None, and recursively
|
|
1623
|
+
list, tuple, set, frozenset, and dict.
|
|
1624
|
+
"""
|
|
1625
|
+
if isinstance(value, bool):
|
|
1626
|
+
return Bool(value)
|
|
1627
|
+
elif isinstance(value, str):
|
|
1628
|
+
return String(value)
|
|
1629
|
+
elif isinstance(value, bytes):
|
|
1630
|
+
return Bytes(value)
|
|
1631
|
+
elif isinstance(value, (int, float)):
|
|
1632
|
+
return Number(value)
|
|
1633
|
+
elif value is None:
|
|
1634
|
+
return NoneExpr()
|
|
1635
|
+
elif isinstance(value, list):
|
|
1636
|
+
return List([auto(item) for item in value])
|
|
1637
|
+
elif isinstance(value, tuple):
|
|
1638
|
+
return Tuple([auto(item) for item in value])
|
|
1639
|
+
elif isinstance(value, (set, frozenset)):
|
|
1640
|
+
return Set([auto(item) for item in sorted(value, key=repr)])
|
|
1641
|
+
elif isinstance(value, dict): # type: ignore[reportUnnecessaryIsInstance]
|
|
1642
|
+
return Dict([(auto(k), auto(v)) for k, v in value.items()])
|
|
1643
|
+
assert_never(value)
|
|
1644
|
+
|
|
1645
|
+
|
|
1646
|
+
class constants:
|
|
1647
|
+
"""
|
|
1648
|
+
Useful pre-made Expression constants
|
|
1649
|
+
"""
|
|
1650
|
+
|
|
1651
|
+
None_: NoneExpr = auto(None)
|
|
1652
|
+
True_: Bool = auto(True)
|
|
1653
|
+
False_: Bool = auto(False)
|