egglog 6.1.0__cp311-none-win_amd64.whl → 7.0.0__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/declarations.py CHANGED
@@ -1,18 +1,17 @@
1
1
  """
2
2
  Data only descriptions of the components of an egraph and the expressions.
3
+
4
+ We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types.
3
5
  """
4
6
 
5
7
  from __future__ import annotations
6
8
 
7
- from collections import defaultdict
8
9
  from dataclasses import dataclass, field
9
- from inspect import Parameter, Signature
10
- from typing import TYPE_CHECKING, Protocol, TypeAlias, Union, runtime_checkable
10
+ from functools import cached_property
11
+ from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_checkable
11
12
 
12
13
  from typing_extensions import Self, assert_never
13
14
 
14
- from . import bindings
15
-
16
15
  if TYPE_CHECKING:
17
16
  from collections.abc import Callable, Iterable
18
17
 
@@ -20,84 +19,64 @@ if TYPE_CHECKING:
20
19
  __all__ = [
21
20
  "Declarations",
22
21
  "DeclerationsLike",
23
- "upcast_decleratioons",
22
+ "DelayedDeclerations",
23
+ "upcast_declerations",
24
+ "Declarations",
24
25
  "JustTypeRef",
25
26
  "ClassTypeVarRef",
26
27
  "TypeRefWithVars",
27
28
  "TypeOrVarRef",
28
- "FunctionRef",
29
29
  "MethodRef",
30
30
  "ClassMethodRef",
31
+ "FunctionRef",
32
+ "ConstantRef",
31
33
  "ClassVariableRef",
32
- "FunctionCallableRef",
33
34
  "PropertyRef",
34
35
  "CallableRef",
35
- "ConstantRef",
36
36
  "FunctionDecl",
37
+ "RelationDecl",
38
+ "ConstantDecl",
39
+ "CallableDecl",
37
40
  "VarDecl",
38
- "LitType",
39
41
  "PyObjectDecl",
42
+ "LitType",
40
43
  "LitDecl",
41
44
  "CallDecl",
42
45
  "ExprDecl",
43
46
  "TypedExprDecl",
44
47
  "ClassDecl",
45
- "PrettyContext",
46
- "GLOBAL_PY_OBJECT_SORT",
48
+ "RulesetDecl",
49
+ "SaturateDecl",
50
+ "RepeatDecl",
51
+ "SequenceDecl",
52
+ "RunDecl",
53
+ "ScheduleDecl",
54
+ "EqDecl",
55
+ "ExprFactDecl",
56
+ "FactDecl",
57
+ "LetDecl",
58
+ "SetDecl",
59
+ "ExprActionDecl",
60
+ "ChangeDecl",
61
+ "UnionDecl",
62
+ "PanicDecl",
63
+ "ActionDecl",
64
+ "RewriteDecl",
65
+ "BiRewriteDecl",
66
+ "RuleDecl",
67
+ "RewriteOrRuleDecl",
68
+ "ActionCommandDecl",
69
+ "CommandDecl",
47
70
  ]
48
71
 
49
- # Create a global sort for python objects, so we can store them without an e-graph instance
50
- # Needed when serializing commands to egg commands when creating modules
51
- GLOBAL_PY_OBJECT_SORT = bindings.PyObjectSort()
52
-
53
- # Special methods which we might want to use as functions
54
- # Mapping to the operator they represent for pretty printing them
55
- # https://docs.python.org/3/reference/datamodel.html
56
- BINARY_METHODS = {
57
- "__lt__": "<",
58
- "__le__": "<=",
59
- "__eq__": "==",
60
- "__ne__": "!=",
61
- "__gt__": ">",
62
- "__ge__": ">=",
63
- # Numeric
64
- "__add__": "+",
65
- "__sub__": "-",
66
- "__mul__": "*",
67
- "__matmul__": "@",
68
- "__truediv__": "/",
69
- "__floordiv__": "//",
70
- "__mod__": "%",
71
- # TODO: Support divmod, with tuple return value
72
- # "__divmod__": "divmod",
73
- # TODO: Three arg power
74
- "__pow__": "**",
75
- "__lshift__": "<<",
76
- "__rshift__": ">>",
77
- "__and__": "&",
78
- "__xor__": "^",
79
- "__or__": "|",
80
- }
81
- REFLECTED_BINARY_METHODS = {
82
- "__radd__": "__add__",
83
- "__rsub__": "__sub__",
84
- "__rmul__": "__mul__",
85
- "__rmatmul__": "__matmul__",
86
- "__rtruediv__": "__truediv__",
87
- "__rfloordiv__": "__floordiv__",
88
- "__rmod__": "__mod__",
89
- "__rpow__": "__pow__",
90
- "__rlshift__": "__lshift__",
91
- "__rrshift__": "__rshift__",
92
- "__rand__": "__and__",
93
- "__rxor__": "__xor__",
94
- "__ror__": "__or__",
95
- }
96
- UNARY_METHODS = {
97
- "__pos__": "+",
98
- "__neg__": "-",
99
- "__invert__": "~",
100
- }
72
+
73
+ @dataclass
74
+ class DelayedDeclerations:
75
+ __egg_decls_thunk__: Callable[[], Declarations]
76
+
77
+ @property
78
+ def __egg_decls__(self) -> Declarations:
79
+ return self.__egg_decls_thunk__()
101
80
 
102
81
 
103
82
  @runtime_checkable
@@ -109,7 +88,10 @@ class HasDeclerations(Protocol):
109
88
  DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
110
89
 
111
90
 
112
- def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
91
+ # TODO: Make all ClassDecls take deferred type refs, which return new decls when resolving.
92
+
93
+
94
+ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
113
95
  d = []
114
96
  for l in declerations_like:
115
97
  if l is None:
@@ -125,30 +107,14 @@ def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[
125
107
 
126
108
  @dataclass
127
109
  class Declarations:
128
- _functions: dict[str, FunctionDecl] = field(default_factory=dict)
110
+ _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
111
+ _constants: dict[str, ConstantDecl] = field(default_factory=dict)
129
112
  _classes: dict[str, ClassDecl] = field(default_factory=dict)
130
- _constants: dict[str, JustTypeRef] = field(default_factory=dict)
131
-
132
- # Bidirectional mapping between egg function names and python callable references.
133
- # Note that there are possibly mutliple callable references for a single egg function name, like `+`
134
- # for both int and rational classes.
135
- _egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(default_factory=lambda: defaultdict(set))
136
- _callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=dict)
137
-
138
- # Bidirectional mapping between egg sort names and python type references.
139
- _egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
140
- _type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
141
-
142
- # Mapping from egg name (of sort or function) to command to create it.
143
- _cmds: dict[str, bindings._Command] = field(default_factory=dict)
144
-
145
- def __post_init__(self) -> None:
146
- if "!=" not in self._egg_fn_to_callable_refs:
147
- self.register_callable_ref(FunctionRef("!="), "!=")
113
+ _rulesets: dict[str, RulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
148
114
 
149
115
  @classmethod
150
116
  def create(cls, *others: DeclerationsLike) -> Declarations:
151
- others = upcast_decleratioons(others)
117
+ others = upcast_declerations(others)
152
118
  if not others:
153
119
  return Declarations()
154
120
  first, *rest = others
@@ -159,25 +125,9 @@ class Declarations:
159
125
  return new
160
126
 
161
127
  def copy(self) -> Declarations:
162
- return Declarations(
163
- _functions=self._functions.copy(),
164
- _classes=self._classes.copy(),
165
- _constants=self._constants.copy(),
166
- _egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self._egg_fn_to_callable_refs.items()}),
167
- _callable_ref_to_egg_fn=self._callable_ref_to_egg_fn.copy(),
168
- _egg_sort_to_type_ref=self._egg_sort_to_type_ref.copy(),
169
- _type_ref_to_egg_sort=self._type_ref_to_egg_sort.copy(),
170
- _cmds=self._cmds.copy(),
171
- )
172
-
173
- def __deepcopy__(self, memo: dict) -> Declarations:
174
- return self.copy()
175
-
176
- def add_cmd(self, name: str, cmd: bindings._Command) -> None:
177
- self._cmds[name] = cmd
178
-
179
- def list_cmds(self) -> list[bindings._Command]:
180
- return list(self._cmds.values())
128
+ new = Declarations()
129
+ new |= self
130
+ return new
181
131
 
182
132
  def update(self, *others: DeclerationsLike) -> None:
183
133
  for other in others:
@@ -200,82 +150,26 @@ class Declarations:
200
150
  """
201
151
  Updates the other decl with these values in palce.
202
152
  """
203
- # If cmds are == skip unioning for time savings
204
- # if set(self._cmds) == set(other._cmds) and self.record_cmds and other.record_cmds:
205
- # return self
206
153
  other._functions |= self._functions
207
154
  other._classes |= self._classes
208
155
  other._constants |= self._constants
209
- other._egg_sort_to_type_ref |= self._egg_sort_to_type_ref
210
- other._type_ref_to_egg_sort |= self._type_ref_to_egg_sort
211
- other._cmds |= self._cmds
212
- other._callable_ref_to_egg_fn |= self._callable_ref_to_egg_fn
213
- for egg_fn, callable_refs in self._egg_fn_to_callable_refs.items():
214
- other._egg_fn_to_callable_refs[egg_fn] |= callable_refs
215
-
216
- def set_function_decl(self, ref: FunctionCallableRef, decl: FunctionDecl) -> None:
217
- """
218
- Sets a function declaration for the given callable reference.
219
- """
156
+ other._rulesets |= self._rulesets
157
+
158
+ def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
220
159
  match ref:
221
160
  case FunctionRef(name):
222
- if name in self._functions:
223
- raise ValueError(f"Function {name} already registered")
224
- self._functions[name] = decl
161
+ return self._functions[name]
162
+ case ConstantRef(name):
163
+ return self._constants[name]
225
164
  case MethodRef(class_name, method_name):
226
- if method_name in self._classes[class_name].methods:
227
- raise ValueError(f"Method {class_name}.{method_name} already registered")
228
- self._classes[class_name].methods[method_name] = decl
229
- case ClassMethodRef(class_name, method_name):
230
- if method_name in self._classes[class_name].class_methods:
231
- raise ValueError(f"Class method {class_name}.{method_name} already registered")
232
- self._classes[class_name].class_methods[method_name] = decl
165
+ return self._classes[class_name].methods[method_name]
166
+ case ClassVariableRef(class_name, name):
167
+ return self._classes[class_name].class_variables[name]
168
+ case ClassMethodRef(class_name, name):
169
+ return self._classes[class_name].class_methods[name]
233
170
  case PropertyRef(class_name, property_name):
234
- if property_name in self._classes[class_name].properties:
235
- raise ValueError(f"Property {class_name}.{property_name} already registered")
236
- self._classes[class_name].properties[property_name] = decl
237
- case _:
238
- assert_never(ref)
239
-
240
- def set_constant_type(self, ref: ConstantCallableRef, tp: JustTypeRef) -> None:
241
- match ref:
242
- case ConstantRef(name):
243
- if name in self._constants:
244
- raise ValueError(f"Constant {name} already registered")
245
- self._constants[name] = tp
246
- case ClassVariableRef(class_name, variable_name):
247
- if variable_name in self._classes[class_name].class_variables:
248
- raise ValueError(f"Class variable {class_name}.{variable_name} already registered")
249
- self._classes[class_name].class_variables[variable_name] = tp
250
- case _:
251
- assert_never(ref)
252
-
253
- def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None:
254
- """
255
- Registers a callable reference with the given egg name.
256
-
257
- The callable's function needs to be registered first.
258
- """
259
- if ref in self._callable_ref_to_egg_fn:
260
- raise ValueError(f"Callable ref {ref} already registered")
261
- self._callable_ref_to_egg_fn[ref] = egg_name
262
- self._egg_fn_to_callable_refs[egg_name].add(ref)
263
-
264
- def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]:
265
- return self._egg_fn_to_callable_refs[egg_name]
266
-
267
- def get_egg_fn(self, ref: CallableRef) -> str:
268
- return self._callable_ref_to_egg_fn[ref]
269
-
270
- def get_egg_sort(self, ref: JustTypeRef) -> str:
271
- return self._type_ref_to_egg_sort[ref]
272
-
273
- def op_mapping(self) -> dict[str, str]:
274
- """
275
- Create a mapping of egglog function name to Python function name, for use in the serialized format
276
- for better visualization.
277
- """
278
- return {k: str(next(iter(v))) for k, v in self._egg_fn_to_callable_refs.items() if len(v) == 1}
171
+ return self._classes[class_name].properties[property_name]
172
+ assert_never(ref)
279
173
 
