egglog 6.1.0__cp310-none-win_amd64.whl → 7.0.0__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -3,12 +3,10 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import pathlib
5
5
  import tempfile
6
- from abc import ABC, abstractmethod
6
+ from abc import abstractmethod
7
7
  from collections.abc import Callable, Iterable
8
8
  from contextvars import ContextVar, Token
9
- from copy import deepcopy
10
9
  from dataclasses import InitVar, dataclass, field
11
- from functools import cached_property
12
10
  from inspect import Parameter, currentframe, signature
13
11
  from types import FrameType, FunctionType
14
12
  from typing import (
@@ -18,6 +16,7 @@ from typing import (
18
16
  Generic,
19
17
  Literal,
20
18
  NoReturn,
19
+ TypeAlias,
21
20
  TypedDict,
22
21
  TypeVar,
23
22
  cast,
@@ -26,14 +25,16 @@ from typing import (
26
25
  )
27
26
 
28
27
  import graphviz
29
- from typing_extensions import ParamSpec, Self, Unpack, deprecated
30
-
31
- from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations
28
+ from typing_extensions import ParamSpec, Self, Unpack, assert_never, deprecated
32
29
 
33
30
  from . import bindings
31
+ from .conversion import *
34
32
  from .declarations import *
33
+ from .egraph_state import *
35
34
  from .ipython_magic import IN_IPYTHON
35
+ from .pretty import pretty_decl
36
36
  from .runtime import *
37
+ from .thunk import *
37
38
 
38
39
  if TYPE_CHECKING:
39
40
  import ipywidgets
@@ -58,6 +59,7 @@ __all__ = [
58
59
  "let",
59
60
  "constant",
60
61
  "delete",
62
+ "subsume",
61
63
  "union",
62
64
  "set_",
63
65
  "rule",
@@ -65,6 +67,9 @@ __all__ = [
65
67
  "vars_",
66
68
  "Fact",
67
69
  "expr_parts",
70
+ "expr_action",
71
+ "expr_fact",
72
+ "action_command",
68
73
  "Schedule",
69
74
  "run",
70
75
  "seq",
@@ -79,11 +84,10 @@ __all__ = [
79
84
  "_NeBuilder",
80
85
  "_SetBuilder",
81
86
  "_UnionBuilder",
82
- "Rule",
83
- "Rewrite",
84
- "BiRewrite",
85
- "Union_",
87
+ "RewriteOrRule",
88
+ "Fact",
86
89
  "Action",
90
+ "Command",
87
91
  ]
88
92
 
89
93
  T = TypeVar("T")
@@ -112,7 +116,24 @@ IGNORED_ATTRIBUTES = {
112
116
  }
113
117
 
114
118
 
119
+ # special methods that return none and mutate self
115
120
  ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"}
121
+ # special methods which must return real python values instead of lazy expressions
122
+ ALWAYS_PRESERVED = {
123
+ "__repr__",
124
+ "__str__",
125
+ "__bytes__",
126
+ "__format__",
127
+ "__hash__",
128
+ "__bool__",
129
+ "__len__",
130
+ "__length_hint__",
131
+ "__iter__",
132
+ "__reversed__",
133
+ "__contains__",
134
+ "__index__",
135
+ "__bufer__",
136
+ }
116
137
 
117
138
 
118
139
  def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
@@ -124,7 +145,7 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
124
145
  return EGraph().extract(x)
125
146
 
126
147
 
127
- def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr | Set) -> None:
148
+ def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
128
149
  """
129
150
  Verifies that the fact is true given some assumptions and after running the schedule.
130
151
  """
@@ -136,9 +157,6 @@ def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr |
136
157
  egraph.check(x)
137
158
 
138
159
 
139
- # def extract(res: )
140
-
141
-
142
160
  @dataclass
143
161
  class _BaseModule:
144
162
  """
@@ -181,15 +199,8 @@ class _BaseModule:
181
199
  Registers a class.
182
200
  """
183
201
  if kwargs:
184
- assert set(kwargs.keys()) == {"egg_sort"}
185
-
186
- def _inner(cls: object, egg_sort: str = kwargs["egg_sort"]):
187
- assert isinstance(cls, RuntimeClass)
188
- assert isinstance(cls.lazy_decls, _ClassDeclerationsConstructor)
189
- cls.lazy_decls.egg_sort = egg_sort
190
- return cls
191
-
192
- return _inner
202
+ msg = "Switch to subclassing from Expr and passing egg_sort as a keyword arg to the class constructor"
203
+ raise NotImplementedError(msg)
193
204
 
194
205
  assert len(args) == 1
195
206
  return args[0]
@@ -280,9 +291,9 @@ class _BaseModule:
280
291
  # If we have any positional args, then we are calling it directly on a function
281
292
  if args:
282
293
  assert len(args) == 1
283
- return _function(args[0], fn_locals, False)
294
+ return _FunctionConstructor(fn_locals)(args[0])
284
295
  # otherwise, we are passing some keyword args, so save those, and then return a partial
285
- return lambda fn: _function(fn, fn_locals, False, **kwargs)
296
+ return _FunctionConstructor(fn_locals, **kwargs)
286
297
 
287
298
  @deprecated("Use top level `ruleset` function instead")
288
299
  def ruleset(self, name: str) -> Ruleset:
@@ -324,17 +335,26 @@ class _BaseModule:
324
335
  """
325
336
  return constant(name, tp, egg_name)
326
337
 
327
- def register(self, /, command_or_generator: CommandLike | CommandGenerator, *command_likes: CommandLike) -> None:
338
+ def register(
339
+ self,
340
+ /,
341
+ command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator,
342
+ *command_likes: ActionLike | RewriteOrRule,
343
+ ) -> None:
328
344
  """
329
345
  Registers any number of rewrites or rules.
330
346
  """
331
347
  if isinstance(command_or_generator, FunctionType):
332
348
  assert not command_likes
333
- command_likes = tuple(_command_generator(command_or_generator))
349
+ current_frame = inspect.currentframe()
350
+ assert current_frame
351
+ original_frame = current_frame.f_back
352
+ assert original_frame
353
+ command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame))
334
354
  else:
335
355
  command_likes = (cast(CommandLike, command_or_generator), *command_likes)
336
-
337
- self._register_commands(list(map(_command_like, command_likes)))
356
+ commands = [_command_like(c) for c in command_likes]
357
+ self._register_commands(commands)
338
358
 
339
359
  @abstractmethod
340
360
  def _register_commands(self, cmds: list[Command]) -> None:
@@ -417,136 +437,116 @@ class _ExprMetaclass(type):
417
437
  # If this is the Expr subclass, just return the class
418
438
  if not bases:
419
439
  return super().__new__(cls, name, bases, namespace)
440
+ # TODO: Raise error on subclassing or multiple inheritence
420
441
 
421
442
  frame = currentframe()
422
443
  assert frame
423
444
  prev_frame = frame.f_back
424
445
  assert prev_frame
425
- return _ClassDeclerationsConstructor(
426
- namespace=namespace,
427
- # Store frame so that we can get live access to updated locals/globals
428
- # Otherwise, f_locals returns a copy
429
- # https://peps.python.org/pep-0667/
430
- frame=prev_frame,
431
- builtin=builtin,
432
- egg_sort=egg_sort,
433
- cls_name=name,
434
- ).current_cls
446
+
447
+ # Store frame so that we can get live access to updated locals/globals
448
+ # Otherwise, f_locals returns a copy
449
+ # https://peps.python.org/pep-0667/
450
+ decls_thunk = Thunk.fn(
451
+ _generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
452
+ )
453
+ return RuntimeClass(decls_thunk, TypeRefWithVars(name))
435
454
 
436
455
  def __instancecheck__(cls, instance: object) -> bool:
437
456
  return isinstance(instance, RuntimeExpr)
438
457
 
439
458
 
440
- @dataclass
441
- class _ClassDeclerationsConstructor:
459
+ def _generate_class_decls(
460
+ namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
461
+ ) -> Declarations:
442
462
  """
443
463
  Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
444
464
  """
465
+ parameters: list[TypeVar] = (
466
+ # Get the generic params from the orig bases generic class
467
+ namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in namespace else []
468
+ )
469
+ type_vars = tuple(p.__name__ for p in parameters)
470
+ del parameters
471
+ cls_decl = ClassDecl(egg_sort, type_vars, builtin)
472
+ decls = Declarations(_classes={cls_name: cls_decl})
473
+
474
+ ##
475
+ # Register class variables
476
+ ##
477
+ # Create a dummy type to pass to get_type_hints to resolve the annotations we have
478
+ _Dummytype = type("_DummyType", (), {"__annotations__": namespace.get("__annotations__", {})})
479
+ for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
480
+ if getattr(v, "__origin__", None) == ClassVar:
481
+ (inner_tp,) = v.__args__
482
+ type_ref = resolve_type_annotation(decls, inner_tp).to_just()
483
+ cls_decl.class_variables[k] = ConstantDecl(type_ref)
445
484
 
446
- namespace: dict[str, Any]
447
- frame: FrameType
448
- builtin: bool
449
- egg_sort: str | None
450
- cls_name: str
451
- current_cls: RuntimeClass = field(init=False)
485
+ else:
486
+ msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
487
+ raise NotImplementedError(msg)
452
488
 
453
- def __post_init__(self) -> None:
454
- self.current_cls = RuntimeClass(self, self.cls_name)
455
-
456
- def __call__(self, decls: Declarations) -> None: # noqa: PLR0912
457
- # Get all the methods from the class
458
- cls_dict: dict[str, Any] = {
459
- k: v for k, v in self.namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
460
- }
461
- parameters: list[TypeVar] = (
462
- # Get the generic params from the orig bases generic class
463
- self.namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in self.namespace else []
464
- )
465
- type_vars = tuple(p.__name__ for p in parameters)
466
- del parameters
467
-
468
- decls.register_class(self.cls_name, type_vars, self.builtin, self.egg_sort)
469
- # The type ref of self is paramterized by the type vars
470
- slf_type_ref = TypeRefWithVars(self.cls_name, tuple(map(ClassTypeVarRef, type_vars)))
471
-
472
- # Create a dummy type to pass to get_type_hints to resolve the annotations we have
473
- class _Dummytype:
474
- pass
475
-
476
- _Dummytype.__annotations__ = self.namespace.get("__annotations__", {})
477
- # Make lazy update to locals, so we keep a live handle on them after class creation
478
- locals = self.frame.f_locals.copy()
479
- locals[self.cls_name] = self.current_cls
480
- for k, v in get_type_hints(_Dummytype, globalns=self.frame.f_globals, localns=locals).items():
481
- if getattr(v, "__origin__", None) == ClassVar:
482
- (inner_tp,) = v.__args__
483
- _register_constant(decls, ClassVariableRef(self.cls_name, k), inner_tp, None)
484
- else:
485
- msg = f"On class {self.cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
486
- raise NotImplementedError(msg)
487
-
488
- # Then register each of its methods
489
- for method_name, method in cls_dict.items():
490
- is_init = method_name == "__init__"
491
- # Don't register the init methods for literals, since those don't use the type checking mechanisms
492
- if is_init and self.cls_name in LIT_CLASS_NAMES:
493
- continue
494
- if isinstance(method, _WrappedMethod):
495
- fn = method.fn
496
- egg_fn = method.egg_fn
497
- cost = method.cost
498
- default = method.default
499
- merge = method.merge
500
- on_merge = method.on_merge
501
- mutates_first_arg = method.mutates_self
502
- unextractable = method.unextractable
503
- if method.preserve:
504
- decls.register_preserved_method(self.cls_name, method_name, fn)
505
- continue
506
- else:
507
- fn = method
489
+ ##
490
+ # Register methods, classmethods, preserved methods, and properties
491
+ ##
492
+
493
+ # The type ref of self is paramterized by the type vars
494
+ slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
495
+
496
+ # Get all the methods from the class
497
+ filtered_namespace: list[tuple[str, Any]] = [
498
+ (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
499
+ ]
500
+
501
+ # Then register each of its methods
502
+ for method_name, method in filtered_namespace:
503
+ is_init = method_name == "__init__"
504
+ # Don't register the init methods for literals, since those don't use the type checking mechanisms
505
+ if is_init and cls_name in LIT_CLASS_NAMES:
506
+ continue
507
+ match method:
508
+ case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable):
509
+ pass
510
+ case _:
508
511
  egg_fn, cost, default, merge, on_merge = None, None, None, None, None
509
- unextractable = False
510
- mutates_first_arg = False
511
- if isinstance(fn, classmethod):
512
- fn = fn.__func__
513
- is_classmethod = True
514
- else:
515
- # We count __init__ as a classmethod since it is called on the class
516
- is_classmethod = is_init
517
-
518
- if isinstance(fn, property):
519
- fn = fn.fget
520
- is_property = True
521
- if is_classmethod:
522
- msg = "Can't have a classmethod property"
523
- raise NotImplementedError(msg)
524
- else:
525
- is_property = False
526
- ref: FunctionCallableRef = (
527
- ClassMethodRef(self.cls_name, method_name)
528
- if is_classmethod
529
- else PropertyRef(self.cls_name, method_name)
530
- if is_property
531
- else MethodRef(self.cls_name, method_name)
532
- )
533
- _register_function(
512
+ fn = method
513
+ unextractable, preserve = False, False
514
+ mutates = method_name in ALWAYS_MUTATES_SELF
515
+ if preserve:
516
+ cls_decl.preserved_methods[method_name] = fn
517
+ continue
518
+ locals = frame.f_locals
519
+
520
+ def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
521
+ return _fn_decl(
534
522
  decls,
535
- ref,
536
- egg_fn,
523
+ egg_fn, # noqa: B023
537
524
  fn,
538
- locals,
539
- default,
540
- cost,
541
- merge,
542
- on_merge,
543
- mutates_first_arg or method_name in ALWAYS_MUTATES_SELF,
544
- self.builtin,
545
- "cls" if is_classmethod and not is_init else slf_type_ref,
546
- is_init,
547
- unextractable=unextractable,
525
+ locals, # noqa: B023
526
+ default, # noqa: B023
527
+ cost, # noqa: B023
528
+ merge, # noqa: B023
529
+ on_merge, # noqa: B023
530
+ mutates, # noqa: B023
531
+ builtin,
532
+ first,
533
+ is_init, # noqa: B023
534
+ unextractable, # noqa: B023
548
535
  )
549
536
 
537
+ match fn:
538
+ case classmethod():
539
+ cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
540
+ case property():
541
+ cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
542
+ case _:
543
+ if is_init:
544
+ cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
545
+ else:
546
+ cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
547
+
548
+ return decls
549
+
550
550
 
551
551
  @overload
552
552
  def function(fn: CALLABLE, /) -> CALLABLE: ...
@@ -589,48 +589,46 @@ def function(*args, **kwargs) -> Any:
589
589
  # If we have any positional args, then we are calling it directly on a function
590
590
  if args:
591
591
  assert len(args) == 1
592
- return _function(args[0], fn_locals, False)
592
+ return _FunctionConstructor(fn_locals)(args[0])
593
593
  # otherwise, we are passing some keyword args, so save those, and then return a partial
594
- return lambda fn: _function(fn, fn_locals, **kwargs)
594
+ return _FunctionConstructor(fn_locals, **kwargs)
595
595
 
596
596
 
597
- def _function(
598
- fn: Callable[..., RuntimeExpr],
599
- hint_locals: dict[str, Any],
600
- builtin: bool = False,
601
- mutates_first_arg: bool = False,
602
- egg_fn: str | None = None,
603
- cost: int | None = None,
604
- default: RuntimeExpr | None = None,
605
- merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None,
606
- on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None,
607
- unextractable: bool = False,
608
- ) -> RuntimeFunction:
609
- """
610
- Uncurried version of function decorator
611
- """
612
- name = fn.__name__
613
- decls = Declarations()
614
- _register_function(
615
- decls,
616
- FunctionRef(name),
617
- egg_fn,
618
- fn,
619
- hint_locals,
620
- default,
621
- cost,
622
- merge,
623
- on_merge,
624
- mutates_first_arg,
625
- builtin,
626
- unextractable=unextractable,
627
- )
628
- return RuntimeFunction(decls, name)
597
+ @dataclass
598
+ class _FunctionConstructor:
599
+ hint_locals: dict[str, Any]
600
+ builtin: bool = False
601
+ mutates_first_arg: bool = False
602
+ egg_fn: str | None = None
603
+ cost: int | None = None
604
+ default: RuntimeExpr | None = None
605
+ merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
606
+ on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
607
+ unextractable: bool = False
608
+
609
+ def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
610
+ return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
611
+
612
+ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
613
+ decls = Declarations()
614
+ decls._functions[fn.__name__] = _fn_decl(
615
+ decls,
616
+ self.egg_fn,
617
+ fn,
618
+ self.hint_locals,
619
+ self.default,
620
+ self.cost,
621
+ self.merge,
622
+ self.on_merge,
623
+ self.mutates_first_arg,
624
+ self.builtin,
625
+ unextractable=self.unextractable,
626
+ )
627
+ return decls
629
628
 
630
629
 
631
- def _register_function(
630
+ def _fn_decl(
632
631
  decls: Declarations,
633
- ref: FunctionCallableRef,
634
632
  egg_name: str | None,
635
633
  fn: object,
636
634
  # Pass in the locals, retrieved from the frame when wrapping,
@@ -646,7 +644,7 @@ def _register_function(
646
644
  first_arg: Literal["cls"] | TypeOrVarRef | None = None,
647
645
  is_init: bool = False,
648
646
  unextractable: bool = False,
649
- ) -> None:
647
+ ) -> FunctionDecl:
650
648
  if not isinstance(fn, FunctionType):
651
649
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
652
650
 
@@ -699,8 +697,8 @@ def _register_function(
699
697
  None
700
698
  if merge is None
701
699
  else merge(
702
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
703
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
700
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
701
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
704
702
  )
705
703
  )
706
704
  decls |= merged
@@ -710,30 +708,25 @@ def _register_function(
710
708
  if on_merge is None
711
709
  else _action_likes(
712
710
  on_merge(
713
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
714
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
711
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
712
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
715
713
  )
716
714
  )
717
715
  )
718
716
  decls.update(*merge_action)
719
- fn_decl = FunctionDecl(
720
- return_type=return_type,
717
+ return FunctionDecl(
718
+ return_type=None if mutates_first_arg else return_type,
721
719
  var_arg_type=var_arg_type,
722
720
  arg_types=arg_types,
723
721
  arg_names=tuple(t.name for t in params),
724
722
  arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
725
- mutates_first_arg=mutates_first_arg,
726
- )
727
- decls.register_function_callable(
728
- ref,
729
- fn_decl,
730
- egg_name,
731
- cost,
732
- None if default is None else default.__egg_typed_expr__.expr,
733
- merged.__egg_typed_expr__.expr if merged is not None else None,
734
- [a._to_egg_action() for a in merge_action],
735
- unextractable,
736
- is_builtin,
723
+ cost=cost,
724
+ egg_name=egg_name,
725
+ merge=merged.__egg_typed_expr__.expr if merged is not None else None,
726
+ unextractable=unextractable,
727
+ builtin=is_builtin,
728
+ default=None if default is None else default.__egg_typed_expr__.expr,
729
+ on_merge=tuple(a.action for a in merge_action),
737
730
  )
738
731
 
739
732
 
@@ -764,49 +757,31 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
764
757
  """
765
758
  Creates a function whose return type is `Unit` and has a default value.
766
759
  """
760
+ decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
761
+ return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
762
+
763
+
764
+ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
767
765
  decls = Declarations()
768
766
  decls |= cast(RuntimeClass, Unit)
769
- arg_types = tuple(resolve_type_annotation(decls, tp) for tp in tps)
770
- fn_decl = FunctionDecl(arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit"), mutates_first_arg=False)
771
- decls.register_function_callable(
772
- FunctionRef(name),
773
- fn_decl,
774
- egg_fn,
775
- cost=None,
776
- default=None,
777
- merge=None,
778
- merge_action=[],
779
- unextractable=False,
780
- builtin=False,
781
- is_relation=True,
782
- )
783
- return cast(Callable[..., Unit], RuntimeFunction(decls, name))
767
+ arg_types = tuple(resolve_type_annotation(decls, tp).to_just() for tp in tps)
768
+ decls._functions[name] = RelationDecl(arg_types, tuple(None for _ in tps), egg_fn)
769
+ return decls
784
770
 
785
771
 
786
772
  def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
787
773
  """
788
-
789
774
  A "constant" is implemented as the instantiation of a value that takes no args.
790
775
  This creates a function with `name` and return type `tp` and returns a value of it being called.
791
776
  """
792
- ref = ConstantRef(name)
793
- decls = Declarations()
794
- type_ref = _register_constant(decls, ref, tp, egg_name)
795
- return cast(EXPR, RuntimeExpr(decls, TypedExprDecl(type_ref, CallDecl(ref))))
777
+ return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
796
778
 
797
779
 
798
- def _register_constant(
799
- decls: Declarations,
800
- ref: ConstantRef | ClassVariableRef,
801
- tp: object,
802
- egg_name: str | None,
803
- ) -> JustTypeRef:
804
- """
805
- Register a constant, returning its typeref().
806
- """
780
+ def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
781
+ decls = Declarations()
807
782
  type_ref = resolve_type_annotation(decls, tp).to_just()
808
- decls.register_constant_callable(ref, type_ref, egg_name)
809
- return type_ref
783
+ decls._constants[name] = ConstantDecl(type_ref, egg_name)
784
+ return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
810
785
 
811
786
 
812
787
  def _last_param_variable(params: list[Parameter]) -> bool:
@@ -858,29 +833,6 @@ class GraphvizKwargs(TypedDict, total=False):
858
833
  split_primitive_outputs: bool
859
834
 
860
835
 
861
- @dataclass
862
- class _EGraphState:
863
- """
864
- State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
865
- """
866
-
867
- # The decleratons we have added. The _cmds represent all the symbols we have added
868
- decls: Declarations = field(default_factory=Declarations)
869
- # List of rulesets already added, so we don't re-add them if they are passed again
870
- added_rulesets: set[str] = field(default_factory=set)
871
-
872
- def add_decls(self, new_decls: Declarations) -> Iterable[bindings._Command]:
873
- new_cmds = [v for k, v in new_decls._cmds.items() if k not in self.decls._cmds]
874
- self.decls |= new_decls
875
- return new_cmds
876
-
877
- def add_rulesets(self, rulesets: Iterable[Ruleset]) -> Iterable[bindings._Command]:
878
- for ruleset in rulesets:
879
- if ruleset.egg_name not in self.added_rulesets:
880
- self.added_rulesets.add(ruleset.egg_name)
881
- yield from ruleset._cmds
882
-
883
-
884
836
  @dataclass
885
837
  class EGraph(_BaseModule):
886
838
  """
@@ -892,56 +844,34 @@ class EGraph(_BaseModule):
892
844
  seminaive: InitVar[bool] = True
893
845
  save_egglog_string: InitVar[bool] = False
894
846
 
895
- default_ruleset: Ruleset | None = None
896
- _egraph: bindings.EGraph = field(repr=False, init=False)
897
- _egglog_string: str | None = field(default=None, repr=False, init=False)
898
- _state: _EGraphState = field(default_factory=_EGraphState, repr=False)
847
+ _state: EGraphState = field(init=False)
899
848
  # For pushing/popping with egglog
900
- _state_stack: list[_EGraphState] = field(default_factory=list, repr=False)
849
+ _state_stack: list[EGraphState] = field(default_factory=list, repr=False)
901
850
  # For storing the global "current" egraph
902
851
  _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False)
903
852
 
904
853
  def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None:
905
- self._egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive)
854
+ egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string)
855
+ self._state = EGraphState(egraph)
906
856
  super().__post_init__(modules)
