egglog 7.0.0__cp312-none-win_amd64.whl → 7.2.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -75,6 +75,7 @@ __all__ = [
75
75
  "seq",
76
76
  "Command",
77
77
  "simplify",
78
+ "unstable_combine_rulesets",
78
79
  "check",
79
80
  "GraphvizKwargs",
80
81
  "Ruleset",
@@ -88,6 +89,7 @@ __all__ = [
88
89
  "Fact",
89
90
  "Action",
90
91
  "Command",
92
+ "check_eq",
91
93
  ]
92
94
 
93
95
  T = TypeVar("T")
@@ -145,6 +147,23 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
145
147
  return EGraph().extract(x)
146
148
 
147
149
 
150
+ def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None) -> EGraph:
151
+ """
152
+ Verifies that two expressions are equal after running the schedule.
153
+ """
154
+ egraph = EGraph()
155
+ x_var = egraph.let("__check_eq_x", x)
156
+ y_var = egraph.let("__check_eq_y", y)
157
+ if schedule:
158
+ egraph.run(schedule)
159
+ fact = eq(x_var).to(y_var)
160
+ try:
161
+ egraph.check(fact)
162
+ except bindings.EggSmolError as err:
163
+ raise AssertionError(f"Failed {eq(x).to(y)}\n -> {ne(egraph.extract(x)).to(egraph.extract(y))})") from err
164
+ return egraph
165
+
166
+
148
167
  def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
149
168
  """
150
169
  Verifies that the fact is true given some assumptions and after running the schedule.
@@ -333,7 +352,7 @@ class _BaseModule:
333
352
  This is the same as defining a nullary function with a high cost.
334
353
  # TODO: Rename as declare to match eggglog?
335
354
  """
336
- return constant(name, tp, egg_name)
355
+ return constant(name, tp, egg_name=egg_name)
337
356
 
338
357
  def register(
339
358
  self,
@@ -433,6 +452,7 @@ class _ExprMetaclass(type):
433
452
  namespace: dict[str, Any],
434
453
  egg_sort: str | None = None,
435
454
  builtin: bool = False,
455
+ ruleset: Ruleset | None = None,
436
456
  ) -> RuntimeClass | type:
437
457
  # If this is the Expr subclass, just return the class
438
458
  if not bases:
@@ -448,7 +468,14 @@ class _ExprMetaclass(type):
448
468
  # Otherwise, f_locals returns a copy
449
469
  # https://peps.python.org/pep-0667/
450
470
  decls_thunk = Thunk.fn(
451
- _generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
471
+ _generate_class_decls,
472
+ namespace,
473
+ prev_frame,
474
+ builtin,
475
+ egg_sort,
476
+ name,
477
+ ruleset,
478
+ fallback=Declarations,
452
479
  )
453
480
  return RuntimeClass(decls_thunk, TypeRefWithVars(name))
454
481
 
@@ -457,7 +484,12 @@ class _ExprMetaclass(type):
457
484
 
458
485
 
459
486
  def _generate_class_decls(
460
- namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
487
+ namespace: dict[str, Any],
488
+ frame: FrameType,
489
+ builtin: bool,
490
+ egg_sort: str | None,
491
+ cls_name: str,
492
+ ruleset: Ruleset | None,
461
493
  ) -> Declarations:
462
494
  """
463
495
  Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
@@ -479,9 +511,9 @@ def _generate_class_decls(
479
511
  for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
480
512
  if getattr(v, "__origin__", None) == ClassVar:
481
513
  (inner_tp,) = v.__args__
482
- type_ref = resolve_type_annotation(decls, inner_tp).to_just()
483
- cls_decl.class_variables[k] = ConstantDecl(type_ref)
484
-
514
+ type_ref = resolve_type_annotation(decls, inner_tp)
515
+ cls_decl.class_variables[k] = ConstantDecl(type_ref.to_just())
516
+ _add_default_rewrite(decls, ClassVariableRef(cls_name, k), type_ref, namespace.pop(k, None), ruleset)
485
517
  else:
486
518
  msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
487
519
  raise NotImplementedError(msg)
@@ -491,13 +523,15 @@ def _generate_class_decls(
491
523
  ##
492
524
 
493
525
  # The type ref of self is paramterized by the type vars
494
- slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
526
+ TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
495
527
 
496
528
  # Get all the methods from the class
497
529
  filtered_namespace: list[tuple[str, Any]] = [
498
530
  (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
499
531
  ]
500
532
 
533
+ # all methods we should try adding default functions for
534
+ default_function_refs: dict[ClassMethodRef | MethodRef | PropertyRef, Callable] = {}
501
535
  # Then register each of its methods
502
536
  for method_name, method in filtered_namespace:
503
537
  is_init = method_name == "__init__"
@@ -517,33 +551,26 @@ def _generate_class_decls(
517
551
  continue
518
552
  locals = frame.f_locals
519
553
 
520
- def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
521
- return _fn_decl(
522
- decls,
523
- egg_fn, # noqa: B023
524
- fn,
525
- locals, # noqa: B023
526
- default, # noqa: B023
527
- cost, # noqa: B023
528
- merge, # noqa: B023
529
- on_merge, # noqa: B023
530
- mutates, # noqa: B023
531
- builtin,
532
- first,
533
- is_init, # noqa: B023
534
- unextractable, # noqa: B023
535
- )
536
-
554
+ ref: ClassMethodRef | MethodRef | PropertyRef | InitRef
537
555
  match fn:
538
556
  case classmethod():
539
- cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
557
+ ref = ClassMethodRef(cls_name, method_name)
558
+ fn = fn.__func__
540
559
  case property():
541
- cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
560
+ ref = PropertyRef(cls_name, method_name)
561
+ fn = fn.fget
542
562
  case _:
543
- if is_init:
544
- cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
545
- else:
546
- cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
563
+ ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
564
+
565
+ _fn_decl(decls, egg_fn, ref, fn, locals, default, cost, merge, on_merge, mutates, builtin, unextractable)
566
+
567
+ if not builtin and not isinstance(ref, InitRef) and not mutates:
568
+ default_function_refs[ref] = fn
569
+
570
+ # Add all rewrite methods at the end so that all methods are registered first and can be accessed
571
+ # in the bodies
572
+ for ref, fn in default_function_refs.items():
573
+ _add_default_rewrite_function(decls, ref, fn, ruleset)
547
574
 
548
575
  return decls
549
576
 
@@ -562,6 +589,7 @@ def function(
562
589
  mutates_first_arg: bool = False,
563
590
  unextractable: bool = False,
564
591
  builtin: bool = False,
592
+ ruleset: Ruleset | None = None,
565
593
  ) -> Callable[[CALLABLE], CALLABLE]: ...
566
594
 
567
595
 
@@ -575,6 +603,7 @@ def function(
575
603
  on_merge: Callable[[EXPR, EXPR], Iterable[ActionLike]] | None = None,
576
604
  mutates_first_arg: bool = False,
577
605
  unextractable: bool = False,
606
+ ruleset: Ruleset | None = None,
578
607
  ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]: ...
579
608
 
580
609
 
@@ -605,15 +634,17 @@ class _FunctionConstructor:
605
634
  merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
606
635
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
607
636
  unextractable: bool = False
637
+ ruleset: Ruleset | None = None
608
638
 
609
639
  def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
610
640
  return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
611
641
 
612
642
  def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
613
643
  decls = Declarations()
614
- decls._functions[fn.__name__] = _fn_decl(
644
+ _fn_decl(
615
645
  decls,
616
646
  self.egg_fn,
647
+ (ref := FunctionRef(fn.__name__)),
617
648
  fn,
618
649
  self.hint_locals,
619
650
  self.default,
@@ -624,12 +655,14 @@ class _FunctionConstructor:
624
655
  self.builtin,
625
656
  unextractable=self.unextractable,
626
657
  )
658
+ _add_default_rewrite_function(decls, ref, fn, self.ruleset)
627
659
  return decls
628
660
 
629
661
 
630
662
  def _fn_decl(
631
663
  decls: Declarations,
632
664
  egg_name: str | None,
665
+ ref: FunctionRef | MethodRef | PropertyRef | ClassMethodRef | InitRef,
633
666
  fn: object,
634
667
  # Pass in the locals, retrieved from the frame when wrapping,
635
668
  # so that we support classes and function defined inside of other functions (which won't show up in the globals)
@@ -640,22 +673,35 @@ def _fn_decl(
640
673
  on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None,
641
674
  mutates_first_arg: bool,
642
675
  is_builtin: bool,
643
- # The first arg is either cls, for a classmethod, a self type, or none for a function
644
- first_arg: Literal["cls"] | TypeOrVarRef | None = None,
645
- is_init: bool = False,
646
676
  unextractable: bool = False,
647
- ) -> FunctionDecl:
677
+ ) -> None:
678
+ """
679
+ Sets the function decl for the function object.
680
+ """
648
681
  if not isinstance(fn, FunctionType):
649
682
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
650
683
 
684
+ # partial function creation and calling are handled with a special case in the type checker, so don't
685
+ # use the normal logic
686
+ special_function_name: SpecialFunctions | None = (
687
+ "fn-partial" if egg_name == "unstable-fn" else "fn-app" if egg_name == "unstable-app" else None
688
+ )
689
+ if special_function_name:
690
+ decls.set_function_decl(ref, FunctionDecl(special_function_name, builtin=True, egg_name=egg_name))
691
+ return
692
+
651
693
  hint_globals = fn.__globals__.copy()
694
+ # Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
695
+ # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
696
+ if "Callable" not in hint_globals:
697
+ hint_globals["Callable"] = Callable
652
698
 
653
699
  hints = get_type_hints(fn, hint_globals, hint_locals)
654
700
 
655
701
  params = list(signature(fn).parameters.values())
656
702
 
657
- # If this is an init function, or a classmethod, remove the first arg name
658
- if is_init or first_arg == "cls":
703
+ # If this is an init function, or a classmethod, the first arg is not used
704
+ if isinstance(ref, ClassMethodRef | InitRef):
659
705
  params = params[1:]
660
706
 
661
707
  if _last_param_variable(params):
@@ -665,9 +711,8 @@ def _fn_decl(
665
711
  else:
666
712
  var_arg_type = None
667
713
  arg_types = tuple(
668
- first_arg
669
- # If the first arg is a self, and this not an __init__ fn, add this as a typeref
670
- if i == 0 and isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars) and not is_init
714
+ decls.get_paramaterized_class(ref.class_name)
715
+ if i == 0 and isinstance(ref, MethodRef | PropertyRef)
671
716
  else resolve_type_annotation(decls, hints[t.name])
672
717
  for i, t in enumerate(params)
673
718
  )
@@ -680,17 +725,15 @@ def _fn_decl(
680
725
 
681
726
  decls.update(*arg_defaults)
682
727
 
683
- # If this is an init fn use the first arg as the return type
684
- if is_init:
685
- assert not mutates_first_arg
686
- if not isinstance(first_arg, ClassTypeVarRef | TypeRefWithVars):
687
- msg = "Init function must have a self type"
688
- raise ValueError(msg)
689
- return_type = first_arg
690
- elif mutates_first_arg:
691
- return_type = arg_types[0]
692
- else:
693
- return_type = resolve_type_annotation(decls, hints["return"])
728
+ return_type = (
729
+ decls.get_paramaterized_class(ref.class_name)
730
+ if isinstance(ref, InitRef)
731
+ else arg_types[0]
732
+ if mutates_first_arg
733
+ else resolve_type_annotation(decls, hints["return"])
734
+ )
735
+
736
+ arg_names = tuple(t.name for t in params)
694
737
 
695
738
  decls |= default
696
739
  merged = (
@@ -714,12 +757,14 @@ def _fn_decl(
714
757
  )
715
758
  )
716
759
  decls.update(*merge_action)
717
- return FunctionDecl(
718
- return_type=None if mutates_first_arg else return_type,
719
- var_arg_type=var_arg_type,
720
- arg_types=arg_types,
721
- arg_names=tuple(t.name for t in params),
722
- arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
760
+ decl = FunctionDecl(
761
+ FunctionSignature(
762
+ return_type=None if mutates_first_arg else return_type,
763
+ var_arg_type=var_arg_type,
764
+ arg_types=arg_types,
765
+ arg_names=arg_names,
766
+ arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
767
+ ),
723
768
  cost=cost,
724
769
  egg_name=egg_name,
725
770
  merge=merged.__egg_typed_expr__.expr if merged is not None else None,
@@ -728,6 +773,7 @@ def _fn_decl(
728
773
  default=None if default is None else default.__egg_typed_expr__.expr,
729
774
  on_merge=tuple(a.action for a in merge_action),
730
775
  )
776
+ decls.set_function_decl(ref, decl)
731
777
 
732
778
 
733
779
  # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
@@ -769,19 +815,73 @@ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Dec
769
815
  return decls
770
816
 
771
817
 
772
- def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
818
+ def constant(
819
+ name: str,
820
+ tp: type[EXPR],
821
+ default_replacement: EXPR | None = None,
822
+ /,
823
+ *,
824
+ egg_name: str | None = None,
825
+ ruleset: Ruleset | None = None,
826
+ ) -> EXPR:
773
827
  """
774
828
  A "constant" is implemented as the instantiation of a value that takes no args.
775
829
  This creates a function with `name` and return type `tp` and returns a value of it being called.
776
830
  """
777
- return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
831
+ return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name, default_replacement, ruleset)))
778
832
 
779
833
 
780
- def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
834
+ def _constant_thunk(
835
+ name: str, tp: type, egg_name: str | None, default_replacement: object, ruleset: Ruleset | None
836
+ ) -> tuple[Declarations, TypedExprDecl]:
781
837
  decls = Declarations()
782
- type_ref = resolve_type_annotation(decls, tp).to_just()
783
- decls._constants[name] = ConstantDecl(type_ref, egg_name)
784
- return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
838
+ type_ref = resolve_type_annotation(decls, tp)
839
+ callable_ref = ConstantRef(name)
840
+ decls._constants[name] = ConstantDecl(type_ref.to_just(), egg_name)
841
+ _add_default_rewrite(decls, callable_ref, type_ref, default_replacement, ruleset)
842
+ return decls, TypedExprDecl(type_ref.to_just(), CallDecl(callable_ref))
843
+
844
+
845
+ def _add_default_rewrite_function(decls: Declarations, ref: CallableRef, fn: Callable, ruleset: Ruleset | None) -> None:
846
+ """
847
+ Adds a default rewrite for a function, by calling the functions with vars and adding it if it is not None.
848
+ """
849
+ callable_decl = decls.get_callable_decl(ref)
850
+ assert isinstance(callable_decl, FunctionDecl)
851
+ signature = callable_decl.signature
852
+ assert isinstance(signature, FunctionSignature)
853
+
854
+ var_args: list[object] = [
855
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name))))
856
+ for name, tp in zip(signature.arg_names, signature.arg_types, strict=False)
857
+ ]
858
+ # If this is a classmethod, add the class as the first arg
859
+ if isinstance(ref, ClassMethodRef):
860
+ tp = decls.get_paramaterized_class(ref.class_name)
861
+ var_args.insert(0, RuntimeClass(Thunk.value(decls), tp))
862
+ _add_default_rewrite(decls, ref, signature.semantic_return_type, fn(*var_args), ruleset)
863
+
864
+
865
+ def _add_default_rewrite(
866
+ decls: Declarations, ref: CallableRef, type_ref: TypeOrVarRef, default_rewrite: object, ruleset: Ruleset | None
867
+ ) -> None:
868
+ """
869
+ Adds a default rewrite for the callable, if the default rewrite is not None
870
+
871
+ Will add it to the ruleset if it is passed in, or add it to the default ruleset on the passed in decls if not.
872
+ """
873
+ if default_rewrite is None:
874
+ return
875
+ resolved_value = resolve_literal(type_ref, default_rewrite)
876
+ rewrite_decl = DefaultRewriteDecl(ref, resolved_value.__egg_typed_expr__.expr)
877
+ if ruleset:
878
+ ruleset_decls = ruleset._current_egg_decls
879
+ ruleset_decl = ruleset.__egg_ruleset__
880
+ else:
881
+ ruleset_decls = decls
882
+ ruleset_decl = decls.default_ruleset
883
+ ruleset_decl.rules.append(rewrite_decl)
884
+ ruleset_decls |= resolved_value
785
885
 
786
886
 
787
887
  def _last_param_variable(params: list[Parameter]) -> bool:
@@ -933,13 +1033,12 @@ class EGraph(_BaseModule):
933
1033
  """
934
1034
  Displays the e-graph in the notebook.
935
1035
  """
936
- graphviz = self.graphviz(**kwargs)
937
1036
  if IN_IPYTHON:
938
1037
  from IPython.display import SVG, display
939
1038
 
940
1039
  display(SVG(self.graphviz_svg(**kwargs)))
941
1040
  else:
942
- graphviz.render(view=True, format="svg", quiet=True)
1041
+ self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
943
1042
 
944
1043
  def input(self, fn: Callable[..., String], path: str) -> None:
945
1044
  """
@@ -957,6 +1056,7 @@ class EGraph(_BaseModule):
957
1056
  action = let(name, expr)
958
1057
  self.register(action)
959
1058
  runtime_expr = to_runtime_expr(expr)
1059
+ self._add_decls(runtime_expr)
960
1060
  return cast(
961
1061
  EXPR,
962
1062
  RuntimeExpr.__from_value__(
@@ -982,7 +1082,8 @@ class EGraph(_BaseModule):
982
1082
  self._add_decls(runtime_expr, schedule)
983
1083
  egg_schedule = self._state.schedule_to_egg(schedule.schedule)
984
1084
  typed_expr = runtime_expr.__egg_typed_expr__
985
- egg_expr = self._state.expr_to_egg(typed_expr.expr)
1085
+ # Must also register type
1086
+ egg_expr = self._state.typed_expr_to_egg(typed_expr)
986
1087
  self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
987
1088
  extract_report = self._egraph.extract_report()
988
1089
  if not isinstance(extract_report, bindings.Best):
@@ -1059,7 +1160,7 @@ class EGraph(_BaseModule):
1059
1160
  runtime_expr = to_runtime_expr(expr)
1060
1161
  self._add_decls(runtime_expr)
1061
1162
  typed_expr = runtime_expr.__egg_typed_expr__
1062
- extract_report = self._run_extract(typed_expr.expr, 0)
1163
+ extract_report = self._run_extract(typed_expr, 0)
1063
1164
 
1064
1165
  if not isinstance(extract_report, bindings.Best):
1065
1166
  msg = "No extract report saved"
@@ -1079,15 +1180,15 @@ class EGraph(_BaseModule):
1079
1180
  self._add_decls(runtime_expr)
1080
1181
  typed_expr = runtime_expr.__egg_typed_expr__
1081
1182
 
1082
- extract_report = self._run_extract(typed_expr.expr, n)
1183
+ extract_report = self._run_extract(typed_expr, n)
1083
1184
  if not isinstance(extract_report, bindings.Variants):
1084
1185
  msg = "Wrong extract report type"
1085
1186
  raise ValueError(msg) # noqa: TRY004
1086
1187
  new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
1087
1188
  return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1088
1189
 
1089
- def _run_extract(self, expr: ExprDecl, n: int) -> bindings._ExtractReport:
1090
- expr = self._state.expr_to_egg(expr)
1190
+ def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
1191
+ expr = self._state.typed_expr_to_egg(typed_expr)
1091
1192
  self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1092
1193
  extract_report = self._egraph.extract_report()
1093
1194
  if not extract_report:
@@ -1146,7 +1247,7 @@ class EGraph(_BaseModule):
1146
1247
  runtime_expr = to_runtime_expr(expr)
1147
1248
  self._add_decls(runtime_expr)
1148
1249
  typed_expr = runtime_expr.__egg_typed_expr__
1149
- egg_expr = self._state.expr_to_egg(typed_expr.expr)
1250
+ egg_expr = self._state.typed_expr_to_egg(typed_expr)
1150
1251
  match typed_expr.tp:
1151
1252
  case JustTypeRef("i64"):
1152
1253
  return self._egraph.eval_i64(egg_expr)
@@ -1276,8 +1377,10 @@ def ruleset(
1276
1377
  """
1277
1378
  Creates a ruleset with the following rules.
1278
1379
 
1279
- If no name is provided, one is generated based on the current module
1380
+ If no name is provided, try using the name of the funciton.
1280
1381
  """
1382
+ if isinstance(rule_or_generator, FunctionType):
1383
+ name = name or rule_or_generator.__name__
1281
1384
  r = Ruleset(name)
1282
1385
  if rule_or_generator is not None:
1283
1386
  r.register(rule_or_generator, *rules, _increase_frame=True)
@@ -1388,12 +1491,48 @@ class Ruleset(Schedule):
1388
1491
  def __repr__(self) -> str:
1389
1492
  return str(self)
1390
1493
 
1494
+ def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
1495
+ return unstable_combine_rulesets(self, other)
1496
+
1391
1497
  # Create a unique name if we didn't pass one from the user
1392
1498
  @property
1393
1499
  def __egg_name__(self) -> str:
1394
1500
  return self.name or f"ruleset_{id(self)}"
1395
1501
 
1396
1502
 
1503
+ @dataclass
1504
+ class UnstableCombinedRuleset(Schedule):
1505
+ __egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
1506
+ schedule: RunDecl = field(init=False)
1507
+ name: str | None
1508
+ rulesets: InitVar[list[Ruleset | UnstableCombinedRuleset]]
1509
+
1510
+ def __post_init__(self, rulesets: list[Ruleset | UnstableCombinedRuleset]) -> None:
1511
+ self.schedule = RunDecl(self.__egg_name__, ())
1512
+ self.__egg_decls_thunk__ = Thunk.fn(self._create_egg_decls, *rulesets)
1513
+
1514
+ @property
1515
+ def __egg_name__(self) -> str:
1516
+ return self.name or f"combined_ruleset_{id(self)}"
1517
+
1518
+ def _create_egg_decls(self, *rulesets: Ruleset | UnstableCombinedRuleset) -> Declarations:
1519
+ decls = Declarations.create(*rulesets)
1520
+ decls._rulesets[self.__egg_name__] = CombinedRulesetDecl(tuple(r.__egg_name__ for r in rulesets))
1521
+ return decls
1522
+
1523
+ def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
1524
+ return unstable_combine_rulesets(self, other)
1525
+
1526
+
1527
+ def unstable_combine_rulesets(
1528
+ *rulesets: Ruleset | UnstableCombinedRuleset, name: str | None = None
1529
+ ) -> UnstableCombinedRuleset:
1530
+ """
1531
+ Combine multiple rulesets into a single ruleset.
1532
+ """
1533
+ return UnstableCombinedRuleset(name, list(rulesets))
1534
+
1535
+
1397
1536
  @dataclass
1398
1537
  class RewriteOrRule:
1399
1538
  __egg_decls__: Declarations
@@ -1556,9 +1695,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
1556
1695
 
1557
1696
  def _var(name: str, bound: object) -> RuntimeExpr:
1558
1697
  """Create a new variable with the given name and type."""
1559
- if not isinstance(bound, RuntimeClass):
1560
- raise TypeError(f"Unexpected type {type(bound)}")
1561
- return RuntimeExpr.__from_value__(bound.__egg_decls__, TypedExprDecl(bound.__egg_tp__.to_just(), VarDecl(name)))
1698
+ decls = Declarations()
1699
+ type_ref = resolve_type_annotation(decls, bound)
1700
+ return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
1562
1701
 
1563
1702
 
1564
1703
  def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
@@ -1801,9 +1940,11 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
1801
1940
  """
1802
1941
  # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
1803
1942
  # but not in the globals
1804
-
1805
- hints = get_type_hints(gen, gen.__globals__, frame.f_locals)
1806
- args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1943
+ globals = gen.__globals__.copy()
1944
+ if "Callable" not in globals:
1945
+ globals["Callable"] = Callable
1946
+ hints = get_type_hints(gen, globals, frame.f_locals)
1947
+ args = [_var(_rule_var_name(p.name), hints[p.name]) for p in signature(gen).parameters.values()]
1807
1948
  return list(gen(*args)) # type: ignore[misc]
1808
1949
 
1809
1950