280
174
  def has_method(self, class_name: str, method_name: str) -> bool | None:
281
175
  """
@@ -285,138 +179,31 @@ class Declarations:
285
179
  return method_name in self._classes[class_name].methods
286
180
  return None
287
181
 
288
- def get_function_decl(self, ref: CallableRef) -> FunctionDecl:
289
- match ref:
290
- case ConstantRef(name):
291
- return self._constants[name].to_constant_function_decl()
292
- case ClassVariableRef(class_name, variable_name):
293
- return self._classes[class_name].class_variables[variable_name].to_constant_function_decl()
294
- case FunctionRef(name):
295
- return self._functions[name]
296
- case MethodRef(class_name, method_name):
297
- return self._classes[class_name].methods[method_name]
298
- case ClassMethodRef(class_name, method_name):
299
- return self._classes[class_name].class_methods[method_name]
300
- case PropertyRef(class_name, property_name):
301
- return self._classes[class_name].properties[property_name]
302
- assert_never(ref)
303
-
304
182
  def get_class_decl(self, name: str) -> ClassDecl:
305
183
  return self._classes[name]
306
184
 
307
- def get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]:
308
- """
309
- Given a class name, returns all possible registered types that it can be.
310
- """
311
- return frozenset(tp for tp in self._type_ref_to_egg_sort if tp.name == cls_name)
312
185
 
313
- def register_class(self, name: str, type_vars: tuple[str, ...], builtin: bool, egg_sort: str | None) -> None:
314
- # Register class first
315
- if name in self._classes:
316
- raise ValueError(f"Class {name} already registered")
317
- decl = ClassDecl(type_vars=type_vars)
318
- self._classes[name] = decl
319
- self.register_sort(JustTypeRef(name), builtin, egg_sort)
186
+ @dataclass
187
+ class ClassDecl:
188
+ egg_name: str | None = None
189
+ type_vars: tuple[str, ...] = ()
190
+ builtin: bool = False
191
+ class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
192
+ # These have to be seperate from class_methods so that printing them can be done easily
193
+ class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
194
+ methods: dict[str, FunctionDecl] = field(default_factory=dict)
195
+ properties: dict[str, FunctionDecl] = field(default_factory=dict)
196
+ preserved_methods: dict[str, Callable] = field(default_factory=dict)
320
197
 
321
- def register_sort(self, ref: JustTypeRef, builtin: bool, egg_name: str | None = None) -> str:
322
- """
323
- Register a sort with the given name. If no name is given, one is generated.
324
198
 
325
- If this is a type called with generic args, register the generic args as well.
326
- """
327
- # If the sort is already registered, do nothing
328
- try:
329
- egg_sort = self.get_egg_sort(ref)
330
- except KeyError:
331
- pass
332
- else:
333
- return egg_sort
334
- egg_name = egg_name or ref.generate_egg_name()
335
- if egg_name in self._egg_sort_to_type_ref:
336
- raise ValueError(f"Sort {egg_name} is already registered.")
337
- self._egg_sort_to_type_ref[egg_name] = ref
338
- self._type_ref_to_egg_sort[ref] = egg_name
339
- if not builtin:
340
- self.add_cmd(
341
- egg_name,
342
- bindings.Sort(
343
- egg_name,
344
- (
345
- self.get_egg_sort(JustTypeRef(ref.name)),
346
- [bindings.Var(self.register_sort(arg, False)) for arg in ref.args],
347
- )
348
- if ref.args
349
- else None,
350
- ),
351
- )
352
-
353
- return egg_name
354
-
355
- def register_function_callable(
356
- self,
357
- ref: FunctionCallableRef,
358
- fn_decl: FunctionDecl,
359
- egg_name: str | None,
360
- cost: int | None,
361
- default: ExprDecl | None,
362
- merge: ExprDecl | None,
363
- merge_action: list[bindings._Action],
364
- unextractable: bool,
365
- builtin: bool,
366
- is_relation: bool = False,
367
- ) -> None:
368
- """
369
- Registers a callable with the given egg name.
199
+ @dataclass
200
+ class RulesetDecl:
201
+ rules: list[RewriteOrRuleDecl]
370
202
 
371
- The callable's function needs to be registered first.
372
- """
373
- egg_name = egg_name or ref.generate_egg_name()
374
- self.register_callable_ref(ref, egg_name)
375
- self.set_function_decl(ref, fn_decl)
376
-
377
- # Skip generating the cmds if we don't want to record them, like for the builtins
378
- if builtin:
379
- return
380
-
381
- if fn_decl.var_arg_type is not None:
382
- msg = "egglog does not support variable arguments yet."
383
- raise NotImplementedError(msg)
384
- # Remove all vars from the type refs, raising an errory if we find one,
385
- # since we cannot create egg functions with vars
386
- arg_sorts = [self.register_sort(a.to_just(), False) for a in fn_decl.arg_types]
387
- cmd: bindings._Command
388
- if is_relation:
389
- assert not default
390
- assert not merge
391
- assert not merge_action
392
- assert not cost
393
- cmd = bindings.Relation(egg_name, arg_sorts)
394
- else:
395
- egg_fn_decl = bindings.FunctionDecl(
396
- egg_name,
397
- bindings.Schema(arg_sorts, self.register_sort(fn_decl.return_type.to_just(), False)),
398
- default.to_egg(self) if default else None,
399
- merge.to_egg(self) if merge else None,
400
- merge_action,
401
- cost,
402
- unextractable,
403
- )
404
- cmd = bindings.Function(egg_fn_decl)
405
- self.add_cmd(egg_name, cmd)
406
-
407
- def register_constant_callable(self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: str | None) -> None:
408
- egg_name = egg_name or ref.generate_egg_name()
409
- self.register_callable_ref(ref, egg_name)
410
- self.set_constant_type(ref, type_ref)
411
- egg_sort = self.register_sort(type_ref, False)
412
- # self.add_cmd(egg_name, bindings.Declare(egg_name, self.get_egg_sort(type_ref)))
413
- # Use function decleration instead of constant b/c constants cannot be extracted
414
- # https://github.com/egraphs-good/egglog/issues/334
415
- fn_decl = bindings.FunctionDecl(egg_name, bindings.Schema([], egg_sort))
416
- self.add_cmd(egg_name, bindings.Function(fn_decl))
417
-
418
- def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None:
419
- self._classes[class_].preserved_methods[method] = fn
203
+ # Make hashable so when traversing for pretty-fying we can know which rulesets we have already
204
+ # made into strings
205
+ def __hash__(self) -> int:
206
+ return hash((type(self), tuple(self.rules)))
420
207
 
421
208
 
422
209
  # Have two different types of type refs, one that can include vars recursively and one that cannot.
@@ -427,38 +214,18 @@ class JustTypeRef:
427
214
  name: str
428
215
  args: tuple[JustTypeRef, ...] = ()
429
216
 
430
- def generate_egg_name(self) -> str:
431
- """
432
- Generates an egg sort name for this type reference by linearizing the type.
433
- """
434
- if not self.args:
435
- return self.name
436
- args = "_".join(a.generate_egg_name() for a in self.args)
437
- return f"{self.name}_{args}"
438
-
439
217
  def to_var(self) -> TypeRefWithVars:
440
218
  return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args))
441
219
 
442
- def pretty(self) -> str:
443
- if not self.args:
444
- return self.name
445
- args = ", ".join(a.pretty() for a in self.args)
446
- return f"{self.name}[{args}]"
220
+ def __str__(self) -> str:
221
+ if self.args:
222
+ return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
223
+ return self.name
447
224
 
448
- def to_constant_function_decl(self) -> FunctionDecl:
449
- """
450
- Create a function declaration for a constant function.
451
225
 