907
857
 
908
858
  for m in self._flatted_deps:
909
- self._add_decls(*m.cmds)
910
859
  self._register_commands(m.cmds)
911
- if save_egglog_string:
912
- self._egglog_string = ""
913
-
914
- def _register_commands(self, commands: list[Command]) -> None:
915
- for c in commands:
916
- if c.ruleset:
917
- self._add_schedule(c.ruleset)
918
-
919
- self._add_decls(*commands)
920
- self._process_commands(command._to_egg_command(self._default_ruleset_name) for command in commands)
921
-
922
- def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
923
- commands = list(commands)
924
- self._egraph.run_program(*commands)
925
- if isinstance(self._egglog_string, str):
926
- self._egglog_string += "\n".join(str(c) for c in commands) + "\n"
927
860
 
928
861
  def _add_decls(self, *decls: DeclerationsLike) -> None:
929
- for d in upcast_decleratioons(decls):
930
- self._process_commands(self._state.add_decls(d))
931
-
932
- def _add_schedule(self, schedule: Schedule) -> None:
933
- self._add_decls(schedule)
934
- self._process_commands(self._state.add_rulesets(schedule._rulesets()))
862
+ for d in decls:
863
+ self._state.__egg_decls__ |= d
935
864
 
936
865
  @property
937
866
  def as_egglog_string(self) -> str:
938
867
  """
939
868
  Returns the egglog string for this module.
940
869
  """
941
- if self._egglog_string is None:
870
+ cmds = self._egraph.commands()
871
+ if cmds is None:
942
872
  msg = "Can't get egglog string unless EGraph created with save_egglog_string=True"
943
873
  raise ValueError(msg)
944
- return self._egglog_string
874
+ return cmds
945
875
 
946
876
  def _repr_mimebundle_(self, *args, **kwargs):
947
877
  """
@@ -954,7 +884,7 @@ class EGraph(_BaseModule):
954
884
  kwargs.setdefault("split_primitive_outputs", True)
955
885
  n_inline = kwargs.pop("n_inline_leaves", 0)
956
886
  serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
957
- serialized.map_ops(self._state.decls.op_mapping())
887
+ serialized.map_ops(self._state.op_mapping())
958
888
  for _ in range(n_inline):
959
889
  serialized.inline_leaves()
960
890
  original = serialized.to_dot()
@@ -1016,17 +946,23 @@ class EGraph(_BaseModule):
1016
946
  Loads a CSV file and sets it as *input, output of the function.
1017
947
  """
1018
948
  ref, decls = resolve_callable(fn)
1019
- fn_name = decls.get_egg_fn(ref)
1020
- self._process_commands(decls.list_cmds())
1021
- self._process_commands([bindings.Input(fn_name, path)])
949
+ self._add_decls(decls)
950
+ fn_name = self._state.callable_ref_to_egg(ref)
951
+ self._egraph.run_program(bindings.Input(fn_name, path))
1022
952
 
1023
953
  def let(self, name: str, expr: EXPR) -> EXPR:
1024
954
  """
1025
955
  Define a new expression in the egraph and return a reference to it.
1026
956
  """
1027
- self._register_commands([let(name, expr)])
1028
- expr = to_runtime_expr(expr)
1029
- return cast(EXPR, RuntimeExpr(expr.__egg_decls__, TypedExprDecl(expr.__egg_typed_expr__.tp, VarDecl(name))))
957
+ action = let(name, expr)
958
+ self.register(action)
959
+ runtime_expr = to_runtime_expr(expr)
960
+ return cast(
961
+ EXPR,
962
+ RuntimeExpr.__from_value__(
963
+ self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
964
+ ),
965
+ )
1030
966
 
1031
967
  @overload
1032
968
  def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ...
@@ -1041,28 +977,19 @@ class EGraph(_BaseModule):
1041
977
  Simplifies the given expression.
