egglog 7.0.0__cp310-none-win_amd64.whl → 7.1.0__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Binary file
egglog/bindings.pyi CHANGED
@@ -494,6 +494,12 @@ class Relation:
494
494
  class PrintOverallStatistics:
495
495
  def __init__(self) -> None: ...
496
496
 
497
+ @final
498
+ class UnstableCombinedRuleset:
499
+ name: str
500
+ rulesets: list[str]
501
+ def __init__(self, name: str, rulesets: list[str]) -> None: ...
502
+
497
503
  _Command: TypeAlias = (
498
504
  SetOption
499
505
  | Datatype
@@ -521,6 +527,7 @@ _Command: TypeAlias = (
521
527
  | CheckProof
522
528
  | Relation
523
529
  | PrintOverallStatistics
530
+ | UnstableCombinedRuleset
524
531
  )
525
532
 
526
533
  def termdag_term_to_expr(termdag: TermDag, term: _Term) -> _Expr: ...
egglog/builtins.py CHANGED
@@ -5,10 +5,14 @@ Builtin sorts and function to egg.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union
8
+ from functools import partial
9
+ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union, overload
10
+
11
+ from typing_extensions import TypeVarTuple, Unpack
9
12
 
10
13
  from .conversion import converter
11
14
  from .egraph import Expr, Unit, function, method
15
+ from .runtime import RuntimeFunction
12
16
 
13
17
  if TYPE_CHECKING:
14
18
  from collections.abc import Callable
@@ -31,6 +35,7 @@ __all__ = [
31
35
  "py_eval",
32
36
  "py_exec",
33
37
  "py_eval_fn",
38
+ "UnstableFn",
34
39
  ]
35
40
 
36
41
 
@@ -461,3 +466,38 @@ def py_exec(code: StringLike, globals: object = PyObject.dict(), locals: object
461
466
  """
462
467
  Copies the locals, execs the Python code, and returns the locals with any updates.
463
468
  """
469
+
470
+
471
+ TS = TypeVarTuple("TS")
472
+
473
+ T1 = TypeVar("T1")
474
+ T2 = TypeVar("T2")
475
+ T3 = TypeVar("T3")
476
+
477
+
478
+ class UnstableFn(Expr, Generic[T, Unpack[TS]], builtin=True):
479
+ @overload
480
+ def __init__(self, f: Callable[[Unpack[TS]], T]) -> None: ...
481
+
482
+ @overload
483
+ def __init__(self, f: Callable[[T1, Unpack[TS]], T], _a: T1, /) -> None: ...
484
+
485
+ @overload
486
+ def __init__(self, f: Callable[[T1, T2, Unpack[TS]], T], _a: T1, _b: T2, /) -> None: ...
487
+
488
+ # Removing due to bug in MyPy
489
+ # https://github.com/python/mypy/issues/17212
490
+ # @overload
491
+ # def __init__(self, f: Callable[[T1, T2, T3, Unpack[TS]], T], _a: T1, _b: T2, _c: T3, /) -> None: ...
492
+
493
+ # etc, for partial application
494
+
495
+ @method(egg_fn="unstable-fn")
496
+ def __init__(self, f, *partial) -> None: ...
497
+
498
+ @method(egg_fn="unstable-app")
499
+ def __call__(self, *args: Unpack[TS]) -> T: ...
500
+
501
+
502
+ converter(RuntimeFunction, UnstableFn, UnstableFn)
503
+ converter(partial, UnstableFn, lambda p: UnstableFn(p.func, *p.args))
egglog/conversion.py CHANGED
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, TypeVar, cast
4
+ from typing import TYPE_CHECKING, NewType, TypeVar, cast
5
5
 
6
6
  from .declarations import *
7
7
  from .pretty import *
@@ -16,7 +16,8 @@ if TYPE_CHECKING:
16
16
 
17
17
  __all__ = ["convert", "converter", "resolve_literal", "convert_to_same_type"]
18
18
  # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
19
- CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
19
+ TypeName = NewType("TypeName", str)
20
+ CONVERSIONS: dict[tuple[type | TypeName, TypeName], tuple[int, Callable]] = {}
20
21
  # Global declerations to store all convertable types so we can query if they have certain methods or not
21
22
  # Defer it as a thunk so we can register conversions without triggering type signature loading
22
23
  CONVERSIONS_DECLS: Callable[[], Declarations] = Thunk.value(Declarations())
@@ -34,12 +35,12 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
34
35
  Register a converter from some type to an egglog type.
35
36
  """
36
37
  to_type_name = process_tp(to_type)
37
- if not isinstance(to_type_name, JustTypeRef):
38
+ if not isinstance(to_type_name, str):
38
39
  raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
39
40
  _register_converter(process_tp(from_type), to_type_name, fn, cost)
40
41
 
41
42
 
42
- def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
43
+ def _register_converter(a: type | TypeName, b: TypeName, a_b: Callable, cost: int) -> None:
43
44
  """
44
45
  Registers a converter from some type to an egglog type, if not already registered.
45
46
 
@@ -94,14 +95,17 @@ def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
94
95
  return resolve_literal(tp.to_var(), source)
95
96
 
96
97
 
97
- def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
98
+ def process_tp(tp: type | RuntimeClass) -> TypeName | type:
98
99
  """
99
100
  Process a type before converting it, to add it to the global declerations and resolve to a ref.
100
101
  """
101
102
  global CONVERSIONS_DECLS
102
103
  if isinstance(tp, RuntimeClass):
103
104
  CONVERSIONS_DECLS = Thunk.fn(_combine_decls, CONVERSIONS_DECLS, tp)
104
- return tp.__egg_tp__.to_just()
105
+ egg_tp = tp.__egg_tp__
106
+ if egg_tp.args:
107
+ raise TypeError(f"Cannot register a converter for a generic type, got {tp}")
108
+ return TypeName(egg_tp.name)
105
109
  return tp
106
110
 
107
111
 
@@ -109,7 +113,7 @@ def _combine_decls(d: Callable[[], Declarations], x: HasDeclerations) -> Declara
109
113
  return Declarations.create(d(), x)
110
114
 
111
115
 
112
- def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
116
+ def min_convertable_tp(a: object, b: object, name: str) -> TypeName:
113
117
  """
114
118
  Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
115
119
  """
@@ -117,14 +121,14 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
117
121
  a_tp = _get_tp(a)
118
122
  b_tp = _get_tp(b)
119
123
  a_converts_to = {
120
- to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
124
+ to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to, name)
121
125
  }
122
126
  b_converts_to = {
123
- to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
127
+ to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to, name)
124
128
  }
