sonolus.py 0.1.3__py3-none-any.whl → 0.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of sonolus.py might be problematic. Click here for more details.

Files changed (90) hide show
  1. sonolus/backend/blocks.py +756 -756
  2. sonolus/backend/excepthook.py +37 -37
  3. sonolus/backend/finalize.py +77 -69
  4. sonolus/backend/interpret.py +7 -7
  5. sonolus/backend/ir.py +29 -3
  6. sonolus/backend/mode.py +24 -24
  7. sonolus/backend/node.py +40 -40
  8. sonolus/backend/ops.py +197 -197
  9. sonolus/backend/optimize/__init__.py +0 -0
  10. sonolus/backend/optimize/allocate.py +126 -0
  11. sonolus/backend/optimize/constant_evaluation.py +374 -0
  12. sonolus/backend/optimize/copy_coalesce.py +85 -0
  13. sonolus/backend/optimize/dead_code.py +185 -0
  14. sonolus/backend/optimize/dominance.py +96 -0
  15. sonolus/backend/{flow.py → optimize/flow.py} +122 -92
  16. sonolus/backend/optimize/inlining.py +137 -0
  17. sonolus/backend/optimize/liveness.py +177 -0
  18. sonolus/backend/optimize/optimize.py +44 -0
  19. sonolus/backend/optimize/passes.py +52 -0
  20. sonolus/backend/optimize/simplify.py +191 -0
  21. sonolus/backend/optimize/ssa.py +200 -0
  22. sonolus/backend/place.py +17 -25
  23. sonolus/backend/utils.py +58 -48
  24. sonolus/backend/visitor.py +1151 -882
  25. sonolus/build/cli.py +7 -1
  26. sonolus/build/compile.py +88 -90
  27. sonolus/build/engine.py +10 -5
  28. sonolus/build/level.py +24 -23
  29. sonolus/build/node.py +43 -43
  30. sonolus/script/archetype.py +438 -139
  31. sonolus/script/array.py +27 -10
  32. sonolus/script/array_like.py +297 -0
  33. sonolus/script/bucket.py +253 -191
  34. sonolus/script/containers.py +257 -51
  35. sonolus/script/debug.py +26 -10
  36. sonolus/script/easing.py +365 -0
  37. sonolus/script/effect.py +191 -131
  38. sonolus/script/engine.py +71 -4
  39. sonolus/script/globals.py +303 -269
  40. sonolus/script/instruction.py +205 -151
  41. sonolus/script/internal/__init__.py +5 -5
  42. sonolus/script/internal/builtin_impls.py +255 -144
  43. sonolus/script/{callbacks.py → internal/callbacks.py} +127 -127
  44. sonolus/script/internal/constant.py +139 -0
  45. sonolus/script/internal/context.py +26 -9
  46. sonolus/script/internal/descriptor.py +17 -17
  47. sonolus/script/internal/dict_impl.py +65 -0
  48. sonolus/script/internal/generic.py +6 -9
  49. sonolus/script/internal/impl.py +38 -13
  50. sonolus/script/internal/introspection.py +17 -14
  51. sonolus/script/internal/math_impls.py +121 -0
  52. sonolus/script/internal/native.py +40 -38
  53. sonolus/script/internal/random.py +67 -0
  54. sonolus/script/internal/range.py +81 -0
  55. sonolus/script/internal/transient.py +51 -0
  56. sonolus/script/internal/tuple_impl.py +113 -0
  57. sonolus/script/internal/value.py +3 -3
  58. sonolus/script/interval.py +338 -112
  59. sonolus/script/iterator.py +167 -214
  60. sonolus/script/level.py +24 -0
  61. sonolus/script/num.py +80 -48
  62. sonolus/script/options.py +257 -191
  63. sonolus/script/particle.py +190 -157
  64. sonolus/script/pointer.py +30 -30
  65. sonolus/script/print.py +102 -81
  66. sonolus/script/project.py +8 -0
  67. sonolus/script/quad.py +263 -0
  68. sonolus/script/record.py +47 -16
  69. sonolus/script/runtime.py +52 -1
  70. sonolus/script/sprite.py +418 -333
  71. sonolus/script/text.py +409 -407
  72. sonolus/script/timing.py +114 -42
  73. sonolus/script/transform.py +332 -48
  74. sonolus/script/ui.py +216 -160
  75. sonolus/script/values.py +6 -13
  76. sonolus/script/vec.py +196 -78
  77. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/METADATA +1 -1
  78. sonolus_py-0.1.5.dist-info/RECORD +89 -0
  79. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/WHEEL +1 -1
  80. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/licenses/LICENSE +21 -21
  81. sonolus/backend/allocate.py +0 -51
  82. sonolus/backend/optimize.py +0 -9
  83. sonolus/backend/passes.py +0 -6
  84. sonolus/backend/simplify.py +0 -30
  85. sonolus/script/comptime.py +0 -160
  86. sonolus/script/graphics.py +0 -150
  87. sonolus/script/math.py +0 -92
  88. sonolus/script/range.py +0 -58
  89. sonolus_py-0.1.3.dist-info/RECORD +0 -75
  90. {sonolus_py-0.1.3.dist-info → sonolus_py-0.1.5.dist-info}/entry_points.txt +0 -0
