egglog 6.0.1__cp312-none-win_amd64.whl → 7.0.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -3,12 +3,10 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import pathlib
5
5
  import tempfile
6
- from abc import ABC, abstractmethod
6
+ from abc import abstractmethod
7
7
  from collections.abc import Callable, Iterable
8
8
  from contextvars import ContextVar, Token
9
- from copy import deepcopy
10
9
  from dataclasses import InitVar, dataclass, field
11
- from functools import cached_property
12
10
  from inspect import Parameter, currentframe, signature
13
11
  from types import FrameType, FunctionType
14
12
  from typing import (
@@ -18,6 +16,7 @@ from typing import (
18
16
  Generic,
19
17
  Literal,
20
18
  NoReturn,
19
+ TypeAlias,
21
20
  TypedDict,
22
21
  TypeVar,
23
22
  cast,
@@ -26,14 +25,16 @@ from typing import (
26
25
  )
27
26
 
28
27
  import graphviz
29
- from typing_extensions import ParamSpec, Self, Unpack, deprecated
30
-
31
- from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations
28
+ from typing_extensions import ParamSpec, Self, Unpack, assert_never, deprecated
32
29
 
33
30
  from . import bindings
31
+ from .conversion import *
34
32
  from .declarations import *
33
+ from .egraph_state import *
35
34
  from .ipython_magic import IN_IPYTHON
35
+ from .pretty import pretty_decl
36
36
  from .runtime import *
37
+ from .thunk import *
37
38
 
38
39
  if TYPE_CHECKING:
39
40
  import ipywidgets
@@ -58,6 +59,7 @@ __all__ = [
58
59
  "let",
59
60
  "constant",
60
61
  "delete",
62
+ "subsume",
61
63
  "union",
62
64
  "set_",
63
65
  "rule",
@@ -65,6 +67,9 @@ __all__ = [
65
67
  "vars_",
66
68
  "Fact",
67
69
  "expr_parts",
70
+ "expr_action",
71
+ "expr_fact",
72
+ "action_command",
68
73
  "Schedule",
69
74
  "run",
70
75
  "seq",
@@ -79,11 +84,10 @@ __all__ = [
79
84
  "_NeBuilder",
80
85
  "_SetBuilder",
81
86
  "_UnionBuilder",
82
- "Rule",
83
- "Rewrite",
84
- "BiRewrite",
85
- "Union_",
87
+ "RewriteOrRule",
88
+ "Fact",
86
89
  "Action",
90
+ "Command",
87
91
  ]
88
92
 
89
93
  T = TypeVar("T")
@@ -112,7 +116,24 @@ IGNORED_ATTRIBUTES = {
112
116
  }
113
117
 
114
118
 
119
+ # special methods that return none and mutate self
115
120
  ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"}
121
+ # special methods which must return real python values instead of lazy expressions
122
+ ALWAYS_PRESERVED = {
123
+ "__repr__",
124
+ "__str__",
125
+ "__bytes__",
126
+ "__format__",
127
+ "__hash__",
128
+ "__bool__",
129
+ "__len__",
130
+ "__length_hint__",
131
+ "__iter__",
132
+ "__reversed__",
133
+ "__contains__",
134
+ "__index__",
135
+ "__bufer__",
136
+ }
116
137
 
117
138
 
118
139
  def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
@@ -124,7 +145,7 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
124
145
  return EGraph().extract(x)
125
146
 
126
147
 
127
- def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr | Set) -> None:
148
+ def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
128
149
  """
129
150
  Verifies that the fact is true given some assumptions and after running the schedule.
130
151
  """
@@ -136,9 +157,6 @@ def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr |
136
157
  egraph.check(x)
137
158
 
138
159
 
139
- # def extract(res: )
140
-
141
-
142
160
  @dataclass
143
161
  class _BaseModule:
144
162
  """
@@ -181,15 +199,8 @@ class _BaseModule:
181
199
  Registers a class.
182
200
  """
183
201
  if kwargs:
184
- assert set(kwargs.keys()) == {"egg_sort"}
185
-
186
- def _inner(cls: object, egg_sort: str = kwargs["egg_sort"]):
187
- assert isinstance(cls, RuntimeClass)
188
- assert isinstance(cls.lazy_decls, _ClassDeclerationsConstructor)
189
- cls.lazy_decls.egg_sort = egg_sort
190
- return cls
191
-
192
- return _inner
202
+ msg = "Switch to subclassing from Expr and passing egg_sort as a keyword arg to the class constructor"
203
+ raise NotImplementedError(msg)
193
204
 
194
205
  assert len(args) == 1
195
206
  return args[0]
@@ -280,9 +291,9 @@ class _BaseModule:
280
291
  # If we have any positional args, then we are calling it directly on a function
281
292
  if args:
282
293
  assert len(args) == 1
283
- return _function(args[0], fn_locals, False)
294
+ return _FunctionConstructor(fn_locals)(args[0])
284
295
  # otherwise, we are passing some keyword args, so save those, and then return a partial
285
- return lambda fn: _function(fn, fn_locals, False, **kwargs)
296
+ return _FunctionConstructor(fn_locals, **kwargs)
286
297
 
287
298
  @deprecated("Use top level `ruleset` function instead")
288
299
  def ruleset(self, name: str) -> Ruleset:
@@ -324,17 +335,26 @@ class _BaseModule:
324
335
  """
325
336
  return constant(name, tp, egg_name)
326
337
 
327
- def register(self, /, command_or_generator: CommandLike | CommandGenerator, *command_likes: CommandLike) -> None:
338
+ def register(
339
+ self,
340
+ /,
341
+ command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator,
342
+ *command_likes: ActionLike | RewriteOrRule,
343
+ ) -> None:
328
344
  """
329
345
  Registers any number of rewrites or rules.
330
346
  """
331
347
  if isinstance(command_or_generator, FunctionType):
332
348
  assert not command_likes
333
- command_likes = tuple(_command_generator(command_or_generator))
349
+ current_frame = inspect.currentframe()
350
+ assert current_frame
351
+ original_frame = current_frame.f_back
352
+ assert original_frame
353
+ command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame))
334
354
  else:
335
355
  command_likes = (cast(CommandLike, command_or_generator), *command_likes)
336
-
337
- self._register_commands(list(map(_command_like, command_likes)))
356
+ commands = [_command_like(c) for c in command_likes]
357
+ self._register_commands(commands)
338
358
 
339
359
  @abstractmethod
340
360
  def _register_commands(self, cmds: list[Command]) -> None:
@@ -417,136 +437,116 @@ class _ExprMetaclass(type):
417
437
  # If this is the Expr subclass, just return the class
418
438
  if not bases:
419
439
  return super().__new__(cls, name, bases, namespace)
440
+ # TODO: Raise error on subclassing or multiple inheritence
420
441
 
421
442
  frame = currentframe()
422
443
  assert frame
423
444
  prev_frame = frame.f_back
424
445
  assert prev_frame
425
- return _ClassDeclerationsConstructor(
426
- namespace=namespace,
427
- # Store frame so that we can get live access to updated locals/globals
428
- # Otherwise, f_locals returns a copy
429
- # https://peps.python.org/pep-0667/
430
- frame=prev_frame,
431
- builtin=builtin,
432
- egg_sort=egg_sort,
433
- cls_name=name,
434
- ).current_cls
446
+
447
+ # Store frame so that we can get live access to updated locals/globals
448
+ # Otherwise, f_locals returns a copy
449
+ # https://peps.python.org/pep-0667/
450
+ decls_thunk = Thunk.fn(
451
+ _generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
452
+ )
453
+ return RuntimeClass(decls_thunk, TypeRefWithVars(name))
435
454
 
436
455
  def __instancecheck__(cls, instance: object) -> bool:
437
456
  return isinstance(instance, RuntimeExpr)
438
457
 
439
458
 
440
- @dataclass
441
- class _ClassDeclerationsConstructor:
459
+ def _generate_class_decls(
460
+ namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
461
+ ) -> Declarations:
442
462
  """
443
463
  Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
444
464
  """
465
+ parameters: list[TypeVar] = (
466
+ # Get the generic params from the orig bases generic class
467
+ namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in namespace else []
468
+ )
469
+ type_vars = tuple(p.__name__ for p in parameters)
470
+ del parameters
471
+ cls_decl = ClassDecl(egg_sort, type_vars, builtin)
472
+ decls = Declarations(_classes={cls_name: cls_decl})
473
+
474
+ ##
475
+ # Register class variables
476
+ ##
477
+ # Create a dummy type to pass to get_type_hints to resolve the annotations we have
478
+ _Dummytype = type("_DummyType", (), {"__annotations__": namespace.get("__annotations__", {})})
479
+ for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
480
+ if getattr(v, "__origin__", None) == ClassVar:
481
+ (inner_tp,) = v.__args__
482
+ type_ref = resolve_type_annotation(decls, inner_tp).to_just()
483
+ cls_decl.class_variables[k] = ConstantDecl(type_ref)
445
484
 
446
- namespace: dict[str, Any]
447
- frame: FrameType
448
- builtin: bool
449
- egg_sort: str | None
450
- cls_name: str
451
- current_cls: RuntimeClass = field(init=False)
485
+ else:
486
+ msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
487
+ raise NotImplementedError(msg)
452
488
 
453
- def __post_init__(self) -> None:
454
- self.current_cls = RuntimeClass(self, self.cls_name)
455
-
456
- def __call__(self, decls: Declarations) -> None: # noqa: PLR0912
457
- # Get all the methods from the class
458
- cls_dict: dict[str, Any] = {
459
- k: v for k, v in self.namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
460
- }
461
- parameters: list[TypeVar] = (
462
- # Get the generic params from the orig bases generic class
463
- self.namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in self.namespace else []
464
- )
465
- type_vars = tuple(p.__name__ for p in parameters)
466
- del parameters
467
-
468
- decls.register_class(self.cls_name, type_vars, self.builtin, self.egg_sort)
469
- # The type ref of self is paramterized by the type vars
470
- slf_type_ref = TypeRefWithVars(self.cls_name, tuple(map(ClassTypeVarRef, type_vars)))
471
-
472
- # Create a dummy type to pass to get_type_hints to resolve the annotations we have
473
- class _Dummytype:
474
- pass
475
-
476
- _Dummytype.__annotations__ = self.namespace.get("__annotations__", {})
477
- # Make lazy update to locals, so we keep a live handle on them after class creation
478
- locals = self.frame.f_locals.copy()
479
- locals[self.cls_name] = self.current_cls
480
- for k, v in get_type_hints(_Dummytype, globalns=self.frame.f_globals, localns=locals).items():
481
- if getattr(v, "__origin__", None) == ClassVar:
482
- (inner_tp,) = v.__args__
483
- _register_constant(decls, ClassVariableRef(self.cls_name, k), inner_tp, None)
484
- else:
485
- msg = f"On class {self.cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
486
- raise NotImplementedError(msg)
487
-
488
- # Then register each of its methods
489
- for method_name, method in cls_dict.items():
490
- is_init = method_name == "__init__"
491
- # Don't register the init methods for literals, since those don't use the type checking mechanisms
492
- if is_init and self.cls_name in LIT_CLASS_NAMES:
493
- continue
494
- if isinstance(method, _WrappedMethod):
495
- fn = method.fn
496
- egg_fn = method.egg_fn
497
- cost = method.cost
498
- default = method.default
499
- merge = method.merge
500
- on_merge = method.on_merge
501
- mutates_first_arg = method.mutates_self
502
- unextractable = method.unextractable
503
- if method.preserve:
504
- decls.register_preserved_method(self.cls_name, method_name, fn)
505
- continue
506
- else:
507
- fn = method
489
+ ##
490
+ # Register methods, classmethods, preserved methods, and properties
491
+ ##
492
+
493
+ # The type ref of self is paramterized by the type vars
494
+ slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
495
+
496
+ # Get all the methods from the class
497
+ filtered_namespace: list[tuple[str, Any]] = [
498
+ (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
499
+ ]
500
+
501
+ # Then register each of its methods
502
+ for method_name, method in filtered_namespace:
503
+ is_init = method_name == "__init__"
504
+ # Don't register the init methods for literals, since those don't use the type checking mechanisms
505
+ if is_init and cls_name in LIT_CLASS_NAMES:
506
+ continue
507
+ match method:
508
+ case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable):
509
+ pass
510
+ case _:
508
511
  egg_fn, cost, default, merge, on_merge = None, None, None, None, None
509
- unextractable = False
510
- mutates_first_arg = False
511
- if isinstance(fn, classmethod):
512
- fn = fn.__func__
513
- is_classmethod = True
514
- else:
515
- # We count __init__ as a classmethod since it is called on the class
516
- is_classmethod = is_init
517
-
518
- if isinstance(fn, property):
519
- fn = fn.fget
520
- is_property = True
521
- if is_classmethod:
522
- msg = "Can't have a classmethod property"
523
- raise NotImplementedError(msg)
524
- else:
525
- is_property = False
526
- ref: FunctionCallableRef = (
527
- ClassMethodRef(self.cls_name, method_name)
528
- if is_classmethod
529
- else PropertyRef(self.cls_name, method_name)
530
- if is_property
531
- else MethodRef(self.cls_name, method_name)
532
- )
533
- _register_function(
512
+ fn = method
513
+ unextractable, preserve = False, False
514
+ mutates = method_name in ALWAYS_MUTATES_SELF
515
+ if preserve:
516
+ cls_decl.preserved_methods[method_name] = fn
517
+ continue
518
+ locals = frame.f_locals
519
+
520
+ def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
521
+ return _fn_decl(
534
522
  decls,
535
- ref,
536
- egg_fn,
523
+ egg_fn, # noqa: B023
537
524
  fn,
538
- locals,
539
- default,
540
- cost,
541
- merge,
542
- on_merge,
543
- mutates_first_arg or method_name in ALWAYS_MUTATES_SELF,
544
- self.builtin,
545
- "cls" if is_classmethod and not is_init else slf_type_ref,
546
- is_init,
547
- unextractable=unextractable,
525
+ locals, # noqa: B023
526
+ default, # noqa: B023
527
+ cost, # noqa: B023
528
+ merge, # noqa: B023
529
+ on_merge, # noqa: B023
530
+ mutates, # noqa: B023
531
+ builtin,
532
+ first,
533
+ is_init, # noqa: B023
534
+ unextractable, # noqa: B023
548
535
  )
549
536
 
537
+ match fn:
538
+ case classmethod():
539
+ cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
540
+ case property():
541
+ cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
542
+ case _:
543
+ if is_init:
544
+ cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
545
+ else:
546
+ cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
547
+
548
+ return decls
549
+
550
550
 
551
551
  @overload
552
552
  def function(fn: CALLABLE, /) -> CALLABLE: ...
@@ -589,48 +589,46 @@ def function(*args, **kwargs) -> Any:
589
589
  # If we have any positional args, then we are calling it directly on a function
590
590
  if args:
591
591
  assert len(args) == 1
592
- return _function(args[0], fn_locals, False)
592
+ return _FunctionConstructor(fn_locals)(args[0])
593
593
  # otherwise, we are passing some keyword args, so save those, and then return a partial
594
- return lambda fn: _function(fn, fn_locals, **kwargs)
594
+ return _FunctionConstructor(fn_locals, **kwargs)
595
595
 
596
596
 
597
- def _function(
598
- fn: Callable[..., RuntimeExpr],
599
- hint_locals: dict[str, Any],
600
- builtin: bool = False,
601
- mutates_first_arg: bool = False,
602
- egg_fn: str | None = None,
603
- cost: int | None = None,
604
- default: RuntimeExpr | None = None,
605
- merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None,
606
- on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None,
607
- unextractable: bool = False,
608
- ) -> RuntimeFunction:
609
- """
610
- Uncurried version of function decorator
611
- """
612
- name = fn.__name__
613
- decls = Declarations()
614
- _register_function(
615
- decls,
616
- FunctionRef(name),
617
- egg_fn,
618
- fn,
619
- hint_locals,
620
- default,
621
- cost,
622
- merge,
623
- on_merge,
624
- mutates_first_arg,
625
- builtin,
626
- unextractable=unextractable,
627
- )
628
- return RuntimeFunction(decls, name)
597
+ @dataclass
598
+ class _FunctionConstructor:
599
+ hint_locals: dict[str, Any]
600
+ builtin: bool = False
601
+ mutates_first_arg: bool = False
602
+ egg_fn: str | None = None
603
+ cost: int | None = None
604
+ default: RuntimeExpr | None = None
605
+ merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
606
+ on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
607
+ unextractable: bool = False
608
+
609
+ def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
610
+ return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
611
+
612
+ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
613
+ decls = Declarations()
614
+ decls._functions[fn.__name__] = _fn_decl(
615
+ decls,
616
+ self.egg_fn,
617
+ fn,
618
+ self.hint_locals,
619
+ self.default,
620
+ self.cost,
621
+ self.merge,
622
+ self.on_merge,
623
+ self.mutates_first_arg,
624
+ self.builtin,
625
+ unextractable=self.unextractable,
626
+ )
627
+ return decls
629
628
 
630
629
 
631
- def _register_function(
630
+ def _fn_decl(
632
631
  decls: Declarations,
633
- ref: FunctionCallableRef,
634
632
  egg_name: str | None,
635
633
  fn: object,
636
634
  # Pass in the locals, retrieved from the frame when wrapping,
@@ -646,7 +644,7 @@ def _register_function(
646
644
  first_arg: Literal["cls"] | TypeOrVarRef | None = None,
647
645
  is_init: bool = False,
648
646
  unextractable: bool = False,
649
- ) -> None:
647
+ ) -> FunctionDecl:
650
648
  if not isinstance(fn, FunctionType):
651
649
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
652
650
 
@@ -699,8 +697,8 @@ def _register_function(
699
697
  None
700
698
  if merge is None
701
699
  else merge(
702
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
703
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
700
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
701
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
704
702
  )
705
703
  )
706
704
  decls |= merged
@@ -710,30 +708,25 @@ def _register_function(
710
708
  if on_merge is None
711
709
  else _action_likes(
712
710
  on_merge(
713
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
714
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
711
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
712
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
715
713
  )
716
714
  )
717
715
  )
718
716
  decls.update(*merge_action)
719
- fn_decl = FunctionDecl(
720
- return_type=return_type,
717
+ return FunctionDecl(
718
+ return_type=None if mutates_first_arg else return_type,
721
719
  var_arg_type=var_arg_type,
722
720
  arg_types=arg_types,
723
721
  arg_names=tuple(t.name for t in params),
724
722
  arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
725
- mutates_first_arg=mutates_first_arg,
726
- )
727
- decls.register_function_callable(
728
- ref,
729
- fn_decl,
730
- egg_name,
731
- cost,
732
- None if default is None else default.__egg_typed_expr__.expr,
733
- merged.__egg_typed_expr__.expr if merged is not None else None,
734
- [a._to_egg_action() for a in merge_action],
735
- unextractable,
736
- is_builtin,
723
+ cost=cost,
724
+ egg_name=egg_name,
725
+ merge=merged.__egg_typed_expr__.expr if merged is not None else None,
726
+ unextractable=unextractable,
727
+ builtin=is_builtin,
728
+ default=None if default is None else default.__egg_typed_expr__.expr,
729
+ on_merge=tuple(a.action for a in merge_action),
737
730
  )
738
731
 
739
732
 
@@ -764,49 +757,31 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
764
757
  """
765
758
  Creates a function whose return type is `Unit` and has a default value.
766
759
  """
760
+ decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
761
+ return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
762
+
763
+
764
+ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
767
765
  decls = Declarations()
768
766
  decls |= cast(RuntimeClass, Unit)
769
- arg_types = tuple(resolve_type_annotation(decls, tp) for tp in tps)
770
- fn_decl = FunctionDecl(arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit"), mutates_first_arg=False)
771
- decls.register_function_callable(
772
- FunctionRef(name),
773
- fn_decl,
774
- egg_fn,
775
- cost=None,
776
- default=None,
777
- merge=None,
778
- merge_action=[],
779
- unextractable=False,
780
- builtin=False,
781
- is_relation=True,
782
- )
783
- return cast(Callable[..., Unit], RuntimeFunction(decls, name))
767
+ arg_types = tuple(resolve_type_annotation(decls, tp).to_just() for tp in tps)
768
+ decls._functions[name] = RelationDecl(arg_types, tuple(None for _ in tps), egg_fn)
769
+ return decls
784
770
 
785
771
 
786
772
  def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
787
773
  """
788
-
789
774
  A "constant" is implemented as the instantiation of a value that takes no args.
790
775
  This creates a function with `name` and return type `tp` and returns a value of it being called.
791
776
  """
792
- ref = ConstantRef(name)
793
- decls = Declarations()
794
- type_ref = _register_constant(decls, ref, tp, egg_name)
795
- return cast(EXPR, RuntimeExpr(decls, TypedExprDecl(type_ref, CallDecl(ref))))
777
+ return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
796
778
 
797
779
 
798
- def _register_constant(
799
- decls: Declarations,
800
- ref: ConstantRef | ClassVariableRef,
801
- tp: object,
802
- egg_name: str | None,
803
- ) -> JustTypeRef:
804
- """
805
- Register a constant, returning its typeref().
806
- """
780
+ def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
781
+ decls = Declarations()
807
782
  type_ref = resolve_type_annotation(decls, tp).to_just()
808
- decls.register_constant_callable(ref, type_ref, egg_name)
809
- return type_ref
783
+ decls._constants[name] = ConstantDecl(type_ref, egg_name)
784
+ return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
810
785
 
811
786
 
812
787
  def _last_param_variable(params: list[Parameter]) -> bool:
@@ -858,29 +833,6 @@ class GraphvizKwargs(TypedDict, total=False):
858
833
  split_primitive_outputs: bool
859
834
 
860
835
 
861
- @dataclass
862
- class _EGraphState:
863
- """
864
- State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
865
- """
866
-
867
- # The decleratons we have added. The _cmds represent all the symbols we have added
868
- decls: Declarations = field(default_factory=Declarations)
869
- # List of rulesets already added, so we don't re-add them if they are passed again
870
- added_rulesets: set[str] = field(default_factory=set)
871
-
872
- def add_decls(self, new_decls: Declarations) -> Iterable[bindings._Command]:
873
- new_cmds = [v for k, v in new_decls._cmds.items() if k not in self.decls._cmds]
874
- self.decls |= new_decls
875
- return new_cmds
876
-
877
- def add_rulesets(self, rulesets: Iterable[Ruleset]) -> Iterable[bindings._Command]:
878
- for ruleset in rulesets:
879
- if ruleset.egg_name not in self.added_rulesets:
880
- self.added_rulesets.add(ruleset.egg_name)
881
- yield from ruleset._cmds
882
-
883
-
884
836
  @dataclass
885
837
  class EGraph(_BaseModule):
886
838
  """
@@ -892,56 +844,34 @@ class EGraph(_BaseModule):
892
844
  seminaive: InitVar[bool] = True
893
845
  save_egglog_string: InitVar[bool] = False
894
846
 
895
- default_ruleset: Ruleset | None = None
896
- _egraph: bindings.EGraph = field(repr=False, init=False)
897
- _egglog_string: str | None = field(default=None, repr=False, init=False)
898
- _state: _EGraphState = field(default_factory=_EGraphState, repr=False)
847
+ _state: EGraphState = field(init=False)
899
848
  # For pushing/popping with egglog
900
- _state_stack: list[_EGraphState] = field(default_factory=list, repr=False)
849
+ _state_stack: list[EGraphState] = field(default_factory=list, repr=False)
901
850
  # For storing the global "current" egraph
902
851
  _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False)
903
852
 
904
853
  def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None:
905
- self._egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive)
854
+ egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string)
855
+ self._state = EGraphState(egraph)
906
856
  super().__post_init__(modules)
