egglog 7.2.0__cp312-none-win_amd64.whl → 8.0.1__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/pretty.py CHANGED
@@ -16,6 +16,7 @@ from .declarations import *
16
16
  if TYPE_CHECKING:
17
17
  from collections.abc import Mapping
18
18
 
19
+
19
20
  __all__ = [
20
21
  "pretty_decl",
21
22
  "pretty_callable_ref",
@@ -77,9 +78,9 @@ def pretty_decl(
77
78
 
78
79
  This will use re-format the result and put the expression on the last line, preceeded by the statements.
79
80
  """
80
- traverse = TraverseContext()
81
+ traverse = TraverseContext(decls)
81
82
  traverse(decl, toplevel=True)
82
- pretty = traverse.pretty(decls)
83
+ pretty = traverse.pretty()
83
84
  expr = pretty(decl, ruleset_name=ruleset_name)
84
85
  if wrapping_fn:
85
86
  expr = f"{wrapping_fn}({expr})"
@@ -106,15 +107,20 @@ def pretty_callable_ref(
106
107
  """
107
108
  # Pass in three dummy args, which are the max used for any operation that
108
109
  # is not a generic function call
109
- args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
110
+ args: list[ExprDecl] = [VarDecl(ARG_STR, False)] * 3
110
111
  if first_arg:
111
112
  args.insert(0, first_arg)
112
- res = PrettyContext(decls, defaultdict(lambda: 0))._call_inner(
113
- ref, args, bound_tp_params=bound_tp_params, parens=False
114
- )
113
+ context = PrettyContext(decls, defaultdict(lambda: 0))
114
+ res = context._call_inner(ref, args, bound_tp_params=bound_tp_params, parens=False)
115
115
  # Either returns a function or a function with args. If args are provided, they would just be called,
116
116
  # on the function, so return them, because they are dummies
117
- return res[0] if isinstance(res, tuple) else res
117
+ if isinstance(res, tuple):
118
+ name = res[0]
119
+ # if this is an unnamed function, return it but don't partially apply any args
120
+ if isinstance(name, UnnamedFunctionRef):
121
+ return context._pretty_function_body(name, [])
122
+ return name
123
+ return res
118
124
 
119
125
 
120
126
  # TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
@@ -128,18 +134,20 @@ class TraverseContext:
128
134
  expression has.
129
135
  """
130
136
 
137
+ decls: Declarations
138
+
131
139
  # All expressions we have seen (incremented the parent counts of all children)
132
140
  _seen: set[AllDecls] = field(default_factory=set)
133
141
  # The number of parents for each expressions
134
142
  parents: Counter[AllDecls] = field(default_factory=Counter)
135
143
 
136
- def pretty(self, decls: Declarations) -> PrettyContext:
144
+ def pretty(self) -> PrettyContext:
137
145
  """
138
146
  Create a pretty context from the state of this traverse context.
139
147
  """
140
- return PrettyContext(decls, self.parents)
148
+ return PrettyContext(self.decls, self.parents)
141
149
 
142
- def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901
150
+ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C901, PLR0912
143
151
  if not toplevel:
144
152
  self.parents[decl] += 1
145
153
  if decl in self._seen:
@@ -169,9 +177,13 @@ class TraverseContext:
169
177
  if isinstance(de, DefaultRewriteDecl):
170
178
  continue
171
179
  self(de)
172
- case CallDecl(_, exprs, _):
173
- for e in exprs:
174
- self(e.expr)
180
+ case CallDecl(ref, exprs, _):
181
+ match ref:
182
+ case FunctionRef(UnnamedFunctionRef(_, res)):
183
+ self(res.expr)
184
+ case _:
185
+ for e in exprs:
186
+ self(e.expr)
175
187
  case RunDecl(_, until):
176
188
  if until:
177
189
  for f in until:
@@ -224,7 +236,7 @@ class PrettyContext:
224
236
  return expr_name
225
237
  return expr
226
238
 
227
- def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: PLR0911
239
+ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: C901, PLR0911, PLR0912
228
240
  match decl:
229
241
  case LitDecl(value):
230
242
  match value:
@@ -244,8 +256,7 @@ class PrettyContext:
244
256
  case CallDecl(_, _, _):
245
257
  return self._call(decl, parens)
246
258
  case PartialCallDecl(CallDecl(ref, typed_args, _)):
247
- arg_strs = (_pretty_callable(ref), *(self(a.expr, parens=False, unwrap_lit=True) for a in typed_args))
248
- return f"UnstableFn({', '.join(arg_strs)})", "fn"
259
+ return self._pretty_partial(ref, [a.expr for a in typed_args]), "fn"
249
260
  case PyObjectDecl(value):
250
261
  return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
251
262
  case ActionCommandDecl(action):
@@ -352,12 +363,16 @@ class PrettyContext:
352
363
  has_multiple_parents = self.parents[first_arg] > 1
353
364
  self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
354
365
  # Set the first arg to be the name of the mutated arg and return the name
355
- args[0] = VarDecl(expr_name)
366
+ args[0] = VarDecl(expr_name, True)
356
367
  else:
357
368
  expr_name = None
358
369
  res = self._call_inner(ref, args, decl.bound_tp_params, parens)
359
370
  expr = (
360
- f"{res[0]}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})"
371
+ (
372
+ f"{name}({', '.join(self(a, parens=False, unwrap_lit=True) for a in res[1])})"
373
+ if isinstance((name := res[0]), str)
374
+ else ((called := self._pretty_function_body(name, res[1])) if not parens else f"({called})")
375
+ )
361
376
  if isinstance(res, tuple)
362
377
  else res
363
378
  )
@@ -367,9 +382,13 @@ class PrettyContext:
367
382
  return expr_name, tp_name
368
383
  return expr, tp_name
369
384
 
370
- def _call_inner( # noqa: PLR0911
371
- self, ref: CallableRef, args: list[ExprDecl], bound_tp_params: tuple[JustTypeRef, ...] | None, parens: bool
372
- ) -> tuple[str, list[ExprDecl]] | str:
385
+ def _call_inner( # noqa: C901, PLR0911, PLR0912
386
+ self,
387
+ ref: CallableRef,
388
+ args: list[ExprDecl],
389
+ bound_tp_params: tuple[JustTypeRef, ...] | None,
390
+ parens: bool,
391
+ ) -> tuple[str | UnnamedFunctionRef, list[ExprDecl]] | str:
373
392
  """
374
393
  Pretty print the call, returning either the full function call or a tuple of the function and the args.
375
394
  """
@@ -408,6 +427,8 @@ class PrettyContext:
408
427
  case InitRef(class_name):
409
428
  tp_ref = JustTypeRef(class_name, bound_tp_params or ())
410
429
  return str(tp_ref), args
430
+ case UnnamedFunctionRef():
431
+ return ref, args
411
432
  assert_never(ref)
412
433
 
413
434
  def _generate_name(self, typ: str) -> str:
@@ -428,29 +449,52 @@ class PrettyContext:
428
449
  self.statements.append(f"{name} = {expr_str}")
429
450
  return name
430
451
 
452
+ def _pretty_partial(self, ref: CallableRef, args: list[ExprDecl]) -> str:
453
+ """
454
+ Returns a partial function call as a string.
455
+ """
456
+ match ref:
457
+ case FunctionRef(name):
458
+ fn = name
459
+ case UnnamedFunctionRef():
460
+ return self._pretty_function_body(ref, args)
461
+ case (
462
+ ClassMethodRef(class_name, method_name)
463
+ | MethodRef(class_name, method_name)
464
+ | PropertyRef(class_name, method_name)
465
+ ):
466
+ fn = f"{class_name}.{method_name}"
467
+ case InitRef(class_name):
468
+ fn = class_name
469
+ case ConstantRef(_):
470
+ msg = "Constants should not be callable"
471
+ raise NotImplementedError(msg)
472
+ case ClassVariableRef(_, _):
473
+ msg = "Class variables should not be callable"
474
+ raise NotADirectoryError(msg)
475
+ case _:
476
+ assert_never(ref)
477
+ if not args:
478
+ return fn
479
+ arg_strs = (
480
+ fn,
481
+ *(self(a, parens=False, unwrap_lit=True) for a in args),
482
+ )
483
+ return f"partial({', '.join(arg_strs)})"
431
484
 
432
- def _pretty_callable(ref: CallableRef) -> str:
433
- """
434
- Returns a function call as a string.
435
- """
436
- match ref:
437
- case FunctionRef(name):
438
- return name
439
- case (
440
- ClassMethodRef(class_name, method_name)
441
- | MethodRef(class_name, method_name)
442
- | PropertyRef(class_name, method_name)
443
- ):
444
- return f"{class_name}.{method_name}"
445
- case InitRef(class_name):
446
- return class_name
447
- case ConstantRef(_):
448
- msg = "Constants should not be callable"
449
- raise NotImplementedError(msg)
450
- case ClassVariableRef(_, _):
451
- msg = "Class variables should not be callable"
452
- raise NotADirectoryError(msg)
453
- assert_never(ref)
485
+ def _pretty_function_body(self, fn: UnnamedFunctionRef, args: list[ExprDecl]) -> str:
486
+ """
487
+ Pretty print the body of a function, partially applying some arguments.
488
+ """
489
+ var_args = fn.args
490
+ replacements = {var_arg: TypedExprDecl(var_arg.tp, arg) for var_arg, arg in zip(var_args, args, strict=False)}
491
+ var_args = var_args[len(args) :]
492
+ res = replace_typed_expr(fn.res, replacements)
493
+ arg_names = fn.args[len(args) :]
494
+ prefix = "lambda"
495
+ if arg_names:
496
+ prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
497
+ return f"{prefix}: {self(res.expr)}"
454
498
 
455
499
 
456
500
  def _plot_line_length(expr: object): # pragma: no cover
@@ -469,6 +513,6 @@ def _plot_line_length(expr: object): # pragma: no cover
469
513
  new_l = len(str(expr).split())
470
514
  sizes.append((line_length, diff, new_l))
471
515
 
472
- df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"]) # noqa: PD901
516
+ df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"])
473
517
 
474
518
  return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")
egglog/runtime.py CHANGED
@@ -81,7 +81,7 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
81
81
  return resolve_type_annotation(decls, first)
82
82
 
83
83
  # If the type is `object` then this is assumed to be a PyObjectLike, i.e. converted into a PyObject
84
- if tp == object:
84
+ if tp is object:
85
85
  assert _PY_OBJECT_CLASS
86
86
  return resolve_type_annotation(decls, _PY_OBJECT_CLASS)
87
87
  # If the type is a `Callable` then convert it into a UnstableFn
@@ -127,8 +127,8 @@ class RuntimeClass(DelayedDeclerations):
127
127
  # If this is a literal type, initializing it with a literal should return a literal
128
128
  if (name := self.__egg_tp__.name) == "PyObject":
129
129
  assert len(args) == 1
130
- return RuntimeExpr.__from_value__(
131
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
130
+ return RuntimeExpr(
131
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])))
132
132
  )
133
133
  if name == "UnstableFn":
134
134
  assert not kwargs
@@ -150,20 +150,20 @@ class RuntimeClass(DelayedDeclerations):
150
150
  value = PartialCallDecl(replace(call, args=call.args[:n_args]))
151
151
  remaining_arg_types = [a.tp for a in call.args[n_args:]]
152
152
  type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
153
- return RuntimeExpr.__from_value__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
153
+ return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
154
154
 
155
155
  if name in UNARY_LIT_CLASS_NAMES:
156
156
  assert len(args) == 1
157
157
  assert isinstance(args[0], int | float | str | bool)
158
- return RuntimeExpr.__from_value__(
159
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
158
+ return RuntimeExpr(
159
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])))
160
160
  )
161
161
  if name == UNIT_CLASS_NAME:
162
162
  assert len(args) == 0
163
- return RuntimeExpr.__from_value__(
164
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
163
+ return RuntimeExpr(
164
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)))
165
165
  )
166
- fn = RuntimeFunction(Thunk.value(self.__egg_decls__), InitRef(name), self.__egg_tp__.to_just())
166
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(InitRef(name)), self.__egg_tp__.to_just())
167
167
  return fn(*args, **kwargs) # type: ignore[arg-type]
168
168
 
169
169
  def __dir__(self) -> list[str]:
@@ -208,19 +208,21 @@ class RuntimeClass(DelayedDeclerations):
208
208
  # if this is a class variable, return an expr for it, otherwise, assume it's a method
209
209
  if name in cls_decl.class_variables:
210
210
  return_tp = cls_decl.class_variables[name]
211
- return RuntimeExpr.__from_value__(
212
- self.__egg_decls__,
213
- TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
211
+ return RuntimeExpr(
212
+ self.__egg_decls_thunk__,
213
+ Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name)))),
214
214
  )
215
215
  if name in cls_decl.class_methods:
216
216
  return RuntimeFunction(
217
- Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
217
+ self.__egg_decls_thunk__,
218
+ Thunk.value(ClassMethodRef(self.__egg_tp__.name, name)),
219
+ self.__egg_tp__.to_just(),
218
220
  )
219
221
  # allow referencing properties and methods as class variables as well
220
222
  if name in cls_decl.properties:
221
- return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(self.__egg_tp__.name, name))
223
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.name, name)))
222
224
  if name in cls_decl.methods:
223
- return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(self.__egg_tp__.name, name))
225
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name)))
224
226
 
225
227
  msg = f"Class {self.__egg_tp__.name} has no method {name}"
226
228
  if name == "__ne__":
@@ -241,16 +243,23 @@ class RuntimeClass(DelayedDeclerations):
241
243
 
242
244
  @dataclass
243
245
  class RuntimeFunction(DelayedDeclerations):
244
- __egg_ref__: CallableRef
246
+ __egg_ref_thunk__: Callable[[], CallableRef]
245
247
  # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
246
248
  __egg_bound__: JustTypeRef | RuntimeExpr | None = None
247
249
 
250
+ @property
251
+ def __egg_ref__(self) -> CallableRef:
252
+ return self.__egg_ref_thunk__()
253
+
248
254
  def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
249
255
  from .conversion import resolve_literal
250
256
 
251
257
  if isinstance(self.__egg_bound__, RuntimeExpr):
252
258
  args = (self.__egg_bound__, *args)
253
- signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
259
+ try:
260
+ signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
261
+ except Exception as e:
262
+ raise TypeError(f"Failed to find callable {self}") from e
254
263
  decls = self.__egg_decls__.copy()
255
264
  # Special case function application bc we dont support variadic generics yet generally
256
265
  if signature == "fn-app":
@@ -285,7 +294,7 @@ class RuntimeFunction(DelayedDeclerations):
285
294
  args = bound.args
286
295
 
287
296
  upcasted_args = [
288
- resolve_literal(cast(TypeOrVarRef, tp), arg)
297
+ resolve_literal(cast(TypeOrVarRef, tp), arg, Thunk.value(decls))
289
298
  for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
290
299
  ]
291
300
  decls.update(*upcasted_args)
@@ -322,9 +331,10 @@ class RuntimeFunction(DelayedDeclerations):
322
331
  # If there is not return type, we are mutating the first arg
323
332
  if not signature.return_type:
324
333
  first_arg = upcasted_args[0]
325
- first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
334
+ first_arg.__egg_decls_thunk__ = Thunk.value(decls)
335
+ first_arg.__egg_typed_expr_thunk__ = Thunk.value(typed_expr_decl)
326
336
  return None
327
- return RuntimeExpr.__from_value__(decls, typed_expr_decl)
337
+ return RuntimeExpr.__from_values__(decls, typed_expr_decl)
328
338
 
329
339
  def __str__(self) -> str:
330
340
  first_arg, bound_tp_params = None, None
@@ -349,7 +359,9 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
349
359
  Parameter(
350
360
  n,
351
361
  Parameter.POSITIONAL_OR_KEYWORD,
352
- default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n)))
362
+ default=RuntimeExpr.__from_values__(
363
+ decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
364
+ )
353
365
  if d is not None or optional_args
354
366
  else Parameter.empty,
355
367
  )
@@ -387,21 +399,16 @@ PARTIAL_METHODS = {
387
399
 
388
400
 
389
401
  @dataclass
390
- class RuntimeExpr:
391
- # Defer needing decls/expr so we can make constants that don't resolve their class types
392
- __egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
402
+ class RuntimeExpr(DelayedDeclerations):
403
+ __egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
393
404
 
394
405
  @classmethod
395
- def __from_value__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
396
- return cls(Thunk.value((d, e)))
397
-
398
- @property
399
- def __egg_decls__(self) -> Declarations:
400
- return self.__egg_thunk__()[0]
406
+ def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
407
+ return cls(Thunk.value(d), Thunk.value(e))
401
408
 
402
409
  @property
403
410
  def __egg_typed_expr__(self) -> TypedExprDecl:
404
- return self.__egg_thunk__()[1]
411
+ return self.__egg_typed_expr_thunk__()
405
412
 
406
413
  def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
407
414
  cls_name = self.__egg_class_name__
@@ -411,9 +418,9 @@ class RuntimeExpr:
411
418
  return preserved_methods[name].__get__(self)
412
419
 
413
420
  if name in class_decl.methods:
414
- return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(cls_name, name), self)
421
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
415
422
  if name in class_decl.properties:
416
- return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(cls_name, name), self)()
423
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
417
424
  raise AttributeError(f"{cls_name} has no method {name}") from None
418
425
 
419
426
  def __repr__(self) -> str:
@@ -456,10 +463,11 @@ class RuntimeExpr:
456
463
  # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
457
464
 
458
465
  def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
459
- return self.__egg_thunk__()
466
+ return self.__egg_decls__, self.__egg_typed_expr__
460
467
 
461
468
  def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
462
- self.__egg_thunk__ = Thunk.value(d)
469
+ self.__egg_decls_thunk__ = Thunk.value(d[0])
470
+ self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
463
471
 
464
472
  def __hash__(self) -> int:
465
473
  return hash(self.__egg_typed_expr__)
@@ -497,7 +505,7 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
497
505
  # we use the standard process
498
506
  pass
499
507
  if __name in class_decl.methods:
500
- fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
508
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
501
509
  return fn(*args, **kwargs) # type: ignore[arg-type]
502
510
  if __name in PARTIAL_METHODS:
503
511
  return NotImplemented
@@ -523,7 +531,7 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime
523
531
  min_tp = min_convertable_tp(slf, other, name)
524
532
  slf = resolve_literal(TypeRefWithVars(min_tp), slf)
525
533
  other = resolve_literal(TypeRefWithVars(min_tp), other)
526
- method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
534
+ method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
527
535
  return method(other)
528
536
 
529
537
 
@@ -545,7 +553,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
545
553
  """
546
554
  match callable:
547
555
  case RuntimeFunction(decls, ref, _):
548
- return ref, decls()
556
+ return ref(), decls()
549
557
  case RuntimeClass(thunk, tp):
550
- return ClassMethodRef(tp.name, "__init__"), thunk()
551
- raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
558
+ return InitRef(tp.name), thunk()
559
+ case RuntimeExpr(decl_thunk, expr_thunk):
560
+ if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(expr.callable, ConstantRef):
561
+ raise NotImplementedError(f"Can only turn constants into callable refs, not {expr}")
562
+ return expr.callable, decl_thunk()
563
+ case _:
564
+ raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
egglog/thunk.py CHANGED
@@ -9,10 +9,27 @@ if TYPE_CHECKING:
9
9
  from collections.abc import Callable
10
10
 
11
11
 
12
- __all__ = ["Thunk"]
12
+ __all__ = ["Thunk", "split_thunk"]
13
13
 
14
14
  T = TypeVar("T")
15
15
  TS = TypeVarTuple("TS")
16
+ V = TypeVar("V")
17
+
18
+
19
+ def split_thunk(fn: Callable[[], tuple[T, V]]) -> tuple[Callable[[], T], Callable[[], V]]:
20
+ s = _Split(fn)
21
+ return s.left, s.right
22
+
23
+
24
+ @dataclass
25
+ class _Split(Generic[T, V]):
26
+ fn: Callable[[], tuple[T, V]]
27
+
28
+ def left(self) -> T:
29
+ return self.fn()[0]
30
+
31
+ def right(self) -> V:
32
+ return self.fn()[1]
16
33
 
17
34
 
18
35
  @dataclass
@@ -21,18 +38,16 @@ class Thunk(Generic[T, Unpack[TS]]):
21
38
  Cached delayed function call.
22
39
  """
23
40
 
24
- state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving[T] | Error
41
+ state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
25
42
 
26
43
  @classmethod
27
- def fn(
28
- cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], fallback: Callable[[], T] | None = None
29
- ) -> Thunk[T, Unpack[TS]]:
44
+ def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
30
45
  """
31
46
  Create a thunk based on some functions and some partial args.
32
47
 
33
- If the function is called while it is being resolved recursively, will instead return the fallback, if provided.
48
+ If the function is called while it is being resolved recursively it will raise an exception.
34
49
  """
35
- return cls(Unresolved(fn, args, fallback))
50
+ return cls(Unresolved(fn, args))
36
51
 
37
52
  @classmethod
38
53
  def value(cls, value: T) -> Thunk[T]:
@@ -42,21 +57,19 @@ class Thunk(Generic[T, Unpack[TS]]):
42
57
  match self.state:
43
58
  case Resolved(value):
44
59
  return value
45
- case Unresolved(fn, args, fallback):
46
- self.state = Resolving(fallback)
60
+ case Unresolved(fn, args):
61
+ self.state = Resolving()
47
62
  try:
48
63
  res = fn(*args)
49
64
  except Exception as e:
50
65
  self.state = Error(e)
51
- raise
66
+ raise e from None
52
67
  else:
53
68
  self.state = Resolved(res)
54
69
  return res
55
- case Resolving(fallback):
56
- if fallback is None:
57
- msg = "Recursively resolving thunk without fallback"
58
- raise ValueError(msg)
59
- return fallback()
70
+ case Resolving():
71
+ msg = "Recursively resolving thunk"
72
+ raise ValueError(msg)
60
73
  case Error(e):
61
74
  raise e
62
75
 
@@ -70,12 +83,11 @@ class Resolved(Generic[T]):
70
83
  class Unresolved(Generic[T, Unpack[TS]]):
71
84
  fn: Callable[[Unpack[TS]], T]
72
85
  args: tuple[Unpack[TS]]
73
- fallback: Callable[[], T] | None
74
86
 
75
87
 
76
88
  @dataclass
77
- class Resolving(Generic[T]):
78
- fallback: Callable[[], T] | None
89
+ class Resolving:
90
+ pass
79
91
 
80
92
 
81
93
  @dataclass