452
- This is similar to how egglog compiles the `constant` command.
453
- """
454
- return FunctionDecl(
455
- arg_types=(),
456
- arg_names=(),
457
- arg_defaults=(),
458
- return_type=self.to_var(),
459
- mutates_first_arg=False,
460
- var_arg_type=None,
461
- )
226
+ ##
227
+ # Type references with vars
228
+ ##
462
229
 
463
230
 
464
231
  @dataclass(frozen=True)
@@ -473,7 +240,7 @@ class ClassTypeVarRef:
473
240
  msg = "egglog does not support generic classes yet."
474
241
  raise NotImplementedError(msg)
475
242
 
476
- def pretty(self) -> str:
243
+ def __str__(self) -> str:
477
244
  return self.name
478
245
 
479
246
 
@@ -485,30 +252,27 @@ class TypeRefWithVars:
485
252
  def to_just(self) -> JustTypeRef:
486
253
  return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
487
254
 
488
- def pretty(self) -> str:
489
- if not self.args:
490
- return self.name
491
- args = ", ".join(a.pretty() for a in self.args)
492
- return f"{self.name}[{args}]"
255
+ def __str__(self) -> str:
256
+ if self.args:
257
+ return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
258
+ return self.name
493
259
 
494
260
 
495
261
  TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
496
262
 
263
+ ##
264
+ # Callables References
265
+ ##
266
+
497
267
 
498
268
  @dataclass(frozen=True)
499
269
  class FunctionRef:
500
270
  name: str
501
271
 
502
- def generate_egg_name(self) -> str:
503
- return self.name
504
-
505
- def __str__(self) -> str:
506
- return self.name
507
272
 
508
-
509
- # Use this special character in place of the args, so that if the args are inlined
510
- # in the viz, they will replace it
511
- ARG = "·"
273
+ @dataclass(frozen=True)
274
+ class ConstantRef:
275
+ name: str
512
276
 
513
277
 
514
278
  @dataclass(frozen=True)
@@ -516,123 +280,115 @@ class MethodRef:
516
280
  class_name: str
517
281
  method_name: str
518
282
 
519
- def generate_egg_name(self) -> str:
520
- return f"{self.class_name}_{self.method_name}"
521
-
522
- def __str__(self) -> str: # noqa: PLR0911
523
- match self.method_name:
524
- case _ if self.method_name in UNARY_METHODS:
525
- return f"{UNARY_METHODS[self.method_name]}{ARG}"
526
- case _ if self.method_name in BINARY_METHODS:
527
- return f"({ARG} {BINARY_METHODS[self.method_name]} {ARG})"
528
- case "__getitem__":
529
- return f"{ARG}[{ARG}]"
530
- case "__call__":
531
- return f"{ARG}({ARG})"
532
- case "__delitem__":
533
- return f"del {ARG}[{ARG}]"
534
- case "__setitem__":
535
- return f"{ARG}[{ARG}] = {ARG}"
536
- return f"{ARG}.{self.method_name}"
537
-
538
283
 
539
284
  @dataclass(frozen=True)
540
285
  class ClassMethodRef:
541
286
  class_name: str
542
287
  method_name: str
543
288
 
544
- def generate_egg_name(self) -> str:
545
- return f"{self.class_name}_{self.method_name}"
546
-
547
- def __str__(self) -> str:
548
- if self.method_name == "__init__":
549
- return self.class_name
550
- return f"{self.class_name}.{self.method_name}"
551
-
552
289
 
553
290
  @dataclass(frozen=True)
554
- class ConstantRef:
555
- name: str
556
-
557
- def generate_egg_name(self) -> str:
558
- return self.name
559
-
560
- def __str__(self) -> str:
561
- return self.name
291
+ class ClassVariableRef:
292
+ class_name: str
293
+ var_name: str
562
294
 
563
295
 
564
296
  @dataclass(frozen=True)
565
- class ClassVariableRef:
297
+ class PropertyRef:
566
298
  class_name: str
567
- variable_name: str
299
+ property_name: str
568
300
 
569
- def generate_egg_name(self) -> str:
570
- return f"{self.class_name}_{self.variable_name}"
571
301
 
572
- def __str__(self) -> str:
573
- return f"{self.class_name}.{self.variable_name}"
302
+ CallableRef: TypeAlias = FunctionRef | ConstantRef | MethodRef | ClassMethodRef | ClassVariableRef | PropertyRef
303
+
304
+
305
+ ##
306
+ # Callables
307
+ ##
574
308
 
575
309
 
576
310
  @dataclass(frozen=True)
577
- class PropertyRef:
578
- class_name: str
579
- property_name: str
311
+ class RelationDecl:
312
+ arg_types: tuple[JustTypeRef, ...]
313
+ # List of defaults. None for any arg which doesn't have one.
314
+ arg_defaults: tuple[ExprDecl | None, ...]
315
+ egg_name: str | None
580
316
 
581
- def generate_egg_name(self) -> str:
582
- return f"{self.class_name}_{self.property_name}"
317
+ def to_function_decl(self) -> FunctionDecl:
318
+ return FunctionDecl(
319
+ arg_types=tuple(a.to_var() for a in self.arg_types),
320
+ arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
321
+ arg_defaults=self.arg_defaults,
322
+ return_type=TypeRefWithVars("Unit"),
323
+ egg_name=self.egg_name,
324
+ default=LitDecl(None),
325
+ )
583
326
 
584
- def __str__(self) -> str:
585
- return f"{ARG}.{self.property_name}"
586
327
 
328
+ @dataclass(frozen=True)
329
+ class ConstantDecl:
330
+ """
331
+ Same as `(declare)` in egglog
332
+ """
587
333
 
588
- ConstantCallableRef: TypeAlias = ConstantRef | ClassVariableRef
589
- FunctionCallableRef: TypeAlias = FunctionRef | MethodRef | ClassMethodRef | PropertyRef
590
- CallableRef: TypeAlias = ConstantCallableRef | FunctionCallableRef
334
+ type_ref: JustTypeRef
335
+ egg_name: str | None = None
336
+
337
+ def to_function_decl(self) -> FunctionDecl:
338
+ return FunctionDecl(
339
+ arg_types=(),
340
+ arg_names=(),
341
+ arg_defaults=(),
342
+ return_type=self.type_ref.to_var(),
343
+ egg_name=self.egg_name,
344
+ )
591
345
 
592
346
 
593
347
  @dataclass(frozen=True)
594
348
  class FunctionDecl:
349
+ # All args are delayed except for relations converted to function decls
595
350
  arg_types: tuple[TypeOrVarRef, ...]
596
- # Is None for relation which doesn't have named args
597
- arg_names: tuple[str, ...] | None
351
+ arg_names: tuple[str, ...]
352
+ # List of defaults. None for any arg which doesn't have one.
598
353
  arg_defaults: tuple[ExprDecl | None, ...]
599
- return_type: TypeOrVarRef
600
- mutates_first_arg: bool
354
+ # If None, then the first arg is mutated and returned
355
+ return_type: TypeOrVarRef | None
601
356
  var_arg_type: TypeOrVarRef | None = None
602
357
 
603
- def __post_init__(self) -> None:
604
- # If we mutate the first arg, then the first arg should be the same type as the return
605
- if self.mutates_first_arg:
606
- assert self.arg_types[0] == self.return_type
607
-
608
- def to_signature(self, transform_default: Callable[[TypedExprDecl], object]) -> Signature:
609
- arg_names = self.arg_names or tuple(f"__{i}" for i in range(len(self.arg_types)))
610
- parameters = [
611
- Parameter(
612
- n,
613
- Parameter.POSITIONAL_OR_KEYWORD,
614
- default=transform_default(TypedExprDecl(t.to_just(), d)) if d else Parameter.empty,
615
- )
616
- for n, d, t in zip(arg_names, self.arg_defaults, self.arg_types, strict=True)
617
- ]
618
- if self.var_arg_type is not None:
619
- parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
620
- return Signature(parameters)
358
+ # Egg params
359
+ builtin: bool = False
360
+ egg_name: str | None = None
361
+ cost: int | None = None
362
+ default: ExprDecl | None = None
363
+ on_merge: tuple[ActionDecl, ...] = ()
364
+ merge: ExprDecl | None = None
365
+ unextractable: bool = False
621
366
 
367
+ def to_function_decl(self) -> FunctionDecl:
368
+ return self
622
369
 
623
- @dataclass(frozen=True)
624
- class VarDecl:
625
- name: str
370
+ @property
371
+ def semantic_return_type(self) -> TypeOrVarRef:
372
+ """
373
+ The type that is returned by the function, which wil be in the first arg if it mutates it.
374
+ """
375
+ return self.return_type or self.arg_types[0]
626
376
 
627
- @classmethod
628
- def from_egg(cls, var: bindings.TermVar) -> ExprDecl:
629
- return cls(var.name)
377
+ @property
378
+ def mutates(self) -> bool:
379
+ return self.return_type is None
630
380
 
631
- def to_egg(self, _decls: Declarations) -> bindings.Var:
632
- return bindings.Var(self.name)
633
381
 
634
- def pretty(self, context: PrettyContext, **kwargs) -> str:
635
- return self.name
382
+ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
383
+
384
+ ##
385
+ # Expressions
386
+ ##
387
+
388
+
389
+ @dataclass(frozen=True)
390
+ class VarDecl:
391
+ name: str
636
392
 
637
393
 
638
394
  @dataclass(frozen=True)
@@ -646,16 +402,14 @@ class PyObjectDecl:
646
402
  except TypeError:
647
403
  return id(self.value)
648
404
 
649
- @classmethod
650
- def from_egg(cls, egraph: bindings.EGraph, termdag: bindings.TermDag, term: bindings.TermApp) -> ExprDecl:
651
- call = bindings.termdag_term_to_expr(termdag, term)
652
- return cls(egraph.eval_py_object(call))
653
-
654
- def to_egg(self, _decls: Declarations) -> bindings._Expr:
655
- return GLOBAL_PY_OBJECT_SORT.store(self.value)
405
+ def __eq__(self, other: object) -> bool:
406
+ if not isinstance(other, PyObjectDecl):
407
+ return False
408
+ return self.parts == other.parts
656
409
 
657
- def pretty(self, context: PrettyContext, **kwargs) -> str:
658
- return repr(self.value)
410
+ @property
411
+ def parts(self) -> tuple[type, object]:
412
+ return (type(self.value), self.value)
659
413
 
660
414
 
661
415
  LitType: TypeAlias = int | str | float | bool | None
@@ -665,53 +419,30 @@ LitType: TypeAlias = int | str | float | bool | None
665
419
  class LitDecl:
666
420
  value: LitType
667
421
 
668
- @classmethod
669
- def from_egg(cls, lit: bindings.TermLit) -> ExprDecl:
670
- value = lit.value
671
- if isinstance(value, bindings.Unit):
672
- return cls(None)
673
- return cls(value.value)
674
-
675
- def to_egg(self, _decls: Declarations) -> bindings.Lit:
676
- if self.value is None:
677
- return bindings.Lit(bindings.Unit())
678
- if isinstance(self.value, bool):
679
- return bindings.Lit(bindings.Bool(self.value))
680
- if isinstance(self.value, int):
681
- return bindings.Lit(bindings.Int(self.value))
682
- if isinstance(self.value, float):
683
- return bindings.Lit(bindings.F64(self.value))
684
- if isinstance(self.value, str):
685
- return bindings.Lit(bindings.String(self.value))
686
- assert_never(self.value)
687
-
688
- def pretty(self, context: PrettyContext, unwrap_lit: bool = True, **kwargs) -> str:
422
+ def __hash__(self) -> int:
689
423
  """