907
857
 
908
858
  for m in self._flatted_deps:
909
- self._add_decls(*m.cmds)
910
859
  self._register_commands(m.cmds)
911
- if save_egglog_string:
912
- self._egglog_string = ""
913
-
914
- def _register_commands(self, commands: list[Command]) -> None:
915
- for c in commands:
916
- if c.ruleset:
917
- self._add_schedule(c.ruleset)
918
-
919
- self._add_decls(*commands)
920
- self._process_commands(command._to_egg_command(self._default_ruleset_name) for command in commands)
921
-
922
- def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
923
- commands = list(commands)
924
- self._egraph.run_program(*commands)
925
- if isinstance(self._egglog_string, str):
926
- self._egglog_string += "\n".join(str(c) for c in commands) + "\n"
927
860
 
928
861
  def _add_decls(self, *decls: DeclerationsLike) -> None:
929
- for d in upcast_decleratioons(decls):
930
- self._process_commands(self._state.add_decls(d))
931
-
932
- def _add_schedule(self, schedule: Schedule) -> None:
933
- self._add_decls(schedule)
934
- self._process_commands(self._state.add_rulesets(schedule._rulesets()))
862
+ for d in decls:
863
+ self._state.__egg_decls__ |= d
935
864
 
936
865
  @property
937
866
  def as_egglog_string(self) -> str:
938
867
  """
939
868
  Returns the egglog string for this module.
940
869
  """
941
- if self._egglog_string is None:
870
+ cmds = self._egraph.commands()
871
+ if cmds is None:
942
872
  msg = "Can't get egglog string unless EGraph created with save_egglog_string=True"
943
873
  raise ValueError(msg)
944
- return self._egglog_string
874
+ return cmds
945
875
 
946
876
  def _repr_mimebundle_(self, *args, **kwargs):
947
877
  """
@@ -954,7 +884,7 @@ class EGraph(_BaseModule):
954
884
  kwargs.setdefault("split_primitive_outputs", True)
955
885
  n_inline = kwargs.pop("n_inline_leaves", 0)
956
886
  serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
957
- serialized.map_ops(self._state.decls.op_mapping())
887
+ serialized.map_ops(self._state.op_mapping())
958
888
  for _ in range(n_inline):
959
889
  serialized.inline_leaves()
960
890
  original = serialized.to_dot()
@@ -1016,17 +946,23 @@ class EGraph(_BaseModule):
1016
946
  Loads a CSV file and sets it as *input, output of the function.
1017
947
  """
1018
948
  ref, decls = resolve_callable(fn)
1019
- fn_name = decls.get_egg_fn(ref)
1020
- self._process_commands(decls.list_cmds())
1021
- self._process_commands([bindings.Input(fn_name, path)])
949
+ self._add_decls(decls)
950
+ fn_name = self._state.callable_ref_to_egg(ref)
951
+ self._egraph.run_program(bindings.Input(fn_name, path))
1022
952
 
1023
953
  def let(self, name: str, expr: EXPR) -> EXPR:
1024
954
  """
1025
955
  Define a new expression in the egraph and return a reference to it.
1026
956
  """
1027
- self._register_commands([let(name, expr)])
1028
- expr = to_runtime_expr(expr)
1029
- return cast(EXPR, RuntimeExpr(expr.__egg_decls__, TypedExprDecl(expr.__egg_typed_expr__.tp, VarDecl(name))))
957
+ action = let(name, expr)
958
+ self.register(action)
959
+ runtime_expr = to_runtime_expr(expr)
960
+ return cast(
961
+ EXPR,
962
+ RuntimeExpr.__from_value__(
963
+ self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
964
+ ),
965
+ )
1030
966
 
1031
967
  @overload
1032
968
  def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ...
@@ -1041,28 +977,19 @@ class EGraph(_BaseModule):
1041
977
  Simplifies the given expression.
1042
978
  """
1043
979
  schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
1044
- del limit_or_schedule
1045
- expr = to_runtime_expr(expr)
1046
- self._add_decls(expr)
1047
- self._add_schedule(schedule)
1048
-
1049
- # decls = Declarations.create(expr, schedule)
1050
- self._process_commands([bindings.Simplify(expr.__egg__, schedule._to_egg_schedule(self._default_ruleset_name))])
980
+ del limit_or_schedule, until, ruleset
981
+ runtime_expr = to_runtime_expr(expr)
982
+ self._add_decls(runtime_expr, schedule)
983
+ egg_schedule = self._state.schedule_to_egg(schedule.schedule)
984
+ typed_expr = runtime_expr.__egg_typed_expr__
985
+ egg_expr = self._state.expr_to_egg(typed_expr.expr)
986
+ self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
1051
987
  extract_report = self._egraph.extract_report()
1052
988
  if not isinstance(extract_report, bindings.Best):
1053
989
  msg = "No extract report saved"
1054
990
  raise ValueError(msg) # noqa: TRY004
1055
- new_typed_expr = TypedExprDecl.from_egg(
1056
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, extract_report.term, {}
1057
- )
1058
- return cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
1059
-
1060
- @property
1061
- def _default_ruleset_name(self) -> str:
1062
- if self.default_ruleset:
1063
- self._add_schedule(self.default_ruleset)
1064
- return self.default_ruleset.egg_name
1065
- return ""
991
+ (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
992
+ return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1066
993
 
1067
994
  def include(self, path: str) -> None:
1068
995
  """