1042
978
  """
1043
979
  schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
1044
- del limit_or_schedule
1045
- expr = to_runtime_expr(expr)
1046
- self._add_decls(expr)
1047
- self._add_schedule(schedule)
1048
-
1049
- # decls = Declarations.create(expr, schedule)
1050
- self._process_commands([bindings.Simplify(expr.__egg__, schedule._to_egg_schedule(self._default_ruleset_name))])
980
+ del limit_or_schedule, until, ruleset
981
+ runtime_expr = to_runtime_expr(expr)
982
+ self._add_decls(runtime_expr, schedule)
983
+ egg_schedule = self._state.schedule_to_egg(schedule.schedule)
984
+ typed_expr = runtime_expr.__egg_typed_expr__
985
+ egg_expr = self._state.expr_to_egg(typed_expr.expr)
986
+ self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
1051
987
  extract_report = self._egraph.extract_report()
1052
988
  if not isinstance(extract_report, bindings.Best):
1053
989
  msg = "No extract report saved"
1054
990
  raise ValueError(msg) # noqa: TRY004
1055
- new_typed_expr = TypedExprDecl.from_egg(
1056
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, extract_report.term, {}
1057
- )
1058
- return cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
1059
-
1060
- @property
1061
- def _default_ruleset_name(self) -> str:
1062
- if self.default_ruleset:
1063
- self._add_schedule(self.default_ruleset)
1064
- return self.default_ruleset.egg_name
1065
- return ""
991
+ (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
992
+ return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1066
993
 
1067
994
  def include(self, path: str) -> None:
1068
995
  """
@@ -1092,8 +1019,9 @@ class EGraph(_BaseModule):
1092
1019
  return self._run_schedule(limit_or_schedule)
1093
1020
 
1094
1021
  def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
1095
- self._add_schedule(schedule)
1096
- self._process_commands([bindings.RunSchedule(schedule._to_egg_schedule(self._default_ruleset_name))])
1022
+ self._add_decls(schedule)
1023
+ egg_schedule = self._state.schedule_to_egg(schedule.schedule)
1024
+ self._egraph.run_program(bindings.RunSchedule(egg_schedule))
1097
1025
  run_report = self._egraph.run_report()
1098
1026
  if not run_report:
1099
1027
  msg = "No run report saved"
@@ -1104,18 +1032,18 @@ class EGraph(_BaseModule):
1104
1032
  """
1105
1033
  Check if a fact is true in the egraph.
1106
1034
  """
1107
- self._process_commands([self._facts_to_check(facts)])
1035
+ self._egraph.run_program(self._facts_to_check(facts))
1108
1036
 
1109
1037
  def check_fail(self, *facts: FactLike) -> None:
1110
1038
  """
1111
1039
  Checks that one of the facts is not true
1112
1040
  """
1113
- self._process_commands([bindings.Fail(self._facts_to_check(facts))])
1041
+ self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
1114
1042
 
1115
- def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check:
1116
- facts = _fact_likes(facts)
1043
+ def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
1044
+ facts = _fact_likes(fact_likes)
1117
1045
  self._add_decls(*facts)
1118
- egg_facts = [f._to_egg_fact() for f in _fact_likes(facts)]
1046
+ egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
1119
1047
  return bindings.Check(egg_facts)
1120
1048
 
1121
1049
  @overload
@@ -1128,16 +1056,17 @@ class EGraph(_BaseModule):
1128
1056
  """
1129
1057
  Extract the lowest cost expression from the egraph.
1130
1058
  """
1131
- assert isinstance(expr, RuntimeExpr)
1132
- self._add_decls(expr)
1133
- extract_report = self._run_extract(expr.__egg__, 0)
1059
+ runtime_expr = to_runtime_expr(expr)
1060
+ self._add_decls(runtime_expr)
1061
+ typed_expr = runtime_expr.__egg_typed_expr__
1062
+ extract_report = self._run_extract(typed_expr.expr, 0)
1063
+
1134
1064
  if not isinstance(extract_report, bindings.Best):
1135
1065
  msg = "No extract report saved"
1136
1066
  raise ValueError(msg) # noqa: TRY004
1137
- new_typed_expr = TypedExprDecl.from_egg(
1138
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, extract_report.term, {}
1139
- )
1140
- res = cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
1067
+ (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1068
+
1069
+ res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1141
1070
  if include_cost:
1142
1071
  return res, extract_report.cost
1143
1072
  return res
@@ -1146,23 +1075,20 @@ class EGraph(_BaseModule):
1146
1075
  """
1147
1076
  Extract multiple expressions from the egraph.
1148
1077
  """
1149
- assert isinstance(expr, RuntimeExpr)
1150
- self._add_decls(expr)
1078
+ runtime_expr = to_runtime_expr(expr)
1079
+ self._add_decls(runtime_expr)
1080
+ typed_expr = runtime_expr.__egg_typed_expr__
1151
1081
 
1152
- extract_report = self._run_extract(expr.__egg__, n)
1082
+ extract_report = self._run_extract(typed_expr.expr, n)
1153
1083
  if not isinstance(extract_report, bindings.Variants):
1154
1084
  msg = "Wrong extract report type"
1155
1085
  raise ValueError(msg) # noqa: TRY004
1156
- new_exprs = [
1157
- TypedExprDecl.from_egg(
1158
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, term, {}
1159
- )
1160
- for term in extract_report.terms
1161
- ]
1162
- return [cast(EXPR, RuntimeExpr(self._state.decls, expr)) for expr in new_exprs]
1086
+ new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
1087
+ return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1163
1088
 
1164
- def _run_extract(self, expr: bindings._Expr, n: int) -> bindings._ExtractReport:
1165
- self._process_commands([bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n))))])
1089
+ def _run_extract(self, expr: ExprDecl, n: int) -> bindings._ExtractReport:
1090
+ expr = self._state.expr_to_egg(expr)
1091
+ self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1166
1092
  extract_report = self._egraph.extract_report()
1167
1093
  if not extract_report:
1168
1094
  msg = "No extract report saved"
@@ -1173,15 +1099,15 @@ class EGraph(_BaseModule):
1173
1099
  """
1174
1100
  Push the current state of the egraph, so that it can be popped later and reverted back.
1175
1101
  """
1176
- self._process_commands([bindings.Push(1)])
1102
+ self._egraph.run_program(bindings.Push(1))
1177
1103
  self._state_stack.append(self._state)
1178
- self._state = deepcopy(self._state)
1104
+ self._state = self._state.copy()
1179
1105
 
1180
1106
  def pop(self) -> None:
1181
1107
  """
1182
1108
  Pop the current state of the egraph, reverting back to the previous state.
1183
1109
  """
1184
- self._process_commands([bindings.Pop(1)])
1110
+ self._egraph.run_program(bindings.Pop(1))
1185
1111
  self._state = self._state_stack.pop()
1186
1112
 
1187
1113
  def __enter__(self) -> Self:
@@ -1217,9 +1143,10 @@ class EGraph(_BaseModule):
1217
1143
  """
1218
1144
  Evaluates the given expression (which must be a primitive type), returning the result.
1219
1145
  """
1220
- assert isinstance(expr, RuntimeExpr)
1221
- typed_expr = expr.__egg_typed_expr__
1222
- egg_expr = expr.__egg__
1146
+ runtime_expr = to_runtime_expr(expr)
1147
+ self._add_decls(runtime_expr)
1148
+ typed_expr = runtime_expr.__egg_typed_expr__
1149
+ egg_expr = self._state.expr_to_egg(typed_expr.expr)
1223
1150
  match typed_expr.tp:
1224
1151
  case JustTypeRef("i64"):
1225
1152
  return self._egraph.eval_i64(egg_expr)
@@ -1231,7 +1158,7 @@ class EGraph(_BaseModule):
1231
1158
  return self._egraph.eval_string(egg_expr)
1232
1159
  case JustTypeRef("PyObject"):
1233
1160
  return self._egraph.eval_py_object(egg_expr)
1234
- raise NotImplementedError(f"Eval not implemented for {typed_expr.tp.name}")
1161
+ raise TypeError(f"Eval not implemented for {typed_expr.tp}")
1235
1162
 
