egglog 7.1.0__cp310-none-win_amd64.whl → 7.2.0__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Binary file
egglog/declarations.py CHANGED
@@ -71,6 +71,8 @@ __all__ = [
71
71
  "CommandDecl",
72
72
  "SpecialFunctions",
73
73
  "FunctionSignature",
74
+ "DefaultRewriteDecl",
75
+ "InitRef",
74
76
  ]
75
77
 
76
78
 
@@ -80,7 +82,13 @@ class DelayedDeclerations:
80
82
 
81
83
  @property
82
84
  def __egg_decls__(self) -> Declarations:
83
- return self.__egg_decls_thunk__()
85
+ try:
86
+ return self.__egg_decls_thunk__()
87
+ # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
88
+ # instead raise explicitly
89
+ except AttributeError as err:
90
+ msg = "Error resolving declerations"
91
+ raise RuntimeError(msg) from err
84
92
 
85
93
 
86
94
  @runtime_checkable
@@ -113,6 +121,12 @@ class Declarations:
113
121
  _classes: dict[str, ClassDecl] = field(default_factory=dict)
114
122
  _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
115
123
 
124
+ @property
125
+ def default_ruleset(self) -> RulesetDecl:
126
+ ruleset = self._rulesets[""]
127
+ assert isinstance(ruleset, RulesetDecl)
128
+ return ruleset
129
+
116
130
  @classmethod
117
131
  def create(cls, *others: DeclerationsLike) -> Declarations:
118
132
  others = upcast_declerations(others)
@@ -127,7 +141,7 @@ class Declarations:
127
141
 
128
142
  def copy(self) -> Declarations:
129
143
  new = Declarations()
130
- new |= self
144
+ self.update_other(new)
131
145
  return new
132
146
 
133
147
  def update(self, *others: DeclerationsLike) -> None:
@@ -154,9 +168,13 @@ class Declarations:
154
168
  other._functions |= self._functions
155
169
  other._classes |= self._classes
156
170
  other._constants |= self._constants
171
+ # Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
172
+ # is added to functions.
173
+ combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
157
174
  other._rulesets |= self._rulesets
175
+ other._rulesets[""] = RulesetDecl(list(combined_default_rules))
158
176
 
159
- def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
177
+ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
160
178
  match ref:
161
179
  case FunctionRef(name):
162
180
  return self._functions[name]
@@ -170,8 +188,29 @@ class Declarations:
170
188
  return self._classes[class_name].class_methods[name]
171
189
  case PropertyRef(class_name, property_name):
172
190
  return self._classes[class_name].properties[property_name]
191
+ case InitRef(class_name):
192
+ init_fn = self._classes[class_name].init
193
+ assert init_fn
194
+ return init_fn
173
195
  assert_never(ref)
174
196
 
197
+ def set_function_decl(
198
+ self, ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef, decl: FunctionDecl
199
+ ) -> None:
200
+ match ref:
201
+ case FunctionRef(name):
202
+ self._functions[name] = decl
203
+ case MethodRef(class_name, method_name):
204
+ self._classes[class_name].methods[method_name] = decl
205
+ case ClassMethodRef(class_name, name):
206
+ self._classes[class_name].class_methods[name] = decl
207
+ case PropertyRef(class_name, property_name):
208
+ self._classes[class_name].properties[property_name] = decl
209
+ case InitRef(class_name):
210
+ self._classes[class_name].init = decl
211
+ case _:
212
+ assert_never(ref)
213
+
175
214
  def has_method(self, class_name: str, method_name: str) -> bool | None:
176
215
  """
177
216
  Returns whether the given class has the given method, or None if we cant find the class.
@@ -183,12 +222,20 @@ class Declarations:
183
222
  def get_class_decl(self, name: str) -> ClassDecl:
184
223
  return self._classes[name]
185
224
 
225
+ def get_paramaterized_class(self, name: str) -> TypeRefWithVars:
226
+ """
227
+ Returns a class reference with type parameters, if the class is paramaterized.
228
+ """
229
+ type_vars = self._classes[name].type_vars
230
+ return TypeRefWithVars(name, tuple(map(ClassTypeVarRef, type_vars)))
231
+
186
232
 
187
233
  @dataclass
188
234
  class ClassDecl:
189
235
  egg_name: str | None = None
190
236
  type_vars: tuple[str, ...] = ()
191
237
  builtin: bool = False
238
+ init: FunctionDecl | None = None
192
239
  class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
193
240
  # These have to be seperate from class_methods so that printing them can be done easily
194
241
  class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
@@ -293,6 +340,11 @@ class ClassMethodRef:
293
340
  method_name: str
294
341
 
295
342
 
343
+ @dataclass(frozen=True)
344
+ class InitRef:
345
+ class_name: str
346
+
347
+
296
348
  @dataclass(frozen=True)
297
349
  class ClassVariableRef:
298
350
  class_name: str
@@ -305,7 +357,9 @@ class PropertyRef:
305
357
  property_name: str
306
358
 
307
359
 
308
- CallableRef: TypeAlias = FunctionRef | ConstantRef | MethodRef | ClassMethodRef | ClassVariableRef | PropertyRef
360
+ CallableRef: TypeAlias = (
361
+ FunctionRef | ConstantRef | MethodRef | ClassMethodRef | InitRef | ClassVariableRef | PropertyRef
362
+ )
309
363
 
310
364
 
311
365
  ##
@@ -378,7 +432,6 @@ class FunctionSignature:
378
432
  @dataclass(frozen=True)
379
433
  class FunctionDecl:
380
434
  signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
381
-
382
435
  # Egg params
383
436
  builtin: bool = False
384
437
  egg_name: str | None = None
@@ -458,7 +511,7 @@ class CallDecl:
458
511
  bound_tp_params: tuple[JustTypeRef, ...] | None = None
459
512
 
460
513
  def __post_init__(self) -> None:
461
- if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
514
+ if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
462
515
  msg = "Cannot bind type parameters to a non-class method callable."
463
516
  raise ValueError(msg)
464
517
 
@@ -629,7 +682,13 @@ class RuleDecl:
629
682
  name: str | None
630
683
 
631
684
 
632
- RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl
685
+ @dataclass(frozen=True)
686
+ class DefaultRewriteDecl:
687
+ ref: CallableRef
688
+ expr: ExprDecl
689
+
690
+
691
+ RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
633
692
 
634
693
 
635
694
  @dataclass(frozen=True)
egglog/egraph.py CHANGED
@@ -352,7 +352,7 @@ class _BaseModule:
352
352
  This is the same as defining a nullary function with a high cost.
353
353
  # TODO: Rename as declare to match eggglog?
354
354
  """
