egglog 10.0.1__cp312-cp312-win_amd64.whl → 11.0.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/runtime.py CHANGED
@@ -11,30 +11,32 @@ so they are not mangled by Python and can be accessed by the user.
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ import itertools
14
15
  import operator
16
+ import types
15
17
  from collections.abc import Callable
16
- from dataclasses import dataclass, replace
18
+ from dataclasses import InitVar, dataclass, replace
17
19
  from inspect import Parameter, Signature
18
20
  from itertools import zip_longest
19
- from typing import TYPE_CHECKING, TypeVar, Union, cast, get_args, get_origin
21
+ from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin
20
22
 
21
23
  from .declarations import *
22
24
  from .pretty import *
23
25
  from .thunk import Thunk
24
26
  from .type_constraint_solver import *
27
+ from .version_compat import *
25
28
 
26
29
  if TYPE_CHECKING:
27
30
  from collections.abc import Iterable
28
31
 
29
- from .egraph import Fact
30
-
31
32
 
32
33
  __all__ = [
33
34
  "LIT_CLASS_NAMES",
34
- "REFLECTED_BINARY_METHODS",
35
+ "NUMERIC_BINARY_METHODS",
35
36
  "RuntimeClass",
36
37
  "RuntimeExpr",
37
38
  "RuntimeFunction",
39
+ "define_expr_method",
38
40
  "resolve_callable",
39
41
  "resolve_type_annotation",
40
42
  "resolve_type_annotation_mutate",
@@ -45,24 +47,34 @@ UNIT_CLASS_NAME = "Unit"
45
47
  UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
46
48
  LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
47
49
 
48
- REFLECTED_BINARY_METHODS = {
49
- "__radd__": "__add__",
50
- "__rsub__": "__sub__",
51
- "__rmul__": "__mul__",
52
- "__rmatmul__": "__matmul__",
53
- "__rtruediv__": "__truediv__",
54
- "__rfloordiv__": "__floordiv__",
55
- "__rmod__": "__mod__",
56
- "__rpow__": "__pow__",
57
- "__rlshift__": "__lshift__",
58
- "__rrshift__": "__rshift__",
59
- "__rand__": "__and__",
60
- "__rxor__": "__xor__",
61
- "__ror__": "__or__",
50
+ # All methods which should return NotImplemented if they fail to resolve and are reflected as well
51
+ # From https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
52
+
53
+ NUMERIC_BINARY_METHODS = {
54
+ "__add__",
55
+ "__sub__",
56
+ "__mul__",
57
+ "__matmul__",
58
+ "__truediv__",
59
+ "__floordiv__",
60
+ "__mod__",
61
+ "__divmod__",
62
+ "__pow__",
63
+ "__lshift__",
64
+ "__rshift__",
65
+ "__and__",
66
+ "__xor__",
67
+ "__or__",
68
+ "__lt__",
69
+ "__le__",
70
+ "__gt__",
71
+ "__ge__",
62
72
  }
63
73
 
64
- # Methods that need to return real Python values not expressions
65
- PRESERVED_METHODS = [
74
+
75
+ # Methods that need to be defined on the runtime type that holds `Expr` objects, so that they can be used as methods.
76
+
77
+ TYPE_DEFINED_METHODS = {
66
78
  "__bool__",
67
79
  "__len__",
68
80
  "__complex__",
@@ -70,9 +82,15 @@ PRESERVED_METHODS = [
70
82
  "__float__",
71
83
  "__iter__",
72
84
  "__index__",
73
- "__float__",
74
- "__int__",
75
- ]
85
+ "__call__",
86
+ "__getitem__",
87
+ "__setitem__",
88
+ "__delitem__",
89
+ "__pos__",
90
+ "__neg__",
91
+ "__invert__",
92
+ "__round__",
93
+ }
76
94
 
77
95
  # Set this globally so we can get access to PyObject when we have a type annotation of just object.
78
96
  # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
@@ -134,11 +152,48 @@ def inverse_resolve_type_annotation(decls_thunk: Callable[[], Declarations], tp:
134
152
  ##
135
153
 
136
154
 
137
- @dataclass
138
- class RuntimeClass(DelayedDeclerations):
155
+ class BaseClassFactoryMeta(type):
156
+ """
157
+ Base metaclass for all runtime classes created by ClassFactory
158
+ """
159
+
160
+ def __instancecheck__(cls, instance: object) -> bool:
161
+ assert isinstance(cls, RuntimeClass)
162
+ return isinstance(instance, RuntimeExpr) and cls.__egg_tp__.name == instance.__egg_typed_expr__.tp.name
163
+
164
+
165
+ class ClassFactory(type):
166
+ """
167
+ A metaclass for types which should create `type` objects when instantiated.
168
+
169
+ That's so that they work with `isinstance` and can be placed in `match ClassName()`.
170
+ """
171
+
172
+ def __call__(cls, *args, **kwargs) -> type:
173
+ # If we have params, don't inherit from `type` because we don't need to match against this and also
174
+ # this won't work with `Union[X]` because it won't look at `__parameters__` for instances of `type`.
175
+ if kwargs.pop("_egg_has_params", False):
176
+ return super().__call__(*args, **kwargs)
177
+ namespace: dict[str, Any] = {}
178
+ for m in reversed(cls.__mro__):
179
+ namespace.update(m.__dict__)
180
+ init = namespace.pop("__init__")
181
+ meta = types.new_class("type(RuntimeClass)", (BaseClassFactoryMeta,), {}, lambda ns: ns.update(**namespace))
182
+ tp = types.new_class("RuntimeClass", (), {"metaclass": meta})
183
+ init(tp, *args, **kwargs)
184
+ return tp
185
+
186
+ def __instancecheck__(cls, instance: object) -> bool:
187
+ return isinstance(instance, BaseClassFactoryMeta)
188
+
189
+
190
+ @dataclass(match_args=False)
191
+ class RuntimeClass(DelayedDeclerations, metaclass=ClassFactory):
139
192
  __egg_tp__: TypeRefWithVars
193
+ # True if we want `__parameters__` to be recognized by `Union`, which means we can't inherit from `type` directly.
194
+ _egg_has_params: InitVar[bool] = False
140
195
 
141
- def __post_init__(self) -> None:
196
+ def __post_init__(self, _egg_has_params: bool) -> None:
142
197
  global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
143
198
  if (name := self.__egg_tp__.name) == "PyObject":
144
199
  _PY_OBJECT_CLASS = self
@@ -228,7 +283,7 @@ class RuntimeClass(DelayedDeclerations):
228
283
  else:
229
284
  final_args = new_args
230
285
  tp = TypeRefWithVars(self.__egg_tp__.name, final_args)
231
- return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp)
286
+ return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp, _egg_has_params=True)
232
287
 
233
288
  def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
234
289
  if name == "__origin__" and self.__egg_tp__.args:
@@ -249,8 +304,7 @@ class RuntimeClass(DelayedDeclerations):
249
304
  try:
250
305
  cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
251
306
  except Exception as e:
252
- e.add_note(f"Error processing class {self.__egg_tp__.name}")
253
- raise
307
+ raise add_note(f"Error processing class {self.__egg_tp__.name}", e) from None
254
308
 
255
309
  preserved_methods = cls_decl.preserved_methods
256
310
  if name in preserved_methods:
@@ -281,10 +335,21 @@ class RuntimeClass(DelayedDeclerations):
281
335
  def __str__(self) -> str:
282
336
  return str(self.__egg_tp__)
283
337
 
338
+ def __repr__(self) -> str:
339
+ return str(self)
340
+
284
341
  # Make hashable so can go in Union
285
342
  def __hash__(self) -> int:
286
343
  return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
287
344
 
345
+ def __eq__(self, other: object) -> bool:
346
+ """
347
+ Support equality for runtime comparison of egglog classes.
348
+ """
349
+ if not isinstance(other, RuntimeClass):
350
+ return NotImplemented
351
+ return self.__egg_tp__ == other.__egg_tp__
352
+
288
353
  # Support unioning like types
289
354
  def __or__(self, value: type) -> object:
290
355
  return Union[self, value] # noqa: UP007
@@ -296,6 +361,10 @@ class RuntimeClass(DelayedDeclerations):
296
361
  """
297
362
  return tuple(inverse_resolve_type_annotation(self.__egg_decls_thunk__, tp) for tp in self.__egg_tp__.args)
298
363
 
364
+ @property
365
+ def __match_args__(self) -> tuple[str, ...]:
366
+ return self.__egg_decls__._classes[self.__egg_tp__.name].match_args
367
+
299
368
 
300
369
  @dataclass
301
370
  class RuntimeFunction(DelayedDeclerations):
@@ -303,20 +372,30 @@ class RuntimeFunction(DelayedDeclerations):
303
372
  # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
304
373
  __egg_bound__: JustTypeRef | RuntimeExpr | None = None
305
374
 
375
+ def __eq__(self, other: object) -> bool:
376
+ """
377
+ Support equality for runtime comparison of egglog functions.
378
+ """
379
+ if not isinstance(other, RuntimeFunction):
380
+ return NotImplemented
381
+ return self.__egg_ref__ == other.__egg_ref__ and bool(self.__egg_bound__ == other.__egg_bound__)
382
+
383
+ def __hash__(self) -> int:
384
+ return hash((self.__egg_ref__, self.__egg_bound__))
385
+
306
386
  @property
307
387
  def __egg_ref__(self) -> CallableRef:
308
388
  return self.__egg_ref_thunk__()
309
389
 
310
390
  def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
311
- from .conversion import resolve_literal
391
+ from .conversion import resolve_literal # noqa: PLC0415
312
392
 
313
393
  if isinstance(self.__egg_bound__, RuntimeExpr):
314
394
  args = (self.__egg_bound__, *args)
315
395
  try:
316
396
  signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).signature
317
397
  except Exception as e:
318
- e.add_note(f"Failed to find callable {self}")
319
- raise
398
+ raise add_note(f"Failed to find callable {self}", e) # noqa: B904
320
399
  decls = self.__egg_decls__.copy()
321
400
  # Special case function application bc we dont support variadic generics yet generally
322
401
  if signature == "fn-app":
@@ -344,7 +423,7 @@ class RuntimeFunction(DelayedDeclerations):
344
423
  try:
345
424
  bound = py_signature.bind(*args, **kwargs)
346
425
  except TypeError as err:
347
- raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err
426
+ raise TypeError(f"Failed to bind arguments for {self} with args {args} and kwargs {kwargs}: {err}") from err
348
427
  del kwargs
349
428
  bound.apply_defaults()
350
429
  assert not bound.kwargs
@@ -413,9 +492,7 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
413
492
  Parameter(
414
493
  n,
415
494
  Parameter.POSITIONAL_OR_KEYWORD,
416
- default=RuntimeExpr.__from_values__(
417
- decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
418
- )
495
+ default=RuntimeExpr.__from_values__(decls, TypedExprDecl(t.to_just(), d or LetRefDecl(n)))
419
496
  if d is not None or optional_args
420
497
  else Parameter.empty,
421
498
  )
@@ -426,32 +503,6 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
426
503
  return Signature(parameters)
427
504
 
428
505
 
429
- # All methods which should return NotImplemented if they fail to resolve
430
- # From https://docs.python.org/3/reference/datamodel.html
431
- PARTIAL_METHODS = {
432
- "__lt__",
433
- "__le__",
434
- "__eq__",
435
- "__ne__",
436
- "__gt__",
437
- "__ge__",
438
- "__add__",
439
- "__sub__",
440
- "__mul__",
441
- "__matmul__",
442
- "__truediv__",
443
- "__floordiv__",
444
- "__mod__",
445
- "__divmod__",
446
- "__pow__",
447
- "__lshift__",
448
- "__rshift__",
449
- "__and__",
450
- "__xor__",
451
- "__or__",
452
- }
453
-
454
-
455
506
  @dataclass
456
507
  class RuntimeExpr(DelayedDeclerations):
457
508
  __egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
@@ -468,17 +519,14 @@ class RuntimeExpr(DelayedDeclerations):
468
519
  return self.__egg_typed_expr_thunk__()
469
520
 
470
521
  def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
471
- cls_name = self.__egg_class_name__
472
- class_decl = self.__egg_class_decl__
473
-
474
- if name in (preserved_methods := class_decl.preserved_methods):
475
- return preserved_methods[name].__get__(self)
476
-
477
- if name in class_decl.methods:
478
- return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
479
- if name in class_decl.properties:
480
- return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
481
- raise AttributeError(f"{cls_name} has no method {name}") from None
522
+ if (method := _get_expr_method(self, name)) is not None:
523
+ return method
524
+ if name in self.__egg_class_decl__.properties:
525
+ fn = RuntimeFunction(
526
+ self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_class_name__, name)), self
527
+ )
528
+ return fn()
529
+ raise AttributeError(f"{self.__egg_class_name__} has no method {name}") from None
482
530
 
483
531
  def __repr__(self) -> str:
484
532
  """
@@ -493,7 +541,7 @@ class RuntimeExpr(DelayedDeclerations):
493
541
  return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
494
542
 
495
543
  def _ipython_display_(self) -> None:
496
- from IPython.display import Code, display
544
+ from IPython.display import Code, display # noqa: PLC0415
497
545
 
498
546
  display(Code(str(self), language="python"))
499
547
 
@@ -509,13 +557,6 @@ class RuntimeExpr(DelayedDeclerations):
509
557
  def __egg_class_decl__(self) -> ClassDecl:
510
558
  return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
511
559
 
512
- # These both will be overriden below in the special methods section, but add these here for type hinting purposes
513
- def __eq__(self, other: object) -> Fact: # type: ignore[override, empty-body]
514
- ...
515
-
516
- def __ne__(self, other: object) -> RuntimeExpr: # type: ignore[override, empty-body]
517
- ...
518
-
519
560
  # Implement these so that copy() works on this object
520
561
  # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
521
562
 
@@ -527,91 +568,102 @@ class RuntimeExpr(DelayedDeclerations):
527
568
  self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
528
569
 
529
570
  def __hash__(self) -> int:
571
+ if (method := _get_expr_method(self, "__hash__")) is not None:
572
+ return cast("int", cast("Any", method()))
530
573
  return hash(self.__egg_typed_expr__)
531
574
 
575
+ # Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a
576
+ # preserved method or defined on the class
577
+ def __eq__(self, other: object) -> object: # type: ignore[override]
578
+ if (method := _get_expr_method(self, "__eq__")) is not None:
579
+ return method(other)
532
580
 
533
- # Define each of the special methods, since we have already declared them for pretty printing
534
- for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__", "__setitem__", "__delitem__"]:
581
+ # TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other
582
+ # expr gets a chance to resolve __eq__ which could be a preserved method.
583
+ from .egraph import BaseExpr, eq # noqa: PLC0415
535
584
 
536
- def _special_method(
537
- self: RuntimeExpr,
538
- *args: object,
539
- __name: str = name,
540
- **kwargs: object,
541
- ) -> RuntimeExpr | Fact | None:
542
- from .conversion import ConvertError
585
+ return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other))
543
586
 