@@ -1,882 +1,1151 @@
1
- # ruff: noqa: N802
2
- import ast
3
- import functools
4
- import inspect
5
- from collections.abc import Callable, Mapping
6
- from contextlib import contextmanager
7
- from types import FunctionType, MethodType
8
- from typing import Any, Never
9
-
10
- from sonolus.backend.excepthook import install_excepthook
11
- from sonolus.backend.utils import get_function, scan_writes
12
- from sonolus.script.debug import assert_true
13
- from sonolus.script.internal.builtin_impls import BUILTIN_IMPLS
14
- from sonolus.script.internal.context import Context, EmptyBinding, Scope, ValueBinding, ctx, set_ctx
15
- from sonolus.script.internal.descriptor import SonolusDescriptor
16
- from sonolus.script.internal.error import CompilationError
17
- from sonolus.script.internal.impl import try_validate_value, validate_value
18
- from sonolus.script.internal.value import Value
19
- from sonolus.script.iterator import SonolusIterator
20
- from sonolus.script.num import Num, is_num
21
-
22
- _compiler_internal_ = True
23
-
24
-
25
- def compile_and_call[**P, R](fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R:
26
- if not ctx():
27
- return fn(*args, **kwargs)
28
- return validate_value(generate_fn_impl(fn)(*args, **kwargs))
29
-
30
-
31
- def generate_fn_impl(fn: Callable):
32
- install_excepthook()
33
- match fn:
34
- case Value() as value if value._is_py_():
35
- return generate_fn_impl(value._as_py_())
36
- case MethodType() as method:
37
- return functools.partial(generate_fn_impl(method.__func__), method.__self__)
38
- case FunctionType() as function:
39
- if getattr(function, "_meta_fn_", False):
40
- return function
41
- return functools.partial(eval_fn, function)
42
- case _:
43
- if callable(fn) and isinstance(fn, Value):
44
- return generate_fn_impl(fn.__call__)
45
- elif fn is type:
46
- return fn
47
- elif callable(fn):
48
- raise TypeError(f"Unsupported callable {fn!r}")
49
- else:
50
- raise TypeError(f"Not callable {fn!r}")
51
-
52
-
53
- def eval_fn(fn: Callable, /, *args, **kwargs):
54
- source_file, node = get_function(fn)
55
- bound_args = inspect.signature(fn).bind(*args, **kwargs)
56
- bound_args.apply_defaults()
57
- closurevars = inspect.getclosurevars(fn)
58
- global_vars = {**closurevars.nonlocals, **closurevars.globals, **closurevars.builtins}
59
- return Visitor(source_file, bound_args, global_vars).run(node)
60
-
61
-
62
- unary_ops = {
63
- ast.Invert: "__invert__",
64
- ast.UAdd: "__pos__",
65
- ast.USub: "__neg__",
66
- }
67
-
68
- bin_ops = {
69
- ast.Add: "__add__",
70
- ast.Sub: "__sub__",
71
- ast.Mult: "__mul__",
72
- ast.Div: "__truediv__",
73
- ast.FloorDiv: "__floordiv__",
74
- ast.Mod: "__mod__",
75
- ast.Pow: "__pow__",
76
- ast.LShift: "__lshift__",
77
- ast.RShift: "__rshift__",
78
- ast.BitOr: "__or__",
79
- ast.BitAnd: "__and__",
80
- ast.BitXor: "__xor__",
81
- ast.MatMult: "__matmul__",
82
- }
83
-
84
- rbin_ops = {
85
- ast.Add: "__radd__",
86
- ast.Sub: "__rsub__",
87
- ast.Mult: "__rmul__",
88
- ast.Div: "__rtruediv__",
89
- ast.FloorDiv: "__rfloordiv__",
90
- ast.Mod: "__rmod__",
91
- ast.Pow: "__rpow__",
92
- ast.LShift: "__rlshift__",
93
- ast.RShift: "__rrshift__",
94
- ast.BitOr: "__ror__",
95
- ast.BitAnd: "__rand__",
96
- ast.BitXor: "__rxor__",
97
- ast.MatMult: "__rmatmul__",
98
- }
99
-
100
- inplace_ops = {
101
- ast.Add: "__iadd__",
102
- ast.Sub: "__isub__",
103
- ast.Mult: "__imul__",
104
- ast.Div: "__itruediv__",
105
- ast.FloorDiv: "__ifloordiv__",
106
- ast.Mod: "__imod__",
107
- ast.Pow: "__ipow__",
108
- ast.LShift: "__ilshift__",
109
- ast.RShift: "__irshift__",
110
- ast.BitOr: "__ior__",
111
- ast.BitXor: "__ixor__",
112
- ast.BitAnd: "__iand__",
113
- ast.MatMult: "__imatmul__",
114
- }
115
-
116
- comp_ops = {
117
- ast.Eq: "__eq__",
118
- ast.NotEq: "__ne__",
119
- ast.Lt: "__lt__",
120
- ast.LtE: "__le__",
121
- ast.Gt: "__gt__",
122
- ast.GtE: "__ge__",
123
- }
124
-
125
- rcomp_ops = {
126
- ast.Eq: "__req__",
127
- ast.NotEq: "__rne__",
128
- ast.Lt: "__gt__",
129
- ast.LtE: "__ge__",
130
- ast.Gt: "__lt__",
131
- ast.GtE: "__le__",
132
- ast.In: "__contains__",
133
- ast.NotIn: "__contains__",
134
- }
135
-
136
-
137
- class Visitor(ast.NodeVisitor):
138
- source_file: str
139
- globals: dict[str, Any]
140
- bound_args: inspect.BoundArguments
141
- used_names: dict[str, int]
142
- return_ctxs: list[Context] # Contexts at return statements, which will branch to the exit
143
- loop_head_ctxs: list[Context] # Contexts at loop heads, from outer to inner
144
- break_ctxs: list[list[Context]] # Contexts at break statements, from outer to inner
145
-
146
- def __init__(self, source_file: str, bound_args: inspect.BoundArguments, global_vars: dict[str, Any]):
147
- self.source_file = source_file
148
- self.globals = {}
149
- for k, v in global_vars.items():
150
- # Unfortunately, inspect.closurevars also includes attributes
151
- if v is ctx:
152
- raise ValueError("Unexpected use of ctx in non-meta function")
153
- value = try_validate_value(BUILTIN_IMPLS.get(id(v), v))
154
- if value is not None:
155
- self.globals[k] = value
156
- self.bound_args = bound_args
157
- self.used_names = {}
158
- self.return_ctxs = []
159
- self.loop_head_ctxs = []
160
- self.break_ctxs = []
161
-
162
- def run(self, node):
163
- before_ctx = ctx()
164
- set_ctx(before_ctx.branch_with_scope(None, Scope()))
165
- for name, value in self.bound_args.arguments.items():
166
- ctx().scope.set_value(name, validate_value(value))
167
- match node:
168
- case ast.FunctionDef(body=body):
169
- ctx().scope.set_value("$return", validate_value(None))
170
- for stmt in body:
171
- self.visit(stmt)
172
- case _:
173
- raise NotImplementedError("Unsupported syntax")
174
- after_ctx = Context.meet([*self.return_ctxs, ctx()])
175
- result_binding = after_ctx.scope.get_binding("$return")
176
- if not isinstance(result_binding, ValueBinding):
177
- raise ValueError("Function has conflicting return values")
178
- set_ctx(after_ctx.branch_with_scope(None, before_ctx.scope.copy()))
179
- return result_binding.value
180
-
181
- def visit_FunctionDef(self, node):
182
- raise NotImplementedError("Nested functions are not supported")
183
-
184
- def visit_AsyncFunctionDef(self, node):
185
- raise NotImplementedError("Async functions are not supported")
186
-
187
- def visit_ClassDef(self, node):
188
- raise NotImplementedError("Classes within functions are not supported")
189
-
190
- def visit_Return(self, node):
191
- value = self.visit(node.value) if node.value else validate_value(None)
192
- ctx().scope.set_value("$return", value)
193
- self.return_ctxs.append(ctx())
194
- set_ctx(ctx().into_dead())
195
-
196
- def visit_Delete(self, node):
197
- raise NotImplementedError("Delete statements are not supported")
198
-
199
- def visit_Assign(self, node):
200
- value = self.visit(node.value)
201
- for target in node.targets:
202
- self.handle_assign(target, value)
203
-
204
- def visit_TypeAlias(self, node):
205
- raise NotImplementedError("Type aliases are not supported")
206
-
207
- def visit_AugAssign(self, node):
208
- lhs_value = self.visit(node.target)
209
- rhs_value = self.visit(node.value)
210
- inplace_fn_name = inplace_ops[type(node.op)]
211
- regular_fn_name = bin_ops[type(node.op)]
212
- right_fn_name = rbin_ops[type(node.op)]
213
- if hasattr(lhs_value, inplace_fn_name):
214
- result = self.handle_call(node, getattr(lhs_value, inplace_fn_name), rhs_value)
215
- if not self.is_not_implemented(result):
216
- if result is not lhs_value:
217
- raise ValueError("Inplace operation must return the same object")
218
- # Skip the actual assignment because the inplace operation already did the job, as an optimization
219
- # There could be side effects of assignment, but that's atypical
220
- return
221
- if hasattr(lhs_value, regular_fn_name):
222
- result = self.handle_call(node, getattr(lhs_value, regular_fn_name), rhs_value)
223
- if not self.is_not_implemented(result):
224
- self.handle_assign(node.target, result)
225
- return
226
- if hasattr(rhs_value, right_fn_name):
227
- result = self.handle_call(node, getattr(rhs_value, right_fn_name), lhs_value)
228
- if not self.is_not_implemented(result):
229
- self.handle_assign(node.target, result)
230
- return
231
- raise NotImplementedError("Unsupported augmented assignment")
232
-
233
- def visit_AnnAssign(self, node):
234
- value = self.visit(node.value)
235
- self.handle_assign(node.target, value)
236
-
237
- def visit_For(self, node):
238
- iterator = iter(self.visit(node.iter))
239
- if not isinstance(iterator, SonolusIterator):
240
- raise ValueError("Unsupported iterator")
241
- writes = scan_writes(node)
242
- header_ctx = ctx().prepare_loop_header(writes)
243
- self.loop_head_ctxs.append(header_ctx)
244
- self.break_ctxs.append([])
245
- set_ctx(header_ctx)
246
- has_next = self.ensure_boolean_num(self.handle_call(node, iterator.has_next))
247
- ctx().test = has_next.ir()
248
- body_ctx = ctx().branch(None)
249
- else_ctx = ctx().branch(0)
250
-
251
- set_ctx(body_ctx)
252
- self.handle_assign(node.target, self.handle_call(node, iterator.next))
253
- for stmt in node.body:
254
- self.visit(stmt)
255
- ctx().branch_to_loop_header(header_ctx)
256
-
257
- set_ctx(else_ctx)
258
- for stmt in node.orelse:
259
- self.visit(stmt)
260
- else_end_ctx = ctx()
261
-
262
- break_ctxs = self.break_ctxs.pop()
263
- after_ctx = Context.meet([else_end_ctx, *break_ctxs])
264
- set_ctx(after_ctx)
265
-
266
- def visit_While(self, node):
267
- writes = scan_writes(node)
268
- header_ctx = ctx().prepare_loop_header(writes)
269
- self.loop_head_ctxs.append(header_ctx)
270
- self.break_ctxs.append([])
271
- set_ctx(header_ctx)
272
- test = self.ensure_boolean_num(self.visit(node.test))
273
- ctx().test = test.ir()
274
- body_ctx = ctx().branch(None)
275
- else_ctx = ctx().branch(0)
276
-
277
- set_ctx(body_ctx)
278
- for stmt in node.body:
279
- self.visit(stmt)
280
- ctx().branch_to_loop_header(header_ctx)
281
-
282
- set_ctx(else_ctx)
283
- for stmt in node.orelse:
284
- self.visit(stmt)
285
- else_end_ctx = ctx()
286
-
287
- break_ctxs = self.break_ctxs.pop()
288
- after_ctx = Context.meet([else_end_ctx, *break_ctxs])
289
- set_ctx(after_ctx)
290
-
291
- def visit_If(self, node):
292
- test = self.ensure_boolean_num(self.visit(node.test))
293
-
294
- if test._is_py_():
295
- if test._as_py_():
296
- for stmt in node.body:
297
- self.visit(stmt)
298
- else:
299
- for stmt in node.orelse:
300
- self.visit(stmt)
301
- return
302
-
303
- ctx_init = ctx()
304
- ctx_init.test = test.ir()
305
- true_ctx = ctx_init.branch(None)
306
- false_ctx = ctx_init.branch(0)
307
-
308
- set_ctx(true_ctx)
309
- for stmt in node.body:
310
- self.visit(stmt)
311
- true_end_ctx = ctx()
312
-
313
- set_ctx(false_ctx)
314
- for stmt in node.orelse:
315
- self.visit(stmt)
316
- false_end_ctx = ctx()
317
-
318
- set_ctx(Context.meet([true_end_ctx, false_end_ctx]))
319
-
320
- def visit_With(self, node):
321
- raise NotImplementedError("With statements are not supported")
322
-
323
- def visit_AsyncWith(self, node):
324
- raise NotImplementedError("Async with statements are not supported")
325
-
326
- def visit_Match(self, node):
327
- subject = self.visit(node.subject)
328
- end_ctxs = []
329
- for case in node.cases:
330
- if not ctx().live:
331
- break
332
- true_ctx, false_ctx = self.handle_match_pattern(subject, case.pattern)
333
- if not true_ctx.live:
334
- set_ctx(false_ctx)
335
- continue
336
- set_ctx(true_ctx)
337
- guard = self.ensure_boolean_num(self.visit(case.guard)) if case.guard else validate_value(True)
338
- if guard._is_py_():
339
- if guard._as_py_():
340
- for stmt in case.body:
341
- self.visit(stmt)
342
- end_ctxs.append(ctx())
343
- else:
344
- end_ctxs.append(ctx())
345
- else:
346
- ctx().test = guard.ir()
347
- true_ctx = ctx().branch(None)
348
- false_ctx = ctx().branch(0)
349
- set_ctx(true_ctx)
350
- for stmt in case.body:
351
- self.visit(stmt)
352
- end_ctxs.append(ctx())
353
- set_ctx(false_ctx)
354
- if end_ctxs:
355
- set_ctx(Context.meet(end_ctxs))
356
-
357
- def handle_match_pattern(self, subject: Value, pattern: ast.pattern) -> tuple[Context, Context]:
358
- match pattern:
359
- case ast.MatchValue(value=value):
360
- value = self.visit(value)
361
- test = self.ensure_boolean_num(subject == value)
362
- ctx_init = ctx()
363
- ctx_init.test = test.ir()
364
- true_ctx = ctx_init.branch(None)
365
- false_ctx = ctx_init.branch(0)
366
- return true_ctx, false_ctx
367
- case ast.MatchSingleton(value=value):
368
- match value:
369
- case True:
370
- test = self.ensure_boolean_num(subject)
371
- case False:
372
- test = self.ensure_boolean_num(subject).not_()
373
- case None:
374
- test = Num._accept_(subject._is_py_() and subject._as_py_() is None)
375
- case _:
376
- raise NotImplementedError("Unsupported match singleton")
377
- ctx_init = ctx()
378
- ctx_init.test = test.ir()
379
- true_ctx = ctx_init.branch(None)
380
- false_ctx = ctx_init.branch(0)
381
- return true_ctx, false_ctx
382
- case ast.MatchSequence():
383
- raise NotImplementedError("Match sequences are not supported")
384
- case ast.MatchMapping():
385
- raise NotImplementedError("Match mappings are not supported")
386
- case ast.MatchClass(cls=cls, patterns=patterns, kwd_attrs=kwd_attrs, kwd_patterns=kwd_patterns):
387
- from sonolus.script.comptime import Comptime
388
- from sonolus.script.internal.generic import validate_type_spec
389
-
390
- cls = validate_type_spec(self.visit(cls))
391
- if not isinstance(cls, type):
392
- raise TypeError("Class is not a type")
393
- if issubclass(cls, Comptime):
394
- raise TypeError("Comptime is not supported in match patterns")
395
- if not isinstance(subject, cls):
396
- return ctx().into_dead(), ctx()
397
- if patterns:
398
- if not hasattr(cls, "__match_args__"):
399
- raise TypeError("Class does not support match patterns")
400
- if len(cls.__match_args__) < len(patterns):
401
- raise ValueError("Too many match patterns")
402
- # kwd_attrs can't be mixed with patterns on the syntax level,
403
- # so we can just set it like this since it's empty
404
- kwd_attrs = cls.__match_args__[: len(patterns)]
405
- kwd_patterns = patterns
406
- if kwd_attrs:
407
- true_ctx = ctx()
408
- false_ctxs = []
409
- for attr, subpattern in zip(kwd_attrs, kwd_patterns, strict=False):
410
- if not hasattr(subject, attr):
411
- raise AttributeError(f"Object has no attribute {attr}")
412
- value = self.handle_getattr(subpattern, subject, attr)
413
- true_ctx, false_ctx = self.handle_match_pattern(value, subpattern)
414
- false_ctxs.append(false_ctx)
415
- set_ctx(true_ctx)
416
- return true_ctx, Context.meet(false_ctxs)
417
- return ctx(), ctx().into_dead()
418
- case ast.MatchStar():
419
- raise NotImplementedError("Match stars are not supported")
420
- case ast.MatchAs(pattern=pattern, name=name):
421
- if pattern:
422
- true_ctx, false_ctx = self.handle_match_pattern(subject, pattern)
423
- if name:
424
- true_ctx.scope.set_value(name, subject)
425
- return true_ctx, false_ctx
426
- else:
427
- if name:
428
- ctx().scope.set_value(name, subject)
429
- return ctx(), ctx().into_dead()
430
- case ast.MatchOr():
431
- true_ctxs = []
432
- false_ctx = ctx()
433
- assert pattern.patterns
434
- for subpattern in pattern.patterns:
435
- true_ctx, false_ctx = self.handle_match_pattern(subject, subpattern)
436
- true_ctxs.append(true_ctx)
437
- set_ctx(false_ctx)
438
- return Context.meet(true_ctxs), false_ctx
439
-
440
- def visit_Raise(self, node):
441
- raise NotImplementedError("Raise statements are not supported")
442
-
443
- def visit_Try(self, node):
444
- raise NotImplementedError("Try statements are not supported")
445
-
446
- def visit_TryStar(self, node):
447
- raise NotImplementedError("Try* statements are not supported")
448
-
449
- def visit_Assert(self, node):
450
- self.handle_call(
451
- node, assert_true, self.visit(node.test), self.visit(node.msg) if node.msg else validate_value(None)
452
- )
453
-
454
- def visit_Import(self, node):
455
- raise NotImplementedError("Import statements are not supported")
456
-
457
- def visit_ImportFrom(self, node):
458
- raise NotImplementedError("Import statements are not supported")
459
-
460
- def visit_Global(self, node):
461
- raise NotImplementedError("Global statements are not supported")
462
-
463
- def visit_Nonlocal(self, node):
464
- raise NotImplementedError("Nonlocal statements are not supported")
465
-
466
- def visit_Expr(self, node):
467
- return self.visit(node.value)
468
-
469
- def visit_Pass(self, node):
470
- pass
471
-
472
- def visit_Break(self, node):
473
- self.break_ctxs[-1].append(ctx())
474
- set_ctx(ctx().into_dead())
475
-
476
- def visit_Continue(self, node):
477
- ctx().branch_to_loop_header(self.loop_head_ctxs[-1])
478
- set_ctx(ctx().into_dead())
479
-
480
- def visit_BoolOp(self, node) -> Value:
481
- match node.op:
482
- case ast.And():
483
- handler = self.handle_and
484
- case ast.Or():
485
- handler = self.handle_or
486
- case _:
487
- raise NotImplementedError(f"Unsupported bool operator {node.op}")
488
-
489
- if not node.values:
490
- raise ValueError("Bool operator requires at least one operand")
491
- if len(node.values) == 1:
492
- return self.visit(node.values[0])
493
- initial, *rest = node.values
494
- return handler(self.visit(initial), ast.copy_location(ast.BoolOp(op=node.op, values=rest), node))
495
-
496
- def visit_NamedExpr(self, node):
497
- value = self.visit(node.value)
498
- self.handle_assign(node.target, value)
499
- return value
500
-
501
- def visit_BinOp(self, node):
502
- lhs = self.visit(node.left)
503
- rhs = self.visit(node.right)
504
- op = bin_ops[type(node.op)]
505
- if hasattr(lhs, op):
506
- result = self.handle_call(node, getattr(lhs, op), rhs)
507
- if not self.is_not_implemented(result):
508
- return result
509
- if hasattr(rhs, rbin_ops[type(node.op)]):
510
- result = self.handle_call(node, getattr(rhs, rbin_ops[type(node.op)]), lhs)
511
- if not self.is_not_implemented(result):
512
- return result
513
- raise NotImplementedError(f"Unsupported operand types for binary operator {node.op}")
514
-
515
- def visit_UnaryOp(self, node):
516
- operand = self.visit(node.operand)
517
- if isinstance(node.op, ast.Not):
518
- return self.ensure_boolean_num(operand).not_()
519
- op = unary_ops[type(node.op)]
520
- if hasattr(operand, op):
521
- return self.handle_call(node, getattr(operand, op))
522
- raise NotImplementedError(f"Unsupported operand type for unary operator {node.op}")
523
-
524
- def visit_Lambda(self, node):
525
- raise NotImplementedError("Lambda functions are not supported")
526
-
527
- def visit_IfExp(self, node):
528
- test = self.ensure_boolean_num(self.visit(node.test))
529
-
530
- if test._is_py_():
531
- if test._as_py_():
532
- return self.visit(node.body)
533
- else:
534
- return self.visit(node.orelse)
535
-
536
- res_name = self.new_name("ifexp")
537
- ctx_init = ctx()
538
- ctx_init.test = test.ir()
539
-
540
- set_ctx(ctx_init.branch(None))
541
- true_value = self.visit(node.body)
542
- ctx().scope.set_value(res_name, true_value)
543
- ctx_true = ctx()
544
-
545
- set_ctx(ctx_init.branch(0))
546
- false_value = self.visit(node.orelse)
547
- ctx().scope.set_value(res_name, false_value)
548
- ctx_false = ctx()
549
-
550
- set_ctx(Context.meet([ctx_true, ctx_false]))
551
- return ctx().scope.get_value(res_name)
552
-
553
- def visit_Dict(self, node):
554
- return validate_value({self.visit(k): self.visit(v) for k, v in zip(node.keys, node.values, strict=True)})
555
-
556
- def visit_Set(self, node):
557
- raise NotImplementedError("Set literals are not supported")
558
-
559
- def visit_ListComp(self, node):
560
- raise NotImplementedError("List comprehensions are not supported")
561
-
562
- def visit_SetComp(self, node):
563
- raise NotImplementedError("Set comprehensions are not supported")
564
-
565
- def visit_DictComp(self, node):
566
- raise NotImplementedError("Dict comprehensions are not supported")
567
-
568
- def visit_GeneratorExp(self, node):
569
- raise NotImplementedError("Generator expressions are not supported")
570
-
571
- def visit_Await(self, node):
572
- raise NotImplementedError("Await expressions are not supported")
573
-
574
- def visit_Yield(self, node):
575
- raise NotImplementedError("Yield expressions are not supported")
576
-
577
- def visit_YieldFrom(self, node):
578
- raise NotImplementedError("Yield from expressions are not supported")
579
-
580
- def visit_Compare(self, node):
581
- result_name = self.new_name("compare")
582
- ctx().scope.set_value(result_name, Num._accept_(0))
583
- l_val = self.visit(node.left)
584
- false_ctxs = []
585
- for i, (op, rhs) in enumerate(zip(node.ops, node.comparators, strict=True)):
586
- r_val = self.visit(rhs)
587
- inverted = isinstance(op, ast.NotIn)
588
- result = None
589
- if isinstance(op, ast.Is | ast.IsNot):
590
- if not (r_val._is_py_() and r_val._as_py_() is None):
591
- raise TypeError("The right operand of 'is' must be None")
592
- if isinstance(op, ast.Is):
593
- result = Num._accept_(l_val._is_py_() and l_val._as_py_() is None)
594
- else:
595
- result = Num._accept_(not (l_val._is_py_() and l_val._as_py_() is None))
596
- elif type(op) in comp_ops and hasattr(l_val, comp_ops[type(op)]):
597
- result = self.handle_call(node, getattr(l_val, comp_ops[type(op)]), r_val)
598
- if (
599
- (result is None or self.is_not_implemented(result))
600
- and type(op) in rcomp_ops
601
- and hasattr(r_val, rcomp_ops[type(op)])
602
- ):
603
- result = self.handle_call(node, getattr(r_val, rcomp_ops[type(op)]), l_val)
604
- if result is None or self.is_not_implemented(result):
605
- raise NotImplementedError(f"Unsupported comparison operator {op}")
606
- result = self.ensure_boolean_num(result)
607
- if inverted:
608
- result = result.not_()
609
- curr_ctx = ctx()
610
- if i == len(node.ops) - 1:
611
- curr_ctx.scope.set_value(result_name, result)
612
- else:
613
- curr_ctx.test = result.ir()
614
- true_ctx = curr_ctx.branch(None)
615
- false_ctx = curr_ctx.branch(0)
616
- false_ctxs.append(false_ctx)
617
- set_ctx(true_ctx)
618
- l_val = r_val
619
- last_ctx = ctx() # This is the result of the last comparison returning true
620
- set_ctx(Context.meet([last_ctx, *false_ctxs]))
621
- return ctx().scope.get_value(result_name)
622
-
623
- def visit_Call(self, node):
624
- fn = self.visit(node.func)
625
- if fn is Num:
626
- raise ValueError("Calling int/bool/float is not supported")
627
- args = []
628
- kwargs = {}
629
- for arg in node.args:
630
- if isinstance(arg, ast.Starred):
631
- args.extend(self.handle_starred(self.visit(arg.value)))
632
- else:
633
- args.append(self.visit(arg))
634
- for keyword in node.keywords:
635
- if keyword.arg:
636
- kwargs[keyword.arg] = self.visit(keyword.value)
637
- else:
638
- value = self.visit(keyword.value)
639
- if value._is_py_() and isinstance(value._as_py_(), Mapping):
640
- kwargs.update(value._as_py_())
641
- else:
642
- raise ValueError("Starred keyword arguments (**kwargs) must be dictionaries")
643
- return self.handle_call(node, fn, *args, **kwargs)
644
-
645
- def visit_FormattedValue(self, node):
646
- raise NotImplementedError("F-strings are not supported")
647
-
648
- def visit_JoinedStr(self, node):
649
- raise NotImplementedError("F-strings are not supported")
650
-
651
- def visit_Constant(self, node):
652
- return validate_value(node.value)
653
-
654
- def visit_Attribute(self, node):
655
- return self.handle_getattr(node, self.visit(node.value), node.attr)
656
-
657
- def visit_Subscript(self, node):
658
- value = self.visit(node.value)
659
- slice_value = self.visit(node.slice)
660
- return self.handle_getitem(node, value, slice_value)
661
-
662
- def visit_Starred(self, node):
663
- raise NotImplementedError("Starred expressions are not supported")
664
-
665
- def visit_Name(self, node):
666
- if isinstance(ctx().scope.get_binding(node.id), EmptyBinding) and node.id in self.globals:
667
- # globals can have false positives due to limitations of inspect.closurevars
668
- # so we need to check that it's not defined as a local variable
669
- return self.globals[node.id]
670
- return ctx().scope.get_value(node.id)
671
-
672
- def visit_List(self, node):
673
- raise NotImplementedError("List literals are not supported")
674
-
675
- def visit_Tuple(self, node):
676
- values = []
677
- for elt in node.elts:
678
- if isinstance(elt, ast.Starred):
679
- values.extend(self.handle_starred(self.visit(elt.value)))
680
- else:
681
- values.append(self.visit(elt))
682
- return validate_value(tuple(values))
683
-
684
- def visit_Slice(self, node):
685
- raise NotImplementedError("Slices are not supported")
686
-
687
- def handle_assign(self, target: ast.stmt | ast.expr, value: Value):
688
- match target:
689
- case ast.Name(id=name):
690
- ctx().scope.set_value(name, value)
691
- case ast.Attribute(value=attr_value, attr=attr):
692
- attr_value = self.visit(attr_value)
693
- self.handle_setattr(target, attr_value, attr, value)
694
- case ast.Subscript(value=sub_value, slice=slice_expr):
695
- sub_value = self.visit(sub_value)
696
- slice_value = self.visit(slice_expr)
697
- self.handle_setitem(target, sub_value, slice_value, value)
698
- case ast.Tuple(elts=elts) | ast.List(elts=elts):
699
- values = self.handle_starred(value)
700
- if len(elts) != len(values):
701
- raise ValueError("Unpacking assignment requires the same number of elements")
702
- for elt, v in zip(elts, values, strict=False):
703
- self.handle_assign(elt, validate_value(v))
704
- case ast.Starred():
705
- raise NotImplementedError("Starred assignment is not supported")
706
- case _:
707
- raise NotImplementedError("Unsupported assignment target")
708
-
709
- def handle_and(self, l_val: Value, r_expr: ast.expr) -> Value:
710
- ctx_init = ctx()
711
- l_val = self.ensure_boolean_num(l_val)
712
- ctx_init.test = l_val.ir()
713
- res_name = self.new_name("and")
714
-
715
- set_ctx(ctx_init.branch(None))
716
- r_val = self.ensure_boolean_num(self.visit(r_expr))
717
- ctx().scope.set_value(res_name, r_val)
718
- ctx_true = ctx()
719
-
720
- set_ctx(ctx_init.branch(0))
721
- ctx().scope.set_value(res_name, Num._accept_(0))
722
- ctx_false = ctx()
723
-
724
- set_ctx(Context.meet([ctx_true, ctx_false]))
725
- if l_val._is_py_() and r_val._is_py_():
726
- return Num._accept_(l_val._as_py_() and r_val._as_py_())
727
- return ctx().scope.get_value(res_name)
728
-
729
- def handle_or(self, l_val: Value, r_expr: ast.expr) -> Value:
730
- ctx_init = ctx()
731
- l_val = self.ensure_boolean_num(l_val)
732
- ctx_init.test = l_val.ir()
733
- res_name = self.new_name("or")
734
-
735
- set_ctx(ctx_init.branch(None))
736
- ctx().scope.set_value(res_name, l_val)
737
- ctx_true = ctx()
738
-
739
- set_ctx(ctx_init.branch(0))
740
- r_val = self.ensure_boolean_num(self.visit(r_expr))
741
- ctx().scope.set_value(res_name, r_val)
742
- ctx_false = ctx()
743
-
744
- set_ctx(Context.meet([ctx_true, ctx_false]))
745
- if l_val._is_py_() and r_val._is_py_():
746
- return Num._accept_(l_val._as_py_() or r_val._as_py_())
747
- return ctx().scope.get_value(res_name)
748
-
749
- def generic_visit(self, node):
750
- if isinstance(node, ast.stmt | ast.expr):
751
- with self.reporting_errors_at_node(node):
752
- raise NotImplementedError(f"Unsupported syntax: {type(node).__name__}")
753
- raise NotImplementedError(f"Unsupported syntax: {type(node).__name__}")
754
-
755
- def handle_getattr(self, node: ast.stmt | ast.expr, target: Value, key: str) -> Value:
756
- with self.reporting_errors_at_node(node):
757
- if target._is_py_():
758
- target = target._as_py_()
759
- descriptor = type(target).__dict__.get(key)
760
- match descriptor:
761
- case property(fget=getter):
762
- return self.handle_call(node, getter, target)
763
- case SonolusDescriptor() | FunctionType() | classmethod() | staticmethod() | None:
764
- return validate_value(getattr(target, key))
765
- case non_descriptor if not hasattr(non_descriptor, "__get__"):
766
- return validate_value(getattr(target, key))
767
- case _:
768
- raise TypeError(f"Unsupported field or descriptor {key}")
769
-
770
- def handle_setattr(self, node: ast.stmt | ast.expr, target: Value, key: str, value: Value):
771
- with self.reporting_errors_at_node(node):
772
- if target._is_py_():
773
- target = target._as_py_()
774
- descriptor = getattr(type(target), key, None)
775
- match descriptor:
776
- case property(fset=setter):
777
- if setter is None:
778
- raise AttributeError(f"Cannot set attribute {key} because property has no setter")
779
- self.handle_call(node, setter, target, value)
780
- case SonolusDescriptor():
781
- setattr(target, key, value)
782
- case _:
783
- raise TypeError(f"Unsupported field or descriptor {key}")
784
-
785
- def handle_call[**P, R](
786
- self, node: ast.stmt | ast.expr, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs
787
- ) -> R:
788
- """Handles a call to the given callable."""
789
- if (
790
- isinstance(fn, Value)
791
- and fn._is_py_()
792
- and isinstance(fn._as_py_(), type)
793
- and issubclass(fn._as_py_(), Value)
794
- ):
795
- return validate_value(self.execute_at_node(node, fn._as_py_(), *args, **kwargs))
796
- else:
797
- return self.execute_at_node(node, lambda: validate_value(compile_and_call(fn, *args, **kwargs)))
798
-
799
- def handle_getitem(self, node: ast.stmt | ast.expr, target: Value, key: Value) -> Value:
800
- with self.reporting_errors_at_node(node):
801
- if target._is_py_():
802
- target = target._as_py_()
803
- if key._is_py_():
804
- return validate_value(target[key._as_py_()])
805
- if isinstance(target, Value) and hasattr(target, "__getitem__"):
806
- return self.handle_call(node, target.__getitem__, key)
807
- raise TypeError(f"Cannot get items on {type(target).__name__}")
808
- else:
809
- if isinstance(target, Value) and hasattr(target, "__getitem__"):
810
- return self.handle_call(node, target.__getitem__, key)
811
- raise TypeError(f"Cannot get items on {type(target).__name__}")
812
-
813
- def handle_setitem(self, node: ast.stmt | ast.expr, target: Value, key: Value, value: Value):
814
- with self.reporting_errors_at_node(node):
815
- if target._is_py_():
816
- target = target._as_py_()
817
- if key._is_py_():
818
- target[key._as_py_()] = value._as_py_()
819
- if isinstance(target, Value) and hasattr(target, "__setitem__"):
820
- return self.handle_call(node, target.__setitem__, key, value)
821
- raise TypeError(f"Cannot set items on {type(target).__name__}")
822
- else:
823
- if isinstance(target, Value) and hasattr(target, "__setitem__"):
824
- return self.handle_call(node, target.__setitem__, key, value)
825
- raise TypeError(f"Cannot set items on {type(target).__name__}")
826
-
827
- def handle_starred(self, value: Value) -> tuple[Value, ...]:
828
- if value._is_py_() and isinstance(value._as_py_(), tuple):
829
- return value._as_py_()
830
- raise ValueError("Unsupported starred expression")
831
-
832
- def is_not_implemented(self, value):
833
- value = validate_value(value)
834
- return value._is_py_() and value._as_py_() is NotImplemented
835
-
836
- def ensure_boolean_num(self, value) -> Num:
837
- # This just checks the type for now, although we could support custom __bool__ implementations in the future
838
- if not is_num(value):
839
- raise TypeError(f"Invalid type where a bool (Num) was expected: {type(value).__name__}")
840
- return value
841
-
842
- def raise_exception_at_node(self, node: ast.stmt | ast.expr, cause: Exception) -> Never:
843
- """Throws a compilation error at the given node."""
844
-
845
- def thrower() -> Never:
846
- raise CompilationError(str(cause)) from cause
847
-
848
- self.execute_at_node(node, thrower)
849
-
850
- def execute_at_node[**P, R](
851
- self, node: ast.stmt | ast.expr, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs
852
- ) -> R:
853
- """Executes the given function at the given node for a better traceback."""
854
- expr = ast.Expression(
855
- body=ast.Call(
856
- func=ast.Name(id="fn", ctx=ast.Load()),
857
- args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
858
- keywords=[ast.keyword(value=ast.Name(id="kwargs", ctx=ast.Load()), arg=None)],
859
- lineno=node.lineno,
860
- col_offset=node.col_offset,
861
- end_lineno=node.end_lineno,
862
- end_col_offset=node.end_col_offset,
863
- ),
864
- )
865
- expr = ast.fix_missing_locations(expr)
866
- return eval(
867
- compile(expr, filename=self.source_file, mode="eval"),
868
- {"fn": fn, "args": args, "kwargs": kwargs, "_filter_traceback_": True},
869
- )
870
-
871
- @contextmanager
872
- def reporting_errors_at_node(self, node: ast.stmt | ast.expr):
873
- try:
874
- yield
875
- except CompilationError as e:
876
- raise e from None
877
- except Exception as e:
878
- self.raise_exception_at_node(node, e)
879
-
880
- def new_name(self, name: str):
881
- self.used_names[name] = self.used_names.get(name, 0) + 1
882
- return f"${name}_{self.used_names[name]}"
1
+ # ruff: noqa: N802
2
+ import ast
3
+ import builtins
4
+ import functools
5
+ import inspect
6
+ from collections.abc import Callable, Sequence
7
+ from types import FunctionType, MethodType, MethodWrapperType
8
+ from typing import Any, Never, Self
9
+
10
+ from sonolus.backend.excepthook import install_excepthook
11
+ from sonolus.backend.utils import get_function, scan_writes
12
+ from sonolus.script.debug import assert_true
13
+ from sonolus.script.internal.builtin_impls import BUILTIN_IMPLS, _bool, _float, _int, _len
14
+ from sonolus.script.internal.constant import ConstantValue
15
+ from sonolus.script.internal.context import Context, EmptyBinding, Scope, ValueBinding, ctx, set_ctx
16
+ from sonolus.script.internal.descriptor import SonolusDescriptor
17
+ from sonolus.script.internal.error import CompilationError
18
+ from sonolus.script.internal.impl import validate_value
19
+ from sonolus.script.internal.tuple_impl import TupleImpl
20
+ from sonolus.script.internal.value import Value
21
+ from sonolus.script.iterator import SonolusIterator
22
+ from sonolus.script.num import Num, _is_num
23
+
24
+ _compiler_internal_ = True
25
+
26
+
27
+ def compile_and_call[**P, R](fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs) -> R:
28
+ if not ctx():
29
+ return fn(*args, **kwargs)
30
+ return validate_value(generate_fn_impl(fn)(*args, **kwargs))
31
+
32
+
33
+ def generate_fn_impl(fn: Callable):
34
+ install_excepthook()
35
+ match fn:
36
+ case ConstantValue() as value if value._is_py_():
37
+ return generate_fn_impl(value._as_py_())
38
+ case MethodType() as method:
39
+ return functools.partial(generate_fn_impl(method.__func__), method.__self__)
40
+ case FunctionType() as function:
41
+ if getattr(function, "_meta_fn_", False):
42
+ return function
43
+ return functools.partial(eval_fn, function)
44
+ case _:
45
+ if callable(fn) and isinstance(fn, Value):
46
+ return generate_fn_impl(fn.__call__)
47
+ elif callable(fn):
48
+ raise TypeError(f"Unsupported callable {fn!r}")
49
+ else:
50
+ raise TypeError(f"'{type(fn).__name__}' object is not callable")
51
+
52
+
53
+ def eval_fn(fn: Callable, /, *args, **kwargs):
54
+ source_file, node = get_function(fn)
55
+ bound_args = inspect.signature(fn).bind(*args, **kwargs)
56
+ bound_args.apply_defaults()
57
+ global_vars = {
58
+ **builtins.__dict__,
59
+ **fn.__globals__,
60
+ **inspect.getclosurevars(fn).nonlocals,
61
+ }
62
+ return Visitor(source_file, bound_args, global_vars).run(node)
63
+
64
+
65
+ unary_ops = {
66
+ ast.Invert: "__invert__",
67
+ ast.UAdd: "__pos__",
68
+ ast.USub: "__neg__",
69
+ }
70
+
71
+ bin_ops = {
72
+ ast.Add: "__add__",
73
+ ast.Sub: "__sub__",
74
+ ast.Mult: "__mul__",
75
+ ast.Div: "__truediv__",
76
+ ast.FloorDiv: "__floordiv__",
77
+ ast.Mod: "__mod__",
78
+ ast.Pow: "__pow__",
79
+ ast.LShift: "__lshift__",
80
+ ast.RShift: "__rshift__",
81
+ ast.BitOr: "__or__",
82
+ ast.BitAnd: "__and__",
83
+ ast.BitXor: "__xor__",
84
+ ast.MatMult: "__matmul__",
85
+ }
86
+
87
+ rbin_ops = {
88
+ ast.Add: "__radd__",
89
+ ast.Sub: "__rsub__",
90
+ ast.Mult: "__rmul__",
91
+ ast.Div: "__rtruediv__",
92
+ ast.FloorDiv: "__rfloordiv__",
93
+ ast.Mod: "__rmod__",
94
+ ast.Pow: "__rpow__",
95
+ ast.LShift: "__rlshift__",
96
+ ast.RShift: "__rrshift__",
97
+ ast.BitOr: "__ror__",
98
+ ast.BitAnd: "__rand__",
99
+ ast.BitXor: "__rxor__",
100
+ ast.MatMult: "__rmatmul__",
101
+ }
102
+
103
+ inplace_ops = {
104
+ ast.Add: "__iadd__",
105
+ ast.Sub: "__isub__",
106
+ ast.Mult: "__imul__",
107
+ ast.Div: "__itruediv__",
108
+ ast.FloorDiv: "__ifloordiv__",
109
+ ast.Mod: "__imod__",
110
+ ast.Pow: "__ipow__",
111
+ ast.LShift: "__ilshift__",
112
+ ast.RShift: "__irshift__",
113
+ ast.BitOr: "__ior__",
114
+ ast.BitXor: "__ixor__",
115
+ ast.BitAnd: "__iand__",
116
+ ast.MatMult: "__imatmul__",
117
+ }
118
+
119
+ comp_ops = {
120
+ ast.Eq: "__eq__",
121
+ ast.NotEq: "__ne__",
122
+ ast.Lt: "__lt__",
123
+ ast.LtE: "__le__",
124
+ ast.Gt: "__gt__",
125
+ ast.GtE: "__ge__",
126
+ }
127
+
128
+ rcomp_ops = {
129
+ ast.Eq: "__eq__",
130
+ ast.NotEq: "__ne__",
131
+ ast.Lt: "__gt__",
132
+ ast.LtE: "__ge__",
133
+ ast.Gt: "__lt__",
134
+ ast.GtE: "__le__",
135
+ ast.In: "__contains__", # Only supported on the right side
136
+ ast.NotIn: "__contains__",
137
+ }
138
+
139
+ op_to_symbol = {
140
+ ast.Add: "+",
141
+ ast.Sub: "-",
142
+ ast.Mult: "*",
143
+ ast.Div: "/",
144
+ ast.FloorDiv: "//",
145
+ ast.Mod: "%",
146
+ ast.Pow: "**",
147
+ ast.Eq: "==",
148
+ ast.NotEq: "!=",
149
+ ast.Lt: "<",
150
+ ast.LtE: "<=",
151
+ ast.Gt: ">",
152
+ ast.GtE: ">=",
153
+ ast.And: "and",
154
+ ast.Or: "or",
155
+ ast.BitAnd: "&",
156
+ ast.BitOr: "|",
157
+ ast.BitXor: "^",
158
+ ast.LShift: "<<",
159
+ ast.RShift: ">>",
160
+ ast.USub: "-",
161
+ ast.UAdd: "+",
162
+ ast.Invert: "~",
163
+ ast.Not: "not",
164
+ ast.In: "in",
165
+ ast.NotIn: "not in",
166
+ }
167
+
168
+
169
+ class Visitor(ast.NodeVisitor):
170
+ source_file: str
171
+ globals: dict[str, Any]
172
+ bound_args: inspect.BoundArguments
173
+ used_names: dict[str, int]
174
+ return_ctxs: list[Context] # Contexts at return statements, which will branch to the exit
175
+ loop_head_ctxs: list[Context] # Contexts at loop heads, from outer to inner
176
+ break_ctxs: list[list[Context]] # Contexts at break statements, from outer to inner
177
+ active_ctx: Context | None # The active context for use in nested functions=
178
+ parent: Self | None # The parent visitor for use in nested functions
179
+
180
+ def __init__(
181
+ self,
182
+ source_file: str,
183
+ bound_args: inspect.BoundArguments,
184
+ global_vars: dict[str, Any],
185
+ parent: Self | None = None,
186
+ ):
187
+ self.source_file = source_file
188
+ self.globals = global_vars
189
+ self.bound_args = bound_args
190
+ self.used_names = {}
191
+ self.return_ctxs = []
192
+ self.loop_head_ctxs = []
193
+ self.break_ctxs = []
194
+ self.active_ctx = None
195
+ self.parent = parent
196
+
197
+ def run(self, node):
198
+ before_ctx = ctx()
199
+ set_ctx(before_ctx.branch_with_scope(None, Scope()))
200
+ for name, value in self.bound_args.arguments.items():
201
+ ctx().scope.set_value(name, validate_value(value))
202
+ match node:
203
+ case ast.FunctionDef(body=body):
204
+ ctx().scope.set_value("$return", validate_value(None))
205
+ for stmt in body:
206
+ self.visit(stmt)
207
+ case ast.Lambda(body=body):
208
+ result = self.visit(body)
209
+ ctx().scope.set_value("$return", result)
210
+ case _:
211
+ raise NotImplementedError("Unsupported syntax")
212
+ after_ctx = Context.meet([*self.return_ctxs, ctx()])
213
+ self.active_ctx = after_ctx
214
+ result_binding = after_ctx.scope.get_binding("$return")
215
+ if not isinstance(result_binding, ValueBinding):
216
+ raise ValueError("Function has conflicting return values")
217
+ set_ctx(after_ctx.branch_with_scope(None, before_ctx.scope.copy()))
218
+ return result_binding.value
219
+
220
+ def visit(self, node):
221
+ """Visit a node."""
222
+ # We want this here so this is filtered out of tracebacks
223
+ method = "visit_" + node.__class__.__name__
224
+ visitor = getattr(self, method, self.generic_visit)
225
+ with self.reporting_errors_at_node(node):
226
+ return visitor(node)
227
+
228
+ def visit_FunctionDef(self, node):
229
+ name = node.name
230
+ signature = self.arguments_to_signature(node.args)
231
+
232
+ def fn(*args, **kwargs):
233
+ bound = signature.bind(*args, **kwargs)
234
+ bound.apply_defaults()
235
+ return Visitor(
236
+ self.source_file,
237
+ bound,
238
+ self.globals,
239
+ self,
240
+ ).run(node)
241
+
242
+ fn._meta_fn_ = True
243
+ fn.__name__ = name
244
+ fn.__qualname__ = name
245
+
246
+ for decorator in reversed(node.decorator_list):
247
+ fn = self.handle_call(decorator, self.visit(decorator), fn)
248
+
249
+ ctx().scope.set_value(name, validate_value(fn))
250
+
251
+ def visit_AsyncFunctionDef(self, node):
252
+ raise NotImplementedError("Async functions are not supported")
253
+
254
+ def visit_ClassDef(self, node):
255
+ raise NotImplementedError("Classes within functions are not supported")
256
+
257
+ def visit_Return(self, node):
258
+ value = self.visit(node.value) if node.value else validate_value(None)
259
+ ctx().scope.set_value("$return", value)
260
+ self.return_ctxs.append(ctx())
261
+ set_ctx(ctx().into_dead())
262
+
263
+ def visit_Delete(self, node):
264
+ for target in node.targets:
265
+ match target:
266
+ case ast.Name():
267
+ raise NotImplementedError("Deleting variables is not supported")
268
+ case ast.Subscript(value=value, slice=slice):
269
+ self.handle_delitem(target, self.visit(value), self.visit(slice))
270
+ case ast.Attribute():
271
+ raise NotImplementedError("Deleting attributes is not supported")
272
+ case _:
273
+ raise NotImplementedError("Unsupported delete target")
274
+
275
+ def visit_Assign(self, node):
276
+ value = self.visit(node.value)
277
+ for target in node.targets:
278
+ self.handle_assign(target, value)
279
+
280
+ def visit_TypeAlias(self, node):
281
+ raise NotImplementedError("Type aliases are not supported")
282
+
283
+ def visit_AugAssign(self, node):
284
+ lhs_value = self.visit(node.target)
285
+ rhs_value = self.visit(node.value)
286
+ inplace_fn_name = inplace_ops[type(node.op)]
287
+ regular_fn_name = bin_ops[type(node.op)]
288
+ right_fn_name = rbin_ops[type(node.op)]
289
+ if hasattr(lhs_value, inplace_fn_name):
290
+ result = self.handle_call(node, getattr(lhs_value, inplace_fn_name), rhs_value)
291
+ if not self.is_not_implemented(result):
292
+ if result is not lhs_value:
293
+ raise ValueError("Inplace operation must return the same object")
294
+ self.handle_assign(node.target, result)
295
+ return
296
+ if hasattr(lhs_value, regular_fn_name):
297
+ result = self.handle_call(node, getattr(lhs_value, regular_fn_name), rhs_value)
298
+ if not self.is_not_implemented(result):
299
+ self.handle_assign(node.target, result)
300
+ return
301
+ if hasattr(rhs_value, right_fn_name) and type(lhs_value) is not type(rhs_value):
302
+ result = self.handle_call(node, getattr(rhs_value, right_fn_name), lhs_value)
303
+ if not self.is_not_implemented(result):
304
+ self.handle_assign(node.target, result)
305
+ return
306
+ raise TypeError(
307
+ f"unsupported operand type(s) for {op_to_symbol[type(node.op)]}=: "
308
+ f"'{type(lhs_value).__name__}' and '{type(rhs_value).__name__}'"
309
+ )
310
+
311
+ def visit_AnnAssign(self, node):
312
+ value = self.visit(node.value)
313
+ self.handle_assign(node.target, value)
314
+
315
+ def visit_For(self, node):
316
+ from sonolus.script.internal.tuple_impl import TupleImpl
317
+
318
+ iterable = self.visit(node.iter)
319
+ if isinstance(iterable, TupleImpl):
320
+ # Unroll the loop
321
+ for value in iterable.value:
322
+ set_ctx(ctx().branch(None))
323
+ self.handle_assign(node.target, validate_value(value))
324
+ for stmt in node.body:
325
+ self.visit(stmt)
326
+ return
327
+ iterator = self.handle_call(node, iterable.__iter__)
328
+ if not isinstance(iterator, SonolusIterator):
329
+ raise ValueError("Unsupported iterator")
330
+ writes = scan_writes(node)
331
+ header_ctx = ctx().prepare_loop_header(writes)
332
+ self.loop_head_ctxs.append(header_ctx)
333
+ self.break_ctxs.append([])
334
+ set_ctx(header_ctx)
335
+ has_next = self.ensure_boolean_num(self.handle_call(node, iterator.has_next))
336
+ if has_next._is_py_() and not has_next._as_py_():
337
+ # The loop will never run, continue after evaluating the condition
338
+ for stmt in node.orelse:
339
+ self.visit(stmt)
340
+ return
341
+ ctx().test = has_next.ir()
342
+ body_ctx = ctx().branch(None)
343
+ else_ctx = ctx().branch(0)
344
+
345
+ set_ctx(body_ctx)
346
+ self.handle_assign(node.target, self.handle_call(node, iterator.next))
347
+ for stmt in node.body:
348
+ self.visit(stmt)
349
+ ctx().branch_to_loop_header(header_ctx)
350
+
351
+ set_ctx(else_ctx)
352
+ for stmt in node.orelse:
353
+ self.visit(stmt)
354
+ else_end_ctx = ctx()
355
+
356
+ self.loop_head_ctxs.pop()
357
+ break_ctxs = self.break_ctxs.pop()
358
+ after_ctx = Context.meet([else_end_ctx, *break_ctxs])
359
+ set_ctx(after_ctx)
360
+
361
+ def visit_While(self, node):
362
+ writes = scan_writes(node)
363
+ header_ctx = ctx().prepare_loop_header(writes)
364
+ self.loop_head_ctxs.append(header_ctx)
365
+ self.break_ctxs.append([])
366
+ set_ctx(header_ctx)
367
+ test = self.ensure_boolean_num(self.visit(node.test))
368
+ if test._is_py_() and not test._as_py_():
369
+ # The loop will never run, continue after evaluating the condition
370
+ for stmt in node.orelse:
371
+ self.visit(stmt)
372
+ return
373
+ ctx().test = test.ir()
374
+ body_ctx = ctx().branch(None)
375
+ else_ctx = ctx().branch(0)
376
+
377
+ set_ctx(body_ctx)
378
+ for stmt in node.body:
379
+ self.visit(stmt)
380
+ ctx().branch_to_loop_header(header_ctx)
381
+
382
+ set_ctx(else_ctx)
383
+ for stmt in node.orelse:
384
+ self.visit(stmt)
385
+ else_end_ctx = ctx()
386
+
387
+ self.loop_head_ctxs.pop()
388
+ break_ctxs = self.break_ctxs.pop()
389
+ after_ctx = Context.meet([else_end_ctx, *break_ctxs])
390
+ set_ctx(after_ctx)
391
+
392
+ def visit_If(self, node):
393
+ test = self.ensure_boolean_num(self.visit(node.test))
394
+
395
+ if test._is_py_():
396
+ if test._as_py_():
397
+ for stmt in node.body:
398
+ self.visit(stmt)
399
+ else:
400
+ for stmt in node.orelse:
401
+ self.visit(stmt)
402
+ return
403
+
404
+ ctx_init = ctx()
405
+ ctx_init.test = test.ir()
406
+ true_ctx = ctx_init.branch(None)
407
+ false_ctx = ctx_init.branch(0)
408
+
409
+ set_ctx(true_ctx)
410
+ for stmt in node.body:
411
+ self.visit(stmt)
412
+ true_end_ctx = ctx()
413
+
414
+ set_ctx(false_ctx)
415
+ for stmt in node.orelse:
416
+ self.visit(stmt)
417
+ false_end_ctx = ctx()
418
+
419
+ set_ctx(Context.meet([true_end_ctx, false_end_ctx]))
420
+
421
+ def visit_With(self, node):
422
+ raise NotImplementedError("With statements are not supported")
423
+
424
+ def visit_AsyncWith(self, node):
425
+ raise NotImplementedError("Async with statements are not supported")
426
+
427
+ def visit_Match(self, node):
428
+ subject = self.visit(node.subject)
429
+ end_ctxs = []
430
+ for case in node.cases:
431
+ if not ctx().live:
432
+ break
433
+ true_ctx, false_ctx = self.handle_match_pattern(subject, case.pattern)
434
+ if not true_ctx.live:
435
+ set_ctx(false_ctx)
436
+ continue
437
+ set_ctx(true_ctx)
438
+ guard = self.ensure_boolean_num(self.visit(case.guard)) if case.guard else validate_value(True)
439
+ if guard._is_py_():
440
+ if guard._as_py_():
441
+ for stmt in case.body:
442
+ self.visit(stmt)
443
+ end_ctxs.append(ctx())
444
+ else:
445
+ # Merge failing before the guard and failing now at the guard (which we know is guaranteed to fail)
446
+ false_ctx = Context.meet([ctx(), false_ctx])
447
+ else:
448
+ ctx().test = guard.ir()
449
+ guard_true_ctx = ctx().branch(None)
450
+ guard_false_ctx = ctx().branch(0)
451
+ set_ctx(guard_true_ctx)
452
+ for stmt in case.body:
453
+ self.visit(stmt)
454
+ end_ctxs.append(ctx())
455
+ false_ctx = Context.meet([false_ctx, guard_false_ctx])
456
+ set_ctx(false_ctx)
457
+ end_ctxs.append(ctx())
458
+ if end_ctxs:
459
+ set_ctx(Context.meet(end_ctxs))
460
+
461
+ def handle_match_pattern(self, subject: Value, pattern: ast.pattern) -> tuple[Context, Context]:
462
+ from sonolus.script.internal.generic import validate_type_spec
463
+ from sonolus.script.internal.tuple_impl import TupleImpl
464
+
465
+ if not ctx().live:
466
+ return ctx().into_dead(), ctx()
467
+
468
+ match pattern:
469
+ case ast.MatchValue(value=value):
470
+ value = self.visit(value)
471
+ test = self.ensure_boolean_num(validate_value(subject == value))
472
+ if test._is_py_():
473
+ if test._as_py_():
474
+ return ctx(), ctx().into_dead()
475
+ else:
476
+ return ctx().into_dead(), ctx()
477
+ ctx_init = ctx()
478
+ ctx_init.test = test.ir()
479
+ true_ctx = ctx_init.branch(None)
480
+ false_ctx = ctx_init.branch(0)
481
+ return true_ctx, false_ctx
482
+ case ast.MatchSingleton(value=value):
483
+ match value:
484
+ case True:
485
+ raise NotImplementedError("Matching against True is not supported, use 1 instead")
486
+ case False:
487
+ raise NotImplementedError("Matching against False is not supported, use 0 instead")
488
+ case None:
489
+ test = validate_value(subject._is_py_() and subject._as_py_() is None)
490
+ case _:
491
+ raise NotImplementedError("Unsupported match singleton")
492
+ ctx_init = ctx()
493
+ ctx_init.test = test.ir()
494
+ true_ctx = ctx_init.branch(None)
495
+ false_ctx = ctx_init.branch(0)
496
+ return true_ctx, false_ctx
497
+ case ast.MatchSequence(patterns=patterns):
498
+ target_len = len(patterns)
499
+ if not (isinstance(subject, Sequence | TupleImpl)):
500
+ return ctx().into_dead(), ctx()
501
+ length_test = self.ensure_boolean_num(validate_value(_len(subject) == target_len))
502
+ ctx_init = ctx()
503
+ if not length_test._is_py_():
504
+ ctx_init.test = length_test.ir()
505
+ true_ctx = ctx_init.branch(None)
506
+ false_ctxs = [ctx_init.branch(0)]
507
+ elif length_test._as_py_():
508
+ true_ctx = ctx_init
509
+ false_ctxs = []
510
+ else:
511
+ return ctx().into_dead(), ctx()
512
+ set_ctx(true_ctx)
513
+ for i, subpattern in enumerate(patterns):
514
+ if not ctx().live:
515
+ break
516
+ value = self.handle_getitem(subpattern, subject, validate_value(i))
517
+ true_ctx, false_ctx = self.handle_match_pattern(value, subpattern)
518
+ false_ctxs.append(false_ctx)
519
+ set_ctx(true_ctx)
520
+ return true_ctx, Context.meet(false_ctxs)
521
+ case ast.MatchMapping():
522
+ raise NotImplementedError("Match mappings are not supported")
523
+ case ast.MatchClass(cls=cls, patterns=patterns, kwd_attrs=kwd_attrs, kwd_patterns=kwd_patterns):
524
+ cls = self.visit(cls)
525
+ if cls._is_py_() and cls._as_py_() in {_int, _float, _bool}:
526
+ raise TypeError("Instance check against int, float, or bool is not supported, use Num instead")
527
+ cls = validate_type_spec(cls)
528
+ if not isinstance(cls, type):
529
+ raise TypeError("Class is not a type")
530
+ if not isinstance(subject, cls):
531
+ return ctx().into_dead(), ctx()
532
+ if patterns:
533
+ if not hasattr(cls, "__match_args__"):
534
+ raise TypeError("Class does not support match patterns")
535
+ if len(cls.__match_args__) < len(patterns):
536
+ raise ValueError("Too many match patterns")
537
+ # kwd_attrs can't be mixed with patterns on the syntax level,
538
+ # so we can just set it like this since it's empty
539
+ kwd_attrs = cls.__match_args__[: len(patterns)]
540
+ kwd_patterns = patterns
541
+ if kwd_attrs:
542
+ true_ctx = ctx()
543
+ false_ctxs = []
544
+ for attr, subpattern in zip(kwd_attrs, kwd_patterns, strict=False):
545
+ if not hasattr(subject, attr):
546
+ raise AttributeError(f"Object has no attribute {attr}")
547
+ value = self.handle_getattr(subpattern, subject, attr)
548
+ true_ctx, false_ctx = self.handle_match_pattern(value, subpattern)
549
+ false_ctxs.append(false_ctx)
550
+ set_ctx(true_ctx)
551
+ return true_ctx, Context.meet(false_ctxs)
552
+ return ctx(), ctx().into_dead()
553
+ case ast.MatchStar():
554
+ raise NotImplementedError("Match stars are not supported")
555
+ case ast.MatchAs(pattern=pattern, name=name):
556
+ if pattern:
557
+ true_ctx, false_ctx = self.handle_match_pattern(subject, pattern)
558
+ if name:
559
+ true_ctx.scope.set_value(name, subject)
560
+ return true_ctx, false_ctx
561
+ else:
562
+ if name:
563
+ ctx().scope.set_value(name, subject)
564
+ return ctx(), ctx().into_dead()
565
+ case ast.MatchOr():
566
+ true_ctxs = []
567
+ assert pattern.patterns
568
+ for subpattern in pattern.patterns:
569
+ if not ctx().live:
570
+ break
571
+ true_ctx, false_ctx = self.handle_match_pattern(subject, subpattern)
572
+ true_ctxs.append(true_ctx)
573
+ set_ctx(false_ctx)
574
+ return Context.meet(true_ctxs), ctx()
575
+
576
+ def visit_Raise(self, node):
577
+ raise NotImplementedError("Raise statements are not supported")
578
+
579
+ def visit_Try(self, node):
580
+ raise NotImplementedError("Try statements are not supported")
581
+
582
+ def visit_TryStar(self, node):
583
+ raise NotImplementedError("Try* statements are not supported")
584
+
585
+ def visit_Assert(self, node):
586
+ self.handle_call(
587
+ node, assert_true, self.visit(node.test), self.visit(node.msg) if node.msg else validate_value(None)
588
+ )
589
+
590
+ def visit_Import(self, node):
591
+ raise NotImplementedError("Import statements are not supported")
592
+
593
+ def visit_ImportFrom(self, node):
594
+ raise NotImplementedError("Import statements are not supported")
595
+
596
+ def visit_Global(self, node):
597
+ raise NotImplementedError("Global statements are not supported")
598
+
599
+ def visit_Nonlocal(self, node):
600
+ raise NotImplementedError("Nonlocal statements are not supported")
601
+
602
+ def visit_Expr(self, node):
603
+ return self.visit(node.value)
604
+
605
+ def visit_Pass(self, node):
606
+ pass
607
+
608
+ def visit_Break(self, node):
609
+ self.break_ctxs[-1].append(ctx())
610
+ set_ctx(ctx().into_dead())
611
+
612
+ def visit_Continue(self, node):
613
+ ctx().branch_to_loop_header(self.loop_head_ctxs[-1])
614
+ set_ctx(ctx().into_dead())
615
+
616
+ def visit_BoolOp(self, node) -> Value:
617
+ match node.op:
618
+ case ast.And():
619
+ handler = self.handle_and
620
+ case ast.Or():
621
+ handler = self.handle_or
622
+ case _:
623
+ raise NotImplementedError(f"Unsupported bool operator {op_to_symbol[type(node.op)]}")
624
+
625
+ if not node.values:
626
+ raise ValueError("Bool operator requires at least one operand")
627
+ if len(node.values) == 1:
628
+ return self.visit(node.values[0])
629
+ initial, *rest = node.values
630
+ return handler(self.visit(initial), ast.copy_location(ast.BoolOp(op=node.op, values=rest), node))
631
+
632
+ def visit_NamedExpr(self, node):
633
+ value = self.visit(node.value)
634
+ self.handle_assign(node.target, value)
635
+ return value
636
+
637
+ def visit_BinOp(self, node):
638
+ lhs = self.visit(node.left)
639
+ rhs = self.visit(node.right)
640
+ op = bin_ops[type(node.op)]
641
+ if lhs._is_py_() and rhs._is_py_():
642
+ lhs_py = lhs._as_py_()
643
+ rhs_py = rhs._as_py_()
644
+ if isinstance(lhs_py, type) and isinstance(rhs_py, type):
645
+ return validate_value(getattr(lhs_py, op)(rhs_py))
646
+ if hasattr(lhs, op):
647
+ result = self.handle_call(node, getattr(lhs, op), rhs)
648
+ if not self.is_not_implemented(result):
649
+ return result
650
+ if hasattr(rhs, rbin_ops[type(node.op)]) and type(lhs) is not type(rhs):
651
+ result = self.handle_call(node, getattr(rhs, rbin_ops[type(node.op)]), lhs)
652
+ if not self.is_not_implemented(result):
653
+ return result
654
+ raise TypeError(
655
+ f"unsupported operand type(s) for {op_to_symbol[type(node.op)]}: "
656
+ f"'{type(lhs).__name__}' and '{type(rhs).__name__}'"
657
+ )
658
+
659
+ def visit_UnaryOp(self, node):
660
+ operand = self.visit(node.operand)
661
+ if isinstance(node.op, ast.Not):
662
+ return self.ensure_boolean_num(operand).not_()
663
+ op = unary_ops[type(node.op)]
664
+ if hasattr(operand, op):
665
+ return self.handle_call(node, getattr(operand, op))
666
+ raise TypeError(f"bad operand type for unary {op_to_symbol[type(node.op)]}: '{type(operand).__name__}'")
667
+
668
+ def visit_Lambda(self, node):
669
+ signature = self.arguments_to_signature(node.args)
670
+
671
+ def fn(*args, **kwargs):
672
+ bound = signature.bind(*args, **kwargs)
673
+ bound.apply_defaults()
674
+ return Visitor(
675
+ self.source_file,
676
+ bound,
677
+ self.globals,
678
+ self,
679
+ ).run(node)
680
+
681
+ fn._meta_fn_ = True
682
+ fn.__name__ = "<lambda>"
683
+
684
+ return validate_value(fn)
685
+
686
+ def visit_IfExp(self, node):
687
+ test = self.ensure_boolean_num(self.visit(node.test))
688
+
689
+ if test._is_py_():
690
+ if test._as_py_():
691
+ return self.visit(node.body)
692
+ else:
693
+ return self.visit(node.orelse)
694
+
695
+ res_name = self.new_name("ifexp")
696
+ ctx_init = ctx()
697
+ ctx_init.test = test.ir()
698
+
699
+ set_ctx(ctx_init.branch(None))
700
+ true_value = self.visit(node.body)
701
+ ctx().scope.set_value(res_name, true_value)
702
+ ctx_true = ctx()
703
+
704
+ set_ctx(ctx_init.branch(0))
705
+ false_value = self.visit(node.orelse)
706
+ ctx().scope.set_value(res_name, false_value)
707
+ ctx_false = ctx()
708
+
709
+ set_ctx(Context.meet([ctx_true, ctx_false]))
710
+ return ctx().scope.get_value(res_name)
711
+
712
+ def visit_Dict(self, node):
713
+ return validate_value({self.visit(k): self.visit(v) for k, v in zip(node.keys, node.values, strict=True)})
714
+
715
+ def visit_Set(self, node):
716
+ raise NotImplementedError("Set literals are not supported")
717
+
718
+ def visit_ListComp(self, node):
719
+ raise NotImplementedError("List comprehensions are not supported")
720
+
721
+ def visit_SetComp(self, node):
722
+ raise NotImplementedError("Set comprehensions are not supported")
723
+
724
+ def visit_DictComp(self, node):
725
+ raise NotImplementedError("Dict comprehensions are not supported")
726
+
727
+ def visit_GeneratorExp(self, node):
728
+ raise NotImplementedError("Generator expressions are not supported")
729
+
730
+ def visit_Await(self, node):
731
+ raise NotImplementedError("Await expressions are not supported")
732
+
733
+ def visit_Yield(self, node):
734
+ raise NotImplementedError("Yield expressions are not supported")
735
+
736
+ def visit_YieldFrom(self, node):
737
+ raise NotImplementedError("Yield from expressions are not supported")
738
+
739
+ def _has_real_method(self, obj: Value, method_name: str) -> bool:
740
+ return hasattr(obj, method_name) and not isinstance(getattr(obj, method_name), MethodWrapperType)
741
+
742
+ def visit_Compare(self, node):
743
+ result_name = self.new_name("compare")
744
+ ctx().scope.set_value(result_name, Num._accept_(0))
745
+ l_val = self.visit(node.left)
746
+ false_ctxs = []
747
+ for i, (op, rhs) in enumerate(zip(node.ops, node.comparators, strict=True)):
748
+ r_val = self.visit(rhs)
749
+ inverted = isinstance(op, ast.NotIn)
750
+ result = None
751
+ if isinstance(op, ast.Is | ast.IsNot):
752
+ if not (r_val._is_py_() and r_val._as_py_() is None):
753
+ raise TypeError("The right operand of 'is' must be None")
754
+ if isinstance(op, ast.Is):
755
+ result = Num._accept_(l_val._is_py_() and l_val._as_py_() is None)
756
+ else:
757
+ result = Num._accept_(not (l_val._is_py_() and l_val._as_py_() is None))
758
+ elif type(op) in comp_ops and self._has_real_method(l_val, comp_ops[type(op)]):
759
+ result = self.handle_call(node, getattr(l_val, comp_ops[type(op)]), r_val)
760
+ if (
761
+ (result is None or self.is_not_implemented(result))
762
+ and type(op) in rcomp_ops
763
+ and self._has_real_method(r_val, rcomp_ops[type(op)])
764
+ ):
765
+ result = self.handle_call(node, getattr(r_val, rcomp_ops[type(op)]), l_val)
766
+ if result is None or self.is_not_implemented(result):
767
+ if type(op) is ast.Eq:
768
+ result = Num._accept_(l_val is r_val)
769
+ elif type(op) is ast.NotEq:
770
+ result = Num._accept_(l_val is not r_val)
771
+ else:
772
+ raise TypeError(
773
+ f"'{op_to_symbol[type(op)]}' not supported between instances of '{type(l_val).__name__}' and "
774
+ f"'{type(r_val).__name__}'"
775
+ )
776
+ result = self.ensure_boolean_num(result)
777
+ if inverted:
778
+ result = result.not_()
779
+ curr_ctx = ctx()
780
+ if i == len(node.ops) - 1:
781
+ curr_ctx.scope.set_value(result_name, result)
782
+ elif result._is_py_():
783
+ if result._as_py_():
784
+ l_val = r_val
785
+ else:
786
+ false_ctxs.append(curr_ctx)
787
+ set_ctx(curr_ctx.into_dead())
788
+ break
789
+ else:
790
+ curr_ctx.test = result.ir()
791
+ true_ctx = curr_ctx.branch(None)
792
+ false_ctx = curr_ctx.branch(0)
793
+ false_ctxs.append(false_ctx)
794
+ set_ctx(true_ctx)
795
+ l_val = r_val
796
+ last_ctx = ctx() # This is the result of the last comparison returning true
797
+ set_ctx(Context.meet([last_ctx, *false_ctxs]))
798
+ return ctx().scope.get_value(result_name)
799
+
800
+ def visit_Call(self, node):
801
+ from sonolus.script.internal.dict_impl import DictImpl
802
+
803
+ fn = self.visit(node.func)
804
+ args = []
805
+ kwargs = {}
806
+ for arg in node.args:
807
+ if isinstance(arg, ast.Starred):
808
+ args.extend(self.handle_starred(self.visit(arg.value)))
809
+ else:
810
+ args.append(self.visit(arg))
811
+ for keyword in node.keywords:
812
+ if keyword.arg:
813
+ kwargs[keyword.arg] = self.visit(keyword.value)
814
+ else:
815
+ value = self.visit(keyword.value)
816
+ if isinstance(value, DictImpl):
817
+ if not all(isinstance(k, str) for k in value.value):
818
+ raise ValueError("Keyword arguments must be strings")
819
+ kwargs.update(value.value)
820
+ else:
821
+ raise ValueError("Starred keyword arguments (**kwargs) must be dictionaries")
822
+ return self.handle_call(node, fn, *args, **kwargs)
823
+
824
+ def visit_FormattedValue(self, node):
825
+ raise NotImplementedError("F-strings are not supported")
826
+
827
+ def visit_JoinedStr(self, node):
828
+ raise NotImplementedError("F-strings are not supported")
829
+
830
+ def visit_Constant(self, node):
831
+ return validate_value(node.value)
832
+
833
+ def visit_Attribute(self, node):
834
+ return self.handle_getattr(node, self.visit(node.value), node.attr)
835
+
836
+ def visit_Subscript(self, node):
837
+ value = self.visit(node.value)
838
+ slice_value = self.visit(node.slice)
839
+ return self.handle_getitem(node, value, slice_value)
840
+
841
+ def visit_Starred(self, node):
842
+ raise NotImplementedError("Starred expressions are not supported")
843
+
844
+ def visit_Name(self, node):
845
+ self.active_ctx = ctx()
846
+ v = self
847
+ while v:
848
+ if not isinstance(v.active_ctx.scope.get_binding(node.id), EmptyBinding):
849
+ return v.active_ctx.scope.get_value(node.id)
850
+ v = v.parent
851
+ if node.id in self.globals:
852
+ value = self.globals[node.id]
853
+ if value is ctx:
854
+ raise ValueError("Unexpected use of ctx in non meta-function")
855
+ return validate_value(BUILTIN_IMPLS.get(id(value), value))
856
+ raise NameError(f"Name {node.id} is not defined")
857
+
858
+ def visit_List(self, node):
859
+ raise NotImplementedError("List literals are not supported")
860
+
861
+ def visit_Tuple(self, node):
862
+ values = []
863
+ for elt in node.elts:
864
+ if isinstance(elt, ast.Starred):
865
+ values.extend(self.handle_starred(self.visit(elt.value)))
866
+ else:
867
+ values.append(self.visit(elt))
868
+ return validate_value(tuple(values))
869
+
870
+ def visit_Slice(self, node):
871
+ raise NotImplementedError("Slices are not supported")
872
+
873
+ def handle_assign(self, target: ast.stmt | ast.expr, value: Value):
874
+ match target:
875
+ case ast.Name(id=name):
876
+ ctx().scope.set_value(name, value)
877
+ case ast.Attribute(value=attr_value, attr=attr):
878
+ attr_value = self.visit(attr_value)
879
+ self.handle_setattr(target, attr_value, attr, value)
880
+ case ast.Subscript(value=sub_value, slice=slice_expr):
881
+ sub_value = self.visit(sub_value)
882
+ slice_value = self.visit(slice_expr)
883
+ self.handle_setitem(target, sub_value, slice_value, value)
884
+ case ast.Tuple(elts=elts) | ast.List(elts=elts):
885
+ values = self.handle_starred(value)
886
+ if len(elts) != len(values):
887
+ raise ValueError("Unpacking assignment requires the same number of elements")
888
+ for elt, v in zip(elts, values, strict=False):
889
+ self.handle_assign(elt, validate_value(v))
890
+ case ast.Starred():
891
+ raise NotImplementedError("Starred assignment is not supported")
892
+ case _:
893
+ raise NotImplementedError("Unsupported assignment target")
894
+
895
+ def handle_and(self, l_val: Value, r_expr: ast.expr) -> Value:
896
+ ctx_init = ctx()
897
+ l_val = self.ensure_boolean_num(l_val)
898
+
899
+ if l_val._is_py_():
900
+ if l_val._as_py_():
901
+ # The rhs is definitely evaluated, so we can return it directly
902
+ return self.ensure_boolean_num(self.visit(r_expr))
903
+ else:
904
+ return l_val
905
+
906
+ ctx_init.test = l_val.ir()
907
+ res_name = self.new_name("and")
908
+
909
+ set_ctx(ctx_init.branch(None))
910
+ r_val = self.ensure_boolean_num(self.visit(r_expr))
911
+ ctx().scope.set_value(res_name, r_val)
912
+ ctx_true = ctx()
913
+
914
+ set_ctx(ctx_init.branch(0))
915
+ ctx().scope.set_value(res_name, Num._accept_(0))
916
+ ctx_false = ctx()
917
+
918
+ set_ctx(Context.meet([ctx_true, ctx_false]))
919
+ if l_val._is_py_() and r_val._is_py_():
920
+ return Num._accept_(l_val._as_py_() and r_val._as_py_())
921
+ return ctx().scope.get_value(res_name)
922
+
923
+ def handle_or(self, l_val: Value, r_expr: ast.expr) -> Value:
924
+ ctx_init = ctx()
925
+ l_val = self.ensure_boolean_num(l_val)
926
+
927
+ if l_val._is_py_():
928
+ if l_val._as_py_():
929
+ return l_val
930
+ else:
931
+ # The rhs is definitely evaluated, so we can return it directly
932
+ return self.ensure_boolean_num(self.visit(r_expr))
933
+
934
+ ctx_init.test = l_val.ir()
935
+ res_name = self.new_name("or")
936
+
937
+ set_ctx(ctx_init.branch(None))
938
+ ctx().scope.set_value(res_name, l_val)
939
+ ctx_true = ctx()
940
+
941
+ set_ctx(ctx_init.branch(0))
942
+ r_val = self.ensure_boolean_num(self.visit(r_expr))
943
+ ctx().scope.set_value(res_name, r_val)
944
+ ctx_false = ctx()
945
+
946
+ set_ctx(Context.meet([ctx_true, ctx_false]))
947
+ if l_val._is_py_() and r_val._is_py_():
948
+ return Num._accept_(l_val._as_py_() or r_val._as_py_())
949
+ return ctx().scope.get_value(res_name)
950
+
951
+ def generic_visit(self, node):
952
+ if isinstance(node, ast.stmt | ast.expr):
953
+ with self.reporting_errors_at_node(node):
954
+ raise NotImplementedError(f"Unsupported syntax: {type(node).__name__}")
955
+ raise NotImplementedError(f"Unsupported syntax: {type(node).__name__}")
956
+
957
+ def handle_getattr(self, node: ast.stmt | ast.expr, target: Value, key: str) -> Value:
958
+ with self.reporting_errors_at_node(node):
959
+ if isinstance(target, ConstantValue):
960
+ # Unwrap so we can access fields
961
+ target = target._as_py_()
962
+ descriptor = type(target).__dict__.get(key)
963
+ match descriptor:
964
+ case property(fget=getter):
965
+ return self.handle_call(node, getter, target)
966
+ case SonolusDescriptor() | FunctionType() | classmethod() | staticmethod() | None:
967
+ return validate_value(getattr(target, key))
968
+ case non_descriptor if not hasattr(non_descriptor, "__get__"):
969
+ return validate_value(getattr(target, key))
970
+ case _:
971
+ raise TypeError(f"Unsupported field or descriptor {key}")
972
+
973
+ def handle_setattr(self, node: ast.stmt | ast.expr, target: Value, key: str, value: Value):
974
+ with self.reporting_errors_at_node(node):
975
+ if target._is_py_():
976
+ target = target._as_py_()
977
+ descriptor = getattr(type(target), key, None)
978
+ match descriptor:
979
+ case property(fset=setter):
980
+ if setter is None:
981
+ raise AttributeError(f"Cannot set attribute {key} because property has no setter")
982
+ self.handle_call(node, setter, target, value)
983
+ case SonolusDescriptor():
984
+ setattr(target, key, value)
985
+ case _:
986
+ raise TypeError(f"Unsupported field or descriptor {key}")
987
+
988
+ def handle_call[**P, R](
989
+ self, node: ast.stmt | ast.expr, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs
990
+ ) -> R:
991
+ """Handles a call to the given callable."""
992
+ self.active_ctx = ctx()
993
+ if (
994
+ isinstance(fn, Value)
995
+ and fn._is_py_()
996
+ and isinstance(fn._as_py_(), type)
997
+ and issubclass(fn._as_py_(), Value)
998
+ ):
999
+ return validate_value(self.execute_at_node(node, fn._as_py_(), *args, **kwargs))
1000
+ else:
1001
+ return self.execute_at_node(node, lambda: validate_value(compile_and_call(fn, *args, **kwargs)))
1002
+
1003
+ def handle_getitem(self, node: ast.stmt | ast.expr, target: Value, key: Value) -> Value:
1004
+ with self.reporting_errors_at_node(node):
1005
+ if target._is_py_() and isinstance(target._as_py_(), type):
1006
+ if not key._is_py_():
1007
+ raise ValueError("Type parameters must be compile-time constants")
1008
+ return validate_value(target._as_py_()[key._as_py_()])
1009
+ else:
1010
+ if isinstance(target, Value) and hasattr(target, "__getitem__"):
1011
+ return self.handle_call(node, target.__getitem__, key)
1012
+ raise TypeError(f"Cannot get items on {type(target).__name__}")
1013
+
1014
+ def handle_setitem(self, node: ast.stmt | ast.expr, target: Value, key: Value, value: Value):
1015
+ with self.reporting_errors_at_node(node):
1016
+ if isinstance(target, Value) and hasattr(target, "__setitem__"):
1017
+ return self.handle_call(node, target.__setitem__, key, value)
1018
+ raise TypeError(f"Cannot set items on {type(target).__name__}")
1019
+
1020
+ def handle_delitem(self, node: ast.stmt | ast.expr, target: Value, key: Value):
1021
+ with self.reporting_errors_at_node(node):
1022
+ if isinstance(target, Value) and hasattr(target, "__delitem__"):
1023
+ return self.handle_call(node, target.__delitem__, key)
1024
+ raise TypeError(f"Cannot delete items on {type(target).__name__}")
1025
+
1026
+ def handle_starred(self, value: Value) -> tuple[Value, ...]:
1027
+ if isinstance(value, TupleImpl):
1028
+ return value.value
1029
+ raise ValueError("Unsupported starred expression")
1030
+
1031
+ def is_not_implemented(self, value):
1032
+ value = validate_value(value)
1033
+ return value._is_py_() and value._as_py_() is NotImplemented
1034
+
1035
+ def ensure_boolean_num(self, value) -> Num:
1036
+ # This just checks the type for now, although we could support custom __bool__ implementations in the future
1037
+ if not _is_num(value):
1038
+ raise TypeError(f"Invalid type where a bool (Num) was expected: {type(value).__name__}")
1039
+ return value
1040
+
1041
+ def arguments_to_signature(self, arguments: ast.arguments) -> inspect.Signature:
1042
+ parameters: list[inspect.Parameter] = []
1043
+ pos_only_count = len(arguments.posonlyargs)
1044
+ for i, arg in enumerate(arguments.posonlyargs):
1045
+ default_idx = i - pos_only_count + len(arguments.defaults)
1046
+ default = self.visit(arguments.defaults[default_idx]) if default_idx >= 0 else None
1047
+ param = inspect.Parameter(
1048
+ name=arg.arg,
1049
+ kind=inspect.Parameter.POSITIONAL_ONLY,
1050
+ default=default if default_idx >= 0 else inspect.Parameter.empty,
1051
+ annotation=inspect.Parameter.empty,
1052
+ )
1053
+ parameters.append(param)
1054
+
1055
+ pos_kw_count = len(arguments.args)
1056
+ for i, arg in enumerate(arguments.args):
1057
+ default_idx = i - pos_kw_count + len(arguments.defaults)
1058
+ default = self.visit(arguments.defaults[default_idx]) if default_idx >= 0 else None
1059
+ param = inspect.Parameter(
1060
+ name=arg.arg,
1061
+ kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
1062
+ default=default if default_idx >= 0 else inspect.Parameter.empty,
1063
+ annotation=inspect.Parameter.empty,
1064
+ )
1065
+ parameters.append(param)
1066
+
1067
+ if arguments.vararg:
1068
+ param = inspect.Parameter(
1069
+ name=arguments.vararg.arg,
1070
+ kind=inspect.Parameter.VAR_POSITIONAL,
1071
+ default=inspect.Parameter.empty,
1072
+ annotation=inspect.Parameter.empty,
1073
+ )
1074
+ parameters.append(param)
1075
+
1076
+ for i, arg in enumerate(arguments.kwonlyargs):
1077
+ default = self.visit(arguments.kw_defaults[i]) if arguments.kw_defaults[i] is not None else None
1078
+ param = inspect.Parameter(
1079
+ name=arg.arg,
1080
+ kind=inspect.Parameter.KEYWORD_ONLY,
1081
+ default=default if default is not None else inspect.Parameter.empty,
1082
+ annotation=inspect.Parameter.empty,
1083
+ )
1084
+ parameters.append(param)
1085
+
1086
+ if arguments.kwarg:
1087
+ param = inspect.Parameter(
1088
+ name=arguments.kwarg.arg,
1089
+ kind=inspect.Parameter.VAR_KEYWORD,
1090
+ default=inspect.Parameter.empty,
1091
+ annotation=inspect.Parameter.empty,
1092
+ )
1093
+ parameters.append(param)
1094
+
1095
+ return inspect.Signature(parameters)
1096
+
1097
+ def raise_exception_at_node(self, node: ast.stmt | ast.expr, cause: Exception) -> Never:
1098
+ """Throws a compilation error at the given node."""
1099
+
1100
+ def thrower() -> Never:
1101
+ raise CompilationError(str(cause)) from cause
1102
+
1103
+ self.execute_at_node(node, thrower)
1104
+
1105
+ def execute_at_node[**P, R](
1106
+ self, node: ast.stmt | ast.expr, fn: Callable[P, R], /, *args: P.args, **kwargs: P.kwargs
1107
+ ) -> R:
1108
+ """Executes the given function at the given node for a better traceback."""
1109
+ expr = ast.Expression(
1110
+ body=ast.Call(
1111
+ func=ast.Name(id="fn", ctx=ast.Load()),
1112
+ args=[ast.Starred(value=ast.Name(id="args", ctx=ast.Load()), ctx=ast.Load())],
1113
+ keywords=[ast.keyword(value=ast.Name(id="kwargs", ctx=ast.Load()), arg=None)],
1114
+ lineno=node.lineno,
1115
+ col_offset=node.col_offset,
1116
+ end_lineno=node.end_lineno,
1117
+ end_col_offset=node.end_col_offset,
1118
+ ),
1119
+ )
1120
+ expr = ast.fix_missing_locations(expr)
1121
+ return eval(
1122
+ compile(expr, filename=self.source_file, mode="eval"),
1123
+ {"fn": fn, "args": args, "kwargs": kwargs, "_filter_traceback_": True},
1124
+ )
1125
+
1126
+ def reporting_errors_at_node(self, node: ast.stmt | ast.expr):
1127
+ return ReportingErrorsAtNode(self, node)
1128
+
1129
+ def new_name(self, name: str):
1130
+ self.used_names[name] = self.used_names.get(name, 0) + 1
1131
+ return f"${name}_{self.used_names[name]}"
1132
+
1133
+
1134
+ # Not using @contextmanager so it doesn't end up in tracebacks
1135
+ class ReportingErrorsAtNode:
1136
+ def __init__(self, compiler, node: ast.stmt | ast.expr):
1137
+ self.compiler = compiler
1138
+ self.node = node
1139
+
1140
+ def __enter__(self):
1141
+ return self
1142
+
1143
+ def __exit__(self, exc_type, exc_value, traceback):
1144
+ if exc_type is None:
1145
+ return
1146
+
1147
+ if issubclass(exc_type, CompilationError):
1148
+ raise exc_value from exc_value.__cause__
1149
+
1150
+ if exc_value is not None:
1151
+ self.compiler.raise_exception_at_node(self.node, exc_value)