355
- return constant(name, tp, egg_name)
355
+ return constant(name, tp, egg_name=egg_name)
356
356
 
357
357
  def register(
358
358
  self,
@@ -452,6 +452,7 @@ class _ExprMetaclass(type):
452
452
  namespace: dict[str, Any],
453
453
  egg_sort: str | None = None,
454
454
  builtin: bool = False,
455
+ ruleset: Ruleset | None = None,
455
456
  ) -> RuntimeClass | type:
456
457
  # If this is the Expr subclass, just return the class
457
458
  if not bases:
@@ -467,7 +468,14 @@ class _ExprMetaclass(type):
467
468
  # Otherwise, f_locals returns a copy
468
469
  # https://peps.python.org/pep-0667/
469
470
  decls_thunk = Thunk.fn(
470
- _generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
471
+ _generate_class_decls,
472
+ namespace,
473
+ prev_frame,
474
+ builtin,
475
+ egg_sort,
476
+ name,
477
+ ruleset,
478
+ fallback=Declarations,
471
479
  )
472
480
  return RuntimeClass(decls_thunk, TypeRefWithVars(name))
473
481
 
@@ -475,8 +483,13 @@ class _ExprMetaclass(type):
475
483
  return isinstance(instance, RuntimeExpr)
476
484
 
477
485
 