1236
1163
  def saturate(
1237
1164
  self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
@@ -1270,6 +1197,32 @@ class EGraph(_BaseModule):
1270
1197
  """
1271
1198
  return CURRENT_EGRAPH.get()
1272
1199
 
1200
+ @property
1201
+ def _egraph(self) -> bindings.EGraph:
1202
+ return self._state.egraph
1203
+
1204
+ @property
1205
+ def __egg_decls__(self) -> Declarations:
1206
+ return self._state.__egg_decls__
1207
+
1208
+ def _register_commands(self, cmds: list[Command]) -> None:
1209
+ self._add_decls(*cmds)
1210
+ egg_cmds = list(map(self._command_to_egg, cmds))
1211
+ self._egraph.run_program(*egg_cmds)
1212
+
1213
+ def _command_to_egg(self, cmd: Command) -> bindings._Command:
1214
+ ruleset_name = ""
1215
+ cmd_decl: CommandDecl
1216
+ match cmd:
1217
+ case RewriteOrRule(_, cmd_decl, ruleset):
1218
+ if ruleset:
1219
+ ruleset_name = ruleset.__egg_name__
1220
+ case Action(_, action):
1221
+ cmd_decl = ActionCommandDecl(action)
1222
+ case _:
1223
+ assert_never(cmd)
1224
+ return self._state.command_to_egg(cmd_decl, ruleset_name)
1225
+
1273
1226
 
1274
1227
  CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH")
1275
1228
 
@@ -1316,61 +1269,53 @@ class Unit(Expr, egg_sort="Unit", builtin=True):
1316
1269
 
1317
1270
 
1318
1271
  def ruleset(
1319
- rule_or_generator: CommandLike | CommandGenerator | None = None, *rules: Rule | Rewrite, name: None | str = None
1272
+ rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator | None = None,
1273
+ *rules: RewriteOrRule,
1274
+ name: None | str = None,
1320
1275
  ) -> Ruleset:
1321
1276
  """
1322
1277
  Creates a ruleset with the following rules.
1323
1278
 
1324
1279
  If no name is provided, one is generated based on the current module
1325
1280
  """
1326
- r = Ruleset(name=name)
1281
+ r = Ruleset(name)
1327
1282
  if rule_or_generator is not None:
1328
- r.register(rule_or_generator, *rules)
1283
+ r.register(rule_or_generator, *rules, _increase_frame=True)
1329
1284
  return r
1330
1285
 
1331
1286
 
1332
- class Schedule(ABC):
1287
+ @dataclass
1288
+ class Schedule(DelayedDeclerations):
1333
1289
  """
1334
1290
  A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
1335
1291
  """
1336
1292
 
1293
+ # Defer declerations so that we can have rule generators that used not yet defined yet
1294
+ schedule: ScheduleDecl
1295
+
1296
+ def __str__(self) -> str:
1297
+ return pretty_decl(self.__egg_decls__, self.schedule)
1298
+
1299
+ def __repr__(self) -> str:
1300
+ return str(self)
1301
+
1337
1302
  def __mul__(self, length: int) -> Schedule:
1338
1303
  """
1339
1304
  Repeat the schedule a number of times.
1340
1305
  """
1341
- return Repeat(length, self)
1306
+ return Schedule(self.__egg_decls_thunk__, RepeatDecl(self.schedule, length))
1342
1307
 
1343
1308
  def saturate(self) -> Schedule:
1344
1309
  """
1345
1310
  Run the schedule until the e-graph is saturated.
1346
1311
  """
1347
- return Saturate(self)
1312
+ return Schedule(self.__egg_decls_thunk__, SaturateDecl(self.schedule))
1348
1313
 
1349
1314
  def __add__(self, other: Schedule) -> Schedule:
1350
1315
  """
1351
1316
  Run two schedules in sequence.
1352
1317
  """
1353
- return Sequence((self, other))
1354
-
1355
- @abstractmethod
1356
- def __str__(self) -> str:
1357
- raise NotImplementedError
1358
-
1359
- @abstractmethod
1360
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1361
- raise NotImplementedError
1362
-
1363
- @abstractmethod
1364
- def _rulesets(self) -> Iterable[Ruleset]:
1365
- """
1366
- Mapping of all the rulesets used to commands.
1367
- """
1368
- raise NotImplementedError
1369
-
1370
- @property
1371
- @abstractmethod
1372
- def __egg_decls__(self) -> Declarations:
1373
- raise NotImplementedError
1318
+ return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
1374
1319
 
1375
1320
 
1376
1321
  @dataclass
@@ -1379,422 +1324,119 @@ class Ruleset(Schedule):
1379
1324
  A collection of rules, which can be run as a schedule.
1380
1325
  """
1381
1326
 
1327
+ __egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
1328
+ schedule: RunDecl = field(init=False)
1382
1329
  name: str | None
1383
- rules: list[Rule | Rewrite] = field(default_factory=list)
1384
1330
 
1385
- def append(self, rule: Rule | Rewrite) -> None:
1331
+ # Current declerations we have accumulated
1332
+ _current_egg_decls: Declarations = field(default_factory=Declarations)
1333
+ # Current rulesets we have accumulated
1334
+ __egg_ruleset__: RulesetDecl = field(init=False)
1335
+ # Rule generator functions that have been deferred, to allow for late type binding
1336
+ deferred_rule_gens: list[Callable[[], Iterable[RewriteOrRule]]] = field(default_factory=list)
1337
+
1338
+ def __post_init__(self) -> None:
1339
+ self.schedule = RunDecl(self.__egg_name__, ())
1340
+ self.__egg_ruleset__ = self._current_egg_decls._rulesets[self.__egg_name__] = RulesetDecl([])
1341
+ self.__egg_decls_thunk__ = self._update_egg_decls
1342
+
1343
+ def _update_egg_decls(self) -> Declarations:
1344
+ """
1345
+ To return the egg decls, we go through our deferred rules and add any we haven't yet
1346
+ """
1347
+ while self.deferred_rule_gens:
1348
+ rules = self.deferred_rule_gens.pop()()
1349
+ self._current_egg_decls.update(*rules)
1350
+ self.__egg_ruleset__.rules.extend(r.decl for r in rules)
1351
+ return self._current_egg_decls
1352
+
1353
+ def append(self, rule: RewriteOrRule) -> None:
1386
1354
  """
1387
1355
  Register a rule with the ruleset.
1388
1356
  """
1389
- self.rules.append(rule)
1357
+ self._current_egg_decls |= rule
1358
+ self.__egg_ruleset__.rules.append(rule.decl)
1390
1359
 
1391
- def register(self, /, rule_or_generator: CommandLike | CommandGenerator, *rules: Rule | Rewrite) -> None:
1360
+ def register(
1361
+ self,
1362
+ /,
1363
+ rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator,
1364
+ *rules: RewriteOrRule,
1365
+ _increase_frame: bool = False,
1366
+ ) -> None:
1392
1367
  """
1393
1368
  Register rewrites or rules, either as a function or as values.
1394
1369
  """
1395
- if isinstance(rule_or_generator, FunctionType):
1396
- assert not rules
1397
- rules = tuple(_command_generator(rule_or_generator))
1370
+ if isinstance(rule_or_generator, RewriteOrRule):
1371
+ self.append(rule_or_generator)
1372
+ for r in rules:
1373
+ self.append(r)
1398
1374
  else:
1399
- rules = (cast(Rule | Rewrite, rule_or_generator), *rules)
1400
- for r in rules:
1401
- self.append(r)
1402
-
1403
- @cached_property
1404
- def __egg_decls__(self) -> Declarations:
1405
- return Declarations.create(*self.rules)
1406
-
1407
- @property
1408
- def _cmds(self) -> list[bindings._Command]:
1409
- cmds = [r._to_egg_command(self.egg_name) for r in self.rules]
1410
- if self.egg_name:
1411
- cmds.insert(0, bindings.AddRuleset(self.egg_name))
1412
- return cmds
1375
+ assert not rules
1376
+ current_frame = inspect.currentframe()
1377
+ assert current_frame
1378
+ original_frame = current_frame.f_back
1379
+ assert original_frame
1380
+ if _increase_frame:
1381
+ original_frame = original_frame.f_back
1382
+ assert original_frame
1383
+ self.deferred_rule_gens.append(Thunk.fn(_rewrite_or_rule_generator, rule_or_generator, original_frame))
1413
1384
 
1414
1385
  def __str__(self) -> str:
1415
- return f"ruleset(name={self.egg_name!r})"
1386
+ return pretty_decl(self._current_egg_decls, self.__egg_ruleset__, ruleset_name=self.name)
1416
1387
 
1417
1388
  def __repr__(self) -> str:
1418
- if not self.rules:
1419
- return str(self)
1420
- rules = ", ".join(map(repr, self.rules))
1421
- return f"ruleset({rules}, name={self.egg_name!r})"
1422
-
1423
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1424
- return bindings.Run(self._to_egg_config())
1425
-
1426
- def _to_egg_config(self) -> bindings.RunConfig:
1427
- return bindings.RunConfig(self.egg_name, None)
1428
-
1429
- def _rulesets(self) -> Iterable[Ruleset]:
1430
- yield self
1431
-
1432
- @property
1433
- def egg_name(self) -> str:
1434
- return self.name or f"_ruleset_{id(self)}"
1435
-
1436
-
1437
- class Command(ABC):
1438
- """
1439
- A command that can be executed in the egg interpreter.
1440
-
1441
- We only use this for commands which return no result and don't create new Python objects.
1442
-
1443
- Anything that can be passed to the `register` function in a Module is a Command.
1444
- """
1445
-
1446
- ruleset: Ruleset | None
1389
+ return str(self)
1447
1390
 
1391
+ # Create a unique name if we didn't pass one from the user
1448
1392
  @property
1449
- @abstractmethod
1450
- def __egg_decls__(self) -> Declarations:
1451
- raise NotImplementedError
1452
-
1453
- @abstractmethod
1454
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1455
- raise NotImplementedError
1456
-
1457
- @abstractmethod
1458
- def __str__(self) -> str:
1459
- raise NotImplementedError
1393
+ def __egg_name__(self) -> str:
1394
+ return self.name or f"ruleset_{id(self)}"
1460
1395
 
1461
1396
 
1462
1397
  @dataclass
1463
- class Rewrite(Command):
1464
- ruleset: Ruleset | None
1465
- _lhs: RuntimeExpr
1466
- _rhs: RuntimeExpr
1467
- _conditions: tuple[Fact, ...]
1468
- _subsume: bool
1469
- _fn_name: ClassVar[str] = "rewrite"
1398
+ class RewriteOrRule:
1399
+ __egg_decls__: Declarations
1400
+ decl: RewriteOrRuleDecl
1401
+ ruleset: Ruleset | None = None
1470
1402
 
1471
1403
  def __str__(self) -> str:
1472
- args_str = ", ".join(map(str, [self._rhs, *self._conditions]))
1473
- return f"{self._fn_name}({self._lhs}).to({args_str})"
1474
-
1475
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1476
- return bindings.RewriteCommand(
1477
- self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite(), self._subsume
1478
- )
1479
-
1480
- def _to_egg_rewrite(self) -> bindings.Rewrite:
1481
- return bindings.Rewrite(
1482
- self._lhs.__egg_typed_expr__.expr.to_egg(self._lhs.__egg_decls__),
1483
- self._rhs.__egg_typed_expr__.expr.to_egg(self._rhs.__egg_decls__),
1484
- [c._to_egg_fact() for c in self._conditions],
1485
- )
1486
-
1487
- @cached_property
1488
- def __egg_decls__(self) -> Declarations:
1489
- return Declarations.create(self._lhs, self._rhs, *self._conditions)
1490
-
1491
- def with_ruleset(self, ruleset: Ruleset) -> Rewrite:
1492
- return Rewrite(ruleset, self._lhs, self._rhs, self._conditions, self._subsume)
1404
+ return pretty_decl(self.__egg_decls__, self.decl)
1493
1405
 
1494
-
1495
- @dataclass
1496
- class BiRewrite(Rewrite):
1497
- _fn_name: ClassVar[str] = "birewrite"
1498
-
1499
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1500
- return bindings.BiRewriteCommand(
1501
- self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
1502
- )
1406
+ def __repr__(self) -> str:
1407
+ return str(self)
1503
1408
 
1504
1409
 
1505
1410
  @dataclass
1506
- class Fact(ABC):
1411
+ class Fact:
1507
1412
  """
1508
1413
  A query on an EGraph, either by an expression or an equivalence between multiple expressions.
1509
1414
  """
1510
1415
 
1511
- @abstractmethod
1512
- def _to_egg_fact(self) -> bindings._Fact:
1513
- raise NotImplementedError
1514
-
1515
- @property
1516
- @abstractmethod
1517
- def __egg_decls__(self) -> Declarations:
1518
- raise NotImplementedError
1519
-
1520
-
1521
- @dataclass
1522
- class Eq(Fact):
1523
- _exprs: list[RuntimeExpr]
1416
+ __egg_decls__: Declarations
1417
+ fact: FactDecl
1524
1418
 
1525
1419
  def __str__(self) -> str:
1526
- first, *rest = self._exprs
1527
- args_str = ", ".join(map(str, rest))
1528
- return f"eq({first}).to({args_str})"
1529
-
1530
- def _to_egg_fact(self) -> bindings.Eq:
1531
- return bindings.Eq([e.__egg__ for e in self._exprs])
1420
+ return pretty_decl(self.__egg_decls__, self.fact)
1532
1421
 
1533
- @cached_property
1534
- def __egg_decls__(self) -> Declarations:
1535
- return Declarations.create(*self._exprs)
1536
-
1537
-
1538
- @dataclass
1539
- class ExprFact(Fact):
1540
- _expr: RuntimeExpr
1541
-
1542
- def __str__(self) -> str:
1543
- return str(self._expr)
1544
-
1545
- def _to_egg_fact(self) -> bindings.Fact:
1546
- return bindings.Fact(self._expr.__egg__)
1547
-
1548
- @cached_property
1549
- def __egg_decls__(self) -> Declarations:
1550
- return self._expr.__egg_decls__
1422
+ def __repr__(self) -> str:
1423
+ return str(self)
1551
1424
 
1552
1425
 
1553
1426
  @dataclass
1554
- class Rule(Command):
1555
- head: tuple[Action, ...]
1556
- body: tuple[Fact, ...]
1557
- name: str
1558
- ruleset: Ruleset | None
1559
-
1560
- def __str__(self) -> str:
1561
- head_str = ", ".join(map(str, self.head))
1562
- body_str = ", ".join(map(str, self.body))
1563
- return f"rule({body_str}).then({head_str})"
1564
-
1565
- def _to_egg_command(self, default_ruleset_name: str) -> bindings.RuleCommand:
1566
- return bindings.RuleCommand(
1567
- self.name,
1568
- self.ruleset.egg_name if self.ruleset else default_ruleset_name,
1569
- bindings.Rule(
1570
- [a._to_egg_action() for a in self.head],
1571
- [f._to_egg_fact() for f in self.body],
1572
- ),
1573
- )
1574
-
1575
- @cached_property
1576
- def __egg_decls__(self) -> Declarations:
1577
- return Declarations.create(*self.head, *self.body)
1578
-
1579
-
1580
- class Action(Command, ABC):
1427
+ class Action:
1581
1428
  """
1582
1429
  A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking.
1583
1430
  """
1584
1431
 
1585
- @abstractmethod
1586
- def _to_egg_action(self) -> bindings._Action:
1587
- raise NotImplementedError
1588
-
1589
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1590
- return bindings.ActionCommand(self._to_egg_action())
1591
-
1592
- @property
1593
- def ruleset(self) -> None | Ruleset: # type: ignore[override]
1594
- return None
1595
-
1596
-
1597
- @dataclass
1598
- class Let(Action):
1599
- _name: str
1600
- _value: RuntimeExpr
1601
-
1602
- def __str__(self) -> str:
1603
- return f"let({self._name}, {self._value})"
1604
-
1605
- def _to_egg_action(self) -> bindings.Let:
1606
- return bindings.Let(self._name, self._value.__egg__)
1607
-
1608
- @property
1609
- def __egg_decls__(self) -> Declarations:
1610
- return self._value.__egg_decls__
1611
-
1612
-
1613
- @dataclass
1614
- class Set(Action):
1615
- """
1616
- Similar to union, except can be used on primitive expressions, whereas union can only be used on user defined expressions.
1617
- """
1618
-
1619
- _call: RuntimeExpr
1620
- _rhs: RuntimeExpr
1621
-
1622
- def __str__(self) -> str:
1623
- return f"set({self._call}).to({self._rhs})"
1624
-
1625
- def _to_egg_action(self) -> bindings.Set:
1626
- egg_call = self._call.__egg__
1627
- if not isinstance(egg_call, bindings.Call):
1628
- raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") # noqa: TRY004
1629
- return bindings.Set(
1630
- egg_call.name,
1631
- egg_call.args,
1632
- self._rhs.__egg__,
1633
- )
1634
-
1635
- @cached_property
1636
- def __egg_decls__(self) -> Declarations:
1637
- return Declarations.create(self._call, self._rhs)
1638
-
1639
-
1640
- @dataclass
1641
- class ExprAction(Action):
1642
- _expr: RuntimeExpr
1643
-
1644
- def __str__(self) -> str:
1645
- return str(self._expr)
1646
-
1647
- def _to_egg_action(self) -> bindings.Expr_:
1648
- return bindings.Expr_(self._expr.__egg__)
1649
-
1650
- @property
1651
- def __egg_decls__(self) -> Declarations:
1652
- return self._expr.__egg_decls__
1653
-
1654
-
1655
- @dataclass
1656
- class Change(Action):
1657
- """
1658
- Change a function call in an EGraph.
1659
- """
1660
-
1661
- change: Literal["delete", "subsume"]
1662
- _call: RuntimeExpr
1663
-
1664
- def __str__(self) -> str:
1665
- return f"{self.change}({self._call})"
1666
-
1667
- def _to_egg_action(self) -> bindings.Change:
1668
- egg_call = self._call.__egg__
1669
- if not isinstance(egg_call, bindings.Call):
1670
- raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004
1671
- change: bindings._Change = bindings.Delete() if self.change == "delete" else bindings.Subsume()
1672
- return bindings.Change(change, egg_call.name, egg_call.args)
1673
-
1674
- @property
1675
- def __egg_decls__(self) -> Declarations:
1676
- return self._call.__egg_decls__
1677
-
1678
-
1679
- @dataclass
1680
- class Union_(Action): # noqa: N801
1681
- """
1682
- Merges two equivalence classes of two expressions.
1683
- """
1684
-
1685
- _lhs: RuntimeExpr
1686
- _rhs: RuntimeExpr
1687
-
1688
- def __str__(self) -> str:
1689
- return f"union({self._lhs}).with_({self._rhs})"
1690
-
1691
- def _to_egg_action(self) -> bindings.Union:
1692
- return bindings.Union(self._lhs.__egg__, self._rhs.__egg__)
1693
-
1694
- @cached_property
1695
- def __egg_decls__(self) -> Declarations:
1696
- return Declarations.create(self._lhs, self._rhs)
1697
-
1698
-
1699
- @dataclass
1700
- class Panic(Action):
1701
- message: str
1702
-
1703
- def __str__(self) -> str:
1704
- return f"panic({self.message})"
1705
-
1706
- def _to_egg_action(self) -> bindings.Panic:
1707
- return bindings.Panic(self.message)
1708
-
1709
- @cached_property
1710
- def __egg_decls__(self) -> Declarations:
1711
- return Declarations()
1712
-
1713
-
1714
- @dataclass
1715
- class Run(Schedule):
1716
- """Configuration of a run"""
1717
-
1718
- # None if using default ruleset
1719
- ruleset: Ruleset | None
1720
- until: tuple[Fact, ...]
1721
-
1722
- def __str__(self) -> str:
1723
- args_str = ", ".join(map(str, [self.ruleset, *self.until]))
1724
- return f"run({args_str})"
1725
-
1726
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1727
- return bindings.Run(self._to_egg_config(default_ruleset_name))
1728
-
1729
- def _to_egg_config(self, default_ruleset_name: str) -> bindings.RunConfig:
1730
- return bindings.RunConfig(
1731
- self.ruleset.egg_name if self.ruleset else default_ruleset_name,
1732
- [fact._to_egg_fact() for fact in self.until] if self.until else None,
1733
- )
1734
-
1735
- def _rulesets(self) -> Iterable[Ruleset]:
1736
- if self.ruleset:
1737
- yield self.ruleset
1738
-
1739
- @cached_property
1740
- def __egg_decls__(self) -> Declarations:
1741
- return Declarations.create(self.ruleset, *self.until)
1742
-
1743
-
1744
- @dataclass
1745
- class Saturate(Schedule):
1746
- schedule: Schedule
1747
-
1748
- def __str__(self) -> str:
1749
- return f"{self.schedule}.saturate()"
1750
-
1751
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1752
- return bindings.Saturate(self.schedule._to_egg_schedule(default_ruleset_name))
1753
-
1754
- def _rulesets(self) -> Iterable[Ruleset]:
1755
- return self.schedule._rulesets()
1756
-
1757
- @property
1758
- def __egg_decls__(self) -> Declarations:
1759
- return self.schedule.__egg_decls__
1760
-
1761
-
1762
- @dataclass
1763
- class Repeat(Schedule):
1764
- length: int
1765
- schedule: Schedule
1766
-
1767
- def __str__(self) -> str:
1768
- return f"{self.schedule} * {self.length}"
1769
-
1770
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1771
- return bindings.Repeat(self.length, self.schedule._to_egg_schedule(default_ruleset_name))
1772
-
1773
- def _rulesets(self) -> Iterable[Ruleset]:
1774
- return self.schedule._rulesets()
1775
-
1776
- @property
1777
- def __egg_decls__(self) -> Declarations:
1778
- return self.schedule.__egg_decls__
1779
-
1780
-
1781
- @dataclass
1782
- class Sequence(Schedule):
1783
- schedules: tuple[Schedule, ...]
1432
+ __egg_decls__: Declarations
1433
+ action: ActionDecl
1784
1434
 
1785
1435
  def __str__(self) -> str:
1786
- return f"sequence({', '.join(map(str, self.schedules))})"
1787
-
1788
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1789
- return bindings.Sequence([schedule._to_egg_schedule(default_ruleset_name) for schedule in self.schedules])
1436
+ return pretty_decl(self.__egg_decls__, self.action)
1790
1437
 
1791
- def _rulesets(self) -> Iterable[Ruleset]:
1792
- for s in self.schedules:
1793
- yield from s._rulesets()
1794
-
1795
- @cached_property
1796
- def __egg_decls__(self) -> Declarations:
1797
- return Declarations.create(*self.schedules)
1438
+ def __repr__(self) -> str:
1439
+ return str(self)
1798
1440
 
1799
1441
 
1800
1442
  # We use these builders so that when creating these structures we can type check
@@ -1841,30 +1483,41 @@ def ne(expr: EXPR) -> _NeBuilder[EXPR]:
1841
1483
 
1842
1484
  def panic(message: str) -> Action:
1843
1485
  """Raise an error with the given message."""
1844
- return Panic(message)
1486
+ return Action(Declarations(), PanicDecl(message))
1845
1487
 
1846
1488
 
1847
1489
  def let(name: str, expr: Expr) -> Action:
1848
1490
  """Create a let binding."""
1849
- return Let(name, to_runtime_expr(expr))
1491
+ runtime_expr = to_runtime_expr(expr)
1492
+ return Action(runtime_expr.__egg_decls__, LetDecl(name, runtime_expr.__egg_typed_expr__))
1850
1493
 
1851
1494
 
1852
1495
  def expr_action(expr: Expr) -> Action:
1853
- return ExprAction(to_runtime_expr(expr))
1496
+ runtime_expr = to_runtime_expr(expr)
1497
+ return Action(runtime_expr.__egg_decls__, ExprActionDecl(runtime_expr.__egg_typed_expr__))
1854
1498
 
1855
1499
 
1856
1500
  def delete(expr: Expr) -> Action:
1857
1501
  """Create a delete expression."""
1858
- return Change("delete", to_runtime_expr(expr))
1502
+ runtime_expr = to_runtime_expr(expr)
1503
+ typed_expr = runtime_expr.__egg_typed_expr__
1504
+ call_decl = typed_expr.expr
1505
+ assert isinstance(call_decl, CallDecl), "Can only delete calls, not literals or vars"
1506
+ return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "delete"))
1859
1507
 
