egglog 7.2.0__cp310-none-win_amd64.whl → 8.0.0__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/runtime.py CHANGED
@@ -127,8 +127,8 @@ class RuntimeClass(DelayedDeclerations):
127
127
  # If this is a literal type, initializing it with a literal should return a literal
128
128
  if (name := self.__egg_tp__.name) == "PyObject":
129
129
  assert len(args) == 1
130
- return RuntimeExpr.__from_value__(
131
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
130
+ return RuntimeExpr(
131
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])))
132
132
  )
133
133
  if name == "UnstableFn":
134
134
  assert not kwargs
@@ -150,20 +150,20 @@ class RuntimeClass(DelayedDeclerations):
150
150
  value = PartialCallDecl(replace(call, args=call.args[:n_args]))
151
151
  remaining_arg_types = [a.tp for a in call.args[n_args:]]
152
152
  type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
153
- return RuntimeExpr.__from_value__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
153
+ return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
154
154
 
155
155
  if name in UNARY_LIT_CLASS_NAMES:
156
156
  assert len(args) == 1
157
157
  assert isinstance(args[0], int | float | str | bool)
158
- return RuntimeExpr.__from_value__(
159
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
158
+ return RuntimeExpr(
159
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])))
160
160
  )
161
161
  if name == UNIT_CLASS_NAME:
162
162
  assert len(args) == 0
163
- return RuntimeExpr.__from_value__(
164
- self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
163
+ return RuntimeExpr(
164
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)))
165
165
  )
166
- fn = RuntimeFunction(Thunk.value(self.__egg_decls__), InitRef(name), self.__egg_tp__.to_just())
166
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(InitRef(name)), self.__egg_tp__.to_just())
167
167
  return fn(*args, **kwargs) # type: ignore[arg-type]
168
168
 
169
169
  def __dir__(self) -> list[str]:
@@ -208,19 +208,21 @@ class RuntimeClass(DelayedDeclerations):
208
208
  # if this is a class variable, return an expr for it, otherwise, assume it's a method
209
209
  if name in cls_decl.class_variables:
210
210
  return_tp = cls_decl.class_variables[name]
211
- return RuntimeExpr.__from_value__(
212
- self.__egg_decls__,
213
- TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
211
+ return RuntimeExpr(
212
+ self.__egg_decls_thunk__,
213
+ Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name)))),
214
214
  )
215
215
  if name in cls_decl.class_methods:
216
216
  return RuntimeFunction(
217
- Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
217
+ self.__egg_decls_thunk__,
218
+ Thunk.value(ClassMethodRef(self.__egg_tp__.name, name)),
219
+ self.__egg_tp__.to_just(),
218
220
  )
219
221
  # allow referencing properties and methods as class variables as well
220
222
  if name in cls_decl.properties:
221
- return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(self.__egg_tp__.name, name))
223
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.name, name)))
222
224
  if name in cls_decl.methods:
223
- return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(self.__egg_tp__.name, name))
225
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name)))
224
226
 
225
227
  msg = f"Class {self.__egg_tp__.name} has no method {name}"
226
228
  if name == "__ne__":
@@ -241,16 +243,23 @@ class RuntimeClass(DelayedDeclerations):
241
243
 
242
244
  @dataclass
243
245
  class RuntimeFunction(DelayedDeclerations):
244
- __egg_ref__: CallableRef
246
+ __egg_ref_thunk__: Callable[[], CallableRef]
245
247
  # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
246
248
  __egg_bound__: JustTypeRef | RuntimeExpr | None = None
247
249
 
250
+ @property
251
+ def __egg_ref__(self) -> CallableRef:
252
+ return self.__egg_ref_thunk__()
253
+
248
254
  def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
249
255
  from .conversion import resolve_literal
250
256
 
251
257
  if isinstance(self.__egg_bound__, RuntimeExpr):
252
258
  args = (self.__egg_bound__, *args)
253
- signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
259
+ try:
260
+ signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
261
+ except Exception as e:
262
+ raise TypeError(f"Failed to find callable {self}") from e
254
263
  decls = self.__egg_decls__.copy()
255
264
  # Special case function application bc we dont support variadic generics yet generally
256
265
  if signature == "fn-app":