544
- class_name = self.__egg_class_name__
545
- class_decl = self.__egg_class_decl__
546
- # First, try to resolve as preserved method
547
- try:
548
- method = class_decl.preserved_methods[__name]
549
- except KeyError:
550
- pass
551
- else:
552
- return method(self, *args, **kwargs)
553
- # If this is a "partial" method meaning that it can return NotImplemented,
554
- # we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
555
- # using the arg type of the self arg.
556
- # This is neccesary so if we add like an int to a ndarray, it will upcast the int to an ndarray, instead of vice versa.
557
- if __name in PARTIAL_METHODS:
558
- try:
559
- return call_method_min_conversion(self, args[0], __name)
560
- except ConvertError:
561
- # Defer raising not imeplemented in case the dunder method is not symmetrical, then
562
- # we use the standard process
563
- pass
564
- if __name in class_decl.methods:
565
- fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
566
- return fn(*args, **kwargs) # type: ignore[arg-type]
567
- # Handle == and != fallbacks to eq and ne helpers if the methods aren't defined on the class explicitly.
568
- if __name == "__eq__":
569
- from .egraph import BaseExpr, eq
570
-
571
- return eq(cast("BaseExpr", self)).to(cast("BaseExpr", args[0]))
572
- if __name == "__ne__":
573
- from .egraph import BaseExpr, ne
574
-
575
- return cast("RuntimeExpr", ne(cast("BaseExpr", self)).to(cast("BaseExpr", args[0])))
576
-
577
- if __name in PARTIAL_METHODS:
578
- return NotImplemented
579
- raise TypeError(f"{class_name!r} object does not support {__name}")
587
+ def __ne__(self, other: object) -> object: # type: ignore[override]
588
+ if (method := _get_expr_method(self, "__ne__")) is not None:
589
+ return method(other)
590
+
591
+ from .egraph import BaseExpr, ne # noqa: PLC0415
580
592
 