1860
1508
 
1861
1509
  def subsume(expr: Expr) -> Action:
1862
- """Subsume an expression."""
1863
- return Change("subsume", to_runtime_expr(expr))
1510
+ """Subsume an expression so it cannot be matched against or extracted"""
1511
+ runtime_expr = to_runtime_expr(expr)
1512
+ typed_expr = runtime_expr.__egg_typed_expr__
1513
+ call_decl = typed_expr.expr
1514
+ assert isinstance(call_decl, CallDecl), "Can only subsume calls, not literals or vars"
1515
+ return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume"))
1864
1516
 
1865
1517
 
1866
1518
  def expr_fact(expr: Expr) -> Fact:
1867
- return ExprFact(to_runtime_expr(expr))
1519
+ runtime_expr = to_runtime_expr(expr)
1520
+ return Fact(runtime_expr.__egg_decls__, ExprFactDecl(runtime_expr.__egg_typed_expr__))
1868
1521
 
1869
1522
 
1870
1523
  def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
@@ -1891,6 +1544,11 @@ def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = No
1891
1544
  return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
1892
1545
 
1893
1546
 
1547
+ @deprecated("This function is now a no-op, you can remove it and use actions as commands")
1548
+ def action_command(action: Action) -> Action:
1549
+ return action
1550
+
1551
+
1894
1552
  def var(name: str, bound: type[EXPR]) -> EXPR:
1895
1553
  """Create a new variable with the given name and type."""
1896
1554
  return cast(EXPR, _var(name, bound))
@@ -1898,9 +1556,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
1898
1556
 
1899
1557
  def _var(name: str, bound: object) -> RuntimeExpr:
1900
1558
  """Create a new variable with the given name and type."""
1901
- if not isinstance(bound, RuntimeClass | RuntimeParamaterizedClass):
1559
+ if not isinstance(bound, RuntimeClass):
1902
1560
  raise TypeError(f"Unexpected type {type(bound)}")
1903
- return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(class_to_ref(bound), VarDecl(name)))
1561
+ return RuntimeExpr.__from_value__(bound.__egg_decls__, TypedExprDecl(bound.__egg_tp__.to_just(), VarDecl(name)))
1904
1562
 
1905
1563
 
1906
1564
  def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
@@ -1915,15 +1573,27 @@ class _RewriteBuilder(Generic[EXPR]):
1915
1573
  ruleset: Ruleset | None
1916
1574
  subsume: bool
1917
1575
 
1918
- def to(self, rhs: EXPR, *conditions: FactLike) -> Rewrite:
1576
+ def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
1919
1577
  lhs = to_runtime_expr(self.lhs)
1920
- rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), self.subsume)
1578
+ facts = _fact_likes(conditions)
1579
+ rhs = convert_to_same_type(rhs, lhs)
1580
+ rule = RewriteOrRule(
1581
+ Declarations.create(lhs, rhs, *facts, self.ruleset),
1582
+ RewriteDecl(
1583
+ lhs.__egg_typed_expr__.tp,
1584
+ lhs.__egg_typed_expr__.expr,
1585
+ rhs.__egg_typed_expr__.expr,
1586
+ tuple(f.fact for f in facts),
1587
+ self.subsume,
1588
+ ),
1589
+ )
1921
1590
  if self.ruleset:
1922
1591
  self.ruleset.append(rule)
1923
1592
  return rule
1924
1593
 
1925
1594
  def __str__(self) -> str:
1926
- return f"rewrite({self.lhs})"
1595
+ lhs = to_runtime_expr(self.lhs)
1596
+ return lhs.__egg_pretty__("rewrite")
1927
1597
 
1928
1598
 
1929
1599
  @dataclass
@@ -1931,15 +1601,26 @@ class _BirewriteBuilder(Generic[EXPR]):
1931
1601
  lhs: EXPR
1932
1602
  ruleset: Ruleset | None
1933
1603
 
1934
- def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
1604
+ def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
1935
1605
  lhs = to_runtime_expr(self.lhs)
1936
- rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), False)
1606
+ facts = _fact_likes(conditions)
1607
+ rhs = convert_to_same_type(rhs, lhs)
1608
+ rule = RewriteOrRule(
1609
+ Declarations.create(lhs, rhs, *facts, self.ruleset),
1610
+ BiRewriteDecl(
1611
+ lhs.__egg_typed_expr__.tp,
1612
+ lhs.__egg_typed_expr__.expr,
1613
+ rhs.__egg_typed_expr__.expr,
1614
+ tuple(f.fact for f in facts),
1615
+ ),
1616
+ )
1937
1617
  if self.ruleset:
1938
1618
  self.ruleset.append(rule)
1939
1619
  return rule
1940
1620
 
1941
1621
  def __str__(self) -> str:
1942
- return f"birewrite({self.lhs})"
1622
+ lhs = to_runtime_expr(self.lhs)
1623
+ return lhs.__egg_pretty__("birewrite")
1943
1624
 
1944
1625
 
1945
1626
  @dataclass
@@ -1948,52 +1629,84 @@ class _EqBuilder(Generic[EXPR]):
1948
1629
 
1949
1630
  def to(self, *exprs: EXPR) -> Fact:
1950
1631
  expr = to_runtime_expr(self.expr)