@@ -1092,8 +1019,9 @@ class EGraph(_BaseModule):
1092
1019
  return self._run_schedule(limit_or_schedule)
1093
1020
 
1094
1021
  def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
1095
- self._add_schedule(schedule)
1096
- self._process_commands([bindings.RunSchedule(schedule._to_egg_schedule(self._default_ruleset_name))])
1022
+ self._add_decls(schedule)
1023
+ egg_schedule = self._state.schedule_to_egg(schedule.schedule)
1024
+ self._egraph.run_program(bindings.RunSchedule(egg_schedule))
1097
1025
  run_report = self._egraph.run_report()
1098
1026
  if not run_report:
1099
1027
  msg = "No run report saved"
@@ -1104,18 +1032,18 @@ class EGraph(_BaseModule):
1104
1032
  """
1105
1033
  Check if a fact is true in the egraph.
1106
1034
  """
1107
- self._process_commands([self._facts_to_check(facts)])
1035
+ self._egraph.run_program(self._facts_to_check(facts))
1108
1036
 
1109
1037
  def check_fail(self, *facts: FactLike) -> None:
1110
1038
  """
1111
1039
  Checks that one of the facts is not true
1112
1040
  """
1113
- self._process_commands([bindings.Fail(self._facts_to_check(facts))])
1041
+ self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
1114
1042
 
1115
- def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check:
1116
- facts = _fact_likes(facts)
1043
+ def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
1044
+ facts = _fact_likes(fact_likes)
1117
1045
  self._add_decls(*facts)
1118
- egg_facts = [f._to_egg_fact() for f in _fact_likes(facts)]
1046
+ egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
1119
1047
  return bindings.Check(egg_facts)
1120
1048
 
1121
1049
  @overload
@@ -1128,16 +1056,17 @@ class EGraph(_BaseModule):
1128
1056
  """
1129
1057
  Extract the lowest cost expression from the egraph.
1130
1058
  """
1131
- assert isinstance(expr, RuntimeExpr)
1132
- self._add_decls(expr)
1133
- extract_report = self._run_extract(expr.__egg__, 0)
1059
+ runtime_expr = to_runtime_expr(expr)
1060
+ self._add_decls(runtime_expr)
1061
+ typed_expr = runtime_expr.__egg_typed_expr__
1062
+ extract_report = self._run_extract(typed_expr.expr, 0)
1063
+
1134
1064
  if not isinstance(extract_report, bindings.Best):
1135
1065
  msg = "No extract report saved"
1136
1066
  raise ValueError(msg) # noqa: TRY004
1137
- new_typed_expr = TypedExprDecl.from_egg(
1138
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, extract_report.term, {}
1139
- )
1140
- res = cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
1067
+ (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1068
+
1069
+ res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1141
1070
  if include_cost:
1142
1071
  return res, extract_report.cost
1143
1072
  return res
@@ -1146,23 +1075,20 @@ class EGraph(_BaseModule):
1146
1075
  """
1147
1076
  Extract multiple expressions from the egraph.
1148
1077
  """
1149
- assert isinstance(expr, RuntimeExpr)
1150
- self._add_decls(expr)
1078
+ runtime_expr = to_runtime_expr(expr)
1079
+ self._add_decls(runtime_expr)
1080
+ typed_expr = runtime_expr.__egg_typed_expr__
1151
1081
 
1152
- extract_report = self._run_extract(expr.__egg__, n)
1082
+ extract_report = self._run_extract(typed_expr.expr, n)
1153
1083
  if not isinstance(extract_report, bindings.Variants):
1154
1084
  msg = "Wrong extract report type"
1155
1085
  raise ValueError(msg) # noqa: TRY004
1156
- new_exprs = [
1157
- TypedExprDecl.from_egg(
1158
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, term, {}
1159
- )
1160
- for term in extract_report.terms
1161
- ]
1162
- return [cast(EXPR, RuntimeExpr(self._state.decls, expr)) for expr in new_exprs]
1086
+ new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
1087
+ return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1163
1088
 
1164
- def _run_extract(self, expr: bindings._Expr, n: int) -> bindings._ExtractReport:
1165
- self._process_commands([bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n))))])
1089
+ def _run_extract(self, expr: ExprDecl, n: int) -> bindings._ExtractReport:
1090
+ expr = self._state.expr_to_egg(expr)
1091
+ self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1166
1092
  extract_report = self._egraph.extract_report()
1167
1093
  if not extract_report:
1168
1094
  msg = "No extract report saved"
@@ -1173,15 +1099,15 @@ class EGraph(_BaseModule):
1173
1099
  """
1174
1100
  Push the current state of the egraph, so that it can be popped later and reverted back.
1175
1101
  """
1176
- self._process_commands([bindings.Push(1)])
1102
+ self._egraph.run_program(bindings.Push(1))
1177
1103
  self._state_stack.append(self._state)
1178
- self._state = deepcopy(self._state)
1104
+ self._state = self._state.copy()
1179
1105
 
1180
1106
  def pop(self) -> None:
1181
1107
  """
1182
1108
  Pop the current state of the egraph, reverting back to the previous state.
1183
1109
  """
1184
- self._process_commands([bindings.Pop(1)])
1110
+ self._egraph.run_program(bindings.Pop(1))
1185
1111
  self._state = self._state_stack.pop()
1186
1112
 
1187
1113
  def __enter__(self) -> Self:
@@ -1217,9 +1143,10 @@ class EGraph(_BaseModule):
1217
1143
  """
1218
1144
  Evaluates the given expression (which must be a primitive type), returning the result.
1219
1145
  """
1220
- assert isinstance(expr, RuntimeExpr)
1221
- typed_expr = expr.__egg_typed_expr__
1222
- egg_expr = expr.__egg__
1146
+ runtime_expr = to_runtime_expr(expr)
1147
+ self._add_decls(runtime_expr)
1148
+ typed_expr = runtime_expr.__egg_typed_expr__
1149
+ egg_expr = self._state.expr_to_egg(typed_expr.expr)
1223
1150
  match typed_expr.tp:
1224
1151
  case JustTypeRef("i64"):
1225
1152
  return self._egraph.eval_i64(egg_expr)
@@ -1231,7 +1158,7 @@ class EGraph(_BaseModule):
1231
1158
  return self._egraph.eval_string(egg_expr)
1232
1159
  case JustTypeRef("PyObject"):
1233
1160
  return self._egraph.eval_py_object(egg_expr)
1234
- raise NotImplementedError(f"Eval not implemented for {typed_expr.tp.name}")
1161
+ raise TypeError(f"Eval not implemented for {typed_expr.tp}")
1235
1162
 