581
- setattr(RuntimeExpr, name, _special_method)
593
+ return ne(cast("BaseExpr", self)).to(cast("BaseExpr", other))
582
594
 
583
- # For each of the reflected binary methods, translate to the corresponding non-reflected method
584
- for reflected, non_reflected in REFLECTED_BINARY_METHODS.items():
595
+ def __call__(
596
+ self, *args: object, **kwargs: object
597
+ ) -> object: # define it here only for type checking, it will be overriden below
598
+ ...
585
599
 
586
- def _reflected_method(self: RuntimeExpr, other: object, __non_reflected: str = non_reflected) -> RuntimeExpr | None:
587
- # All binary methods are also "partial" meaning we should try to upcast first.
588
- return call_method_min_conversion(other, self, __non_reflected)
589
600
 
590
- setattr(RuntimeExpr, reflected, _reflected_method)
601
+ def _get_expr_method(expr: RuntimeExpr, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
602
+ if name in (preserved_methods := expr.__egg_class_decl__.preserved_methods):
603
+ return preserved_methods[name].__get__(expr)
591
604
 
605
+ if name in expr.__egg_class_decl__.methods:
606
+ return RuntimeFunction(expr.__egg_decls_thunk__, Thunk.value(MethodRef(expr.__egg_class_name__, name)), expr)
607
+ return None
592
608
 
593
- def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None:
594
- from .conversion import min_convertable_tp, resolve_literal
595
609
 
596
- # find a minimum type that both can be converted to
597
- # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
598
- min_tp = min_convertable_tp(slf, other, name).to_var()
599
- slf = resolve_literal(min_tp, slf)
600
- other = resolve_literal(min_tp, other)
601
- method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
602
- return method(other)
610
+ def define_expr_method(name: str) -> None:
611
+ """
612
+ Given the name of a method, explicitly defines it on the runtime type that holds `Expr` objects as a method.
603
613
 
614
+ Call this if you need a method to be defined on the type itself where overrindg with `__getattr__` does not suffice,
615
+ like for NumPy's `__array_ufunc__`.
616
+ """
604
617
 
605
- for name in PRESERVED_METHODS:
618
+ def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs):
619
+ fn = _get_expr_method(self, __name)
620
+ if fn is None:
621
+ raise TypeError(f"{self.__egg_class_name__} expression has no method {__name}")
622
+ return fn(*args, **kwargs)
606
623
 