1951
- return Eq([expr] + [convert_to_same_type(e, expr) for e in exprs])
1632
+ args = [expr, *(convert_to_same_type(e, expr) for e in exprs)]
1633
+ return Fact(
1634
+ Declarations.create(*args),
1635
+ EqDecl(expr.__egg_typed_expr__.tp, tuple(a.__egg_typed_expr__.expr for a in args)),
1636
+ )
1637
+
1638
+ def __repr__(self) -> str:
1639
+ return str(self)
1952
1640
 
1953
1641
  def __str__(self) -> str:
1954
- return f"eq({self.expr})"
1642
+ expr = to_runtime_expr(self.expr)
1643
+ return expr.__egg_pretty__("eq")
1955
1644
 
1956
1645
 
1957
1646
  @dataclass
1958
1647
  class _NeBuilder(Generic[EXPR]):
1959
- expr: EXPR
1648
+ lhs: EXPR
1960
1649
 
1961
- def to(self, expr: EXPR) -> Unit:
1962
- assert isinstance(self.expr, RuntimeExpr)
1963
- args = (self.expr, convert_to_same_type(expr, self.expr))
1964
- decls = Declarations.create(*args)
1965
- res = RuntimeExpr(
1966
- decls,
1967
- TypedExprDecl(JustTypeRef("Unit"), CallDecl(FunctionRef("!="), tuple(a.__egg_typed_expr__ for a in args))),
1650
+ def to(self, rhs: EXPR) -> Unit:
1651
+ lhs = to_runtime_expr(self.lhs)
1652
+ rhs = convert_to_same_type(rhs, lhs)
1653
+ assert isinstance(Unit, RuntimeClass)
1654
+ res = RuntimeExpr.__from_value__(
1655
+ Declarations.create(Unit, lhs, rhs),
1656
+ TypedExprDecl(
1657
+ JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
1658
+ ),
1968
1659
  )
1969
1660
  return cast(Unit, res)
1970
1661
 
1662
+ def __repr__(self) -> str:
1663
+ return str(self)
1664
+
1971
1665
  def __str__(self) -> str:
1972
- return f"ne({self.expr})"
1666
+ expr = to_runtime_expr(self.lhs)
1667
+ return expr.__egg_pretty__("ne")
1973
1668
 
1974
1669
 
1975
1670
  @dataclass
1976
1671
  class _SetBuilder(Generic[EXPR]):
1977
- lhs: Expr
1672
+ lhs: EXPR
1978
1673
 
1979
- def to(self, rhs: EXPR) -> Set:
1674
+ def to(self, rhs: EXPR) -> Action:
1980
1675
  lhs = to_runtime_expr(self.lhs)
1981
- return Set(lhs, convert_to_same_type(rhs, lhs))
1676
+ rhs = convert_to_same_type(rhs, lhs)
1677
+ lhs_expr = lhs.__egg_typed_expr__.expr
1678
+ assert isinstance(lhs_expr, CallDecl), "Can only set function calls"
1679
+ return Action(
1680
+ Declarations.create(lhs, rhs),
1681
+ SetDecl(lhs.__egg_typed_expr__.tp, lhs_expr, rhs.__egg_typed_expr__.expr),
1682
+ )
1683
+
1684
+ def __repr__(self) -> str:
1685
+ return str(self)
1982
1686
 
1983
1687
  def __str__(self) -> str:
1984
- return f"set_({self.lhs})"
1688
+ lhs = to_runtime_expr(self.lhs)
1689
+ return lhs.__egg_pretty__("set_")
1985
1690
 
1986
1691
 
1987
1692
  @dataclass
1988
1693
  class _UnionBuilder(Generic[EXPR]):
1989
- lhs: Expr
1694
+ lhs: EXPR
1990
1695
 
1991
1696
  def with_(self, rhs: EXPR) -> Action:
1992
1697
  lhs = to_runtime_expr(self.lhs)
1993
- return Union_(lhs, convert_to_same_type(rhs, lhs))
1698
+ rhs = convert_to_same_type(rhs, lhs)
1699
+ return Action(
1700
+ Declarations.create(lhs, rhs),
1701
+ UnionDecl(lhs.__egg_typed_expr__.tp, lhs.__egg_typed_expr__.expr, rhs.__egg_typed_expr__.expr),
1702
+ )
1703
+
1704
+ def __repr__(self) -> str:
1705
+ return str(self)
1994
1706
 
1995
1707
  def __str__(self) -> str:
1996
- return f"union({self.lhs})"
1708
+ lhs = to_runtime_expr(self.lhs)
1709
+ return lhs.__egg_pretty__("union")
1997
1710
 
1998
1711
 
1999
1712
  @dataclass
@@ -2002,12 +1715,25 @@ class _RuleBuilder:
2002
1715
  name: str | None
2003
1716
  ruleset: Ruleset | None
2004
1717
 
2005
- def then(self, *actions: ActionLike) -> Rule:
2006
- rule = Rule(_action_likes(actions), self.facts, self.name or "", self.ruleset)
1718
+ def then(self, *actions: ActionLike) -> RewriteOrRule:
1719
+ actions = _action_likes(actions)
1720
+ rule = RewriteOrRule(
1721
+ Declarations.create(self.ruleset, *actions, *self.facts),
1722
+ RuleDecl(tuple(a.action for a in actions), tuple(f.fact for f in self.facts), self.name),
1723
+ )
2007
1724
  if self.ruleset:
2008
1725
  self.ruleset.append(rule)
2009
1726
  return rule
2010
1727
 
1728
+ def __str__(self) -> str:
1729
+ # TODO: Figure out how to stringify rulebuilder that preserves statements
1730
+ args = list(map(str, self.facts))
1731
+ if self.name is not None:
1732
+ args.append(f"name={self.name}")
1733
+ if ruleset is not None:
1734
+ args.append(f"ruleset={self.ruleset}")
1735
+ return f"rule({', '.join(args)})"
1736
+
2011
1737
 
2012
1738
  def expr_parts(expr: Expr) -> TypedExprDecl:
2013
1739
  """
@@ -2024,60 +1750,61 @@ def to_runtime_expr(expr: Expr) -> RuntimeExpr:
2024
1750
  return expr
2025
1751
 
2026
1752
 
2027
- def run(ruleset: Ruleset | None = None, *until: Fact) -> Run:
1753
+ def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
2028
1754
  """
2029
1755
  Create a run configuration.
2030
1756
  """
2031
- return Run(ruleset, tuple(until))
1757
+ facts = _fact_likes(until)
1758
+ return Schedule(
1759
+ Thunk.fn(Declarations.create, ruleset, *facts),
1760
+ RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
1761
+ )
2032
1762
 
2033
1763
 
2034
1764
  def seq(*schedules: Schedule) -> Schedule:
2035
1765
  """
2036
1766
  Run a sequence of schedules.
2037
1767
  """
2038
- return Sequence(tuple(schedules))
1768
+ return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
2039
1769
 
2040
1770
 
2041
- CommandLike = Command | Expr
1771
+ ActionLike: TypeAlias = Action | Expr
2042
1772
 
2043
1773
 
2044
- def _command_like(command_like: CommandLike) -> Command:
2045
- if isinstance(command_like, Expr):
2046
- return expr_action(command_like)
2047
- return command_like
1774
+ def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
1775
+ return tuple(map(_action_like, action_likes))
2048
1776
 
2049
1777
 
2050
- CommandGenerator = Callable[..., Iterable[Rule | Rewrite]]
1778
+ def _action_like(action_like: ActionLike) -> Action:
1779
+ if isinstance(action_like, Expr):
1780
+ return expr_action(action_like)
1781
+ return action_like
2051
1782
 
2052
1783
 
2053
- def _command_generator(gen: CommandGenerator) -> Iterable[Command]:
2054
- """
2055
- Calls the function with variables of the type and name of the arguments.
2056
- """
2057
- # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
2058
- # but not in the globals
2059
- current_frame = inspect.currentframe()
2060
- assert current_frame
2061
- register_frame = current_frame.f_back
2062
- assert register_frame
2063
- original_frame = register_frame.f_back
2064
- assert original_frame
2065
- hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals)
2066
- args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
2067
- return gen(*args)
1784
+ Command: TypeAlias = Action | RewriteOrRule
2068
1785
 
1786
+ CommandLike: TypeAlias = ActionLike | RewriteOrRule
2069
1787
 
2070
- ActionLike = Action | Expr
2071
1788
 
1789
+ def _command_like(command_like: CommandLike) -> Command:
1790
+ if isinstance(command_like, RewriteOrRule):
1791
+ return command_like
1792
+ return _action_like(command_like)
2072
1793
 
2073
- def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
2074
- return tuple(map(_action_like, action_likes))
2075
1794
 
1795
+ RewriteOrRuleGenerator = Callable[..., Iterable[RewriteOrRule]]
2076
1796
 
2077
- def _action_like(action_like: ActionLike) -> Action:
2078
- if isinstance(action_like, Expr):
2079
- return expr_action(action_like)
2080
- return action_like
1797
+
1798
+ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> Iterable[RewriteOrRule]:
1799
+ """
1800
+ Returns a thunk which will call the function with variables of the type and name of the arguments.
1801
+ """
1802
+ # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
1803
+ # but not in the globals
1804
+
1805
+ hints = get_type_hints(gen, gen.__globals__, frame.f_locals)
1806
+ args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1807
+ return list(gen(*args)) # type: ignore[misc]
2081
1808
 
2082
1809
 
2083
1810
  FactLike = Fact | Expr