690
- Returns a string representation of the literal.
691
-
692
- :param wrap_lit: If True, wraps the literal in a call to the literal constructor.
424
+ Include type in has so that 1.0 != 1
693
425
  """
694
- if self.value is None:
695
- return "Unit()"
696
- if isinstance(self.value, bool):
697
- return f"Bool({self.value})" if not unwrap_lit else str(self.value)
698
- if isinstance(self.value, int):
699
- return f"i64({self.value})" if not unwrap_lit else str(self.value)
700
- if isinstance(self.value, float):
701
- return f"f64({self.value})" if not unwrap_lit else str(self.value)
702
- if isinstance(self.value, str):
703
- return f"String({self.value!r})" if not unwrap_lit else repr(self.value)
704
- assert_never(self.value)
426
+ return hash(self.parts)
427
+
428
+ def __eq__(self, other: object) -> bool:
429
+ if not isinstance(other, LitDecl):
430
+ return False
431
+ return self.parts == other.parts
432
+
433
+ @property
434
+ def parts(self) -> tuple[type, LitType]:
435
+ return (type(self.value), self.value)
705
436
 
706
437
 
707
438
  @dataclass(frozen=True)
708
439
  class CallDecl:
709
440
  callable: CallableRef
441
+ # TODO: Can I make these not typed expressions?
710
442
  args: tuple[TypedExprDecl, ...] = ()
711
443
  # type parameters that were bound to the callable, if it is a classmethod
712
444
  # Used for pretty printing classmethod calls with type parameters
713
445
  bound_tp_params: tuple[JustTypeRef, ...] | None = None
714
- _cached_hash: int | None = None
715
446
 
716
447
  def __post_init__(self) -> None:
717
448
  if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
@@ -719,302 +450,165 @@ class CallDecl:
719
450
  raise ValueError(msg)
720
451
 
721
452
  def __hash__(self) -> int:
722
- # Modified hash which will cache result for performance
723
- if self._cached_hash is None:
724
- res = hash((self.callable, self.args, self.bound_tp_params))
725
- object.__setattr__(self, "_cached_hash", res)
726
- return res
727
453
  return self._cached_hash
728
454
 
455
+ @cached_property
456
+ def _cached_hash(self) -> int:
457
+ return hash((self.callable, self.args, self.bound_tp_params))
458
+
729
459
  def __eq__(self, other: object) -> bool:
730
460
  # Override eq to use cached hash for perf
731
461
  if not isinstance(other, CallDecl):
732
462
  return False
733
463
  return hash(self) == hash(other)
734
464
 
735
- @classmethod
736
- def from_egg(
737
- cls,
738
- egraph: bindings.EGraph,
739
- decls: Declarations,
740
- return_tp: JustTypeRef,
741
- termdag: bindings.TermDag,
742
- term: bindings.TermApp,
743
- cache: dict[int, TypedExprDecl],
744
- ) -> ExprDecl:
745
- """
746
- Convert an egg expression into a typed expression by using the declerations.
747
465
 
