egglog 11.2.0__cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Files changed (46) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-314-x86_64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +734 -0
  4. egglog/builtins.py +1133 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +286 -0
  7. egglog/declarations.py +912 -0
  8. egglog/deconstruct.py +173 -0
  9. egglog/egraph.py +1875 -0
  10. egglog/egraph_state.py +680 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +67 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/__init__.py +3 -0
  26. egglog/exp/array_api.py +2019 -0
  27. egglog/exp/array_api_jit.py +51 -0
  28. egglog/exp/array_api_loopnest.py +74 -0
  29. egglog/exp/array_api_numba.py +69 -0
  30. egglog/exp/array_api_program_gen.py +510 -0
  31. egglog/exp/program_gen.py +425 -0
  32. egglog/exp/siu_examples.py +32 -0
  33. egglog/ipython_magic.py +41 -0
  34. egglog/pretty.py +509 -0
  35. egglog/py.typed +0 -0
  36. egglog/runtime.py +712 -0
  37. egglog/thunk.py +97 -0
  38. egglog/type_constraint_solver.py +113 -0
  39. egglog/version_compat.py +87 -0
  40. egglog/visualizer.css +1 -0
  41. egglog/visualizer.js +35777 -0
  42. egglog/visualizer_widget.py +39 -0
  43. egglog-11.2.0.dist-info/METADATA +74 -0
  44. egglog-11.2.0.dist-info/RECORD +46 -0
  45. egglog-11.2.0.dist-info/WHEEL +4 -0
  46. egglog-11.2.0.dist-info/licenses/LICENSE +21 -0
