egglog 7.1.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/runtime.py CHANGED
@@ -127,8 +127,8 @@ class RuntimeClass(DelayedDeclerations):
127
127
  # If this is a literal type, initializing it with a literal should return a literal
128
128
  if (name := self.__egg_tp__.name) == "PyObject":
129
129
  assert len(args) == 1
130
- return RuntimeExpr.__from_value__(
131
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
130
+ return RuntimeExpr(
131
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])))
132
132
  )
133
133
  if name == "UnstableFn":
134
134
  assert not kwargs
@@ -150,23 +150,21 @@ class RuntimeClass(DelayedDeclerations):
150
150
  value = PartialCallDecl(replace(call, args=call.args[:n_args]))
151
151
  remaining_arg_types = [a.tp for a in call.args[n_args:]]
152
152
  type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
153
- return RuntimeExpr.__from_value__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
153
+ return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
154
154
 
155
155
  if name in UNARY_LIT_CLASS_NAMES:
156
156
  assert len(args) == 1
157
157
  assert isinstance(args[0], int | float | str | bool)
158
- return RuntimeExpr.__from_value__(
159
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
158
+ return RuntimeExpr(
159
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])))
160
160
  )
161
161
  if name == UNIT_CLASS_NAME:
162
162
  assert len(args) == 0
163
- return RuntimeExpr.__from_value__(
164
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
163
+ return RuntimeExpr(
164
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)))
165
165
  )
166
-
167
- return RuntimeFunction(
168
- Thunk.value(self.__egg_decls__), ClassMethodRef(name, "__init__"), self.__egg_tp__.to_just()
169
- )(*args, **kwargs) # type: ignore[arg-type]
166
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(InitRef(name)), self.__egg_tp__.to_just())
167
+ return fn(*args, **kwargs) # type: ignore[arg-type]
170
168
 
171
169
  def __dir__(self) -> list[str]:
172
170
  cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
@@ -210,19 +208,21 @@ class RuntimeClass(DelayedDeclerations):
210
208
  # if this is a class variable, return an expr for it, otherwise, assume it's a method
211
209
  if name in cls_decl.class_variables:
212
210
  return_tp = cls_decl.class_variables[name]
213
- return RuntimeExpr.__from_value__(
214
- self.__egg_decls__,
215
- TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
211
+ return RuntimeExpr(
212
+ self.__egg_decls_thunk__,
213
+ Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name)))),
216
214
  )
217
215
  if name in cls_decl.class_methods:
218
216
  return RuntimeFunction(
219
- Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
217
+ self.__egg_decls_thunk__,
218
+ Thunk.value(ClassMethodRef(self.__egg_tp__.name, name)),
219
+ self.__egg_tp__.to_just(),
220
220
  )
221
221
  # allow referencing properties and methods as class variables as well
222
222
  if name in cls_decl.properties:
223
- return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(self.__egg_tp__.name, name))
223
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.name, name)))
224
224
  if name in cls_decl.methods:
225
- return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(self.__egg_tp__.name, name))
225
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name)))
226
226
 
227
227
  msg = f"Class {self.__egg_tp__.name} has no method {name}"
228
228
  if name == "__ne__":
@@ -243,16 +243,23 @@ class RuntimeClass(DelayedDeclerations):
243
243
 
244
244
  @dataclass
245
245
  class RuntimeFunction(DelayedDeclerations):
246
- __egg_ref__: CallableRef
246
+ __egg_ref_thunk__: Callable[[], CallableRef]
247
247
  # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
248
248
  __egg_bound__: JustTypeRef | RuntimeExpr | None = None
249
249
 
250
+ @property
251
+ def __egg_ref__(self) -> CallableRef:
252
+ return self.__egg_ref_thunk__()
253
+
250
254
  def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
251
255
  from .conversion import resolve_literal
252
256
 
253
257
  if isinstance(self.__egg_bound__, RuntimeExpr):
254
258
  args = (self.__egg_bound__, *args)
255
- signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
259
+ try:
260
+ signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
261
+ except Exception as e:
262
+ raise TypeError(f"Failed to find callable {self}") from e
256
263
  decls = self.__egg_decls__.copy()
257
264
  # Special case function application bc we dont support variadic generics yet generally
258
265
  if signature == "fn-app":
@@ -277,14 +284,17 @@ class RuntimeFunction(DelayedDeclerations):
277
284
 
278
285
  # Turn all keyword args into positional args
279
286
  py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
280
- bound = py_signature.bind(*args, **kwargs)
287
+ try:
288
+ bound = py_signature.bind(*args, **kwargs)
289
+ except TypeError as err:
290
+ raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err
281
291
  del kwargs
282
292
  bound.apply_defaults()
283
293
  assert not bound.kwargs
284
294
  args = bound.args
285
295
 
286
296
  upcasted_args = [
287
- resolve_literal(cast(TypeOrVarRef, tp), arg)
297
+ resolve_literal(cast(TypeOrVarRef, tp), arg, Thunk.value(decls))
288
298
  for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
289
299
  ]