607
- def _preserved_method(self: RuntimeExpr, __name: str = name):
608
- try:
609
- method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name]
610
- except KeyError as e:
611
- raise TypeError(f"{self.__egg_typed_expr__.tp.name} has no method {__name}") from e
612
- return method(self)
624
+ setattr(RuntimeExpr, name, _defined_method)
625
+
626
+
627
+ for name in TYPE_DEFINED_METHODS:
628
+ define_expr_method(name)
629
+
630
+
631
+ for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
632
+
633
+ def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object:
634
+ """
635
+ Implements numeric binary operations.
636
+
637
+ Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
638
+ the LHS or the RHS as exactly the right type and then upcasting the other to that type.
639
+ """
640
+ # 1. switch if reversed method
641
+ if r_method:
642
+ self, other = other, self
643
+ # If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion
644
+ if not (
645
+ isinstance(self, RuntimeExpr)
646
+ and isinstance(other, RuntimeExpr)
647
+ and (
648
+ self.__egg_decls__.check_binary_method_with_types(
649
+ name, self.__egg_typed_expr__.tp, other.__egg_typed_expr__.tp
650
+ )
651
+ )
652
+ ):
653
+ from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415
654
+
655
+ best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))
656
+
657
+ if not best_method:
658
+ raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
659
+ self, other = best_method[0](self), best_method[1](other)
660
+
661
+ method_ref = MethodRef(self.__egg_class_name__, name)
662
+ fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
663
+ return fn(other)
613
664
 
