egglog 9.0.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/declarations.py ADDED
@@ -0,0 +1,818 @@
1
+ """
2
+ Data only descriptions of the components of an egraph and the expressions.
3
+
4
+ We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from functools import cached_property
11
+ from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
12
+ from weakref import WeakValueDictionary
13
+
14
+ from typing_extensions import Self, assert_never
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Callable, Iterable, Mapping
18
+
19
+
20
+ __all__ = [
21
+ "ActionCommandDecl",
22
+ "ActionDecl",
23
+ "BiRewriteDecl",
24
+ "CallDecl",
25
+ "CallableDecl",
26
+ "CallableRef",
27
+ "ChangeDecl",
28
+ "ClassDecl",
29
+ "ClassMethodRef",
30
+ "ClassTypeVarRef",
31
+ "ClassVariableRef",
32
+ "CombinedRulesetDecl",
33
+ "CommandDecl",
34
+ "ConstantDecl",
35
+ "ConstantRef",
36
+ "ConstructorDecl",
37
+ "Declarations",
38
+ "Declarations",
39
+ "DeclerationsLike",
40
+ "DefaultRewriteDecl",
41
+ "DelayedDeclerations",
42
+ "EqDecl",
43
+ "ExprActionDecl",
44
+ "ExprDecl",
45
+ "ExprFactDecl",
46
+ "FactDecl",
47
+ "FunctionDecl",
48
+ "FunctionRef",
49
+ "FunctionSignature",
50
+ "HasDeclerations",
51
+ "InitRef",
52
+ "JustTypeRef",
53
+ "LetDecl",
54
+ "LitDecl",
55
+ "LitType",
56
+ "MethodRef",
57
+ "PanicDecl",
58
+ "PartialCallDecl",
59
+ "PropertyRef",
60
+ "PyObjectDecl",
61
+ "RelationDecl",
62
+ "RepeatDecl",
63
+ "RewriteDecl",
64
+ "RewriteOrRuleDecl",
65
+ "RuleDecl",
66
+ "RulesetDecl",
67
+ "RunDecl",
68
+ "SaturateDecl",
69
+ "ScheduleDecl",
70
+ "SequenceDecl",
71
+ "SetDecl",
72
+ "SpecialFunctions",
73
+ "TypeOrVarRef",
74
+ "TypeRefWithVars",
75
+ "TypedExprDecl",
76
+ "UnionDecl",
77
+ "UnnamedFunctionRef",
78
+ "VarDecl",
79
+ "replace_typed_expr",
80
+ "upcast_declerations",
81
+ ]
82
+
83
+
84
+ @dataclass
85
+ class DelayedDeclerations:
86
+ __egg_decls_thunk__: Callable[[], Declarations] = field(repr=False)
87
+
88
+ @property
89
+ def __egg_decls__(self) -> Declarations:
90
+ thunk = self.__egg_decls_thunk__
91
+ try:
92
+ return thunk()
93
+ # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
94
+ # instead raise explicitly
95
+ except AttributeError as err:
96
+ msg = f"Cannot resolve declerations for {self}"
97
+ raise RuntimeError(msg) from err
98
+
99
+
100
+ @runtime_checkable
101
+ class HasDeclerations(Protocol):
102
+ @property
103
+ def __egg_decls__(self) -> Declarations: ...
104
+
105
+
106
+ DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
107
+
108
+
109
+ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
110
+ d = []
111
+ for l in declerations_like:
112
+ if l is None:
113
+ continue
114
+ if isinstance(l, HasDeclerations):
115
+ d.append(l.__egg_decls__)
116
+ elif isinstance(l, Declarations):
117
+ d.append(l)
118
+ else:
119
+ assert_never(l)
120
+ return d
121
+
122
+
123
+ @dataclass
124
+ class Declarations:
125
+ _unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
126
+ _functions: dict[str, FunctionDecl | RelationDecl | ConstructorDecl] = field(default_factory=dict)
127
+ _constants: dict[str, ConstantDecl] = field(default_factory=dict)
128
+ _classes: dict[str, ClassDecl] = field(default_factory=dict)
129
+ _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
130
+
131
+ @property
132
+ def default_ruleset(self) -> RulesetDecl:
133
+ ruleset = self._rulesets[""]
134
+ assert isinstance(ruleset, RulesetDecl)
135
+ return ruleset
136
+
137
+ @classmethod
138
+ def create(cls, *others: DeclerationsLike) -> Declarations:
139
+ others = upcast_declerations(others)
140
+ if not others:
141
+ return Declarations()
142
+ first, *rest = others
143
+ if not rest:
144
+ return first
145
+ new = first.copy()
146
+ new.update(*rest)
147
+ return new
148
+
149
+ def copy(self) -> Declarations:
150
+ new = Declarations()
151
+ self.update_other(new)
152
+ return new
153
+
154
+ def update(self, *others: DeclerationsLike) -> None:
155
+ for other in others:
156
+ self |= other
157
+
158
+ def __or__(self, other: DeclerationsLike) -> Declarations:
159
+ result = self.copy()
160
+ result |= other
161
+ return result
162
+
163
+ def __ior__(self, other: DeclerationsLike) -> Self:
164
+ if other is None:
165
+ return self
166
+ if isinstance(other, HasDeclerations):
167
+ other = other.__egg_decls__
168
+ other.update_other(self)
169
+ return self
170
+
171
+ def update_other(self, other: Declarations) -> None:
172
+ """
173
+ Updates the other decl with these values in palce.
174
+ """
175
+ other._functions |= self._functions
176
+ other._classes |= self._classes
177
+ other._constants |= self._constants
178
+ # Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
179
+ # is added to functions.
180
+ combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
181
+ other._rulesets |= self._rulesets
182
+ other._rulesets[""] = RulesetDecl(list(combined_default_rules))
183
+
184
+ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
185
+ match ref:
186
+ case FunctionRef(name):
187
+ return self._functions[name]
188
+ case ConstantRef(name):
189
+ return self._constants[name]
190
+ case MethodRef(class_name, method_name):
191
+ return self._classes[class_name].methods[method_name]
192
+ case ClassVariableRef(class_name, name):
193
+ return self._classes[class_name].class_variables[name]
194
+ case ClassMethodRef(class_name, name):
195
+ return self._classes[class_name].class_methods[name]
196
+ case PropertyRef(class_name, property_name):
197
+ return self._classes[class_name].properties[property_name]
198
+ case InitRef(class_name):
199
+ init_fn = self._classes[class_name].init
200
+ assert init_fn, f"Class {class_name} does not have an init function."
201
+ return init_fn
202
+ case UnnamedFunctionRef():
203
+ return ConstructorDecl(ref.signature)
204
+
205
+ assert_never(ref)
206
+
207
+ def set_function_decl(
208
+ self,
209
+ ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef,
210
+ decl: FunctionDecl | ConstructorDecl,
211
+ ) -> None:
212
+ match ref:
213
+ case FunctionRef(name):
214
+ self._functions[name] = decl
215
+ case MethodRef(class_name, method_name):
216
+ self._classes[class_name].methods[method_name] = decl
217
+ case ClassMethodRef(class_name, name):
218
+ self._classes[class_name].class_methods[name] = decl
219
+ case PropertyRef(class_name, property_name):
220
+ self._classes[class_name].properties[property_name] = decl
221
+ case InitRef(class_name):
222
+ self._classes[class_name].init = decl
223
+ case _:
224
+ assert_never(ref)
225
+
226
+ def has_method(self, class_name: str, method_name: str) -> bool | None:
227
+ """
228
+ Returns whether the given class has the given method, or None if we cant find the class.
229
+ """
230
+ if class_name in self._classes:
231
+ return method_name in self._classes[class_name].methods
232
+ return None
233
+
234
+ def get_class_decl(self, name: str) -> ClassDecl:
235
+ return self._classes[name]
236
+
237
+ def get_paramaterized_class(self, name: str) -> TypeRefWithVars:
238
+ """
239
+ Returns a class reference with type parameters, if the class is paramaterized.
240
+ """
241
+ type_vars = self._classes[name].type_vars
242
+ return TypeRefWithVars(name, type_vars)
243
+
244
+
245
+ @dataclass
246
+ class ClassDecl:
247
+ egg_name: str | None = None
248
+ type_vars: tuple[ClassTypeVarRef, ...] = ()
249
+ builtin: bool = False
250
+ init: ConstructorDecl | FunctionDecl | None = None
251
+ class_methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
252
+ # These have to be seperate from class_methods so that printing them can be done easily
253
+ class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
254
+ methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
255
+ properties: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
256
+ preserved_methods: dict[str, Callable] = field(default_factory=dict)
257
+
258
+
259
+ @dataclass(frozen=True)
260
+ class RulesetDecl:
261
+ rules: list[RewriteOrRuleDecl]
262
+
263
+ # Make hashable so when traversing for pretty-fying we can know which rulesets we have already
264
+ # made into strings
265
+ def __hash__(self) -> int:
266
+ return hash((type(self), tuple(self.rules)))
267
+
268
+
269
+ @dataclass(frozen=True)
270
+ class CombinedRulesetDecl:
271
+ rulesets: tuple[str, ...]
272
+
273
+
274
+ # Have two different types of type refs, one that can include vars recursively and one that cannot.
275
+ # We only use the one with vars for classmethods and methods, and the other one for egg references as
276
+ # well as runtime values.
277
+ @dataclass(frozen=True)
278
+ class JustTypeRef:
279
+ name: str
280
+ args: tuple[JustTypeRef, ...] = ()
281
+
282
+ def to_var(self) -> TypeRefWithVars:
283
+ return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args))
284
+
285
+ def __str__(self) -> str:
286
+ if self.args:
287
+ return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
288
+ return self.name
289
+
290
+
291
+ ##
292
+ # Type references with vars
293
+ ##
294
+
295
+ # mapping of name and module of resolved typevars to runtime values
296
+ # so that when spitting them back out again can use same instance
297
+ # since equality is based on identity not value
298
+ _RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}
299
+
300
+
301
+ @dataclass(frozen=True)
302
+ class ClassTypeVarRef:
303
+ """
304
+ A class type variable represents one of the types of the class, if it is a generic class.
305
+ """
306
+
307
+ name: str
308
+ module: str
309
+
310
+ def to_just(self) -> JustTypeRef:
311
+ msg = "egglog does not support generic classes yet."
312
+ raise NotImplementedError(msg)
313
+
314
+ def __str__(self) -> str:
315
+ return f"{self.module}.{self.name}"
316
+
317
+ @classmethod
318
+ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef:
319
+ res = cls(typevar.__name__, typevar.__module__)
320
+ _RESOLVED_TYPEVARS[res] = typevar
321
+ return res
322
+
323
+ def to_type_var(self) -> TypeVar:
324
+ return _RESOLVED_TYPEVARS[self]
325
+
326
+
327
+ @dataclass(frozen=True)
328
+ class TypeRefWithVars:
329
+ name: str
330
+ args: tuple[TypeOrVarRef, ...] = ()
331
+
332
+ def to_just(self) -> JustTypeRef:
333
+ return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
334
+
335
+ def __str__(self) -> str:
336
+ if self.args:
337
+ return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
338
+ return self.name
339
+
340
+
341
+ TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
342
+
343
+ ##
344
+ # Callables References
345
+ ##
346
+
347
+
348
+ @dataclass(frozen=True)
349
+ class UnnamedFunctionRef:
350
+ """
351
+ A reference to a function that doesn't have a name, but does have a body.
352
+ """
353
+
354
+ # tuple of var arg names and their types
355
+ args: tuple[TypedExprDecl, ...]
356
+ res: TypedExprDecl
357
+
358
+ @property
359
+ def signature(self) -> FunctionSignature:
360
+ arg_types = []
361
+ arg_names = []
362
+ for a in self.args:
363
+ arg_types.append(a.tp.to_var())
364
+ assert isinstance(a.expr, VarDecl)
365
+ arg_names.append(a.expr.name)
366
+ return FunctionSignature(
367
+ arg_types=tuple(arg_types),
368
+ arg_names=tuple(arg_names),
369
+ arg_defaults=(None,) * len(self.args),
370
+ return_type=self.res.tp.to_var(),
371
+ )
372
+
373
+ @property
374
+ def egg_name(self) -> None | str:
375
+ return None
376
+
377
+
378
+ @dataclass(frozen=True)
379
+ class FunctionRef:
380
+ name: str
381
+
382
+
383
+ @dataclass(frozen=True)
384
+ class ConstantRef:
385
+ name: str
386
+
387
+
388
+ @dataclass(frozen=True)
389
+ class MethodRef:
390
+ class_name: str
391
+ method_name: str
392
+
393
+
394
+ @dataclass(frozen=True)
395
+ class ClassMethodRef:
396
+ class_name: str
397
+ method_name: str
398
+
399
+
400
+ @dataclass(frozen=True)
401
+ class InitRef:
402
+ class_name: str
403
+
404
+
405
+ @dataclass(frozen=True)
406
+ class ClassVariableRef:
407
+ class_name: str
408
+ var_name: str
409
+
410
+
411
+ @dataclass(frozen=True)
412
+ class PropertyRef:
413
+ class_name: str
414
+ property_name: str
415
+
416
+
417
+ CallableRef: TypeAlias = (
418
+ FunctionRef
419
+ | ConstantRef
420
+ | MethodRef
421
+ | ClassMethodRef
422
+ | InitRef
423
+ | ClassVariableRef
424
+ | PropertyRef
425
+ | UnnamedFunctionRef
426
+ )
427
+
428
+
429
+ ##
430
+ # Callables
431
+ ##
432
+
433
+
434
+ @dataclass(frozen=True)
435
+ class RelationDecl:
436
+ arg_types: tuple[JustTypeRef, ...]
437
+ # List of defaults. None for any arg which doesn't have one.
438
+ arg_defaults: tuple[ExprDecl | None, ...]
439
+ egg_name: str | None
440
+
441
+ @property
442
+ def signature(self) -> FunctionSignature:
443
+ return FunctionSignature(
444
+ arg_types=tuple(a.to_var() for a in self.arg_types),
445
+ arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
446
+ arg_defaults=self.arg_defaults,
447
+ return_type=TypeRefWithVars("Unit"),
448
+ )
449
+
450
+
451
+ @dataclass(frozen=True)
452
+ class ConstantDecl:
453
+ """
454
+ Same as `(declare)` in egglog
455
+ """
456
+
457
+ type_ref: JustTypeRef
458
+ egg_name: str | None = None
459
+
460
+ @property
461
+ def signature(self) -> FunctionSignature:
462
+ return FunctionSignature(return_type=self.type_ref.to_var())
463
+
464
+
465
+ # special cases for partial function creation and application, which cannot use the normal python rules
466
+ SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
467
+
468
+
469
+ @dataclass(frozen=True)
470
+ class FunctionSignature:
471
+ arg_types: tuple[TypeOrVarRef, ...] = ()
472
+ arg_names: tuple[str, ...] = ()
473
+ # List of defaults. None for any arg which doesn't have one.
474
+ arg_defaults: tuple[ExprDecl | None, ...] = ()
475
+ # If None, then the first arg is mutated and returned
476
+ return_type: TypeOrVarRef | None = None
477
+ var_arg_type: TypeOrVarRef | None = None
478
+ # Whether to reverse args when emitting to egglog
479
+ reverse_args: bool = False
480
+
481
+ @property
482
+ def semantic_return_type(self) -> TypeOrVarRef:
483
+ """
484
+ The type that is returned by the function, which wil be in the first arg if it mutates it.
485
+ """
486
+ return self.return_type or self.arg_types[0]
487
+
488
+ @property
489
+ def mutates(self) -> bool:
490
+ return self.return_type is None
491
+
492
+
493
+ @dataclass(frozen=True)
494
+ class FunctionDecl:
495
+ signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
496
+ builtin: bool = False
497
+ egg_name: str | None = None
498
+ merge: ExprDecl | None = None
499
+
500
+
501
+ @dataclass(frozen=True)
502
+ class ConstructorDecl:
503
+ signature: FunctionSignature = field(default_factory=FunctionSignature)
504
+ egg_name: str | None = None
505
+ cost: int | None = None
506
+ unextractable: bool = False
507
+
508
+
509
+ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl | ConstructorDecl
510
+
511
+ ##
512
+ # Expressions
513
+ ##
514
+
515
+
516
+ @dataclass(frozen=True)
517
+ class VarDecl:
518
+ name: str
519
+ # Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
520
+ is_let: bool
521
+
522
+
523
+ @dataclass(frozen=True)
524
+ class PyObjectDecl:
525
+ value: object
526
+
527
+ def __hash__(self) -> int:
528
+ """Tries using the hash of the value, if unhashable use the ID."""
529
+ try:
530
+ return hash((type(self.value), self.value))
531
+ except TypeError:
532
+ return id(self.value)
533
+
534
+ def __eq__(self, other: object) -> bool:
535
+ if not isinstance(other, PyObjectDecl):
536
+ return False
537
+ return self.parts == other.parts
538
+
539
+ @property
540
+ def parts(self) -> tuple[type, object]:
541
+ return (type(self.value), self.value)
542
+
543
+
544
+ LitType: TypeAlias = int | str | float | bool | None
545
+
546
+
547
+ @dataclass(frozen=True)
548
+ class LitDecl:
549
+ value: LitType
550
+
551
+ def __hash__(self) -> int:
552
+ """
553
+ Include type in has so that 1.0 != 1
554
+ """
555
+ return hash(self.parts)
556
+
557
+ def __eq__(self, other: object) -> bool:
558
+ if not isinstance(other, LitDecl):
559
+ return False
560
+ return self.parts == other.parts
561
+
562
+ @property
563
+ def parts(self) -> tuple[type, LitType]:
564
+ return (type(self.value), self.value)
565
+
566
+
567
+ @dataclass(frozen=True)
568
+ class CallDecl:
569
+ callable: CallableRef
570
+ # TODO: Can I make these not typed expressions?
571
+ args: tuple[TypedExprDecl, ...] = ()
572
+ # type parameters that were bound to the callable, if it is a classmethod
573
+ # Used for pretty printing classmethod calls with type parameters
574
+ bound_tp_params: tuple[JustTypeRef, ...] | None = None
575
+
576
+ # pool objects for faster __eq__
577
+ _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})
578
+
579
+ def __new__(cls, *args: object, **kwargs: object) -> Self:
580
+ """
581
+ Pool CallDecls so that they can be compared by identity more quickly.
582
+
583
+ Neccessary bc we search for common parents when serializing CallDecl trees to egglog to
584
+ only serialize each sub-tree once.
585
+ """
586
+ # normalize the args/kwargs to a tuple so that they can be compared
587
+ callable = args[0] if args else kwargs["callable"]
588
+ args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
589
+ bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params")
590
+
591
+ normalized_args = (callable, args_, bound_tp_params)
592
+ try:
593
+ return cast("Self", cls._args_to_value[normalized_args])
594
+ except KeyError:
595
+ res = super().__new__(cls)
596
+ cls._args_to_value[normalized_args] = res
597
+ return res
598
+
599
+ def __post_init__(self) -> None:
600
+ if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
601
+ msg = "Cannot bind type parameters to a non-class method callable."
602
+ raise ValueError(msg)
603
+
604
+ def __hash__(self) -> int:
605
+ return self._cached_hash
606
+
607
+ @cached_property
608
+ def _cached_hash(self) -> int:
609
+ return hash((self.callable, self.args, self.bound_tp_params))
610
+
611
+ def __eq__(self, other: object) -> bool:
612
+ return self is other
613
+
614
+ def __ne__(self, other: object) -> bool:
615
+ return self is not other
616
+
617
+
618
+ @dataclass(frozen=True)
619
+ class PartialCallDecl:
620
+ """
621
+ A partially applied function aka a function sort.
622
+
623
+ Note it does not need to have any args, in which case it's just a function pointer.
624
+
625
+ Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
626
+ """
627
+
628
+ call: CallDecl
629
+
630
+
631
+ ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
632
+
633
+
634
+ @dataclass(frozen=True)
635
+ class TypedExprDecl:
636
+ tp: JustTypeRef
637
+ expr: ExprDecl
638
+
639
+ def descendants(self) -> list[TypedExprDecl]:
640
+ """
641
+ Returns a list of all the descendants of this expression.
642
+ """
643
+ l = [self]
644
+ if isinstance(self.expr, CallDecl):
645
+ for a in self.expr.args:
646
+ l.extend(a.descendants())
647
+ return l
648
+
649
+
650
+ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
651
+ """
652
+ Replace all the typed expressions in the given typed expression with the replacements.
653
+ """
654
+ # keep track of the traversed expressions for memoization
655
+ traversed: dict[TypedExprDecl, TypedExprDecl] = {}
656
+
657
+ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
658
+ if typed_expr in traversed:
659
+ return traversed[typed_expr]
660
+ if typed_expr in replacements:
661
+ res = replacements[typed_expr]
662
+ else:
663
+ match typed_expr.expr:
664
+ case CallDecl(callable, args, bound_tp_params) | PartialCallDecl(
665
+ CallDecl(callable, args, bound_tp_params)
666
+ ):
667
+ new_args = tuple(_inner(a) for a in args)
668
+ call_decl = CallDecl(callable, new_args, bound_tp_params)
669
+ res = TypedExprDecl(
670
+ typed_expr.tp,
671
+ call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
672
+ )
673
+ case _:
674
+ res = typed_expr
675
+ traversed[typed_expr] = res
676
+ return res
677
+
678
+ return _inner(typed_expr)
679
+
680
+
681
+ ##
682
+ # Schedules
683
+ ##
684
+
685
+
686
+ @dataclass(frozen=True)
687
+ class SaturateDecl:
688
+ schedule: ScheduleDecl
689
+
690
+
691
+ @dataclass(frozen=True)
692
+ class RepeatDecl:
693
+ schedule: ScheduleDecl
694
+ times: int
695
+
696
+
697
+ @dataclass(frozen=True)
698
+ class SequenceDecl:
699
+ schedules: tuple[ScheduleDecl, ...]
700
+
701
+
702
+ @dataclass(frozen=True)
703
+ class RunDecl:
704
+ ruleset: str
705
+ until: tuple[FactDecl, ...] | None
706
+
707
+
708
+ ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
709
+
710
+ ##
711
+ # Facts
712
+ ##
713
+
714
+
715
+ @dataclass(frozen=True)
716
+ class EqDecl:
717
+ tp: JustTypeRef
718
+ left: ExprDecl
719
+ right: ExprDecl
720
+
721
+
722
+ @dataclass(frozen=True)
723
+ class ExprFactDecl:
724
+ typed_expr: TypedExprDecl
725
+
726
+
727
+ FactDecl: TypeAlias = EqDecl | ExprFactDecl
728
+
729
+ ##
730
+ # Actions
731
+ ##
732
+
733
+
734
+ @dataclass(frozen=True)
735
+ class LetDecl:
736
+ name: str
737
+ typed_expr: TypedExprDecl
738
+
739
+
740
+ @dataclass(frozen=True)
741
+ class SetDecl:
742
+ tp: JustTypeRef
743
+ call: CallDecl
744
+ rhs: ExprDecl
745
+
746
+
747
+ @dataclass(frozen=True)
748
+ class ExprActionDecl:
749
+ typed_expr: TypedExprDecl
750
+
751
+
752
+ @dataclass(frozen=True)
753
+ class ChangeDecl:
754
+ tp: JustTypeRef
755
+ call: CallDecl
756
+ change: Literal["delete", "subsume"]
757
+
758
+
759
+ @dataclass(frozen=True)
760
+ class UnionDecl:
761
+ tp: JustTypeRef
762
+ lhs: ExprDecl
763
+ rhs: ExprDecl
764
+
765
+
766
+ @dataclass(frozen=True)
767
+ class PanicDecl:
768
+ msg: str
769
+
770
+
771
+ ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
772
+
773
+
774
+ ##
775
+ # Commands
776
+ ##
777
+
778
+
779
+ @dataclass(frozen=True)
780
+ class RewriteDecl:
781
+ tp: JustTypeRef
782
+ lhs: ExprDecl
783
+ rhs: ExprDecl
784
+ conditions: tuple[FactDecl, ...]
785
+ subsume: bool
786
+
787
+
788
+ @dataclass(frozen=True)
789
+ class BiRewriteDecl:
790
+ tp: JustTypeRef
791
+ lhs: ExprDecl
792
+ rhs: ExprDecl
793
+ conditions: tuple[FactDecl, ...]
794
+
795
+
796
+ @dataclass(frozen=True)
797
+ class RuleDecl:
798
+ head: tuple[ActionDecl, ...]
799
+ body: tuple[FactDecl, ...]
800
+ name: str | None
801
+
802
+
803
+ @dataclass(frozen=True)
804
+ class DefaultRewriteDecl:
805
+ ref: CallableRef
806
+ expr: ExprDecl
807
+ subsume: bool
808
+
809
+
810
+ RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
811
+
812
+
813
+ @dataclass(frozen=True)
814
+ class ActionCommandDecl:
815
+ action: ActionDecl
816
+
817
+
818
+ CommandDecl: TypeAlias = RewriteOrRuleDecl | ActionCommandDecl