egglog 6.1.0__cp312-none-win_amd64.whl → 7.1.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py CHANGED
@@ -3,12 +3,10 @@ from __future__ import annotations
3
3
  import inspect
4
4
  import pathlib
5
5
  import tempfile
6
- from abc import ABC, abstractmethod
6
+ from abc import abstractmethod
7
7
  from collections.abc import Callable, Iterable
8
8
  from contextvars import ContextVar, Token
9
- from copy import deepcopy
10
9
  from dataclasses import InitVar, dataclass, field
11
- from functools import cached_property
12
10
  from inspect import Parameter, currentframe, signature
13
11
  from types import FrameType, FunctionType
14
12
  from typing import (
@@ -18,6 +16,7 @@ from typing import (
18
16
  Generic,
19
17
  Literal,
20
18
  NoReturn,
19
+ TypeAlias,
21
20
  TypedDict,
22
21
  TypeVar,
23
22
  cast,
@@ -26,14 +25,16 @@ from typing import (
26
25
  )
27
26
 
28
27
  import graphviz
29
- from typing_extensions import ParamSpec, Self, Unpack, deprecated
30
-
31
- from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations
28
+ from typing_extensions import ParamSpec, Self, Unpack, assert_never, deprecated
32
29
 
33
30
  from . import bindings
31
+ from .conversion import *
34
32
  from .declarations import *
33
+ from .egraph_state import *
35
34
  from .ipython_magic import IN_IPYTHON
35
+ from .pretty import pretty_decl
36
36
  from .runtime import *
37
+ from .thunk import *
37
38
 
38
39
  if TYPE_CHECKING:
39
40
  import ipywidgets
@@ -58,6 +59,7 @@ __all__ = [
58
59
  "let",
59
60
  "constant",
60
61
  "delete",
62
+ "subsume",
61
63
  "union",
62
64
  "set_",
63
65
  "rule",
@@ -65,11 +67,15 @@ __all__ = [
65
67
  "vars_",
66
68
  "Fact",
67
69
  "expr_parts",
70
+ "expr_action",
71
+ "expr_fact",
72
+ "action_command",
68
73
  "Schedule",
69
74
  "run",
70
75
  "seq",
71
76
  "Command",
72
77
  "simplify",
78
+ "unstable_combine_rulesets",
73
79
  "check",
74
80
  "GraphvizKwargs",
75
81
  "Ruleset",
@@ -79,11 +85,11 @@ __all__ = [
79
85
  "_NeBuilder",
80
86
  "_SetBuilder",
81
87
  "_UnionBuilder",
82
- "Rule",
83
- "Rewrite",
84
- "BiRewrite",
85
- "Union_",
88
+ "RewriteOrRule",
89
+ "Fact",
86
90
  "Action",
91
+ "Command",
92
+ "check_eq",
87
93
  ]
88
94
 
89
95
  T = TypeVar("T")
@@ -112,7 +118,24 @@ IGNORED_ATTRIBUTES = {
112
118
  }
113
119
 
114
120
 
121
+ # special methods that return none and mutate self
115
122
  ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"}
123
+ # special methods which must return real python values instead of lazy expressions
124
+ ALWAYS_PRESERVED = {
125
+ "__repr__",
126
+ "__str__",
127
+ "__bytes__",
128
+ "__format__",
129
+ "__hash__",
130
+ "__bool__",
131
+ "__len__",
132
+ "__length_hint__",
133
+ "__iter__",
134
+ "__reversed__",
135
+ "__contains__",
136
+ "__index__",
137
+ "__bufer__",
138
+ }
116
139
 
117
140
 
118
141
  def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
@@ -124,7 +147,24 @@ def simplify(x: EXPR, schedule: Schedule | None = None) -> EXPR:
124
147
  return EGraph().extract(x)
125
148
 
126
149
 
127
- def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr | Set) -> None:
150
+ def check_eq(x: EXPR, y: EXPR, schedule: Schedule | None = None) -> EGraph:
151
+ """
152
+ Verifies that two expressions are equal after running the schedule.
153
+ """
154
+ egraph = EGraph()
155
+ x_var = egraph.let("__check_eq_x", x)
156
+ y_var = egraph.let("__check_eq_y", y)
157
+ if schedule:
158
+ egraph.run(schedule)
159
+ fact = eq(x_var).to(y_var)
160
+ try:
161
+ egraph.check(fact)
162
+ except bindings.EggSmolError as err:
163
+ raise AssertionError(f"Failed {eq(x).to(y)}\n -> {ne(egraph.extract(x)).to(egraph.extract(y))})") from err
164
+ return egraph
165
+
166
+
167
+ def check(x: FactLike, schedule: Schedule | None = None, *given: ActionLike) -> None:
128
168
  """
129
169
  Verifies that the fact is true given some assumptions and after running the schedule.
130
170
  """
@@ -136,9 +176,6 @@ def check(x: FactLike, schedule: Schedule | None = None, *given: Union_ | Expr |
136
176
  egraph.check(x)
137
177
 
138
178
 
139
- # def extract(res: )
140
-
141
-
142
179
  @dataclass
143
180
  class _BaseModule:
144
181
  """
@@ -181,15 +218,8 @@ class _BaseModule:
181
218
  Registers a class.
182
219
  """
183
220
  if kwargs:
184
- assert set(kwargs.keys()) == {"egg_sort"}
185
-
186
- def _inner(cls: object, egg_sort: str = kwargs["egg_sort"]):
187
- assert isinstance(cls, RuntimeClass)
188
- assert isinstance(cls.lazy_decls, _ClassDeclerationsConstructor)
189
- cls.lazy_decls.egg_sort = egg_sort
190
- return cls
191
-
192
- return _inner
221
+ msg = "Switch to subclassing from Expr and passing egg_sort as a keyword arg to the class constructor"
222
+ raise NotImplementedError(msg)
193
223
 
194
224
  assert len(args) == 1
195
225
  return args[0]
@@ -280,9 +310,9 @@ class _BaseModule:
280
310
  # If we have any positional args, then we are calling it directly on a function
281
311
  if args:
282
312
  assert len(args) == 1
283
- return _function(args[0], fn_locals, False)
313
+ return _FunctionConstructor(fn_locals)(args[0])
284
314
  # otherwise, we are passing some keyword args, so save those, and then return a partial
285
- return lambda fn: _function(fn, fn_locals, False, **kwargs)
315
+ return _FunctionConstructor(fn_locals, **kwargs)
286
316
 
287
317
  @deprecated("Use top level `ruleset` function instead")
288
318
  def ruleset(self, name: str) -> Ruleset:
@@ -324,17 +354,26 @@ class _BaseModule:
324
354
  """
325
355
  return constant(name, tp, egg_name)
326
356
 
327
- def register(self, /, command_or_generator: CommandLike | CommandGenerator, *command_likes: CommandLike) -> None:
357
+ def register(
358
+ self,
359
+ /,
360
+ command_or_generator: ActionLike | RewriteOrRule | RewriteOrRuleGenerator,
361
+ *command_likes: ActionLike | RewriteOrRule,
362
+ ) -> None:
328
363
  """
329
364
  Registers any number of rewrites or rules.
330
365
  """
331
366
  if isinstance(command_or_generator, FunctionType):
332
367
  assert not command_likes
333
- command_likes = tuple(_command_generator(command_or_generator))
368
+ current_frame = inspect.currentframe()
369
+ assert current_frame
370
+ original_frame = current_frame.f_back
371
+ assert original_frame
372
+ command_likes = tuple(_rewrite_or_rule_generator(command_or_generator, original_frame))
334
373
  else:
335
374
  command_likes = (cast(CommandLike, command_or_generator), *command_likes)
336
-
337
- self._register_commands(list(map(_command_like, command_likes)))
375
+ commands = [_command_like(c) for c in command_likes]
376
+ self._register_commands(commands)
338
377
 
339
378
  @abstractmethod
340
379
  def _register_commands(self, cmds: list[Command]) -> None:
@@ -417,136 +456,126 @@ class _ExprMetaclass(type):
417
456
  # If this is the Expr subclass, just return the class
418
457
  if not bases:
419
458
  return super().__new__(cls, name, bases, namespace)
459
+ # TODO: Raise error on subclassing or multiple inheritence
420
460
 
421
461
  frame = currentframe()
422
462
  assert frame
423
463
  prev_frame = frame.f_back
424
464
  assert prev_frame
425
- return _ClassDeclerationsConstructor(
426
- namespace=namespace,
427
- # Store frame so that we can get live access to updated locals/globals
428
- # Otherwise, f_locals returns a copy
429
- # https://peps.python.org/pep-0667/
430
- frame=prev_frame,
431
- builtin=builtin,
432
- egg_sort=egg_sort,
433
- cls_name=name,
434
- ).current_cls
465
+
466
+ # Store frame so that we can get live access to updated locals/globals
467
+ # Otherwise, f_locals returns a copy
468
+ # https://peps.python.org/pep-0667/
469
+ decls_thunk = Thunk.fn(
470
+ _generate_class_decls, namespace, prev_frame, builtin, egg_sort, name, fallback=Declarations
471
+ )
472
+ return RuntimeClass(decls_thunk, TypeRefWithVars(name))
435
473
 
436
474
  def __instancecheck__(cls, instance: object) -> bool:
437
475
  return isinstance(instance, RuntimeExpr)
438
476
 
439
477
 
440
- @dataclass
441
- class _ClassDeclerationsConstructor:
478
+ def _generate_class_decls( # noqa: C901
479
+ namespace: dict[str, Any], frame: FrameType, builtin: bool, egg_sort: str | None, cls_name: str
480
+ ) -> Declarations:
442
481
  """
443
482
  Lazy constructor for class declerations to support classes with methods whose types are not yet defined.
444
483
  """
484
+ parameters: list[TypeVar] = (
485
+ # Get the generic params from the orig bases generic class
486
+ namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in namespace else []
487
+ )
488
+ type_vars = tuple(p.__name__ for p in parameters)
489
+ del parameters
490
+ cls_decl = ClassDecl(egg_sort, type_vars, builtin)
491
+ decls = Declarations(_classes={cls_name: cls_decl})
492
+
493
+ ##
494
+ # Register class variables
495
+ ##
496
+ # Create a dummy type to pass to get_type_hints to resolve the annotations we have
497
+ _Dummytype = type("_DummyType", (), {"__annotations__": namespace.get("__annotations__", {})})
498
+ for k, v in get_type_hints(_Dummytype, globalns=frame.f_globals, localns=frame.f_locals).items():
499
+ if getattr(v, "__origin__", None) == ClassVar:
500
+ (inner_tp,) = v.__args__
501
+ type_ref = resolve_type_annotation(decls, inner_tp).to_just()
502
+ cls_decl.class_variables[k] = ConstantDecl(type_ref)
445
503
 
446
- namespace: dict[str, Any]
447
- frame: FrameType
448
- builtin: bool
449
- egg_sort: str | None
450
- cls_name: str
451
- current_cls: RuntimeClass = field(init=False)
504
+ else:
505
+ msg = f"On class {cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
506
+ raise NotImplementedError(msg)
452
507
 
453
- def __post_init__(self) -> None:
454
- self.current_cls = RuntimeClass(self, self.cls_name)
455
-
456
- def __call__(self, decls: Declarations) -> None: # noqa: PLR0912
457
- # Get all the methods from the class
458
- cls_dict: dict[str, Any] = {
459
- k: v for k, v in self.namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
460
- }
461
- parameters: list[TypeVar] = (
462
- # Get the generic params from the orig bases generic class
463
- self.namespace["__orig_bases__"][1].__parameters__ if "__orig_bases__" in self.namespace else []
464
- )
465
- type_vars = tuple(p.__name__ for p in parameters)
466
- del parameters
467
-
468
- decls.register_class(self.cls_name, type_vars, self.builtin, self.egg_sort)
469
- # The type ref of self is paramterized by the type vars
470
- slf_type_ref = TypeRefWithVars(self.cls_name, tuple(map(ClassTypeVarRef, type_vars)))
471
-
472
- # Create a dummy type to pass to get_type_hints to resolve the annotations we have
473
- class _Dummytype:
474
- pass
475
-
476
- _Dummytype.__annotations__ = self.namespace.get("__annotations__", {})
477
- # Make lazy update to locals, so we keep a live handle on them after class creation
478
- locals = self.frame.f_locals.copy()
479
- locals[self.cls_name] = self.current_cls
480
- for k, v in get_type_hints(_Dummytype, globalns=self.frame.f_globals, localns=locals).items():
481
- if getattr(v, "__origin__", None) == ClassVar:
482
- (inner_tp,) = v.__args__
483
- _register_constant(decls, ClassVariableRef(self.cls_name, k), inner_tp, None)
484
- else:
485
- msg = f"On class {self.cls_name}, for attribute '{k}', expected a ClassVar, but got {v}"
486
- raise NotImplementedError(msg)
487
-
488
- # Then register each of its methods
489
- for method_name, method in cls_dict.items():
490
- is_init = method_name == "__init__"
491
- # Don't register the init methods for literals, since those don't use the type checking mechanisms
492
- if is_init and self.cls_name in LIT_CLASS_NAMES:
493
- continue
494
- if isinstance(method, _WrappedMethod):
495
- fn = method.fn
496
- egg_fn = method.egg_fn
497
- cost = method.cost
498
- default = method.default
499
- merge = method.merge
500
- on_merge = method.on_merge
501
- mutates_first_arg = method.mutates_self
502
- unextractable = method.unextractable
503
- if method.preserve:
504
- decls.register_preserved_method(self.cls_name, method_name, fn)
505
- continue
506
- else:
507
- fn = method
508
+ ##
509
+ # Register methods, classmethods, preserved methods, and properties
510
+ ##
511
+
512
+ # The type ref of self is paramterized by the type vars
513
+ slf_type_ref = TypeRefWithVars(cls_name, tuple(map(ClassTypeVarRef, type_vars)))
514
+
515
+ # Get all the methods from the class
516
+ filtered_namespace: list[tuple[str, Any]] = [
517
+ (k, v) for k, v in namespace.items() if k not in IGNORED_ATTRIBUTES or isinstance(v, _WrappedMethod)
518
+ ]
519
+
520
+ # Then register each of its methods
521
+ for method_name, method in filtered_namespace:
522
+ is_init = method_name == "__init__"
523
+ # Don't register the init methods for literals, since those don't use the type checking mechanisms
524
+ if is_init and cls_name in LIT_CLASS_NAMES:
525
+ continue
526
+ match method:
527
+ case _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn, preserve, mutates, unextractable):
528
+ pass
529
+ case _:
508
530
  egg_fn, cost, default, merge, on_merge = None, None, None, None, None
509
- unextractable = False
510
- mutates_first_arg = False
511
- if isinstance(fn, classmethod):
512
- fn = fn.__func__
513
- is_classmethod = True
514
- else:
515
- # We count __init__ as a classmethod since it is called on the class
516
- is_classmethod = is_init
517
-
518
- if isinstance(fn, property):
519
- fn = fn.fget
520
- is_property = True
521
- if is_classmethod:
522
- msg = "Can't have a classmethod property"
523
- raise NotImplementedError(msg)
524
- else:
525
- is_property = False
526
- ref: FunctionCallableRef = (
527
- ClassMethodRef(self.cls_name, method_name)
528
- if is_classmethod
529
- else PropertyRef(self.cls_name, method_name)
530
- if is_property
531
- else MethodRef(self.cls_name, method_name)
531
+ fn = method
532
+ unextractable, preserve = False, False
533
+ mutates = method_name in ALWAYS_MUTATES_SELF
534
+ if preserve:
535
+ cls_decl.preserved_methods[method_name] = fn
536
+ continue
537
+ locals = frame.f_locals
538
+
539
+ def create_decl(fn: object, first: Literal["cls"] | TypeRefWithVars) -> FunctionDecl:
540
+ special_function_name: SpecialFunctions | None = (
541
+ "fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None # noqa: B023
532
542
  )
533
- _register_function(
543
+ if special_function_name:
544
+ return FunctionDecl(
545
+ special_function_name,
546
+ builtin=True,
547
+ egg_name=egg_fn, # noqa: B023
548
+ )
549
+
550
+ return _fn_decl(
534
551
  decls,
535
- ref,
536
- egg_fn,
552
+ egg_fn, # noqa: B023
537
553
  fn,
538
- locals,
539
- default,
540
- cost,
541
- merge,
542
- on_merge,
543
- mutates_first_arg or method_name in ALWAYS_MUTATES_SELF,
544
- self.builtin,
545
- "cls" if is_classmethod and not is_init else slf_type_ref,
546
- is_init,
547
- unextractable=unextractable,
554
+ locals, # noqa: B023
555
+ default, # noqa: B023
556
+ cost, # noqa: B023
557
+ merge, # noqa: B023
558
+ on_merge, # noqa: B023
559
+ mutates, # noqa: B023
560
+ builtin,
561
+ first,
562
+ is_init, # noqa: B023
563
+ unextractable, # noqa: B023
548
564
  )
549
565
 
566
+ match fn:
567
+ case classmethod():
568
+ cls_decl.class_methods[method_name] = create_decl(fn.__func__, "cls")
569
+ case property():
570
+ cls_decl.properties[method_name] = create_decl(fn.fget, slf_type_ref)
571
+ case _:
572
+ if is_init:
573
+ cls_decl.class_methods[method_name] = create_decl(fn, slf_type_ref)
574
+ else:
575
+ cls_decl.methods[method_name] = create_decl(fn, slf_type_ref)
576
+
577
+ return decls
578
+
550
579
 
551
580
  @overload
552
581
  def function(fn: CALLABLE, /) -> CALLABLE: ...
@@ -589,48 +618,46 @@ def function(*args, **kwargs) -> Any:
589
618
  # If we have any positional args, then we are calling it directly on a function
590
619
  if args:
591
620
  assert len(args) == 1
592
- return _function(args[0], fn_locals, False)
621
+ return _FunctionConstructor(fn_locals)(args[0])
593
622
  # otherwise, we are passing some keyword args, so save those, and then return a partial
594
- return lambda fn: _function(fn, fn_locals, **kwargs)
623
+ return _FunctionConstructor(fn_locals, **kwargs)
595
624
 
596
625
 
597
- def _function(
598
- fn: Callable[..., RuntimeExpr],
599
- hint_locals: dict[str, Any],
600
- builtin: bool = False,
601
- mutates_first_arg: bool = False,
602
- egg_fn: str | None = None,
603
- cost: int | None = None,
604
- default: RuntimeExpr | None = None,
605
- merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None,
606
- on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None,
607
- unextractable: bool = False,
608
- ) -> RuntimeFunction:
609
- """
610
- Uncurried version of function decorator
611
- """
612
- name = fn.__name__
613
- decls = Declarations()
614
- _register_function(
615
- decls,
616
- FunctionRef(name),
617
- egg_fn,
618
- fn,
619
- hint_locals,
620
- default,
621
- cost,
622
- merge,
623
- on_merge,
624
- mutates_first_arg,
625
- builtin,
626
- unextractable=unextractable,
627
- )
628
- return RuntimeFunction(decls, name)
626
+ @dataclass
627
+ class _FunctionConstructor:
628
+ hint_locals: dict[str, Any]
629
+ builtin: bool = False
630
+ mutates_first_arg: bool = False
631
+ egg_fn: str | None = None
632
+ cost: int | None = None
633
+ default: RuntimeExpr | None = None
634
+ merge: Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr] | None = None
635
+ on_merge: Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]] | None = None
636
+ unextractable: bool = False
637
+
638
+ def __call__(self, fn: Callable[..., RuntimeExpr]) -> RuntimeFunction:
639
+ return RuntimeFunction(Thunk.fn(self.create_decls, fn), FunctionRef(fn.__name__))
640
+
641
+ def create_decls(self, fn: Callable[..., RuntimeExpr]) -> Declarations:
642
+ decls = Declarations()
643
+ decls._functions[fn.__name__] = _fn_decl(
644
+ decls,
645
+ self.egg_fn,
646
+ fn,
647
+ self.hint_locals,
648
+ self.default,
649
+ self.cost,
650
+ self.merge,
651
+ self.on_merge,
652
+ self.mutates_first_arg,
653
+ self.builtin,
654
+ unextractable=self.unextractable,
655
+ )
656
+ return decls
629
657
 
630
658
 
631
- def _register_function(
659
+ def _fn_decl(
632
660
  decls: Declarations,
633
- ref: FunctionCallableRef,
634
661
  egg_name: str | None,
635
662
  fn: object,
636
663
  # Pass in the locals, retrieved from the frame when wrapping,
@@ -646,11 +673,15 @@ def _register_function(
646
673
  first_arg: Literal["cls"] | TypeOrVarRef | None = None,
647
674
  is_init: bool = False,
648
675
  unextractable: bool = False,
649
- ) -> None:
676
+ ) -> FunctionDecl:
650
677
  if not isinstance(fn, FunctionType):
651
678
  raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
652
679
 
653
680
  hint_globals = fn.__globals__.copy()
681
+ # Copy Callable into global if not present bc sometimes it gets automatically removed by ruff to type only block
682
+ # https://docs.astral.sh/ruff/rules/typing-only-standard-library-import/
683
+ if "Callable" not in hint_globals:
684
+ hint_globals["Callable"] = Callable
654
685
 
655
686
  hints = get_type_hints(fn, hint_globals, hint_locals)
656
687
 
@@ -699,8 +730,8 @@ def _register_function(
699
730
  None
700
731
  if merge is None
701
732
  else merge(
702
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
703
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
733
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
734
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
704
735
  )
705
736
  )
706
737
  decls |= merged
@@ -710,30 +741,27 @@ def _register_function(
710
741
  if on_merge is None
711
742
  else _action_likes(
712
743
  on_merge(
713
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
714
- RuntimeExpr(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
744
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
745
+ RuntimeExpr.__from_value__(decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
715
746
  )
716
747
  )
717
748
  )
718
749
  decls.update(*merge_action)
719
- fn_decl = FunctionDecl(
720
- return_type=return_type,
721
- var_arg_type=var_arg_type,
722
- arg_types=arg_types,
723
- arg_names=tuple(t.name for t in params),
724
- arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
725
- mutates_first_arg=mutates_first_arg,
726
- )
727
- decls.register_function_callable(
728
- ref,
729
- fn_decl,
730
- egg_name,
731
- cost,
732
- None if default is None else default.__egg_typed_expr__.expr,
733
- merged.__egg_typed_expr__.expr if merged is not None else None,
734
- [a._to_egg_action() for a in merge_action],
735
- unextractable,
736
- is_builtin,
750
+ return FunctionDecl(
751
+ FunctionSignature(
752
+ return_type=None if mutates_first_arg else return_type,
753
+ var_arg_type=var_arg_type,
754
+ arg_types=arg_types,
755
+ arg_names=tuple(t.name for t in params),
756
+ arg_defaults=tuple(a.__egg_typed_expr__.expr if a is not None else None for a in arg_defaults),
757
+ ),
758
+ cost=cost,
759
+ egg_name=egg_name,
760
+ merge=merged.__egg_typed_expr__.expr if merged is not None else None,
761
+ unextractable=unextractable,
762
+ builtin=is_builtin,
763
+ default=None if default is None else default.__egg_typed_expr__.expr,
764
+ on_merge=tuple(a.action for a in merge_action),
737
765
  )
738
766
 
739
767
 
@@ -764,49 +792,31 @@ def relation(name: str, /, *tps: type, egg_fn: str | None = None) -> Callable[..
764
792
  """
765
793
  Creates a function whose return type is `Unit` and has a default value.
766
794
  """
795
+ decls_thunk = Thunk.fn(_relation_decls, name, tps, egg_fn)
796
+ return cast(Callable[..., Unit], RuntimeFunction(decls_thunk, FunctionRef(name)))
797
+
798
+
799
+ def _relation_decls(name: str, tps: tuple[type, ...], egg_fn: str | None) -> Declarations:
767
800
  decls = Declarations()
768
801
  decls |= cast(RuntimeClass, Unit)
769
- arg_types = tuple(resolve_type_annotation(decls, tp) for tp in tps)
770
- fn_decl = FunctionDecl(arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit"), mutates_first_arg=False)
771
- decls.register_function_callable(
772
- FunctionRef(name),
773
- fn_decl,
774
- egg_fn,
775
- cost=None,
776
- default=None,
777
- merge=None,
778
- merge_action=[],
779
- unextractable=False,
780
- builtin=False,
781
- is_relation=True,
782
- )
783
- return cast(Callable[..., Unit], RuntimeFunction(decls, name))
802
+ arg_types = tuple(resolve_type_annotation(decls, tp).to_just() for tp in tps)
803
+ decls._functions[name] = RelationDecl(arg_types, tuple(None for _ in tps), egg_fn)
804
+ return decls
784
805
 
785
806
 
786
807
  def constant(name: str, tp: type[EXPR], egg_name: str | None = None) -> EXPR:
787
808
  """
788
-
789
809
  A "constant" is implemented as the instantiation of a value that takes no args.
790
810
  This creates a function with `name` and return type `tp` and returns a value of it being called.
791
811
  """
792
- ref = ConstantRef(name)
793
- decls = Declarations()
794
- type_ref = _register_constant(decls, ref, tp, egg_name)
795
- return cast(EXPR, RuntimeExpr(decls, TypedExprDecl(type_ref, CallDecl(ref))))
812
+ return cast(EXPR, RuntimeExpr(Thunk.fn(_constant_thunk, name, tp, egg_name)))
796
813
 
797
814
 
798
- def _register_constant(
799
- decls: Declarations,
800
- ref: ConstantRef | ClassVariableRef,
801
- tp: object,
802
- egg_name: str | None,
803
- ) -> JustTypeRef:
804
- """
805
- Register a constant, returning its typeref().
806
- """
815
+ def _constant_thunk(name: str, tp: type, egg_name: str | None) -> tuple[Declarations, TypedExprDecl]:
816
+ decls = Declarations()
807
817
  type_ref = resolve_type_annotation(decls, tp).to_just()
808
- decls.register_constant_callable(ref, type_ref, egg_name)
809
- return type_ref
818
+ decls._constants[name] = ConstantDecl(type_ref, egg_name)
819
+ return decls, TypedExprDecl(type_ref, CallDecl(ConstantRef(name)))
810
820
 
811
821
 
812
822
  def _last_param_variable(params: list[Parameter]) -> bool:
@@ -858,29 +868,6 @@ class GraphvizKwargs(TypedDict, total=False):
858
868
  split_primitive_outputs: bool
859
869
 
860
870
 
861
- @dataclass
862
- class _EGraphState:
863
- """
864
- State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
865
- """
866
-
867
- # The decleratons we have added. The _cmds represent all the symbols we have added
868
- decls: Declarations = field(default_factory=Declarations)
869
- # List of rulesets already added, so we don't re-add them if they are passed again
870
- added_rulesets: set[str] = field(default_factory=set)
871
-
872
- def add_decls(self, new_decls: Declarations) -> Iterable[bindings._Command]:
873
- new_cmds = [v for k, v in new_decls._cmds.items() if k not in self.decls._cmds]
874
- self.decls |= new_decls
875
- return new_cmds
876
-
877
- def add_rulesets(self, rulesets: Iterable[Ruleset]) -> Iterable[bindings._Command]:
878
- for ruleset in rulesets:
879
- if ruleset.egg_name not in self.added_rulesets:
880
- self.added_rulesets.add(ruleset.egg_name)
881
- yield from ruleset._cmds
882
-
883
-
884
871
  @dataclass
885
872
  class EGraph(_BaseModule):
886
873
  """
@@ -892,56 +879,34 @@ class EGraph(_BaseModule):
892
879
  seminaive: InitVar[bool] = True
893
880
  save_egglog_string: InitVar[bool] = False
894
881
 
895
- default_ruleset: Ruleset | None = None
896
- _egraph: bindings.EGraph = field(repr=False, init=False)
897
- _egglog_string: str | None = field(default=None, repr=False, init=False)
898
- _state: _EGraphState = field(default_factory=_EGraphState, repr=False)
882
+ _state: EGraphState = field(init=False)
899
883
  # For pushing/popping with egglog
900
- _state_stack: list[_EGraphState] = field(default_factory=list, repr=False)
884
+ _state_stack: list[EGraphState] = field(default_factory=list, repr=False)
901
885
  # For storing the global "current" egraph
902
886
  _token_stack: list[Token[EGraph]] = field(default_factory=list, repr=False)
903
887
 
904
888
  def __post_init__(self, modules: list[Module], seminaive: bool, save_egglog_string: bool) -> None:
905
- self._egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive)
889
+ egraph = bindings.EGraph(GLOBAL_PY_OBJECT_SORT, seminaive=seminaive, record=save_egglog_string)
890
+ self._state = EGraphState(egraph)
906
891
  super().__post_init__(modules)
907
892
 
908
893
  for m in self._flatted_deps:
909
- self._add_decls(*m.cmds)
910
894
  self._register_commands(m.cmds)
911
- if save_egglog_string:
912
- self._egglog_string = ""
913
-
914
- def _register_commands(self, commands: list[Command]) -> None:
915
- for c in commands:
916
- if c.ruleset:
917
- self._add_schedule(c.ruleset)
918
-
919
- self._add_decls(*commands)
920
- self._process_commands(command._to_egg_command(self._default_ruleset_name) for command in commands)
921
-
922
- def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
923
- commands = list(commands)
924
- self._egraph.run_program(*commands)
925
- if isinstance(self._egglog_string, str):
926
- self._egglog_string += "\n".join(str(c) for c in commands) + "\n"
927
895
 
928
896
  def _add_decls(self, *decls: DeclerationsLike) -> None:
929
- for d in upcast_decleratioons(decls):
930
- self._process_commands(self._state.add_decls(d))
931
-
932
- def _add_schedule(self, schedule: Schedule) -> None:
933
- self._add_decls(schedule)
934
- self._process_commands(self._state.add_rulesets(schedule._rulesets()))
897
+ for d in decls:
898
+ self._state.__egg_decls__ |= d
935
899
 
936
900
  @property
937
901
  def as_egglog_string(self) -> str:
938
902
  """
939
903
  Returns the egglog string for this module.
940
904
  """
941
- if self._egglog_string is None:
905
+ cmds = self._egraph.commands()
906
+ if cmds is None:
942
907
  msg = "Can't get egglog string unless EGraph created with save_egglog_string=True"
943
908
  raise ValueError(msg)
944
- return self._egglog_string
909
+ return cmds
945
910
 
946
911
  def _repr_mimebundle_(self, *args, **kwargs):
947
912
  """
@@ -954,7 +919,7 @@ class EGraph(_BaseModule):
954
919
  kwargs.setdefault("split_primitive_outputs", True)
955
920
  n_inline = kwargs.pop("n_inline_leaves", 0)
956
921
  serialized = self._egraph.serialize([], **kwargs) # type: ignore[misc]
957
- serialized.map_ops(self._state.decls.op_mapping())
922
+ serialized.map_ops(self._state.op_mapping())
958
923
  for _ in range(n_inline):
959
924
  serialized.inline_leaves()
960
925
  original = serialized.to_dot()
@@ -1003,30 +968,35 @@ class EGraph(_BaseModule):
1003
968
  """
1004
969
  Displays the e-graph in the notebook.
1005
970
  """
1006
- graphviz = self.graphviz(**kwargs)
1007
971
  if IN_IPYTHON:
1008
972
  from IPython.display import SVG, display
1009
973
 
1010
974
  display(SVG(self.graphviz_svg(**kwargs)))
1011
975
  else:
1012
- graphviz.render(view=True, format="svg", quiet=True)
976
+ self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
1013
977
 
1014
978
  def input(self, fn: Callable[..., String], path: str) -> None:
1015
979
  """
1016
980
  Loads a CSV file and sets it as *input, output of the function.
1017
981
  """
1018
982
  ref, decls = resolve_callable(fn)
1019
- fn_name = decls.get_egg_fn(ref)
1020
- self._process_commands(decls.list_cmds())
1021
- self._process_commands([bindings.Input(fn_name, path)])
983
+ self._add_decls(decls)
984
+ fn_name = self._state.callable_ref_to_egg(ref)
985
+ self._egraph.run_program(bindings.Input(fn_name, path))
1022
986
 
1023
987
  def let(self, name: str, expr: EXPR) -> EXPR:
1024
988
  """
1025
989
  Define a new expression in the egraph and return a reference to it.
1026
990
  """
1027
- self._register_commands([let(name, expr)])
1028
- expr = to_runtime_expr(expr)
1029
- return cast(EXPR, RuntimeExpr(expr.__egg_decls__, TypedExprDecl(expr.__egg_typed_expr__.tp, VarDecl(name))))
991
+ action = let(name, expr)
992
+ self.register(action)
993
+ runtime_expr = to_runtime_expr(expr)
994
+ return cast(
995
+ EXPR,
996
+ RuntimeExpr.__from_value__(
997
+ self.__egg_decls__, TypedExprDecl(runtime_expr.__egg_typed_expr__.tp, VarDecl(name))
998
+ ),
999
+ )
1030
1000
 
1031
1001
  @overload
1032
1002
  def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Ruleset | None = None) -> EXPR: ...
@@ -1041,28 +1011,19 @@ class EGraph(_BaseModule):
1041
1011
  Simplifies the given expression.
1042
1012
  """
1043
1013
  schedule = run(ruleset, *until) * limit_or_schedule if isinstance(limit_or_schedule, int) else limit_or_schedule
1044
- del limit_or_schedule
1045
- expr = to_runtime_expr(expr)
1046
- self._add_decls(expr)
1047
- self._add_schedule(schedule)
1048
-
1049
- # decls = Declarations.create(expr, schedule)
1050
- self._process_commands([bindings.Simplify(expr.__egg__, schedule._to_egg_schedule(self._default_ruleset_name))])
1014
+ del limit_or_schedule, until, ruleset
1015
+ runtime_expr = to_runtime_expr(expr)
1016
+ self._add_decls(runtime_expr, schedule)
1017
+ egg_schedule = self._state.schedule_to_egg(schedule.schedule)
1018
+ typed_expr = runtime_expr.__egg_typed_expr__
1019
+ egg_expr = self._state.expr_to_egg(typed_expr.expr)
1020
+ self._egraph.run_program(bindings.Simplify(egg_expr, egg_schedule))
1051
1021
  extract_report = self._egraph.extract_report()
1052
1022
  if not isinstance(extract_report, bindings.Best):
1053
1023
  msg = "No extract report saved"
1054
1024
  raise ValueError(msg) # noqa: TRY004
1055
- new_typed_expr = TypedExprDecl.from_egg(
1056
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, extract_report.term, {}
1057
- )
1058
- return cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
1059
-
1060
- @property
1061
- def _default_ruleset_name(self) -> str:
1062
- if self.default_ruleset:
1063
- self._add_schedule(self.default_ruleset)
1064
- return self.default_ruleset.egg_name
1065
- return ""
1025
+ (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1026
+ return cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1066
1027
 
1067
1028
  def include(self, path: str) -> None:
1068
1029
  """
@@ -1092,8 +1053,9 @@ class EGraph(_BaseModule):
1092
1053
  return self._run_schedule(limit_or_schedule)
1093
1054
 
1094
1055
  def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
1095
- self._add_schedule(schedule)
1096
- self._process_commands([bindings.RunSchedule(schedule._to_egg_schedule(self._default_ruleset_name))])
1056
+ self._add_decls(schedule)
1057
+ egg_schedule = self._state.schedule_to_egg(schedule.schedule)
1058
+ self._egraph.run_program(bindings.RunSchedule(egg_schedule))
1097
1059
  run_report = self._egraph.run_report()
1098
1060
  if not run_report:
1099
1061
  msg = "No run report saved"
@@ -1104,18 +1066,18 @@ class EGraph(_BaseModule):
1104
1066
  """
1105
1067
  Check if a fact is true in the egraph.
1106
1068
  """
1107
- self._process_commands([self._facts_to_check(facts)])
1069
+ self._egraph.run_program(self._facts_to_check(facts))
1108
1070
 
1109
1071
  def check_fail(self, *facts: FactLike) -> None:
1110
1072
  """
1111
1073
  Checks that one of the facts is not true
1112
1074
  """
1113
- self._process_commands([bindings.Fail(self._facts_to_check(facts))])
1075
+ self._egraph.run_program(bindings.Fail(self._facts_to_check(facts)))
1114
1076
 
1115
- def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check:
1116
- facts = _fact_likes(facts)
1077
+ def _facts_to_check(self, fact_likes: Iterable[FactLike]) -> bindings.Check:
1078
+ facts = _fact_likes(fact_likes)
1117
1079
  self._add_decls(*facts)
1118
- egg_facts = [f._to_egg_fact() for f in _fact_likes(facts)]
1080
+ egg_facts = [self._state.fact_to_egg(f.fact) for f in _fact_likes(facts)]
1119
1081
  return bindings.Check(egg_facts)
1120
1082
 
1121
1083
  @overload
@@ -1128,16 +1090,17 @@ class EGraph(_BaseModule):
1128
1090
  """
1129
1091
  Extract the lowest cost expression from the egraph.
1130
1092
  """
1131
- assert isinstance(expr, RuntimeExpr)
1132
- self._add_decls(expr)
1133
- extract_report = self._run_extract(expr.__egg__, 0)
1093
+ runtime_expr = to_runtime_expr(expr)
1094
+ self._add_decls(runtime_expr)
1095
+ typed_expr = runtime_expr.__egg_typed_expr__
1096
+ extract_report = self._run_extract(typed_expr, 0)
1097
+
1134
1098
  if not isinstance(extract_report, bindings.Best):
1135
1099
  msg = "No extract report saved"
1136
1100
  raise ValueError(msg) # noqa: TRY004
1137
- new_typed_expr = TypedExprDecl.from_egg(
1138
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, extract_report.term, {}
1139
- )
1140
- res = cast(EXPR, RuntimeExpr(self._state.decls, new_typed_expr))
1101
+ (new_typed_expr,) = self._state.exprs_from_egg(extract_report.termdag, [extract_report.term], typed_expr.tp)
1102
+
1103
+ res = cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, new_typed_expr))
1141
1104
  if include_cost:
1142
1105
  return res, extract_report.cost
1143
1106
  return res
@@ -1146,23 +1109,21 @@ class EGraph(_BaseModule):
1146
1109
  """
1147
1110
  Extract multiple expressions from the egraph.
1148
1111
  """
1149
- assert isinstance(expr, RuntimeExpr)
1150
- self._add_decls(expr)
1112
+ runtime_expr = to_runtime_expr(expr)
1113
+ self._add_decls(runtime_expr)
1114
+ typed_expr = runtime_expr.__egg_typed_expr__
1151
1115
 
1152
- extract_report = self._run_extract(expr.__egg__, n)
1116
+ extract_report = self._run_extract(typed_expr, n)
1153
1117
  if not isinstance(extract_report, bindings.Variants):
1154
1118
  msg = "Wrong extract report type"
1155
1119
  raise ValueError(msg) # noqa: TRY004
1156
- new_exprs = [
1157
- TypedExprDecl.from_egg(
1158
- self._egraph, self._state.decls, expr.__egg_typed_expr__.tp, extract_report.termdag, term, {}
1159
- )
1160
- for term in extract_report.terms
1161
- ]
1162
- return [cast(EXPR, RuntimeExpr(self._state.decls, expr)) for expr in new_exprs]
1120
+ new_exprs = self._state.exprs_from_egg(extract_report.termdag, extract_report.terms, typed_expr.tp)
1121
+ return [cast(EXPR, RuntimeExpr.__from_value__(self.__egg_decls__, expr)) for expr in new_exprs]
1163
1122
 
1164
- def _run_extract(self, expr: bindings._Expr, n: int) -> bindings._ExtractReport:
1165
- self._process_commands([bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n))))])
1123
+ def _run_extract(self, typed_expr: TypedExprDecl, n: int) -> bindings._ExtractReport:
1124
+ self._state.type_ref_to_egg(typed_expr.tp)
1125
+ expr = self._state.expr_to_egg(typed_expr.expr)
1126
+ self._egraph.run_program(bindings.ActionCommand(bindings.Extract(expr, bindings.Lit(bindings.Int(n)))))
1166
1127
  extract_report = self._egraph.extract_report()
1167
1128
  if not extract_report:
1168
1129
  msg = "No extract report saved"
@@ -1173,15 +1134,15 @@ class EGraph(_BaseModule):
1173
1134
  """
1174
1135
  Push the current state of the egraph, so that it can be popped later and reverted back.
1175
1136
  """
1176
- self._process_commands([bindings.Push(1)])
1137
+ self._egraph.run_program(bindings.Push(1))
1177
1138
  self._state_stack.append(self._state)
1178
- self._state = deepcopy(self._state)
1139
+ self._state = self._state.copy()
1179
1140
 
1180
1141
  def pop(self) -> None:
1181
1142
  """
1182
1143
  Pop the current state of the egraph, reverting back to the previous state.
1183
1144
  """
1184
- self._process_commands([bindings.Pop(1)])
1145
+ self._egraph.run_program(bindings.Pop(1))
1185
1146
  self._state = self._state_stack.pop()
1186
1147
 
1187
1148
  def __enter__(self) -> Self:
@@ -1217,9 +1178,10 @@ class EGraph(_BaseModule):
1217
1178
  """
1218
1179
  Evaluates the given expression (which must be a primitive type), returning the result.
1219
1180
  """
1220
- assert isinstance(expr, RuntimeExpr)
1221
- typed_expr = expr.__egg_typed_expr__
1222
- egg_expr = expr.__egg__
1181
+ runtime_expr = to_runtime_expr(expr)
1182
+ self._add_decls(runtime_expr)
1183
+ typed_expr = runtime_expr.__egg_typed_expr__
1184
+ egg_expr = self._state.expr_to_egg(typed_expr.expr)
1223
1185
  match typed_expr.tp:
1224
1186
  case JustTypeRef("i64"):
1225
1187
  return self._egraph.eval_i64(egg_expr)
@@ -1231,7 +1193,7 @@ class EGraph(_BaseModule):
1231
1193
  return self._egraph.eval_string(egg_expr)
1232
1194
  case JustTypeRef("PyObject"):
1233
1195
  return self._egraph.eval_py_object(egg_expr)
1234
- raise NotImplementedError(f"Eval not implemented for {typed_expr.tp.name}")
1196
+ raise TypeError(f"Eval not implemented for {typed_expr.tp}")
1235
1197
 
1236
1198
  def saturate(
1237
1199
  self, *, max: int = 1000, performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
@@ -1270,6 +1232,32 @@ class EGraph(_BaseModule):
1270
1232
  """
1271
1233
  return CURRENT_EGRAPH.get()
1272
1234
 
1235
+ @property
1236
+ def _egraph(self) -> bindings.EGraph:
1237
+ return self._state.egraph
1238
+
1239
+ @property
1240
+ def __egg_decls__(self) -> Declarations:
1241
+ return self._state.__egg_decls__
1242
+
1243
+ def _register_commands(self, cmds: list[Command]) -> None:
1244
+ self._add_decls(*cmds)
1245
+ egg_cmds = list(map(self._command_to_egg, cmds))
1246
+ self._egraph.run_program(*egg_cmds)
1247
+
1248
+ def _command_to_egg(self, cmd: Command) -> bindings._Command:
1249
+ ruleset_name = ""
1250
+ cmd_decl: CommandDecl
1251
+ match cmd:
1252
+ case RewriteOrRule(_, cmd_decl, ruleset):
1253
+ if ruleset:
1254
+ ruleset_name = ruleset.__egg_name__
1255
+ case Action(_, action):
1256
+ cmd_decl = ActionCommandDecl(action)
1257
+ case _:
1258
+ assert_never(cmd)
1259
+ return self._state.command_to_egg(cmd_decl, ruleset_name)
1260
+
1273
1261
 
1274
1262
  CURRENT_EGRAPH = ContextVar[EGraph]("CURRENT_EGRAPH")
1275
1263
 
@@ -1316,61 +1304,55 @@ class Unit(Expr, egg_sort="Unit", builtin=True):
1316
1304
 
1317
1305
 
1318
1306
  def ruleset(
1319
- rule_or_generator: CommandLike | CommandGenerator | None = None, *rules: Rule | Rewrite, name: None | str = None
1307
+ rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator | None = None,
1308
+ *rules: RewriteOrRule,
1309
+ name: None | str = None,
1320
1310
  ) -> Ruleset:
1321
1311
  """
1322
1312
  Creates a ruleset with the following rules.
1323
1313
 
1324
- If no name is provided, one is generated based on the current module
1314
+ If no name is provided, try using the name of the funciton.
1325
1315
  """
1326
- r = Ruleset(name=name)
1316
+ if isinstance(rule_or_generator, FunctionType):
1317
+ name = name or rule_or_generator.__name__
1318
+ r = Ruleset(name)
1327
1319
  if rule_or_generator is not None:
1328
- r.register(rule_or_generator, *rules)
1320
+ r.register(rule_or_generator, *rules, _increase_frame=True)
1329
1321
  return r
1330
1322
 
1331
1323
 
1332
- class Schedule(ABC):
1324
+ @dataclass
1325
+ class Schedule(DelayedDeclerations):
1333
1326
  """
1334
1327
  A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
1335
1328
  """
1336
1329
 
1330
+ # Defer declerations so that we can have rule generators that used not yet defined yet
1331
+ schedule: ScheduleDecl
1332
+
1333
+ def __str__(self) -> str:
1334
+ return pretty_decl(self.__egg_decls__, self.schedule)
1335
+
1336
+ def __repr__(self) -> str:
1337
+ return str(self)
1338
+
1337
1339
  def __mul__(self, length: int) -> Schedule:
1338
1340
  """
1339
1341
  Repeat the schedule a number of times.
1340
1342
  """
1341
- return Repeat(length, self)
1343
+ return Schedule(self.__egg_decls_thunk__, RepeatDecl(self.schedule, length))
1342
1344
 
1343
1345
  def saturate(self) -> Schedule:
1344
1346
  """
1345
1347
  Run the schedule until the e-graph is saturated.
1346
1348
  """
1347
- return Saturate(self)
1349
+ return Schedule(self.__egg_decls_thunk__, SaturateDecl(self.schedule))
1348
1350
 
1349
1351
  def __add__(self, other: Schedule) -> Schedule:
1350
1352
  """
1351
1353
  Run two schedules in sequence.
1352
1354
  """
1353
- return Sequence((self, other))
1354
-
1355
- @abstractmethod
1356
- def __str__(self) -> str:
1357
- raise NotImplementedError
1358
-
1359
- @abstractmethod
1360
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1361
- raise NotImplementedError
1362
-
1363
- @abstractmethod
1364
- def _rulesets(self) -> Iterable[Ruleset]:
1365
- """
1366
- Mapping of all the rulesets used to commands.
1367
- """
1368
- raise NotImplementedError
1369
-
1370
- @property
1371
- @abstractmethod
1372
- def __egg_decls__(self) -> Declarations:
1373
- raise NotImplementedError
1355
+ return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
1374
1356
 
1375
1357
 
1376
1358
  @dataclass
@@ -1379,422 +1361,155 @@ class Ruleset(Schedule):
1379
1361
  A collection of rules, which can be run as a schedule.
1380
1362
  """
1381
1363
 
1364
+ __egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
1365
+ schedule: RunDecl = field(init=False)
1382
1366
  name: str | None
1383
- rules: list[Rule | Rewrite] = field(default_factory=list)
1384
1367
 
1385
- def append(self, rule: Rule | Rewrite) -> None:
1368
+ # Current declerations we have accumulated
1369
+ _current_egg_decls: Declarations = field(default_factory=Declarations)
1370
+ # Current rulesets we have accumulated
1371
+ __egg_ruleset__: RulesetDecl = field(init=False)
1372
+ # Rule generator functions that have been deferred, to allow for late type binding
1373
+ deferred_rule_gens: list[Callable[[], Iterable[RewriteOrRule]]] = field(default_factory=list)
1374
+
1375
+ def __post_init__(self) -> None:
1376
+ self.schedule = RunDecl(self.__egg_name__, ())
1377
+ self.__egg_ruleset__ = self._current_egg_decls._rulesets[self.__egg_name__] = RulesetDecl([])
1378
+ self.__egg_decls_thunk__ = self._update_egg_decls
1379
+
1380
+ def _update_egg_decls(self) -> Declarations:
1381
+ """
1382
+ To return the egg decls, we go through our deferred rules and add any we haven't yet
1383
+ """
1384
+ while self.deferred_rule_gens:
1385
+ rules = self.deferred_rule_gens.pop()()
1386
+ self._current_egg_decls.update(*rules)
1387
+ self.__egg_ruleset__.rules.extend(r.decl for r in rules)
1388
+ return self._current_egg_decls
1389
+
1390
+ def append(self, rule: RewriteOrRule) -> None:
1386
1391
  """
1387
1392
  Register a rule with the ruleset.
1388
1393
  """
1389
- self.rules.append(rule)
1394
+ self._current_egg_decls |= rule
1395
+ self.__egg_ruleset__.rules.append(rule.decl)
1390
1396
 
1391
- def register(self, /, rule_or_generator: CommandLike | CommandGenerator, *rules: Rule | Rewrite) -> None:
1397
+ def register(
1398
+ self,
1399
+ /,
1400
+ rule_or_generator: RewriteOrRule | RewriteOrRuleGenerator,
1401
+ *rules: RewriteOrRule,
1402
+ _increase_frame: bool = False,
1403
+ ) -> None:
1392
1404
  """
1393
1405
  Register rewrites or rules, either as a function or as values.
1394
1406
  """
1395
- if isinstance(rule_or_generator, FunctionType):
1396
- assert not rules
1397
- rules = tuple(_command_generator(rule_or_generator))
1407
+ if isinstance(rule_or_generator, RewriteOrRule):
1408
+ self.append(rule_or_generator)
1409
+ for r in rules:
1410
+ self.append(r)
1398
1411
  else:
1399
- rules = (cast(Rule | Rewrite, rule_or_generator), *rules)
1400
- for r in rules:
1401
- self.append(r)
1402
-
1403
- @cached_property
1404
- def __egg_decls__(self) -> Declarations:
1405
- return Declarations.create(*self.rules)
1406
-
1407
- @property
1408
- def _cmds(self) -> list[bindings._Command]:
1409
- cmds = [r._to_egg_command(self.egg_name) for r in self.rules]
1410
- if self.egg_name:
1411
- cmds.insert(0, bindings.AddRuleset(self.egg_name))
1412
- return cmds
1412
+ assert not rules
1413
+ current_frame = inspect.currentframe()
1414
+ assert current_frame
1415
+ original_frame = current_frame.f_back
1416
+ assert original_frame
1417
+ if _increase_frame:
1418
+ original_frame = original_frame.f_back
1419
+ assert original_frame
1420
+ self.deferred_rule_gens.append(Thunk.fn(_rewrite_or_rule_generator, rule_or_generator, original_frame))
1413
1421
 
1414
1422
  def __str__(self) -> str:
1415
- return f"ruleset(name={self.egg_name!r})"
1423
+ return pretty_decl(self._current_egg_decls, self.__egg_ruleset__, ruleset_name=self.name)
1416
1424
 
1417
1425
  def __repr__(self) -> str:
1418
- if not self.rules:
1419
- return str(self)
1420
- rules = ", ".join(map(repr, self.rules))
1421
- return f"ruleset({rules}, name={self.egg_name!r})"
1422
-
1423
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1424
- return bindings.Run(self._to_egg_config())
1425
-
1426
- def _to_egg_config(self) -> bindings.RunConfig:
1427
- return bindings.RunConfig(self.egg_name, None)
1426
+ return str(self)
1428
1427
 
1429
- def _rulesets(self) -> Iterable[Ruleset]:
1430
- yield self
1428
+ def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
1429
+ return unstable_combine_rulesets(self, other)
1431
1430
 
1431
+ # Create a unique name if we didn't pass one from the user
1432
1432
  @property
1433
- def egg_name(self) -> str:
1434
- return self.name or f"_ruleset_{id(self)}"
1435
-
1436
-
1437
- class Command(ABC):
1438
- """
1439
- A command that can be executed in the egg interpreter.
1440
-
1441
- We only use this for commands which return no result and don't create new Python objects.
1442
-
1443
- Anything that can be passed to the `register` function in a Module is a Command.
1444
- """
1445
-
1446
- ruleset: Ruleset | None
1447
-
1448
- @property
1449
- @abstractmethod
1450
- def __egg_decls__(self) -> Declarations:
1451
- raise NotImplementedError
1452
-
1453
- @abstractmethod
1454
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1455
- raise NotImplementedError
1456
-
1457
- @abstractmethod
1458
- def __str__(self) -> str:
1459
- raise NotImplementedError
1460
-
1461
-
1462
- @dataclass
1463
- class Rewrite(Command):
1464
- ruleset: Ruleset | None
1465
- _lhs: RuntimeExpr
1466
- _rhs: RuntimeExpr
1467
- _conditions: tuple[Fact, ...]
1468
- _subsume: bool
1469
- _fn_name: ClassVar[str] = "rewrite"
1470
-
1471
- def __str__(self) -> str:
1472
- args_str = ", ".join(map(str, [self._rhs, *self._conditions]))
1473
- return f"{self._fn_name}({self._lhs}).to({args_str})"
1474
-
1475
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1476
- return bindings.RewriteCommand(
1477
- self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite(), self._subsume
1478
- )
1479
-
1480
- def _to_egg_rewrite(self) -> bindings.Rewrite:
1481
- return bindings.Rewrite(
1482
- self._lhs.__egg_typed_expr__.expr.to_egg(self._lhs.__egg_decls__),
1483
- self._rhs.__egg_typed_expr__.expr.to_egg(self._rhs.__egg_decls__),
1484
- [c._to_egg_fact() for c in self._conditions],
1485
- )
1486
-
1487
- @cached_property
1488
- def __egg_decls__(self) -> Declarations:
1489
- return Declarations.create(self._lhs, self._rhs, *self._conditions)
1490
-
1491
- def with_ruleset(self, ruleset: Ruleset) -> Rewrite:
1492
- return Rewrite(ruleset, self._lhs, self._rhs, self._conditions, self._subsume)
1433
+ def __egg_name__(self) -> str:
1434
+ return self.name or f"ruleset_{id(self)}"
1493
1435
 
1494
1436
 
1495
1437
  @dataclass
1496
- class BiRewrite(Rewrite):
1497
- _fn_name: ClassVar[str] = "birewrite"
1498
-
1499
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1500
- return bindings.BiRewriteCommand(
1501
- self.ruleset.egg_name if self.ruleset else default_ruleset_name, self._to_egg_rewrite()
1502
- )
1503
-
1504
-
1505
- @dataclass
1506
- class Fact(ABC):
1507
- """
1508
- A query on an EGraph, either by an expression or an equivalence between multiple expressions.
1509
- """
1510
-
1511
- @abstractmethod
1512
- def _to_egg_fact(self) -> bindings._Fact:
1513
- raise NotImplementedError
1514
-
1515
- @property
1516
- @abstractmethod
1517
- def __egg_decls__(self) -> Declarations:
1518
- raise NotImplementedError
1519
-
1520
-
1521
- @dataclass
1522
- class Eq(Fact):
1523
- _exprs: list[RuntimeExpr]
1524
-
1525
- def __str__(self) -> str:
1526
- first, *rest = self._exprs
1527
- args_str = ", ".join(map(str, rest))
1528
- return f"eq({first}).to({args_str})"
1529
-
1530
- def _to_egg_fact(self) -> bindings.Eq:
1531
- return bindings.Eq([e.__egg__ for e in self._exprs])
1532
-
1533
- @cached_property
1534
- def __egg_decls__(self) -> Declarations:
1535
- return Declarations.create(*self._exprs)
1536
-
1537
-
1538
- @dataclass
1539
- class ExprFact(Fact):
1540
- _expr: RuntimeExpr
1541
-
1542
- def __str__(self) -> str:
1543
- return str(self._expr)
1544
-
1545
- def _to_egg_fact(self) -> bindings.Fact:
1546
- return bindings.Fact(self._expr.__egg__)
1547
-
1548
- @cached_property
1549
- def __egg_decls__(self) -> Declarations:
1550
- return self._expr.__egg_decls__
1551
-
1552
-
1553
- @dataclass
1554
- class Rule(Command):
1555
- head: tuple[Action, ...]
1556
- body: tuple[Fact, ...]
1557
- name: str
1558
- ruleset: Ruleset | None
1559
-
1560
- def __str__(self) -> str:
1561
- head_str = ", ".join(map(str, self.head))
1562
- body_str = ", ".join(map(str, self.body))
1563
- return f"rule({body_str}).then({head_str})"
1564
-
1565
- def _to_egg_command(self, default_ruleset_name: str) -> bindings.RuleCommand:
1566
- return bindings.RuleCommand(
1567
- self.name,
1568
- self.ruleset.egg_name if self.ruleset else default_ruleset_name,
1569
- bindings.Rule(
1570
- [a._to_egg_action() for a in self.head],
1571
- [f._to_egg_fact() for f in self.body],
1572
- ),
1573
- )
1574
-
1575
- @cached_property
1576
- def __egg_decls__(self) -> Declarations:
1577
- return Declarations.create(*self.head, *self.body)
1578
-
1579
-
1580
- class Action(Command, ABC):
1581
- """
1582
- A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking.
1583
- """
1584
-
1585
- @abstractmethod
1586
- def _to_egg_action(self) -> bindings._Action:
1587
- raise NotImplementedError
1438
+ class UnstableCombinedRuleset(Schedule):
1439
+ __egg_decls_thunk__: Callable[[], Declarations] = field(init=False)
1440
+ schedule: RunDecl = field(init=False)
1441
+ name: str | None
1442
+ rulesets: InitVar[list[Ruleset | UnstableCombinedRuleset]]
1588
1443
 
1589
- def _to_egg_command(self, default_ruleset_name: str) -> bindings._Command:
1590
- return bindings.ActionCommand(self._to_egg_action())
1444
+ def __post_init__(self, rulesets: list[Ruleset | UnstableCombinedRuleset]) -> None:
1445
+ self.schedule = RunDecl(self.__egg_name__, ())
1446
+ self.__egg_decls_thunk__ = Thunk.fn(self._create_egg_decls, *rulesets)
1591
1447
 
1592
1448
  @property
1593
- def ruleset(self) -> None | Ruleset: # type: ignore[override]
1594
- return None
1595
-
1596
-
1597
- @dataclass
1598
- class Let(Action):
1599
- _name: str
1600
- _value: RuntimeExpr
1601
-
1602
- def __str__(self) -> str:
1603
- return f"let({self._name}, {self._value})"
1449
+ def __egg_name__(self) -> str:
1450
+ return self.name or f"combined_ruleset_{id(self)}"
1604
1451
 
1605
- def _to_egg_action(self) -> bindings.Let:
1606
- return bindings.Let(self._name, self._value.__egg__)
1452
+ def _create_egg_decls(self, *rulesets: Ruleset | UnstableCombinedRuleset) -> Declarations:
1453
+ decls = Declarations.create(*rulesets)
1454
+ decls._rulesets[self.__egg_name__] = CombinedRulesetDecl(tuple(r.__egg_name__ for r in rulesets))
1455
+ return decls
1607
1456
 
1608
- @property
1609
- def __egg_decls__(self) -> Declarations:
1610
- return self._value.__egg_decls__
1457
+ def __or__(self, other: Ruleset | UnstableCombinedRuleset) -> UnstableCombinedRuleset:
1458
+ return unstable_combine_rulesets(self, other)
1611
1459
 
1612
1460
 
1613
- @dataclass
1614
- class Set(Action):
1461
+ def unstable_combine_rulesets(
1462
+ *rulesets: Ruleset | UnstableCombinedRuleset, name: str | None = None
1463
+ ) -> UnstableCombinedRuleset:
1615
1464
  """
1616
- Similar to union, except can be used on primitive expressions, whereas union can only be used on user defined expressions.
1465
+ Combine multiple rulesets into a single ruleset.
1617
1466
  """
1618
-
1619
- _call: RuntimeExpr
1620
- _rhs: RuntimeExpr
1621
-
1622
- def __str__(self) -> str:
1623
- return f"set({self._call}).to({self._rhs})"
1624
-
1625
- def _to_egg_action(self) -> bindings.Set:
1626
- egg_call = self._call.__egg__
1627
- if not isinstance(egg_call, bindings.Call):
1628
- raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") # noqa: TRY004
1629
- return bindings.Set(
1630
- egg_call.name,
1631
- egg_call.args,
1632
- self._rhs.__egg__,
1633
- )
1634
-
1635
- @cached_property
1636
- def __egg_decls__(self) -> Declarations:
1637
- return Declarations.create(self._call, self._rhs)
1467
+ return UnstableCombinedRuleset(name, list(rulesets))
1638
1468
 
1639
1469
 
1640
1470
  @dataclass
1641
- class ExprAction(Action):
1642
- _expr: RuntimeExpr
1471
+ class RewriteOrRule:
1472
+ __egg_decls__: Declarations
1473
+ decl: RewriteOrRuleDecl
1474
+ ruleset: Ruleset | None = None
1643
1475
 
1644
1476
  def __str__(self) -> str:
1645
- return str(self._expr)
1646
-
1647
- def _to_egg_action(self) -> bindings.Expr_:
1648
- return bindings.Expr_(self._expr.__egg__)
1477
+ return pretty_decl(self.__egg_decls__, self.decl)
1649
1478
 
1650
- @property
1651
- def __egg_decls__(self) -> Declarations:
1652
- return self._expr.__egg_decls__
1479
+ def __repr__(self) -> str:
1480
+ return str(self)
1653
1481
 
1654
1482
 
1655
1483
  @dataclass
1656
- class Change(Action):
1484
+ class Fact:
1657
1485
  """
1658
- Change a function call in an EGraph.
1486
+ A query on an EGraph, either by an expression or an equivalence between multiple expressions.
1659
1487
  """
1660
1488
 
1661
- change: Literal["delete", "subsume"]
1662
- _call: RuntimeExpr
1489
+ __egg_decls__: Declarations
1490
+ fact: FactDecl
1663
1491
 
1664
1492
  def __str__(self) -> str:
1665
- return f"{self.change}({self._call})"
1666
-
1667
- def _to_egg_action(self) -> bindings.Change:
1668
- egg_call = self._call.__egg__
1669
- if not isinstance(egg_call, bindings.Call):
1670
- raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") # noqa: TRY004
1671
- change: bindings._Change = bindings.Delete() if self.change == "delete" else bindings.Subsume()
1672
- return bindings.Change(change, egg_call.name, egg_call.args)
1493
+ return pretty_decl(self.__egg_decls__, self.fact)
1673
1494
 
1674
- @property
1675
- def __egg_decls__(self) -> Declarations:
1676
- return self._call.__egg_decls__
1495
+ def __repr__(self) -> str:
1496
+ return str(self)
1677
1497
 
1678
1498
 
1679
1499
  @dataclass
1680
- class Union_(Action): # noqa: N801
1500
+ class Action:
1681
1501
  """
1682
- Merges two equivalence classes of two expressions.
1502
+ A change to an EGraph, either unioning multiple expressing, setting the value of a function call, deleting an expression, or panicking.
1683
1503
  """
1684
1504
 
1685
- _lhs: RuntimeExpr
1686
- _rhs: RuntimeExpr
1687
-
1688
- def __str__(self) -> str:
1689
- return f"union({self._lhs}).with_({self._rhs})"
1690
-
1691
- def _to_egg_action(self) -> bindings.Union:
1692
- return bindings.Union(self._lhs.__egg__, self._rhs.__egg__)
1693
-
1694
- @cached_property
1695
- def __egg_decls__(self) -> Declarations:
1696
- return Declarations.create(self._lhs, self._rhs)
1697
-
1698
-
1699
- @dataclass
1700
- class Panic(Action):
1701
- message: str
1702
-
1703
- def __str__(self) -> str:
1704
- return f"panic({self.message})"
1705
-
1706
- def _to_egg_action(self) -> bindings.Panic:
1707
- return bindings.Panic(self.message)
1708
-
1709
- @cached_property
1710
- def __egg_decls__(self) -> Declarations:
1711
- return Declarations()
1712
-
1713
-
1714
- @dataclass
1715
- class Run(Schedule):
1716
- """Configuration of a run"""
1717
-
1718
- # None if using default ruleset
1719
- ruleset: Ruleset | None
1720
- until: tuple[Fact, ...]
1505
+ __egg_decls__: Declarations
1506
+ action: ActionDecl
1721
1507
 
1722
1508
  def __str__(self) -> str:
1723
- args_str = ", ".join(map(str, [self.ruleset, *self.until]))
1724
- return f"run({args_str})"
1725
-
1726
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1727
- return bindings.Run(self._to_egg_config(default_ruleset_name))
1509
+ return pretty_decl(self.__egg_decls__, self.action)
1728
1510
 
1729
- def _to_egg_config(self, default_ruleset_name: str) -> bindings.RunConfig:
1730
- return bindings.RunConfig(
1731
- self.ruleset.egg_name if self.ruleset else default_ruleset_name,
1732
- [fact._to_egg_fact() for fact in self.until] if self.until else None,
1733
- )
1734
-
1735
- def _rulesets(self) -> Iterable[Ruleset]:
1736
- if self.ruleset:
1737
- yield self.ruleset
1738
-
1739
- @cached_property
1740
- def __egg_decls__(self) -> Declarations:
1741
- return Declarations.create(self.ruleset, *self.until)
1742
-
1743
-
1744
- @dataclass
1745
- class Saturate(Schedule):
1746
- schedule: Schedule
1747
-
1748
- def __str__(self) -> str:
1749
- return f"{self.schedule}.saturate()"
1750
-
1751
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1752
- return bindings.Saturate(self.schedule._to_egg_schedule(default_ruleset_name))
1753
-
1754
- def _rulesets(self) -> Iterable[Ruleset]:
1755
- return self.schedule._rulesets()
1756
-
1757
- @property
1758
- def __egg_decls__(self) -> Declarations:
1759
- return self.schedule.__egg_decls__
1760
-
1761
-
1762
- @dataclass
1763
- class Repeat(Schedule):
1764
- length: int
1765
- schedule: Schedule
1766
-
1767
- def __str__(self) -> str:
1768
- return f"{self.schedule} * {self.length}"
1769
-
1770
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1771
- return bindings.Repeat(self.length, self.schedule._to_egg_schedule(default_ruleset_name))
1772
-
1773
- def _rulesets(self) -> Iterable[Ruleset]:
1774
- return self.schedule._rulesets()
1775
-
1776
- @property
1777
- def __egg_decls__(self) -> Declarations:
1778
- return self.schedule.__egg_decls__
1779
-
1780
-
1781
- @dataclass
1782
- class Sequence(Schedule):
1783
- schedules: tuple[Schedule, ...]
1784
-
1785
- def __str__(self) -> str:
1786
- return f"sequence({', '.join(map(str, self.schedules))})"
1787
-
1788
- def _to_egg_schedule(self, default_ruleset_name: str) -> bindings._Schedule:
1789
- return bindings.Sequence([schedule._to_egg_schedule(default_ruleset_name) for schedule in self.schedules])
1790
-
1791
- def _rulesets(self) -> Iterable[Ruleset]:
1792
- for s in self.schedules:
1793
- yield from s._rulesets()
1794
-
1795
- @cached_property
1796
- def __egg_decls__(self) -> Declarations:
1797
- return Declarations.create(*self.schedules)
1511
+ def __repr__(self) -> str:
1512
+ return str(self)
1798
1513
 
1799
1514
 
1800
1515
  # We use these builders so that when creating these structures we can type check
@@ -1841,30 +1556,41 @@ def ne(expr: EXPR) -> _NeBuilder[EXPR]:
1841
1556
 
1842
1557
  def panic(message: str) -> Action:
1843
1558
  """Raise an error with the given message."""
1844
- return Panic(message)
1559
+ return Action(Declarations(), PanicDecl(message))
1845
1560
 
1846
1561
 
1847
1562
  def let(name: str, expr: Expr) -> Action:
1848
1563
  """Create a let binding."""
1849
- return Let(name, to_runtime_expr(expr))
1564
+ runtime_expr = to_runtime_expr(expr)
1565
+ return Action(runtime_expr.__egg_decls__, LetDecl(name, runtime_expr.__egg_typed_expr__))
1850
1566
 
1851
1567
 
1852
1568
  def expr_action(expr: Expr) -> Action:
1853
- return ExprAction(to_runtime_expr(expr))
1569
+ runtime_expr = to_runtime_expr(expr)
1570
+ return Action(runtime_expr.__egg_decls__, ExprActionDecl(runtime_expr.__egg_typed_expr__))
1854
1571
 
1855
1572
 
1856
1573
  def delete(expr: Expr) -> Action:
1857
1574
  """Create a delete expression."""
1858
- return Change("delete", to_runtime_expr(expr))
1575
+ runtime_expr = to_runtime_expr(expr)
1576
+ typed_expr = runtime_expr.__egg_typed_expr__
1577
+ call_decl = typed_expr.expr
1578
+ assert isinstance(call_decl, CallDecl), "Can only delete calls, not literals or vars"
1579
+ return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "delete"))
1859
1580
 
1860
1581
 
1861
1582
  def subsume(expr: Expr) -> Action:
1862
- """Subsume an expression."""
1863
- return Change("subsume", to_runtime_expr(expr))
1583
+ """Subsume an expression so it cannot be matched against or extracted"""
1584
+ runtime_expr = to_runtime_expr(expr)
1585
+ typed_expr = runtime_expr.__egg_typed_expr__
1586
+ call_decl = typed_expr.expr
1587
+ assert isinstance(call_decl, CallDecl), "Can only subsume calls, not literals or vars"
1588
+ return Action(runtime_expr.__egg_decls__, ChangeDecl(typed_expr.tp, call_decl, "subsume"))
1864
1589
 
1865
1590
 
1866
1591
  def expr_fact(expr: Expr) -> Fact:
1867
- return ExprFact(to_runtime_expr(expr))
1592
+ runtime_expr = to_runtime_expr(expr)
1593
+ return Fact(runtime_expr.__egg_decls__, ExprFactDecl(runtime_expr.__egg_typed_expr__))
1868
1594
 
1869
1595
 
1870
1596
  def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
@@ -1891,6 +1617,11 @@ def rule(*facts: FactLike, ruleset: Ruleset | None = None, name: str | None = No
1891
1617
  return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
1892
1618
 
1893
1619
 
1620
+ @deprecated("This function is now a no-op, you can remove it and use actions as commands")
1621
+ def action_command(action: Action) -> Action:
1622
+ return action
1623
+
1624
+
1894
1625
  def var(name: str, bound: type[EXPR]) -> EXPR:
1895
1626
  """Create a new variable with the given name and type."""
1896
1627
  return cast(EXPR, _var(name, bound))
@@ -1898,9 +1629,9 @@ def var(name: str, bound: type[EXPR]) -> EXPR:
1898
1629
 
1899
1630
  def _var(name: str, bound: object) -> RuntimeExpr:
1900
1631
  """Create a new variable with the given name and type."""
1901
- if not isinstance(bound, RuntimeClass | RuntimeParamaterizedClass):
1902
- raise TypeError(f"Unexpected type {type(bound)}")
1903
- return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(class_to_ref(bound), VarDecl(name)))
1632
+ decls = Declarations()
1633
+ type_ref = resolve_type_annotation(decls, bound)
1634
+ return RuntimeExpr.__from_value__(decls, TypedExprDecl(type_ref.to_just(), VarDecl(name)))
1904
1635
 
1905
1636
 
1906
1637
  def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
@@ -1915,15 +1646,27 @@ class _RewriteBuilder(Generic[EXPR]):
1915
1646
  ruleset: Ruleset | None
1916
1647
  subsume: bool
1917
1648
 
1918
- def to(self, rhs: EXPR, *conditions: FactLike) -> Rewrite:
1649
+ def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
1919
1650
  lhs = to_runtime_expr(self.lhs)
1920
- rule = Rewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), self.subsume)
1651
+ facts = _fact_likes(conditions)
1652
+ rhs = convert_to_same_type(rhs, lhs)
1653
+ rule = RewriteOrRule(
1654
+ Declarations.create(lhs, rhs, *facts, self.ruleset),
1655
+ RewriteDecl(
1656
+ lhs.__egg_typed_expr__.tp,
1657
+ lhs.__egg_typed_expr__.expr,
1658
+ rhs.__egg_typed_expr__.expr,
1659
+ tuple(f.fact for f in facts),
1660
+ self.subsume,
1661
+ ),
1662
+ )
1921
1663
  if self.ruleset:
1922
1664
  self.ruleset.append(rule)
1923
1665
  return rule
1924
1666
 
1925
1667
  def __str__(self) -> str:
1926
- return f"rewrite({self.lhs})"
1668
+ lhs = to_runtime_expr(self.lhs)
1669
+ return lhs.__egg_pretty__("rewrite")
1927
1670
 
1928
1671
 
1929
1672
  @dataclass
@@ -1931,15 +1674,26 @@ class _BirewriteBuilder(Generic[EXPR]):
1931
1674
  lhs: EXPR
1932
1675
  ruleset: Ruleset | None
1933
1676
 
1934
- def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
1677
+ def to(self, rhs: EXPR, *conditions: FactLike) -> RewriteOrRule:
1935
1678
  lhs = to_runtime_expr(self.lhs)
1936
- rule = BiRewrite(self.ruleset, lhs, convert_to_same_type(rhs, lhs), _fact_likes(conditions), False)
1679
+ facts = _fact_likes(conditions)
1680
+ rhs = convert_to_same_type(rhs, lhs)
1681
+ rule = RewriteOrRule(
1682
+ Declarations.create(lhs, rhs, *facts, self.ruleset),
1683
+ BiRewriteDecl(
1684
+ lhs.__egg_typed_expr__.tp,
1685
+ lhs.__egg_typed_expr__.expr,
1686
+ rhs.__egg_typed_expr__.expr,
1687
+ tuple(f.fact for f in facts),
1688
+ ),
1689
+ )
1937
1690
  if self.ruleset:
1938
1691
  self.ruleset.append(rule)
1939
1692
  return rule
1940
1693
 
1941
1694
  def __str__(self) -> str:
1942
- return f"birewrite({self.lhs})"
1695
+ lhs = to_runtime_expr(self.lhs)
1696
+ return lhs.__egg_pretty__("birewrite")
1943
1697
 
1944
1698
 
1945
1699
  @dataclass
@@ -1948,52 +1702,84 @@ class _EqBuilder(Generic[EXPR]):
1948
1702
 
1949
1703
  def to(self, *exprs: EXPR) -> Fact:
1950
1704
  expr = to_runtime_expr(self.expr)
1951
- return Eq([expr] + [convert_to_same_type(e, expr) for e in exprs])
1705
+ args = [expr, *(convert_to_same_type(e, expr) for e in exprs)]
1706
+ return Fact(
1707
+ Declarations.create(*args),
1708
+ EqDecl(expr.__egg_typed_expr__.tp, tuple(a.__egg_typed_expr__.expr for a in args)),
1709
+ )
1710
+
1711
+ def __repr__(self) -> str:
1712
+ return str(self)
1952
1713
 
1953
1714
  def __str__(self) -> str:
1954
- return f"eq({self.expr})"
1715
+ expr = to_runtime_expr(self.expr)
1716
+ return expr.__egg_pretty__("eq")
1955
1717
 
1956
1718
 
1957
1719
  @dataclass
1958
1720
  class _NeBuilder(Generic[EXPR]):
1959
- expr: EXPR
1721
+ lhs: EXPR
1960
1722
 
1961
- def to(self, expr: EXPR) -> Unit:
1962
- assert isinstance(self.expr, RuntimeExpr)
1963
- args = (self.expr, convert_to_same_type(expr, self.expr))
1964
- decls = Declarations.create(*args)
1965
- res = RuntimeExpr(
1966
- decls,
1967
- TypedExprDecl(JustTypeRef("Unit"), CallDecl(FunctionRef("!="), tuple(a.__egg_typed_expr__ for a in args))),
1723
+ def to(self, rhs: EXPR) -> Unit:
1724
+ lhs = to_runtime_expr(self.lhs)
1725
+ rhs = convert_to_same_type(rhs, lhs)
1726
+ assert isinstance(Unit, RuntimeClass)
1727
+ res = RuntimeExpr.__from_value__(
1728
+ Declarations.create(Unit, lhs, rhs),
1729
+ TypedExprDecl(
1730
+ JustTypeRef("Unit"), CallDecl(FunctionRef("!="), (lhs.__egg_typed_expr__, rhs.__egg_typed_expr__))
1731
+ ),
1968
1732
  )
1969
1733
  return cast(Unit, res)
1970
1734
 
1735
+ def __repr__(self) -> str:
1736
+ return str(self)
1737
+
1971
1738
  def __str__(self) -> str:
1972
- return f"ne({self.expr})"
1739
+ expr = to_runtime_expr(self.lhs)
1740
+ return expr.__egg_pretty__("ne")
1973
1741
 
1974
1742
 
1975
1743
  @dataclass
1976
1744
  class _SetBuilder(Generic[EXPR]):
1977
- lhs: Expr
1745
+ lhs: EXPR
1978
1746
 
1979
- def to(self, rhs: EXPR) -> Set:
1747
+ def to(self, rhs: EXPR) -> Action:
1980
1748
  lhs = to_runtime_expr(self.lhs)
1981
- return Set(lhs, convert_to_same_type(rhs, lhs))
1749
+ rhs = convert_to_same_type(rhs, lhs)
1750
+ lhs_expr = lhs.__egg_typed_expr__.expr
1751
+ assert isinstance(lhs_expr, CallDecl), "Can only set function calls"
1752
+ return Action(
1753
+ Declarations.create(lhs, rhs),
1754
+ SetDecl(lhs.__egg_typed_expr__.tp, lhs_expr, rhs.__egg_typed_expr__.expr),
1755
+ )
1756
+
1757
+ def __repr__(self) -> str:
1758
+ return str(self)
1982
1759
 
1983
1760
  def __str__(self) -> str:
1984
- return f"set_({self.lhs})"
1761
+ lhs = to_runtime_expr(self.lhs)
1762
+ return lhs.__egg_pretty__("set_")
1985
1763
 
1986
1764
 
1987
1765
  @dataclass
1988
1766
  class _UnionBuilder(Generic[EXPR]):
1989
- lhs: Expr
1767
+ lhs: EXPR
1990
1768
 
1991
1769
  def with_(self, rhs: EXPR) -> Action:
1992
1770
  lhs = to_runtime_expr(self.lhs)
1993
- return Union_(lhs, convert_to_same_type(rhs, lhs))
1771
+ rhs = convert_to_same_type(rhs, lhs)
1772
+ return Action(
1773
+ Declarations.create(lhs, rhs),
1774
+ UnionDecl(lhs.__egg_typed_expr__.tp, lhs.__egg_typed_expr__.expr, rhs.__egg_typed_expr__.expr),
1775
+ )
1776
+
1777
+ def __repr__(self) -> str:
1778
+ return str(self)
1994
1779
 
1995
1780
  def __str__(self) -> str:
1996
- return f"union({self.lhs})"
1781
+ lhs = to_runtime_expr(self.lhs)
1782
+ return lhs.__egg_pretty__("union")
1997
1783
 
1998
1784
 
1999
1785
  @dataclass
@@ -2002,12 +1788,25 @@ class _RuleBuilder:
2002
1788
  name: str | None
2003
1789
  ruleset: Ruleset | None
2004
1790
 
2005
- def then(self, *actions: ActionLike) -> Rule:
2006
- rule = Rule(_action_likes(actions), self.facts, self.name or "", self.ruleset)
1791
+ def then(self, *actions: ActionLike) -> RewriteOrRule:
1792
+ actions = _action_likes(actions)
1793
+ rule = RewriteOrRule(
1794
+ Declarations.create(self.ruleset, *actions, *self.facts),
1795
+ RuleDecl(tuple(a.action for a in actions), tuple(f.fact for f in self.facts), self.name),
1796
+ )
2007
1797
  if self.ruleset:
2008
1798
  self.ruleset.append(rule)
2009
1799
  return rule
2010
1800
 
1801
+ def __str__(self) -> str:
1802
+ # TODO: Figure out how to stringify rulebuilder that preserves statements
1803
+ args = list(map(str, self.facts))
1804
+ if self.name is not None:
1805
+ args.append(f"name={self.name}")
1806
+ if ruleset is not None:
1807
+ args.append(f"ruleset={self.ruleset}")
1808
+ return f"rule({', '.join(args)})"
1809
+
2011
1810
 
2012
1811
  def expr_parts(expr: Expr) -> TypedExprDecl:
2013
1812
  """
@@ -2024,60 +1823,63 @@ def to_runtime_expr(expr: Expr) -> RuntimeExpr:
2024
1823
  return expr
2025
1824
 
2026
1825
 
2027
- def run(ruleset: Ruleset | None = None, *until: Fact) -> Run:
1826
+ def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
2028
1827
  """
2029
1828
  Create a run configuration.
2030
1829
  """
2031
- return Run(ruleset, tuple(until))
1830
+ facts = _fact_likes(until)
1831
+ return Schedule(
1832
+ Thunk.fn(Declarations.create, ruleset, *facts),
1833
+ RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
1834
+ )
2032
1835
 
2033
1836
 
2034
1837
  def seq(*schedules: Schedule) -> Schedule:
2035
1838
  """
2036
1839
  Run a sequence of schedules.
2037
1840
  """
2038
- return Sequence(tuple(schedules))
1841
+ return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules)))
2039
1842
 
2040
1843
 
2041
- CommandLike = Command | Expr
1844
+ ActionLike: TypeAlias = Action | Expr
2042
1845
 
2043
1846
 
2044
- def _command_like(command_like: CommandLike) -> Command:
2045
- if isinstance(command_like, Expr):
2046
- return expr_action(command_like)
2047
- return command_like
1847
+ def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
1848
+ return tuple(map(_action_like, action_likes))
2048
1849
 
2049
1850
 
2050
- CommandGenerator = Callable[..., Iterable[Rule | Rewrite]]
1851
+ def _action_like(action_like: ActionLike) -> Action:
1852
+ if isinstance(action_like, Expr):
1853
+ return expr_action(action_like)
1854
+ return action_like
2051
1855
 
2052
1856
 
2053
- def _command_generator(gen: CommandGenerator) -> Iterable[Command]:
2054
- """
2055
- Calls the function with variables of the type and name of the arguments.
2056
- """
2057
- # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
2058
- # but not in the globals
2059
- current_frame = inspect.currentframe()
2060
- assert current_frame
2061
- register_frame = current_frame.f_back
2062
- assert register_frame
2063
- original_frame = register_frame.f_back
2064
- assert original_frame
2065
- hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals)
2066
- args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
2067
- return gen(*args)
1857
+ Command: TypeAlias = Action | RewriteOrRule
2068
1858
 
