egglog 12.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +887 -0
  4. egglog/builtins.py +1144 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +290 -0
  7. egglog/declarations.py +964 -0
  8. egglog/deconstruct.py +176 -0
  9. egglog/egraph.py +2247 -0
  10. egglog/egraph_state.py +978 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +64 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/MoA.ipynb +617 -0
  26. egglog/exp/__init__.py +3 -0
  27. egglog/exp/any_expr.py +947 -0
  28. egglog/exp/any_expr_example.ipynb +408 -0
  29. egglog/exp/array_api.py +2019 -0
  30. egglog/exp/array_api_jit.py +51 -0
  31. egglog/exp/array_api_loopnest.py +74 -0
  32. egglog/exp/array_api_numba.py +69 -0
  33. egglog/exp/array_api_program_gen.py +510 -0
  34. egglog/exp/program_gen.py +427 -0
  35. egglog/exp/siu_examples.py +32 -0
  36. egglog/ipython_magic.py +41 -0
  37. egglog/pretty.py +566 -0
  38. egglog/py.typed +0 -0
  39. egglog/runtime.py +888 -0
  40. egglog/thunk.py +97 -0
  41. egglog/type_constraint_solver.py +111 -0
  42. egglog/visualizer.css +1 -0
  43. egglog/visualizer.js +35798 -0
  44. egglog/visualizer_widget.py +39 -0
  45. egglog-12.0.0.dist-info/METADATA +93 -0
  46. egglog-12.0.0.dist-info/RECORD +48 -0
  47. egglog-12.0.0.dist-info/WHEEL +5 -0
  48. egglog-12.0.0.dist-info/licenses/LICENSE +21 -0
