sutra-dev 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sutra_compiler/__init__.py +49 -0
- sutra_compiler/__main__.py +514 -0
- sutra_compiler/ast_nodes.py +553 -0
- sutra_compiler/codegen.py +1811 -0
- sutra_compiler/codegen_base.py +2436 -0
- sutra_compiler/codegen_pytorch.py +1472 -0
- sutra_compiler/diagnostics.py +145 -0
- sutra_compiler/inliner.py +581 -0
- sutra_compiler/lexer.py +821 -0
- sutra_compiler/parser.py +2112 -0
- sutra_compiler/review.py +322 -0
- sutra_compiler/simplify.py +1046 -0
- sutra_compiler/simplify_egglog.py +674 -0
- sutra_compiler/stdlib/axons.su +53 -0
- sutra_compiler/stdlib/embed.su +48 -0
- sutra_compiler/stdlib/javascript_object.su +18 -0
- sutra_compiler/stdlib/logic.su +202 -0
- sutra_compiler/stdlib/math.su +12 -0
- sutra_compiler/stdlib/memory.su +82 -0
- sutra_compiler/stdlib/numbers.su +99 -0
- sutra_compiler/stdlib/rotation.su +83 -0
- sutra_compiler/stdlib/similarity.su +97 -0
- sutra_compiler/stdlib/strings.su +56 -0
- sutra_compiler/stdlib/tensor.su +82 -0
- sutra_compiler/stdlib/vectors.su +119 -0
- sutra_compiler/stdlib_loader.py +219 -0
- sutra_compiler/sutradb_embedded.py +273 -0
- sutra_compiler/trace.py +135 -0
- sutra_compiler/validator.py +552 -0
- sutra_compiler/workspace.py +655 -0
- sutra_dev-0.2.0.dist-info/METADATA +80 -0
- sutra_dev-0.2.0.dist-info/RECORD +36 -0
- sutra_dev-0.2.0.dist-info/WHEEL +5 -0
- sutra_dev-0.2.0.dist-info/entry_points.txt +2 -0
- sutra_dev-0.2.0.dist-info/licenses/LICENSE +201 -0
- sutra_dev-0.2.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
"""Stdlib function inliner — step 2 of the function-expansion pipeline.
|
|
2
|
+
|
|
3
|
+
Rewrites every `Call(Identifier(name), args)` in a module's AST, for
|
|
4
|
+
each name present in the stdlib symbol table whose function body is
|
|
5
|
+
a single `return <expr>;` statement, by substituting the arguments
|
|
6
|
+
into the parameter slots and replacing the call node with the
|
|
7
|
+
substituted expression.
|
|
8
|
+
|
|
9
|
+
### What this covers today
|
|
10
|
+
|
|
11
|
+
Functions with a single-return-expr body:
|
|
12
|
+
|
|
13
|
+
- logical_not(v) → `0 - v`
|
|
14
|
+
- logical_and(a, b) → `(a + b + a*b - a*a - b*b + a*a*b*b) * 0.5`
|
|
15
|
+
- logical_or(a, b) → `(a + b - a*b + a*a + b*b - a*a*b*b) * 0.5`
|
|
16
|
+
- neq(a, b) → `!(a == b)`
|
|
17
|
+
- lt(a, b) → `b > a`
|
|
18
|
+
- ge(a, b) → `a > b`
|
|
19
|
+
- le(a, b) → `a < b`
|
|
20
|
+
|
|
21
|
+
A user call like `logical_and(p, q)` becomes the polynomial form
|
|
22
|
+
inline; the compiler then sees `(p + q + p*q - p*p - q*q + p*p*q*q)
|
|
23
|
+
* 0.5` and can fold arithmetic, re-bundle, and emit tensor ops
|
|
24
|
+
directly — no call into the runtime's `_VSA.logical_and` method.
|
|
25
|
+
|
|
26
|
+
### What this doesn't cover (yet)
|
|
27
|
+
|
|
28
|
+
- Statement-bodied functions (today: `defuzzy` with its ten-iter
|
|
29
|
+
loop). Inlining a body that contains statements into an
|
|
30
|
+
expression position needs statement-level hoisting — a call
|
|
31
|
+
site has to become a preceding statement block plus a temp
|
|
32
|
+
variable in the expression slot. That's the next extension
|
|
33
|
+
(call it step 2.5) and the prerequisite for step 3's
|
|
34
|
+
loop-unroll propagation.
|
|
35
|
+
|
|
36
|
+
- Intrinsic calls. The stubs in stdlib (eq, gt, make_real,
|
|
37
|
+
complex_mul, bind, bundle, ...) aren't FunctionDecls today —
|
|
38
|
+
they live as commented pseudo-Sutra. When the `@intrinsic`
|
|
39
|
+
mechanism lands (step 5), the inliner will be extended to
|
|
40
|
+
resolve intrinsic calls too, but for now those calls are left
|
|
41
|
+
untouched and continue to hit the hardcoded runtime methods.
|
|
42
|
+
|
|
43
|
+
### Pipeline position
|
|
44
|
+
|
|
45
|
+
Runs before `simplify_module` so that the arithmetic constant
|
|
46
|
+
folding and zero-absorption rewrites in simplify can fold the
|
|
47
|
+
inlined polynomial bodies. Called from `translate_module` in
|
|
48
|
+
`codegen.py` and `codegen_pytorch.py` via `inline_stdlib_calls`.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
from __future__ import annotations
|
|
52
|
+
|
|
53
|
+
import copy
|
|
54
|
+
from typing import Dict, Optional
|
|
55
|
+
|
|
56
|
+
from . import ast_nodes as ast
|
|
57
|
+
from .stdlib_loader import load_stdlib
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
_STDLIB_CACHE: Optional[Dict[str, ast.FunctionDecl]] = None
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _stdlib_table() -> Dict[str, ast.FunctionDecl]:
|
|
64
|
+
"""Load the stdlib once and cache for this process. Re-loading on
|
|
65
|
+
every module is wasteful; the stdlib source doesn't change during
|
|
66
|
+
a compiler session."""
|
|
67
|
+
global _STDLIB_CACHE
|
|
68
|
+
if _STDLIB_CACHE is None:
|
|
69
|
+
_STDLIB_CACHE = load_stdlib()
|
|
70
|
+
return _STDLIB_CACHE
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def inline_stdlib_calls(
|
|
74
|
+
module: ast.Module,
|
|
75
|
+
stdlib_table: Optional[Dict[str, ast.FunctionDecl]] = None,
|
|
76
|
+
) -> ast.Module:
|
|
77
|
+
"""Rewrite operators to stdlib calls, then inline every stdlib
|
|
78
|
+
call whose target has a single-return-expr body. Mutates
|
|
79
|
+
`module` in place and returns it.
|
|
80
|
+
|
|
81
|
+
Pass `stdlib_table` explicitly to test against a synthetic stdlib;
|
|
82
|
+
otherwise the real `sutra_compiler/stdlib/` is loaded and cached.
|
|
83
|
+
"""
|
|
84
|
+
table = stdlib_table if stdlib_table is not None else _stdlib_table()
|
|
85
|
+
inlineable = {
|
|
86
|
+
name: decl
|
|
87
|
+
for name, decl in table.items()
|
|
88
|
+
if _is_single_return_expr(decl)
|
|
89
|
+
}
|
|
90
|
+
# Step 2.6 — lower operators to stdlib calls for the ones with
|
|
91
|
+
# stdlib bodies. After this pass, `a && b` is a Call to
|
|
92
|
+
# logical_and, `!v` is a Call to logical_not, etc. — and the
|
|
93
|
+
# inliner below expands them uniformly with direct user calls.
|
|
94
|
+
_lower_operators_to_stdlib_calls(module, inlineable)
|
|
95
|
+
# Step 2 — inline stdlib calls.
|
|
96
|
+
for item in module.items:
|
|
97
|
+
_rewrite_top_level(item, inlineable)
|
|
98
|
+
return module
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _is_single_return_expr(decl: ast.FunctionDecl) -> bool:
|
|
102
|
+
"""True iff the function body is exactly one `return <expr>;`
|
|
103
|
+
statement with a non-None value. These are the functions step 2
|
|
104
|
+
can inline today."""
|
|
105
|
+
stmts = decl.body.statements
|
|
106
|
+
if len(stmts) != 1:
|
|
107
|
+
return False
|
|
108
|
+
stmt = stmts[0]
|
|
109
|
+
return isinstance(stmt, ast.ReturnStmt) and stmt.value is not None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
# ---------------------------------------------------------------------------
|
|
113
|
+
# Top-level and statement walk — finds expressions to rewrite
|
|
114
|
+
# ---------------------------------------------------------------------------
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def _rewrite_top_level(item, table) -> None:
|
|
118
|
+
if isinstance(item, (ast.FunctionDecl, ast.MethodDecl)):
|
|
119
|
+
_rewrite_block(item.body, table)
|
|
120
|
+
elif isinstance(item, ast.LoopFunctionDecl):
|
|
121
|
+
# Walk the loop function's condition + body so stdlib calls
|
|
122
|
+
# inside (like `<` → `lt(a,b)` → `b > a`) get inlined.
|
|
123
|
+
item.condition = _rewrite_expr(item.condition, table)
|
|
124
|
+
_rewrite_block(item.body, table)
|
|
125
|
+
elif isinstance(item, ast.ClassDecl):
|
|
126
|
+
for m in item.methods:
|
|
127
|
+
_rewrite_block(m.body, table)
|
|
128
|
+
for lf in item.loop_functions:
|
|
129
|
+
lf.condition = _rewrite_expr(lf.condition, table)
|
|
130
|
+
_rewrite_block(lf.body, table)
|
|
131
|
+
elif isinstance(item, ast.VarDecl):
|
|
132
|
+
if item.initializer is not None:
|
|
133
|
+
item.initializer = _rewrite_expr(item.initializer, table)
|
|
134
|
+
elif isinstance(item, ast.Stmt):
|
|
135
|
+
_rewrite_stmt(item, table)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _rewrite_block(block: ast.Block, table) -> None:
|
|
139
|
+
for stmt in block.statements:
|
|
140
|
+
_rewrite_stmt(stmt, table)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _rewrite_stmt(stmt, table) -> None:
|
|
144
|
+
if isinstance(stmt, ast.VarDecl):
|
|
145
|
+
if stmt.initializer is not None:
|
|
146
|
+
stmt.initializer = _rewrite_expr(stmt.initializer, table)
|
|
147
|
+
elif isinstance(stmt, ast.ReturnStmt):
|
|
148
|
+
if stmt.value is not None:
|
|
149
|
+
stmt.value = _rewrite_expr(stmt.value, table)
|
|
150
|
+
elif isinstance(stmt, ast.ExprStmt):
|
|
151
|
+
stmt.expr = _rewrite_expr(stmt.expr, table)
|
|
152
|
+
elif isinstance(stmt, ast.Assignment):
|
|
153
|
+
stmt.target = _rewrite_expr(stmt.target, table)
|
|
154
|
+
stmt.value = _rewrite_expr(stmt.value, table)
|
|
155
|
+
elif isinstance(stmt, ast.IfStmt):
|
|
156
|
+
stmt.condition = _rewrite_expr(stmt.condition, table)
|
|
157
|
+
_rewrite_block(stmt.then_branch, table)
|
|
158
|
+
if stmt.else_branch is not None:
|
|
159
|
+
if isinstance(stmt.else_branch, ast.IfStmt):
|
|
160
|
+
_rewrite_stmt(stmt.else_branch, table)
|
|
161
|
+
else:
|
|
162
|
+
_rewrite_block(stmt.else_branch, table)
|
|
163
|
+
elif isinstance(stmt, ast.WhileStmt):
|
|
164
|
+
stmt.condition = _rewrite_expr(stmt.condition, table)
|
|
165
|
+
_rewrite_block(stmt.body, table)
|
|
166
|
+
elif isinstance(stmt, ast.DoWhileStmt):
|
|
167
|
+
_rewrite_block(stmt.body, table)
|
|
168
|
+
stmt.condition = _rewrite_expr(stmt.condition, table)
|
|
169
|
+
elif isinstance(stmt, ast.ForStmt):
|
|
170
|
+
if stmt.init is not None:
|
|
171
|
+
_rewrite_stmt(stmt.init, table)
|
|
172
|
+
if stmt.condition is not None:
|
|
173
|
+
stmt.condition = _rewrite_expr(stmt.condition, table)
|
|
174
|
+
if stmt.step is not None:
|
|
175
|
+
_rewrite_stmt(stmt.step, table)
|
|
176
|
+
_rewrite_block(stmt.body, table)
|
|
177
|
+
elif isinstance(stmt, ast.ForeachStmt):
|
|
178
|
+
stmt.iterable = _rewrite_expr(stmt.iterable, table)
|
|
179
|
+
_rewrite_block(stmt.body, table)
|
|
180
|
+
elif isinstance(stmt, ast.LoopStmt):
|
|
181
|
+
if stmt.count is not None:
|
|
182
|
+
stmt.count = _rewrite_expr(stmt.count, table)
|
|
183
|
+
if stmt.condition is not None:
|
|
184
|
+
stmt.condition = _rewrite_expr(stmt.condition, table)
|
|
185
|
+
_rewrite_block(stmt.body, table)
|
|
186
|
+
elif isinstance(stmt, ast.PassStmt):
|
|
187
|
+
# Each pass value is either an Expr or a ReplaceMarker.
|
|
188
|
+
# Rewrite expressions in place; ReplaceMarker is a leaf.
|
|
189
|
+
for i, val in enumerate(stmt.values):
|
|
190
|
+
if not isinstance(val, ast.ReplaceMarker):
|
|
191
|
+
stmt.values[i] = _rewrite_expr(val, table)
|
|
192
|
+
elif isinstance(stmt, ast.LoopCallStmt):
|
|
193
|
+
stmt.condition_arg = _rewrite_expr(stmt.condition_arg, table)
|
|
194
|
+
elif isinstance(stmt, ast.TryStmt):
|
|
195
|
+
_rewrite_block(stmt.try_block, table)
|
|
196
|
+
for clause in stmt.catches:
|
|
197
|
+
_rewrite_block(clause.body, table)
|
|
198
|
+
elif isinstance(stmt, ast.Block):
|
|
199
|
+
_rewrite_block(stmt, table)
|
|
200
|
+
# BreakStmt / ContinueStmt carry no expressions.
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# ---------------------------------------------------------------------------
|
|
204
|
+
# Expression rewrite — post-order, inlines Call nodes at the bottom
|
|
205
|
+
# ---------------------------------------------------------------------------
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def _rewrite_expr(expr, table):
|
|
209
|
+
"""Post-order: rewrite children first, then consider the node
|
|
210
|
+
itself for inlining. Returns the (possibly replaced) expression."""
|
|
211
|
+
if expr is None:
|
|
212
|
+
return None
|
|
213
|
+
|
|
214
|
+
if isinstance(expr, ast.Call):
|
|
215
|
+
expr.callee = _rewrite_expr(expr.callee, table)
|
|
216
|
+
expr.args = [_rewrite_expr(a, table) for a in expr.args]
|
|
217
|
+
if (isinstance(expr.callee, ast.Identifier)
|
|
218
|
+
and expr.callee.name in table):
|
|
219
|
+
return _do_inline(expr, table[expr.callee.name], table)
|
|
220
|
+
return expr
|
|
221
|
+
|
|
222
|
+
if isinstance(expr, ast.BinaryOp):
|
|
223
|
+
expr.left = _rewrite_expr(expr.left, table)
|
|
224
|
+
expr.right = _rewrite_expr(expr.right, table)
|
|
225
|
+
return expr
|
|
226
|
+
if isinstance(expr, ast.UnaryOp):
|
|
227
|
+
expr.operand = _rewrite_expr(expr.operand, table)
|
|
228
|
+
return expr
|
|
229
|
+
if isinstance(expr, ast.PostfixOp):
|
|
230
|
+
expr.operand = _rewrite_expr(expr.operand, table)
|
|
231
|
+
return expr
|
|
232
|
+
if isinstance(expr, ast.Parenthesized):
|
|
233
|
+
expr.inner = _rewrite_expr(expr.inner, table)
|
|
234
|
+
return expr
|
|
235
|
+
if isinstance(expr, ast.Subscript):
|
|
236
|
+
expr.target = _rewrite_expr(expr.target, table)
|
|
237
|
+
expr.index = _rewrite_expr(expr.index, table)
|
|
238
|
+
return expr
|
|
239
|
+
if isinstance(expr, ast.MemberAccess):
|
|
240
|
+
expr.obj = _rewrite_expr(expr.obj, table)
|
|
241
|
+
return expr
|
|
242
|
+
if isinstance(expr, ast.Assignment):
|
|
243
|
+
expr.target = _rewrite_expr(expr.target, table)
|
|
244
|
+
expr.value = _rewrite_expr(expr.value, table)
|
|
245
|
+
return expr
|
|
246
|
+
if isinstance(expr, ast.CastExpr):
|
|
247
|
+
expr.expr = _rewrite_expr(expr.expr, table)
|
|
248
|
+
return expr
|
|
249
|
+
if isinstance(expr, ast.UnsafeCastExpr):
|
|
250
|
+
expr.expr = _rewrite_expr(expr.expr, table)
|
|
251
|
+
return expr
|
|
252
|
+
if isinstance(expr, ast.UnsafeOverrideExpr):
|
|
253
|
+
expr.expr = _rewrite_expr(expr.expr, table)
|
|
254
|
+
return expr
|
|
255
|
+
if isinstance(expr, ast.DefuzzyExpr):
|
|
256
|
+
expr.expr = _rewrite_expr(expr.expr, table)
|
|
257
|
+
return expr
|
|
258
|
+
if isinstance(expr, ast.EmbedExpr):
|
|
259
|
+
expr.expr = _rewrite_expr(expr.expr, table)
|
|
260
|
+
return expr
|
|
261
|
+
if isinstance(expr, ast.ArrayLiteral):
|
|
262
|
+
expr.elements = [_rewrite_expr(e, table) for e in expr.elements]
|
|
263
|
+
return expr
|
|
264
|
+
if isinstance(expr, ast.MapLiteral):
|
|
265
|
+
expr.keys = [_rewrite_expr(k, table) for k in expr.keys]
|
|
266
|
+
expr.values = [_rewrite_expr(v, table) for v in expr.values]
|
|
267
|
+
return expr
|
|
268
|
+
if isinstance(expr, ast.InterpolatedString):
|
|
269
|
+
expr.parts = [
|
|
270
|
+
part if isinstance(part, str) else _rewrite_expr(part, table)
|
|
271
|
+
for part in expr.parts
|
|
272
|
+
]
|
|
273
|
+
return expr
|
|
274
|
+
|
|
275
|
+
# Leaves: Identifier, IntLiteral, FloatLiteral, StringLiteral,
|
|
276
|
+
# CharLiteral, BoolLiteral, UnknownLiteral, ComplexLiteral,
|
|
277
|
+
# ImaginaryLiteral, TypeRef, etc. Return unchanged.
|
|
278
|
+
return expr
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
# ---------------------------------------------------------------------------
|
|
282
|
+
# The actual inline: param-arg substitution into a cloned body
|
|
283
|
+
# ---------------------------------------------------------------------------
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def _do_inline(call: ast.Call, decl: ast.FunctionDecl, table=None):
|
|
287
|
+
"""Return the inlined expression for `call`, or the original call
|
|
288
|
+
unchanged if arities disagree (let the validator flag it).
|
|
289
|
+
|
|
290
|
+
After substituting params into the body we re-run the rewriter on
|
|
291
|
+
the result, so inlined bodies that themselves contain stdlib
|
|
292
|
+
calls get fully expanded in one pass. Today's stdlib has no
|
|
293
|
+
recursion, so this terminates trivially; a future `@intrinsic`
|
|
294
|
+
form that's self-referential would need a depth guard."""
|
|
295
|
+
if len(call.args) != len(decl.params):
|
|
296
|
+
return call
|
|
297
|
+
|
|
298
|
+
# Deep-copy the return expression so we don't alias the stdlib
|
|
299
|
+
# AST across call sites.
|
|
300
|
+
return_stmt: ast.ReturnStmt = decl.body.statements[0]
|
|
301
|
+
body_expr = copy.deepcopy(return_stmt.value)
|
|
302
|
+
|
|
303
|
+
subst = {
|
|
304
|
+
param.name: arg
|
|
305
|
+
for param, arg in zip(decl.params, call.args)
|
|
306
|
+
}
|
|
307
|
+
substituted = _substitute_params(body_expr, subst)
|
|
308
|
+
# Recurse: the substituted body may contain operators or stdlib
|
|
309
|
+
# calls that still need lowering/inlining. e.g. neq's body is
|
|
310
|
+
# `!(a == b)` — the `!` is a UnaryOp that needs operator-lowering
|
|
311
|
+
# into a logical_not Call, which itself then inlines. A single
|
|
312
|
+
# pre-order pass over user code wouldn't see this `!` because it
|
|
313
|
+
# only appeared after inlining.
|
|
314
|
+
if table is not None:
|
|
315
|
+
substituted = _lower_ops_expr(substituted, table)
|
|
316
|
+
return _rewrite_expr(substituted, table)
|
|
317
|
+
return substituted
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def _substitute_params(expr, subst: Dict[str, object]):
|
|
321
|
+
"""Replace `Identifier(name)` with `copy.deepcopy(subst[name])`
|
|
322
|
+
wherever `name` is a parameter. Each occurrence gets its own copy
|
|
323
|
+
so downstream mutation doesn't create AST aliasing across the
|
|
324
|
+
different use-sites."""
|
|
325
|
+
if expr is None:
|
|
326
|
+
return None
|
|
327
|
+
|
|
328
|
+
if isinstance(expr, ast.Identifier):
|
|
329
|
+
if expr.name in subst:
|
|
330
|
+
return copy.deepcopy(subst[expr.name])
|
|
331
|
+
return expr
|
|
332
|
+
|
|
333
|
+
if isinstance(expr, ast.BinaryOp):
|
|
334
|
+
expr.left = _substitute_params(expr.left, subst)
|
|
335
|
+
expr.right = _substitute_params(expr.right, subst)
|
|
336
|
+
return expr
|
|
337
|
+
if isinstance(expr, ast.UnaryOp):
|
|
338
|
+
expr.operand = _substitute_params(expr.operand, subst)
|
|
339
|
+
return expr
|
|
340
|
+
if isinstance(expr, ast.PostfixOp):
|
|
341
|
+
expr.operand = _substitute_params(expr.operand, subst)
|
|
342
|
+
return expr
|
|
343
|
+
if isinstance(expr, ast.Parenthesized):
|
|
344
|
+
expr.inner = _substitute_params(expr.inner, subst)
|
|
345
|
+
return expr
|
|
346
|
+
if isinstance(expr, ast.Call):
|
|
347
|
+
expr.callee = _substitute_params(expr.callee, subst)
|
|
348
|
+
expr.args = [_substitute_params(a, subst) for a in expr.args]
|
|
349
|
+
return expr
|
|
350
|
+
if isinstance(expr, ast.Subscript):
|
|
351
|
+
expr.target = _substitute_params(expr.target, subst)
|
|
352
|
+
expr.index = _substitute_params(expr.index, subst)
|
|
353
|
+
return expr
|
|
354
|
+
if isinstance(expr, ast.MemberAccess):
|
|
355
|
+
expr.obj = _substitute_params(expr.obj, subst)
|
|
356
|
+
return expr
|
|
357
|
+
if isinstance(expr, ast.CastExpr):
|
|
358
|
+
expr.expr = _substitute_params(expr.expr, subst)
|
|
359
|
+
return expr
|
|
360
|
+
if isinstance(expr, ast.UnsafeCastExpr):
|
|
361
|
+
expr.expr = _substitute_params(expr.expr, subst)
|
|
362
|
+
return expr
|
|
363
|
+
if isinstance(expr, ast.UnsafeOverrideExpr):
|
|
364
|
+
expr.expr = _substitute_params(expr.expr, subst)
|
|
365
|
+
return expr
|
|
366
|
+
if isinstance(expr, ast.DefuzzyExpr):
|
|
367
|
+
expr.expr = _substitute_params(expr.expr, subst)
|
|
368
|
+
return expr
|
|
369
|
+
if isinstance(expr, ast.EmbedExpr):
|
|
370
|
+
expr.expr = _substitute_params(expr.expr, subst)
|
|
371
|
+
return expr
|
|
372
|
+
if isinstance(expr, ast.ArrayLiteral):
|
|
373
|
+
expr.elements = [_substitute_params(e, subst) for e in expr.elements]
|
|
374
|
+
return expr
|
|
375
|
+
if isinstance(expr, ast.MapLiteral):
|
|
376
|
+
expr.keys = [_substitute_params(k, subst) for k in expr.keys]
|
|
377
|
+
expr.values = [_substitute_params(v, subst) for v in expr.values]
|
|
378
|
+
return expr
|
|
379
|
+
|
|
380
|
+
# Literals and everything else: no substitution.
|
|
381
|
+
return expr
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
# ---------------------------------------------------------------------------
|
|
385
|
+
# Step 2.6 — operator lowering to stdlib calls
|
|
386
|
+
# ---------------------------------------------------------------------------
|
|
387
|
+
#
|
|
388
|
+
# Rewrite `!v`, `a && b`, `a || b`, `a != b`, `a < b`, `a <= b`, `a >= b`
|
|
389
|
+
# into explicit Call nodes targeting their stdlib counterparts. After this
|
|
390
|
+
# runs, the inliner pass below sees a uniform Call shape whether the user
|
|
391
|
+
# wrote `logical_and(a, b)` or `a && b`. The operators that don't have a
|
|
392
|
+
# stdlib body today (`==`, `>`) are left alone — they continue to compile
|
|
393
|
+
# through the hardcoded runtime methods until their stdlib forms land
|
|
394
|
+
# (blocked on eq/gt intrinsics).
|
|
395
|
+
#
|
|
396
|
+
# Operators that aren't part of this set — `+`, `-`, `*`, `/`, etc. — are
|
|
397
|
+
# not candidates for stdlib lowering. They're not "logic ops with a
|
|
398
|
+
# Sutra-source definition;" they're primitive tensor arithmetic the
|
|
399
|
+
# codegen knows how to emit directly.
|
|
400
|
+
|
|
401
|
+
_BINARY_OP_TO_STDLIB = {
|
|
402
|
+
"&&": "logical_and",
|
|
403
|
+
"||": "logical_or",
|
|
404
|
+
"nand": "logical_nand",
|
|
405
|
+
"xor": "logical_xor",
|
|
406
|
+
"xnor": "logical_xnor",
|
|
407
|
+
"!=": "neq",
|
|
408
|
+
"<": "lt",
|
|
409
|
+
"<=": "le",
|
|
410
|
+
">=": "ge",
|
|
411
|
+
}
|
|
412
|
+
_UNARY_OP_TO_STDLIB = {
|
|
413
|
+
"!": "logical_not",
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
|
|
417
|
+
def _lower_operators_to_stdlib_calls(
|
|
418
|
+
module: ast.Module, inlineable: Dict[str, ast.FunctionDecl]
|
|
419
|
+
) -> None:
|
|
420
|
+
"""Walk the module and replace each operator node whose stdlib
|
|
421
|
+
counterpart is present and inlineable with a Call to it. Filters
|
|
422
|
+
by the inlineable set so operators without a stdlib body (or
|
|
423
|
+
stdlib body that can't yet be inlined) stay as operators."""
|
|
424
|
+
for item in module.items:
|
|
425
|
+
_lower_ops_top_level(item, inlineable)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _lower_ops_top_level(item, inlineable) -> None:
|
|
429
|
+
if isinstance(item, (ast.FunctionDecl, ast.MethodDecl)):
|
|
430
|
+
_lower_ops_block(item.body, inlineable)
|
|
431
|
+
elif isinstance(item, ast.LoopFunctionDecl):
|
|
432
|
+
item.condition = _lower_ops_expr(item.condition, inlineable)
|
|
433
|
+
_lower_ops_block(item.body, inlineable)
|
|
434
|
+
elif isinstance(item, ast.ClassDecl):
|
|
435
|
+
for m in item.methods:
|
|
436
|
+
_lower_ops_block(m.body, inlineable)
|
|
437
|
+
for lf in item.loop_functions:
|
|
438
|
+
lf.condition = _lower_ops_expr(lf.condition, inlineable)
|
|
439
|
+
_lower_ops_block(lf.body, inlineable)
|
|
440
|
+
elif isinstance(item, ast.VarDecl):
|
|
441
|
+
if item.initializer is not None:
|
|
442
|
+
item.initializer = _lower_ops_expr(item.initializer, inlineable)
|
|
443
|
+
elif isinstance(item, ast.Stmt):
|
|
444
|
+
_lower_ops_stmt(item, inlineable)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def _lower_ops_block(block: ast.Block, inlineable) -> None:
|
|
448
|
+
for stmt in block.statements:
|
|
449
|
+
_lower_ops_stmt(stmt, inlineable)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def _lower_ops_stmt(stmt, inlineable) -> None:
|
|
453
|
+
if isinstance(stmt, ast.VarDecl):
|
|
454
|
+
if stmt.initializer is not None:
|
|
455
|
+
stmt.initializer = _lower_ops_expr(stmt.initializer, inlineable)
|
|
456
|
+
elif isinstance(stmt, ast.ReturnStmt):
|
|
457
|
+
if stmt.value is not None:
|
|
458
|
+
stmt.value = _lower_ops_expr(stmt.value, inlineable)
|
|
459
|
+
elif isinstance(stmt, ast.ExprStmt):
|
|
460
|
+
stmt.expr = _lower_ops_expr(stmt.expr, inlineable)
|
|
461
|
+
elif isinstance(stmt, ast.Assignment):
|
|
462
|
+
stmt.target = _lower_ops_expr(stmt.target, inlineable)
|
|
463
|
+
stmt.value = _lower_ops_expr(stmt.value, inlineable)
|
|
464
|
+
elif isinstance(stmt, ast.IfStmt):
|
|
465
|
+
stmt.condition = _lower_ops_expr(stmt.condition, inlineable)
|
|
466
|
+
_lower_ops_block(stmt.then_branch, inlineable)
|
|
467
|
+
if stmt.else_branch is not None:
|
|
468
|
+
if isinstance(stmt.else_branch, ast.IfStmt):
|
|
469
|
+
_lower_ops_stmt(stmt.else_branch, inlineable)
|
|
470
|
+
else:
|
|
471
|
+
_lower_ops_block(stmt.else_branch, inlineable)
|
|
472
|
+
elif isinstance(stmt, ast.WhileStmt):
|
|
473
|
+
stmt.condition = _lower_ops_expr(stmt.condition, inlineable)
|
|
474
|
+
_lower_ops_block(stmt.body, inlineable)
|
|
475
|
+
elif isinstance(stmt, ast.DoWhileStmt):
|
|
476
|
+
_lower_ops_block(stmt.body, inlineable)
|
|
477
|
+
stmt.condition = _lower_ops_expr(stmt.condition, inlineable)
|
|
478
|
+
elif isinstance(stmt, ast.ForStmt):
|
|
479
|
+
if stmt.init is not None:
|
|
480
|
+
_lower_ops_stmt(stmt.init, inlineable)
|
|
481
|
+
if stmt.condition is not None:
|
|
482
|
+
stmt.condition = _lower_ops_expr(stmt.condition, inlineable)
|
|
483
|
+
if stmt.step is not None:
|
|
484
|
+
_lower_ops_stmt(stmt.step, inlineable)
|
|
485
|
+
_lower_ops_block(stmt.body, inlineable)
|
|
486
|
+
elif isinstance(stmt, ast.ForeachStmt):
|
|
487
|
+
stmt.iterable = _lower_ops_expr(stmt.iterable, inlineable)
|
|
488
|
+
_lower_ops_block(stmt.body, inlineable)
|
|
489
|
+
elif isinstance(stmt, ast.LoopStmt):
|
|
490
|
+
if stmt.count is not None:
|
|
491
|
+
stmt.count = _lower_ops_expr(stmt.count, inlineable)
|
|
492
|
+
if stmt.condition is not None:
|
|
493
|
+
stmt.condition = _lower_ops_expr(stmt.condition, inlineable)
|
|
494
|
+
_lower_ops_block(stmt.body, inlineable)
|
|
495
|
+
elif isinstance(stmt, ast.PassStmt):
|
|
496
|
+
for i, val in enumerate(stmt.values):
|
|
497
|
+
if not isinstance(val, ast.ReplaceMarker):
|
|
498
|
+
stmt.values[i] = _lower_ops_expr(val, inlineable)
|
|
499
|
+
elif isinstance(stmt, ast.LoopCallStmt):
|
|
500
|
+
stmt.condition_arg = _lower_ops_expr(stmt.condition_arg, inlineable)
|
|
501
|
+
elif isinstance(stmt, ast.TryStmt):
|
|
502
|
+
_lower_ops_block(stmt.try_block, inlineable)
|
|
503
|
+
for clause in stmt.catches:
|
|
504
|
+
_lower_ops_block(clause.body, inlineable)
|
|
505
|
+
elif isinstance(stmt, ast.Block):
|
|
506
|
+
_lower_ops_block(stmt, inlineable)
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def _lower_ops_expr(expr, inlineable):
|
|
510
|
+
if expr is None:
|
|
511
|
+
return None
|
|
512
|
+
|
|
513
|
+
# Recurse into children first (post-order).
|
|
514
|
+
if isinstance(expr, ast.BinaryOp):
|
|
515
|
+
expr.left = _lower_ops_expr(expr.left, inlineable)
|
|
516
|
+
expr.right = _lower_ops_expr(expr.right, inlineable)
|
|
517
|
+
stdlib_name = _BINARY_OP_TO_STDLIB.get(expr.op)
|
|
518
|
+
if stdlib_name is not None and stdlib_name in inlineable:
|
|
519
|
+
return ast.Call(
|
|
520
|
+
callee=ast.Identifier(name=stdlib_name, span=expr.span),
|
|
521
|
+
type_args=[],
|
|
522
|
+
args=[expr.left, expr.right],
|
|
523
|
+
span=expr.span,
|
|
524
|
+
)
|
|
525
|
+
return expr
|
|
526
|
+
|
|
527
|
+
if isinstance(expr, ast.UnaryOp):
|
|
528
|
+
expr.operand = _lower_ops_expr(expr.operand, inlineable)
|
|
529
|
+
stdlib_name = _UNARY_OP_TO_STDLIB.get(expr.op)
|
|
530
|
+
if stdlib_name is not None and stdlib_name in inlineable:
|
|
531
|
+
return ast.Call(
|
|
532
|
+
callee=ast.Identifier(name=stdlib_name, span=expr.span),
|
|
533
|
+
type_args=[],
|
|
534
|
+
args=[expr.operand],
|
|
535
|
+
span=expr.span,
|
|
536
|
+
)
|
|
537
|
+
return expr
|
|
538
|
+
|
|
539
|
+
# Structural recursion for every other expression shape.
|
|
540
|
+
if isinstance(expr, ast.PostfixOp):
|
|
541
|
+
expr.operand = _lower_ops_expr(expr.operand, inlineable)
|
|
542
|
+
return expr
|
|
543
|
+
if isinstance(expr, ast.Call):
|
|
544
|
+
expr.callee = _lower_ops_expr(expr.callee, inlineable)
|
|
545
|
+
expr.args = [_lower_ops_expr(a, inlineable) for a in expr.args]
|
|
546
|
+
return expr
|
|
547
|
+
if isinstance(expr, ast.Parenthesized):
|
|
548
|
+
expr.inner = _lower_ops_expr(expr.inner, inlineable)
|
|
549
|
+
return expr
|
|
550
|
+
if isinstance(expr, ast.Subscript):
|
|
551
|
+
expr.target = _lower_ops_expr(expr.target, inlineable)
|
|
552
|
+
expr.index = _lower_ops_expr(expr.index, inlineable)
|
|
553
|
+
return expr
|
|
554
|
+
if isinstance(expr, ast.MemberAccess):
|
|
555
|
+
expr.obj = _lower_ops_expr(expr.obj, inlineable)
|
|
556
|
+
return expr
|
|
557
|
+
if isinstance(expr, ast.Assignment):
|
|
558
|
+
expr.target = _lower_ops_expr(expr.target, inlineable)
|
|
559
|
+
expr.value = _lower_ops_expr(expr.value, inlineable)
|
|
560
|
+
return expr
|
|
561
|
+
if isinstance(expr, (ast.CastExpr, ast.UnsafeCastExpr,
|
|
562
|
+
ast.UnsafeOverrideExpr, ast.DefuzzyExpr,
|
|
563
|
+
ast.EmbedExpr)):
|
|
564
|
+
expr.expr = _lower_ops_expr(expr.expr, inlineable)
|
|
565
|
+
return expr
|
|
566
|
+
if isinstance(expr, ast.ArrayLiteral):
|
|
567
|
+
expr.elements = [_lower_ops_expr(e, inlineable) for e in expr.elements]
|
|
568
|
+
return expr
|
|
569
|
+
if isinstance(expr, ast.MapLiteral):
|
|
570
|
+
expr.keys = [_lower_ops_expr(k, inlineable) for k in expr.keys]
|
|
571
|
+
expr.values = [_lower_ops_expr(v, inlineable) for v in expr.values]
|
|
572
|
+
return expr
|
|
573
|
+
if isinstance(expr, ast.InterpolatedString):
|
|
574
|
+
expr.parts = [
|
|
575
|
+
part if isinstance(part, str) else _lower_ops_expr(part, inlineable)
|
|
576
|
+
for part in expr.parts
|
|
577
|
+
]
|
|
578
|
+
return expr
|
|
579
|
+
|
|
580
|
+
# Leaves (Identifier, literals, TypeRef) — no-op.
|
|
581
|
+
return expr
|