478
- def _generate_class_decls( # noqa: C901
479
- namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
486
+ def _generate_class_decls(
487
+ namespace: dict[str, Any],
488
+ frame: FrameType,
489
+ builtin: bool,
490
+ egg_sort: str | None,
491
+ cls_name: str,
492
+ ruleset: Ruleset | None,
480
493
  ) -> Declarations:
481
494
  """
482
495
  Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
@@ -498,9 +511,9 @@ def _generate_class_decls( # noqa: C901
498
511
  for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
499
512
  if getattr(v, "__origin__", None) == ClassVar:
500
513
  (inner_tp,) = v.__args__
501
- type_ref = resolve_type_annotation(decls, inner_tp).to_just()
502
- cls_decl.class_variables[k] = ConstantDecl(type_ref)
503
-
514
+ type_ref = resolve_type_annotation(decls, inner_tp)
515
+ cls_decl.class_variables[k] = ConstantDecl(type_ref.to_just())
516
+ _add_default_rewrite(decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset)
504
517
  else:
505
518
  msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
506
519
  raise NotImplementedError(msg)
@@ -510,13 +523,15 @@ def _generate_class_decls( # noqa: C901
510
523
  ##
511
524
 
512
525
  # The type ref of self is paramterized by the type vars
513
- slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
526
+ TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
514
527
 
515
528
  # Get all the methods from the class
516
529
  filtered_namespace: list[tuple[str, Any]] = [
517
530
  (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
518
531
  ]
519
532
 
533
+ # all methods we should try adding default functions for
534
+ default_function_refs: dict[ClassMethodRef | MethodRef | PropertyRef, Callable] = {}
520
535
  # Then register each of its methods
521
536
  for method_name, method in filtered_namespace:
522
537
  is_init = method_name == "__init__"
@@ -536,43 +551,26 @@ def _generate_class_decls( # noqa: C901
536
551
  continue
537
552
  locals = frame.f_locals
538
553
 
539
- def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
540
- special_function_name: SpecialFunctions | None = (
541
- "fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None # noqa: B023
542
- )
543
- if special_function_name:
544
- return FunctionDecl(
545
- special_function_name,
546
- builtin=True,
547
- egg_name=egg_fn, # noqa: B023
548
- )
549
-
550
- return _fn_decl(
551
- decls,
552
- egg_fn, # noqa: B023
553
- fn,
554
- locals, # noqa: B023
555
- default, # noqa: B023
556
- cost, # noqa: B023
557
- merge, # noqa: B023
558
- on_merge, # noqa: B023
559
- mutates, # noqa: B023
560
- builtin,
561
- first,
562
- is_init, # noqa: B023
563
- unextractable, # noqa: B023
564
- )
565
-
554
+ ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
566
555
  match fn:
567
556
  case classmethod():
568
- cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
557
+ ref = ClassMethodRef(cls_name, method_name)
558
+ fn = fn.__func__
569
559
  case property():
570
- cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
560
+ ref = PropertyRef(cls_name, method_name)
561
+ fn = fn.fget
571
562
  case _:
572
- if is_init:
573
- cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
574
- else:
575
- cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
563
+ ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
564
+
565
+ _fn_decl(decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, unextractable)
566
+
567
+ if not builtin and not isinstance(ref, InitRef) and not mutates:
568
+ default_function_refs[ref] = fn
569
+
570
+ # Add all rewrite methods at the end so that all methods are registered first and can be accessed
571
+ # in the bodies
572
+ for ref, fn in default_function_refs.items():
573
+ _add_default_rewrite_function(decls, ref, fn, ruleset)
576
574
 
577
575
  return decls
578
576
 
@@ -591,6 +589,7 @@ def function(
591
589
  mutates_first_arg: bool = False,
592
590
  unextractable: bool = False,
593
591
  builtin: bool = False,
592
+ ruleset: Ruleset | None = None,
594
593
  ) -> Callable[[CALLABLE], CALLABLE]: ...
595
594
 
596
595
 
@@ -604,6 +603,7 @@ def function(
604
603
  on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
605
604
  mutates_first_arg: bool = False,
606
605
  unextractable: bool = False,
606
+ ruleset: Ruleset | None = None,
607
607
  ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
608
608
 
609
609
 
@@ -634,15 +634,17 @@ class _FunctionConstructor:
634
634
  merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
635
635
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
636
636
  unextractable: bool = False
637
+ ruleset: Ruleset | None = None
637
638
 
638
639
  def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
639
640
  return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
640
641
 
641
642
  def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
642
643
  decls = Declarations()
643
- decls._functions[fn.__name__] = _fn_decl(
644
+ _fn_decl(
644
645
  decls,
645
646
  self.egg_fn,
647
+ (ref := FunctionRef(fn.__name__)),
646
648
  fn,
647
649
  self.hint_locals,
648
650
  self.default,
@@ -653,12 +655,14 @@ class _FunctionConstructor:
653
655
  self.builtin,
654
656
  unextractable=self.unextractable,
655
657
  )
658
+ _add_default_rewrite_function(decls, ref, fn, self.ruleset)
656
659
  return decls
657
660
 
658
661
 
659
662
  def _fn_decl(
660
663
  decls: Declarations,
661
664
  egg_name: str | None,
665
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
662
666
  fn: object,
663
667
  # Pass in the locals, retrieved from the frame when wrapping,
664
668
  # so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -669,14 +673,23 @@ def _fn_decl(
669
673
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
670
674
  mutates_first_arg: bool,
671
675
  is_builtin: bool,
672
- # The first arg is either cls, for a classmethod, a self type, or none for a function
673
- first_arg: Literal["cls"] | TypeOrVarRef | None = None,
674
- is_init: bool = False,
675
676
  unextractable: bool = False,
676
- ) -> FunctionDecl:
677
+ ) -> None:
678
+ """
679
+ Sets the function decl for the function object.
680
+ """
677
681
  if not isinstance(fn, FunctionType):
678
682
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
679
683
 
684
+ # partial function creation and calling are handled with a special case in the type checker, so don't
685
+ # use the normal logic
686
+ special_function_name: SpecialFunctions | None = (
687
+ "fn-partial" if egg_name == "unstable-fn" else "fn-app" if egg_name == "unstable-app" else None
688
+ )
689
+ if special_function_name:
690
+ decls.set_function_decl(ref, FunctionDecl(special_function_name, builtin=True, egg_name=egg_name))
691
+ return
692
+
680
693
  hint_globals = fn.__globals__.copy()
681
694
  # Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
682
695
  # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
@@ -687,8 +700,8 @@ def _fn_decl(
687
700
 
688
701
  params = list(signature(fn).parameters.values())
689
702
 
690
- # If this is an init function, or a classmethod, remove the first arg name
691
- if is_init or first_arg == "cls":
703
+ # If this is an init function, or a classmethod, the first arg is not used
704
+ if isinstance(ref, ClassMethodRef | InitRef):
692
705
  params = params[1:]
693
706
 
694
707
  if _last_param_variable(params):
@@ -698,9 +711,8 @@ def _fn_decl(
698
711
  else:
699
712
  var_arg_type = None
700
713
  arg_types = tuple(
701
- first_arg
702
- # If the first arg is a self, and this not an __init__ fn, add this as a typeref
703
- if i == 0 and isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init
714
+ decls.get_paramaterized_class(ref.class_name)
715
+ if i == 0 and isinstance(ref, MethodRef | PropertyRef)
704
716
  else resolve_type_annotation(decls, hints[t.name])
705
717
  for i, t in enumerate(params)
706
718
  )
@@ -713,17 +725,15 @@ def _fn_decl(
713
725
 
714
726
  decls.update(*arg_defaults)
715
727
 
716
- # If this is an init fn use the first arg as the return type
717
- if is_init:
718
- assert not mutates_first_arg
719
- if not isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars):
720
- msg = "Init function must have a self type"
721
- raise ValueError(msg)
722
- return_type = first_arg
723
- elif mutates_first_arg:
724
- return_type = arg_types[0]
725
- else:
726
- return_type = resolve_type_annotation(decls, hints["return"])
728
+ return_type = (
729
+ decls.get_paramaterized_class(ref.class_name)
730
+ if isinstance(ref, InitRef)
731
+ else arg_types[0]
732
+ if mutates_first_arg
733
+ else resolve_type_annotation(decls, hints["return"])
734
+ )
735
+
736
+ arg_names = tuple(t.name for t in params)
727
737
 
728
738
  decls |= default
729
739
  merged = (
@@ -747,12 +757,12 @@ def _fn_decl(
747
757
  )
748
758
  )
749
759
  decls.update(*merge_action)
750
- return FunctionDecl(
760
+ decl = FunctionDecl(
751
761
  FunctionSignature(
752
762
  return_type=None if mutates_first_arg else return_type,
753
763
  var_arg_type=var_arg_type,
754
764
  arg_types=arg_types,
755
- arg_names=tuple(t.name for t in params),
765
+ arg_names=arg_names,
756
766
  arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
757
767
  ),
758
768
  cost=cost,
@@ -763,6 +773,7 @@ def _fn_decl(
763
773
  default=None if default is None else default.__egg_typed_expr__.expr,
764
774
  on_merge=tuple(a.action for a in merge_action),
765
775
  )
776
+ decls.set_function_decl(ref, decl)
766
777
 
767
778
 
768
779
  # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -804,19 +815,73 @@ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Dec
804
815
  return decls
805
816
 
806
817
 
807
- def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
818
+ def constant(
819
+ name: str,
820
+ tp: type[EXPR],
821
+ default_replacement: EXPR | None = None,
822
+ /,
823
+ *,
824
+ egg_name: str | None = None,
825
+ ruleset: Ruleset | None = None,
826
+ ) -> EXPR:
808
827
  """
809
828
  A "constant" is implemented as the instantiation of a value that takes no args.
810
829
  This creates a function with `name` and return type `tp` and returns a value of it being called.
811
830
  """
812
- return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
831
+ return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
813
832
 
814
833
 
815
- def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
834
+ def _constant_thunk(
835
+ name: str, tp: type, egg_name: str | None, default_replacement: object, ruleset: Ruleset | None
836
+ ) -> tuple[Declarations, TypedExprDecl]:
816
837
  decls = Declarations()
817
- type_ref = resolve_type_annotation(decls, tp).to_just()
818
- decls._constants[name] = ConstantDecl(type_ref, egg_name)
819
- return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
838
+ type_ref = resolve_type_annotation(decls, tp)
839
+ callable_ref = ConstantRef(name)
840
+ decls._constants[name] = ConstantDecl(type_ref.to_just(), egg_name)
841
+ _add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset)
842
+ return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
843
+
844
+
845
+ def _add_default_rewrite_function(decls: Declarations, ref: CallableRef, fn: Callable, ruleset: Ruleset | None) -> None:
846
+ """
847
+ Adds a default rewrite for a function, by calling the functions with vars and adding it if it is not None.
848
+ """
849
+ callable_decl = decls.get_callable_decl(ref)
850
+ assert isinstance(callable_decl, FunctionDecl)
851
+ signature = callable_decl.signature
852
+ assert isinstance(signature, FunctionSignature)
853
+
854
+ var_args: list[object] = [
855
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name))))
856
+ for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
857
+ ]
858
+ # If this is a classmethod, add the class as the first arg
859
+ if isinstance(ref, ClassMethodRef):
860
+ tp = decls.get_paramaterized_class(ref.class_name)
861
+ var_args.insert(0, RuntimeClass(Thunk.value(decls), tp))
862
+ _add_default_rewrite(decls, ref, signature.semantic_return_type, fn(*var_args), ruleset)
863
+
864
+
865
+ def _add_default_rewrite(
866
+ decls: Declarations, ref: CallableRef, type_ref: TypeOrVarRef, default_rewrite: object, ruleset: Ruleset | None
867
+ ) -> None:
868
+ """
869
+ Adds a default rewrite for the callable, if the default rewrite is not None
870
+
871
+ Will add it to the ruleset if it is passed in, or add it to the default ruleset on the passed in decls if not.
872
+ """
873
+ if default_rewrite is None:
874
+ return
875
+ resolved_value = resolve_literal(type_ref, default_rewrite)
876
+ rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
877
+ if ruleset:
878
+ ruleset_decls = ruleset._current_egg_decls
879
+ ruleset_decl = ruleset.__egg_ruleset__
880
+ else:
881
+ ruleset_decls = decls
882
+ ruleset_decl = decls.default_ruleset
883
+ ruleset_decl.rules.append(rewrite_decl)
884
+ ruleset_decls |= resolved_value
820
885
 
821
886
 
822
887
  def _last_param_variable(params: list[Parameter]) -> bool:
@@ -991,6 +1056,7 @@ class EGraph(_BaseModule):
991
1056
  action = let(name, expr)
992
1057
  self.register(action)
993
1058
  runtime_expr = to_runtime_expr(expr)
1059
+ self._add_decls(runtime_expr)
994
1060
  return cast(
995
1061
  EXPR,
996
1062
  RuntimeExpr.__from_value__(
@@ -1016,7 +1082,8 @@ class EGraph(_BaseModule):
1016
1082
  self._add_decls(runtime_expr, schedule)
1017
1083
  egg_schedule = self._state.schedule_to_egg(schedule.schedule)
1018
1084
  typed_expr = runtime_expr.__egg_typed_expr__
1019
- egg_expr = self._state.expr_to_egg(typed_expr.expr)
1085
+ # Must also register type
1086
+ egg_expr = self._state.typed_expr_to_egg(typed_expr)
1020
1087
  self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
1021
1088
  extract_report = self._egraph.extract_report()
1022
1089
  if not isinstance(extract_report, bindings.Best):
@@ -1121,8 +1188,7 @@ class EGraph(_BaseModule):
1121
1188
  return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1122
1189
 
1123
1190
  def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
1124
- self._state.type_ref_to_egg(typed_expr.tp)
1125
- expr = self._state.expr_to_egg(typed_expr.expr)
1191
+ expr = self._state.typed_expr_to_egg(typed_expr)
1126
1192
  self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1127
1193
  extract_report = self._egraph.extract_report()
1128
1194
  if not extract_report:
@@ -1181,7 +1247,7 @@ class EGraph(_BaseModule):
1181
1247
  runtime_expr = to_runtime_expr(expr)
1182
1248
  self._add_decls(runtime_expr)
1183
1249
  typed_expr = runtime_expr.__egg_typed_expr__
1184
- egg_expr = self._state.expr_to_egg(typed_expr.expr)
1250
+ egg_expr = self._state.typed_expr_to_egg(typed_expr)
1185
1251
  match typed_expr.tp:
1186
1252
  case JustTypeRef("i64"):
1187
1253
  return self._egraph.eval_i64(egg_expr)
@@ -1878,7 +1944,7 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
1878
1944
  if "Callable" not in globals:
1879
1945
  globals["Callable"] = Callable
1880
1946
  hints = get_type_hints(gen, globals, frame.f_locals)
1881
- args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1947
+ args = [_var(_rule_var_name(p.name), hints[p.name]) for p in signature(gen).parameters.values()]
1882
1948
  return list(gen(*args)) # type: ignore[misc]
1883
1949
 
1884
1950
 
egglog/egraph_state.py CHANGED
@@ -19,7 +19,7 @@ from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
19
19
  if TYPE_CHECKING:
20
20
  from collections.abc import Iterable
21
21
 
22
- __all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
22
+ __all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT", "_rule_var_name"]
23
23
 
24
24
  # Create a global sort for python objects, so we can store them without an e-graph instance
25
25
  # Needed when serializing commands to egg commands when creating modules
@@ -98,7 +98,8 @@ class EGraphState:
98
98
  for rule in rules:
99
99
  if rule in added_rules:
100
100
  continue
101
- self.egraph.run_program(self.command_to_egg(rule, name))
101
+ cmd = self.command_to_egg(rule, name)
102
+ self.egraph.run_program(cmd)
102
103
  added_rules.add(rule)
103
104
  case CombinedRulesetDecl(rulesets):
104
105
  if name in self.rulesets:
@@ -115,8 +116,8 @@ class EGraphState:
115
116
  case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
116
117
  self.type_ref_to_egg(tp)
117
118
  rewrite = bindings.Rewrite(
118
- self.expr_to_egg(lhs),
119
- self.expr_to_egg(rhs),
119
+ self._expr_to_egg(lhs),
120
+ self._expr_to_egg(rhs),
120
121
  [self.fact_to_egg(c) for c in conditions],
121
122
  )
122
123
  return (
@@ -130,6 +131,16 @@ class EGraphState:
130
131
  [self.fact_to_egg(f) for f in body],
131
132
  )
132
133
  return bindings.RuleCommand(name or "", ruleset, rule)
134
+ case DefaultRewriteDecl(ref, expr):
135
+ decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
136
+ sig = decl.signature
137
+ assert isinstance(sig, FunctionSignature)
138
+ args = tuple(
139
+ TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name)))
140
+ for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
141
+ )
142
+ rewrite_decl = RewriteDecl(sig.semantic_return_type.to_just(), CallDecl(ref, args), expr, (), False)
143
+ return self.command_to_egg(rewrite_decl, ruleset)
133
144
  case _:
134
145
  assert_never(cmd)
135
146
 
@@ -139,13 +150,13 @@ class EGraphState:
139
150
  return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
140
151
  case SetDecl(tp, call, rhs):
141
152
  self.type_ref_to_egg(tp)
142
- call_ = self.expr_to_egg(call)
143
- return bindings.Set(call_.name, call_.args, self.expr_to_egg(rhs))
153
+ call_ = self._expr_to_egg(call)
154
+ return bindings.Set(call_.name, call_.args, self._expr_to_egg(rhs))
144
155
  case ExprActionDecl(typed_expr):
145
156
  return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
146
157
  case ChangeDecl(tp, call, change):
147
158
  self.type_ref_to_egg(tp)
148
- call_ = self.expr_to_egg(call)
159
+ call_ = self._expr_to_egg(call)
149
160
  egg_change: bindings._Change
150
161
  match change:
151
162
  case "delete":
@@ -157,7 +168,7 @@ class EGraphState:
157
168
  return bindings.Change(egg_change, call_.name, call_.args)
158
169
  case UnionDecl(tp, lhs, rhs):
159
170
  self.type_ref_to_egg(tp)
160
- return bindings.Union(self.expr_to_egg(lhs), self.expr_to_egg(rhs))
171
+ return bindings.Union(self._expr_to_egg(lhs), self._expr_to_egg(rhs))
161
172
  case PanicDecl(name):
162
173
  return bindings.Panic(name)
163
174
  case _:
@@ -167,7 +178,7 @@ class EGraphState:
167
178
  match fact:
168
179
  case EqDecl(tp, exprs):
169
180
  self.type_ref_to_egg(tp)
170
- return bindings.Eq([self.expr_to_egg(e) for e in exprs])
181
+ return bindings.Eq([self._expr_to_egg(e) for e in exprs])
171
182
  case ExprFactDecl(typed_expr):
172
183
  return bindings.Fact(self.typed_expr_to_egg(typed_expr))
173
184
  case _:
@@ -201,8 +212,8 @@ class EGraphState:
201
212
  [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
202
213
  self.type_ref_to_egg(signature.semantic_return_type.to_just()),
203
214
  ),
204
- self.expr_to_egg(decl.default) if decl.default else None,
205
- self.expr_to_egg(decl.merge) if decl.merge else None,
215
+ self._expr_to_egg(decl.default) if decl.default else None,
216
+ self._expr_to_egg(decl.merge) if decl.merge else None,
206
217
  [self.action_to_egg(a) for a in decl.on_merge],
207
218
  decl.cost,
208
219
  decl.unextractable,
@@ -245,6 +256,8 @@ class EGraphState:
245
256
  if decl.builtin:
246
257
  for method in decl.class_methods:
247
258
  self.callable_ref_to_egg(ClassMethodRef(ref.name, method))
259
+ if decl.init:
260
+ self.callable_ref_to_egg(InitRef(ref.name))
248
261
 
249
262
  return egg_name
250
263
 
@@ -261,15 +274,15 @@ class EGraphState:
261
274
 
262
275
  def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
263
276
  self.type_ref_to_egg(typed_expr_decl.tp)
264
- return self.expr_to_egg(typed_expr_decl.expr)
277
+ return self._expr_to_egg(typed_expr_decl.expr)
265
278
 
266
279
  @overload
267
- def expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
280
+ def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
268
281
 
269
282
  @overload
270
- def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
283
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
271
284
 
272
- def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
285
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
273
286
  """
274
287
  Convert an ExprDecl to an egg expression.
275
288
 
@@ -307,7 +320,7 @@ class EGraphState:
307
320
  case PyObjectDecl(value):
308
321
  res = GLOBAL_PY_OBJECT_SORT.store(value)
309
322
  case PartialCallDecl(call_decl):
310
- egg_fn_call = self.expr_to_egg(call_decl)
323
+ egg_fn_call = self._expr_to_egg(call_decl)
311
324
  res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
312
325
  case _:
313
326
  assert_never(expr_decl.expr)
@@ -355,6 +368,8 @@ def _generate_callable_egg_name(ref: CallableRef) -> str:
355
368
  | PropertyRef(cls_name, name)
356
369
  ):