egglog/declarations.py ADDED
@@ -0,0 +1,964 @@
1
+ """
2
+ Data only descriptions of the components of an egraph and the expressions.
3
+
4
+ We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass, field
10
+ from functools import cached_property
11
+ from typing import (
12
+ TYPE_CHECKING,
13
+ ClassVar,
14
+ Literal,
15
+ Protocol,
16
+ Self,
17
+ TypeAlias,
18
+ TypeVar,
19
+ Union,
20
+ assert_never,
21
+ cast,
22
+ runtime_checkable,
23
+ )
24
+ from uuid import UUID
25
+ from weakref import WeakValueDictionary
26
+
27
+ from .bindings import Value
28
+
29
+ if TYPE_CHECKING:
30
+ from collections.abc import Callable, Iterable, Mapping
31
+
32
+
33
+ __all__ = [
34
+ "ActionCommandDecl",
35
+ "ActionDecl",
36
+ "BackOffDecl",
37
+ "BiRewriteDecl",
38
+ "CallDecl",
39
+ "CallableDecl",
40
+ "CallableRef",
41
+ "ChangeDecl",
42
+ "ClassDecl",
43
+ "ClassMethodRef",
44
+ "ClassTypeVarRef",
45
+ "ClassVariableRef",
46
+ "CombinedRulesetDecl",
47
+ "CommandDecl",
48
+ "ConstantDecl",
49
+ "ConstantRef",
50
+ "ConstructorDecl",
51
+ "Declarations",
52
+ "Declarations",
53
+ "DeclerationsLike",
54
+ "DefaultRewriteDecl",
55
+ "DelayedDeclerations",
56
+ "EqDecl",
57
+ "ExprActionDecl",
58
+ "ExprDecl",
59
+ "ExprFactDecl",
60
+ "FactDecl",
61
+ "FunctionDecl",
62
+ "FunctionRef",
63
+ "FunctionSignature",
64
+ "GetCostDecl",
65
+ "HasDeclerations",
66
+ "Ident",
67
+ "InitRef",
68
+ "JustTypeRef",
69
+ "LetDecl",
70
+ "LetRefDecl",
71
+ "LetSchedulerDecl",
72
+ "LitDecl",
73
+ "LitType",
74
+ "MethodRef",
75
+ "PanicDecl",
76
+ "PartialCallDecl",
77
+ "PropertyRef",
78
+ "PyObjectDecl",
79
+ "RelationDecl",
80
+ "RepeatDecl",
81
+ "RewriteDecl",
82
+ "RewriteOrRuleDecl",
83
+ "RuleDecl",
84
+ "RulesetDecl",
85
+ "RunDecl",
86
+ "SaturateDecl",
87
+ "ScheduleDecl",
88
+ "SequenceDecl",
89
+ "SetCostDecl",
90
+ "SetDecl",
91
+ "SpecialFunctions",
92
+ "TypeOrVarRef",
93
+ "TypeRefWithVars",
94
+ "TypeVarError",
95
+ "TypedExprDecl",
96
+ "UnboundVarDecl",
97
+ "UnionDecl",
98
+ "UnnamedFunctionRef",
99
+ "ValueDecl",
100
+ "collect_unbound_vars",
101
+ "replace_typed_expr",
102
+ "upcast_declerations",
103
+ ]
104
+
105
+
106
+ @dataclass(match_args=False)
107
+ class DelayedDeclerations:
108
+ __egg_decls_thunk__: Callable[[], Declarations] = field(repr=False)
109
+
110
+ @property
111
+ def __egg_decls__(self) -> Declarations:
112
+ thunk = self.__egg_decls_thunk__
113
+ try:
114
+ return thunk()
115
+ # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
116
+ # instead raise explicitly
117
+ except AttributeError as err:
118
+ msg = f"Cannot resolve declarations for {self}: {err}"
119
+ raise RuntimeError(msg) from err
120
+
121
+
122
+ @runtime_checkable
123
+ class HasDeclerations(Protocol):
124
+ @property
125
+ def __egg_decls__(self) -> Declarations: ...
126
+
127
+
128
+ DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
129
+
130
+
131
+ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
132
+ d = []
133
+ for l in declerations_like:
134
+ if l is None:
135
+ continue
136
+ if isinstance(l, HasDeclerations):
137
+ d.append(l.__egg_decls__)
138
+ elif isinstance(l, Declarations):
139
+ d.append(l)
140
+ else:
141
+ assert_never(l)
142
+ return d
143
+
144
+
145
+ @dataclass(frozen=True)
146
+ class Ident:
147
+ name: str
148
+ module: str | None = None
149
+
150
+ def __str__(self) -> str:
151
+ if self.module:
152
+ return f"{self.module}.{self.name}"
153
+ return self.name
154
+
155
+ @classmethod
156
+ def builtin(cls, name: str) -> Ident:
157
+ return cls(name, "egglog.builtins")
158
+
159
+
160
+ default_ruleset_identifier = Ident("")
161
+
162
+
163
+ @dataclass
164
+ class Declarations:
165
+ _unnamed_functions: set[UnnamedFunctionRef] = field(default_factory=set)
166
+ _functions: dict[Ident, FunctionDecl | RelationDecl | ConstructorDecl] = field(default_factory=dict)
167
+ _constants: dict[Ident, ConstantDecl] = field(default_factory=dict)
168
+ _classes: dict[Ident, ClassDecl] = field(default_factory=dict)
169
+ _rulesets: dict[Ident, RulesetDecl | CombinedRulesetDecl] = field(
170
+ default_factory=lambda: {default_ruleset_identifier: RulesetDecl([])}
171
+ )
172
+
173
+ @property
174
+ def default_ruleset(self) -> RulesetDecl:
175
+ ruleset = self._rulesets[default_ruleset_identifier]
176
+ assert isinstance(ruleset, RulesetDecl)
177
+ return ruleset
178
+
179
+ @classmethod
180
+ def create(cls, *others: DeclerationsLike) -> Declarations:
181
+ others = upcast_declerations(others)
182
+ if not others:
183
+ return Declarations()
184
+ first, *rest = others
185
+ if not rest:
186
+ return first
187
+ new = first.copy()
188
+ new.update(*rest)
189
+ return new
190
+
191
+ def copy(self) -> Declarations:
192
+ new = Declarations()
193
+ self.update_other(new)
194
+ return new
195
+
196
+ def update(self, *others: DeclerationsLike) -> None:
197
+ for other in others:
198
+ self |= other
199
+
200
+ def __or__(self, other: DeclerationsLike) -> Declarations:
201
+ result = self.copy()
202
+ result |= other
203
+ return result
204
+
205
+ def __ior__(self, other: DeclerationsLike) -> Self:
206
+ if other is None:
207
+ return self
208
+ if isinstance(other, HasDeclerations):
209
+ other = other.__egg_decls__
210
+ other.update_other(self)
211
+ return self
212
+
213
+ def update_other(self, other: Declarations) -> None:
214
+ """
215
+ Updates the other decl with these values in palce.
216
+ """
217
+ other._functions |= self._functions
218
+ other._classes |= self._classes
219
+ other._constants |= self._constants
220
+ # Must combine rulesets bc the empty ruleset might be different, bc DefaultRewriteDecl
221
+ # is added to functions.
222
+ combined_default_rules: set[RewriteOrRuleDecl] = {*self.default_ruleset.rules, *other.default_ruleset.rules}
223
+ other._rulesets |= self._rulesets
224
+ other._rulesets[default_ruleset_identifier] = RulesetDecl(list(combined_default_rules))
225
+
226
+ def get_callable_decl(self, ref: CallableRef) -> CallableDecl: # noqa: PLR0911
227
+ match ref:
228
+ case FunctionRef(name):
229
+ return self._functions[name]
230
+ case ConstantRef(name):
231
+ return self._constants[name]
232
+ case MethodRef(class_name, method_name):
233
+ return self._classes[class_name].methods[method_name]
234
+ case ClassVariableRef(class_name, name):
235
+ return self._classes[class_name].class_variables[name]
236
+ case ClassMethodRef(class_name, name):
237
+ return self._classes[class_name].class_methods[name]
238
+ case PropertyRef(class_name, property_name):
239
+ return self._classes[class_name].properties[property_name]
240
+ case InitRef(class_name):
241
+ init_fn = self._classes[class_name].init
242
+ assert init_fn, f"Class {class_name} does not have an init function."
243
+ return init_fn
244
+ case UnnamedFunctionRef():
245
+ return ConstructorDecl(ref.signature)
246
+
247
+ assert_never(ref)
248
+
249
+ def set_function_decl(
250
+ self,
251
+ ref: FunctionRef | MethodRef | ClassMethodRef | PropertyRef | InitRef,
252
+ decl: FunctionDecl | ConstructorDecl,
253
+ ) -> None:
254
+ match ref:
255
+ case FunctionRef(name):
256
+ self._functions[name] = decl
257
+ case MethodRef(class_name, method_name):
258
+ self._classes[class_name].methods[method_name] = decl
259
+ case ClassMethodRef(class_name, name):
260
+ self._classes[class_name].class_methods[name] = decl
261
+ case PropertyRef(class_name, property_name):
262
+ self._classes[class_name].properties[property_name] = decl
263
+ case InitRef(class_name):
264
+ self._classes[class_name].init = decl
265
+ case _:
266
+ assert_never(ref)
267
+
268
+ def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool:
269
+ """
270
+ Checks if the class has a binary method compatible with the given types.
271
+ """
272
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
273
+ if callable_decl := self._classes[self_type.ident].methods.get(method_name):
274
+ match callable_decl.signature:
275
+ case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(
276
+ vars, self_type
277
+ ) and other_arg_type.matches_just(vars, other_type):
278
+ return True
279
+ return False
280
+
281
+ def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None:
282
+ """
283
+ Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
284
+ """
285
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
286
+ class_decl = self._classes.get(self_type.ident)
287
+ if class_decl is None:
288
+ return None
289
+ if callable_decl := class_decl.methods.get(method_name):
290
+ match callable_decl.signature:
291
+ case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
292
+ return other_arg_type.to_just(vars)
293
+ return None
294
+
295
+ def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]:
296
+ """
297
+ Returns the types which are compatible with the given binary method name and other type.
298
+ """
299
+ for class_decl in self._classes.values():
300
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
301
+ if callable_decl := class_decl.methods.get(method_name):
302
+ match callable_decl.signature:
303
+ case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just(
304
+ vars, other_type
305
+ ):
306
+ yield self_arg_type.to_just(vars)
307
+
308
+ def get_class_decl(self, ident: Ident) -> ClassDecl:
309
+ return self._classes[ident]
310
+
311
+ def get_paramaterized_class(self, ident: Ident) -> TypeRefWithVars:
312
+ """
313
+ Returns a class reference with type parameters, if the class is paramaterized.
314
+ """
315
+ type_vars = self._classes[ident].type_vars
316
+ return TypeRefWithVars(ident, type_vars)
317
+
318
+
319
+ @dataclass
320
+ class ClassDecl:
321
+ egg_name: str | None = None
322
+ type_vars: tuple[ClassTypeVarRef, ...] = ()
323
+ builtin: bool = False
324
+ init: ConstructorDecl | FunctionDecl | None = None
325
+ class_methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
326
+ # These have to be seperate from class_methods so that printing them can be done easily
327
+ class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
328
+ methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
329
+ properties: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
330
+ preserved_methods: dict[str, Callable] = field(default_factory=dict)
331
+ match_args: tuple[str, ...] = field(default=())
332
+ doc: str | None = field(default=None)
333
+
334
+
335
+ @dataclass(frozen=True)
336
+ class RulesetDecl:
337
+ rules: list[RewriteOrRuleDecl]
338
+
339
+ # Make hashable so when traversing for pretty-fying we can know which rulesets we have already
340
+ # made into strings
341
+ def __hash__(self) -> int:
342
+ return hash((type(self), tuple(self.rules)))
343
+
344
+
345
+ @dataclass(frozen=True)
346
+ class CombinedRulesetDecl:
347
+ rulesets: tuple[Ident, ...]
348
+
349
+
350
+ # Have two different types of type refs, one that can include vars recursively and one that cannot.
351
+ # We only use the one with vars for classmethods and methods, and the other one for egg references as
352
+ # well as runtime values.
353
+ @dataclass(frozen=True)
354
+ class JustTypeRef:
355
+ ident: Ident
356
+ args: tuple[JustTypeRef, ...] = ()
357
+
358
+ def to_var(self) -> TypeRefWithVars:
359
+ return TypeRefWithVars(self.ident, tuple(a.to_var() for a in self.args))
360
+
361
+ def __str__(self) -> str:
362
+ if self.args:
363
+ return f"{self.ident.name}[{', '.join(str(a) for a in self.args)}]"
364
+ return str(self.ident.name)
365
+
366
+
367
+ ##
368
+ # Type references with vars
369
+ ##
370
+
371
+ # mapping of name and module of resolved typevars to runtime values
372
+ # so that when spitting them back out again can use same instance
373
+ # since equality is based on identity not value
374
+ _RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}
375
+
376
+
377
+ class TypeVarError(RuntimeError):
378
+ """Error when trying to resolve a type variable that doesn't exist."""
379
+
380
+
381
+ @dataclass(frozen=True)
382
+ class ClassTypeVarRef:
383
+ """
384
+ A class type variable represents one of the types of the class, if it is a generic class.
385
+ """
386
+
387
+ ident: Ident
388
+
389
+ def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
390
+ if vars is None or self not in vars:
391
+ raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings")
392
+ return vars[self]
393
+
394
+ def __str__(self) -> str:
395
+ return str(self.to_type_var())
396
+
397
+ @classmethod
398
+ def from_type_var(cls, typevar: TypeVar) -> ClassTypeVarRef:
399
+ res = cls(Ident(typevar.__name__, typevar.__module__))
400
+ _RESOLVED_TYPEVARS[res] = typevar
401
+ return res
402
+
403
+ def to_type_var(self) -> TypeVar:
404
+ return _RESOLVED_TYPEVARS[self]
405
+
406
+ def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
407
+ """
408
+ Checks if this type variable matches the given JustTypeRef, including type variables.
409
+ """
410
+ if self in vars:
411
+ return vars[self] == other
412
+ vars[self] = other
413
+ return True
414
+
415
+
416
+ @dataclass(frozen=True)
417
+ class TypeRefWithVars:
418
+ ident: Ident
419
+ args: tuple[TypeOrVarRef, ...] = ()
420
+
421
+ def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
422
+ return JustTypeRef(self.ident, tuple(a.to_just(vars) for a in self.args))
423
+
424
+ def __str__(self) -> str:
425
+ if self.args:
426
+ return f"{self.ident.name}[{', '.join(str(a) for a in self.args)}]"
427
+ return str(self.ident.name)
428
+
429
+ def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
430
+ """
431
+ Checks if this type reference matches the given JustTypeRef, including type variables.
432
+ """
433
+ return (
434
+ self.ident == other.ident
435
+ and len(self.args) == len(other.args)
436
+ and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
437
+ )
438
+
439
+
440
+ TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
441
+
442
+ ##
443
+ # Callables References
444
+ ##
445
+
446
+
447
+ @dataclass(frozen=True)
448
+ class UnnamedFunctionRef:
449
+ """
450
+ A reference to a function that doesn't have a name, but does have a body.
451
+ """
452
+
453
+ # tuple of var arg names and their types
454
+ args: tuple[TypedExprDecl, ...]
455
+ res: TypedExprDecl
456
+
457
+ @property
458
+ def signature(self) -> FunctionSignature:
459
+ arg_types = []
460
+ arg_names = []
461
+ for a in self.args:
462
+ arg_types.append(a.tp.to_var())
463
+ assert isinstance(a.expr, UnboundVarDecl)
464
+ arg_names.append(a.expr.name)
465
+ return FunctionSignature(
466
+ arg_types=tuple(arg_types),
467
+ arg_names=tuple(arg_names),
468
+ arg_defaults=(None,) * len(self.args),
469
+ return_type=self.res.tp.to_var(),
470
+ )
471
+
472
+ @property
473
+ def egg_name(self) -> None | str:
474
+ return None
475
+
476
+
477
+ @dataclass(frozen=True)
478
+ class FunctionRef:
479
+ ident: Ident
480
+
481
+
482
+ @dataclass(frozen=True)
483
+ class ConstantRef:
484
+ ident: Ident
485
+
486
+
487
+ @dataclass(frozen=True)
488
+ class MethodRef:
489
+ ident: Ident
490
+ method_name: str
491
+
492
+
493
+ @dataclass(frozen=True)
494
+ class ClassMethodRef:
495
+ ident: Ident
496
+ method_name: str
497
+
498
+
499
+ @dataclass(frozen=True)
500
+ class InitRef:
501
+ ident: Ident
502
+
503
+
504
+ @dataclass(frozen=True)
505
+ class ClassVariableRef:
506
+ ident: Ident
507
+ var_name: str
508
+
509
+
510
+ @dataclass(frozen=True)
511
+ class PropertyRef:
512
+ ident: Ident
513
+ property_name: str
514
+
515
+
516
+ CallableRef: TypeAlias = (
517
+ FunctionRef
518
+ | ConstantRef
519
+ | MethodRef
520
+ | ClassMethodRef
521
+ | InitRef
522
+ | ClassVariableRef
523
+ | PropertyRef
524
+ | UnnamedFunctionRef
525
+ )
526
+
527
+
528
+ ##
529
+ # Callables
530
+ ##
531
+
532
+
533
+ @dataclass(frozen=True)
534
+ class RelationDecl:
535
+ arg_types: tuple[JustTypeRef, ...]
536
+ # List of defaults. None for any arg which doesn't have one.
537
+ arg_defaults: tuple[ExprDecl | None, ...]
538
+ egg_name: str | None
539
+
540
+ @property
541
+ def signature(self) -> FunctionSignature:
542
+ return FunctionSignature(
543
+ arg_types=tuple(a.to_var() for a in self.arg_types),
544
+ arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
545
+ arg_defaults=self.arg_defaults,
546
+ return_type=TypeRefWithVars(Ident.builtin("Unit")),
547
+ )
548
+
549
+
550
+ @dataclass(frozen=True)
551
+ class ConstantDecl:
552
+ """
553
+ Same as `(declare)` in egglog
554
+ """
555
+
556
+ type_ref: JustTypeRef
557
+ egg_name: str | None = None
558
+
559
+ @property
560
+ def signature(self) -> FunctionSignature:
561
+ return FunctionSignature(return_type=self.type_ref.to_var())
562
+
563
+
564
+ # special cases for partial function creation and application, which cannot use the normal python rules
565
+ SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
566
+
567
+
568
+ @dataclass(frozen=True)
569
+ class FunctionSignature:
570
+ arg_types: tuple[TypeOrVarRef, ...] = ()
571
+ arg_names: tuple[str, ...] = ()
572
+ # List of defaults. None for any arg which doesn't have one.
573
+ arg_defaults: tuple[ExprDecl | None, ...] = ()
574
+ # If None, then the first arg is mutated and returned
575
+ return_type: TypeOrVarRef | None = None
576
+ var_arg_type: TypeOrVarRef | None = None
577
+ # Whether to reverse args when emitting to egglog
578
+ reverse_args: bool = False
579
+
580
+ @property
581
+ def semantic_return_type(self) -> TypeOrVarRef:
582
+ """
583
+ The type that is returned by the function, which wil be in the first arg if it mutates it.
584
+ """
585
+ return self.return_type or self.arg_types[0]
586
+
587
+ @property
588
+ def mutates(self) -> bool:
589
+ return self.return_type is None
590
+
591
+
592
+ @dataclass(frozen=True)
593
+ class FunctionDecl:
594
+ signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
595
+ builtin: bool = False
596
+ egg_name: str | None = None
597
+ merge: ExprDecl | None = None
598
+ doc: str | None = None
599
+
600
+
601
+ @dataclass(frozen=True)
602
+ class ConstructorDecl:
603
+ signature: FunctionSignature = field(default_factory=FunctionSignature)
604
+ egg_name: str | None = None
605
+ cost: int | None = None
606
+ unextractable: bool = False
607
+ doc: str | None = None
608
+
609
+
610
+ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl | ConstructorDecl
611
+
612
+ ##
613
+ # Expressions
614
+ ##
615
+
616
+
617
+ @dataclass(frozen=True)
618
+ class UnboundVarDecl:
619
+ name: str
620
+ egg_name: str | None = None
621
+
622
+
623
+ @dataclass(frozen=True)
624
+ class LetRefDecl:
625
+ name: str
626
+
627
+
628
+ @dataclass(frozen=True)
629
+ class PyObjectDecl:
630
+ pickled: bytes
631
+
632
+
633
+ LitType: TypeAlias = int | str | float | bool | None
634
+
635
+
636
+ @dataclass(frozen=True)
637
+ class LitDecl:
638
+ value: LitType
639
+
640
+ def __hash__(self) -> int:
641
+ """
642
+ Include type in has so that 1.0 != 1
643
+ """
644
+ return hash(self.parts)
645
+
646
+ def __eq__(self, other: object) -> bool:
647
+ if not isinstance(other, LitDecl):
648
+ return False
649
+ return self.parts == other.parts
650
+
651
+ @property
652
+ def parts(self) -> tuple[type, LitType]:
653
+ return (type(self.value), self.value)
654
+
655
+
656
+ @dataclass(frozen=True)
657
+ class CallDecl:
658
+ callable: CallableRef
659
+ # TODO: Can I make these not typed expressions?
660
+ args: tuple[TypedExprDecl, ...] = ()
661
+ # type parameters that were bound to the callable, if it is a classmethod
662
+ # Used for pretty printing classmethod calls with type parameters
663
+ bound_tp_params: tuple[JustTypeRef, ...] = ()
664
+
665
+ # pool objects for faster __eq__
666
+ _args_to_value: ClassVar[WeakValueDictionary[tuple[object, ...], CallDecl]] = WeakValueDictionary({})
667
+
668
+ def __new__(cls, *args: object, **kwargs: object) -> Self:
669
+ """
670
+ Pool CallDecls so that they can be compared by identity more quickly.
671
+
672
+ Neccessary bc we search for common parents when serializing CallDecl trees to egglog to
673
+ only serialize each sub-tree once.
674
+ """
675
+ # normalize the args/kwargs to a tuple so that they can be compared
676
+ callable = args[0] if args else kwargs["callable"]
677
+ args_ = args[1] if len(args) > 1 else kwargs.get("args", ())
678
+ bound_tp_params = args[2] if len(args) > 2 else kwargs.get("bound_tp_params", ())
679
+
680
+ normalized_args = (callable, args_, bound_tp_params)
681
+ try:
682
+ return cast("Self", cls._args_to_value[normalized_args])
683
+ except KeyError:
684
+ res = super().__new__(cls)
685
+ cls._args_to_value[normalized_args] = res
686
+ return res
687
+
688
+ def __post_init__(self) -> None:
689
+ if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef | InitRef):
690
+ msg = "Cannot bind type parameters to a non-class method callable."
691
+ raise ValueError(msg)
692
+
693
+ def __hash__(self) -> int:
694
+ return self._cached_hash
695
+
696
+ @cached_property
697
+ def _cached_hash(self) -> int:
698
+ return hash((self.callable, self.args, self.bound_tp_params))
699
+
700
+ def __eq__(self, other: object) -> bool:
701
+ return self is other
702
+
703
+ def __ne__(self, other: object) -> bool:
704
+ return self is not other
705
+
706
+
707
+ @dataclass(frozen=True)
708
+ class PartialCallDecl:
709
+ """
710
+ A partially applied function aka a function sort.
711
+
712
+ Note it does not need to have any args, in which case it's just a function pointer.
713
+
714
+ Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
715
+ """
716
+
717
+ call: CallDecl
718
+
719
+
720
+ @dataclass(frozen=True)
721
+ class GetCostDecl:
722
+ callable: CallableRef
723
+ args: tuple[TypedExprDecl, ...]
724
+
725
+
726
+ @dataclass(frozen=True)
727
+ class ValueDecl:
728
+ value: Value
729
+
730
+
731
+ ExprDecl: TypeAlias = (
732
+ UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl | ValueDecl | GetCostDecl
733
+ )
734
+
735
+
736
+ @dataclass(frozen=True)
737
+ class TypedExprDecl:
738
+ tp: JustTypeRef
739
+ expr: ExprDecl
740
+
741
+ def descendants(self) -> list[TypedExprDecl]:
742
+ """
743
+ Returns a list of all the descendants of this expression.
744
+ """
745
+ l = [self]
746
+ if isinstance(self.expr, CallDecl):
747
+ for a in self.expr.args:
748
+ l.extend(a.descendants())
749
+ return l
750
+
751
+
752
+ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExprDecl, TypedExprDecl]) -> TypedExprDecl:
753
+ """
754
+ Replace all the typed expressions in the given typed expression with the replacements.
755
+ """
756
+ # keep track of the traversed expressions for memoization
757
+ traversed: dict[TypedExprDecl, TypedExprDecl] = {}
758
+
759
+ def _inner(typed_expr: TypedExprDecl) -> TypedExprDecl:
760
+ if typed_expr in traversed:
761
+ return traversed[typed_expr]
762
+ if typed_expr in replacements:
763
+ res = replacements[typed_expr]
764
+ else:
765
+ match typed_expr.expr:
766
+ case CallDecl(callable, args, bound_tp_params) | PartialCallDecl(
767
+ CallDecl(callable, args, bound_tp_params)
768
+ ):
769
+ new_args = tuple(_inner(a) for a in args)
770
+ call_decl = CallDecl(callable, new_args, bound_tp_params)
771
+ res = TypedExprDecl(
772
+ typed_expr.tp,
773
+ call_decl if isinstance(typed_expr.expr, CallDecl) else PartialCallDecl(call_decl),
774
+ )
775
+ case _:
776
+ res = typed_expr
777
+ traversed[typed_expr] = res
778
+ return res
779
+
780
+ return _inner(typed_expr)
781
+
782
+
783
+ def collect_unbound_vars(typed_expr: TypedExprDecl) -> set[TypedExprDecl]:
784
+ """
785
+ Returns the set of all unbound vars
786
+ """
787
+ seen = set[TypedExprDecl]()
788
+ unbound_vars = set[TypedExprDecl]()
789
+
790
+ def visit(typed_expr: TypedExprDecl) -> None:
791
+ if typed_expr in seen:
792
+ return
793
+ seen.add(typed_expr)
794
+ match typed_expr.expr:
795
+ case CallDecl(_, args) | PartialCallDecl(CallDecl(_, args)):
796
+ for arg in args:
797
+ visit(arg)
798
+ case UnboundVarDecl(_):
799
+ unbound_vars.add(typed_expr)
800
+
801
+ visit(typed_expr)
802
+ return unbound_vars
803
+
804
+
805
+ ##
806
+ # Schedules
807
+ ##
808
+
809
+
810
+ @dataclass(frozen=True)
811
+ class SaturateDecl:
812
+ schedule: ScheduleDecl
813
+
814
+
815
+ @dataclass(frozen=True)
816
+ class RepeatDecl:
817
+ schedule: ScheduleDecl
818
+ times: int
819
+
820
+
821
+ @dataclass(frozen=True)
822
+ class SequenceDecl:
823
+ schedules: tuple[ScheduleDecl, ...]
824
+
825
+
826
+ @dataclass(frozen=True)
827
+ class RunDecl:
828
+ ruleset: Ident
829
+ until: tuple[FactDecl, ...] | None
830
+ scheduler: BackOffDecl | None = None
831
+
832
+
833
+ @dataclass(frozen=True)
834
+ class LetSchedulerDecl:
835
+ scheduler: BackOffDecl
836
+ inner: ScheduleDecl
837
+
838
+
839
+ ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl | LetSchedulerDecl
840
+
841
+
842
+ @dataclass(frozen=True)
843
+ class BackOffDecl:
844
+ id: UUID
845
+ match_limit: int | None
846
+ ban_length: int | None
847
+
848
+
849
+ ##
850
+ # Facts
851
+ ##
852
+
853
+
854
+ @dataclass(frozen=True)
855
+ class EqDecl:
856
+ tp: JustTypeRef
857
+ left: ExprDecl
858
+ right: ExprDecl
859
+
860
+
861
+ @dataclass(frozen=True)
862
+ class ExprFactDecl:
863
+ typed_expr: TypedExprDecl
864
+
865
+
866
+ FactDecl: TypeAlias = EqDecl | ExprFactDecl
867
+
868
+ ##
869
+ # Actions
870
+ ##
871
+
872
+
873
+ @dataclass(frozen=True)
874
+ class LetDecl:
875
+ name: str
876
+ typed_expr: TypedExprDecl
877
+
878
+
879
+ @dataclass(frozen=True)
880
+ class SetDecl:
881
+ tp: JustTypeRef
882
+ call: CallDecl
883
+ rhs: ExprDecl
884
+
885
+
886
+ @dataclass(frozen=True)
887
+ class ExprActionDecl:
888
+ typed_expr: TypedExprDecl
889
+
890
+
891
+ @dataclass(frozen=True)
892
+ class ChangeDecl:
893
+ tp: JustTypeRef
894
+ call: CallDecl
895
+ change: Literal["delete", "subsume"]
896
+
897
+
898
+ @dataclass(frozen=True)
899
+ class UnionDecl:
900
+ tp: JustTypeRef
901
+ lhs: ExprDecl
902
+ rhs: ExprDecl
903
+
904
+
905
+ @dataclass(frozen=True)
906
+ class PanicDecl:
907
+ msg: str
908
+
909
+
910
+ @dataclass(frozen=True)
911
+ class SetCostDecl:
912
+ tp: JustTypeRef
913
+ expr: CallDecl
914
+ cost: ExprDecl
915
+
916
+
917
+ ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl | SetCostDecl
918
+
919
+
920
+ ##
921
+ # Commands
922
+ ##
923
+
924
+
925
+ @dataclass(frozen=True)
926
+ class RewriteDecl:
927
+ tp: JustTypeRef
928
+ lhs: ExprDecl
929
+ rhs: ExprDecl
930
+ conditions: tuple[FactDecl, ...]
931
+ subsume: bool
932
+
933
+
934
+ @dataclass(frozen=True)
935
+ class BiRewriteDecl:
936
+ tp: JustTypeRef
937
+ lhs: ExprDecl
938
+ rhs: ExprDecl
939
+ conditions: tuple[FactDecl, ...]
940
+
941
+
942
+ @dataclass(frozen=True)
943
+ class RuleDecl:
944
+ head: tuple[ActionDecl, ...]
945
+ body: tuple[FactDecl, ...]
946
+ name: str | None
947
+
948
+
949
+ @dataclass(frozen=True)
950
+ class DefaultRewriteDecl:
951
+ ref: CallableRef
952
+ expr: ExprDecl
953
+ subsume: bool
954
+
955
+
956
+ RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl | DefaultRewriteDecl
957
+
958
+
959
+ @dataclass(frozen=True)
960
+ class ActionCommandDecl:
961
+ action: ActionDecl
962
+
963
+
964
+ CommandDecl: TypeAlias = RewriteOrRuleDecl | ActionCommandDecl