fluent-codegen 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1653 @@
1
+ """
2
+ Utilities for doing Python code generation
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import builtins
8
+ import enum
9
+ import keyword
10
+ import re
11
+ from abc import ABC, abstractmethod
12
+ from collections.abc import Callable, Sequence
13
+ from dataclasses import dataclass
14
+ from typing import ClassVar, Protocol, assert_never, overload, runtime_checkable
15
+
16
+ from . import ast_compat as py_ast
17
+ from .ast_compat import (
18
+ DEFAULT_AST_ARGS,
19
+ DEFAULT_AST_ARGS_ADD,
20
+ DEFAULT_AST_ARGS_ARGUMENTS,
21
+ DEFAULT_AST_ARGS_MODULE,
22
+ )
23
+ from .utils import allowable_keyword_arg_name, allowable_name
24
+
25
+ # This module provides simple utilities for building up Python source code.
26
+ # The design originally came from fluent-compiler, so had the following aims
27
+ # and constraints:
28
+ #
29
+ # 1. Performance.
30
+ #
31
+ # The resulting Python code should do as little as possible, especially for
32
+ # simple cases.
33
+ #
34
+ # 2. Correctness (obviously)
35
+ #
36
+ # In particular, we should try to make it hard to generate code that is
37
+ # syntactically correct and therefore compiles but doesn't work. We try to
38
+ # make it hard to generate accidental name clashes, or use variables that are
39
+ # not defined.
40
+ #
41
+ # Correctness also has a security implication, since the result of this code
42
+ # might be 'exec'ed. To that end:
43
+ # * We build up AST, rather than strings. This eliminates many
44
+ # potential bugs caused by wrong escaping/interpolation.
45
+ # * the `as_ast()` methods are paranoid about input, and do many asserts.
46
+ # We do this even though other layers will usually have checked the
47
+ # input, to allow us to reason locally when checking these methods. These
48
+ # asserts must also have 100% code coverage.
49
+ #
50
+ # 3. Simplicity
51
+ #
52
+ # The resulting Python code should be easy to read and understand.
53
+ #
54
+ # 4. Predictability
55
+ #
56
+ # Since we want to test the resulting source code, we have made some design
57
+ # decisions that aim to ensure things like function argument names are
58
+ # consistent and so can be predicted easily.
59
+
60
+ # Outside of fluent-compiler, this code will likely be useful for situations
61
+ # which have similar aims.
62
+
63
+
64
+ PROPERTY_TYPE = "PROPERTY_TYPE"
65
+ PROPERTY_RETURN_TYPE = "PROPERTY_RETURN_TYPE"
66
+ # UNKNOWN_TYPE is just an alias for `object` for clarity.
67
+ UNKNOWN_TYPE: type = object
68
+ # It is important for our usage of it that UNKNOWN_TYPE is a `type`,
69
+ # and the most general `type`.
70
+ assert isinstance(UNKNOWN_TYPE, type)
71
+
72
+
73
+ SENSITIVE_FUNCTIONS = {
74
+ # builtin functions that we should never be calling from our code
75
+ # generation. This is a defense-in-depth mechansim to stop our code
76
+ # generation becoming a code execution vulnerability. There should also be
77
+ # higher level code that ensures we are not generating calls to arbitrary
78
+ # Python functions. This is not a comprehensive list of functions we are not
79
+ # using, but functions we definitely don't need and are most likely to be
80
+ # used to execute remote code or to get around safety mechanisms.
81
+ "__import__",
82
+ "__build_class__",
83
+ "apply",
84
+ "compile",
85
+ "eval",
86
+ "exec",
87
+ "execfile",
88
+ "exit",
89
+ "file",
90
+ "globals",
91
+ "locals",
92
+ "open",
93
+ "object",
94
+ "reload",
95
+ "type",
96
+ }
97
+
98
+
99
+ class CodeGenAst(ABC):
100
+ """
101
+ Base class representing a simplified Python AST (not the real one).
102
+ Generates real `ast.*` nodes via `as_ast()` method.
103
+ """
104
+
105
+ @abstractmethod
106
+ def as_ast(self) -> py_ast.AST: ...
107
+
108
+ child_elements: ClassVar[list[str]]
109
+
110
+ def as_python_source(self) -> str:
111
+ """Return the Python source code for this AST node."""
112
+ node = self.as_ast()
113
+ py_ast.fix_missing_locations(node)
114
+ return py_ast.unparse(node)
115
+
116
+
117
+ class CodeGenAstList(ABC):
118
+ """
119
+ Alternative base class to CodeGenAst when we have code that wants to return a
120
+ list of AST objects. These must also be `stmt` objects.
121
+ """
122
+
123
+ @abstractmethod
124
+ def as_ast_list(self, allow_empty: bool = True) -> list[py_ast.stmt]: ...
125
+
126
+ child_elements: ClassVar[list[str]]
127
+
128
+ def as_python_source(self) -> str:
129
+ """Return the Python source code for this AST list."""
130
+ mod = py_ast.Module(body=self.as_ast_list(), type_ignores=[], **DEFAULT_AST_ARGS_MODULE)
131
+ py_ast.fix_missing_locations(mod)
132
+ return py_ast.unparse(mod)
133
+
134
+
135
+ CodeGenAstType = CodeGenAst | CodeGenAstList
136
+
137
+
138
+ class Scope:
139
+ def __init__(self, parent_scope: Scope | None = None):
140
+ self.parent_scope = parent_scope
141
+ self.names: set[str] = set()
142
+ self._function_arg_reserved_names: set[str] = set()
143
+ self._assignments: set[str] = set()
144
+
145
+ def is_name_in_use(self, name: str) -> bool:
146
+ if name in self.names:
147
+ return True
148
+
149
+ if self.parent_scope is None:
150
+ return False
151
+
152
+ return self.parent_scope.is_name_in_use(name)
153
+
154
+ def is_name_reserved_function_arg(self, name: str) -> bool:
155
+ if name in self._function_arg_reserved_names:
156
+ return True
157
+
158
+ if self.parent_scope is None:
159
+ return False
160
+
161
+ return self.parent_scope.is_name_reserved_function_arg(name)
162
+
163
+ def is_name_reserved(self, name: str) -> bool:
164
+ return self.is_name_in_use(name) or self.is_name_reserved_function_arg(name)
165
+
166
+ def reserve_name(
167
+ self,
168
+ requested: str,
169
+ function_arg: bool = False,
170
+ is_builtin: bool = False,
171
+ ):
172
+ """
173
+ Reserve a name as being in use in a scope.
174
+
175
+ Pass function_arg=True if this is a function argument.
176
+ """
177
+
178
+ def _add(final: str):
179
+ self.names.add(final)
180
+ return final
181
+
182
+ if function_arg:
183
+ if self.is_name_reserved_function_arg(requested):
184
+ assert not self.is_name_in_use(requested)
185
+ return _add(requested)
186
+ if self.is_name_reserved(requested):
187
+ raise AssertionError(f"Cannot use '{requested}' as argument name as it is already in use")
188
+
189
+ cleaned = cleanup_name(requested)
190
+
191
+ attempt = cleaned
192
+ count = 2 # instance without suffix is regarded as 1
193
+ # To avoid shadowing of global names in local scope, we
194
+ # take into account parent scope when assigning names.
195
+
196
+ def _is_name_allowed(name: str) -> bool:
197
+ # We need to also protect against using keywords ('class', 'def' etc.)
198
+ # i.e. count all keywords as 'used'.
199
+ # However, some builtins are also keywords (e.g. 'None'), and so
200
+ # if a builtin is being reserved, don't check against the keyword list
201
+ if (not is_builtin) and keyword.iskeyword(name):
202
+ return False
203
+
204
+ return not self.is_name_reserved(name)
205
+
206
+ while not _is_name_allowed(attempt):
207
+ attempt = cleaned + str(count)
208
+ count += 1
209
+
210
+ return _add(attempt)
211
+
212
+ def reserve_function_arg_name(self, name: str):
213
+ """
214
+ Reserve a name for *later* use as a function argument. This does not result
215
+ in that name being considered 'in use' in the current scope, but will
216
+ avoid the name being assigned for any use other than as a function argument.
217
+ """
218
+ # To keep things simple, and the generated code predictable, we reserve
219
+ # names for all function arguments in a separate scope, and insist on
220
+ # the exact names
221
+ if self.is_name_reserved(name):
222
+ raise AssertionError(f"Can't reserve '{name}' as function arg name as it is already reserved")
223
+ self._function_arg_reserved_names.add(name)
224
+
225
+ def has_assignment(self, name: str) -> bool:
226
+ return name in self._assignments
227
+
228
+ def register_assignment(self, name: str) -> None:
229
+ self._assignments.add(name)
230
+
231
+ def create_name(self, name: str) -> Name:
232
+ reserved = self.reserve_name(name)
233
+ return Name(reserved, self)
234
+
235
+ def name(self, name: str) -> Name:
236
+ # Convenience utility for returning a Name
237
+ return Name(name, self)
238
+
239
+
240
+ _IDENTIFIER_SANITIZER_RE = re.compile("[^a-zA-Z0-9_]")
241
+ _IDENTIFIER_START_RE = re.compile("^[a-zA-Z_]")
242
+
243
+
244
+ def cleanup_name(name: str) -> str:
245
+ """
246
+ Convert name to a allowable identifier
247
+ """
248
+ # See https://docs.python.org/2/reference/lexical_analysis.html#grammar-token-identifier
249
+ name = _IDENTIFIER_SANITIZER_RE.sub("", name)
250
+ if not _IDENTIFIER_START_RE.match(name):
251
+ name = "n" + name
252
+ return name
253
+
254
+
255
+ class Statement(CodeGenAst):
256
+ pass
257
+
258
+
259
+ @runtime_checkable
260
+ class SupportsNameAssignment(Protocol):
261
+ def has_assignment_for_name(self, name: str) -> bool: ...
262
+
263
+
264
+ class _Annotation(Statement):
265
+ """A bare type annotation without a value, e.g. ``x: int``."""
266
+
267
+ child_elements = []
268
+
269
+ def __init__(self, name: str, annotation: Expression):
270
+ self.name = name
271
+ self.annotation = annotation
272
+
273
+ def as_ast(self):
274
+ if not allowable_name(self.name):
275
+ raise AssertionError(f"Expected {self.name} to be a valid Python identifier")
276
+ return py_ast.AnnAssign(
277
+ target=py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS),
278
+ annotation=self.annotation.as_ast(),
279
+ simple=1,
280
+ value=None,
281
+ **DEFAULT_AST_ARGS,
282
+ )
283
+
284
+ def has_assignment_for_name(self, name: str) -> bool:
285
+ return self.name == name
286
+
287
+
288
+ class _Assignment(Statement):
289
+ child_elements = ["value"]
290
+
291
+ def __init__(self, name: str, value: Expression, /, *, type_hint: Expression | None = None):
292
+ self.name = name
293
+ self.value = value
294
+ self.type_hint = type_hint
295
+
296
+ def as_ast(self):
297
+ if not allowable_name(self.name):
298
+ raise AssertionError(f"Expected {self.name} to be a valid Python identifier")
299
+ if self.type_hint is None:
300
+ return py_ast.Assign(
301
+ targets=[py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS)],
302
+ value=self.value.as_ast(),
303
+ **DEFAULT_AST_ARGS,
304
+ )
305
+ else:
306
+ return py_ast.AnnAssign(
307
+ target=py_ast.Name(id=self.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS),
308
+ annotation=self.type_hint.as_ast(),
309
+ simple=1, # not sure what this does...
310
+ value=self.value.as_ast(),
311
+ **DEFAULT_AST_ARGS,
312
+ )
313
+
314
+ def has_assignment_for_name(self, name: str) -> bool:
315
+ return self.name == name
316
+
317
+
318
+ class Block(CodeGenAstList):
319
+ child_elements = ["statements"]
320
+
321
+ def __init__(self, scope: Scope, parent_block: Block | None = None):
322
+ self.scope = scope
323
+ # We all `Expression` here for things like MethodCall which
324
+ # are bare expressions that are still useful for side effects
325
+ self.statements: list[Block | Statement | Expression] = []
326
+ self.parent_block = parent_block
327
+
328
+ def as_ast_list(self, allow_empty: bool = True) -> list[py_ast.stmt]:
329
+ retval: list[py_ast.stmt] = []
330
+ for s in self.statements:
331
+ if isinstance(s, CodeGenAstList):
332
+ retval.extend(s.as_ast_list(allow_empty=True))
333
+ else:
334
+ if isinstance(s, Statement):
335
+ ast_obj = s.as_ast()
336
+ assert isinstance(ast_obj, py_ast.stmt), (
337
+ "Statement object return {ast_obj} which is not a subclass of py_ast.stmt"
338
+ )
339
+ retval.append(ast_obj)
340
+ else:
341
+ # Things like bare function/method calls need to be wrapped
342
+ # in `Expr` to match the way Python parses.
343
+ retval.append(py_ast.Expr(s.as_ast(), **DEFAULT_AST_ARGS))
344
+
345
+ if len(retval) == 0 and not allow_empty:
346
+ return [py_ast.Pass(**DEFAULT_AST_ARGS)]
347
+ return retval
348
+
349
+ def add_statement(self, statement: Statement | Block | Expression) -> None:
350
+ self.statements.append(statement)
351
+ if isinstance(statement, Block):
352
+ if statement.parent_block is None:
353
+ statement.parent_block = self
354
+ else:
355
+ if statement.parent_block != self:
356
+ raise AssertionError(
357
+ f"Block {statement} is already child of {statement.parent_block}, can't reassign to {self}"
358
+ )
359
+
360
+ def create_import(self, module: str, as_: str | None = None) -> tuple[Import, Name]:
361
+
362
+ return_name_object: Name
363
+ if as_ is not None:
364
+ # "import foo as bar" results in `bar` name being assigned.
365
+ if not allowable_name(as_):
366
+ raise AssertionError(f"{as_!r} is not an allowable 'as' name")
367
+ if self.scope.is_name_in_use(as_):
368
+ raise AssertionError(f"{as_!r} is already assigned in the scope")
369
+ as_name_object = self.scope.create_name(as_)
370
+ return_name_object = as_name_object
371
+ else:
372
+ as_name_object = None
373
+ # "import foo" results in `foo` name being assigned
374
+ # "import foo.bar" also results in `foo` being reserved.
375
+ dotted_parts = module.split(".")
376
+ for part in dotted_parts:
377
+ if not allowable_name(part):
378
+ raise AssertionError(f"{module!r} not an allowable 'import' name")
379
+ name_to_assign = dotted_parts[0]
380
+ # We can't rename, so don't use `reserve_name` or `create_name`.
381
+ # We also need to allow for multiple imports, like `import foo.bar` then `import foo.baz`
382
+
383
+ if not self.scope.is_name_in_use(name_to_assign):
384
+ self.scope.reserve_name(name_to_assign)
385
+ return_name_object = self.scope.name(name_to_assign)
386
+
387
+ import_statement = Import(module=module, as_=as_name_object)
388
+ self.add_statement(import_statement)
389
+ return import_statement, return_name_object
390
+
391
+ def create_import_from(self, *, from_: str, import_: str, as_: str | None = None) -> tuple[ImportFrom, Name]:
392
+
393
+ return_name_object: Name
394
+ if as_ is not None:
395
+ # "from foo import bar as baz" results in `baz` name being assigned.
396
+ if not allowable_name(as_):
397
+ raise AssertionError(f"{as_!r} is not an allowable 'as' name")
398
+ if self.scope.is_name_in_use(as_):
399
+ raise AssertionError(f"{as_!r} is already assigned in the scope")
400
+ as_name_object = self.scope.create_name(as_)
401
+ return_name_object = as_name_object
402
+ else:
403
+ as_name_object = None
404
+ # Check the dotted bit.
405
+ dotted_parts = from_.split(".")
406
+ for part in dotted_parts:
407
+ if not allowable_name(part):
408
+ raise AssertionError(f"{from_!r} not an allowable 'import' name")
409
+
410
+ # Check the `import_` for clashes.
411
+ name_to_assign = import_
412
+
413
+ if self.scope.is_name_in_use(name_to_assign):
414
+ raise AssertionError(f"{name_to_assign!r} is already assigned in the scope")
415
+ return_name_object = self.scope.create_name(name_to_assign)
416
+
417
+ import_statement = ImportFrom(from_module=from_, import_=import_, as_=as_name_object)
418
+ self.add_statement(import_statement)
419
+ return import_statement, return_name_object
420
+
421
+ # Safe alternatives to Block.statements being manipulated directly:
422
+ def create_assignment(
423
+ self, name: str | Name, value: Expression, *, type_hint: Expression | None = None, allow_multiple: bool = False
424
+ ):
425
+ """
426
+ Adds an assigment of the form:
427
+
428
+ x = value
429
+ """
430
+ if isinstance(name, Name):
431
+ name = name.name
432
+ if not self.scope.is_name_in_use(name):
433
+ raise AssertionError(f"Cannot assign to unreserved name '{name}'")
434
+
435
+ if self.scope.has_assignment(name):
436
+ if not allow_multiple:
437
+ raise AssertionError(f"Have already assigned to '{name}' in this scope")
438
+ else:
439
+ self.scope.register_assignment(name)
440
+
441
+ self.add_statement(_Assignment(name, value, type_hint=type_hint))
442
+
443
+ def create_annotation(self, name: str, annotation: Expression) -> Name:
444
+ """
445
+ Adds a bare type annotation of the form::
446
+
447
+ x: int
448
+
449
+ Reserves the name and adds the annotation statement to the block.
450
+ """
451
+ name_obj = self.scope.create_name(name)
452
+ self.scope.register_assignment(name_obj.name)
453
+ self.add_statement(_Annotation(name_obj.name, annotation))
454
+ return name_obj
455
+
456
+ def create_field(self, name: str, annotation: Expression, *, default: Expression | None = None) -> Name:
457
+ """
458
+ Create a typed field, typically used in dataclass bodies.
459
+
460
+ If *default* is provided, creates an annotated assignment::
461
+
462
+ x: int = 0
463
+
464
+ Otherwise, creates a bare annotation::
465
+
466
+ x: int
467
+ """
468
+ if default is not None:
469
+ name_obj = self.scope.create_name(name)
470
+ self.scope.register_assignment(name_obj.name)
471
+ self.add_statement(_Assignment(name_obj.name, default, type_hint=annotation))
472
+ return name_obj
473
+ else:
474
+ return self.create_annotation(name, annotation)
475
+
476
+ def create_function(
477
+ self,
478
+ name: str,
479
+ args: Sequence[str | FunctionArg],
480
+ decorators: Sequence[Expression] | None = None,
481
+ return_type: Expression | None = None,
482
+ ) -> tuple[Function, Name]:
483
+ """
484
+ Reserve a name for a function, create the Function and add the function statement
485
+ to the block.
486
+ """
487
+ name_obj = self.scope.create_name(name)
488
+ func = Function(
489
+ name_obj.name, args=args, parent_scope=self.scope, decorators=decorators, return_type=return_type
490
+ )
491
+ self.add_statement(func)
492
+ return func, name_obj
493
+
494
+ def create_class(
495
+ self,
496
+ name: str,
497
+ bases: Sequence[Expression] | None = None,
498
+ decorators: Sequence[Expression] | None = None,
499
+ ) -> tuple[Class, Name]:
500
+ """
501
+ Reserve a name for a class, create the Class and add the class statement
502
+ to the block.
503
+ """
504
+ name_obj = self.scope.create_name(name)
505
+ cls = Class(name_obj.name, parent_scope=self.scope, bases=bases, decorators=decorators)
506
+ self.add_statement(cls)
507
+ return cls, name_obj
508
+
509
+ def create_return(self, value: Expression) -> None:
510
+ self.add_statement(Return(value))
511
+
512
+ def create_if(self) -> If:
513
+ """
514
+ Create an If statement, add it to this block, and return it.
515
+
516
+ Usage::
517
+
518
+ if_stmt = block.create_if()
519
+ if_block = if_stmt.add_if(condition)
520
+ if_block.create_return(value)
521
+ """
522
+ if_statement = If(self.scope, parent_block=self)
523
+ self.add_statement(if_statement)
524
+ return if_statement
525
+
526
+ def create_with(self, context_expr: Expression, target: Name | None = None) -> With:
527
+ """
528
+ Create a With statement, add it to this block, and return it
529
+
530
+ Usage::
531
+
532
+ with_stmt = block.create_with(expr, "f")
533
+ with_stmt.body.create_return(value)
534
+ """
535
+ with_statement = With(context_expr, target=target, parent_scope=self.scope, parent_block=self)
536
+ self.add_statement(with_statement)
537
+ return with_statement
538
+
539
+ def has_assignment_for_name(self, name: str) -> bool:
540
+ for s in self.statements:
541
+ if isinstance(s, SupportsNameAssignment) and s.has_assignment_for_name(name):
542
+ return True
543
+ if self.parent_block is not None:
544
+ return self.parent_block.has_assignment_for_name(name)
545
+ return False
546
+
547
+
548
+ class Module(Block, CodeGenAst):
549
+ def __init__(self, reserve_builtins: bool = True):
550
+ scope = Scope(parent_scope=None)
551
+ if reserve_builtins:
552
+ for name in dir(builtins):
553
+ scope.reserve_name(name, is_builtin=True)
554
+ Block.__init__(self, scope)
555
+ self.file_comments: list[str] = []
556
+
557
+ def as_ast(self) -> py_ast.Module:
558
+ return py_ast.Module(body=self.as_ast_list(), type_ignores=[], **DEFAULT_AST_ARGS_MODULE)
559
+
560
+ def as_python_source(self) -> str:
561
+ main = super().as_python_source()
562
+ file_comments = "".join(f"# {comment}\n" for comment in self.file_comments)
563
+ return file_comments + main
564
+
565
+ def add_file_comment(self, comment: str) -> None:
566
+ self.file_comments.append(comment)
567
+
568
+
569
+ class ArgKind(enum.Enum):
570
+ """The kind of a function argument."""
571
+
572
+ POSITIONAL_ONLY = "positional_only"
573
+ POSITIONAL_OR_KEYWORD = "positional_or_keyword"
574
+ KEYWORD_ONLY = "keyword_only"
575
+
576
+
577
+ @dataclass(frozen=True)
578
+ class FunctionArg:
579
+ """A function argument with a name, kind, and optional default value."""
580
+
581
+ name: str
582
+ kind: ArgKind = ArgKind.POSITIONAL_OR_KEYWORD
583
+ default: Expression | None = None
584
+ annotation: Expression | None = None
585
+
586
+ @classmethod
587
+ def positional(
588
+ cls, name: str, *, default: Expression | None = None, annotation: Expression | None = None
589
+ ) -> FunctionArg:
590
+ """Create a positional-only argument."""
591
+ return cls(name=name, kind=ArgKind.POSITIONAL_ONLY, default=default, annotation=annotation)
592
+
593
+ @classmethod
594
+ def keyword(
595
+ cls, name: str, *, default: Expression | None = None, annotation: Expression | None = None
596
+ ) -> FunctionArg:
597
+ """Create a keyword-only argument."""
598
+ return cls(name=name, kind=ArgKind.KEYWORD_ONLY, default=default, annotation=annotation)
599
+
600
+ @classmethod
601
+ def standard(
602
+ cls, name: str, *, default: Expression | None = None, annotation: Expression | None = None
603
+ ) -> FunctionArg:
604
+ """Create a positional-or-keyword argument (the Python default)."""
605
+ return cls(name=name, kind=ArgKind.POSITIONAL_OR_KEYWORD, default=default, annotation=annotation)
606
+
607
+
608
+ def _normalize_args(args: Sequence[str | FunctionArg]) -> list[FunctionArg]:
609
+ """Normalize a mixed list of str and FunctionArg into a list of FunctionArg."""
610
+ return [FunctionArg(name=a) if isinstance(a, str) else a for a in args]
611
+
612
+
613
+ def _validate_arg_order(args: list[FunctionArg]) -> None:
614
+ """Validate that args are in the correct order:
615
+ positional-only, then positional-or-keyword, then keyword-only.
616
+ Within each group, defaults must come after non-defaults.
617
+ """
618
+ # Check kind ordering
619
+ KIND_ORDER = {
620
+ ArgKind.POSITIONAL_ONLY: 0,
621
+ ArgKind.POSITIONAL_OR_KEYWORD: 1,
622
+ ArgKind.KEYWORD_ONLY: 2,
623
+ }
624
+ prev_order = -1
625
+ for arg in args:
626
+ order = KIND_ORDER[arg.kind]
627
+ if order < prev_order:
628
+ raise ValueError(
629
+ f"Argument '{arg.name}' of kind {arg.kind.value} "
630
+ f"is out of order: positional-only args must come first, "
631
+ f"then positional-or-keyword, then keyword-only"
632
+ )
633
+ prev_order = order
634
+
635
+ # Check default ordering within positional groups
636
+ # (positional-only and positional-or-keyword share defaults list,
637
+ # so non-default can't follow default across these groups)
638
+ seen_default_in_positional = False
639
+ for arg in args:
640
+ if arg.kind in (ArgKind.POSITIONAL_ONLY, ArgKind.POSITIONAL_OR_KEYWORD):
641
+ if arg.default is not None:
642
+ seen_default_in_positional = True
643
+ elif seen_default_in_positional:
644
+ raise ValueError(f"Non-default argument '{arg.name}' follows default argument in positional arguments")
645
+
646
+ # keyword-only args can have defaults in any order (Python allows it)
647
+
648
+
649
+ class Function(Scope, Statement):
650
+ child_elements = ["body"]
651
+
652
+ def __init__(
653
+ self,
654
+ name: str,
655
+ args: Sequence[str | FunctionArg] | None = None,
656
+ parent_scope: Scope | None = None,
657
+ decorators: Sequence[Expression] | None = None,
658
+ return_type: Expression | None = None,
659
+ ):
660
+ super().__init__(parent_scope=parent_scope)
661
+ self.body = Block(self)
662
+ self.func_name = name
663
+ self.decorators: list[Expression] = list(decorators) if decorators else []
664
+ self.return_type: Expression | None = return_type
665
+ self._args: list[FunctionArg] = []
666
+ if args is not None:
667
+ self.add_args(args)
668
+
669
+ def add_args(self, args: Sequence[str | FunctionArg]) -> None:
670
+ """Add arguments to the function, with the same validation as in __init__."""
671
+ normalized = _normalize_args(args)
672
+ combined = self._args + normalized
673
+ _validate_arg_order(combined)
674
+ for arg in normalized:
675
+ if self.is_name_in_use(arg.name):
676
+ raise AssertionError(f"Can't use '{arg.name}' as function argument name because it shadows other names")
677
+ self.reserve_name(arg.name, function_arg=True)
678
+ self._args = combined
679
+
680
+ def as_ast(self) -> py_ast.stmt:
681
+ if not allowable_name(self.func_name):
682
+ raise AssertionError(f"Expected '{self.func_name}' to be a valid Python identifier")
683
+ for arg in self._args:
684
+ if not allowable_name(arg.name):
685
+ raise AssertionError(f"Expected '{arg.name}' to be a valid Python identifier")
686
+
687
+ def _make_arg(a: FunctionArg) -> py_ast.arg:
688
+ return py_ast.arg(
689
+ arg=a.name,
690
+ annotation=a.annotation.as_ast() if a.annotation is not None else None,
691
+ **DEFAULT_AST_ARGS,
692
+ )
693
+
694
+ posonlyargs = [_make_arg(a) for a in self._args if a.kind == ArgKind.POSITIONAL_ONLY]
695
+ regular_args = [_make_arg(a) for a in self._args if a.kind == ArgKind.POSITIONAL_OR_KEYWORD]
696
+ kwonlyargs = [_make_arg(a) for a in self._args if a.kind == ArgKind.KEYWORD_ONLY]
697
+
698
+ # defaults: right-aligned to posonlyargs + regular_args
699
+ positional_all = [a for a in self._args if a.kind in (ArgKind.POSITIONAL_ONLY, ArgKind.POSITIONAL_OR_KEYWORD)]
700
+ defaults = [a.default.as_ast() for a in positional_all if a.default is not None]
701
+
702
+ # kw_defaults: one entry per kwonlyarg, None if no default
703
+ kw_defaults: list[py_ast.expr | None] = [
704
+ a.default.as_ast() if a.default is not None else None for a in self._args if a.kind == ArgKind.KEYWORD_ONLY
705
+ ]
706
+
707
+ return py_ast.FunctionDef(
708
+ name=self.func_name,
709
+ args=py_ast.arguments(
710
+ posonlyargs=posonlyargs,
711
+ args=regular_args,
712
+ vararg=None,
713
+ kwonlyargs=kwonlyargs,
714
+ kw_defaults=kw_defaults,
715
+ kwarg=None,
716
+ defaults=defaults,
717
+ **DEFAULT_AST_ARGS_ARGUMENTS,
718
+ ),
719
+ body=self.body.as_ast_list(allow_empty=False),
720
+ decorator_list=[d.as_ast() for d in self.decorators],
721
+ type_params=[],
722
+ returns=self.return_type.as_ast() if self.return_type is not None else None,
723
+ **DEFAULT_AST_ARGS,
724
+ )
725
+
726
+ def create_return(self, value: Expression):
727
+ self.body.create_return(value)
728
+
729
+
730
+ class Class(Scope, Statement):
731
+ child_elements = ["body"]
732
+
733
+ def __init__(
734
+ self,
735
+ name: str,
736
+ parent_scope: Scope | None = None,
737
+ bases: Sequence[Expression] | None = None,
738
+ decorators: Sequence[Expression] | None = None,
739
+ ):
740
+ super().__init__(parent_scope=parent_scope)
741
+ self.body = Block(self)
742
+ self.class_name = name
743
+ self.bases: list[Expression] = list(bases) if bases else []
744
+ self.decorators: list[Expression] = list(decorators) if decorators else []
745
+
746
+ def as_ast(self) -> py_ast.stmt:
747
+ if not allowable_name(self.class_name):
748
+ raise AssertionError(f"Expected '{self.class_name}' to be a valid Python identifier")
749
+ return py_ast.ClassDef(
750
+ name=self.class_name,
751
+ bases=[b.as_ast() for b in self.bases],
752
+ keywords=[],
753
+ body=self.body.as_ast_list(allow_empty=False),
754
+ decorator_list=[d.as_ast() for d in self.decorators],
755
+ type_params=[],
756
+ **DEFAULT_AST_ARGS,
757
+ )
758
+
759
+
760
+ class Return(Statement):
761
+ child_elements = ["value"]
762
+
763
+ def __init__(self, value: Expression):
764
+ self.value = value
765
+
766
+ def as_ast(self):
767
+ return py_ast.Return(self.value.as_ast(), **DEFAULT_AST_ARGS)
768
+
769
+ def __repr__(self):
770
+ return f"Return({repr(self.value)}"
771
+
772
+
773
+ class If(Statement):
774
+ child_elements = ["if_blocks", "conditions", "else_block"]
775
+
776
+ def __init__(self, parent_scope: Scope, parent_block: Block | None = None):
777
+ # We model a "compound if statement" as a list of if blocks
778
+ # (if/elif/elif etc), each with their own condition, with a final else
779
+ # block. Note this is quite different from Python's AST for the same
780
+ # thing, so conversion to AST is more complex because of this.
781
+ self.if_blocks: list[Block] = []
782
+ self.conditions: list[Expression] = []
783
+ self._parent_block = parent_block
784
+ self.else_block = Block(parent_scope, parent_block=self._parent_block)
785
+ self._parent_scope = parent_scope
786
+
787
+ def create_if_branch(self, condition: Expression) -> Block:
788
+ """
789
+ Create new if branch with a condition.
790
+ """
791
+ new_if = Block(self._parent_scope, parent_block=self._parent_block)
792
+ self.if_blocks.append(new_if)
793
+ self.conditions.append(condition)
794
+ return new_if
795
+
796
+ def finalize(self) -> Block | Statement:
797
+ if not self.if_blocks:
798
+ # Unusual case of no conditions, only default case, but it
799
+ # simplifies other code to be able to handle this uniformly. We can
800
+ # replace this if statement with a single unconditional block.
801
+ return self.else_block
802
+ return self
803
+
804
+ def as_ast(self) -> py_ast.If:
805
+ if len(self.if_blocks) == 0:
806
+ raise AssertionError("Should have called `finalize` on If")
807
+ if_ast = empty_If()
808
+ current_if = if_ast
809
+ previous_if = None
810
+ for condition, if_block in zip(self.conditions, self.if_blocks):
811
+ current_if.test = condition.as_ast()
812
+ current_if.body = if_block.as_ast_list()
813
+ if previous_if is not None:
814
+ previous_if.orelse.append(current_if)
815
+
816
+ previous_if = current_if
817
+ current_if = empty_If()
818
+
819
+ if self.else_block.statements:
820
+ assert previous_if is not None
821
+ previous_if.orelse = self.else_block.as_ast_list()
822
+
823
+ return if_ast
824
+
825
+
826
+ class With(Statement):
827
+ child_elements = ["context_expr", "body"]
828
+
829
+ def __init__(
830
+ self,
831
+ context_expr: Expression,
832
+ target: Name | None = None,
833
+ *,
834
+ parent_scope: Scope,
835
+ parent_block: Block | None = None,
836
+ ):
837
+ self.context_expr = context_expr
838
+ self.target = target
839
+ self._parent_scope = parent_scope
840
+ self._parent_block = parent_block
841
+ self.body = Block(parent_scope, parent_block=parent_block)
842
+
843
+ def as_ast(self) -> py_ast.With:
844
+ optional_vars = None
845
+ if self.target is not None:
846
+ optional_vars = py_ast.Name(id=self.target.name, ctx=py_ast.Store(), **DEFAULT_AST_ARGS)
847
+
848
+ return py_ast.With(
849
+ items=[
850
+ py_ast.withitem(
851
+ context_expr=self.context_expr.as_ast(),
852
+ optional_vars=optional_vars,
853
+ )
854
+ ],
855
+ body=self.body.as_ast_list(allow_empty=False),
856
+ **DEFAULT_AST_ARGS,
857
+ )
858
+
859
+
860
+ class Try(Statement):
861
+ child_elements = ["catch_exceptions", "try_block", "except_block", "else_block"]
862
+
863
+ def __init__(self, catch_exceptions: Sequence[Expression], parent_scope: Scope):
864
+ self.catch_exceptions = catch_exceptions
865
+ self.try_block = Block(parent_scope)
866
+ self.except_block = Block(parent_scope)
867
+ self.else_block = Block(parent_scope)
868
+
869
+ def as_ast(self) -> py_ast.Try:
870
+ return py_ast.Try(
871
+ body=self.try_block.as_ast_list(allow_empty=False),
872
+ handlers=[
873
+ py_ast.ExceptHandler(
874
+ type=(
875
+ self.catch_exceptions[0].as_ast()
876
+ if len(self.catch_exceptions) == 1
877
+ else py_ast.Tuple(
878
+ elts=[e.as_ast() for e in self.catch_exceptions],
879
+ ctx=py_ast.Load(),
880
+ **DEFAULT_AST_ARGS,
881
+ )
882
+ ),
883
+ name=None,
884
+ body=self.except_block.as_ast_list(allow_empty=False),
885
+ **DEFAULT_AST_ARGS,
886
+ )
887
+ ],
888
+ orelse=self.else_block.as_ast_list(allow_empty=True),
889
+ finalbody=[],
890
+ **DEFAULT_AST_ARGS,
891
+ )
892
+
893
+ def has_assignment_for_name(self, name: str) -> bool:
894
+ if (
895
+ self.try_block.has_assignment_for_name(name) or self.else_block.has_assignment_for_name(name)
896
+ ) and self.except_block.has_assignment_for_name(name):
897
+ return True
898
+ return False
899
+
900
+
901
+ class Import(Statement):
902
+ """
903
+ Simple import statements, supporting:
904
+ - import foo
905
+ - import foo as bar
906
+ - import foo.bar
907
+ - import foo.bar as baz
908
+
909
+ Use via `Block.create_import`
910
+
911
+ We deliberately don't support multiple imports - these should
912
+ be cleaned up later using a linter on the generated code, if desired.
913
+ """
914
+
915
+ def __init__(self, module: str, as_: Name | None) -> None:
916
+ self.module = module
917
+ self.as_ = as_
918
+
919
+ def as_ast(self) -> py_ast.AST:
920
+ if self.as_ is None:
921
+ # No alias needed:
922
+ return py_ast.Import(names=[py_ast.alias(name=self.module)], **DEFAULT_AST_ARGS)
923
+ else:
924
+ return py_ast.Import(names=[py_ast.alias(name=self.module, asname=self.as_.name)], **DEFAULT_AST_ARGS)
925
+
926
+
927
+ class ImportFrom(Statement):
928
+ """
929
+ `from import' statement, supporting:
930
+ - from foo import bar
931
+ - from foo import bar as baz
932
+
933
+ Use via `Block.create_import_from`
934
+
935
+ We deliberately don't support multiple imports - these should
936
+ be cleaned up later using a linter on the generated code, if desired.
937
+ """
938
+
939
+ def __init__(self, from_module: str, import_: str, as_: Name | None) -> None:
940
+ self.from_module = from_module
941
+ self.import_ = import_
942
+ self.as_ = as_
943
+
944
+ def as_ast(self) -> py_ast.AST:
945
+ if self.as_ is None:
946
+ # No alias needed:
947
+ return py_ast.ImportFrom(
948
+ module=self.from_module,
949
+ names=[py_ast.alias(name=self.import_)],
950
+ level=0,
951
+ **DEFAULT_AST_ARGS,
952
+ )
953
+ else:
954
+ return py_ast.ImportFrom(
955
+ module=self.from_module,
956
+ names=[py_ast.alias(name=self.import_, asname=self.as_.name)],
957
+ level=0,
958
+ **DEFAULT_AST_ARGS,
959
+ )
960
+
961
+
962
+ class Expression(CodeGenAst):
963
+ @abstractmethod
964
+ def as_ast(self) -> py_ast.expr: ...
965
+
966
+ # Some utilities for easy chaining:
967
+
968
+ def attr(self, attribute: str, /) -> Attr:
969
+ return Attr(self, attribute)
970
+
971
+ def call(
972
+ self,
973
+ args: Sequence[Expression] | None = None,
974
+ kwargs: dict[str, Expression] | None = None,
975
+ ) -> Call:
976
+ return Call(self, args or [], kwargs or {})
977
+
978
+ def method_call(
979
+ self,
980
+ attribute: str,
981
+ args: Sequence[Expression] | None = None,
982
+ kwargs: dict[str, Expression] | None = None,
983
+ ) -> Call:
984
+ return self.attr(attribute).call(args, kwargs)
985
+
986
+ # Arithmetic operators
987
+
988
+ def add(self, other: Expression, /) -> Add:
989
+ return Add(self, other)
990
+
991
+ def sub(self, other: Expression, /) -> Sub:
992
+ return Sub(self, other)
993
+
994
+ def mul(self, other: Expression, /) -> Mul:
995
+ return Mul(self, other)
996
+
997
+ def div(self, other: Expression, /) -> Div:
998
+ return Div(self, other)
999
+
1000
+ def floordiv(self, other: Expression, /) -> FloorDiv:
1001
+ return FloorDiv(self, other)
1002
+
1003
+ def mod(self, other: Expression, /) -> Mod:
1004
+ return Mod(self, other)
1005
+
1006
+ def pow(self, other: Expression, /) -> Pow:
1007
+ return Pow(self, other)
1008
+
1009
+ def matmul(self, other: Expression, /) -> MatMul:
1010
+ return MatMul(self, other)
1011
+
1012
+ # Comparison operators
1013
+
1014
+ def eq(self, other: Expression, /) -> Equals:
1015
+ return Equals(self, other)
1016
+
1017
+ def ne(self, other: Expression, /) -> NotEquals:
1018
+ return NotEquals(self, other)
1019
+
1020
+ def lt(self, other: Expression, /) -> Lt:
1021
+ return Lt(self, other)
1022
+
1023
+ def gt(self, other: Expression, /) -> Gt:
1024
+ return Gt(self, other)
1025
+
1026
+ def le(self, other: Expression, /) -> LtE:
1027
+ return LtE(self, other)
1028
+
1029
+ def ge(self, other: Expression, /) -> GtE:
1030
+ return GtE(self, other)
1031
+
1032
+ # Boolean operators
1033
+
1034
+ def and_(self, other: Expression, /) -> And:
1035
+ return And(self, other)
1036
+
1037
+ def or_(self, other: Expression, /) -> Or:
1038
+ return Or(self, other)
1039
+
1040
+ # Membership operators
1041
+
1042
+ def in_(self, other: Expression, /) -> In:
1043
+ return In(self, other)
1044
+
1045
+ def not_in(self, other: Expression, /) -> NotIn:
1046
+ return NotIn(self, other)
1047
+
1048
+ # Unpacking
1049
+
1050
+ def starred(self) -> Starred:
1051
+ return Starred(self)
1052
+
1053
+
1054
+ class String(Expression):
1055
+ child_elements = []
1056
+
1057
+ type = str
1058
+
1059
+ def __init__(self, string_value: str):
1060
+ self.string_value = string_value
1061
+
1062
+ def as_ast(self) -> py_ast.expr:
1063
+ return py_ast.Constant(
1064
+ self.string_value,
1065
+ kind=None, # 3.8, indicates no prefix, needed only for tests
1066
+ **DEFAULT_AST_ARGS,
1067
+ )
1068
+
1069
+ def __repr__(self):
1070
+ return f"String({repr(self.string_value)})"
1071
+
1072
+ def __eq__(self, other: object):
1073
+ return isinstance(other, String) and other.string_value == self.string_value
1074
+
1075
+
1076
+ class Bool(Expression):
1077
+ child_elements = []
1078
+
1079
+ type = bool
1080
+
1081
+ def __init__(self, value: bool):
1082
+ self.value = value
1083
+
1084
+ def as_ast(self) -> py_ast.expr:
1085
+ return py_ast.Constant(self.value, **DEFAULT_AST_ARGS)
1086
+
1087
+ def __repr__(self):
1088
+ return f"Bool({self.value!r})"
1089
+
1090
+
1091
+ class Bytes(Expression):
1092
+ child_elements = []
1093
+
1094
+ type = bytes
1095
+
1096
+ def __init__(self, value: bytes):
1097
+ self.value = value
1098
+
1099
+ def as_ast(self) -> py_ast.expr:
1100
+ return py_ast.Constant(self.value, **DEFAULT_AST_ARGS)
1101
+
1102
+ def __repr__(self):
1103
+ return f"Bytes({self.value!r})"
1104
+
1105
+
1106
+ class Number(Expression):
1107
+ child_elements = []
1108
+
1109
+ def __init__(self, number: int | float):
1110
+ self.number = number
1111
+
1112
+ def as_ast(self) -> py_ast.expr:
1113
+ return py_ast.Constant(self.number, **DEFAULT_AST_ARGS)
1114
+
1115
+ def __repr__(self):
1116
+ return f"Number({repr(self.number)})"
1117
+
1118
+
1119
+ class List(Expression):
1120
+ child_elements = ["items"]
1121
+
1122
+ def __init__(self, items: list[Expression]):
1123
+ self.items = items
1124
+
1125
+ def as_ast(self) -> py_ast.expr:
1126
+ return py_ast.List(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
1127
+
1128
+
1129
+ class Tuple(Expression):
1130
+ child_elements = ["items"]
1131
+
1132
+ def __init__(self, items: Sequence[Expression]):
1133
+ self.items = items
1134
+
1135
+ def as_ast(self) -> py_ast.expr:
1136
+ return py_ast.Tuple(elts=[i.as_ast() for i in self.items], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
1137
+
1138
+
1139
+ class Set(Expression):
1140
+ child_elements = ["items"]
1141
+
1142
+ def __init__(self, items: Sequence[Expression]):
1143
+ self.items = items
1144
+
1145
+ def as_ast(self) -> py_ast.expr:
1146
+ if len(self.items) == 0:
1147
+ # {} is a dict literal in Python, so empty sets must use set([])
1148
+ return py_ast.Call(
1149
+ func=py_ast.Name(id="set", ctx=py_ast.Load(), **DEFAULT_AST_ARGS),
1150
+ args=[py_ast.List(elts=[], ctx=py_ast.Load(), **DEFAULT_AST_ARGS)],
1151
+ keywords=[],
1152
+ **DEFAULT_AST_ARGS,
1153
+ )
1154
+ return py_ast.Set(elts=[i.as_ast() for i in self.items], **DEFAULT_AST_ARGS)
1155
+
1156
+
1157
+ class Dict(Expression):
1158
+ child_elements = ["pairs"]
1159
+
1160
+ def __init__(self, pairs: Sequence[tuple[Expression, Expression]]):
1161
+ self.pairs = pairs
1162
+
1163
+ def as_ast(self) -> py_ast.expr:
1164
+ return py_ast.Dict(
1165
+ keys=[k.as_ast() for k, _ in self.pairs],
1166
+ values=[v.as_ast() for _, v in self.pairs],
1167
+ **DEFAULT_AST_ARGS,
1168
+ )
1169
+
1170
+
1171
+ class StringJoinBase(Expression):
1172
+ child_elements = ["parts"]
1173
+
1174
+ type = str
1175
+
1176
+ def __init__(self, parts: Sequence[Expression]):
1177
+ self.parts = parts
1178
+
1179
+ def __repr__(self):
1180
+ return f"{self.__class__.__name__}([{', '.join(repr(p) for p in self.parts)}])"
1181
+
1182
+ @classmethod
1183
+ def build(cls: type[StringJoinBase], parts: Sequence[Expression]) -> StringJoinBase | Expression:
1184
+ """
1185
+ Build a string join operation, but return a simpler expression if possible.
1186
+ """
1187
+ # Merge adjacent String objects.
1188
+ new_parts: list[Expression] = []
1189
+ for part in parts:
1190
+ if len(new_parts) > 0 and isinstance(new_parts[-1], String) and isinstance(part, String):
1191
+ new_parts[-1] = String(new_parts[-1].string_value + part.string_value)
1192
+ else:
1193
+ new_parts.append(part)
1194
+ parts = new_parts
1195
+
1196
+ # See if we can eliminate the StringJoin altogether
1197
+ if len(parts) == 0:
1198
+ return String("")
1199
+ if len(parts) == 1:
1200
+ return parts[0]
1201
+ return cls(parts)
1202
+
1203
+
1204
+ class FStringJoin(StringJoinBase):
1205
+ def as_ast(self) -> py_ast.expr:
1206
+ # f-strings
1207
+ values: list[py_ast.expr] = []
1208
+ for part in self.parts:
1209
+ if isinstance(part, String):
1210
+ values.append(part.as_ast())
1211
+ else:
1212
+ values.append(
1213
+ py_ast.FormattedValue(
1214
+ value=part.as_ast(),
1215
+ conversion=-1,
1216
+ format_spec=None,
1217
+ **DEFAULT_AST_ARGS,
1218
+ )
1219
+ )
1220
+ return py_ast.JoinedStr(values=values, **DEFAULT_AST_ARGS)
1221
+
1222
+
1223
+ class ConcatJoin(StringJoinBase):
1224
+ def as_ast(self) -> py_ast.expr:
1225
+ # Concatenate with +
1226
+ left = self.parts[0].as_ast()
1227
+ for part in self.parts[1:]:
1228
+ right = part.as_ast()
1229
+ left = py_ast.BinOp(
1230
+ left=left,
1231
+ op=py_ast.Add(**DEFAULT_AST_ARGS_ADD),
1232
+ right=right,
1233
+ **DEFAULT_AST_ARGS,
1234
+ )
1235
+ return left
1236
+
1237
+
1238
+ # For CPython, f-strings give a measurable improvement over concatenation,
1239
+ # so make that default
1240
+
1241
+ StringJoin = FStringJoin
1242
+
1243
+
1244
+ class Name(Expression):
1245
+ child_elements = []
1246
+
1247
+ def __init__(self, name: str, scope: Scope):
1248
+ if not scope.is_name_in_use(name):
1249
+ raise AssertionError(f"Cannot refer to undefined name '{name}'")
1250
+ self.name = name
1251
+
1252
+ def as_ast(self) -> py_ast.expr:
1253
+ if not allowable_name(self.name, allow_builtin=True):
1254
+ raise AssertionError(f"Expected {self.name} to be a valid Python identifier")
1255
+ return py_ast.Name(id=self.name, ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
1256
+
1257
+ def __eq__(self, other: object):
1258
+ return type(other) is type(self) and other.name == self.name
1259
+
1260
+ def __repr__(self):
1261
+ return f"Name({repr(self.name)})"
1262
+
1263
+
1264
+ class Attr(Expression):
1265
+ child_elements = ["value"]
1266
+
1267
+ def __init__(self, value: Expression, attribute: str) -> None:
1268
+ self.value = value
1269
+ if not allowable_name(attribute, allow_builtin=True):
1270
+ raise AssertionError(f"Expected {attribute} to be a valid Python identifier")
1271
+ self.attribute = attribute
1272
+
1273
+ def as_ast(self) -> py_ast.expr:
1274
+ return py_ast.Attribute(value=self.value.as_ast(), attr=self.attribute, **DEFAULT_AST_ARGS)
1275
+
1276
+
1277
+ class Starred(Expression):
1278
+ child_elements = ["value"]
1279
+
1280
+ def __init__(self, value: Expression):
1281
+ self.value = value
1282
+
1283
+ def as_ast(self) -> py_ast.expr:
1284
+ return py_ast.Starred(value=self.value.as_ast(), ctx=py_ast.Load(), **DEFAULT_AST_ARGS)
1285
+
1286
+ def __repr__(self):
1287
+ return f"Starred({self.value!r})"
1288
+
1289
+
1290
+ def function_call(
1291
+ function_name: str,
1292
+ args: Sequence[Expression],
1293
+ kwargs: dict[str, Expression],
1294
+ scope: Scope,
1295
+ ) -> Expression:
1296
+ if not scope.is_name_in_use(function_name):
1297
+ raise AssertionError(f"Cannot call unknown function '{function_name}'")
1298
+ if function_name in SENSITIVE_FUNCTIONS:
1299
+ raise AssertionError(f"Disallowing call to '{function_name}'")
1300
+
1301
+ return Name(name=function_name, scope=scope).call(args, kwargs)
1302
+
1303
+
1304
+ class Call(Expression):
1305
+ child_elements = ["value", "args", "kwargs"]
1306
+
1307
+ def __init__(
1308
+ self,
1309
+ value: Expression,
1310
+ args: Sequence[Expression],
1311
+ kwargs: dict[str, Expression],
1312
+ ):
1313
+ self.value = value
1314
+ self.args = list(args)
1315
+ self.kwargs = kwargs
1316
+
1317
+ def as_ast(self) -> py_ast.expr:
1318
+
1319
+ for name in self.kwargs.keys():
1320
+ if not allowable_keyword_arg_name(name):
1321
+ raise AssertionError(f"Expected {name} to be a valid Fluent NamedArgument name")
1322
+
1323
+ if any(not allowable_name(name) for name in self.kwargs.keys()):
1324
+ # This branch covers function arg names like 'foo-bar', which are
1325
+ # allowable in languages like Fluent, but not normally in Python. We work around
1326
+ # this using `my_function(**{'foo-bar': baz})` syntax.
1327
+
1328
+ # (If we only wanted to exec the resulting AST, this branch is technically not
1329
+ # necessary, since it is the Python parser that disallows `foo-bar` as an identifier,
1330
+ # and we are by-passing that by creating AST directly. However, to produce something
1331
+ # that can be decompiled to valid Python, we solve the general case).
1332
+
1333
+ kwarg_pairs = list(sorted(self.kwargs.items()))
1334
+ kwarg_names, kwarg_values = [k for k, _ in kwarg_pairs], [v for _, v in kwarg_pairs]
1335
+ return py_ast.Call(
1336
+ func=self.value.as_ast(),
1337
+ args=[arg.as_ast() for arg in self.args],
1338
+ keywords=[
1339
+ py_ast.keyword(
1340
+ arg=None,
1341
+ value=py_ast.Dict(
1342
+ keys=[py_ast.Constant(k, kind=None, **DEFAULT_AST_ARGS) for k in kwarg_names],
1343
+ values=[v.as_ast() for v in kwarg_values],
1344
+ **DEFAULT_AST_ARGS,
1345
+ ),
1346
+ **DEFAULT_AST_ARGS,
1347
+ )
1348
+ ],
1349
+ **DEFAULT_AST_ARGS,
1350
+ )
1351
+
1352
+ # Normal `my_function(foo=bar)` syntax
1353
+ return py_ast.Call(
1354
+ func=self.value.as_ast(),
1355
+ args=[arg.as_ast() for arg in self.args],
1356
+ keywords=[
1357
+ py_ast.keyword(arg=name, value=value.as_ast(), **DEFAULT_AST_ARGS)
1358
+ for name, value in self.kwargs.items()
1359
+ ],
1360
+ **DEFAULT_AST_ARGS,
1361
+ )
1362
+
1363
+ def __repr__(self):
1364
+ return f"Call({self.value!r}, {self.args}, {self.kwargs})"
1365
+
1366
+
1367
+ def method_call(
1368
+ obj: Expression,
1369
+ method_name: str,
1370
+ args: Sequence[Expression],
1371
+ kwargs: dict[str, Expression],
1372
+ ):
1373
+ return obj.attr(method_name).call(args=args, kwargs=kwargs)
1374
+
1375
+
1376
+ class DictLookup(Expression):
1377
+ child_elements = ["lookup_obj", "lookup_arg"]
1378
+
1379
+ def __init__(self, lookup_obj: Expression, lookup_arg: Expression):
1380
+ self.lookup_obj = lookup_obj
1381
+ self.lookup_arg = lookup_arg
1382
+
1383
+ def as_ast(self) -> py_ast.expr:
1384
+ return py_ast.Subscript(
1385
+ value=self.lookup_obj.as_ast(),
1386
+ slice=py_ast.subscript_slice_object(self.lookup_arg.as_ast()),
1387
+ ctx=py_ast.Load(),
1388
+ **DEFAULT_AST_ARGS,
1389
+ )
1390
+
1391
+
1392
+ create_class_instance = function_call
1393
+
1394
+
1395
+ class NoneExpr(Expression):
1396
+ type = type(None)
1397
+
1398
+ def as_ast(self) -> py_ast.expr:
1399
+ return py_ast.Constant(value=None, **DEFAULT_AST_ARGS)
1400
+
1401
+
1402
+ class BinaryOperator(Expression):
1403
+ child_elements = ["left", "right"]
1404
+
1405
+ def __init__(self, left: Expression, right: Expression):
1406
+ self.left = left
1407
+ self.right = right
1408
+
1409
+
1410
+ class ArithOp(BinaryOperator, ABC):
1411
+ """Arithmetic binary operator (ast.BinOp)."""
1412
+
1413
+ op: ClassVar[type[py_ast.operator]]
1414
+
1415
+ def as_ast(self) -> py_ast.expr:
1416
+ return py_ast.BinOp(
1417
+ left=self.left.as_ast(),
1418
+ op=self.op(**DEFAULT_AST_ARGS_ADD),
1419
+ right=self.right.as_ast(),
1420
+ **DEFAULT_AST_ARGS,
1421
+ )
1422
+
1423
+
1424
+ class Add(ArithOp):
1425
+ op = py_ast.Add
1426
+
1427
+
1428
+ class Sub(ArithOp):
1429
+ op = py_ast.Sub
1430
+
1431
+
1432
+ class Mul(ArithOp):
1433
+ op = py_ast.Mult
1434
+
1435
+
1436
+ class Div(ArithOp):
1437
+ op = py_ast.Div
1438
+
1439
+
1440
+ class FloorDiv(ArithOp):
1441
+ op = py_ast.FloorDiv
1442
+
1443
+
1444
+ class Mod(ArithOp):
1445
+ op = py_ast.Mod
1446
+
1447
+
1448
+ class Pow(ArithOp):
1449
+ op = py_ast.Pow
1450
+
1451
+
1452
+ class MatMul(ArithOp):
1453
+ op = py_ast.MatMult
1454
+
1455
+
1456
+ class CompareOp(BinaryOperator, ABC):
1457
+ """Comparison operator (ast.Compare)."""
1458
+
1459
+ type = bool
1460
+ op: ClassVar[type[py_ast.cmpop]]
1461
+
1462
+ def as_ast(self) -> py_ast.expr:
1463
+ return py_ast.Compare(
1464
+ left=self.left.as_ast(),
1465
+ comparators=[self.right.as_ast()],
1466
+ ops=[self.op()],
1467
+ **DEFAULT_AST_ARGS,
1468
+ )
1469
+
1470
+
1471
+ class Equals(CompareOp):
1472
+ op = py_ast.Eq
1473
+
1474
+
1475
+ class NotEquals(CompareOp):
1476
+ op = py_ast.NotEq
1477
+
1478
+
1479
+ class Lt(CompareOp):
1480
+ op = py_ast.Lt
1481
+
1482
+
1483
+ class Gt(CompareOp):
1484
+ op = py_ast.Gt
1485
+
1486
+
1487
+ class LtE(CompareOp):
1488
+ op = py_ast.LtE
1489
+
1490
+
1491
+ class GtE(CompareOp):
1492
+ op = py_ast.GtE
1493
+
1494
+
1495
+ class In(CompareOp):
1496
+ op = py_ast.In
1497
+
1498
+
1499
+ class NotIn(CompareOp):
1500
+ op = py_ast.NotIn
1501
+
1502
+
1503
+ class BoolOp(BinaryOperator, ABC):
1504
+ type = bool
1505
+ op: ClassVar[type[py_ast.boolop]]
1506
+
1507
+ def as_ast(self) -> py_ast.expr:
1508
+ return py_ast.BoolOp(
1509
+ op=self.op(),
1510
+ values=[self.left.as_ast(), self.right.as_ast()],
1511
+ **DEFAULT_AST_ARGS,
1512
+ )
1513
+
1514
+
1515
+ class And(BoolOp):
1516
+ op = py_ast.And
1517
+
1518
+
1519
+ class Or(BoolOp):
1520
+ op = py_ast.Or
1521
+
1522
+
1523
+ def traverse(ast_node: py_ast.AST, func: Callable[[py_ast.AST], None]):
1524
+ """
1525
+ Apply 'func' to ast_node (which is `ast.*` object)
1526
+ """
1527
+ for node in py_ast.walk(ast_node):
1528
+ func(node)
1529
+
1530
+
1531
+ def simplify(codegen_ast: CodeGenAstType, simplifier: Callable[[CodeGenAstType, list[bool]], CodeGenAst]):
1532
+ changes = [True]
1533
+
1534
+ # Wrap `simplifier` (which takes additional `changes` arg)
1535
+ # into function that take just `node`, as required by rewriting_traverse
1536
+ def rewriter(node: CodeGenAstType) -> CodeGenAstType:
1537
+ return simplifier(node, changes)
1538
+
1539
+ while any(changes):
1540
+ changes[:] = []
1541
+ rewriting_traverse(codegen_ast, rewriter)
1542
+ return codegen_ast
1543
+
1544
+
1545
+ def rewriting_traverse(
1546
+ node: CodeGenAstType | Sequence[CodeGenAstType],
1547
+ func: Callable[[CodeGenAstType], CodeGenAstType],
1548
+ ):
1549
+ """
1550
+ Apply 'func' to node and all sub CodeGenAst nodes
1551
+ """
1552
+ if isinstance(node, (CodeGenAst, CodeGenAstList)):
1553
+ new_node = func(node)
1554
+ if new_node is not node:
1555
+ morph_into(node, new_node)
1556
+ for k in node.child_elements:
1557
+ rewriting_traverse(getattr(node, k), func)
1558
+ elif isinstance(node, (list, tuple)):
1559
+ for i in node:
1560
+ rewriting_traverse(i, func)
1561
+
1562
+
1563
+ def morph_into(item: object, new_item: object) -> None:
1564
+ # This naughty little function allows us to make `item` behave like
1565
+ # `new_item` in every way, except it maintains the identity of `item`, so
1566
+ # that we don't have to rewrite a tree of objects with new objects.
1567
+ item.__class__ = new_item.__class__
1568
+ item.__dict__ = new_item.__dict__
1569
+
1570
+
1571
+ def empty_If() -> py_ast.If:
1572
+ """
1573
+ Create an empty If ast node. The `test` attribute
1574
+ must be added later.
1575
+ """
1576
+ return py_ast.If(test=None, orelse=[], **DEFAULT_AST_ARGS) # type: ignore[reportArgumentType]
1577
+
1578
+
1579
+ type PythonObj = (
1580
+ bool
1581
+ | str
1582
+ | bytes
1583
+ | int
1584
+ | float
1585
+ | None
1586
+ | list[PythonObj]
1587
+ | tuple[PythonObj, ...]
1588
+ | set[PythonObj]
1589
+ | frozenset[PythonObj]
1590
+ | dict[PythonObj, PythonObj]
1591
+ )
1592
+
1593
+
1594
+ @overload
1595
+ def auto(value: bool) -> Bool: ... # type: ignore[overload-overlap] # bool before int/float is intentional
1596
+ @overload
1597
+ def auto(value: str) -> String: ...
1598
+ @overload
1599
+ def auto(value: bytes) -> Bytes: ...
1600
+ @overload
1601
+ def auto(value: int) -> Number: ...
1602
+ @overload
1603
+ def auto(value: float) -> Number: ...
1604
+ @overload
1605
+ def auto(value: None) -> NoneExpr: ...
1606
+ @overload
1607
+ def auto(value: list[PythonObj]) -> List: ...
1608
+ @overload
1609
+ def auto(value: tuple[PythonObj, ...]) -> Tuple: ...
1610
+ @overload
1611
+ def auto(value: set[PythonObj]) -> Set: ...
1612
+ @overload
1613
+ def auto(value: frozenset[PythonObj]) -> Set: ...
1614
+ @overload
1615
+ def auto(value: dict[PythonObj, PythonObj]) -> Dict: ...
1616
+
1617
+
1618
+ def auto(value: PythonObj) -> Expression:
1619
+ """
1620
+ Create a codegen Expression from a plain Python object.
1621
+
1622
+ Supports bool, str, bytes, int, float, None, and recursively
1623
+ list, tuple, set, frozenset, and dict.
1624
+ """
1625
+ if isinstance(value, bool):
1626
+ return Bool(value)
1627
+ elif isinstance(value, str):
1628
+ return String(value)
1629
+ elif isinstance(value, bytes):
1630
+ return Bytes(value)
1631
+ elif isinstance(value, (int, float)):
1632
+ return Number(value)
1633
+ elif value is None:
1634
+ return NoneExpr()
1635
+ elif isinstance(value, list):
1636
+ return List([auto(item) for item in value])
1637
+ elif isinstance(value, tuple):
1638
+ return Tuple([auto(item) for item in value])
1639
+ elif isinstance(value, (set, frozenset)):
1640
+ return Set([auto(item) for item in sorted(value, key=repr)])
1641
+ elif isinstance(value, dict): # type: ignore[reportUnnecessaryIsInstance]
1642
+ return Dict([(auto(k), auto(v)) for k, v in value.items()])
1643
+ assert_never(value)
1644
+
1645
+
1646
+ class constants:
1647
+ """
1648
+ Useful pre-made Expression constants
1649
+ """
1650
+
1651
+ None_: NoneExpr = auto(None)
1652
+ True_: Bool = auto(True)
1653
+ False_: Bool = auto(False)