357
370
  return f"{cls_name}_{name}"
371
+ case InitRef(cls_name):
372
+ return f"{cls_name}___init__"
358
373
  case _:
359
374
  assert_never(ref)
360
375
 
@@ -427,8 +442,11 @@ class FromEggState:
427
442
  possible_types: Iterable[JustTypeRef | None]
428
443
  signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature
429
444
  assert isinstance(signature, FunctionSignature)
430
- if isinstance(callable_ref, ClassMethodRef):
431
- possible_types = self.state._get_possible_types(callable_ref.class_name)
445
+ if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
446
+ # Need OR in case we have class method whose class whas never added as a sort, which would happen
447
+ # if the class method didn't return that type and no other function did. In this case, we don't need
448
+ # to care about the type vars and we we don't need to bind any possible type.
449
+ possible_types = self.state._get_possible_types(callable_ref.class_name) or [None]
432
450
  cls_name = callable_ref.class_name
433
451
  else:
434
452
  possible_types = [None]
@@ -437,7 +455,6 @@ class FromEggState:
437
455
  tcs = TypeConstraintSolver(self.decls)
438
456
  if possible_type and possible_type.args:
439
457
  tcs.bind_class(possible_type)
440
-
441
458
  try:
442
459
  arg_types, bound_tp_params = tcs.infer_arg_types(
443
460
  signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
@@ -445,7 +462,14 @@ class FromEggState:
445
462
  except TypeConstraintError:
446
463
  continue
447
464
  args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
448
- return CallDecl(callable_ref, args, bound_tp_params)
465
+
466
+ return CallDecl(
467
+ callable_ref,
468
+ args,
469
+ # Don't include bound type params if this is just a method, we only needed them for type resolution
470
+ # but dont need to store them
471
+ bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
472
+ )
449
473
  raise ValueError(f"Could not find callable ref for call {term}")
450
474
 
451
475
  def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
@@ -454,3 +478,10 @@ class FromEggState:
454
478
  except KeyError:
455
479
  res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
456
480
  return res
481
+
482
+
483
+ def _rule_var_name(s: str) -> str:
484
+ """
485
+ Create a hidden variable name, for rewrites, so that let bindings or function won't conflict with it
486
+ """
487
+ return f"__var__{s}"
@@ -0,0 +1,50 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Higher Order Functions
4
+ ======================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ from egglog import *
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Callable
15
+
16
+
17
+ class Math(Expr):
18
+ def __init__(self, i: i64Like) -> None: ...
19
+
20
+ def __add__(self, other: Math) -> Math: ...
21
+
22
+
23
+ class MathList(Expr):
24
+ def __init__(self) -> None: ...
25
+
26
+ def append(self, i: Math) -> MathList: ...
27
+
28
+ def map(self, f: Callable[[Math], Math]) -> MathList: ...
29
+
30
+
31
+ @ruleset
32
+ def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]): # noqa: ANN201
33
+ yield rewrite(Math(i) + Math(j)).to(Math(i + j))
34
+ yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
35
+ yield rewrite(MathList().map(f)).to(MathList())
36
+
37
+
38
+ @function(ruleset=math_ruleset)
39
+ def increment_by_one(x: Math) -> Math:
40
+ return x + Math(1)
41
+
42
+
43
+ egraph = EGraph()
44
+ x = egraph.let("x", MathList().append(Math(1)).append(Math(2)))
45
+ y = egraph.let("y", x.map(increment_by_one))
46
+ egraph.run(math_ruleset.saturate())
47
+
48
+ egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
49
+
50
+ egraph
egglog/exp/array_api.py CHANGED
@@ -882,7 +882,10 @@ converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
882
882
 