290
300
  decls.update(*upcasted_args)
@@ -310,7 +320,9 @@ class RuntimeFunction(DelayedDeclerations):
310
320
  return_tp = tcs.infer_return_type(
311
321
  signature.arg_types, signature.semantic_return_type, signature.var_arg_type, arg_types, cls_name
312
322
  )
313
- bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None
323
+ bound_params = (
324
+ cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None
325
+ )
314
326
  # If we were using unstable-app to call a funciton, add that function back as the first arg.
315
327
  if function_value:
316
328
  arg_exprs = (function_value, *arg_exprs)
@@ -319,9 +331,10 @@ class RuntimeFunction(DelayedDeclerations):
319
331
  # If there is not return type, we are mutating the first arg
320
332
  if not signature.return_type:
321
333
  first_arg = upcasted_args[0]
322
- first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
334
+ first_arg.__egg_decls_thunk__ = Thunk.value(decls)
335
+ first_arg.__egg_typed_expr_thunk__ = Thunk.value(typed_expr_decl)
323
336
  return None
324
- return RuntimeExpr.__from_value__(decls, typed_expr_decl)
337
+ return RuntimeExpr.__from_values__(decls, typed_expr_decl)
325
338
 
326
339
  def __str__(self) -> str:
327
340
  first_arg, bound_tp_params = None, None
@@ -346,7 +359,9 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
346
359
  Parameter(
347
360
  n,
348
361
  Parameter.POSITIONAL_OR_KEYWORD,
349
- default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n)))
362
+ default=RuntimeExpr.__from_values__(
363
+ decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
364
+ )
350
365
  if d is not None or optional_args
351
366
  else Parameter.empty,
352
367
  )
@@ -384,21 +399,16 @@ PARTIAL_METHODS = {
384
399
 
385
400
 
386
401
  @dataclass
387
- class RuntimeExpr:
388
- # Defer needing decls/expr so we can make constants that don't resolve their class types
389
- __egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
402
+ class RuntimeExpr(DelayedDeclerations):
403
+ __egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
390
404
 
391
405
  @classmethod
392
- def __from_value__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
393
- return cls(Thunk.value((d, e)))
394
-
395
- @property
396
- def __egg_decls__(self) -> Declarations:
397
- return self.__egg_thunk__()[0]
406
+ def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
407
+ return cls(Thunk.value(d), Thunk.value(e))
398
408
 
399
409
  @property
400
410
  def __egg_typed_expr__(self) -> TypedExprDecl:
401
- return self.__egg_thunk__()[1]
411
+ return self.__egg_typed_expr_thunk__()
402
412
 
403
413
  def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
404
414
  cls_name = self.__egg_class_name__
@@ -408,9 +418,9 @@ class RuntimeExpr:
408
418
  return preserved_methods[name].__get__(self)
409
419
 
410
420
  if name in class_decl.methods:
411
- return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(cls_name, name), self)
421
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
412
422
  if name in class_decl.properties:
413
- return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(cls_name, name), self)()
423
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
414
424
  raise AttributeError(f"{cls_name} has no method {name}") from None
415
425
 
416
426
  def __repr__(self) -> str:
@@ -453,10 +463,11 @@ class RuntimeExpr:
453
463
  # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
454
464
 
455
465
  def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
456
- return self.__egg_thunk__()
466
+ return self.__egg_decls__, self.__egg_typed_expr__
457
467
 
458
468
  def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
459
- self.__egg_thunk__ = Thunk.value(d)
469
+ self.__egg_decls_thunk__ = Thunk.value(d[0])
470
+ self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
460
471
 
461
472
  def __hash__(self) -> int:
462
473
  return hash(self.__egg_typed_expr__)
@@ -494,7 +505,7 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
494
505
  # we use the standard process
495
506
  pass
496
507
  if __name in class_decl.methods:
497
- fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
508
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
498
509
  return fn(*args, **kwargs) # type: ignore[arg-type]
499
510
  if __name in PARTIAL_METHODS:
500
511
  return NotImplemented
@@ -520,7 +531,7 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime
520
531
  min_tp = min_convertable_tp(slf, other, name)
521
532
  slf = resolve_literal(TypeRefWithVars(min_tp), slf)
522
533
  other = resolve_literal(TypeRefWithVars(min_tp), other)
523
- method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
534
+ method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
524
535
  return method(other)
525
536
 
526
537
 