egglog/declarations.py ADDED
@@ -0,0 +1,912 @@
1
+ """
2
+ Data only descriptions of the components of an egraph and the expressions.
3
+
4
+ We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from functools import cached_property
11
+ from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
12
+ from weakref import WeakValueDictionary
13
+
14
+ from typing_extensions import Self, assert_never
15
+
16
+ if TYPE_CHECKING:
17
+ from collections.abc import Callable, Iterable, Mapping
18
+
19
+
20
+ __all__ = [
21
+ "ActionCommandDecl",
22
+ "ActionDecl",
23
+ "BiRewriteDecl",
24
+ "CallDecl",
25
+ "CallableDecl",
26
+ "CallableRef",
27
+ "ChangeDecl",
28
+ "ClassDecl",
29
+ "ClassMethodRef",
30
+ "ClassTypeVarRef",
31
+ "ClassVariableRef",
32
+ "CombinedRulesetDecl",
33
+ "CommandDecl",
34
+ "ConstantDecl",
35
+ "ConstantRef",
36
+ "ConstructorDecl",
37
+ "Declarations",
38
+ "Declarations",
39
+ "DeclerationsLike",
40
+ "DefaultRewriteDecl",
41
+ "DelayedDeclerations",
42
+ "EqDecl",
43
+ "ExprActionDecl",
44
+ "ExprDecl",
45
+ "ExprFactDecl",
46
+ "FactDecl",
47
+ "FunctionDecl",
48
+ "FunctionRef",
49
+ "FunctionSignature",
50
+ "HasDeclerations",
51
+ "InitRef",
52
+ "JustTypeRef",
53
+ "LetDecl",
54
+ "LetRefDecl",
55
+ "LitDecl",
56
+ "LitType",
57
+ "MethodRef",
58
+ "PanicDecl",
59
+ "PartialCallDecl",
60
+ "PropertyRef",
61
+ "PyObjectDecl",
62
+ "RelationDecl",
63
+ "RepeatDecl",
64
+ "RewriteDecl",
65
+ "RewriteOrRuleDecl",
66
+ "RuleDecl",
67
+ "RulesetDecl",
68
+ "RunDecl",
69
+ "SaturateDecl",
70
+ "ScheduleDecl",
71
+ "SequenceDecl",
72
+ "SetCostDecl",
73
+ "SetDecl",
74
+ "SpecialFunctions",
75
+ "TypeOrVarRef",
76
+ "TypeRefWithVars",
77
+ "TypeVarError",
78
+ "TypedExprDecl",
79
+ "UnboundVarDecl",
80
+ "UnionDecl",
81
+ "UnnamedFunctionRef",
82
+ "collect_unbound_vars",
83
+ "replace_typed_expr",
84
+ "upcast_declerations",
85
+ ]
86
+
87
+
88
+ @dataclass(match_args=False)
89
+ class DelayedDeclerations:
90
+ __egg_decls_thunk__: Callable[[], Declarations] = field(repr=False)
91
+
92
+ @property
93
+ def __egg_decls__(self) -> Declarations:
94
+ thunk = self.__egg_decls_thunk__
95
+ try:
96
+ return thunk()
97
+ # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
98
+ # instead raise explicitly
99
+ except AttributeError as err:
100
+ msg = f"Cannot resolve declarations for {self}: {err}"
101
+ raise RuntimeError(msg) from err
102
+
103
+
104
+ @runtime_checkable
105
+ class HasDeclerations(Protocol):
106
+ @property
107
+ def __egg_decls__(self) -> Declarations: ...
108
+
109
+
110
+ DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
111
+
112
+
113
+ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
114
+ d = []
115
+ for l in declerations_like:
116
+ if l is None:
117
+ continue
118
+ if isinstance(l, HasDeclerations):
119
+ d.append(l.__egg_decls__)
120
+ elif isinstance(l, Declarations):
121
+ d.append(l)
122
+ else:
123
+ assert_never(l)
124
+ return d
125
+
126
+
127
+ @dataclass
128
+ class Declarations:
129
+ _unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
130
+ _functions: dict[str, FunctionDecl | RelationDecl | ConstructorDecl] = field(default_factory=dict)
131
+ _constants: dict[str, ConstantDecl] = field(default_factory=dict)
132
+ _classes: dict[str, ClassDecl] = field(default_factory=dict)
133
+ _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
134
+
135
+ @property
136
+ def default_ruleset(self) -> RulesetDecl:
137
+ ruleset = self._rulesets[""]
138
+ assert isinstance(ruleset, RulesetDecl)
139
+ return ruleset
140
+
141
+ @classmethod
142
+ def create(cls, *others: DeclerationsLike) -> Declarations:
143
+ others = upcast_declerations(others)
144
+ if not others:
145
+ return Declarations()
146
+ first, *rest = others
147
+ if not rest:
148
+ return first
149
+ new = first.copy()
150
+ new.update(*rest)
151
+ return new
152
+
153
+ def copy(self) -> Declarations:
154
+ new = Declarations()
155
+ self.update_other(new)
156
+ return new
157
+
158
+ def update(self, *others: DeclerationsLike) -> None:
159
+ for other in others:
160
+ self |= other
161
+
162
+ def __or__(self, other: DeclerationsLike) -> Declarations:
163
+ result = self.copy()
164
+ result |= other
165
+ return result
166
+
167
+ def __ior__(self, other: DeclerationsLike) -> Self:
168
+ if other is None:
169
+ return self
170
+ if isinstance(other, HasDeclerations):
171
+ other = other.__egg_decls__
172
+ other.update_other(self)
173
+ return self
174
+
175
+ def update_other(self, other: Declarations) -> None:
176
+ """
177
+ Updates the other decl with these values in palce.
178
+ """
179
+ other._functions |= self._functions
180
+ other._classes |= self._classes
181
+ other._constants |= self._constants
182
+ # Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
183
+ # is added to functions.
184
+ combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
185
+ other._rulesets |= self._rulesets
186
+ other._rulesets[""] = RulesetDecl(list(combined_default_rules))
187
+
188
+ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
189
+ match ref:
190
+ case FunctionRef(name):
191
+ return self._functions[name]
192
+ case ConstantRef(name):
193
+ return self._constants[name]
194
+ case MethodRef(class_name, method_name):
195
+ return self._classes[class_name].methods[method_name]
196
+ case ClassVariableRef(class_name, name):
197
+ return self._classes[class_name].class_variables[name]
198
+ case ClassMethodRef(class_name, name):
199
+ return self._classes[class_name].class_methods[name]
200
+ case PropertyRef(class_name, property_name):
201
+ return self._classes[class_name].properties[property_name]
202
+ case InitRef(class_name):
203
+ init_fn = self._classes[class_name].init
204
+ assert init_fn, f"Class {class_name} does not have an init function."
205
+ return init_fn
206
+ case UnnamedFunctionRef():
207
+ return ConstructorDecl(ref.signature)
208
+
209
+ assert_never(ref)
210
+
211
+ def set_function_decl(
212
+ self,
213
+ ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef,
214
+ decl: FunctionDecl | ConstructorDecl,
215
+ ) -> None:
216
+ match ref:
217
+ case FunctionRef(name):
218
+ self._functions[name] = decl
219
+ case MethodRef(class_name, method_name):
220
+ self._classes[class_name].methods[method_name] = decl
221
+ case ClassMethodRef(class_name, name):
222
+ self._classes[class_name].class_methods[name] = decl
223
+ case PropertyRef(class_name, property_name):
224
+ self._classes[class_name].properties[property_name] = decl
225
+ case InitRef(class_name):
226
+ self._classes[class_name].init = decl
227
+ case _:
228
+ assert_never(ref)
229
+
230
+ def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool:
231
+ """
232
+ Checks if the class has a binary method compatible with the given types.
233
+ """
234
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
235
+ if callable_decl := self._classes[self_type.name].methods.get(method_name):
236
+ match callable_decl.signature:
237
+ case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(
238
+ vars, self_type
239
+ ) and other_arg_type.matches_just(vars, other_type):
240
+ return True
241
+ return False
242
+
243
+ def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None:
244
+ """
245
+ Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
246
+ """
247
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
248
+ class_decl = self._classes.get(self_type.name)
249
+ if class_decl is None:
250
+ return None
251
+ if callable_decl := class_decl.methods.get(method_name):
252
+ match callable_decl.signature:
253
+ case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
254
+ return other_arg_type.to_just(vars)
255
+ return None
256
+
257
+ def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]:
258
+ """
259
+ Returns the types which are compatible with the given binary method name and other type.
260
+ """
261
+ for class_decl in self._classes.values():
262
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
263
+ if callable_decl := class_decl.methods.get(method_name):
264
+ match callable_decl.signature:
265
+ case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just(
266
+ vars, other_type
267
+ ):
268
+ yield self_arg_type.to_just(vars)
269
+
270
+ def get_class_decl(self, name: str) -> ClassDecl:
271
+ return self._classes[name]
272
+
273
+ def get_paramaterized_class(self, name: str) -> TypeRefWithVars:
274
+ """
275
+ Returns a class reference with type parameters, if the class is paramaterized.
276
+ """
277
+ type_vars = self._classes[name].type_vars
278
+ return TypeRefWithVars(name, type_vars)
279
+
280
+
281
+ @dataclass
282
+ class ClassDecl:
283
+ egg_name: str | None = None
284
+ type_vars: tuple[ClassTypeVarRef, ...] = ()
285
+ builtin: bool = False
286
+ init: ConstructorDecl | FunctionDecl | None = None
287
+ class_methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
288
+ # These have to be seperate from class_methods so that printing them can be done easily
289
+ class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
290
+ methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
291
+ properties: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
292
+ preserved_methods: dict[str, Callable] = field(default_factory=dict)
293
+ match_args: tuple[str, ...] = field(default=())
294
+
295
+
296
+ @dataclass(frozen=True)
297
+ class RulesetDecl:
298
+ rules: list[RewriteOrRuleDecl]
299
+
300
+ # Make hashable so when traversing for pretty-fying we can know which rulesets we have already
301
+ # made into strings
302
+ def __hash__(self) -> int:
303
+ return hash((type(self), tuple(self.rules)))
304
+
305
+
306
+ @dataclass(frozen=True)
307
+ class CombinedRulesetDecl:
308
+ rulesets: tuple[str, ...]
309
+
310
+
311
+ # Have two different types of type refs, one that can include vars recursively and one that cannot.
312
+ # We only use the one with vars for classmethods and methods, and the other one for egg references as
313
+ # well as runtime values.
314
+ @dataclass(frozen=True)
315
+ class JustTypeRef:
316
+ name: str
317
+ args: tuple[JustTypeRef, ...] = ()
318
+
319
+ def to_var(self) -> TypeRefWithVars:
320
+ return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args))
321
+
322
+ def __str__(self) -> str:
323
+ if self.args:
324
+ return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
325
+ return self.name
326
+
327
+
328
+ ##
329
+ # Type references with vars
330
+ ##
331
+
332
+ # mapping of name and module of resolved typevars to runtime values
333
+ # so that when spitting them back out again can use same instance
334
+ # since equality is based on identity not value
335
+ _RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}
336
+
337
+
338
+ class TypeVarError(RuntimeError):
339
+ """Error when trying to resolve a type variable that doesn't exist."""
340
+
341
+
342
+ @dataclass(frozen=True)
343
+ class ClassTypeVarRef:
344
+ """
345
+ A class type variable represents one of the types of the class, if it is a generic class.
346
+ """
347
+
348
+ name: str
349
+ module: str
350
+
351
+ def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
352
+ if vars is None or self not in vars:
353
+ raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings")
354
+ return vars[self]
355
+
356
+ def __str__(self) -> str:
357
+ return str(self.to_type_var())
358
+
359
+ @classmethod
360
+ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef:
361
+ res = cls(typevar.__name__, typevar.__module__)
362
+ _RESOLVED_TYPEVARS[res] = typevar
363
+ return res
364
+
365
+ def to_type_var(self) -> TypeVar:
366
+ return _RESOLVED_TYPEVARS[self]
367
+
368
+ def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
369
+ """
370
+ Checks if this type variable matches the given JustTypeRef, including type variables.
371
+ """
372
+ if self in vars:
373
+ return vars[self] == other
374
+ vars[self] = other
375
+ return True
376
+
377
+
378
+ @dataclass(frozen=True)
379
+ class TypeRefWithVars:
380
+ name: str
381
+ args: tuple[TypeOrVarRef, ...] = ()
382
+
383
+ def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
384
+ return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args))
385
+
386
+ def __str__(self) -> str:
387
+ if self.args:
388
+ return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
389
+ return self.name
390
+
391
+ def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
392
+ """
393
+ Checks if this type reference matches the given JustTypeRef, including type variables.
394
+ """
395
+ return (
396
+ self.name == other.name
397
+ and len(self.args) == len(other.args)
398
+ and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
399
+ )
400
+
401
+
402
+ TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
403
+
404
+ ##
405
+ # Callables References
406
+ ##
407
+
408
+
409
+ @dataclass(frozen=True)
410
+ class UnnamedFunctionRef:
411
+ """
412
+ A reference to a function that doesn't have a name, but does have a body.
413
+ """
414
+
415
+ # tuple of var arg names and their types
416
+ args: tuple[TypedExprDecl, ...]
417
+ res: TypedExprDecl
418
+
419
+ @property
420
+ def signature(self) -> FunctionSignature:
421
+ arg_types = []
422
+ arg_names = []
423
+ for a in self.args:
424
+ arg_types.append(a.tp.to_var())
425
+ assert isinstance(a.expr, UnboundVarDecl)
426
+ arg_names.append(a.expr.name)
427
+ return FunctionSignature(
428
+ arg_types=tuple(arg_types),
429
+ arg_names=tuple(arg_names),
430
+ arg_defaults=(None,) * len(self.args),
431
+ return_type=self.res.tp.to_var(),
432
+ )
433
+
434
+ @property
435
+ def egg_name(self) -> None | str:
436
+ return None
437
+
438
+
439
+ @dataclass(frozen=True)
440
+ class FunctionRef:
441
+ name: str
442
+
443
+
444
+ @dataclass(frozen=True)
445
+ class ConstantRef:
446
+ name: str
447
+
448
+
449
+ @dataclass(frozen=True)
450
+ class MethodRef:
451
+ class_name: str
452
+ method_name: str
453
+
454
+
455
+ @dataclass(frozen=True)
456
+ class ClassMethodRef:
457
+ class_name: str
458
+ method_name: str
459
+
460
+
461
+ @dataclass(frozen=True)
462
+ class InitRef:
463
+ class_name: str
464
+
465
+
466
+ @dataclass(frozen=True)
467
+ class ClassVariableRef:
468
+ class_name: str
469
+ var_name: str
470
+
471
+
472
+ @dataclass(frozen=True)
473
+ class PropertyRef:
474
+ class_name: str
475
+ property_name: str
476
+
477
+
478
+ CallableRef: TypeAlias = (
479
+ FunctionRef
480
+ | ConstantRef
481
+ | MethodRef
482
+ | ClassMethodRef
483
+ | InitRef
484
+ | ClassVariableRef
485
+ | PropertyRef
486
+ | UnnamedFunctionRef
487
+ )
488
+
489
+
490
+ ##
491
+ # Callables
492
+ ##
493
+
494
+
495
+ @dataclass(frozen=True)
496
+ class RelationDecl:
497
+ arg_types: tuple[JustTypeRef, ...]
498
+ # List of defaults. None for any arg which doesn't have one.
499
+ arg_defaults: tuple[ExprDecl | None, ...]
500
+ egg_name: str | None
501
+
502
+ @property
503
+ def signature(self) -> FunctionSignature:
504
+ return FunctionSignature(
505
+ arg_types=tuple(a.to_var() for a in self.arg_types),
506
+ arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
507
+ arg_defaults=self.arg_defaults,
508
+ return_type=TypeRefWithVars("Unit"),
509
+ )
510
+
511
+
512
+ @dataclass(frozen=True)
513
+ class ConstantDecl:
514
+ """
515
+ Same as `(declare)` in egglog
516
+ """
517
+
518
+ type_ref: JustTypeRef
519
+ egg_name: str | None = None
520
+
521
+ @property
522
+ def signature(self) -> FunctionSignature:
523
+ return FunctionSignature(return_type=self.type_ref.to_var())
524
+
525
+
526
+ # special cases for partial function creation and application, which cannot use the normal python rules
527
+ SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
528
+
529
+
530
+ @dataclass(frozen=True)
531
+ class FunctionSignature:
532
+ arg_types: tuple[TypeOrVarRef, ...] = ()
533
+ arg_names: tuple[str, ...] = ()
534
+ # List of defaults. None for any arg which doesn't have one.
535
+ arg_defaults: tuple[ExprDecl | None, ...] = ()
536
+ # If None, then the first arg is mutated and returned
537
+ return_type: TypeOrVarRef | None = None
538
+ var_arg_type: TypeOrVarRef | None = None
539
+ # Whether to reverse args when emitting to egglog
540
+ reverse_args: bool = False
541
+
542
+ @property
543
+ def semantic_return_type(self) -> TypeOrVarRef:
544
+ """
545
+ The type that is returned by the function, which wil be in the first arg if it mutates it.
546
+ """
547
+ return self.return_type or self.arg_types[0]
548
+
549
+ @property
550
+ def mutates(self) -> bool:
551
+ return self.return_type is None
552
+
553
+
554
+ @dataclass(frozen=True)
555
+ class FunctionDecl:
556
+ signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
557
+ builtin: bool = False
558
+ egg_name: str | None = None
559
+ merge: ExprDecl | None = None
560
+
561
+
562
+ @dataclass(frozen=True)
563
+ class ConstructorDecl:
564
+ signature: FunctionSignature = field(default_factory=FunctionSignature)
565
+ egg_name: str | None = None
566
+ cost: int | None = None
567
+ unextractable: bool = False
568
+
569
+
570
+ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl | ConstructorDecl
571
+
572
+ ##
573
+ # Expressions
574
+ ##
575
+
576
+
577
+ @dataclass(frozen=True)
578
+ class UnboundVarDecl:
579
+ name: str
580
+ egg_name: str | None = None
581
+
582
+
583
+ @dataclass(frozen=True)
584
+ class LetRefDecl:
585
+ name: str
586
+
587
+
588
+ @dataclass(frozen=True)
589
+ class PyObjectDecl:
590
+ value: object
591
+
592
+ def __hash__(self) -> int:
593
+ """Tries using the hash of the value, if unhashable use the ID."""
594
+ try:
595
+ return hash((type(self.value), self.value))
596
+ except TypeError:
597
+ return id(self.value)
598
+
599
+ def __eq__(self, other: object) -> bool:
600
+ if not isinstance(other, PyObjectDecl):
601
+ return False
602
+ return self.parts == other.parts
603
+
604
+ @property
605
+ def parts(self) -> tuple[type, object]:
606
+ return (type(self.value), self.value)
607
+
608
+
609
+ LitType: TypeAlias = int | str | float | bool | None
610
+
611
+
612
+ @dataclass(frozen=True)
613
+ class LitDecl:
614
+ value: LitType
615
+
616
+ def __hash__(self) -> int:
617
+ """
618
+ Include type in has so that 1.0 != 1
619
+ """
620
+ return hash(self.parts)
621
+
622
+ def __eq__(self, other: object) -> bool:
623
+ if not isinstance(other, LitDecl):
624
+ return False
625
+ return self.parts == other.parts
626
+
627
+ @property
628
+ def parts(self) -> tuple[type, LitType]:
629
+ return (type(self.value), self.value)
630
+
631
+
632
+ @dataclass(frozen=True)
633
+ class CallDecl:
634
+ callable: CallableRef
635
+ # TODO: Can I make these not typed expressions?
636
+ args: tuple[TypedExprDecl, ...] = ()
637
+ # type parameters that were bound to the callable, if it is a classmethod
638
+ # Used for pretty printing classmethod calls with type parameters
639
+ bound_tp_params: tuple[JustTypeRef, ...] | None = None
640
+
641
+ # pool objects for faster __eq__
642
+ _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})
643
+
644
+ def __new__(cls, *args: object, **kwargs: object) -> Self:
645
+ """
646
+ Pool CallDecls so that they can be compared by identity more quickly.
647
+
648
+ Neccessary bc we search for common parents when serializing CallDecl trees to egglog to
649
+ only serialize each sub-tree once.
650
+ """
651
+ # normalize the args/kwargs to a tuple so that they can be compared
652
+ callable = args[0] if args else kwargs["callable"]
653
+ args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
654
+ bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params")
655
+
656
+ normalized_args = (callable, args_, bound_tp_params)
657
+ try:
658
+ return cast("Self", cls._args_to_value[normalized_args])
659
+ except KeyError:
660
+ res = super().__new__(cls)
661
+ cls._args_to_value[normalized_args] = res
662
+ return res
663
+
664
+ def __post_init__(self) -> None:
665
+ if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
666
+ msg = "Cannot bind type parameters to a non-class method callable."
667
+ raise ValueError(msg)
668
+
669
+ def __hash__(self) -> int:
670
+ return self._cached_hash
671
+
672
+ @cached_property
673
+ def _cached_hash(self) -> int:
674
+ return hash((self.callable, self.args, self.bound_tp_params))
675
+
676
+ def __eq__(self, other: object) -> bool:
677
+ return self is other
678
+
679
+ def __ne__(self, other: object) -> bool:
680
+ return self is not other
681
+
682
+
683
+ @dataclass(frozen=True)
684
+ class PartialCallDecl:
685
+ """
686
+ A partially applied function aka a function sort.
687
+
688
+ Note it does not need to have any args, in which case it's just a function pointer.
689
+
690
+ Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
691
+ """
692
+
693
+ call: CallDecl
694
+
695
+
696
+ ExprDecl: TypeAlias = UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
697
+
698
+
699
+ @dataclass(frozen=True)
700
+ class TypedExprDecl:
701
+ tp: JustTypeRef
702
+ expr: ExprDecl
703
+
704
+ def descendants(self) -> list[TypedExprDecl]:
705
+ """
706
+ Returns a list of all the descendants of this expression.
707
+ """
708
+ l = [self]
709
+ if isinstance(self.expr, CallDecl):
710
+ for a in self.expr.args:
711
+ l.extend(a.descendants())
712
+ return l
713
+
714
+
715
+ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
716
+ """
717
+ Replace all the typed expressions in the given typed expression with the replacements.
718
+ """
719
+ # keep track of the traversed expressions for memoization
720
+ traversed: dict[TypedExprDecl, TypedExprDecl] = {}
721
+
722
+ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
723
+ if typed_expr in traversed:
724
+ return traversed[typed_expr]
725
+ if typed_expr in replacements:
726
+ res = replacements[typed_expr]
727
+ else:
728
+ match typed_expr.expr:
729
+ case CallDecl(callable, args, bound_tp_params) | PartialCallDecl(
730
+ CallDecl(callable, args, bound_tp_params)
731
+ ):
732
+ new_args = tuple(_inner(a) for a in args)
733
+ call_decl = CallDecl(callable, new_args, bound_tp_params)
734
+ res = TypedExprDecl(
735
+ typed_expr.tp,
736
+ call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
737
+ )
738
+ case _:
739
+ res = typed_expr
740
+ traversed[typed_expr] = res
741
+ return res
742
+
743
+ return _inner(typed_expr)
744
+
745
+
746
+ def collect_unbound_vars(typed_expr: TypedExprDecl) -> set[TypedExprDecl]:
747
+ """
748
+ Returns the set of all unbound vars
749
+ """
750
+ seen = set[TypedExprDecl]()
751
+ unbound_vars = set[TypedExprDecl]()
752
+
753
+ def visit(typed_expr: TypedExprDecl) -> None:
754
+ if typed_expr in seen:
755
+ return
756
+ seen.add(typed_expr)
757
+ match typed_expr.expr:
758
+ case CallDecl(_, args) | PartialCallDecl(CallDecl(_, args)):
759
+ for arg in args:
760
+ visit(arg)
761
+ case UnboundVarDecl(_):
762
+ unbound_vars.add(typed_expr)
763
+
764
+ visit(typed_expr)
765
+ return unbound_vars
766
+
767
+
768
+ ##
769
+ # Schedules
770
+ ##
771
+
772
+
773
+ @dataclass(frozen=True)
774
+ class SaturateDecl:
775
+ schedule: ScheduleDecl
776
+
777
+
778
+ @dataclass(frozen=True)
779
+ class RepeatDecl:
780
+ schedule: ScheduleDecl
781
+ times: int
782
+
783
+
784
+ @dataclass(frozen=True)
785
+ class SequenceDecl:
786
+ schedules: tuple[ScheduleDecl, ...]
787
+
788
+
789
+ @dataclass(frozen=True)
790
+ class RunDecl:
791
+ ruleset: str
792
+ until: tuple[FactDecl, ...] | None
793
+
794
+
795
+ ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
796
+
797
+ ##
798
+ # Facts
799
+ ##
800
+
801
+
802
+ @dataclass(frozen=True)
803
+ class EqDecl:
804
+ tp: JustTypeRef
805
+ left: ExprDecl
806
+ right: ExprDecl
807
+
808
+
809
+ @dataclass(frozen=True)
810
+ class ExprFactDecl:
811
+ typed_expr: TypedExprDecl
812
+
813
+
814
+ FactDecl: TypeAlias = EqDecl | ExprFactDecl
815
+
816
+ ##
817
+ # Actions
818
+ ##
819
+
820
+
821
+ @dataclass(frozen=True)
822
+ class LetDecl:
823
+ name: str
824
+ typed_expr: TypedExprDecl
825
+
826
+
827
+ @dataclass(frozen=True)
828
+ class SetDecl:
829
+ tp: JustTypeRef
830
+ call: CallDecl
831
+ rhs: ExprDecl
832
+
833
+
834
+ @dataclass(frozen=True)
835
+ class ExprActionDecl:
836
+ typed_expr: TypedExprDecl
837
+
838
+
839
+ @dataclass(frozen=True)
840
+ class ChangeDecl:
841
+ tp: JustTypeRef
842
+ call: CallDecl
843
+ change: Literal["delete", "subsume"]
844
+
845
+
846
+ @dataclass(frozen=True)
847
+ class UnionDecl:
848
+ tp: JustTypeRef
849
+ lhs: ExprDecl
850
+ rhs: ExprDecl
851
+
852
+
853
+ @dataclass(frozen=True)
854
+ class PanicDecl:
855
+ msg: str
856
+
857
+
858
+ @dataclass(frozen=True)
859
+ class SetCostDecl:
860
+ tp: JustTypeRef
861
+ expr: CallDecl
862
+ cost: ExprDecl
863
+
864
+
865
+ ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl | SetCostDecl
866
+
867
+
868
+ ##
869
+ # Commands
870
+ ##
871
+
872
+
873
+ @dataclass(frozen=True)
874
+ class RewriteDecl:
875
+ tp: JustTypeRef
876
+ lhs: ExprDecl
877
+ rhs: ExprDecl
878
+ conditions: tuple[FactDecl, ...]
879
+ subsume: bool
880
+
881
+
882
+ @dataclass(frozen=True)
883
+ class BiRewriteDecl:
884
+ tp: JustTypeRef
885
+ lhs: ExprDecl
886
+ rhs: ExprDecl
887
+ conditions: tuple[FactDecl, ...]
888
+
889
+
890
+ @dataclass(frozen=True)
891
+ class RuleDecl:
892
+ head: tuple[ActionDecl, ...]
893
+ body: tuple[FactDecl, ...]
894
+ name: str | None
895
+
896
+
897
+ @dataclass(frozen=True)
898
+ class DefaultRewriteDecl:
899
+ ref: CallableRef
900
+ expr: ExprDecl
901
+ subsume: bool
902
+
903
+
904
+ RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
905
+
906
+
907
+ @dataclass(frozen=True)
908
+ class ActionCommandDecl:
909
+ action: ActionDecl
910
+
911
+
912
+ CommandDecl: TypeAlias = RewriteOrRuleDecl | ActionCommandDecl