883
883
  @function
884
884
  def asarray(
885
- a: NDArray, dtype: OptionalDType = OptionalDType.none, copy: OptionalBool = OptionalBool.none
885
+ a: NDArray,
886
+ dtype: OptionalDType = OptionalDType.none,
887
+ copy: OptionalBool = OptionalBool.none,
888
+ device: OptionalDevice = OptionalDevice.none,
886
889
  ) -> NDArray: ...
887
890
 
888
891
 
egglog/pretty.py CHANGED
@@ -166,6 +166,8 @@ class TraverseContext:
166
166
  pass
167
167
  case EqDecl(_, decls) | SequenceDecl(decls) | RulesetDecl(decls):
168
168
  for de in decls:
169
+ if isinstance(de, DefaultRewriteDecl):
170
+ continue
169
171
  self(de)
170
172
  case CallDecl(_, exprs, _):
171
173
  for e in exprs:
@@ -178,6 +180,8 @@ class TraverseContext:
178
180
  self(c)
179
181
  case CombinedRulesetDecl(_):
180
182
  pass
183
+ case DefaultRewriteDecl():
184
+ pass
181
185
  case _:
182
186
  assert_never(decl)
183
187
 
@@ -276,7 +280,7 @@ class PrettyContext:
276
280
  case RulesetDecl(rules):
