egglog 7.0.0__cp311-none-win_amd64.whl → 7.2.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +7 -0
- egglog/builtins.py +41 -1
- egglog/conversion.py +22 -17
- egglog/declarations.py +122 -37
- egglog/egraph.py +219 -78
- egglog/egraph_state.py +124 -54
- egglog/examples/higher_order_functions.py +50 -0
- egglog/exp/array_api.py +12 -9
- egglog/pretty.py +71 -15
- egglog/runtime.py +118 -33
- egglog/thunk.py +17 -6
- egglog/type_constraint_solver.py +5 -4
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/METADATA +10 -10
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/RECORD +17 -16
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/WHEEL +0 -0
- {egglog-7.0.0.dist-info → egglog-7.2.0.dist-info}/license_files/LICENSE +0 -0
egglog/egraph_state.py
CHANGED
|
@@ -19,7 +19,7 @@ from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
|
|
|
19
19
|
if TYPE_CHECKING:
|
|
20
20
|
from collections.abc import Iterable
|
|
21
21
|
|
|
22
|
-
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
|
|
22
|
+
__all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT", "_rule_var_name"]
|
|
23
23
|
|
|
24
24
|
# Create a global sort for python objects, so we can store them without an e-graph instance
|
|
25
25
|
# Needed when serializing commands to egg commands when creating modules
|
|
@@ -87,17 +87,27 @@ class EGraphState:
|
|
|
87
87
|
"""
|
|
88
88
|
Registers a ruleset if it's not already registered.
|
|
89
89
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
self.
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
90
|
+
match self.__egg_decls__._rulesets[name]:
|
|
91
|
+
case RulesetDecl(rules):
|
|
92
|
+
if name not in self.rulesets:
|
|
93
|
+
if name:
|
|
94
|
+
self.egraph.run_program(bindings.AddRuleset(name))
|
|
95
|
+
added_rules = self.rulesets[name] = set()
|
|
96
|
+
else:
|
|
97
|
+
added_rules = self.rulesets[name]
|
|
98
|
+
for rule in rules:
|
|
99
|
+
if rule in added_rules:
|
|
100
|
+
continue
|
|
101
|
+
cmd = self.command_to_egg(rule, name)
|
|
102
|
+
self.egraph.run_program(cmd)
|
|
103
|
+
added_rules.add(rule)
|
|
104
|
+
case CombinedRulesetDecl(rulesets):
|
|
105
|
+
if name in self.rulesets:
|
|
106
|
+
return
|
|
107
|
+
self.rulesets[name] = set()
|
|
108
|
+
for ruleset in rulesets:
|
|
109
|
+
self.ruleset_to_egg(ruleset)
|
|
110
|
+
self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
|
|
101
111
|
|
|
102
112
|
def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
|
|
103
113
|
match cmd:
|
|
@@ -106,8 +116,8 @@ class EGraphState:
|
|
|
106
116
|
case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
|
|
107
117
|
self.type_ref_to_egg(tp)
|
|
108
118
|
rewrite = bindings.Rewrite(
|
|
109
|
-
self.
|
|
110
|
-
self.
|
|
119
|
+
self._expr_to_egg(lhs),
|
|
120
|
+
self._expr_to_egg(rhs),
|
|
111
121
|
[self.fact_to_egg(c) for c in conditions],
|
|
112
122
|
)
|
|
113
123
|
return (
|
|
@@ -121,6 +131,16 @@ class EGraphState:
|
|
|
121
131
|
[self.fact_to_egg(f) for f in body],
|
|
122
132
|
)
|
|
123
133
|
return bindings.RuleCommand(name or "", ruleset, rule)
|
|
134
|
+
case DefaultRewriteDecl(ref, expr):
|
|
135
|
+
decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
|
|
136
|
+
sig = decl.signature
|
|
137
|
+
assert isinstance(sig, FunctionSignature)
|
|
138
|
+
args = tuple(
|
|
139
|
+
TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name)))
|
|
140
|
+
for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
|
|
141
|
+
)
|
|
142
|
+
rewrite_decl = RewriteDecl(sig.semantic_return_type.to_just(), CallDecl(ref, args), expr, (), False)
|
|
143
|
+
return self.command_to_egg(rewrite_decl, ruleset)
|
|
124
144
|
case _:
|
|
125
145
|
assert_never(cmd)
|
|
126
146
|
|
|
@@ -130,13 +150,13 @@ class EGraphState:
|
|
|
130
150
|
return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
|
|
131
151
|
case SetDecl(tp, call, rhs):
|
|
132
152
|
self.type_ref_to_egg(tp)
|
|
133
|
-
call_ = self.
|
|
134
|
-
return bindings.Set(call_.name, call_.args, self.
|
|
153
|
+
call_ = self._expr_to_egg(call)
|
|
154
|
+
return bindings.Set(call_.name, call_.args, self._expr_to_egg(rhs))
|
|
135
155
|
case ExprActionDecl(typed_expr):
|
|
136
156
|
return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
|
|
137
157
|
case ChangeDecl(tp, call, change):
|
|
138
158
|
self.type_ref_to_egg(tp)
|
|
139
|
-
call_ = self.
|
|
159
|
+
call_ = self._expr_to_egg(call)
|
|
140
160
|
egg_change: bindings._Change
|
|
141
161
|
match change:
|
|
142
162
|
case "delete":
|
|
@@ -148,7 +168,7 @@ class EGraphState:
|
|
|
148
168
|
return bindings.Change(egg_change, call_.name, call_.args)
|
|
149
169
|
case UnionDecl(tp, lhs, rhs):
|
|
150
170
|
self.type_ref_to_egg(tp)
|
|
151
|
-
return bindings.Union(self.
|
|
171
|
+
return bindings.Union(self._expr_to_egg(lhs), self._expr_to_egg(rhs))
|
|
152
172
|
case PanicDecl(name):
|
|
153
173
|
return bindings.Panic(name)
|
|
154
174
|
case _:
|
|
@@ -158,7 +178,7 @@ class EGraphState:
|
|
|
158
178
|
match fact:
|
|
159
179
|
case EqDecl(tp, exprs):
|
|
160
180
|
self.type_ref_to_egg(tp)
|
|
161
|
-
return bindings.Eq([self.
|
|
181
|
+
return bindings.Eq([self._expr_to_egg(e) for e in exprs])
|
|
162
182
|
case ExprFactDecl(typed_expr):
|
|
163
183
|
return bindings.Fact(self.typed_expr_to_egg(typed_expr))
|
|
164
184
|
case _:
|
|
@@ -184,14 +204,16 @@ class EGraphState:
|
|
|
184
204
|
)
|
|
185
205
|
case FunctionDecl():
|
|
186
206
|
if not decl.builtin:
|
|
207
|
+
signature = decl.signature
|
|
208
|
+
assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
|
|
187
209
|
egg_fn_decl = bindings.FunctionDecl(
|
|
188
210
|
egg_name,
|
|
189
211
|
bindings.Schema(
|
|
190
|
-
[self.type_ref_to_egg(a.to_just()) for a in
|
|
191
|
-
self.type_ref_to_egg(
|
|
212
|
+
[self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
|
|
213
|
+
self.type_ref_to_egg(signature.semantic_return_type.to_just()),
|
|
192
214
|
),
|
|
193
|
-
self.
|
|
194
|
-
self.
|
|
215
|
+
self._expr_to_egg(decl.default) if decl.default else None,
|
|
216
|
+
self._expr_to_egg(decl.merge) if decl.merge else None,
|
|
195
217
|
[self.action_to_egg(a) for a in decl.on_merge],
|
|
196
218
|
decl.cost,
|
|
197
219
|
decl.unextractable,
|
|
@@ -212,25 +234,30 @@ class EGraphState:
|
|
|
212
234
|
decl = self.__egg_decls__._classes[ref.name]
|
|
213
235
|
self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
|
|
214
236
|
if not decl.builtin or ref.args:
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
(
|
|
220
|
-
self.type_ref_to_egg(
|
|
221
|
-
[bindings.Var(self.type_ref_to_egg(a)) for a in ref.args],
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
237
|
+
if ref.args:
|
|
238
|
+
if ref.name == "UnstableFn":
|
|
239
|
+
# UnstableFn is a special case, where the rest of args are collected into a call
|
|
240
|
+
type_args: list[bindings._Expr] = [
|
|
241
|
+
bindings.Call(
|
|
242
|
+
self.type_ref_to_egg(ref.args[1]),
|
|
243
|
+
[bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
244
|
+
),
|
|
245
|
+
bindings.Var(self.type_ref_to_egg(ref.args[0])),
|
|
246
|
+
]
|
|
247
|
+
else:
|
|
248
|
+
type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
|
|
249
|
+
args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
|
|
250
|
+
else:
|
|
251
|
+
args = None
|
|
252
|
+
self.egraph.run_program(bindings.Sort(egg_name, args))
|
|
228
253
|
# For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
|
|
229
254
|
# these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
|
|
230
255
|
# even if you never use that function.
|
|
231
256
|
if decl.builtin:
|
|
232
257
|
for method in decl.class_methods:
|
|
233
258
|
self.callable_ref_to_egg(ClassMethodRef(ref.name, method))
|
|
259
|
+
if decl.init:
|
|
260
|
+
self.callable_ref_to_egg(InitRef(ref.name))
|
|
234
261
|
|
|
235
262
|
return egg_name
|
|
236
263
|
|
|
@@ -247,15 +274,15 @@ class EGraphState:
|
|
|
247
274
|
|
|
248
275
|
def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
|
|
249
276
|
self.type_ref_to_egg(typed_expr_decl.tp)
|
|
250
|
-
return self.
|
|
277
|
+
return self._expr_to_egg(typed_expr_decl.expr)
|
|
251
278
|
|
|
252
279
|
@overload
|
|
253
|
-
def
|
|
280
|
+
def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
|
|
254
281
|
|
|
255
282
|
@overload
|
|
256
|
-
def
|
|
283
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
|
|
257
284
|
|
|
258
|
-
def
|
|
285
|
+
def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
|
|
259
286
|
"""
|
|
260
287
|
Convert an ExprDecl to an egg expression.
|
|
261
288
|
|
|
@@ -292,6 +319,9 @@ class EGraphState:
|
|
|
292
319
|
res = bindings.Call(egg_fn, egg_args)
|
|
293
320
|
case PyObjectDecl(value):
|
|
294
321
|
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
322
|
+
case PartialCallDecl(call_decl):
|
|
323
|
+
egg_fn_call = self._expr_to_egg(call_decl)
|
|
324
|
+
res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
|
|
295
325
|
case _:
|
|
296
326
|
assert_never(expr_decl.expr)
|
|
297
327
|
|
|
@@ -338,6 +368,8 @@ def _generate_callable_egg_name(ref: CallableRef) -> str:
|
|
|
338
368
|
| PropertyRef(cls_name, name)
|
|
339
369
|
):
|
|
340
370
|
return f"{cls_name}_{name}"
|
|
371
|
+
case InitRef(cls_name):
|
|
372
|
+
return f"{cls_name}___init__"
|
|
341
373
|
case _:
|
|
342
374
|
assert_never(ref)
|
|
343
375
|
|
|
@@ -371,26 +403,50 @@ class FromEggState:
|
|
|
371
403
|
if term.name == "py-object":
|
|
372
404
|
call = bindings.termdag_term_to_expr(self.termdag, term)
|
|
373
405
|
expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
|
|
406
|
+
if term.name == "unstable-fn":
|
|
407
|
+
# Get function name
|
|
408
|
+
fn_term, *arg_terms = term.args
|
|
409
|
+
fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
|
|
410
|
+
assert isinstance(fn_value.expr, LitDecl)
|
|
411
|
+
fn_name = fn_value.expr.value
|
|
412
|
+
assert isinstance(fn_name, str)
|
|
413
|
+
|
|
414
|
+
# Resolve what types the partiallied applied args are
|
|
415
|
+
assert tp.name == "UnstableFn"
|
|
416
|
+
call_decl = self.from_call(tp.args[0], bindings.TermApp(fn_name, arg_terms))
|
|
417
|
+
expr_decl = PartialCallDecl(call_decl)
|
|
374
418
|
else:
|
|
375
419
|
expr_decl = self.from_call(tp, term)
|
|
376
420
|
else:
|
|
377
421
|
assert_never(term)
|
|
378
422
|
return TypedExprDecl(tp, expr_decl)
|
|
379
423
|
|
|
380
|
-
def from_call(
|
|
424
|
+
def from_call(
|
|
425
|
+
self,
|
|
426
|
+
tp: JustTypeRef,
|
|
427
|
+
term: bindings.TermApp, # additional_arg_tps: tuple[JustTypeRef, ...]
|
|
428
|
+
) -> CallDecl:
|
|
381
429
|
"""
|
|
382
430
|
Convert a call to a CallDecl.
|
|
383
431
|
|
|
384
432
|
There could be Python call refs which match the call, so we need to find the correct one.
|
|
433
|
+
|
|
434
|
+
The additional_arg_tps are known types for arguments that come after the term args, used to infer types
|
|
435
|
+
for partially applied functions, where we know the types of the later args, but not of the earlier ones where
|
|
436
|
+
we have values for.
|
|
385
437
|
"""
|
|
386
438
|
# Find the first callable ref that matches the call
|
|
387
439
|
for callable_ref in self.state.egg_fn_to_callable_refs[term.name]:
|
|
388
440
|
# If this is a classmethod, we might need the type params that were bound for this type
|
|
389
441
|
# This could be multiple types if the classmethod is ambiguous, like map create.
|
|
390
442
|
possible_types: Iterable[JustTypeRef | None]
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
443
|
+
signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature
|
|
444
|
+
assert isinstance(signature, FunctionSignature)
|
|
445
|
+
if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
|
|
446
|
+
# Need OR in case we have class method whose class whas never added as a sort, which would happen
|
|
447
|
+
# if the class method didn't return that type and no other function did. In this case, we don't need
|
|
448
|
+
# to care about the type vars and we we don't need to bind any possible type.
|
|
449
|
+
possible_types = self.state._get_possible_types(callable_ref.class_name) or [None]
|
|
394
450
|
cls_name = callable_ref.class_name
|
|
395
451
|
else:
|
|
396
452
|
possible_types = [None]
|
|
@@ -399,19 +455,33 @@ class FromEggState:
|
|
|
399
455
|
tcs = TypeConstraintSolver(self.decls)
|
|
400
456
|
if possible_type and possible_type.args:
|
|
401
457
|
tcs.bind_class(possible_type)
|
|
402
|
-
|
|
403
458
|
try:
|
|
404
459
|
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
405
|
-
|
|
460
|
+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
|
|
406
461
|
)
|
|
407
462
|
except TypeConstraintError:
|
|
408
463
|
continue
|
|
409
|
-
args
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
464
|
+
args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
|
|
465
|
+
|
|
466
|
+
return CallDecl(
|
|
467
|
+
callable_ref,
|
|
468
|
+
args,
|
|
469
|
+
# Don't include bound type params if this is just a method, we only needed them for type resolution
|
|
470
|
+
# but dont need to store them
|
|
471
|
+
bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
|
|
472
|
+
)
|
|
417
473
|
raise ValueError(f"Could not find callable ref for call {term}")
|
|
474
|
+
|
|
475
|
+
def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
|
|
476
|
+
try:
|
|
477
|
+
return self.cache[term_id]
|
|
478
|
+
except KeyError:
|
|
479
|
+
res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
|
|
480
|
+
return res
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def _rule_var_name(s: str) -> str:
|
|
484
|
+
"""
|
|
485
|
+
Create a hidden variable name, for rewrites, so that let bindings or function won't conflict with it
|
|
486
|
+
"""
|
|
487
|
+
return f"__var__{s}"
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# mypy: disable-error-code="empty-body"
|
|
2
|
+
"""
|
|
3
|
+
Higher Order Functions
|
|
4
|
+
======================
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from egglog import *
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Math(Expr):
|
|
18
|
+
def __init__(self, i: i64Like) -> None: ...
|
|
19
|
+
|
|
20
|
+
def __add__(self, other: Math) -> Math: ...
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MathList(Expr):
|
|
24
|
+
def __init__(self) -> None: ...
|
|
25
|
+
|
|
26
|
+
def append(self, i: Math) -> MathList: ...
|
|
27
|
+
|
|
28
|
+
def map(self, f: Callable[[Math], Math]) -> MathList: ...
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@ruleset
|
|
32
|
+
def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]): # noqa: ANN201
|
|
33
|
+
yield rewrite(Math(i) + Math(j)).to(Math(i + j))
|
|
34
|
+
yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
|
|
35
|
+
yield rewrite(MathList().map(f)).to(MathList())
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@function(ruleset=math_ruleset)
|
|
39
|
+
def increment_by_one(x: Math) -> Math:
|
|
40
|
+
return x + Math(1)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
egraph = EGraph()
|
|
44
|
+
x = egraph.let("x", MathList().append(Math(1)).append(Math(2)))
|
|
45
|
+
y = egraph.let("y", x.map(increment_by_one))
|
|
46
|
+
egraph.run(math_ruleset.saturate())
|
|
47
|
+
|
|
48
|
+
egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
|
|
49
|
+
|
|
50
|
+
egraph
|
egglog/exp/array_api.py
CHANGED
|
@@ -18,7 +18,7 @@ from egglog.runtime import RuntimeExpr
|
|
|
18
18
|
from .program_gen import *
|
|
19
19
|
|
|
20
20
|
if TYPE_CHECKING:
|
|
21
|
-
from collections.abc import Iterator
|
|
21
|
+
from collections.abc import Callable, Iterator
|
|
22
22
|
from types import ModuleType
|
|
23
23
|
|
|
24
24
|
# Pretend that exprs are numbers b/c sklearn does isinstance checks
|
|
@@ -257,7 +257,7 @@ class TupleInt(Expr):
|
|
|
257
257
|
|
|
258
258
|
def __getitem__(self, i: Int) -> Int: ...
|
|
259
259
|
|
|
260
|
-
def
|
|
260
|
+
def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
|
|
261
261
|
|
|
262
262
|
|
|
263
263
|
converter(
|
|
@@ -272,7 +272,7 @@ converter(
|
|
|
272
272
|
|
|
273
273
|
|
|
274
274
|
@array_api_ruleset.register
|
|
275
|
-
def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
|
|
275
|
+
def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64, f: Callable[[Int, Int], Int]):
|
|
276
276
|
return [
|
|
277
277
|
rewrite(ti + TupleInt.EMPTY).to(ti),
|
|
278
278
|
rewrite(TupleInt(i).length()).to(Int(1)),
|
|
@@ -281,10 +281,10 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
|
|
|
281
281
|
rewrite((TupleInt(i) + ti)[Int(0)]).to(i),
|
|
282
282
|
# Rule for indexing > 0
|
|
283
283
|
rule(eq(i).to((TupleInt(i2) + ti)[Int(k)]), k > 0).then(union(i).with_(ti[Int(k - 1)])),
|
|
284
|
-
#
|
|
285
|
-
rewrite(TupleInt(i)
|
|
286
|
-
rewrite(
|
|
287
|
-
rewrite(TupleInt.
|
|
284
|
+
# fold
|
|
285
|
+
rewrite(TupleInt.EMPTY.fold(i, f)).to(i),
|
|
286
|
+
rewrite(TupleInt(i2).fold(i, f)).to(f(i, i2)),
|
|
287
|
+
rewrite((TupleInt(i2) + ti).fold(i, f)).to(ti.fold(f(i, i2), f)),
|
|
288
288
|
]
|
|
289
289
|
|
|
290
290
|
|
|
@@ -882,7 +882,10 @@ converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
|
|
|
882
882
|
|
|
883
883
|
@function
|
|
884
884
|
def asarray(
|
|
885
|
-
a: NDArray,
|
|
885
|
+
a: NDArray,
|
|
886
|
+
dtype: OptionalDType = OptionalDType.none,
|
|
887
|
+
copy: OptionalBool = OptionalBool.none,
|
|
888
|
+
device: OptionalDevice = OptionalDevice.none,
|
|
886
889
|
) -> NDArray: ...
|
|
887
890
|
|
|
888
891
|
|
|
@@ -1346,7 +1349,7 @@ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
|
|
|
1346
1349
|
|
|
1347
1350
|
@array_api_ruleset.register
|
|
1348
1351
|
def _size(x: NDArray):
|
|
1349
|
-
yield rewrite(x.size).to(x.shape.
|
|
1352
|
+
yield rewrite(x.size).to(x.shape.fold(Int(1), Int.__mul__))
|
|
1350
1353
|
|
|
1351
1354
|
|
|
1352
1355
|
@overload
|
egglog/pretty.py
CHANGED
|
@@ -66,7 +66,7 @@ UNARY_METHODS = {
|
|
|
66
66
|
"__invert__": "~",
|
|
67
67
|
}
|
|
68
68
|
|
|
69
|
-
AllDecls: TypeAlias = RulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
|
|
69
|
+
AllDecls: TypeAlias = RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
def pretty_decl(
|
|
@@ -106,7 +106,7 @@ def pretty_callable_ref(
|
|
|
106
106
|
"""
|
|
107
107
|
# Pass in three dummy args, which are the max used for any operation that
|
|
108
108
|
# is not a generic function call
|
|
109
|
-
args: list[ExprDecl] = [
|
|
109
|
+
args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
|
|
110
110
|
if first_arg:
|
|
111
111
|
args.insert(0, first_arg)
|
|
112
112
|
res = PrettyContext(decls, defaultdict(lambda: 0))._call_inner(
|
|
@@ -117,6 +117,10 @@ def pretty_callable_ref(
|
|
|
117
117
|
return res[0] if isinstance(res, tuple) else res
|
|
118
118
|
|
|
119
119
|
|
|
120
|
+
# TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
|
|
121
|
+
# so that things like Math.__add__ will be represented properly
|
|
122
|
+
|
|
123
|
+
|
|
120
124
|
@dataclass
|
|
121
125
|
class TraverseContext:
|
|
122
126
|
"""
|
|
@@ -162,6 +166,8 @@ class TraverseContext:
|
|
|
162
166
|
pass
|
|
163
167
|
case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls):
|
|
164
168
|
for de in decls:
|
|
169
|
+
if isinstance(de, DefaultRewriteDecl):
|
|
170
|
+
continue
|
|
165
171
|
self(de)
|
|
166
172
|
case CallDecl(_, exprs, _):
|
|
167
173
|
for e in exprs:
|
|
@@ -170,6 +176,12 @@ class TraverseContext:
|
|
|
170
176
|
if until:
|
|
171
177
|
for f in until:
|
|
172
178
|
self(f)
|
|
179
|
+
case PartialCallDecl(c):
|
|
180
|
+
self(c)
|
|
181
|
+
case CombinedRulesetDecl(_):
|
|
182
|
+
pass
|
|
183
|
+
case DefaultRewriteDecl():
|
|
184
|
+
pass
|
|
173
185
|
case _:
|
|
174
186
|
assert_never(decl)
|
|
175
187
|
|
|
@@ -231,6 +243,9 @@ class PrettyContext:
|
|
|
231
243
|
return name, name
|
|
232
244
|
case CallDecl(_, _, _):
|
|
233
245
|
return self._call(decl, parens)
|
|
246
|
+
case PartialCallDecl(CallDecl(ref, typed_args, _)):
|
|
247
|
+
arg_strs = (_pretty_callable(ref), *(self(a.expr, parens=False, unwrap_lit=True) for a in typed_args))
|
|
248
|
+
return f"UnstableFn({', '.join(arg_strs)})", "fn"
|
|
234
249
|
case PyObjectDecl(value):
|
|
235
250
|
return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
|
|
236
251
|
case ActionCommandDecl(action):
|
|
@@ -265,8 +280,12 @@ class PrettyContext:
|
|
|
265
280
|
case RulesetDecl(rules):
|
|
266
281
|
if ruleset_name:
|
|
267
282
|
return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
|
|
268
|
-
args = ", ".join(
|
|
283
|
+
args = ", ".join(self(r) for r in rules if not isinstance(r, DefaultRewriteDecl))
|
|
269
284
|
return f"ruleset({args})", "ruleset"
|
|
285
|
+
case CombinedRulesetDecl(rulesets):
|
|
286
|
+
if ruleset_name:
|
|
287
|
+
rulesets = (*rulesets, f"name={ruleset_name!r})")
|
|
288
|
+
return f"unstable_combine_rulesets({', '.join(rulesets)})", "combined_ruleset"
|
|
270
289
|
case SaturateDecl(schedule):
|
|
271
290
|
return f"{self(schedule, parens=True)}.saturate()", "schedule"
|
|
272
291
|
case RepeatDecl(schedule, times):
|
|
@@ -283,6 +302,9 @@ class PrettyContext:
|
|
|
283
302
|
return ruleset_str, "schedule"
|
|
284
303
|
args = ", ".join(map(self, until))
|
|
285
304
|
return f"run({ruleset_str}, {args})", "schedule"
|
|
305
|
+
case DefaultRewriteDecl():
|
|
306
|
+
msg = "default rewrites should not be pretty printed"
|
|
307
|
+
raise TypeError(msg)
|
|
286
308
|
assert_never(decl)
|
|
287
309
|
|
|
288
310
|
def _call(
|
|
@@ -302,19 +324,28 @@ class PrettyContext:
|
|
|
302
324
|
l, r = self(args[0]), self(args[1])
|
|
303
325
|
return f"ne({l}).to({r})", "Unit"
|
|
304
326
|
function_decl = self.decls.get_callable_decl(ref).to_function_decl()
|
|
327
|
+
signature = function_decl.signature
|
|
328
|
+
|
|
305
329
|
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
|
|
306
330
|
n_defaults = 0
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
331
|
+
# Dont try counting defaults for function application
|
|
332
|
+
if isinstance(signature, FunctionSignature):
|
|
333
|
+
for arg, default in zip(
|
|
334
|
+
reversed(args), reversed(signature.arg_defaults), strict=not signature.var_arg_type
|
|
335
|
+
):
|
|
336
|
+
if arg != default:
|
|
337
|
+
break
|
|
338
|
+
n_defaults += 1
|
|
313
339
|
if n_defaults:
|
|
314
340
|
args = args[:-n_defaults]
|
|
315
341
|
|
|
316
|
-
|
|
317
|
-
if
|
|
342
|
+
# If this is a function application, the type is the first type arg of the function object
|
|
343
|
+
if signature == "fn-app":
|
|
344
|
+
tp_name = decl.args[0].tp.args[0].name
|
|
345
|
+
else:
|
|
346
|
+
assert isinstance(signature, FunctionSignature)
|
|
347
|
+
tp_name = signature.semantic_return_type.name
|
|
348
|
+
if isinstance(signature, FunctionSignature) and signature.mutates:
|
|
318
349
|
first_arg = args[0]
|
|
319
350
|
expr_str = self(first_arg)
|
|
320
351
|
# copy an identifier expression iff it has multiple parents (b/c then we can't mutate it directly)
|
|
@@ -346,10 +377,8 @@ class PrettyContext:
|
|
|
346
377
|
case FunctionRef(name):
|
|
347
378
|
return name, args
|
|
348
379
|
case ClassMethodRef(class_name, method_name):
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
fn_str += f".{method_name}"
|
|
352
|
-
return fn_str, args
|
|
380
|
+
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
381
|
+
return f"{tp_ref}.{method_name}", args
|
|
353
382
|
case MethodRef(_class_name, method_name):
|
|
354
383
|
slf, *args = args
|
|
355
384
|
slf = self(slf, parens=True)
|
|
@@ -376,6 +405,9 @@ class PrettyContext:
|
|
|
376
405
|
return f"{class_name}.{variable_name}"
|
|
377
406
|
case PropertyRef(_class_name, property_name):
|
|
378
407
|
return f"{self(args[0], parens=True)}.{property_name}"
|
|
408
|
+
case InitRef(class_name):
|
|
409
|
+
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
410
|
+
return str(tp_ref), args
|
|
379
411
|
assert_never(ref)
|
|
380
412
|
|
|
381
413
|
def _generate_name(self, typ: str) -> str:
|
|
@@ -397,6 +429,30 @@ class PrettyContext:
|
|
|
397
429
|
return name
|
|
398
430
|
|
|
399
431
|
|
|
432
|
+
def _pretty_callable(ref: CallableRef) -> str:
|
|
433
|
+
"""
|
|
434
|
+
Returns a function call as a string.
|
|
435
|
+
"""
|
|
436
|
+
match ref:
|
|
437
|
+
case FunctionRef(name):
|
|
438
|
+
return name
|
|
439
|
+
case (
|
|
440
|
+
ClassMethodRef(class_name, method_name)
|
|
441
|
+
| MethodRef(class_name, method_name)
|
|
442
|
+
| PropertyRef(class_name, method_name)
|
|
443
|
+
):
|
|
444
|
+
return f"{class_name}.{method_name}"
|
|
445
|
+
case InitRef(class_name):
|
|
446
|
+
return class_name
|
|
447
|
+
case ConstantRef(_):
|
|
448
|
+
msg = "Constants should not be callable"
|
|
449
|
+
raise NotImplementedError(msg)
|
|
450
|
+
case ClassVariableRef(_, _):
|
|
451
|
+
msg = "Class variables should not be callable"
|
|
452
|
+
raise NotADirectoryError(msg)
|
|
453
|
+
assert_never(ref)
|
|
454
|
+
|
|
455
|
+
|
|
400
456
|
def _plot_line_length(expr: object): # pragma: no cover
|
|
401
457
|
"""
|
|
402
458
|
Plots the number of line lengths based on different max lengths
|