@@ -285,7 +294,7 @@ class RuntimeFunction(DelayedDeclerations):
285
294
  args = bound.args
286
295
 
287
296
  upcasted_args = [
288
- resolve_literal(cast(TypeOrVarRef, tp), arg)
297
+ resolve_literal(cast(TypeOrVarRef, tp), arg, Thunk.value(decls))
289
298
  for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
290
299
  ]
291
300
  decls.update(*upcasted_args)
@@ -322,9 +331,10 @@ class RuntimeFunction(DelayedDeclerations):
322
331
  # If there is not return type, we are mutating the first arg
323
332
  if not signature.return_type:
324
333
  first_arg = upcasted_args[0]
325
- first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
334
+ first_arg.__egg_decls_thunk__ = Thunk.value(decls)
335
+ first_arg.__egg_typed_expr_thunk__ = Thunk.value(typed_expr_decl)
326
336
  return None
327
- return RuntimeExpr.__from_value__(decls, typed_expr_decl)
337
+ return RuntimeExpr.__from_values__(decls, typed_expr_decl)
328
338
 
329
339
  def __str__(self) -> str:
330
340
  first_arg, bound_tp_params = None, None
@@ -349,7 +359,9 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
349
359
  Parameter(
350
360
  n,
351
361
  Parameter.POSITIONAL_OR_KEYWORD,
352
- default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n)))
362
+ default=RuntimeExpr.__from_values__(
363
+ decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
364
+ )
353
365
  if d is not None or optional_args
354
366
  else Parameter.empty,
355
367
  )
@@ -387,21 +399,16 @@ PARTIAL_METHODS = {
387
399
 
388
400
 
389
401
  @dataclass
390
- class RuntimeExpr:
391
- # Defer needing decls/expr so we can make constants that don't resolve their class types
392
- __egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
402
+ class RuntimeExpr(DelayedDeclerations):
403
+ __egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
393
404
 
394
405
  @classmethod
395
- def __from_value__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
396
- return cls(Thunk.value((d, e)))
397
-
398
- @property
399
- def __egg_decls__(self) -> Declarations:
400
- return self.__egg_thunk__()[0]
406
+ def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
407
+ return cls(Thunk.value(d), Thunk.value(e))
401
408
 
402
409
  @property
403
410
  def __egg_typed_expr__(self) -> TypedExprDecl:
404
- return self.__egg_thunk__()[1]
411
+ return self.__egg_typed_expr_thunk__()
405
412
 
406
413
  def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
407
414
  cls_name = self.__egg_class_name__
@@ -411,9 +418,9 @@ class RuntimeExpr:
411
418
  return preserved_methods[name].__get__(self)
412
419
 
413
420
  if name in class_decl.methods:
414
- return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(cls_name, name), self)
421
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
415
422
  if name in class_decl.properties:
416
- return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(cls_name, name), self)()
423
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
417
424
  raise AttributeError(f"{cls_name} has no method {name}") from None
418
425
 
419
426
  def __repr__(self) -> str:
@@ -456,10 +463,11 @@ class RuntimeExpr:
456
463
  # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
457
464
 
458
465
  def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
459
- return self.__egg_thunk__()
466
+ return self.__egg_decls__, self.__egg_typed_expr__
460
467
 
461
468
  def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
462
- self.__egg_thunk__ = Thunk.value(d)
469
+ self.__egg_decls_thunk__ = Thunk.value(d[0])
470
+ self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
463
471
 
464
472
  def __hash__(self) -> int:
465
473
  return hash(self.__egg_typed_expr__)
@@ -497,7 +505,7 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
497
505
  # we use the standard process
498
506
  pass
499
507
  if __name in class_decl.methods:
500
- fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
508
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
501
509
  return fn(*args, **kwargs) # type: ignore[arg-type]
502
510
  if __name in PARTIAL_METHODS:
503
511
  return NotImplemented
@@ -523,7 +531,7 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime
523
531
  min_tp = min_convertable_tp(slf, other, name)
524
532
  slf = resolve_literal(TypeRefWithVars(min_tp), slf)
525
533
  other = resolve_literal(TypeRefWithVars(min_tp), other)
526
- method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
534
+ method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
527
535
  return method(other)
528
536
 
529
537
 
@@ -545,7 +553,12 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
545
553
  """
546
554
  match callable:
547
555
  case RuntimeFunction(decls, ref, _):
548
- return ref, decls()
556
+ return ref(), decls()
549
557
  case RuntimeClass(thunk, tp):
550
- return ClassMethodRef(tp.name, "__init__"), thunk()
551
- raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
558
+ return InitRef(tp.name), thunk()
559
+ case RuntimeExpr(decl_thunk, expr_thunk):
560
+ if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(expr.callable, ConstantRef):
561
+ raise NotImplementedError(f"Can only turn constants into callable refs, not {expr}")
562
+ return expr.callable, decl_thunk()
563
+ case _:
564
+ raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
egglog/thunk.py CHANGED
@@ -9,10 +9,27 @@ if TYPE_CHECKING:
9
9
  from collections.abc import Callable
10
10
 
11
11
 
12
- __all__ = ["Thunk"]
12
+ __all__ = ["Thunk", "split_thunk"]
13
13
 
14
14
  T = TypeVar("T")
15
15
  TS = TypeVarTuple("TS")
16
+ V = TypeVar("V")
17
+
18
+
19
+ def split_thunk(fn: Callable[[], tuple[T, V]]) -> tuple[Callable[[], T], Callable[[], V]]:
20
+ s = _Split(fn)
21
+ return s.left, s.right
22
+
23
+
24
+ @dataclass
25
+ class _Split(Generic[T, V]):
26
+ fn: Callable[[], tuple[T, V]]
27
+
28
+ def left(self) -> T:
29
+ return self.fn()[0]
30
+
31
+ def right(self) -> V:
32
+ return self.fn()[1]
16
33
 
17
34
 
18
35
  @dataclass
@@ -21,18 +38,16 @@ class Thunk(Generic[T, Unpack[TS]]):
21
38
  Cached delayed function call.
22
39
  """
23
40
 
24
- state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving[T] | Error
41
+ state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
25
42
 
26
43
  @classmethod
27
- def fn(
28
- cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], fallback: Callable[[], T] | None = None
29
- ) -> Thunk[T, Unpack[TS]]:
44
+ def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
30
45
  """
31
46
  Create a thunk based on some functions and some partial args.
32
47
 
33
- If the function is called while it is being resolved recursively, will instead return the fallback, if provided.
48
+ If the function is called while it is being resolved recursively it will raise an exception.
34
49
  """
35
- return cls(Unresolved(fn, args, fallback))
50
+ return cls(Unresolved(fn, args))
36
51
 
37
52
  @classmethod
38
53
  def value(cls, value: T) -> Thunk[T]:
@@ -42,21 +57,19 @@ class Thunk(Generic[T, Unpack[TS]]):
42
57
  match self.state:
43
58
  case Resolved(value):
44
59
  return value
45
- case Unresolved(fn, args, fallback):
46
- self.state = Resolving(fallback)
60
+ case Unresolved(fn, args):
61
+ self.state = Resolving()
47
62
  try:
48
63
  res = fn(*args)
49
64
  except Exception as e:
50
65
  self.state = Error(e)
51
- raise
66
+ raise e from None
52
67
  else:
53
68
  self.state = Resolved(res)
54
69
  return res
55
- case Resolving(fallback):
56
- if fallback is None:
57
- msg = "Recursively resolving thunk without fallback"
58
- raise ValueError(msg)
59
- return fallback()
70
+ case Resolving():
71
+ msg = "Recursively resolving thunk"
72
+ raise ValueError(msg)
60
73
  case Error(e):
61
74
  raise e
62
75
 
@@ -70,12 +83,11 @@ class Resolved(Generic[T]):
70
83
  class Unresolved(Generic[T, Unpack[TS]]):
71
84
  fn: Callable[[Unpack[TS]], T]
72
85
  args: tuple[Unpack[TS]]
73
- fallback: Callable[[], T] | None
74
86
 
75
87
 
76
88
  @dataclass
77
- class Resolving(Generic[T]):
78
- fallback: Callable[[], T] | None
89
+ class Resolving:
90
+ pass
79
91
 
80
92
 
81
93
  @dataclass