1236
1163
  def saturate(
1237
1164
  self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
@@ -1270,6 +1197,32 @@ class EGraph(_BaseModule):
1270
1197
  """
1271
1198
  return CURRENT_EGRAPH.get()
1272
1199
 
1200
+ @property
1201
+ def _egraph(self) -> bindings.EGraph:
1202
+ return self._state.egraph
1203
+
1204
+ @property
1205
+ def __egg_decls__(self) -> Declarations:
1206
+ return self._state.__egg_decls__
1207
+
1208
+ def _register_commands(self, cmds: list[Command]) -> None:
1209
+ self._add_decls(*cmds)
1210
+ egg_cmds = list(map(self._command_to_egg, cmds))
1211
+ self._egraph.run_program(*egg_cmds)
1212
+
1213
+ def _command_to_egg(self, cmd: Command) -> bindings._Command:
1214
+ ruleset_name = ""
1215
+ cmd_decl: CommandDecl
1216
+ match cmd:
1217
+ case RewriteOrRule(_, cmd_decl, ruleset):
1218
+ if ruleset:
1219
+ ruleset_name = ruleset.__egg_name__
1220
+ case Action(_, action):
1221
+ cmd_decl = ActionCommandDecl(action)
1222
+ case _:
1223
+ assert_never(cmd)
1224
+ return self._state.command_to_egg(cmd_decl, ruleset_name)
1225
+
1273
1226
 
1274
1227
  CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH")
1275
1228
 
@@ -1316,61 +1269,53 @@ class Unit(Expr, egg_sort="Unit", builtin=True):
1316
1269
 
1317
1270
 
1318
1271
  def ruleset(
1319
- rule_or_generator: CommandLike | CommandGenerator | None = None, *rules: Rule | Rewrite, name: None | str = None
1272
+ rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator | None = None,
1273
+ *rules: RewriteOrRule,
1274
+ name: None | str = None,
1320
1275
  ) -> Ruleset:
1321
1276
  """
1322
1277
  Creates a ruleset with the following rules.
1323
1278
 
1324
1279
  If no name is provided, one is generated based on the current module
1325
1280
  """
1326
- r = Ruleset(name=name)
1281
+ r = Ruleset(name)
1327
1282
  if rule_or_generator is not None:
1328
- r.register(rule_or_generator, *rules)
1283
+ r.register(rule_or_generator, *rules, _increase_frame=True)
1329
1284
  return r
1330
1285
 
1331
1286
 
1332
- class Schedule(ABC):
1287
+ @dataclass
1288
+ class Schedule(DelayedDeclerations):
1333
1289
  """
1334
1290
  A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
1335
1291
  """
1336
1292
 
1293
+ # Defer declerations so that we can have rule generators that used not yet defined yet
1294
+ schedule: ScheduleDecl
1295
+
1296
+ def __str__(self) -> str:
1297
+ return pretty_decl(self.__egg_decls__, self.schedule)
1298
+
1299
+ def __repr__(self) -> str:
1300
+ return str(self)
1301
+
1337
1302
  def __mul__(self, length: int) -> Schedule:
1338
1303
  """
1339
1304
  Repeat the schedule a number of times.
1340
1305
  """
1341
- return Repeat(length, self)
1306
+ return Schedule(self.__egg_decls_thunk__, RepeatDecl(self.schedule, length))
1342
1307
 
1343
1308
  def saturate(self) -> Schedule:
1344
1309
  """
1345
1310
  Run the schedule until the e-graph is saturated.
1346
1311
  """
1347
- return Saturate(self)
1312
+ return Schedule(self.__egg_decls_thunk__, SaturateDecl(self.schedule))
1348
1313
 
1349
1314
  def __add__(self, other: Schedule) -> Schedule:
1350
1315
  """
1351
1316
  Run two schedules in sequence.
1352
1317
  """
1353
- return Sequence((self, other))
1354
-
1355
- @abstractmethod
1356
- def __str__(self) -> str:
1357
- raise NotImplementedError
1358
-
1359
- @abstractmethod
1360
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1361
- raise NotImplementedError
1362
-
1363
- @abstractmethod
1364
- def _rulesets(self) -> Iterable[Ruleset]:
1365
- """
1366
- Mapping of all the rulesets used to commands.
1367
- """
1368
- raise NotImplementedError
1369
-
1370
- @property
1371
- @abstractmethod
1372
- def __egg_decls__(self) -> Declarations:
1373
- raise NotImplementedError
1318
+ return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
1374
1319
 
1375
1320
 
1376
1321
  @dataclass
@@ -1379,419 +1324,119 @@ class Ruleset(Schedule):
1379
1324
  A collection of rules, which can be run as a schedule.
1380
1325
  """
1381
1326
 
1327
+ __egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
1328
+ schedule: RunDecl = field(init=False)
1382
1329
  name: str | None
1383
- rules: list[Rule | Rewrite] = field(default_factory=list)
1384
1330
 
1385
- def append(self, rule: Rule | Rewrite) -> None:
1331
+ # Current declerations we have accumulated
1332
+ _current_egg_decls: Declarations = field(default_factory=Declarations)
1333
+ # Current rulesets we have accumulated
1334
+ __egg_ruleset__: RulesetDecl = field(init=False)
1335
+ # Rule generator functions that have been deferred, to allow for late type binding
1336
+ deferred_rule_gens: list[Callable[[], Iterable[RewriteOrRule]]] = field(default_factory=list)
1337
+
1338
+ def __post_init__(self) -> None:
1339
+ self.schedule = RunDecl(self.__egg_name__, ())
1340
+ self.__egg_ruleset__ = self._current_egg_decls._rulesets[self.__egg_name__] = RulesetDecl([])
1341
+ self.__egg_decls_thunk__ = self._update_egg_decls
1342
+
1343
+ def _update_egg_decls(self) -> Declarations:
1344
+ """
1345
+ To return the egg decls, we go through our deferred rules and add any we haven't yet
1346
+ """
1347
+ while self.deferred_rule_gens:
1348
+ rules = self.deferred_rule_gens.pop()()
1349
+ self._current_egg_decls.update(*rules)
1350
+ self.__egg_ruleset__.rules.extend(r.decl for r in rules)
1351
+ return self._current_egg_decls
1352
+
1353
+ def append(self, rule: RewriteOrRule) -> None:
1386
1354
  """
1387
1355
  Register a rule with the ruleset.
1388
1356
  """
1389
- self.rules.append(rule)
1357
+ self._current_egg_decls |= rule
1358
+ self.__egg_ruleset__.rules.append(rule.decl)
1390
1359
 
1391
- def register(self, /, rule_or_generator: CommandLike | CommandGenerator, *rules: Rule | Rewrite) -> None:
1360
+ def register(
1361
+ self,
1362
+ /,
1363
+ rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator,
1364
+ *rules: RewriteOrRule,
1365
+ _increase_frame: bool = False,
1366
+ ) -> None:
1392
1367
  """
1393
1368
  Register rewrites or rules, either as a function or as values.
1394
1369
  """
1395
- if isinstance(rule_or_generator, FunctionType):
1396
- assert not rules
1397
- rules = tuple(_command_generator(rule_or_generator))
1370
+ if isinstance(rule_or_generator, RewriteOrRule):
1371
+ self.append(rule_or_generator)
1372
+ for r in rules:
1373
+ self.append(r)
1398
1374
  else:
1399
- rules = (cast(Rule | Rewrite, rule_or_generator), *rules)
1400
- for r in rules:
1401
- self.append(r)
1402
-
1403
- @cached_property
1404
- def __egg_decls__(self) -> Declarations:
1405
- return Declarations.create(*self.rules)
1406
-
1407
- @property
1408
- def _cmds(self) -> list[bindings._Command]:
1409
- cmds = [r._to_egg_command(self.egg_name) for r in self.rules]
1410
- if self.egg_name:
1411
- cmds.insert(0, bindings.AddRuleset(self.egg_name))
1412
- return cmds
1375
+ assert not rules
1376
+ current_frame = inspect.currentframe()
1377
+ assert current_frame
1378
+ original_frame = current_frame.f_back
1379
+ assert original_frame
1380
+ if _increase_frame:
1381
+ original_frame = original_frame.f_back
1382
+ assert original_frame
1383
+ self.deferred_rule_gens.append(Thunk.fn(_rewrite_or_rule_generator, rule_or_generator, original_frame))
1413
1384
 
1414
1385
  def __str__(self) -> str:
1415
- return f"ruleset(name={self.egg_name!r})"
1386
+ return pretty_decl(self._current_egg_decls, self.__egg_ruleset__, ruleset_name=self.name)
1416
1387
 
1417
1388
  def __repr__(self) -> str:
1418
- if not self.rules:
1419
- return str(self)
1420
- rules = ", ".join(map(repr, self.rules))
1421
- return f"ruleset({rules}, name={self.egg_name!r})"
1422
-
1423
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1424
- return bindings.Run(self._to_egg_config())
1425
-
1426
- def _to_egg_config(self) -> bindings.RunConfig:
1427
- return bindings.RunConfig(self.egg_name, None)
1428
-
1429
- def _rulesets(self) -> Iterable[Ruleset]:
1430
- yield self
1389
+ return str(self)
1431
1390
 
1391
+ # Create a unique name if we didn't pass one from the user
1432
1392
  @property
1433
- def egg_name(self) -> str:
1434
- return self.name or f"_ruleset_{id(self)}"
1435
-
1436
-
1437
- class Command(ABC):
1438
- """
1439
- A command that can be executed in the egg interpreter.
1440
-
1441
- We only use this for commands which return no result and don't create new Python objects.
1442
-
1443
- Anything that can be passed to the `register` function in a Module is a Command.
1444
- """
1445
-
1446
- ruleset: Ruleset | None
1447
-
1448
- @property
1449
- @abstractmethod
1450
- def __egg_decls__(self) -> Declarations:
1451
- raise NotImplementedError
1452
-
1453
- @abstractmethod
1454
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1455
- raise NotImplementedError
1456
-
1457
- @abstractmethod
1458
- def __str__(self) -> str:
1459
- raise NotImplementedError
1393
+ def __egg_name__(self) -> str:
1394
+ return self.name or f"ruleset_{id(self)}"
1460
1395
 
1461
1396
 
1462
1397
  @dataclass
1463
- class Rewrite(Command):
1464
- ruleset: Ruleset | None
1465
- _lhs: RuntimeExpr
1466
- _rhs: RuntimeExpr
1467
- _conditions: tuple[Fact, ...]
1468
- _fn_name: ClassVar[str] = "rewrite"
1398
+ class RewriteOrRule:
1399
+ __egg_decls__: Declarations
1400
+ decl: RewriteOrRuleDecl
1401
+ ruleset: Ruleset | None = None
1469
1402
 
1470
1403
  def __str__(self) -> str:
1471
- args_str = ", ".join(map(str, [self._rhs, *self._conditions]))
1472
- return f"{self._fn_name}({self._lhs}).to({args_str})"
1473
-
1474
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1475
- return bindings.RewriteCommand(
1476
- self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
1477
- )
1478
-
1479
- def _to_egg_rewrite(self) -> bindings.Rewrite:
1480
- return bindings.Rewrite(
1481
- self._lhs.__egg_typed_expr__.expr.to_egg(self._lhs.__egg_decls__),
1482
- self._rhs.__egg_typed_expr__.expr.to_egg(self._rhs.__egg_decls__),
1483
- [c._to_egg_fact() for c in self._conditions],
1484
- )
1485
-
1486
- @cached_property
1487
- def __egg_decls__(self) -> Declarations:
1488
- return Declarations.create(self._lhs, self._rhs, *self._conditions)
1489
-
1490
- def with_ruleset(self, ruleset: Ruleset) -> Rewrite:
1491
- return Rewrite(ruleset, self._lhs, self._rhs, self._conditions)
1492
-
1404
+ return pretty_decl(self.__egg_decls__, self.decl)
1493
1405
 
1494
- @dataclass
1495
- class BiRewrite(Rewrite):
1496
- _fn_name: ClassVar[str] = "birewrite"
1497
-
1498
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1499
- return bindings.BiRewriteCommand(
1500
- self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
1501
- )
1406
+ def __repr__(self) -> str:
1407
+ return str(self)
1502
1408
 
1503
1409
 
1504
1410
  @dataclass
1505
- class Fact(ABC):
1411
+ class Fact:
1506
1412
  """
1507
1413
  A query on an EGraph, either by an expression or an equivalence between multiple expressions.
1508
1414
  """
1509
1415
 
1510
- @abstractmethod
1511
- def _to_egg_fact(self) -> bindings._Fact:
1512
- raise NotImplementedError
1513
-
1514
- @property
1515
- @abstractmethod
1516
- def __egg_decls__(self) -> Declarations:
1517
- raise NotImplementedError
1518
-
1519
-
1520
- @dataclass
1521
- class Eq(Fact):
1522
- _exprs: list[RuntimeExpr]
1523
-
1524
- def __str__(self) -> str:
1525
- first, *rest = self._exprs
1526
- args_str = ", ".join(map(str, rest))
1527
- return f"eq({first}).to({args_str})"
1528
-
1529
- def _to_egg_fact(self) -> bindings.Eq:
1530
- return bindings.Eq([e.__egg__ for e in self._exprs])
1531
-
1532
- @cached_property
1533
- def __egg_decls__(self) -> Declarations:
1534
- return Declarations.create(*self._exprs)
1535
-
1536
-
1537
- @dataclass
1538
- class ExprFact(Fact):
1539
- _expr: RuntimeExpr
1416
+ __egg_decls__: Declarations
1417
+ fact: FactDecl
1540
1418
 
1541
1419
  def __str__(self) -> str:
1542
- return str(self._expr)
1420
+ return pretty_decl(self.__egg_decls__, self.fact)
1543
1421
 
1544
- def _to_egg_fact(self) -> bindings.Fact:
1545
- return bindings.Fact(self._expr.__egg__)
1546
-
1547
- @cached_property
1548
- def __egg_decls__(self) -> Declarations:
1549
- return self._expr.__egg_decls__
1422
+ def __repr__(self) -> str:
1423
+ return str(self)
1550
1424
 
1551
1425
 
1552
1426
  @dataclass
1553
- class Rule(Command):
1554
- head: tuple[Action, ...]
1555
- body: tuple[Fact, ...]
1556
- name: str
1557
- ruleset: Ruleset | None
1558
-
1559
- def __str__(self) -> str:
1560
- head_str = ", ".join(map(str, self.head))
1561
- body_str = ", ".join(map(str, self.body))
1562
- return f"rule({body_str}).then({head_str})"
1563
-
1564
- def _to_egg_command(self, default_ruleset_name: str) -> bindings.RuleCommand:
1565
- return bindings.RuleCommand(
1566
- self.name,
1567
- self.ruleset.egg_name if self.ruleset else default_ruleset_name,
1568
- bindings.Rule(
1569
- [a._to_egg_action() for a in self.head],
1570
- [f._to_egg_fact() for f in self.body],
1571
- ),
1572
- )
1573
-
1574
- @cached_property
1575
- def __egg_decls__(self) -> Declarations:
1576
- return Declarations.create(*self.head, *self.body)
1577
-
1578
-
1579
- class Action(Command, ABC):
1427
+ class Action:
1580
1428
  """
1581
1429
  A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking.
1582
1430
  """
1583
1431
 
1584
- @abstractmethod
1585
- def _to_egg_action(self) -> bindings._Action:
1586
- raise NotImplementedError
1587
-
1588
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1589
- return bindings.ActionCommand(self._to_egg_action())
1590
-
1591
- @property
1592
- def ruleset(self) -> None | Ruleset: # type: ignore[override]
1593
- return None
1594
-
1595
-
1596
- @dataclass
1597
- class Let(Action):
1598
- _name: str
1599
- _value: RuntimeExpr
1432
+ __egg_decls__: Declarations
1433
+ action: ActionDecl
1600
1434
 
1601
1435
  def __str__(self) -> str:
1602
- return f"let({self._name}, {self._value})"
1603
-
1604
- def _to_egg_action(self) -> bindings.Let:
1605
- return bindings.Let(self._name, self._value.__egg__)
1606
-
1607
- @property
1608
- def __egg_decls__(self) -> Declarations:
1609
- return self._value.__egg_decls__
1610
-
1611
-
1612
- @dataclass
1613
- class Set(Action):
1614
- """
1615
- Similar to union, except can be used on primitive expressions, whereas union can only be used on user defined expressions.
1616
- """
1617
-
1618
- _call: RuntimeExpr
1619
- _rhs: RuntimeExpr
1620
-
1621
- def __str__(self) -> str:
1622
- return f"set({self._call}).to({self._rhs})"
1623
-
1624
- def _to_egg_action(self) -> bindings.Set:
1625
- egg_call = self._call.__egg__
1626
- if not isinstance(egg_call, bindings.Call):
1627
- raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") # noqa: TRY004
1628
- return bindings.Set(
1629
- egg_call.name,
1630
- egg_call.args,
1631
- self._rhs.__egg__,
1632
- )
1633
-
1634
- @cached_property
1635
- def __egg_decls__(self) -> Declarations:
1636
- return Declarations.create(self._call, self._rhs)
1637
-
1638
-
1639
- @dataclass
1640
- class ExprAction(Action):
1641
- _expr: RuntimeExpr
1642
-
1643
- def __str__(self) -> str:
1644
- return str(self._expr)
1645
-
1646
- def _to_egg_action(self) -> bindings.Expr_:
1647
- return bindings.Expr_(self._expr.__egg__)
1648
-
1649
- @property
1650
- def __egg_decls__(self) -> Declarations:
1651
- return self._expr.__egg_decls__
1652
-
1653
-
1654
- @dataclass
1655
- class Delete(Action):
1656
- """
1657
- Remove a function call from an EGraph.
1658
- """
1659
-
1660
- _call: RuntimeExpr
1661
-
1662
- def __str__(self) -> str:
1663
- return f"delete({self._call})"
1664
-
1665
- def _to_egg_action(self) -> bindings.Delete:
1666
- egg_call = self._call.__egg__
1667
- if not isinstance(egg_call, bindings.Call):
1668
- raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004
1669
- return bindings.Delete(egg_call.name, egg_call.args)
1670
-
1671
- @property
1672
- def __egg_decls__(self) -> Declarations:
1673
- return self._call.__egg_decls__
1674
-
1675
-
1676
- @dataclass
1677
- class Union_(Action): # noqa: N801
1678
- """
1679
- Merges two equivalence classes of two expressions.
1680
- """
1681
-
1682
- _lhs: RuntimeExpr
1683
- _rhs: RuntimeExpr
1684
-
1685
- def __str__(self) -> str:
1686
- return f"union({self._lhs}).with_({self._rhs})"
1687
-
1688
- def _to_egg_action(self) -> bindings.Union:
1689
- return bindings.Union(self._lhs.__egg__, self._rhs.__egg__)
1690
-
1691
- @cached_property
1692
- def __egg_decls__(self) -> Declarations:
1693
- return Declarations.create(self._lhs, self._rhs)
1694
-
1695
-
1696
- @dataclass
1697
- class Panic(Action):
1698
- message: str
1699
-
1700
- def __str__(self) -> str:
1701
- return f"panic({self.message})"
1702
-
1703
- def _to_egg_action(self) -> bindings.Panic:
1704
- return bindings.Panic(self.message)
1705
-
1706
- @cached_property
1707
- def __egg_decls__(self) -> Declarations:
1708
- return Declarations()
1709
-
1710
-
1711
- @dataclass
1712
- class Run(Schedule):
1713
- """Configuration of a run"""
1714
-
1715
- # None if using default ruleset
1716
- ruleset: Ruleset | None
1717
- until: tuple[Fact, ...]
1718
-
1719
- def __str__(self) -> str:
1720
- args_str = ", ".join(map(str, [self.ruleset, *self.until]))
1721
- return f"run({args_str})"
1722
-
1723
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1724
- return bindings.Run(self._to_egg_config(default_ruleset_name))
1725
-
1726
- def _to_egg_config(self, default_ruleset_name: str) -> bindings.RunConfig:
1727
- return bindings.RunConfig(
1728
- self.ruleset.egg_name if self.ruleset else default_ruleset_name,
1729
- [fact._to_egg_fact() for fact in self.until] if self.until else None,
1730
- )
1731
-
1732
- def _rulesets(self) -> Iterable[Ruleset]:
1733
- if self.ruleset:
1734
- yield self.ruleset
1735
-
1736
- @cached_property
1737
- def __egg_decls__(self) -> Declarations:
1738
- return Declarations.create(self.ruleset, *self.until)
1436
+ return pretty_decl(self.__egg_decls__, self.action)
1739
1437
 
1740
-
1741
- @dataclass
1742
- class Saturate(Schedule):
1743
- schedule: Schedule
1744
-
1745
- def __str__(self) -> str:
1746
- return f"{self.schedule}.saturate()"
1747
-
1748
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1749
- return bindings.Saturate(self.schedule._to_egg_schedule(default_ruleset_name))
1750
-
1751
- def _rulesets(self) -> Iterable[Ruleset]:
1752
- return self.schedule._rulesets()
1753
-
1754
- @property
1755
- def __egg_decls__(self) -> Declarations:
1756
- return self.schedule.__egg_decls__
1757
-
1758
-
1759
- @dataclass
1760
- class Repeat(Schedule):
1761
- length: int
1762
- schedule: Schedule
1763
-
1764
- def __str__(self) -> str:
1765
- return f"{self.schedule} * {self.length}"
1766
-
1767
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1768
- return bindings.Repeat(self.length, self.schedule._to_egg_schedule(default_ruleset_name))
1769
-
1770
- def _rulesets(self) -> Iterable[Ruleset]:
1771
- return self.schedule._rulesets()
1772
-
1773
- @property
1774
- def __egg_decls__(self) -> Declarations:
1775
- return self.schedule.__egg_decls__
1776
-
1777
-
1778
- @dataclass
1779
- class Sequence(Schedule):
1780
- schedules: tuple[Schedule, ...]
1781
-
1782
- def __str__(self) -> str:
1783
- return f"sequence({', '.join(map(str, self.schedules))})"
1784
-
1785
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1786
- return bindings.Sequence([schedule._to_egg_schedule(default_ruleset_name) for schedule in self.schedules])
1787
-
1788
- def _rulesets(self) -> Iterable[Ruleset]:
1789
- for s in self.schedules:
1790
- yield from s._rulesets()
1791
-
1792
- @cached_property
1793
- def __egg_decls__(self) -> Declarations:
1794
- return Declarations.create(*self.schedules)
1438
+ def __repr__(self) -> str:
1439
+ return str(self)
1795
1440
 
1796
1441
 
1797
1442
  # We use these builders so that when creating these structures we can type check
@@ -1800,16 +1445,16 @@ class Sequence(Schedule):
1800
1445
 
1801
1446
  @deprecated("Use <ruleset>.register(<rewrite>) instead of passing rulesets as arguments to rewrites.")
1802
1447
  @overload
1803
- def rewrite(lhs: EXPR, ruleset: Ruleset) -> _RewriteBuilder[EXPR]: ...
1448
+ def rewrite(lhs: EXPR, ruleset: Ruleset, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...
1804
1449
 
1805
1450
 
1806
1451
  @overload
1807
- def rewrite(lhs: EXPR, ruleset: None = None) -> _RewriteBuilder[EXPR]: ...
1452
+ def rewrite(lhs: EXPR, ruleset: None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]: ...
1808
1453
 
1809
1454
 
1810
- def rewrite(lhs: EXPR, ruleset: Ruleset | None = None) -> _RewriteBuilder[EXPR]:
1455
+ def rewrite(lhs: EXPR, ruleset: Ruleset | None = None, *, subsume: bool = False) -> _RewriteBuilder[EXPR]:
1811
1456
  """Rewrite the given expression to a new expression."""
1812
- return _RewriteBuilder(lhs, ruleset)
1457
+ return _RewriteBuilder(lhs, ruleset, subsume)
1813
1458
 
1814
1459
 
1815
1460
  @deprecated("Use <ruleset>.register(<birewrite>) instead of passing rulesets as arguments to birewrites.")
@@ -1838,25 +1483,41 @@ def ne(expr: EXPR) -> _NeBuilder[EXPR]:
1838
1483
 
1839
1484
  def panic(message: str) -> Action:
1840
1485
  """Raise an error with the given message."""
1841
- return Panic(message)
1486
+ return Action(Declarations(), PanicDecl(message))
1842
1487
 
1843
1488
 
1844
1489
  def let(name: str, expr: Expr) -> Action:
1845
1490
  """Create a let binding."""
1846
- return Let(name, to_runtime_expr(expr))
1491
+ runtime_expr = to_runtime_expr(expr)
1492
+ return Action(runtime_expr.__egg_decls__, LetDecl(name, runtime_expr.__egg_typed_expr__))
1847
1493
 
1848
1494
 
1849
1495
  def expr_action(expr: Expr) -> Action:
1850
- return ExprAction(to_runtime_expr(expr))
1496
+ runtime_expr = to_runtime_expr(expr)
1497
+ return Action(runtime_expr.__egg_decls__, ExprActionDecl(runtime_expr.__egg_typed_expr__))
1851
1498
 
1852
1499
 
1853
1500
  def delete(expr: Expr) -> Action:
1854
1501
  """Create a delete expression."""
1855
- return Delete(to_runtime_expr(expr))
1502
+ runtime_expr = to_runtime_expr(expr)
1503
+ typed_expr = runtime_expr.__egg_typed_expr__
1504
+ call_decl = typed_expr.expr
1505
+ assert isinstance(call_decl, CallDecl), "Can only delete calls, not literals or vars"
1506
+ return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "delete"))
1507
+
1508
+
1509
+ def subsume(expr: Expr) -> Action:
1510
+ """Subsume an expression so it cannot be matched against or extracted"""
1511
+ runtime_expr = to_runtime_expr(expr)
1512
+ typed_expr = runtime_expr.__egg_typed_expr__
1513
+ call_decl = typed_expr.expr
1514
+ assert isinstance(call_decl, CallDecl), "Can only subsume calls, not literals or vars"
1515
+ return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume"))
1856
1516
 