125
- if isinstance(a_tp, JustTypeRef):
129
+ if isinstance(a_tp, str):
126
130
  a_converts_to[a_tp] = 0
127
- if isinstance(b_tp, JustTypeRef):
131
+ if isinstance(b_tp, str):
128
132
  b_converts_to[b_tp] = 0
129
133
  common = set(a_converts_to) & set(b_converts_to)
130
134
  if not common:
@@ -143,28 +147,29 @@ def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
143
147
  try:
144
148
  tp_just = tp.to_just()
145
149
  except NotImplementedError:
146
- # If this is a var, it has to be a runtime exprssions
150
+ # If this is a var, it has to be a runtime expession
147
151
  assert isinstance(arg, RuntimeExpr)
148
152
  return arg
149
- if arg_type == tp_just:
153
+ tp_name = TypeName(tp_just.name)
154
+ if arg_type == tp_name:
150
155
  # If the type is an egg type, it has to be a runtime expr
151
156
  assert isinstance(arg, RuntimeExpr)
152
157
  return arg
153
158
  # Try all parent types as well, if we are converting from a Python type
154
159
  for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
155
160
  try:
156
- fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1]
161
+ fn = CONVERSIONS[(cast(TypeName | type, arg_type_instance), tp_name)][1]
157
162
  except KeyError:
158
163
  continue
159
164
  break
160
165
  else:
161
- raise ConvertError(f"Cannot convert {arg_type} to {tp_just}")
166
+ raise ConvertError(f"Cannot convert {arg_type} to {tp_name}")
162
167
  return fn(arg)
163
168
 
164
169
 
165
- def _get_tp(x: object) -> JustTypeRef | type:
170
+ def _get_tp(x: object) -> TypeName | type:
166
171
  if isinstance(x, RuntimeExpr):