614
- setattr(RuntimeExpr, name, _preserved_method)
665
+ method_name = f"__r{name[2:]}" if r_method else name
666
+ setattr(RuntimeExpr, method_name, _numeric_binary_method)
615
667
 
616
668
 
617
669
  def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
egglog/thunk.py CHANGED
@@ -41,13 +41,13 @@ class Thunk(Generic[T, Unpack[TS]]):
41
41
  state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving | Error
42
42
 
43
43
  @classmethod
44
- def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS]) -> Thunk[T, Unpack[TS]]:
44
+ def fn(cls, fn: Callable[[Unpack[TS]], T], *args: Unpack[TS], context: str | None = None) -> Thunk[T, Unpack[TS]]:
45
45
  """
46
46
  Create a thunk based on some functions and some partial args.
47
47
 
48
48
  If the function is called while it is being resolved recursively it will raise an exception.
49
49
  """
50
- return cls(Unresolved(fn, args))
50
+ return cls(Unresolved(fn, args, context))
51
51
 
52
52
  @classmethod
53
53
  def value(cls, value: T) -> Thunk[T]:
@@ -57,12 +57,12 @@ class Thunk(Generic[T, Unpack[TS]]):
57
57
  match self.state:
58
58
  case Resolved(value):
59
59
  return value