1857
1517
 
1858
1518
  def expr_fact(expr: Expr) -> Fact:
1859
- return ExprFact(to_runtime_expr(expr))
1519
+ runtime_expr = to_runtime_expr(expr)
1520
+ return Fact(runtime_expr.__egg_decls__, ExprFactDecl(runtime_expr.__egg_typed_expr__))
1860
1521
 
1861
1522
 
1862
1523
  def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
@@ -1883,6 +1544,11 @@ def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = No
1883
1544
  return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
1884
1545
 
1885
1546
 
1547
+ @deprecated("This function is now a no-op, you can remove it and use actions as commands")
1548
+ def action_command(action: Action) -> Action:
1549
+ return action
1550
+
1551
+
1886
1552
  def var(name: str, bound: type[EXPR]) -> EXPR:
1887
1553
  """Create a new variable with the given name and type."""
1888
1554
  return cast(EXPR, _var(name, bound))
@@ -1890,9 +1556,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
1890
1556
 
1891
1557
  def _var(name: str, bound: object) -> RuntimeExpr:
1892
1558
  """Create a new variable with the given name and type."""
1893
- if not isinstance(bound, RuntimeClass | RuntimeParamaterizedClass):
1559
+ if not isinstance(bound, RuntimeClass):
1894
1560
  raise TypeError(f"Unexpected type {type(bound)}")