277
281
  if ruleset_name:
278
282
  return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
279
- args = ", ".join(map(self, rules))
283
+ args = ", ".join(self(r) for r in rules if not isinstance(r, DefaultRewriteDecl))
280
284
  return f"ruleset({args})", "ruleset"
281
285
  case CombinedRulesetDecl(rulesets):
282
286
  if ruleset_name:
@@ -298,6 +302,9 @@ class PrettyContext:
298
302
  return ruleset_str, "schedule"
299
303
  args = ", ".join(map(self, until))
300
304
  return f"run({ruleset_str}, {args})", "schedule"
305
+ case DefaultRewriteDecl():
306
+ msg = "default rewrites should not be pretty printed"
307
+ raise TypeError(msg)
301
308
  assert_never(decl)
302
309
 
303
310
  def _call(
@@ -370,10 +377,8 @@ class PrettyContext:
370
377
  case FunctionRef(name):
371
378
  return name, args
372
379
  case ClassMethodRef(class_name, method_name):
373
- fn_str = str(JustTypeRef(class_name, bound_tp_params or ()))
374
- if method_name != "__init__":
375
- fn_str += f".{method_name}"
376
- return fn_str, args
380
+ tp_ref = JustTypeRef(class_name, bound_tp_params or ())
381
+ return f"{tp_ref}.{method_name}", args
377
382
  case MethodRef(_class_name, method_name):
378
383
  slf, *args = args
379
384
  slf = self(slf, parens=True)
@@ -400,6 +405,9 @@ class PrettyContext:
400
405
  return f"{class_name}.{variable_name}"
401
406
  case PropertyRef(_class_name, property_name):
402
407
  return f"{self(args[0], parens=True)}.{property_name}"
408
+ case InitRef(class_name):
409
+ tp_ref = JustTypeRef(class_name, bound_tp_params or ())
410
+ return str(tp_ref), args
403
411
  assert_never(ref)
404
412
 
405
413
  def _generate_name(self, typ: str) -> str:
@@ -434,6 +442,8 @@ def _pretty_callable(ref: CallableRef) -> str:
434
442
  | PropertyRef(class_name, method_name)
435
443
  ):
436
444
  return f"{class_name}.{method_name}"
445
+ case InitRef(class_name):
446
+ return class_name
437
447
  case ConstantRef(_):
438
448
  msg = "Constants should not be callable"
439
449
  raise NotImplementedError(msg)
egglog/runtime.py CHANGED
@@ -163,10 +163,8 @@ class RuntimeClass(DelayedDeclerations):
163
163
  return RuntimeExpr.__from_value__(
164
164
  self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
165
165
  )
166
-
167
- return RuntimeFunction(
168
- Thunk.value(self.__egg_decls__), ClassMethodRef(name, "__init__"), self.__egg_tp__.to_just()
169
- )(*args, **kwargs) # type: ignore[arg-type]
166
+ fn = RuntimeFunction(Thunk.value(self.__egg_decls__), InitRef(name), self.__egg_tp__.to_just())
167
+ return fn(*args, **kwargs) # type: ignore[arg-type]
170
168
 
171
169
  def __dir__(self) -> list[str]:
172
170
  cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
@@ -277,7 +275,10 @@ class RuntimeFunction(DelayedDeclerations):
277
275
 
278
276
  # Turn all keyword args into positional args
279
277
  py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
280
- bound = py_signature.bind(*args, **kwargs)
278
+ try:
279
+ bound = py_signature.bind(*args, **kwargs)
280
+ except TypeError as err:
281
+ raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err
281
282
  del kwargs
282
283
  bound.apply_defaults()
283
284
  assert not bound.kwargs
@@ -310,7 +311,9 @@ class RuntimeFunction(DelayedDeclerations):
310
311
  return_tp = tcs.infer_return_type(
311
312
  signature.arg_types, signature.semantic_return_type, signature.var_arg_type, arg_types, cls_name
312
313
  )
313
- bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None
314
+ bound_params = (
315
+ cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None
316
+ )
314
317
  # If we were using unstable-app to call a funciton, add that function back as the first arg.
315
318
  if function_value:
316
319
  arg_exprs = (function_value, *arg_exprs)
egglog/thunk.py CHANGED
@@ -21,7 +21,7 @@ class Thunk(Generic[T, Unpack[TS]]):
21
21
  Cached delayed function call.
22
22
  """
23
23
 
24
- state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving[T]
24
+ state: Resolved[T] | Unresolved[T, Unpack[TS]] | Resolving[T] | Error
25
25
 
26
26
  @classmethod
27
27
  def fn(
@@ -44,14 +44,21 @@ class Thunk(Generic[T, Unpack[TS]]):
44
44
  return value
45
45
  case Unresolved(fn, args, fallback):
46
46
  self.state = Resolving(fallback)
47
- res = fn(*args)
48
- self.state = Resolved(res)
49
- return res
47
+ try:
48
+ res = fn(*args)
49
+ except Exception as e:
50
+ self.state = Error(e)
51
+ raise
52
+ else:
53
+ self.state = Resolved(res)
54
+ return res
50
55
  case Resolving(fallback):
51
56
  if fallback is None:
52
57
  msg = "Recursively resolving thunk without fallback"
53
58
  raise ValueError(msg)
54
59
  return fallback()
60
+ case Error(e):
61
+ raise e
55
62
 
56
63
 
57
64
  @dataclass
@@ -69,3 +76,8 @@ class Unresolved(Generic[T, Unpack[TS]]):
69
76
  @dataclass
70
77
  class Resolving(Generic[T]):
71
78
  fallback: Callable[[], T] | None
79
+
80
+
81
+ @dataclass
82
+ class Error:
83
+ e: Exception
@@ -79,9 +79,10 @@ class TypeConstraintSolver:
79
79
  Also returns the bound type params if the class name is passed in.
80
80
  """