60
- case Unresolved(fn, args):
60
+ case Unresolved(fn, args, context):
61
61
  self.state = Resolving()
62
62
  try:
63
63
  res = fn(*args)
64
64
  except Exception as e:
65
- self.state = Error(e)
65
+ self.state = Error(e, context)
66
66
  raise e from None
67
67
  else:
68
68
  self.state = Resolved(res)
@@ -83,6 +83,7 @@ class Resolved(Generic[T]):
83
83
  class Unresolved(Generic[T, Unpack[TS]]):
84
84
  fn: Callable[[Unpack[TS]], T]
85
85
  args: tuple[Unpack[TS]]
86
+ context: str | None
86
87
 
87
88
 
88
89
  @dataclass
@@ -93,3 +94,4 @@ class Resolving:
93
94
  @dataclass
94
95
  class Error:
95
96
  e: Exception
97
+ context: str | None
@@ -107,7 +107,7 @@ class TypeConstraintSolver:
107
107
  try:
108
108
  return self._cls_typevar_index_to_type[cls_name][tp]
109
109
  except KeyError as e:
110
- raise TypeConstraintError(f"Not enough bound typevars for {tp} in class {cls_name}") from e
110
+ raise TypeConstraintError(f"Not enough bound typevars for {tp!r} in class {cls_name}") from e
111
111
  case TypeRefWithVars(name, args):
112
112
  return JustTypeRef(name, tuple(self.substitute_typevars(arg, cls_name) for arg in args))