748
- Also pass in the desired type to do type checking top down. Needed to disambiguate calls like (map-create)
749
- during expression extraction, where we always know the types.
466
+ ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl
467
+
468
+
469
+ @dataclass(frozen=True)
470
+ class TypedExprDecl:
471
+ tp: JustTypeRef
472
+ expr: ExprDecl
473
+
474
+ def descendants(self) -> list[TypedExprDecl]:
750
475
  """
751
- from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
752
-
753
- # Find the first callable ref that matches the call
754
- for callable_ref in decls.get_callable_refs(term.name):
755
- # If this is a classmethod, we might need the type params that were bound for this type
756
- # This could be multiple types if the classmethod is ambiguous, like map create.
757
- possible_types: Iterable[JustTypeRef | None]
758
- fn_decl = decls.get_function_decl(callable_ref)
759
- if isinstance(callable_ref, ClassMethodRef):
760
- possible_types = decls.get_possible_types(callable_ref.class_name)
761
- cls_name = callable_ref.class_name
762
- else:
763
- possible_types = [None]
764
- cls_name = None
765
- for possible_type in possible_types:
766
- tcs = TypeConstraintSolver(decls)
767
- if possible_type and possible_type.args:
768
- tcs.bind_class(possible_type)
769
-
770
- try:
771
- arg_types, bound_tp_params = tcs.infer_arg_types(
772
- fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, return_tp, cls_name
773
- )
774
- except TypeConstraintError:
775
- continue
776
- args: list[TypedExprDecl] = []
777
- for a, tp in zip(term.args, arg_types, strict=False):
778
- if a in cache:
779
- res = cache[a]
780
- else:
781
- res = TypedExprDecl.from_egg(egraph, decls, tp, termdag, termdag.nodes[a], cache)
782
- cache[a] = res
783
- args.append(res)
784
- return cls(callable_ref, tuple(args), bound_tp_params)
785
- raise ValueError(f"Could not find callable ref for call {term}")
786
-
787
- def to_egg(self, decls: Declarations) -> bindings._Expr:
788
- """Convert a Call to an egg Call."""
789
- # This was removed when we replaced declerations constants with our b/c of unextractable constants
790
- # # If this is a constant, then emit it just as a var, not as a call
791
- # if isinstance(self.callable, ConstantRef | ClassVariableRef):
792
- # decls.get_egg_fn
793
- # return bindings.Var(egg_fn)
794
- if hasattr(self, "_cached_egg"):
795
- return self._cached_egg
796
- egg_fn = decls.get_egg_fn(self.callable)
797
- res = bindings.Call(egg_fn, [a.to_egg(decls) for a in self.args])
798
- object.__setattr__(self, "_cached_egg", res)
799
- return res
800
-
801
- def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str: # noqa: C901
476
+ Returns a list of all the descendants of this expression.
802
477
  """