@@ -542,7 +553,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
542
553
  """
543
554
  match callable:
544
555
  case RuntimeFunction(decls, ref, _):
545
- return ref, decls()
556
+ return ref(), decls()
546
557
  case RuntimeClass(thunk, tp):
547
- return ClassMethodRef(tp.name, "__init__"), thunk()
548
- raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
558
+ return InitRef(tp.name), thunk()
559
+ case RuntimeExpr(decl_thunk, expr_thunk):
560
+ if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(expr.callable, ConstantRef):
561
+ raise NotImplementedError(f"Can only turn constants into callable refs, not {expr}")
562
+ return expr.callable, decl_thunk()
563
+ case _:
564
+ raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
egglog/thunk.py CHANGED
@@ -9,10 +9,27 @@ if TYPE_CHECKING:
9
9
  from collections.abc import Callable
10
10
 
11
11
 
12
- __all__ = ["Thunk"]
12
+ __all__ = ["Thunk", "split_thunk"]
13
13
 
14
14
  T = TypeVar("T")
15
15
  TS = TypeVarTuple("TS")
16
+ V = TypeVar("V")
17
+
18
+
19
+ def split_thunk(fn: Callable[[], tuple[T, V]]) -> tuple[Callable[[], T], Callable[[], V]]:
20
+ s = _Split(fn)
21
+ return s.left, s.right
22
+
23
+
24
+ @dataclass
25
+ class _Split(Generic[T, V]):
26
+ fn: Callable[[], tuple[T, V]]
27
+
28
+ def left(self) -> T:
29
+ return self.fn()[0]
30
+
31
+ def right(self) -> V:
32
+ return self.fn()[1]
16
33
 
17
34
 
18
35
  @dataclass
@@ -21,18 +38,16 @@ class Thunk(Generic[T, Unpack[TS]]):
21
38
  Cached delayed function call.
22
39
  """
23
40
 
24
- state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving[T]
41
+ state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
25
42
 
26
43
  @classmethod
27
- def fn(
28
- cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], fallback: Callable[[], T] | None = None
29
- ) -> Thunk[T, Unpack[TS]]:
44
+ def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
30
45
  """
31
46
  Create a thunk based on some functions and some partial args.
32
47
 
33
- If the function is called while it is being resolved recursively, will instead return the fallback, if provided.
48
+ If the function is called while it is being resolved recursively it will raise an exception.
34
49
  """
35
- return cls(Unresolved(fn, args, fallback))
50
+ return cls(Unresolved(fn, args))
36
51
 
37
52
  @classmethod
38
53
  def value(cls, value: T) -> Thunk[T]:
@@ -42,16 +57,21 @@ class Thunk(Generic[T, Unpack[TS]]):
42
57
  match self.state:
43
58
  case Resolved(value):
44
59
  return value
45
- case Unresolved(fn, args, fallback):
46
- self.state = Resolving(fallback)
47
- res = fn(*args)
48
- self.state = Resolved(res)
49
- return res
50
- case Resolving(fallback):
51
- if fallback is None:
52
- msg = "Recursively resolving thunk without fallback"
53
- raise ValueError(msg)
54
- return fallback()
60
+ case Unresolved(fn, args):
61
+ self.state = Resolving()
62
+ try:
63
+ res = fn(*args)
64
+ except Exception as e:
65
+ self.state = Error(e)
66
+ raise e from None
67
+ else:
68
+ self.state = Resolved(res)
69
+ return res
70
+ case Resolving():
71
+ msg = "Recursively resolving thunk"
72
+ raise ValueError(msg)
73
+ case Error(e):
74
+ raise e
55
75
 
56
76
 
57
77
  @dataclass
@@ -63,9 +83,13 @@ class Resolved(Generic[T]):
63
83
  class Unresolved(Generic[T, Unpack[TS]]):
64
84
  fn: Callable[[Unpack[TS]], T]
65
85
  args: tuple[Unpack[TS]]
66
- fallback: Callable[[], T] | None
67
86
 
68
87
 
69
88
  @dataclass
70
- class Resolving(Generic[T]):
71
- fallback: Callable[[], T] | None
89
+ class Resolving:
90
+ pass
91
+
92
+
93
+ @dataclass
94
+ class Error:
95
+ e: Exception
@@ -79,9 +79,10 @@ class TypeConstraintSolver:
79
79
  Also returns the bound type params if the class name is passed in.
80
80
  """
81
81
  self._infer_typevars(fn_return, return_, cls_name)
82
- arg_types = (
83
- self._subtitute_typevars(a, cls_name) for a in chain(fn_args, repeat(fn_var_args) if fn_var_args else [])
84
- )
82
+ arg_types: Iterable[JustTypeRef] = [self._subtitute_typevars(a, cls_name) for a in fn_args]
83
+ if fn_var_args:
84
+ # Need to be generator so it can be infinite for variable args
85
+ arg_types = chain(arg_types, repeat(self._subtitute_typevars(fn_var_args, cls_name)))
85
86
  bound_typevars = (
86
87
  tuple(
87
88
  v
@@ -132,8 +133,8 @@ class TypeConstraintSolver:
132
133
  def _subtitute_typevars(self, tp: TypeOrVarRef, cls_name: str | None) -> JustTypeRef:
133
134
  match tp:
134
135
  case ClassTypeVarRef(name):
136
+ assert cls_name is not None
135
137
  try:
136
- assert cls_name is not None
137
138
  return self._cls_typevar_index_to_type[cls_name][name]
138
139
  except KeyError as e:
139
140
  raise TypeConstraintError(f"Not enough bound typevars for {tp}") from e