113
113
  assert_never(tp)
@@ -0,0 +1,87 @@
1
+ import collections
2
+ import sys
3
+ import types
4
+ import typing
5
+
6
+ BEFORE_3_11 = sys.version_info < (3, 11)
7
+
8
+ __all__ = ["add_note"]
9
+
10
+
11
+ def add_note(message: str, exc: BaseException) -> BaseException:
12
+ """
13
+ Backwards compatible add_note for Python <= 3.10
14
+ """
15
+ if BEFORE_3_11:
16
+ return exc
17
+ exc.add_note(message)
18
+ return exc
19
+
20
+
21
+ # For Python version 3.10 need to monkeypatch this function so that RuntimeClass type parameters
22
+ # will be collected as typevars
23
+ if BEFORE_3_11:
24
+
25
+ @typing.no_type_check
26
+ def _collect_type_vars_monkeypatch(types_, typevar_types=None):
27
+ """
28
+ Collect all type variable contained
29
+ in types in order of first appearance (lexicographic order). For example::
30
+
31
+ _collect_type_vars((T, List[S, T])) == (T, S)
32
+ """
33
+ from .runtime import RuntimeClass # noqa: PLC0415
34
+
35
+ if typevar_types is None:
36
+ typevar_types = typing.TypeVar
37
+ tvars = []
38
+ for t in types_:
39
+ if isinstance(t, typevar_types) and t not in tvars:
40
+ tvars.append(t)
41
+ # **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
42
+ if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)): # type: ignore[name-defined]
43
+ tvars.extend([t for t in t.__parameters__ if t not in tvars])
44
+ return tuple(tvars)
45
+
46
+ typing._collect_type_vars = _collect_type_vars_monkeypatch # type: ignore[attr-defined]
47
+
48
+ @typing.no_type_check
49
+ @typing._tp_cache
50
+ def __getitem__monkeypatch(self, params): # noqa: C901, PLR0912
51
+ from .runtime import RuntimeClass # noqa: PLC0415
52
+
53
+ if self.__origin__ in (typing.Generic, typing.Protocol):
54
+ # Can't subscript Generic[...] or Protocol[...].
55
+ raise TypeError(f"Cannot subscript already-subscripted {self}")
56
+ if not isinstance(params, tuple):
57
+ params = (params,)
58
+ params = tuple(typing._type_convert(p) for p in params)
59
+ if self._paramspec_tvars and any(isinstance(t, typing.ParamSpec) for t in self.__parameters__):
60
+ params = typing._prepare_paramspec_params(self, params)
61
+ else:
62
+ typing._check_generic(self, params, len(self.__parameters__))
63
+
64
+ subst = dict(zip(self.__parameters__, params, strict=False))
65
+ new_args = []
66
+ for arg in self.__args__:
67
+ if isinstance(arg, self._typevar_types):
68
+ if isinstance(arg, typing.ParamSpec):
69
+ arg = subst[arg] # noqa: PLW2901
70
+ if not typing._is_param_expr(arg):
71
+ raise TypeError(f"Expected a list of types, an ellipsis, ParamSpec, or Concatenate. Got {arg}")
72
+ else:
73
+ arg = subst[arg] # noqa: PLW2901
74
+ # **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
75
+ elif isinstance(arg, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
76
+ subparams = arg.__parameters__
77
+ if subparams:
78
+ subargs = tuple(subst[x] for x in subparams)
79
+ arg = arg[subargs] # noqa: PLW2901
80
+ # Required to flatten out the args for CallableGenericAlias
81
+ if self.__origin__ == collections.abc.Callable and isinstance(arg, tuple):
82
+ new_args.extend(arg)
83
+ else:
84
+ new_args.append(arg)
85
+ return self.copy_with(tuple(new_args))
86
+
87
+ typing._GenericAlias.__getitem__ = __getitem__monkeypatch # type: ignore[attr-defined]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: egglog
3
- Version: 10.0.1
3
+ Version: 11.0.0
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers