egglog 7.1.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +146 -8
- egglog/egraph.py +337 -203
- egglog/egraph_state.py +171 -64
- egglog/examples/higher_order_functions.py +45 -0
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +97 -43
- egglog/runtime.py +60 -44
- egglog/thunk.py +44 -20
- egglog/type_constraint_solver.py +5 -4
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/METADATA +31 -30
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.1.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.1.0.dist-info/RECORD +0 -39
- {egglog-7.1.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
egglog/runtime.py
CHANGED
|
@@ -127,8 +127,8 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
127
127
|
# If this is a literal type, initializing it with a literal should return a literal
|
|
128
128
|
if (name := self.__egg_tp__.name) == "PyObject":
|
|
129
129
|
assert len(args) == 1
|
|
130
|
-
return RuntimeExpr
|
|
131
|
-
self.
|
|
130
|
+
return RuntimeExpr(
|
|
131
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])))
|
|
132
132
|
)
|
|
133
133
|
if name == "UnstableFn":
|
|
134
134
|
assert not kwargs
|
|
@@ -150,23 +150,21 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
150
150
|
value = PartialCallDecl(replace(call, args=call.args[:n_args]))
|
|
151
151
|
remaining_arg_types = [a.tp for a in call.args[n_args:]]
|
|
152
152
|
type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
|
|
153
|
-
return RuntimeExpr.
|
|
153
|
+
return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
|
|
154
154
|
|
|
155
155
|
if name in UNARY_LIT_CLASS_NAMES:
|
|
156
156
|
assert len(args) == 1
|
|
157
157
|
assert isinstance(args[0], int | float | str | bool)
|
|
158
|
-
return RuntimeExpr
|
|
159
|
-
self.
|
|
158
|
+
return RuntimeExpr(
|
|
159
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])))
|
|
160
160
|
)
|
|
161
161
|
if name == UNIT_CLASS_NAME:
|
|
162
162
|
assert len(args) == 0
|
|
163
|
-
return RuntimeExpr
|
|
164
|
-
self.
|
|
163
|
+
return RuntimeExpr(
|
|
164
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)))
|
|
165
165
|
)
|
|
166
|
-
|
|
167
|
-
return
|
|
168
|
-
Thunk.value(self.__egg_decls__), ClassMethodRef(name, "__init__"), self.__egg_tp__.to_just()
|
|
169
|
-
)(*args, **kwargs) # type: ignore[arg-type]
|
|
166
|
+
fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(InitRef(name)), self.__egg_tp__.to_just())
|
|
167
|
+
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
170
168
|
|
|
171
169
|
def __dir__(self) -> list[str]:
|
|
172
170
|
cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
|
|
@@ -210,19 +208,21 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
210
208
|
# if this is a class variable, return an expr for it, otherwise, assume it's a method
|
|
211
209
|
if name in cls_decl.class_variables:
|
|
212
210
|
return_tp = cls_decl.class_variables[name]
|
|
213
|
-
return RuntimeExpr
|
|
214
|
-
self.
|
|
215
|
-
TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
|
|
211
|
+
return RuntimeExpr(
|
|
212
|
+
self.__egg_decls_thunk__,
|
|
213
|
+
Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name)))),
|
|
216
214
|
)
|
|
217
215
|
if name in cls_decl.class_methods:
|
|
218
216
|
return RuntimeFunction(
|
|
219
|
-
|
|
217
|
+
self.__egg_decls_thunk__,
|
|
218
|
+
Thunk.value(ClassMethodRef(self.__egg_tp__.name, name)),
|
|
219
|
+
self.__egg_tp__.to_just(),
|
|
220
220
|
)
|
|
221
221
|
# allow referencing properties and methods as class variables as well
|
|
222
222
|
if name in cls_decl.properties:
|
|
223
|
-
return RuntimeFunction(Thunk.value(
|
|
223
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.name, name)))
|
|
224
224
|
if name in cls_decl.methods:
|
|
225
|
-
return RuntimeFunction(Thunk.value(
|
|
225
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name)))
|
|
226
226
|
|
|
227
227
|
msg = f"Class {self.__egg_tp__.name} has no method {name}"
|
|
228
228
|
if name == "__ne__":
|
|
@@ -243,16 +243,23 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
243
243
|
|
|
244
244
|
@dataclass
|
|
245
245
|
class RuntimeFunction(DelayedDeclerations):
|
|
246
|
-
|
|
246
|
+
__egg_ref_thunk__: Callable[[], CallableRef]
|
|
247
247
|
# bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
|
|
248
248
|
__egg_bound__: JustTypeRef | RuntimeExpr | None = None
|
|
249
249
|
|
|
250
|
+
@property
|
|
251
|
+
def __egg_ref__(self) -> CallableRef:
|
|
252
|
+
return self.__egg_ref_thunk__()
|
|
253
|
+
|
|
250
254
|
def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
|
|
251
255
|
from .conversion import resolve_literal
|
|
252
256
|
|
|
253
257
|
if isinstance(self.__egg_bound__, RuntimeExpr):
|
|
254
258
|
args = (self.__egg_bound__, *args)
|
|
255
|
-
|
|
259
|
+
try:
|
|
260
|
+
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
|
|
261
|
+
except Exception as e:
|
|
262
|
+
raise TypeError(f"Failed to find callable {self}") from e
|
|
256
263
|
decls = self.__egg_decls__.copy()
|
|
257
264
|
# Special case function application bc we dont support variadic generics yet generally
|
|
258
265
|
if signature == "fn-app":
|
|
@@ -277,14 +284,17 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
277
284
|
|
|
278
285
|
# Turn all keyword args into positional args
|
|
279
286
|
py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
|
|
280
|
-
|
|
287
|
+
try:
|
|
288
|
+
bound = py_signature.bind(*args, **kwargs)
|
|
289
|
+
except TypeError as err:
|
|
290
|
+
raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err
|
|
281
291
|
del kwargs
|
|
282
292
|
bound.apply_defaults()
|
|
283
293
|
assert not bound.kwargs
|
|
284
294
|
args = bound.args
|
|
285
295
|
|
|
286
296
|
upcasted_args = [
|
|
287
|
-
resolve_literal(cast(TypeOrVarRef, tp), arg)
|
|
297
|
+
resolve_literal(cast(TypeOrVarRef, tp), arg, Thunk.value(decls))
|
|
288
298
|
for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
|
|
289
299
|
]
|
|
290
300
|
decls.update(*upcasted_args)
|
|
@@ -310,7 +320,9 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
310
320
|
return_tp = tcs.infer_return_type(
|
|
311
321
|
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, arg_types, cls_name
|
|
312
322
|
)
|
|
313
|
-
bound_params =
|
|
323
|
+
bound_params = (
|
|
324
|
+
cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None
|
|
325
|
+
)
|
|
314
326
|
# If we were using unstable-app to call a funciton, add that function back as the first arg.
|
|
315
327
|
if function_value:
|
|
316
328
|
arg_exprs = (function_value, *arg_exprs)
|
|
@@ -319,9 +331,10 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
319
331
|
# If there is not return type, we are mutating the first arg
|
|
320
332
|
if not signature.return_type:
|
|
321
333
|
first_arg = upcasted_args[0]
|
|
322
|
-
first_arg.
|
|
334
|
+
first_arg.__egg_decls_thunk__ = Thunk.value(decls)
|
|
335
|
+
first_arg.__egg_typed_expr_thunk__ = Thunk.value(typed_expr_decl)
|
|
323
336
|
return None
|
|
324
|
-
return RuntimeExpr.
|
|
337
|
+
return RuntimeExpr.__from_values__(decls, typed_expr_decl)
|
|
325
338
|
|
|
326
339
|
def __str__(self) -> str:
|
|
327
340
|
first_arg, bound_tp_params = None, None
|
|
@@ -346,7 +359,9 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
|
|
|
346
359
|
Parameter(
|
|
347
360
|
n,
|
|
348
361
|
Parameter.POSITIONAL_OR_KEYWORD,
|
|
349
|
-
default=RuntimeExpr.
|
|
362
|
+
default=RuntimeExpr.__from_values__(
|
|
363
|
+
decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
|
|
364
|
+
)
|
|
350
365
|
if d is not None or optional_args
|
|
351
366
|
else Parameter.empty,
|
|
352
367
|
)
|
|
@@ -384,21 +399,16 @@ PARTIAL_METHODS = {
|
|
|
384
399
|
|
|
385
400
|
|
|
386
401
|
@dataclass
|
|
387
|
-
class RuntimeExpr:
|
|
388
|
-
|
|
389
|
-
__egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
|
|
402
|
+
class RuntimeExpr(DelayedDeclerations):
|
|
403
|
+
__egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
|
|
390
404
|
|
|
391
405
|
@classmethod
|
|
392
|
-
def
|
|
393
|
-
return cls(Thunk.value(
|
|
394
|
-
|
|
395
|
-
@property
|
|
396
|
-
def __egg_decls__(self) -> Declarations:
|
|
397
|
-
return self.__egg_thunk__()[0]
|
|
406
|
+
def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
|
|
407
|
+
return cls(Thunk.value(d), Thunk.value(e))
|
|
398
408
|
|
|
399
409
|
@property
|
|
400
410
|
def __egg_typed_expr__(self) -> TypedExprDecl:
|
|
401
|
-
return self.
|
|
411
|
+
return self.__egg_typed_expr_thunk__()
|
|
402
412
|
|
|
403
413
|
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
404
414
|
cls_name = self.__egg_class_name__
|
|
@@ -408,9 +418,9 @@ class RuntimeExpr:
|
|
|
408
418
|
return preserved_methods[name].__get__(self)
|
|
409
419
|
|
|
410
420
|
if name in class_decl.methods:
|
|
411
|
-
return RuntimeFunction(Thunk.value(
|
|
421
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
|
|
412
422
|
if name in class_decl.properties:
|
|
413
|
-
return RuntimeFunction(Thunk.value(
|
|
423
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
|
|
414
424
|
raise AttributeError(f"{cls_name} has no method {name}") from None
|
|
415
425
|
|
|
416
426
|
def __repr__(self) -> str:
|
|
@@ -453,10 +463,11 @@ class RuntimeExpr:
|
|
|
453
463
|
# otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
|
|
454
464
|
|
|
455
465
|
def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
|
|
456
|
-
return self.
|
|
466
|
+
return self.__egg_decls__, self.__egg_typed_expr__
|
|
457
467
|
|
|
458
468
|
def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
|
|
459
|
-
self.
|
|
469
|
+
self.__egg_decls_thunk__ = Thunk.value(d[0])
|
|
470
|
+
self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
|
|
460
471
|
|
|
461
472
|
def __hash__(self) -> int:
|
|
462
473
|
return hash(self.__egg_typed_expr__)
|
|
@@ -494,7 +505,7 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
|
|
|
494
505
|
# we use the standard process
|
|
495
506
|
pass
|
|
496
507
|
if __name in class_decl.methods:
|
|
497
|
-
fn = RuntimeFunction(Thunk.value(
|
|
508
|
+
fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
|
|
498
509
|
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
499
510
|
if __name in PARTIAL_METHODS:
|
|
500
511
|
return NotImplemented
|
|
@@ -520,7 +531,7 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime
|
|
|
520
531
|
min_tp = min_convertable_tp(slf, other, name)
|
|
521
532
|
slf = resolve_literal(TypeRefWithVars(min_tp), slf)
|
|
522
533
|
other = resolve_literal(TypeRefWithVars(min_tp), other)
|
|
523
|
-
method = RuntimeFunction(Thunk.value(
|
|
534
|
+
method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
|
|
524
535
|
return method(other)
|
|
525
536
|
|
|
526
537
|
|
|
@@ -542,7 +553,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
|
|
|
542
553
|
"""
|
|
543
554
|
match callable:
|
|
544
555
|
case RuntimeFunction(decls, ref, _):
|
|
545
|
-
return ref, decls()
|
|
556
|
+
return ref(), decls()
|
|
546
557
|
case RuntimeClass(thunk, tp):
|
|
547
|
-
return
|
|
548
|
-
|
|
558
|
+
return InitRef(tp.name), thunk()
|
|
559
|
+
case RuntimeExpr(decl_thunk, expr_thunk):
|
|
560
|
+
if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(expr.callable, ConstantRef):
|
|
561
|
+
raise NotImplementedError(f"Can only turn constants into callable refs, not {expr}")
|
|
562
|
+
return expr.callable, decl_thunk()
|
|
563
|
+
case _:
|
|
564
|
+
raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
|
egglog/thunk.py
CHANGED
|
@@ -9,10 +9,27 @@ if TYPE_CHECKING:
|
|
|
9
9
|
from collections.abc import Callable
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
__all__ = ["Thunk"]
|
|
12
|
+
__all__ = ["Thunk", "split_thunk"]
|
|
13
13
|
|
|
14
14
|
T = TypeVar("T")
|
|
15
15
|
TS = TypeVarTuple("TS")
|
|
16
|
+
V = TypeVar("V")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def split_thunk(fn: Callable[[], tuple[T, V]]) -> tuple[Callable[[], T], Callable[[], V]]:
|
|
20
|
+
s = _Split(fn)
|
|
21
|
+
return s.left, s.right
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class _Split(Generic[T, V]):
|
|
26
|
+
fn: Callable[[], tuple[T, V]]
|
|
27
|
+
|
|
28
|
+
def left(self) -> T:
|
|
29
|
+
return self.fn()[0]
|
|
30
|
+
|
|
31
|
+
def right(self) -> V:
|
|
32
|
+
return self.fn()[1]
|
|
16
33
|
|
|
17
34
|
|
|
18
35
|
@dataclass
|
|
@@ -21,18 +38,16 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
21
38
|
Cached delayed function call.
|
|
22
39
|
"""
|
|
23
40
|
|
|
24
|
-
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving
|
|
41
|
+
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
|
|
25
42
|
|
|
26
43
|
@classmethod
|
|
27
|
-
def fn(
|
|
28
|
-
cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], fallback: Callable[[], T] | None = None
|
|
29
|
-
) -> Thunk[T, Unpack[TS]]:
|
|
44
|
+
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
|
|
30
45
|
"""
|
|
31
46
|
Create a thunk based on some functions and some partial args.
|
|
32
47
|
|
|
33
|
-
If the function is called while it is being resolved recursively
|
|
48
|
+
If the function is called while it is being resolved recursively it will raise an exception.
|
|
34
49
|
"""
|
|
35
|
-
return cls(Unresolved(fn, args
|
|
50
|
+
return cls(Unresolved(fn, args))
|
|
36
51
|
|
|
37
52
|
@classmethod
|
|
38
53
|
def value(cls, value: T) -> Thunk[T]:
|
|
@@ -42,16 +57,21 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
42
57
|
match self.state:
|
|
43
58
|
case Resolved(value):
|
|
44
59
|
return value
|
|
45
|
-
case Unresolved(fn, args
|
|
46
|
-
self.state = Resolving(
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
60
|
+
case Unresolved(fn, args):
|
|
61
|
+
self.state = Resolving()
|
|
62
|
+
try:
|
|
63
|
+
res = fn(*args)
|
|
64
|
+
except Exception as e:
|
|
65
|
+
self.state = Error(e)
|
|
66
|
+
raise e from None
|
|
67
|
+
else:
|
|
68
|
+
self.state = Resolved(res)
|
|
69
|
+
return res
|
|
70
|
+
case Resolving():
|
|
71
|
+
msg = "Recursively resolving thunk"
|
|
72
|
+
raise ValueError(msg)
|
|
73
|
+
case Error(e):
|
|
74
|
+
raise e
|
|
55
75
|
|
|
56
76
|
|
|
57
77
|
@dataclass
|
|
@@ -63,9 +83,13 @@ class Resolved(Generic[T]):
|
|
|
63
83
|
class Unresolved(Generic[T, Unpack[TS]]):
|
|
64
84
|
fn: Callable[[Unpack[TS]], T]
|
|
65
85
|
args: tuple[Unpack[TS]]
|
|
66
|
-
fallback: Callable[[], T] | None
|
|
67
86
|
|
|
68
87
|
|
|
69
88
|
@dataclass
|
|
70
|
-
class Resolving
|
|
71
|
-
|
|
89
|
+
class Resolving:
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class Error:
|
|
95
|
+
e: Exception
|
egglog/type_constraint_solver.py
CHANGED
|
@@ -79,9 +79,10 @@ class TypeConstraintSolver:
|
|
|
79
79
|
Also returns the bound type params if the class name is passed in.
|
|
80
80
|
"""
|
|
81
81
|
self._infer_typevars(fn_return, return_, cls_name)
|
|
82
|
-
arg_types = (
|
|
83
|
-
|
|
84
|
-
|
|
82
|
+
arg_types: Iterable[JustTypeRef] = [self._subtitute_typevars(a, cls_name) for a in fn_args]
|
|
83
|
+
if fn_var_args:
|
|
84
|
+
# Need to be generator so it can be infinite for variable args
|
|
85
|
+
arg_types = chain(arg_types, repeat(self._subtitute_typevars(fn_var_args, cls_name)))
|
|
85
86
|
bound_typevars = (
|
|
86
87
|
tuple(
|
|
87
88
|
v
|
|
@@ -132,8 +133,8 @@ class TypeConstraintSolver:
|
|
|
132
133
|
def _subtitute_typevars(self, tp: TypeOrVarRef, cls_name: str | None) -> JustTypeRef:
|
|
133
134
|
match tp:
|
|
134
135
|
case ClassTypeVarRef(name):
|
|
136
|
+
assert cls_name is not None
|
|
135
137
|
try:
|
|
136
|
-
assert cls_name is not None
|
|
137
138
|
return self._cls_typevar_index_to_type[cls_name][name]
|
|
138
139
|
except KeyError as e:
|
|
139
140
|
raise TypeConstraintError(f"Not enough bound typevars for {tp}") from e
|