167
- return x.__egg_typed_expr__.tp
172
+ return TypeName(x.__egg_typed_expr__.tp.name)
168
173
  tp = type(x)
169
174
  # If this value has a custom metaclass, let's use that as our index instead of the type
170
175
  if type(tp) != type:
egglog/declarations.py CHANGED
@@ -39,6 +39,7 @@ __all__ = [
39
39
  "CallableDecl",
40
40
  "VarDecl",
41
41
  "PyObjectDecl",
42
+ "PartialCallDecl",
42
43
  "LitType",
43
44
  "LitDecl",
44
45
  "CallDecl",
@@ -46,6 +47,7 @@ __all__ = [
46
47
  "TypedExprDecl",
47
48
  "ClassDecl",
48
49
  "RulesetDecl",
50
+ "CombinedRulesetDecl",
49
51
  "SaturateDecl",
50
52
  "RepeatDecl",
51
53
  "SequenceDecl",
@@ -67,6 +69,8 @@ __all__ = [
67
69
  "RewriteOrRuleDecl",
68
70
  "ActionCommandDecl",
69
71
  "CommandDecl",
72
+ "SpecialFunctions",
73
+ "FunctionSignature",
70
74
  ]
71
75
 
72
76
 
@@ -88,9 +92,6 @@ class HasDeclerations(Protocol):
88
92
  DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
89
93
 
90
94
 
91
- # TODO: Make all ClassDecls take deferred type refs, which return new decls when resolving.
92
-
93
-
94
95
  def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
95
96
  d = []
96
97
  for l in declerations_like:
@@ -110,7 +111,7 @@ class Declarations:
110
111
  _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
111
112
  _constants: dict[str, ConstantDecl] = field(default_factory=dict)
112
113
  _classes: dict[str, ClassDecl] = field(default_factory=dict)
113
- _rulesets: dict[str, RulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
114
+ _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
114
115
 
115
116
  @classmethod
116
117
  def create(cls, *others: DeclerationsLike) -> Declarations:
@@ -196,7 +197,7 @@ class ClassDecl:
196
197
  preserved_methods: dict[str, Callable] = field(default_factory=dict)
197
198
 
198
199
 
199
- @dataclass
200
+ @dataclass(frozen=True)
200
201
  class RulesetDecl:
201
202
  rules: list[RewriteOrRuleDecl]
202
203
 
@@ -206,6 +207,11 @@ class RulesetDecl:
206
207
  return hash((type(self), tuple(self.rules)))
207
208
 
208
209
 
210
+ @dataclass(frozen=True)
211
+ class CombinedRulesetDecl:
212
+ rulesets: tuple[str, ...]
213
+
214
+
209
215
  # Have two different types of type refs, one that can include vars recursively and one that cannot.
210
216
  # We only use the one with vars for classmethods and methods, and the other one for egg references as
211
217
  # well as runtime values.
@@ -316,10 +322,12 @@ class RelationDecl:
316
322
 
317
323
  def to_function_decl(self) -> FunctionDecl:
318
324
  return FunctionDecl(
319
- arg_types=tuple(a.to_var() for a in self.arg_types),
320
- arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
321
- arg_defaults=self.arg_defaults,
322
- return_type=TypeRefWithVars("Unit"),
325
+ FunctionSignature(
326
+ arg_types=tuple(a.to_var() for a in self.arg_types),
327
+ arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
328
+ arg_defaults=self.arg_defaults,
329
+ return_type=TypeRefWithVars("Unit"),
330
+ ),
323
331
  egg_name=self.egg_name,
324
332
  default=LitDecl(None),
325
333
  )
@@ -336,25 +344,41 @@ class ConstantDecl:
336
344
 
337
345
  def to_function_decl(self) -> FunctionDecl:
338
346
  return FunctionDecl(
339
- arg_types=(),
340
- arg_names=(),
341
- arg_defaults=(),
342
- return_type=self.type_ref.to_var(),
347
+ FunctionSignature(return_type=self.type_ref.to_var()),
343
348
  egg_name=self.egg_name,
344
349
  )
345
350
 
346
351
 
352
+ # special cases for partial function creation and application, which cannot use the normal python rules
353
+ SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
354
+
355
+
347
356
  @dataclass(frozen=True)
348
- class FunctionDecl:
349
- # All args are delayed except for relations converted to function decls
350
- arg_types: tuple[TypeOrVarRef, ...]
351
- arg_names: tuple[str, ...]
357
+ class FunctionSignature:
358
+ arg_types: tuple[TypeOrVarRef, ...] = ()
359
+ arg_names: tuple[str, ...] = ()
352
360
  # List of defaults. None for any arg which doesn't have one.
353
- arg_defaults: tuple[ExprDecl | None, ...]
361
+ arg_defaults: tuple[ExprDecl | None, ...] = ()
354
362
  # If None, then the first arg is mutated and returned
355
- return_type: TypeOrVarRef | None
363
+ return_type: TypeOrVarRef | None = None
356
364
  var_arg_type: TypeOrVarRef | None = None
357
365
 
366
+ @property
367
+ def semantic_return_type(self) -> TypeOrVarRef:
368
+ """
369
+ The type that is returned by the function, which wil be in the first arg if it mutates it.
370
+ """
371
+ return self.return_type or self.arg_types[0]
372
+
373
+ @property
374
+ def mutates(self) -> bool:
375
+ return self.return_type is None
376
+
377
+
378
+ @dataclass(frozen=True)
379
+ class FunctionDecl:
380
+ signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
381
+
358
382
  # Egg params
359
383
  builtin: bool = False
360
384
  egg_name: str | None = None
@@ -367,17 +391,6 @@ class FunctionDecl:
367
391
  def to_function_decl(self) -> FunctionDecl:
368
392
  return self
369
393
 
370
- @property
371
- def semantic_return_type(self) -> TypeOrVarRef:
372
- """
373
- The type that is returned by the function, which wil be in the first arg if it mutates it.
374
- """
375
- return self.return_type or self.arg_types[0]
376
-
377
- @property
378
- def mutates(self) -> bool:
379
- return self.return_type is None
380
-
381
394
 
382
395
  CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
383
396
 
@@ -463,7 +476,20 @@ class CallDecl:
463
476
  return hash(self) == hash(other)
464
477
 
465
478
 
466
- ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl
479
+ @dataclass(frozen=True)
480
+ class PartialCallDecl:
481
+ """
482
+ A partially applied function aka a function sort.
483
+
484
+ Note it does not need to have any args, in which case it's just a function pointer.
485
+
486
+ Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
487
+ """
488
+
489
+ call: CallDecl
490
+
491
+
492
+ ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
467
493
 
468
494
 
469
495
  @dataclass(frozen=True)
egglog/egraph.py CHANGED
@@ -75,6 +75,7 @@ __all__ = [
75
75
  "seq",
76
76
  "Command",
77
77
  "simplify",
78
+ "unstable_combine_rulesets",
78
79
  "check",
79
80
  "GraphvizKwargs",
80
81
  "Ruleset",
@@ -88,6 +89,7 @@ __all__ = [
88
89
  "Fact",
89
90
  "Action",
90
91
  "Command",
92
+ "check_eq",
91
93
  ]
92
94
 
93
95
  T = TypeVar("T")
@@ -145,6 +147,23 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
145
147
  return EGraph().extract(x)
146
148
 
147
149
 
150
+ def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None) -> EGraph:
151
+ """
152
+ Verifies that two expressions are equal after running the schedule.
153
+ """
154
+ egraph = EGraph()
155
+ x_var = egraph.let("__check_eq_x", x)
156
+ y_var = egraph.let("__check_eq_y", y)
157
+ if schedule:
158
+ egraph.run(schedule)
159
+ fact = eq(x_var).to(y_var)
160
+ try:
161
+ egraph.check(fact)
162
+ except bindings.EggSmolError as err:
163
+ raise AssertionError(f"Failed {eq(x).to(y)}\n -> {ne(egraph.extract(x)).to(egraph.extract(y))})") from err
164
+ return egraph
165
+
166
+
148
167
  def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
149
168
  """
150
169
  Verifies that the fact is true given some assumptions and after running the schedule.
@@ -456,7 +475,7 @@ class _ExprMetaclass(type):
456
475
  return isinstance(instance, RuntimeExpr)
457
476
 
458
477
 
459
- def _generate_class_decls(
478
+ def _generate_class_decls( # noqa: C901
460
479
  namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
461
480
  ) -> Declarations:
462
481
  """
@@ -518,6 +537,16 @@ def _generate_class_decls(
518
537
  locals = frame.f_locals
519
538
 
520
539
  def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
540
+ special_function_name: SpecialFunctions | None = (
541
+ "fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None # noqa: B023
542
+ )
543
+ if special_function_name:
544
+ return FunctionDecl(
545
+ special_function_name,
546
+ builtin=True,
547
+ egg_name=egg_fn, # noqa: B023
548
+ )
549
+
521
550
  return _fn_decl(
522
551
  decls,
523
552
  egg_fn, # noqa: B023
@@ -649,6 +678,10 @@ def _fn_decl(
649
678
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
650
679
 
651
680
  hint_globals = fn.__globals__.copy()
681
+ # Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
682
+ # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
683
+ if "Callable" not in hint_globals:
684
+ hint_globals["Callable"] = Callable
652
685
 
653
686
  hints = get_type_hints(fn, hint_globals, hint_locals)
654
687
 
@@ -715,11 +748,13 @@ def _fn_decl(
715
748
  )
716
749
  decls.update(*merge_action)
717
750
  return FunctionDecl(
718
- return_type=None if mutates_first_arg else return_type,
719
- var_arg_type=var_arg_type,
720
- arg_types=arg_types,
721
- arg_names=tuple(t.name for t in params),
722
- arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
751
+ FunctionSignature(
752
+ return_type=None if mutates_first_arg else return_type,
753
+ var_arg_type=var_arg_type,
754
+ arg_types=arg_types,
755
+ arg_names=tuple(t.name for t in params),
756
+ arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
757
+ ),
723
758
  cost=cost,
724
759
  egg_name=egg_name,
725
760
  merge=merged.__egg_typed_expr__.expr if merged is not None else None,
@@ -933,13 +968,12 @@ class EGraph(_BaseModule):
933
968
  """
934
969
  Displays the e-graph in the notebook.
935
970
  """
936
- graphviz = self.graphviz(**kwargs)
937
971
  if IN_IPYTHON:
938
972
  from IPython.display import SVG, display
939
973
 
940
974
  display(SVG(self.graphviz_svg(**kwargs)))
941
975
  else:
942
- graphviz.render(view=True, format="svg", quiet=True)
976
+ self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
943
977
 
944
978
  def input(self, fn: Callable[..., String], path: str) -> None:
945
979
  """
@@ -1059,7 +1093,7 @@ class EGraph(_BaseModule):
1059
1093
  runtime_expr = to_runtime_expr(expr)
1060
1094
  self._add_decls(runtime_expr)
1061
1095
  typed_expr = runtime_expr.__egg_typed_expr__
1062
- extract_report = self._run_extract(typed_expr.expr, 0)
1096
+ extract_report = self._run_extract(typed_expr, 0)
1063
1097
 
1064
1098
  if not isinstance(extract_report, bindings.Best):
1065
1099
  msg = "No extract report saved"
@@ -1079,15 +1113,16 @@ class EGraph(_BaseModule):
1079
1113
  self._add_decls(runtime_expr)
1080
1114
  typed_expr = runtime_expr.__egg_typed_expr__
1081
1115
 
1082
- extract_report = self._run_extract(typed_expr.expr, n)
1116
+ extract_report = self._run_extract(typed_expr, n)
1083
1117
  if not isinstance(extract_report, bindings.Variants):
1084
1118
  msg = "Wrong extract report type"
1085
1119
  raise ValueError(msg) # noqa: TRY004
1086
1120
  new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
1087
1121
  return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1088
1122
 
1089
- def _run_extract(self, expr: ExprDecl, n: int) -> bindings._ExtractReport:
1090
- expr = self._state.expr_to_egg(expr)
1123
+ def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
1124
+ self._state.type_ref_to_egg(typed_expr.tp)
1125
+ expr = self._state.expr_to_egg(typed_expr.expr)
1091
1126
  self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1092
1127
  extract_report = self._egraph.extract_report()
1093
1128
  if not extract_report:
@@ -1276,8 +1311,10 @@ def ruleset(
1276
1311
  """
1277
1312
  Creates a ruleset with the following rules.
1278
1313
 
1279
- If no name is provided, one is generated based on the current module
1314
+ If no name is provided, try using the name of the funciton.
1280
1315
  """
1316
+ if isinstance(rule_or_generator, FunctionType):
1317
+ name = name or rule_or_generator.__name__
1281
1318
  r = Ruleset(name)
1282
1319
  if rule_or_generator is not None:
1283
1320
  r.register(rule_or_generator, *rules, _increase_frame=True)
@@ -1388,12 +1425,48 @@ class Ruleset(Schedule):
1388
1425
  def __repr__(self) -> str:
1389
1426
  return str(self)
1390
1427
 
1428
+ def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
1429
+ return unstable_combine_rulesets(self, other)
1430
+
1391
1431
  # Create a unique name if we didn't pass one from the user
1392
1432
  @property
1393
1433
  def __egg_name__(self) -> str:
1394
1434
  return self.name or f"ruleset_{id(self)}"
1395
1435
 
1396
1436
 
1437
+ @dataclass
1438
+ class UnstableCombinedRuleset(Schedule):
1439
+ __egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
1440
+ schedule: RunDecl = field(init=False)
1441
+ name: str | None
1442
+ rulesets: InitVar[list[Ruleset | UnstableCombinedRuleset]]
1443
+
1444
+ def __post_init__(self, rulesets: list[Ruleset | UnstableCombinedRuleset]) -> None:
1445
+ self.schedule = RunDecl(self.__egg_name__, ())
1446
+ self.__egg_decls_thunk__ = Thunk.fn(self._create_egg_decls, *rulesets)
1447
+
1448
+ @property
1449
+ def __egg_name__(self) -> str:
1450
+ return self.name or f"combined_ruleset_{id(self)}"
1451
+
1452
+ def _create_egg_decls(self, *rulesets: Ruleset | UnstableCombinedRuleset) -> Declarations:
1453
+ decls = Declarations.create(*rulesets)
1454
+ decls._rulesets[self.__egg_name__] = CombinedRulesetDecl(tuple(r.__egg_name__ for r in rulesets))
1455
+ return decls
1456
+
1457
+ def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
1458
+ return unstable_combine_rulesets(self, other)
1459
+
1460
+
1461
+ def unstable_combine_rulesets(
1462
+ *rulesets: Ruleset | UnstableCombinedRuleset, name: str | None = None
1463
+ ) -> UnstableCombinedRuleset:
1464
+ """
1465
+ Combine multiple rulesets into a single ruleset.
1466
+ """
1467
+ return UnstableCombinedRuleset(name, list(rulesets))
1468
+
1469
+
1397
1470
  @dataclass
1398
1471
  class RewriteOrRule:
1399
1472
  __egg_decls__: Declarations
@@ -1556,9 +1629,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
1556
1629
 
1557
1630
  def _var(name: str, bound: object) -> RuntimeExpr:
1558
1631
  """Create a new variable with the given name and type."""
1559
- if not isinstance(bound, RuntimeClass):
1560
- raise TypeError(f"Unexpected type {type(bound)}")
1561
- return RuntimeExpr.__from_value__(bound.__egg_decls__, TypedExprDecl(bound.__egg_tp__.to_just(), VarDecl(name)))
1632
+ decls = Declarations()
1633
+ type_ref = resolve_type_annotation(decls, bound)
1634
+ return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
1562
1635
 
1563
1636
 
1564
1637
  def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
@@ -1801,8 +1874,10 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
1801
1874
  """
1802
1875
  # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
1803
1876
  # but not in the globals
1804
-
1805
- hints = get_type_hints(gen, gen.__globals__, frame.f_locals)
1877
+ globals = gen.__globals__.copy()
1878
+ if "Callable" not in globals:
1879
+ globals["Callable"] = Callable
1880
+ hints = get_type_hints(gen, globals, frame.f_locals)
1806
1881
  args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1807
1882
  return list(gen(*args)) # type: ignore[misc]
1808
1883