egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +107 -53
- egglog/builtins.py +49 -6
- egglog/conversion.py +32 -9
- egglog/declarations.py +82 -4
- egglog/egraph.py +260 -179
- egglog/egraph_state.py +149 -66
- egglog/examples/higher_order_functions.py +4 -9
- egglog/exp/array_api.py +278 -93
- egglog/exp/array_api_jit.py +4 -8
- egglog/exp/array_api_loopnest.py +149 -0
- egglog/exp/array_api_numba.py +1 -1
- egglog/exp/array_api_program_gen.py +62 -25
- egglog/exp/program_gen.py +23 -17
- egglog/functionalize.py +91 -0
- egglog/ipython_magic.py +1 -1
- egglog/pretty.py +88 -44
- egglog/runtime.py +53 -40
- egglog/thunk.py +30 -18
- egglog/visualizer.css +1 -0
- egglog/visualizer.js +35774 -0
- egglog/visualizer_widget.py +39 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/METADATA +33 -32
- egglog-8.0.1.dist-info/RECORD +42 -0
- {egglog-7.2.0.dist-info → egglog-8.0.1.dist-info}/WHEEL +1 -1
- egglog/graphviz_widget.py +0 -34
- egglog/widget.css +0 -6
- egglog/widget.js +0 -50
- egglog-7.2.0.dist-info/RECORD +0 -40
- {egglog-7.2.0.dist-info/license_files → egglog-8.0.1.dist-info/licenses}/LICENSE +0 -0
egglog/pretty.py
CHANGED
|
@@ -16,6 +16,7 @@ from .declarations import *
|
|
|
16
16
|
if TYPE_CHECKING:
|
|
17
17
|
from collections.abc import Mapping
|
|
18
18
|
|
|
19
|
+
|
|
19
20
|
__all__ = [
|
|
20
21
|
"pretty_decl",
|
|
21
22
|
"pretty_callable_ref",
|
|
@@ -77,9 +78,9 @@ def pretty_decl(
|
|
|
77
78
|
|
|
78
79
|
This will use re-format the result and put the expression on the last line, preceeded by the statements.
|
|
79
80
|
"""
|
|
80
|
-
traverse = TraverseContext()
|
|
81
|
+
traverse = TraverseContext(decls)
|
|
81
82
|
traverse(decl, toplevel=True)
|
|
82
|
-
pretty = traverse.pretty(
|
|
83
|
+
pretty = traverse.pretty()
|
|
83
84
|
expr = pretty(decl, ruleset_name=ruleset_name)
|
|
84
85
|
if wrapping_fn:
|
|
85
86
|
expr = f"{wrapping_fn}({expr})"
|
|
@@ -106,15 +107,20 @@ def pretty_callable_ref(
|
|
|
106
107
|
"""
|
|
107
108
|
# Pass in three dummy args, which are the max used for any operation that
|
|
108
109
|
# is not a generic function call
|
|
109
|
-
args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
|
|
110
|
+
args: list[ExprDecl] = [VarDecl(ARG_STR, False)] * 3
|
|
110
111
|
if first_arg:
|
|
111
112
|
args.insert(0, first_arg)
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
)
|
|
113
|
+
context = PrettyContext(decls, defaultdict(lambda: 0))
|
|
114
|
+
res = context._call_inner(ref, args, bound_tp_params=bound_tp_params, parens=False)
|
|
115
115
|
# Either returns a function or a function with args. If args are provided, they would just be called,
|
|
116
116
|
# on the function, so return them, because they are dummies
|
|
117
|
-
|
|
117
|
+
if isinstance(res, tuple):
|
|
118
|
+
name = res[0]
|
|
119
|
+
# if this is an unnamed function, return it but don't partially apply any args
|
|
120
|
+
if isinstance(name, UnnamedFunctionRef):
|
|
121
|
+
return context._pretty_function_body(name, [])
|
|
122
|
+
return name
|
|
123
|
+
return res
|
|
118
124
|
|
|
119
125
|
|
|
120
126
|
# TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
|
|
@@ -128,18 +134,20 @@ class TraverseContext:
|
|
|
128
134
|
expression has.
|
|
129
135
|
"""
|
|
130
136
|
|
|
137
|
+
decls: Declarations
|
|
138
|
+
|
|
131
139
|
# All expressions we have seen (incremented the parent counts of all children)
|
|
132
140
|
_seen: set[AllDecls] = field(default_factory=set)
|
|
133
141
|
# The number of parents for each expressions
|
|
134
142
|
parents: Counter[AllDecls] = field(default_factory=Counter)
|
|
135
143
|
|
|
136
|
-
def pretty(self
|
|
144
|
+
def pretty(self) -> PrettyContext:
|
|
137
145
|
"""
|
|
138
146
|
Create a pretty context from the state of this traverse context.
|
|
139
147
|
"""
|
|
140
|
-
return PrettyContext(decls, self.parents)
|
|
148
|
+
return PrettyContext(self.decls, self.parents)
|
|
141
149
|
|
|
142
|
-
def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901
|
|
150
|
+
def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901, PLR0912
|
|
143
151
|
if not toplevel:
|
|
144
152
|
self.parents[decl] += 1
|
|
145
153
|
if decl in self._seen:
|
|
@@ -169,9 +177,13 @@ class TraverseContext:
|
|
|
169
177
|
if isinstance(de, DefaultRewriteDecl):
|
|
170
178
|
continue
|
|
171
179
|
self(de)
|
|
172
|
-
case CallDecl(
|
|
173
|
-
|
|
174
|
-
|
|
180
|
+
case CallDecl(ref, exprs, _):
|
|
181
|
+
match ref:
|
|
182
|
+
case FunctionRef(UnnamedFunctionRef(_, res)):
|
|
183
|
+
self(res.expr)
|
|
184
|
+
case _:
|
|
185
|
+
for e in exprs:
|
|
186
|
+
self(e.expr)
|
|
175
187
|
case RunDecl(_, until):
|
|
176
188
|
if until:
|
|
177
189
|
for f in until:
|
|
@@ -224,7 +236,7 @@ class PrettyContext:
|
|
|
224
236
|
return expr_name
|
|
225
237
|
return expr
|
|
226
238
|
|
|
227
|
-
def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: PLR0911
|
|
239
|
+
def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: C901, PLR0911, PLR0912
|
|
228
240
|
match decl:
|
|
229
241
|
case LitDecl(value):
|
|
230
242
|
match value:
|
|
@@ -244,8 +256,7 @@ class PrettyContext:
|
|
|
244
256
|
case CallDecl(_, _, _):
|
|
245
257
|
return self._call(decl, parens)
|
|
246
258
|
case PartialCallDecl(CallDecl(ref, typed_args, _)):
|
|
247
|
-
|
|
248
|
-
return f"UnstableFn({', '.join(arg_strs)})", "fn"
|
|
259
|
+
return self._pretty_partial(ref, [a.expr for a in typed_args]), "fn"
|
|
249
260
|
case PyObjectDecl(value):
|
|
250
261
|
return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
|
|
251
262
|
case ActionCommandDecl(action):
|
|
@@ -352,12 +363,16 @@ class PrettyContext:
|
|
|
352
363
|
has_multiple_parents = self.parents[first_arg] > 1
|
|
353
364
|
self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
|
|
354
365
|
# Set the first arg to be the name of the mutated arg and return the name
|
|
355
|
-
args[0] = VarDecl(expr_name)
|
|
366
|
+
args[0] = VarDecl(expr_name, True)
|
|
356
367
|
else:
|
|
357
368
|
expr_name = None
|
|
358
369
|
res = self._call_inner(ref, args, decl.bound_tp_params, parens)
|
|
359
370
|
expr = (
|
|
360
|
-
|
|
371
|
+
(
|
|
372
|
+
f"{name}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})"
|
|
373
|
+
if isinstance((name := res[0]), str)
|
|
374
|
+
else ((called := self._pretty_function_body(name, res[1])) if not parens else f"({called})")
|
|
375
|
+
)
|
|
361
376
|
if isinstance(res, tuple)
|
|
362
377
|
else res
|
|
363
378
|
)
|
|
@@ -367,9 +382,13 @@ class PrettyContext:
|
|
|
367
382
|
return expr_name, tp_name
|
|
368
383
|
return expr, tp_name
|
|
369
384
|
|
|
370
|
-
def _call_inner( # noqa: PLR0911
|
|
371
|
-
self,
|
|
372
|
-
|
|
385
|
+
def _call_inner( # noqa: C901, PLR0911, PLR0912
|
|
386
|
+
self,
|
|
387
|
+
ref: CallableRef,
|
|
388
|
+
args: list[ExprDecl],
|
|
389
|
+
bound_tp_params: tuple[JustTypeRef, ...] | None,
|
|
390
|
+
parens: bool,
|
|
391
|
+
) -> tuple[str | UnnamedFunctionRef, list[ExprDecl]] | str:
|
|
373
392
|
"""
|
|
374
393
|
Pretty print the call, returning either the full function call or a tuple of the function and the args.
|
|
375
394
|
"""
|
|
@@ -408,6 +427,8 @@ class PrettyContext:
|
|
|
408
427
|
case InitRef(class_name):
|
|
409
428
|
tp_ref = JustTypeRef(class_name, bound_tp_params or ())
|
|
410
429
|
return str(tp_ref), args
|
|
430
|
+
case UnnamedFunctionRef():
|
|
431
|
+
return ref, args
|
|
411
432
|
assert_never(ref)
|
|
412
433
|
|
|
413
434
|
def _generate_name(self, typ: str) -> str:
|
|
@@ -428,29 +449,52 @@ class PrettyContext:
|
|
|
428
449
|
self.statements.append(f"{name} = {expr_str}")
|
|
429
450
|
return name
|
|
430
451
|
|
|
452
|
+
def _pretty_partial(self, ref: CallableRef, args: list[ExprDecl]) -> str:
|
|
453
|
+
"""
|
|
454
|
+
Returns a partial function call as a string.
|
|
455
|
+
"""
|
|
456
|
+
match ref:
|
|
457
|
+
case FunctionRef(name):
|
|
458
|
+
fn = name
|
|
459
|
+
case UnnamedFunctionRef():
|
|
460
|
+
return self._pretty_function_body(ref, args)
|
|
461
|
+
case (
|
|
462
|
+
ClassMethodRef(class_name, method_name)
|
|
463
|
+
| MethodRef(class_name, method_name)
|
|
464
|
+
| PropertyRef(class_name, method_name)
|
|
465
|
+
):
|
|
466
|
+
fn = f"{class_name}.{method_name}"
|
|
467
|
+
case InitRef(class_name):
|
|
468
|
+
fn = class_name
|
|
469
|
+
case ConstantRef(_):
|
|
470
|
+
msg = "Constants should not be callable"
|
|
471
|
+
raise NotImplementedError(msg)
|
|
472
|
+
case ClassVariableRef(_, _):
|
|
473
|
+
msg = "Class variables should not be callable"
|
|
474
|
+
raise NotADirectoryError(msg)
|
|
475
|
+
case _:
|
|
476
|
+
assert_never(ref)
|
|
477
|
+
if not args:
|
|
478
|
+
return fn
|
|
479
|
+
arg_strs = (
|
|
480
|
+
fn,
|
|
481
|
+
*(self(a, parens=False, unwrap_lit=True) for a in args),
|
|
482
|
+
)
|
|
483
|
+
return f"partial({', '.join(arg_strs)})"
|
|
431
484
|
|
|
432
|
-
def
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
case InitRef(class_name):
|
|
446
|
-
return class_name
|
|
447
|
-
case ConstantRef(_):
|
|
448
|
-
msg = "Constants should not be callable"
|
|
449
|
-
raise NotImplementedError(msg)
|
|
450
|
-
case ClassVariableRef(_, _):
|
|
451
|
-
msg = "Class variables should not be callable"
|
|
452
|
-
raise NotADirectoryError(msg)
|
|
453
|
-
assert_never(ref)
|
|
485
|
+
def _pretty_function_body(self, fn: UnnamedFunctionRef, args: list[ExprDecl]) -> str:
|
|
486
|
+
"""
|
|
487
|
+
Pretty print the body of a function, partially applying some arguments.
|
|
488
|
+
"""
|
|
489
|
+
var_args = fn.args
|
|
490
|
+
replacements = {var_arg: TypedExprDecl(var_arg.tp, arg) for var_arg, arg in zip(var_args, args, strict=False)}
|
|
491
|
+
var_args = var_args[len(args) :]
|
|
492
|
+
res = replace_typed_expr(fn.res, replacements)
|
|
493
|
+
arg_names = fn.args[len(args) :]
|
|
494
|
+
prefix = "lambda"
|
|
495
|
+
if arg_names:
|
|
496
|
+
prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
|
|
497
|
+
return f"{prefix}: {self(res.expr)}"
|
|
454
498
|
|
|
455
499
|
|
|
456
500
|
def _plot_line_length(expr: object): # pragma: no cover
|
|
@@ -469,6 +513,6 @@ def _plot_line_length(expr: object): # pragma: no cover
|
|
|
469
513
|
new_l = len(str(expr).split())
|
|
470
514
|
sizes.append((line_length, diff, new_l))
|
|
471
515
|
|
|
472
|
-
df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"])
|
|
516
|
+
df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"])
|
|
473
517
|
|
|
474
518
|
return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")
|
egglog/runtime.py
CHANGED
|
@@ -81,7 +81,7 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
|
|
|
81
81
|
return resolve_type_annotation(decls, first)
|
|
82
82
|
|
|
83
83
|
# If the type is `object` then this is assumed to be a PyObjectLike, i.e. converted into a PyObject
|
|
84
|
-
if tp
|
|
84
|
+
if tp is object:
|
|
85
85
|
assert _PY_OBJECT_CLASS
|
|
86
86
|
return resolve_type_annotation(decls, _PY_OBJECT_CLASS)
|
|
87
87
|
# If the type is a `Callable` then convert it into a UnstableFn
|
|
@@ -127,8 +127,8 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
127
127
|
# If this is a literal type, initializing it with a literal should return a literal
|
|
128
128
|
if (name := self.__egg_tp__.name) == "PyObject":
|
|
129
129
|
assert len(args) == 1
|
|
130
|
-
return RuntimeExpr
|
|
131
|
-
self.
|
|
130
|
+
return RuntimeExpr(
|
|
131
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])))
|
|
132
132
|
)
|
|
133
133
|
if name == "UnstableFn":
|
|
134
134
|
assert not kwargs
|
|
@@ -150,20 +150,20 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
150
150
|
value = PartialCallDecl(replace(call, args=call.args[:n_args]))
|
|
151
151
|
remaining_arg_types = [a.tp for a in call.args[n_args:]]
|
|
152
152
|
type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
|
|
153
|
-
return RuntimeExpr.
|
|
153
|
+
return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
|
|
154
154
|
|
|
155
155
|
if name in UNARY_LIT_CLASS_NAMES:
|
|
156
156
|
assert len(args) == 1
|
|
157
157
|
assert isinstance(args[0], int | float | str | bool)
|
|
158
|
-
return RuntimeExpr
|
|
159
|
-
self.
|
|
158
|
+
return RuntimeExpr(
|
|
159
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])))
|
|
160
160
|
)
|
|
161
161
|
if name == UNIT_CLASS_NAME:
|
|
162
162
|
assert len(args) == 0
|
|
163
|
-
return RuntimeExpr
|
|
164
|
-
self.
|
|
163
|
+
return RuntimeExpr(
|
|
164
|
+
self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)))
|
|
165
165
|
)
|
|
166
|
-
fn = RuntimeFunction(Thunk.value(
|
|
166
|
+
fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(InitRef(name)), self.__egg_tp__.to_just())
|
|
167
167
|
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
168
168
|
|
|
169
169
|
def __dir__(self) -> list[str]:
|
|
@@ -208,19 +208,21 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
208
208
|
# if this is a class variable, return an expr for it, otherwise, assume it's a method
|
|
209
209
|
if name in cls_decl.class_variables:
|
|
210
210
|
return_tp = cls_decl.class_variables[name]
|
|
211
|
-
return RuntimeExpr
|
|
212
|
-
self.
|
|
213
|
-
TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
|
|
211
|
+
return RuntimeExpr(
|
|
212
|
+
self.__egg_decls_thunk__,
|
|
213
|
+
Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name)))),
|
|
214
214
|
)
|
|
215
215
|
if name in cls_decl.class_methods:
|
|
216
216
|
return RuntimeFunction(
|
|
217
|
-
|
|
217
|
+
self.__egg_decls_thunk__,
|
|
218
|
+
Thunk.value(ClassMethodRef(self.__egg_tp__.name, name)),
|
|
219
|
+
self.__egg_tp__.to_just(),
|
|
218
220
|
)
|
|
219
221
|
# allow referencing properties and methods as class variables as well
|
|
220
222
|
if name in cls_decl.properties:
|
|
221
|
-
return RuntimeFunction(Thunk.value(
|
|
223
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.name, name)))
|
|
222
224
|
if name in cls_decl.methods:
|
|
223
|
-
return RuntimeFunction(Thunk.value(
|
|
225
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name)))
|
|
224
226
|
|
|
225
227
|
msg = f"Class {self.__egg_tp__.name} has no method {name}"
|
|
226
228
|
if name == "__ne__":
|
|
@@ -241,16 +243,23 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
241
243
|
|
|
242
244
|
@dataclass
|
|
243
245
|
class RuntimeFunction(DelayedDeclerations):
|
|
244
|
-
|
|
246
|
+
__egg_ref_thunk__: Callable[[], CallableRef]
|
|
245
247
|
# bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
|
|
246
248
|
__egg_bound__: JustTypeRef | RuntimeExpr | None = None
|
|
247
249
|
|
|
250
|
+
@property
|
|
251
|
+
def __egg_ref__(self) -> CallableRef:
|
|
252
|
+
return self.__egg_ref_thunk__()
|
|
253
|
+
|
|
248
254
|
def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
|
|
249
255
|
from .conversion import resolve_literal
|
|
250
256
|
|
|
251
257
|
if isinstance(self.__egg_bound__, RuntimeExpr):
|
|
252
258
|
args = (self.__egg_bound__, *args)
|
|
253
|
-
|
|
259
|
+
try:
|
|
260
|
+
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
|
|
261
|
+
except Exception as e:
|
|
262
|
+
raise TypeError(f"Failed to find callable {self}") from e
|
|
254
263
|
decls = self.__egg_decls__.copy()
|
|
255
264
|
# Special case function application bc we dont support variadic generics yet generally
|
|
256
265
|
if signature == "fn-app":
|
|
@@ -285,7 +294,7 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
285
294
|
args = bound.args
|
|
286
295
|
|
|
287
296
|
upcasted_args = [
|
|
288
|
-
resolve_literal(cast(TypeOrVarRef, tp), arg)
|
|
297
|
+
resolve_literal(cast(TypeOrVarRef, tp), arg, Thunk.value(decls))
|
|
289
298
|
for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
|
|
290
299
|
]
|
|
291
300
|
decls.update(*upcasted_args)
|
|
@@ -322,9 +331,10 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
322
331
|
# If there is not return type, we are mutating the first arg
|
|
323
332
|
if not signature.return_type:
|
|
324
333
|
first_arg = upcasted_args[0]
|
|
325
|
-
first_arg.
|
|
334
|
+
first_arg.__egg_decls_thunk__ = Thunk.value(decls)
|
|
335
|
+
first_arg.__egg_typed_expr_thunk__ = Thunk.value(typed_expr_decl)
|
|
326
336
|
return None
|
|
327
|
-
return RuntimeExpr.
|
|
337
|
+
return RuntimeExpr.__from_values__(decls, typed_expr_decl)
|
|
328
338
|
|
|
329
339
|
def __str__(self) -> str:
|
|
330
340
|
first_arg, bound_tp_params = None, None
|
|
@@ -349,7 +359,9 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
|
|
|
349
359
|
Parameter(
|
|
350
360
|
n,
|
|
351
361
|
Parameter.POSITIONAL_OR_KEYWORD,
|
|
352
|
-
default=RuntimeExpr.
|
|
362
|
+
default=RuntimeExpr.__from_values__(
|
|
363
|
+
decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
|
|
364
|
+
)
|
|
353
365
|
if d is not None or optional_args
|
|
354
366
|
else Parameter.empty,
|
|
355
367
|
)
|
|
@@ -387,21 +399,16 @@ PARTIAL_METHODS = {
|
|
|
387
399
|
|
|
388
400
|
|
|
389
401
|
@dataclass
|
|
390
|
-
class RuntimeExpr:
|
|
391
|
-
|
|
392
|
-
__egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
|
|
402
|
+
class RuntimeExpr(DelayedDeclerations):
|
|
403
|
+
__egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
|
|
393
404
|
|
|
394
405
|
@classmethod
|
|
395
|
-
def
|
|
396
|
-
return cls(Thunk.value(
|
|
397
|
-
|
|
398
|
-
@property
|
|
399
|
-
def __egg_decls__(self) -> Declarations:
|
|
400
|
-
return self.__egg_thunk__()[0]
|
|
406
|
+
def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
|
|
407
|
+
return cls(Thunk.value(d), Thunk.value(e))
|
|
401
408
|
|
|
402
409
|
@property
|
|
403
410
|
def __egg_typed_expr__(self) -> TypedExprDecl:
|
|
404
|
-
return self.
|
|
411
|
+
return self.__egg_typed_expr_thunk__()
|
|
405
412
|
|
|
406
413
|
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
407
414
|
cls_name = self.__egg_class_name__
|
|
@@ -411,9 +418,9 @@ class RuntimeExpr:
|
|
|
411
418
|
return preserved_methods[name].__get__(self)
|
|
412
419
|
|
|
413
420
|
if name in class_decl.methods:
|
|
414
|
-
return RuntimeFunction(Thunk.value(
|
|
421
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
|
|
415
422
|
if name in class_decl.properties:
|
|
416
|
-
return RuntimeFunction(Thunk.value(
|
|
423
|
+
return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
|
|
417
424
|
raise AttributeError(f"{cls_name} has no method {name}") from None
|
|
418
425
|
|
|
419
426
|
def __repr__(self) -> str:
|
|
@@ -456,10 +463,11 @@ class RuntimeExpr:
|
|
|
456
463
|
# otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
|
|
457
464
|
|
|
458
465
|
def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
|
|
459
|
-
return self.
|
|
466
|
+
return self.__egg_decls__, self.__egg_typed_expr__
|
|
460
467
|
|
|
461
468
|
def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
|
|
462
|
-
self.
|
|
469
|
+
self.__egg_decls_thunk__ = Thunk.value(d[0])
|
|
470
|
+
self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
|
|
463
471
|
|
|
464
472
|
def __hash__(self) -> int:
|
|
465
473
|
return hash(self.__egg_typed_expr__)
|
|
@@ -497,7 +505,7 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
|
|
|
497
505
|
# we use the standard process
|
|
498
506
|
pass
|
|
499
507
|
if __name in class_decl.methods:
|
|
500
|
-
fn = RuntimeFunction(Thunk.value(
|
|
508
|
+
fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
|
|
501
509
|
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
502
510
|
if __name in PARTIAL_METHODS:
|
|
503
511
|
return NotImplemented
|
|
@@ -523,7 +531,7 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime
|
|
|
523
531
|
min_tp = min_convertable_tp(slf, other, name)
|
|
524
532
|
slf = resolve_literal(TypeRefWithVars(min_tp), slf)
|
|
525
533
|
other = resolve_literal(TypeRefWithVars(min_tp), other)
|
|
526
|
-
method = RuntimeFunction(Thunk.value(
|
|
534
|
+
method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
|
|
527
535
|
return method(other)
|
|
528
536
|
|
|
529
537
|
|
|
@@ -545,7 +553,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
|
|
|
545
553
|
"""
|
|
546
554
|
match callable:
|
|
547
555
|
case RuntimeFunction(decls, ref, _):
|
|
548
|
-
return ref, decls()
|
|
556
|
+
return ref(), decls()
|
|
549
557
|
case RuntimeClass(thunk, tp):
|
|
550
|
-
return
|
|
551
|
-
|
|
558
|
+
return InitRef(tp.name), thunk()
|
|
559
|
+
case RuntimeExpr(decl_thunk, expr_thunk):
|
|
560
|
+
if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(expr.callable, ConstantRef):
|
|
561
|
+
raise NotImplementedError(f"Can only turn constants into callable refs, not {expr}")
|
|
562
|
+
return expr.callable, decl_thunk()
|
|
563
|
+
case _:
|
|
564
|
+
raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
|
egglog/thunk.py
CHANGED
|
@@ -9,10 +9,27 @@ if TYPE_CHECKING:
|
|
|
9
9
|
from collections.abc import Callable
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
__all__ = ["Thunk"]
|
|
12
|
+
__all__ = ["Thunk", "split_thunk"]
|
|
13
13
|
|
|
14
14
|
T = TypeVar("T")
|
|
15
15
|
TS = TypeVarTuple("TS")
|
|
16
|
+
V = TypeVar("V")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def split_thunk(fn: Callable[[], tuple[T, V]]) -> tuple[Callable[[], T], Callable[[], V]]:
|
|
20
|
+
s = _Split(fn)
|
|
21
|
+
return s.left, s.right
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class _Split(Generic[T, V]):
|
|
26
|
+
fn: Callable[[], tuple[T, V]]
|
|
27
|
+
|
|
28
|
+
def left(self) -> T:
|
|
29
|
+
return self.fn()[0]
|
|
30
|
+
|
|
31
|
+
def right(self) -> V:
|
|
32
|
+
return self.fn()[1]
|
|
16
33
|
|
|
17
34
|
|
|
18
35
|
@dataclass
|
|
@@ -21,18 +38,16 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
21
38
|
Cached delayed function call.
|
|
22
39
|
"""
|
|
23
40
|
|
|
24
|
-
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving
|
|
41
|
+
state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
|
|
25
42
|
|
|
26
43
|
@classmethod
|
|
27
|
-
def fn(
|
|
28
|
-
cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], fallback: Callable[[], T] | None = None
|
|
29
|
-
) -> Thunk[T, Unpack[TS]]:
|
|
44
|
+
def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
|
|
30
45
|
"""
|
|
31
46
|
Create a thunk based on some functions and some partial args.
|
|
32
47
|
|
|
33
|
-
If the function is called while it is being resolved recursively
|
|
48
|
+
If the function is called while it is being resolved recursively it will raise an exception.
|
|
34
49
|
"""
|
|
35
|
-
return cls(Unresolved(fn, args
|
|
50
|
+
return cls(Unresolved(fn, args))
|
|
36
51
|
|
|
37
52
|
@classmethod
|
|
38
53
|
def value(cls, value: T) -> Thunk[T]:
|
|
@@ -42,21 +57,19 @@ class Thunk(Generic[T, Unpack[TS]]):
|
|
|
42
57
|
match self.state:
|
|
43
58
|
case Resolved(value):
|
|
44
59
|
return value
|
|
45
|
-
case Unresolved(fn, args
|
|
46
|
-
self.state = Resolving(
|
|
60
|
+
case Unresolved(fn, args):
|
|
61
|
+
self.state = Resolving()
|
|
47
62
|
try:
|
|
48
63
|
res = fn(*args)
|
|
49
64
|
except Exception as e:
|
|
50
65
|
self.state = Error(e)
|
|
51
|
-
raise
|
|
66
|
+
raise e from None
|
|
52
67
|
else:
|
|
53
68
|
self.state = Resolved(res)
|
|
54
69
|
return res
|
|
55
|
-
case Resolving(
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
raise ValueError(msg)
|
|
59
|
-
return fallback()
|
|
70
|
+
case Resolving():
|
|
71
|
+
msg = "Recursively resolving thunk"
|
|
72
|
+
raise ValueError(msg)
|
|
60
73
|
case Error(e):
|
|
61
74
|
raise e
|
|
62
75
|
|
|
@@ -70,12 +83,11 @@ class Resolved(Generic[T]):
|
|
|
70
83
|
class Unresolved(Generic[T, Unpack[TS]]):
|
|
71
84
|
fn: Callable[[Unpack[TS]], T]
|
|
72
85
|
args: tuple[Unpack[TS]]
|
|
73
|
-
fallback: Callable[[], T] | None
|
|
74
86
|
|
|
75
87
|
|
|
76
88
|
@dataclass
|
|
77
|
-
class Resolving
|
|
78
|
-
|
|
89
|
+
class Resolving:
|
|
90
|
+
pass
|
|
79
91
|
|
|
80
92
|
|
|
81
93
|
@dataclass
|