1895
- return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(class_to_ref(bound), VarDecl(name)))
1561
+ return RuntimeExpr.__from_value__(bound.__egg_decls__, TypedExprDecl(bound.__egg_tp__.to_just(), VarDecl(name)))
1896
1562
 
1897
1563
 
1898
1564
  def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
@@ -1905,16 +1571,29 @@ def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
1905
1571
  class _RewriteBuilder(Generic[EXPR]):
1906
1572
  lhs: EXPR
1907
1573
  ruleset: Ruleset | None
1574
+ subsume: bool
1908
1575
 
1909
- def to(self, rhs: EXPR, *conditions: FactLike) -> Rewrite:
1576
+ def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
1910
1577
  lhs = to_runtime_expr(self.lhs)
1911
- rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions))
1578
+ facts = _fact_likes(conditions)
1579
+ rhs = convert_to_same_type(rhs, lhs)
1580
+ rule = RewriteOrRule(
1581
+ Declarations.create(lhs, rhs, *facts, self.ruleset),
1582
+ RewriteDecl(
1583
+ lhs.__egg_typed_expr__.tp,
1584
+ lhs.__egg_typed_expr__.expr,
1585
+ rhs.__egg_typed_expr__.expr,
1586
+ tuple(f.fact for f in facts),
1587
+ self.subsume,
1588
+ ),
1589
+ )
1912
1590
  if self.ruleset:
1913
1591
  self.ruleset.append(rule)
1914
1592
  return rule
1915
1593
 
1916
1594
  def __str__(self) -> str:
1917
- return f"rewrite({self.lhs})"
1595
+ lhs = to_runtime_expr(self.lhs)
1596
+ return lhs.__egg_pretty__("rewrite")
1918
1597
 
1919
1598
 
1920
1599
  @dataclass
@@ -1922,15 +1601,26 @@ class _BirewriteBuilder(Generic[EXPR]):
1922
1601
  lhs: EXPR
1923
1602
  ruleset: Ruleset | None
1924
1603
 
1925
- def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
1604
+ def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
1926
1605
  lhs = to_runtime_expr(self.lhs)
1927
- rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions))
1606
+ facts = _fact_likes(conditions)
1607
+ rhs = convert_to_same_type(rhs, lhs)
1608
+ rule = RewriteOrRule(
1609
+ Declarations.create(lhs, rhs, *facts, self.ruleset),
1610
+ BiRewriteDecl(
1611
+ lhs.__egg_typed_expr__.tp,
1612
+ lhs.__egg_typed_expr__.expr,
1613
+ rhs.__egg_typed_expr__.expr,
1614
+ tuple(f.fact for f in facts),
1615
+ ),
1616
+ )
1928
1617
  if self.ruleset:
1929
1618
  self.ruleset.append(rule)
1930
1619
  return rule
1931
1620
 
1932
1621
  def __str__(self) -> str:
1933
- return f"birewrite({self.lhs})"
1622
+ lhs = to_runtime_expr(self.lhs)
1623
+ return lhs.__egg_pretty__("birewrite")
1934
1624
 
1935
1625
 
1936
1626
  @dataclass
@@ -1939,52 +1629,84 @@ class _EqBuilder(Generic[EXPR]):
1939
1629
 
1940
1630
  def to(self, *exprs: EXPR) -> Fact:
1941
1631
  expr = to_runtime_expr(self.expr)
1942
- return Eq([expr] + [convert_to_same_type(e, expr) for e in exprs])
1632
+ args = [expr, *(convert_to_same_type(e, expr) for e in exprs)]
1633
+ return Fact(
1634
+ Declarations.create(*args),
1635
+ EqDecl(expr.__egg_typed_expr__.tp, tuple(a.__egg_typed_expr__.expr for a in args)),
1636
+ )
1637
+
1638
+ def __repr__(self) -> str:
1639
+ return str(self)
1943
1640
 
1944
1641
  def __str__(self) -> str:
1945
- return f"eq({self.expr})"
1642
+ expr = to_runtime_expr(self.expr)
1643
+ return expr.__egg_pretty__("eq")
1946
1644
 
1947
1645
 
1948
1646
  @dataclass
1949
1647
  class _NeBuilder(Generic[EXPR]):
1950
- expr: EXPR
1648
+ lhs: EXPR
1951
1649
 
1952
- def to(self, expr: EXPR) -> Unit:
1953
- assert isinstance(self.expr, RuntimeExpr)
1954
- args = (self.expr, convert_to_same_type(expr, self.expr))
1955
- decls = Declarations.create(*args)
1956
- res = RuntimeExpr(
1957
- decls,
1958
- TypedExprDecl(JustTypeRef("Unit"), CallDecl(FunctionRef("!="), tuple(a.__egg_typed_expr__ for a in args))),
1650
+ def to(self, rhs: EXPR) -> Unit:
1651
+ lhs = to_runtime_expr(self.lhs)
1652
+ rhs = convert_to_same_type(rhs, lhs)
1653
+ assert isinstance(Unit, RuntimeClass)
1654
+ res = RuntimeExpr.__from_value__(
1655
+ Declarations.create(Unit, lhs, rhs),
1656
+ TypedExprDecl(
1657
+ JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
1658
+ ),
1959
1659
  )
1960
1660
  return cast(Unit, res)
1961
1661
 
1662
+ def __repr__(self) -> str:
1663
+ return str(self)
1664
+
1962
1665
  def __str__(self) -> str:
1963
- return f"ne({self.expr})"
1666
+ expr = to_runtime_expr(self.lhs)
1667
+ return expr.__egg_pretty__("ne")
1964
1668
 
1965
1669
 
1966
1670
  @dataclass
1967
1671
  class _SetBuilder(Generic[EXPR]):
1968
- lhs: Expr
1672
+ lhs: EXPR
1969
1673
 
1970
- def to(self, rhs: EXPR) -> Set:
1674
+ def to(self, rhs: EXPR) -> Action:
1971
1675
  lhs = to_runtime_expr(self.lhs)
1972
- return Set(lhs, convert_to_same_type(rhs, lhs))
1676
+ rhs = convert_to_same_type(rhs, lhs)
1677
+ lhs_expr = lhs.__egg_typed_expr__.expr
1678
+ assert isinstance(lhs_expr, CallDecl), "Can only set function calls"
1679
+ return Action(
1680
+ Declarations.create(lhs, rhs),
1681
+ SetDecl(lhs.__egg_typed_expr__.tp, lhs_expr, rhs.__egg_typed_expr__.expr),
1682
+ )
1683
+
1684
+ def __repr__(self) -> str:
1685
+ return str(self)
1973
1686
 
1974
1687
  def __str__(self) -> str:
1975
- return f"set_({self.lhs})"
1688
+ lhs = to_runtime_expr(self.lhs)
1689
+ return lhs.__egg_pretty__("set_")
1976
1690
 
1977
1691
 
1978
1692
  @dataclass
1979
1693
  class _UnionBuilder(Generic[EXPR]):
1980
- lhs: Expr
1694
+ lhs: EXPR
1981
1695
 
1982
1696
  def with_(self, rhs: EXPR) -> Action:
1983
1697
  lhs = to_runtime_expr(self.lhs)
1984
- return Union_(lhs, convert_to_same_type(rhs, lhs))
1698
+ rhs = convert_to_same_type(rhs, lhs)
1699
+ return Action(
1700
+ Declarations.create(lhs, rhs),
1701
+ UnionDecl(lhs.__egg_typed_expr__.tp, lhs.__egg_typed_expr__.expr, rhs.__egg_typed_expr__.expr),
1702
+ )
1703
+
1704
+ def __repr__(self) -> str:
1705
+ return str(self)
1985
1706
 
1986
1707
  def __str__(self) -> str:
1987
- return f"union({self.lhs})"
1708
+ lhs = to_runtime_expr(self.lhs)
1709
+ return lhs.__egg_pretty__("union")
1988
1710
 
1989
1711
 
1990
1712
  @dataclass
@@ -1993,12 +1715,25 @@ class _RuleBuilder:
1993
1715
  name: str | None
1994
1716
  ruleset: Ruleset | None
1995
1717
 
1996
- def then(self, *actions: ActionLike) -> Rule:
1997
- rule = Rule(_action_likes(actions), self.facts, self.name or "", self.ruleset)
1718
+ def then(self, *actions: ActionLike) -> RewriteOrRule:
1719
+ actions = _action_likes(actions)
1720
+ rule = RewriteOrRule(
1721
+ Declarations.create(self.ruleset, *actions, *self.facts),
1722
+ RuleDecl(tuple(a.action for a in actions), tuple(f.fact for f in self.facts), self.name),
1723
+ )
1998
1724
  if self.ruleset:
1999
1725
  self.ruleset.append(rule)
2000
1726
  return rule
2001
1727
 
1728
+ def __str__(self) -> str:
1729
+ # TODO: Figure out how to stringify rulebuilder that preserves statements
1730
+ args = list(map(str, self.facts))
1731
+ if self.name is not None:
1732
+ args.append(f"name={self.name}")
1733
+ if ruleset is not None:
1734
+ args.append(f"ruleset={self.ruleset}")
1735
+ return f"rule({', '.join(args)})"
1736
+
2002
1737
 
2003
1738
  def expr_parts(expr: Expr) -> TypedExprDecl:
2004
1739
  """
@@ -2015,60 +1750,61 @@ def to_runtime_expr(expr: Expr) -> RuntimeExpr:
2015
1750
  return expr
2016
1751
 
2017
1752
 
2018
- def run(ruleset: Ruleset | None = None, *until: Fact) -> Run:
1753
+ def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
2019
1754
  """
2020
1755
  Create a run configuration.
2021
1756
  """
2022
- return Run(ruleset, tuple(until))
1757
+ facts = _fact_likes(until)
1758
+ return Schedule(
1759
+ Thunk.fn(Declarations.create, ruleset, *facts),
1760
+ RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
1761
+ )
2023
1762
 
2024
1763
 
2025
1764
  def seq(*schedules: Schedule) -> Schedule:
2026
1765
  """
2027
1766
  Run a sequence of schedules.
2028
1767
  """
2029
- return Sequence(tuple(schedules))
1768
+ return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
2030
1769
 
2031
1770
 
2032
- CommandLike = Command | Expr
1771
+ ActionLike: TypeAlias = Action | Expr
2033
1772
 
2034
1773
 
2035
- def _command_like(command_like: CommandLike) -> Command:
2036
- if isinstance(command_like, Expr):
2037
- return expr_action(command_like)
2038
- return command_like
1774
+ def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
1775
+ return tuple(map(_action_like, action_likes))
2039
1776
 
2040
1777
 
2041
- CommandGenerator = Callable[..., Iterable[Rule | Rewrite]]
1778
+ def _action_like(action_like: ActionLike) -> Action:
1779
+ if isinstance(action_like, Expr):
1780
+ return expr_action(action_like)
1781
+ return action_like
2042
1782
 
2043
1783
 
2044
- def _command_generator(gen: CommandGenerator) -> Iterable[Command]:
2045
- """
2046
- Calls the function with variables of the type and name of the arguments.
2047
- """
2048
- # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
2049
- # but not in the globals
2050
- current_frame = inspect.currentframe()
2051
- assert current_frame
2052
- register_frame = current_frame.f_back
2053
- assert register_frame
2054
- original_frame = register_frame.f_back
2055
- assert original_frame
2056
- hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals)
2057
- args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
2058
- return gen(*args)
1784
+ Command: TypeAlias = Action | RewriteOrRule
2059
1785
 
1786
+ CommandLike: TypeAlias = ActionLike | RewriteOrRule
2060
1787
 
2061
- ActionLike = Action | Expr
2062
1788
 
1789
+ def _command_like(command_like: CommandLike) -> Command:
1790
+ if isinstance(command_like, RewriteOrRule):
1791
+ return command_like
1792
+ return _action_like(command_like)
2063
1793
 
2064
- def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
2065
- return tuple(map(_action_like, action_likes))
2066
1794
 
1795
+ RewriteOrRuleGenerator = Callable[..., Iterable[RewriteOrRule]]
2067
1796
 
2068
- def _action_like(action_like: ActionLike) -> Action:
2069
- if isinstance(action_like, Expr):
2070
- return expr_action(action_like)
2071
- return action_like
1797
+
1798
+ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> Iterable[RewriteOrRule]:
1799
+ """
1800
+ Returns a thunk which will call the function with variables of the type and name of the arguments.
1801
+ """
1802
+ # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
1803
+ # but not in the globals
1804
+
1805
+ hints = get_type_hints(gen, gen.__globals__, frame.f_locals)
1806
+ args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1807
+ return list(gen(*args)) # type: ignore[misc]
2072
1808
 
2073
1809
 
2074
1810
  FactLike = Fact | Expr