egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +63 -23
- egglog/builtins.py +49 -6
- egglog/conversion.py +31 -8
- egglog/declarations.py +83 -4
- egglog/egraph.py +241 -173
- egglog/egraph_state.py +137 -61
- egglog/examples/higher_order_functions.py +3 -8
- egglog/exp/array_api.py +274 -92
- egglog/exp/array_api_jit.py +1 -4
- egglog/exp/array_api_loopnest.py +145 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +51 -12
- egglog/functionalize.py +91 -0
- egglog/pretty.py +84 -40
- egglog/runtime.py +52 -39
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35753 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/METADATA +33 -32
- egglog-8.0.0.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.0.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.0.dist-info/licenses}/LICENSE +0 -0
egglog/runtime.py
CHANGED
|
@@ -127,8 +127,8 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
127
127
|
# If this is a literal type, initializing it with a literal should return a literal
|
|
128
128
|
if (name := self.__egg_tp__.name) == "PyObject":
|
|
129
129
|
assert len(args) == 1
|
|
130
|
-
return RuntimeExpr
|
|
131
|
-
self.
|
|
130
|
+
return RuntimeExpr(
|
|
131
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])))
|
|
132
132
|
)
|
|
133
133
|
if name == "UnstableFn":
|
|
134
134
|
assert not kwargs
|
|
@@ -150,20 +150,20 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
150
150
|
value = PartialCallDecl(replace(call, args=call.args[:n_args]))
|
|
151
151
|
remaining_arg_types = [a.tp for a in call.args[n_args:]]
|
|
152
152
|
type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
|
|
153
|
-
return RuntimeExpr.
|
|
153
|
+
return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
|
|
154
154
|
|
|
155
155
|
if name in UNARY_LIT_CLASS_NAMES:
|
|
156
156
|
assert len(args) == 1
|
|
157
157
|
assert isinstance(args[0], int | float | str | bool)
|
|
158
|
-
return RuntimeExpr
|
|
159
|
-
self.
|
|
158
|
+
return RuntimeExpr(
|
|
159
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])))
|
|
160
160
|
)
|
|
161
161
|
if name == UNIT_CLASS_NAME:
|
|
162
162
|
assert len(args) == 0
|
|
163
|
-
return RuntimeExpr
|
|
164
|
-
self.
|
|
163
|
+
return RuntimeExpr(
|
|
164
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)))
|
|
165
165
|
)
|
|
166
|
-
fn = RuntimeFunction(Thunk.value(
|
|
166
|
+
fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(InitRef(name)), self.__egg_tp__.to_just())
|
|
167
167
|
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
168
168
|
|
|
169
169
|
def __dir__(self) -> list[str]:
|
|
@@ -208,19 +208,21 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
208
208
|
# if this is a class variable, return an expr for it, otherwise, assume it's a method
|
|
209
209
|
if name in cls_decl.class_variables:
|
|
210
210
|
return_tp = cls_decl.class_variables[name]
|
|
211
|
-
return RuntimeExpr
|
|
212
|
-
self.
|
|
213
|
-
TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
|
|
211
|
+
return RuntimeExpr(
|
|
212
|
+
self.__egg_decls_thunk__,
|
|
213
|
+
Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name)))),
|
|
214
214
|
)
|
|
215
215
|
if name in cls_decl.class_methods:
|
|
216
216
|
return RuntimeFunction(
|
|
217
|
-
|
|
217
|
+
self.__egg_decls_thunk__,
|
|
218
|
+
Thunk.value(ClassMethodRef(self.__egg_tp__.name, name)),
|
|
219
|
+
self.__egg_tp__.to_just(),
|
|
218
220
|
)
|
|
219
221
|
# allow referencing properties and methods as class variables as well
|
|
220
222
|
if name in cls_decl.properties:
|
|
221
|
-
return RuntimeFunction(Thunk.value(
|
|
223
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.name, name)))
|
|
222
224
|
if name in cls_decl.methods:
|
|
223
|
-
return RuntimeFunction(Thunk.value(
|
|
225
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name)))
|
|
224
226
|
|
|
225
227
|
msg = f"Class {self.__egg_tp__.name} has no method {name}"
|
|
226
228
|
if name == "__ne__":
|
|
@@ -241,16 +243,23 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
241
243
|
|
|
242
244
|
@dataclass
|
|
243
245
|
class RuntimeFunction(DelayedDeclerations):
|
|
244
|
-
|
|
246
|
+
__egg_ref_thunk__: Callable[[], CallableRef]
|
|
245
247
|
# bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
|
|
246
248
|
__egg_bound__: JustTypeRef | RuntimeExpr | None = None
|
|
247
249
|
|
|
250
|
+
@property
|
|
251
|
+
def __egg_ref__(self) -> CallableRef:
|
|
252
|
+
return self.__egg_ref_thunk__()
|
|
253
|
+
|
|
248
254
|
def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
|
|
249
255
|
from .conversion import resolve_literal
|
|
250
256
|
|
|
251
257
|
if isinstance(self.__egg_bound__, RuntimeExpr):
|
|
252
258
|
args = (self.__egg_bound__, *args)
|
|
253
|
-
|
|
259
|
+
try:
|
|
260
|
+
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
|
|
261
|
+
except Exception as e:
|
|
262
|
+
raise TypeError(f"Failed to find callable {self}") from e
|
|
254
263
|
decls = self.__egg_decls__.copy()
|
|
255
264
|
# Special case function application bc we dont support variadic generics yet generally
|
|
256
265
|
if signature == "fn-app":
|
|
@@ -285,7 +294,7 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
285
294
|
args = bound.args
|
|
286
295
|
|
|
287
296
|
upcasted_args = [
|
|
288
|
-
resolve_literal(cast(TypeOrVarRef, tp), arg)
|
|
297
|
+
resolve_literal(cast(TypeOrVarRef, tp), arg, Thunk.value(decls))
|
|
289
298
|
for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
|
|
290
299
|
]
|
|
291
300
|
decls.update(*upcasted_args)
|
|
@@ -322,9 +331,10 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
322
331
|
# If there is not return type, we are mutating the first arg
|
|
323
332
|
if not signature.return_type:
|
|
324
333
|
first_arg = upcasted_args[0]
|
|
325
|
-
first_arg.
|
|
334
|
+
first_arg.__egg_decls_thunk__ = Thunk.value(decls)
|
|
335
|
+
first_arg.__egg_typed_expr_thunk__ = Thunk.value(typed_expr_decl)
|
|
326
336
|
return None
|
|
327
|
-
return RuntimeExpr.
|
|
337
|
+
return RuntimeExpr.__from_values__(decls, typed_expr_decl)
|
|
328
338
|
|
|
329
339
|
def __str__(self) -> str:
|
|
330
340
|
first_arg, bound_tp_params = None, None
|
|
@@ -349,7 +359,9 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
|
|
|
349
359
|
Parameter(
|
|
350
360
|
n,
|
|
351
361
|
Parameter.POSITIONAL_OR_KEYWORD,
|
|
352
|
-
default=RuntimeExpr.
|
|
362
|
+
default=RuntimeExpr.__from_values__(
|
|
363
|
+
decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
|
|
364
|
+
)
|
|
353
365
|
if d is not None or optional_args
|
|
354
366
|
else Parameter.empty,
|
|
355
367
|
)
|
|
@@ -387,21 +399,16 @@ PARTIAL_METHODS = {
|
|
|
387
399
|
|
|
388
400
|
|
|
389
401
|
@dataclass
|
|
390
|
-
class RuntimeExpr:
|
|
391
|
-
|
|
392
|
-
__egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
|
|
402
|
+
class RuntimeExpr(DelayedDeclerations):
|
|
403
|
+
__egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
|
|
393
404
|
|
|
394
405
|
@classmethod
|
|
395
|
-
def
|
|
396
|
-
return cls(Thunk.value(
|
|
397
|
-
|
|
398
|
-
@property
|
|
399
|
-
def __egg_decls__(self) -> Declarations:
|
|
400
|
-
return self.__egg_thunk__()[0]
|
|
406
|
+
def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
|
|
407
|
+
return cls(Thunk.value(d), Thunk.value(e))
|
|
401
408
|
|
|
402
409
|
@property
|
|
403
410
|
def __egg_typed_expr__(self) -> TypedExprDecl:
|
|
404
|
-
return self.
|
|
411
|
+
return self.__egg_typed_expr_thunk__()
|
|
405
412
|
|
|
406
413
|
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
407
414
|
cls_name = self.__egg_class_name__
|
|
@@ -411,9 +418,9 @@ class RuntimeExpr:
|
|
|
411
418
|
return preserved_methods[name].__get__(self)
|
|
412
419
|
|
|
413
420
|
if name in class_decl.methods:
|
|
414
|
-
return RuntimeFunction(Thunk.value(
|
|
421
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
|
|
415
422
|
if name in class_decl.properties:
|
|
416
|
-
return RuntimeFunction(Thunk.value(
|
|
423
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
|
|
417
424
|
raise AttributeError(f"{cls_name} has no method {name}") from None
|
|
418
425
|
|
|
419
426
|
def __repr__(self) -> str:
|
|
@@ -456,10 +463,11 @@ class RuntimeExpr:
|
|
|
456
463
|
# otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
|
|
457
464
|
|
|
458
465
|
def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
|
|
459
|
-
return self.
|
|
466
|
+
return self.__egg_decls__, self.__egg_typed_expr__
|
|
460
467
|
|
|
461
468
|
def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
|
|
462
|
-
self.
|
|
469
|
+
self.__egg_decls_thunk__ = Thunk.value(d[0])
|
|
470
|
+
self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
|
|
463
471
|
|
|
464
472
|
def __hash__(self) -> int:
|
|
465
473
|
return hash(self.__egg_typed_expr__)
|
|
@@ -497,7 +505,7 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
|
|
|
497
505
|
# we use the standard process
|
|
498
506
|
pass
|
|
499
507
|
if __name in class_decl.methods:
|
|
500
|
-
fn = RuntimeFunction(Thunk.value(
|
|
508
|
+
fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
|
|
501
509
|
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
502
510
|
if __name in PARTIAL_METHODS:
|
|
503
511
|
return NotImplemented
|
|
@@ -523,7 +531,7 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime
|
|
|
523
531
|
min_tp = min_convertable_tp(slf, other, name)
|
|
524
532
|
slf = resolve_literal(TypeRefWithVars(min_tp), slf)
|
|
525
533
|
other = resolve_literal(TypeRefWithVars(min_tp), other)
|
|
526
|
-
method = RuntimeFunction(Thunk.value(
|
|
534
|
+
method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
|
|
527
535
|
return method(other)
|
|
528
536
|
|
|
529
537
|
|
|
@@ -545,7 +553,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
|
|
|
545
553
|
"""
|
|
546
554
|
match callable:
|
|
547
555
|
case RuntimeFunction(decls, ref, _):
|
|
548
|
-
return ref, decls()
|
|
556
|
+
return ref(), decls()
|
|
549
557
|
case RuntimeClass(thunk, tp):
|
|
550
|
-
return
|
|
551
|
-
|
|
558
|
+
return InitRef(tp.name), thunk()
|
|
559
|
+
case RuntimeExpr(decl_thunk, expr_thunk):
|
|
560
|
+
if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(expr.callable, ConstantRef):
|
|
561
|
+
raise NotImplementedError(f"Can only turn constants into callable refs, not {expr}")
|
|
562
|
+
return expr.callable, decl_thunk()
|
|
563
|
+
case _:
|
|
564
|
+
raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
|
egglog/thunk.py
CHANGED
|
@@ -9,10 +9,27 @@ if TYPE_CHECKING:
|
|
|
9
9
|
from collections.abc import Callable
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
__all__ = ["Thunk"]
|
|
12
|
+
__all__ = ["Thunk", "split_thunk"]
|
|
13
13
|
|
|
14
14
|
T = TypeVar("T")
|
|
15
15
|
TS = TypeVarTuple("TS")
|
|
16
|
+
V = TypeVar("V")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def split_thunk(fn: Callable[[], tuple[T, V]]) -> tuple[Callable[[], T], Callable[[], V]]:
|
|
20
|
+
s = _Split(fn)
|
|
21
|
+
return s.left, s.right
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class _Split(Generic[T, V]):
|
|
26
|
+
fn: Callable[[], tuple[T, V]]
|
|
27
|
+
|
|
28
|
+
def left(self) -> T:
|
|
29
|
+
return self.fn()[0]
|
|
30
|
+
|
|
31
|
+
def right(self) -> V:
|
|
32
|
+
return self.fn()[1]
|
|
16
33
|
|
|
17
34
|
|
|
18
35
|
@dataclass
|
|
@@ -21,18 +38,16 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
21
38
|
Cached delayed function call.
|
|
22
39
|
"""
|
|
23
40
|
|
|
24
|
-
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving
|
|
41
|
+
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
|
|
25
42
|
|
|
26
43
|
@classmethod
|
|
27
|
-
def fn(
|
|
28
|
-
cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], fallback: Callable[[], T] | None = None
|
|
29
|
-
) -> Thunk[T, Unpack[TS]]:
|
|
44
|
+
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
|
|
30
45
|
"""
|
|
31
46
|
Create a thunk based on some functions and some partial args.
|
|
32
47
|
|
|
33
|
-
If the function is called while it is being resolved recursively
|
|
48
|
+
If the function is called while it is being resolved recursively it will raise an exception.
|
|
34
49
|
"""
|
|
35
|
-
return cls(Unresolved(fn, args
|
|
50
|
+
return cls(Unresolved(fn, args))
|
|
36
51
|
|
|
37
52
|
@classmethod
|
|
38
53
|
def value(cls, value: T) -> Thunk[T]:
|
|
@@ -42,21 +57,19 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
42
57
|
match self.state:
|
|
43
58
|
case Resolved(value):
|
|
44
59
|
return value
|
|
45
|
-
case Unresolved(fn, args
|
|
46
|
-
self.state = Resolving(
|
|
60
|
+
case Unresolved(fn, args):
|
|
61
|
+
self.state = Resolving()
|
|
47
62
|
try:
|
|
48
63
|
res = fn(*args)
|
|
49
64
|
except Exception as e:
|
|
50
65
|
self.state = Error(e)
|
|
51
|
-
raise
|
|
66
|
+
raise e from None
|
|
52
67
|
else:
|
|
53
68
|
self.state = Resolved(res)
|
|
54
69
|
return res
|
|
55
|
-
case Resolving(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
raise ValueError(msg)
|
|
59
|
-
return fallback()
|
|
70
|
+
case Resolving():
|
|
71
|
+
msg = "Recursively resolving thunk"
|
|
72
|
+
raise ValueError(msg)
|
|
60
73
|
case Error(e):
|
|
61
74
|
raise e
|
|
62
75
|
|
|
@@ -70,12 +83,11 @@ class Resolved(Generic[T]):
|
|
|
70
83
|
class Unresolved(Generic[T, Unpack[TS]]):
|
|
71
84
|
fn: Callable[[Unpack[TS]], T]
|
|
72
85
|
args: tuple[Unpack[TS]]
|
|
73
|
-
fallback: Callable[[], T] | None
|
|
74
86
|
|
|
75
87
|
|
|
76
88
|
@dataclass
|
|
77
|
-
class Resolving
|
|
78
|
-
|
|
89
|
+
class Resolving:
|
|
90
|
+
pass
|
|
79
91
|
|
|
80
92
|
|
|
81
93
|
@dataclass
|