1859
+ CommandLike: TypeAlias = ActionLike | RewriteOrRule
2069
1860
 
2070
- ActionLike = Action | Expr
2071
1861
 
1862
+ def _command_like(command_like: CommandLike) -> Command:
1863
+ if isinstance(command_like, RewriteOrRule):
1864
+ return command_like
1865
+ return _action_like(command_like)
2072
1866
 
2073
- def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
2074
- return tuple(map(_action_like, action_likes))
2075
1867
 
1868
+ RewriteOrRuleGenerator = Callable[..., Iterable[RewriteOrRule]]
2076
1869
 
2077
- def _action_like(action_like: ActionLike) -> Action:
2078
- if isinstance(action_like, Expr):
2079
- return expr_action(action_like)
2080
- return action_like
1870
+
1871
+ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> Iterable[RewriteOrRule]:
1872
+ """
1873
+ Returns a thunk which will call the function with variables of the type and name of the arguments.
1874
+ """
1875
+ # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
1876
+ # but not in the globals
1877
+ globals = gen.__globals__.copy()
1878
+ if "Callable" not in globals:
1879
+ globals["Callable"] = Callable
1880
+ hints = get_type_hints(gen, globals, frame.f_locals)
1881
+ args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1882
+ return list(gen(*args)) # type: ignore[misc]
2081
1883
 
2082
1884
 
2083
1885
  FactLike = Fact | Expr