803
- Pretty print the call.
478
+ l = [self]
479
+ if isinstance(self.expr, CallDecl):
480
+ for a in self.expr.args:
481
+ l.extend(a.descendants())
482
+ return l
804
483
 
805
- :param parens: If true, wrap the call in parens if it is a binary method call.
806
- """
807
- if self in context.names:
808
- return context.names[self]
809
- ref, args = self.callable, [a.expr for a in self.args]
810
- # Special case !=
811
- if ref == FunctionRef("!="):
812
- return f"ne({args[0].pretty(context, parens=False, unwrap_lit=False)}).to({args[1].pretty(context, parens=False, unwrap_lit=False)})"
813
- function_decl = context.decls.get_function_decl(ref)
814
- # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
815
- n_defaults = 0
816
- for arg, default in zip(
817
- reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type
818
- ):
819
- if arg != default:
820
- break
821
- n_defaults += 1
822
- if n_defaults:
823
- args = args[:-n_defaults]
824
- if function_decl.mutates_first_arg:
825
- first_arg = args[0]
826
- expr_str = first_arg.pretty(context, parens=False)
827
- # copy an identifer expression iff it has multiple parents (b/c then we can't mutate it directly)
828
- has_multiple_parents = context.parents[first_arg] > 1
829
- expr_name = context.name_expr(function_decl.arg_types[0], expr_str, copy_identifier=has_multiple_parents)
830
- # Set the first arg to be the name of the mutated arg and return the name
831
- args[0] = VarDecl(expr_name)
832
- else:
833
- expr_name = None
834
- match ref:
835
- case FunctionRef(name):
836
- expr = _pretty_call(context, name, args)
837
- case ClassMethodRef(class_name, method_name):
838
- tp_ref = JustTypeRef(class_name, self.bound_tp_params or ())
839
- fn_str = tp_ref.pretty() if method_name == "__init__" else f"{tp_ref.pretty()}.{method_name}"
840
- expr = _pretty_call(context, fn_str, args)
841
- case MethodRef(_class_name, method_name):
842
- slf, *args = args
843
- slf = slf.pretty(context, unwrap_lit=False)
844
- match method_name:
845
- case _ if method_name in UNARY_METHODS:
846
- expr = f"{UNARY_METHODS[method_name]}{slf}"
847
- case _ if method_name in BINARY_METHODS:
848
- assert len(args) == 1
849
- expr = f"{slf} {BINARY_METHODS[method_name]} {args[0].pretty(context)}"
850
- if parens:
851
- expr = f"({expr})"
852
- case "__getitem__":
853
- assert len(args) == 1
854
- expr = f"{slf}[{args[0].pretty(context, parens=False)}]"
855
- case "__call__":
856
- expr = _pretty_call(context, slf, args)
857
- case "__delitem__":
858
- assert len(args) == 1
859
- expr = f"del {slf}[{args[0].pretty(context, parens=False)}]"
860
- case "__setitem__":
861
- assert len(args) == 2
862
- expr = (
863
- f"{slf}[{args[0].pretty(context, parens=False)}] = {args[1].pretty(context, parens=False)}"
864
- )
865
- case _:
866
- expr = _pretty_call(context, f"{slf}.{method_name}", args)
867
- case ConstantRef(name):
868
- expr = name
869
- case ClassVariableRef(class_name, variable_name):
870
- expr = f"{class_name}.{variable_name}"
871
- case PropertyRef(_class_name, property_name):
872
- expr = f"{args[0].pretty(context)}.{property_name}"
873
- case _:
874
- assert_never(ref)
875
- # If we have a name, then we mutated
876
- if expr_name:
877
- context.statements.append(expr)
878
- context.names[self] = expr_name
879
- return expr_name
880
-
881
- # We use a heuristic to decide whether to name this sub-expression as a variable
882
- # The rough goal is to reduce the number of newlines, given our line length of ~180
883
- # We determine it's worth making a new line for this expression if the total characters
884
- # it would take up is > than some constant (~ line length).
885
- n_parents = context.parents[self]
886
- line_diff: int = len(expr) - LINE_DIFFERENCE
887
- if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH:
888
- expr_name = context.name_expr(function_decl.return_type, expr, copy_identifier=False)
889
- context.names[self] = expr_name
890
- return expr_name
891
- return expr
892
-
893
-
894
- MAX_LINE_LENGTH = 110
895
- LINE_DIFFERENCE = 10
896
-
897
-
898
- def _plot_line_length(expr: object):
899
- """
900
- Plots the number of line lengths based on different max lengths
901
- """
902
- global MAX_LINE_LENGTH, LINE_DIFFERENCE
903
- import altair as alt
904
- import pandas as pd
905
484
 
906
- sizes = []
907
- for line_length in range(40, 180, 10):
908
- MAX_LINE_LENGTH = line_length
909
- for diff in range(0, 40, 5):
910
- LINE_DIFFERENCE = diff
911
- new_l = len(str(expr).split())
912
- sizes.append((line_length, diff, new_l))
485
+ ##
486
+ # Schedules
487
+ ##
913
488
 
914
- df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"]) # noqa: PD901
915
489
 
916
- return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")
490
+ @dataclass(frozen=True)
491
+ class SaturateDecl:
492
+ schedule: ScheduleDecl
917
493
 
918
494
 
919
- def _pretty_call(context: PrettyContext, fn: str, args: Iterable[ExprDecl]) -> str:
920
- return f"{fn}({', '.join(a.pretty(context, parens=False) for a in args)})"
495
+ @dataclass(frozen=True)
496
+ class RepeatDecl:
497
+ schedule: ScheduleDecl
498
+ times: int
921
499
 
922
500
 
923
- @dataclass
924
- class PrettyContext:
925
- decls: Declarations
926
- # List of statements of "context" setting variable for the expr
927
- statements: list[str] = field(default_factory=list)
928
-
929
- names: dict[ExprDecl, str] = field(default_factory=dict)
930
- parents: dict[ExprDecl, int] = field(default_factory=lambda: defaultdict(lambda: 0))
931
- _traversed_exprs: set[ExprDecl] = field(default_factory=set)
932
-
933
- # Mapping of type to the number of times we have generated a name for that type, used to generate unique names
934
- _gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
935
-
936
- def generate_name(self, typ: str) -> str:
937
- self._gen_name_types[typ] += 1
938
- return f"_{typ}_{self._gen_name_types[typ]}"
939
-
940
- def name_expr(self, expr_type: TypeOrVarRef, expr_str: str, copy_identifier: bool) -> str:
941
- tp_name = expr_type.to_just().name
942
- # If the thing we are naming is already a variable, we don't need to name it
943
- if expr_str.isidentifier():
944
- if copy_identifier:
945
- name = self.generate_name(tp_name)
946
- self.statements.append(f"{name} = copy({expr_str})")
947
- else:
948
- name = expr_str
949
- else:
950
- name = self.generate_name(tp_name)
951
- self.statements.append(f"{name} = {expr_str}")
952
- return name
501
+ @dataclass(frozen=True)
502
+ class SequenceDecl:
503
+ schedules: tuple[ScheduleDecl, ...]
504
+
953
505
 
954
- def render(self, expr: str) -> str:
955
- return "\n".join([*self.statements, expr])
506
+ @dataclass(frozen=True)
507
+ class RunDecl:
508
+ ruleset: str
509
+ until: tuple[FactDecl, ...] | None
956
510
 
957
- def traverse_for_parents(self, expr: ExprDecl) -> None:
958
- if expr in self._traversed_exprs:
959
- return
960
- self._traversed_exprs.add(expr)
961
- if isinstance(expr, CallDecl):
962
- for arg in set(expr.args):
963
- self.parents[arg.expr] += 1
964
- self.traverse_for_parents(arg.expr)
965
511
 
512
+ ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
966
513
 
967
- ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl
514
+ ##
515
+ # Facts
516
+ ##
968
517
 
969
518
 
970
519
  @dataclass(frozen=True)
971
- class TypedExprDecl:
520
+ class EqDecl:
972
521
  tp: JustTypeRef
973
- expr: ExprDecl
522
+ exprs: tuple[ExprDecl, ...]
974
523
 
975
- @classmethod
976
- def from_egg(
977
- cls,
978
- egraph: bindings.EGraph,
979
- decls: Declarations,
980
- tp: JustTypeRef,
981
- termdag: bindings.TermDag,
982
- term: bindings._Term,
983
- cache: dict[int, TypedExprDecl],
984
- ) -> TypedExprDecl:
985
- expr_decl: ExprDecl
986
- if isinstance(term, bindings.TermVar):
987
- expr_decl = VarDecl.from_egg(term)
988
- elif isinstance(term, bindings.TermLit):
989
- expr_decl = LitDecl.from_egg(term)
990
- elif isinstance(term, bindings.TermApp):
991
- if term.name == "py-object":
992
- expr_decl = PyObjectDecl.from_egg(egraph, termdag, term)
993
- else:
994
- expr_decl = CallDecl.from_egg(egraph, decls, tp, termdag, term, cache)
995
- else:
996
- assert_never(term)
997
- return cls(tp, expr_decl)
998
524
 
999
- def to_egg(self, decls: Declarations) -> bindings._Expr:
1000
- return self.expr.to_egg(decls)
525
+ @dataclass(frozen=True)
526
+ class ExprFactDecl:
527
+ typed_expr: TypedExprDecl
1001
528
 
1002
- def descendants(self) -> list[TypedExprDecl]:
1003
- """
1004
- Returns a list of all the descendants of this expression.
1005
- """
1006
- l = [self]
1007
- if isinstance(self.expr, CallDecl):
1008
- for a in self.expr.args:
1009
- l.extend(a.descendants())
1010
- return l
1011
529
 
530
+ FactDecl: TypeAlias = EqDecl | ExprFactDecl
1012
531
 
1013
- @dataclass
1014
- class ClassDecl:
1015
- methods: dict[str, FunctionDecl] = field(default_factory=dict)
1016
- class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
1017
- class_variables: dict[str, JustTypeRef] = field(default_factory=dict)
1018
- properties: dict[str, FunctionDecl] = field(default_factory=dict)
1019
- preserved_methods: dict[str, Callable] = field(default_factory=dict)
1020
- type_vars: tuple[str, ...] = field(default=())
532
+ ##
533
+ # Actions
534
+ ##
535
+
536
+
537
+ @dataclass(frozen=True)
538
+ class LetDecl:
539
+ name: str
540
+ typed_expr: TypedExprDecl
541
+
542
+
543
+ @dataclass(frozen=True)
544
+ class SetDecl:
545
+ tp: JustTypeRef
546
+ call: CallDecl
547
+ rhs: ExprDecl
548
+
549
+
550
+ @dataclass(frozen=True)
551
+ class ExprActionDecl:
552
+ typed_expr: TypedExprDecl
553
+
554
+
555
+ @dataclass(frozen=True)
556
+ class ChangeDecl:
557
+ tp: JustTypeRef
558
+ call: CallDecl
559
+ change: Literal["delete", "subsume"]
560
+
561
+
562
+ @dataclass(frozen=True)
563
+ class UnionDecl:
564
+ tp: JustTypeRef
565
+ lhs: ExprDecl
566
+ rhs: ExprDecl
567
+
568
+
569
+ @dataclass(frozen=True)
570
+ class PanicDecl:
571
+ msg: str
572
+
573
+
574
+ ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
575
+
576
+
577
+ ##
578
+ # Commands
579
+ ##
580
+
581
+
582
+ @dataclass(frozen=True)
583
+ class RewriteDecl:
584
+ tp: JustTypeRef
585
+ lhs: ExprDecl
586
+ rhs: ExprDecl
587
+ conditions: tuple[FactDecl, ...]
588
+ subsume: bool
589
+
590
+
591
+ @dataclass(frozen=True)
592
+ class BiRewriteDecl:
593
+ tp: JustTypeRef
594
+ lhs: ExprDecl
595
+ rhs: ExprDecl
596
+ conditions: tuple[FactDecl, ...]
597
+
598
+
599
+ @dataclass(frozen=True)
600
+ class RuleDecl:
601
+ head: tuple[ActionDecl, ...]
602
+ body: tuple[FactDecl, ...]
603
+ name: str | None
604
+
605
+
606
+ RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl
607
+
608
+
609
+ @dataclass(frozen=True)
610
+ class ActionCommandDecl:
611
+ action: ActionDecl
612
+
613
+
614
+ CommandDecl: TypeAlias = RewriteOrRuleDecl | ActionCommandDecl