81
81
  self._infer_typevars(fn_return, return_, cls_name)
82
- arg_types = (
83
- self._subtitute_typevars(a, cls_name) for a in chain(fn_args, repeat(fn_var_args) if fn_var_args else [])
84
- )
82
+ arg_types: Iterable[JustTypeRef] = [self._subtitute_typevars(a, cls_name) for a in fn_args]
83
+ if fn_var_args:
84
+ # Need to be generator so it can be infinite for variable args
85
+ arg_types = chain(arg_types, repeat(self._subtitute_typevars(fn_var_args, cls_name)))
85
86
  bound_typevars = (
86
87
  tuple(
87
88
  v
@@ -132,8 +133,8 @@ class TypeConstraintSolver:
132
133
  def _subtitute_typevars(self, tp: TypeOrVarRef, cls_name: str | None) -> JustTypeRef:
133
134
  match tp:
134
135
  case ClassTypeVarRef(name):
136
+ assert cls_name is not None
135
137
  try:
136
- assert cls_name is not None
137
138
  return self._cls_typevar_index_to_type[cls_name][name]
138
139
  except KeyError as e:
139
140
  raise TypeConstraintError(f"Not enough bound typevars for {tp}") from e
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: egglog
3
- Version: 7.1.0
3
+ Version: 7.2.0
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers
@@ -19,22 +19,6 @@ Classifier: Typing :: Typed
19
19
  Requires-Dist: typing-extensions
20
20
  Requires-Dist: black
21
21
  Requires-Dist: graphviz
22
- Requires-Dist: scikit-learn; extra == 'array'
23
- Requires-Dist: array_api_compat; extra == 'array'
24
- Requires-Dist: numba==0.59.1; extra == 'array'
25
- Requires-Dist: llvmlite==0.42.0; extra == 'array'
26
- Requires-Dist: pre-commit; extra == 'dev'
27
- Requires-Dist: ruff; extra == 'dev'
28
- Requires-Dist: mypy; extra == 'dev'
29
- Requires-Dist: anywidget[dev]; extra == 'dev'
30
- Requires-Dist: egglog[docs,test]; extra == 'dev'
31
- Requires-Dist: pytest; extra == 'test'
32
- Requires-Dist: mypy; extra == 'test'
33
- Requires-Dist: syrupy; extra == 'test'
34
- Requires-Dist: egglog[array]; extra == 'test'
35
- Requires-Dist: pytest-codspeed; extra == 'test'
36
- Requires-Dist: pytest-benchmark; extra == 'test'
37
- Requires-Dist: pytest-xdist; extra == 'test'
38
22
  Requires-Dist: pydata-sphinx-theme; extra == 'docs'
39
23
  Requires-Dist: myst-nb; extra == 'docs'
40
24
  Requires-Dist: sphinx-autodoc-typehints; extra == 'docs'
@@ -47,10 +31,26 @@ Requires-Dist: egglog[array]; extra == 'docs'
47
31
  Requires-Dist: line-profiler; extra == 'docs'
48
32
  Requires-Dist: sphinxcontrib-mermaid; extra == 'docs'
49
33
  Requires-Dist: ablog; extra == 'docs'
34
+ Requires-Dist: pytest; extra == 'test'
35
+ Requires-Dist: mypy; extra == 'test'
36
+ Requires-Dist: syrupy; extra == 'test'
37
+ Requires-Dist: egglog[array]; extra == 'test'
38
+ Requires-Dist: pytest-codspeed; extra == 'test'
39
+ Requires-Dist: pytest-benchmark; extra == 'test'
40
+ Requires-Dist: pytest-xdist; extra == 'test'
41
+ Requires-Dist: scikit-learn; extra == 'array'
42
+ Requires-Dist: array_api_compat; extra == 'array'
43
+ Requires-Dist: numba==0.59.1; extra == 'array'
44
+ Requires-Dist: llvmlite==0.42.0; extra == 'array'
45
+ Requires-Dist: pre-commit; extra == 'dev'
46
+ Requires-Dist: ruff; extra == 'dev'
47
+ Requires-Dist: mypy; extra == 'dev'
48
+ Requires-Dist: anywidget[dev]; extra == 'dev'
49
+ Requires-Dist: egglog[docs,test]; extra == 'dev'
50
+ Provides-Extra: docs
51
+ Provides-Extra: test
50
52
  Provides-Extra: array
51
53
  Provides-Extra: dev
52
- Provides-Extra: test
53
- Provides-Extra: docs
54
54
  License-File: LICENSE
55
55
  Summary: e-graphs in Python built around the the egglog rust library
56
56
  License: MIT
@@ -1,16 +1,17 @@
1
- egglog-7.1.0.dist-info/METADATA,sha256=t8YGDTCeCIpcJMDrAiicXrJAARElS-vM1R6-GY4wNzE,3829
2
- egglog-7.1.0.dist-info/WHEEL,sha256=keLBtIOE7ZgLyd8Cijw37iqv72a2jhfxMdDu8Unr7zo,96
3
- egglog-7.1.0.dist-info/license_files/LICENSE,sha256=TfaboMVZ81Q6OUaKjU7z6uVjSlcGKclLYcOpgDbm9_s,1091
1
+ egglog-7.2.0.dist-info/METADATA,sha256=O3Uq5oIIHWZ-bG2XVILykl-0dYyuNuoR9Dxj38v6bXs,3829
2
+ egglog-7.2.0.dist-info/WHEEL,sha256=keLBtIOE7ZgLyd8Cijw37iqv72a2jhfxMdDu8Unr7zo,96
3
+ egglog-7.2.0.dist-info/license_files/LICENSE,sha256=TfaboMVZ81Q6OUaKjU7z6uVjSlcGKclLYcOpgDbm9_s,1091
4
4
  egglog/bindings.pyi,sha256=iFdtYHqPjuVt46qh1IE81cOI-lbGgPISKQB-3ERNzhI,11770
5
5
  egglog/builtins.py,sha256=p5oZLDleCgQC8G2Zp2EvmqsfVnpgmARqUHqosKSSKnQ,13120
6
6
  egglog/config.py,sha256=mALVaxh7zmGrbuyzaVKVmYKcu1lF703QsKJw8AF7gSM,176
7
7
  egglog/conversion.py,sha256=4JhocGd1_nwmFMVNCzDDwj6aWfhBFTX2hm_7Xc_DiUM,6328
8
- egglog/declarations.py,sha256=LFClid1ojrQFF7ezJSanwZA89j01fjJ-NkvH0VUH0Wk,16199
9
- egglog/egraph.py,sha256=WRu_omk-5i53wokpkSTw9wA2EqW_WI2-ZX8340_w3q8,66198
10
- egglog/egraph_state.py,sha256=ic6MEL41-DjozBCe-GSKqHM2vkabuC865odeg_zT964,20074
8
+ egglog/declarations.py,sha256=WpZ1yyw4vlWm3Xi_otPP2IU51IqNcDmJA5y20RcBFBQ,18583
9
+ egglog/egraph.py,sha256=Q3ZJhtzzWcsukuarT0vT48-DkfHe17adLLUHn-pILPE,68657
10
+ egglog/egraph_state.py,sha256=kXCNt4zenG1meRv5T6B8xI5oH_NreA06JMhmqv_F-og,21802
11
11
  egglog/examples/bool.py,sha256=pWZTjfXR1cFy3KcihLBU5AF5rn83ImORlhUUJ1YiAXc,733
12
12
  egglog/examples/eqsat_basic.py,sha256=ORXFYYEDsEZK2IPhHtoFsd-LdjMiQi1nn7kix4Nam0s,1011
13
13
  egglog/examples/fib.py,sha256=wAn-PjazxgHDkXAU4o2xTk_GtM_iGL0biV66vWM1st4,520
14
+ egglog/examples/higher_order_functions.py,sha256=CBZqnqVdYGgztHw5QWtcnUhzs1nnBcw4O-gjSzAxjRc,1203
14
15
  egglog/examples/lambda_.py,sha256=hQBOaSw_yorNcbkQVu2EhgSc0IZNWIny7asaOlcUk9s,8496
15
16
  egglog/examples/matrix.py,sha256=_zmjgfFr2O_LjTcsTD-45_38Y_M1sP3AV39K6oFxAdw,5136
16
17
  egglog/examples/ndarrays.py,sha256=T-wwef-n-3LDSjaO35zA8AZH5DXFFqq0XBSCQKEXV6E,4186
@@ -18,7 +19,7 @@ egglog/examples/README.rst,sha256=QrbfmivODBvUvmY3-dHarcbC6bEvwoqAfTDhiI-aJxU,23
18
19
  egglog/examples/resolution.py,sha256=sKkbRI_v9XkQM0DriacKLINqKKDqYGFhvMCAS9tZbTA,2203
19
20
  egglog/examples/schedule_demo.py,sha256=iJtIbcLaZ7zK8UalY0z7KAKMqYjQx0MKTsNF24lKtik,652
20
21
  egglog/examples/__init__.py,sha256=KuhaJFOyz_rpUvEqZubsgLnv6rhQNE_AVFXA6bUnpdY,34
21
- egglog/exp/array_api.py,sha256=diQNK-Uxxt6Qr58zJuQDM_BF6eKgwa-D5eYB5VJoT70,41412
22
+ egglog/exp/array_api.py,sha256=b7MoUuH2sIW3KnUX7fGPaflwlCCskgrt1GsidVT-KG4,41474
22
23
  egglog/exp/array_api_jit.py,sha256=HIZzd0G17u-u_F4vfRdhoYvRo-ETx5HFO3RBcOfLcco,1287
23
24
  egglog/exp/array_api_numba.py,sha256=OjONBLdFRukx4vKiNK_rBxjMzsxbWpMEdD7JcGHJjmY,2924
24
25
  egglog/exp/array_api_program_gen.py,sha256=crgUYXXNhQdfTq31FSIpWLIfzNsgQD8ngg3OosCtIgg,19680
@@ -27,13 +28,13 @@ egglog/exp/siu_examples.py,sha256=KZmpkSCgbL4uqHhx2Jh9Adz1_cDbkIlyXPa1Kh_0O_8,78
27
28
  egglog/exp/__init__.py,sha256=G9zeKUcPBgIhgUg1meC86OfZVFETYIugyHWseTcCe_0,52
28
29
  egglog/graphviz_widget.py,sha256=YtI7LCFWHihDQ1qLvdj2SVYEcuZLSooFUYheunOTxdc,1339
29
30
  egglog/ipython_magic.py,sha256=VA19xAb6Sz7-IxlJBbnZW_gVFDqaYNnvdMB9QitndjE,1254
30
- egglog/pretty.py,sha256=4YzItMaXSQw2xzId9pnjUikOojwE6G7Hg0_M4sOGJ84,19071
31
+ egglog/pretty.py,sha256=s5XBIB7kMrxV25aO-F4NW5Jli2DOoc72wwswFYxJPjM,19561
31
32
  egglog/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
- egglog/runtime.py,sha256=P5wKThFglf2l5AynxgVAryyNjvLtFMbou7JCSSsORno,23013
33
- egglog/thunk.py,sha256=jYMb4owraT-3zGQEYkqMK76hk39Qdpc6bfEa7CVkjcg,1894
34
- egglog/type_constraint_solver.py,sha256=0c8oy-sLCVN0XndfAyRxO1AvIgfQuszw0NaRJ9ILysE,5692
33
+ egglog/runtime.py,sha256=s32xdoD_mIdDRi_PJQ5E_kSNyhUkT310osH38kZfCSk,23174
34
+ egglog/thunk.py,sha256=bm1zZUzwey0DijtMZkSulyfiWjCAg_1UmXaSG8qvvjs,2170
35
+ egglog/type_constraint_solver.py,sha256=MW22gwryNOqsYaB__dCCjVvdnB1Km5fovoW2IgUXaVo,5836
35
36
  egglog/widget.css,sha256=WJS2M1wQdujhSTCakMa_ZXuoTPre1Uy1lPcvBE1LZQU,102
36
37
  egglog/widget.js,sha256=UNOER3sYZ-bS7Qhw9S6qtpR81FuHa5DzXQaWWtQq334,2021
37
38
  egglog/__init__.py,sha256=iUrVe5fb0XFyMCS3CwTjwhKEtU8KpIkdTpJpnUhm8o0,247
38
- egglog/bindings.cp310-win_amd64.pyd,sha256=pjYXlaSXtC50TWNdnGwLjKbntZ0WqfM5AR9hqehz2xc,4591104
39
- egglog-7.1.0.dist-info/RECORD,,
39
+ egglog/bindings.cp310-win_amd64.pyd,sha256=NO6b-M4HPsi0uWxf6FKU7kBdS8WAsLbrJXeclR6XjuU,4591616
40
+ egglog-7.2.0.dist-info/RECORD,,
File without changes