egglog 6.1.0__cp311-none-win_amd64.whl → 7.1.0__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/declarations.py CHANGED
@@ -1,18 +1,17 @@
1
1
  """
2
2
  Data only descriptions of the components of an egraph and the expressions.
3
+
4
+ We seperate it it into two pieces, the references the declerations, so that we can report mutually recursive types.
3
5
  """
4
6
 
5
7
  from __future__ import annotations
6
8
 
7
- from collections import defaultdict
8
9
  from dataclasses import dataclass, field
9
- from inspect import Parameter, Signature
10
- from typing import TYPE_CHECKING, Protocol, TypeAlias, Union, runtime_checkable
10
+ from functools import cached_property
11
+ from typing import TYPE_CHECKING, Literal, Protocol, TypeAlias, Union, runtime_checkable
11
12
 
12
13
  from typing_extensions import Self, assert_never
13
14
 
14
- from . import bindings
15
-
16
15
  if TYPE_CHECKING:
17
16
  from collections.abc import Callable, Iterable
18
17
 
@@ -20,84 +19,68 @@ if TYPE_CHECKING:
20
19
  __all__ = [
21
20
  "Declarations",
22
21
  "DeclerationsLike",
23
- "upcast_decleratioons",
22
+ "DelayedDeclerations",
23
+ "upcast_declerations",
24
+ "Declarations",
24
25
  "JustTypeRef",
25
26
  "ClassTypeVarRef",
26
27
  "TypeRefWithVars",
27
28
  "TypeOrVarRef",
28
- "FunctionRef",
29
29
  "MethodRef",
30
30
  "ClassMethodRef",
31
+ "FunctionRef",
32
+ "ConstantRef",
31
33
  "ClassVariableRef",
32
- "FunctionCallableRef",
33
34
  "PropertyRef",
34
35
  "CallableRef",
35
- "ConstantRef",
36
36
  "FunctionDecl",
37
+ "RelationDecl",
38
+ "ConstantDecl",
39
+ "CallableDecl",
37
40
  "VarDecl",
38
- "LitType",
39
41
  "PyObjectDecl",
42
+ "PartialCallDecl",
43
+ "LitType",
40
44
  "LitDecl",
41
45
  "CallDecl",
42
46
  "ExprDecl",
43
47
  "TypedExprDecl",
44
48
  "ClassDecl",
45
- "PrettyContext",
46
- "GLOBAL_PY_OBJECT_SORT",
49
+ "RulesetDecl",
50
+ "CombinedRulesetDecl",
51
+ "SaturateDecl",
52
+ "RepeatDecl",
53
+ "SequenceDecl",
54
+ "RunDecl",
55
+ "ScheduleDecl",
56
+ "EqDecl",
57
+ "ExprFactDecl",
58
+ "FactDecl",
59
+ "LetDecl",
60
+ "SetDecl",
61
+ "ExprActionDecl",
62
+ "ChangeDecl",
63
+ "UnionDecl",
64
+ "PanicDecl",
65
+ "ActionDecl",
66
+ "RewriteDecl",
67
+ "BiRewriteDecl",
68
+ "RuleDecl",
69
+ "RewriteOrRuleDecl",
70
+ "ActionCommandDecl",
71
+ "CommandDecl",
72
+ "SpecialFunctions",
73
+ "FunctionSignature",
47
74
  ]
48
75
 
49
- # Create a global sort for python objects, so we can store them without an e-graph instance
50
- # Needed when serializing commands to egg commands when creating modules
51
- GLOBAL_PY_OBJECT_SORT = bindings.PyObjectSort()
52
-
53
- # Special methods which we might want to use as functions
54
- # Mapping to the operator they represent for pretty printing them
55
- # https://docs.python.org/3/reference/datamodel.html
56
- BINARY_METHODS = {
57
- "__lt__": "<",
58
- "__le__": "<=",
59
- "__eq__": "==",
60
- "__ne__": "!=",
61
- "__gt__": ">",
62
- "__ge__": ">=",
63
- # Numeric
64
- "__add__": "+",
65
- "__sub__": "-",
66
- "__mul__": "*",
67
- "__matmul__": "@",
68
- "__truediv__": "/",
69
- "__floordiv__": "//",
70
- "__mod__": "%",
71
- # TODO: Support divmod, with tuple return value
72
- # "__divmod__": "divmod",
73
- # TODO: Three arg power
74
- "__pow__": "**",
75
- "__lshift__": "<<",
76
- "__rshift__": ">>",
77
- "__and__": "&",
78
- "__xor__": "^",
79
- "__or__": "|",
80
- }
81
- REFLECTED_BINARY_METHODS = {
82
- "__radd__": "__add__",
83
- "__rsub__": "__sub__",
84
- "__rmul__": "__mul__",
85
- "__rmatmul__": "__matmul__",
86
- "__rtruediv__": "__truediv__",
87
- "__rfloordiv__": "__floordiv__",
88
- "__rmod__": "__mod__",
89
- "__rpow__": "__pow__",
90
- "__rlshift__": "__lshift__",
91
- "__rrshift__": "__rshift__",
92
- "__rand__": "__and__",
93
- "__rxor__": "__xor__",
94
- "__ror__": "__or__",
95
- }
96
- UNARY_METHODS = {
97
- "__pos__": "+",
98
- "__neg__": "-",
99
- "__invert__": "~",
100
- }
76
+
77
+ @dataclass
78
+ class DelayedDeclerations:
79
+ __egg_decls_thunk__: Callable[[], Declarations]
80
+
81
+ @property
82
+ def __egg_decls__(self) -> Declarations:
83
+ return self.__egg_decls_thunk__()
101
84
 
102
85
 
103
86
  @runtime_checkable
@@ -109,7 +92,7 @@ class HasDeclerations(Protocol):
109
92
  DeclerationsLike: TypeAlias = Union[HasDeclerations, None, "Declarations"]
110
93
 
111
94
 
112
- def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
95
+ def upcast_declerations(declerations_like: Iterable[DeclerationsLike]) -> list[Declarations]:
113
96
  d = []
114
97
  for l in declerations_like:
115
98
  if l is None:
@@ -125,30 +108,14 @@ def upcast_decleratioons(declerations_like: Iterable[DeclerationsLike]) -> list[
125
108
 
126
109
  @dataclass
127
110
  class Declarations:
128
- _functions: dict[str, FunctionDecl] = field(default_factory=dict)
111
+ _functions: dict[str, FunctionDecl | RelationDecl] = field(default_factory=dict)
112
+ _constants: dict[str, ConstantDecl] = field(default_factory=dict)
129
113
  _classes: dict[str, ClassDecl] = field(default_factory=dict)
130
- _constants: dict[str, JustTypeRef] = field(default_factory=dict)
131
-
132
- # Bidirectional mapping between egg function names and python callable references.
133
- # Note that there are possibly mutliple callable references for a single egg function name, like `+`
134
- # for both int and rational classes.
135
- _egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(default_factory=lambda: defaultdict(set))
136
- _callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=dict)
137
-
138
- # Bidirectional mapping between egg sort names and python type references.
139
- _egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
140
- _type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
141
-
142
- # Mapping from egg name (of sort or function) to command to create it.
143
- _cmds: dict[str, bindings._Command] = field(default_factory=dict)
144
-
145
- def __post_init__(self) -> None:
146
- if "!=" not in self._egg_fn_to_callable_refs:
147
- self.register_callable_ref(FunctionRef("!="), "!=")
114
+ _rulesets: dict[str, RulesetDecl | CombinedRulesetDecl] = field(default_factory=lambda: {"": RulesetDecl([])})
148
115
 
149
116
  @classmethod
150
117
  def create(cls, *others: DeclerationsLike) -> Declarations:
151
- others = upcast_decleratioons(others)
118
+ others = upcast_declerations(others)
152
119
  if not others:
153
120
  return Declarations()
154
121
  first, *rest = others
@@ -159,25 +126,9 @@ class Declarations:
159
126
  return new
160
127
 
161
128
  def copy(self) -> Declarations:
162
- return Declarations(
163
- _functions=self._functions.copy(),
164
- _classes=self._classes.copy(),
165
- _constants=self._constants.copy(),
166
- _egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self._egg_fn_to_callable_refs.items()}),
167
- _callable_ref_to_egg_fn=self._callable_ref_to_egg_fn.copy(),
168
- _egg_sort_to_type_ref=self._egg_sort_to_type_ref.copy(),
169
- _type_ref_to_egg_sort=self._type_ref_to_egg_sort.copy(),
170
- _cmds=self._cmds.copy(),
171
- )
172
-
173
- def __deepcopy__(self, memo: dict) -> Declarations:
174
- return self.copy()
175
-
176
- def add_cmd(self, name: str, cmd: bindings._Command) -> None:
177
- self._cmds[name] = cmd
178
-
179
- def list_cmds(self) -> list[bindings._Command]:
180
- return list(self._cmds.values())
129
+ new = Declarations()
130
+ new |= self
131
+ return new
181
132
 
182
133
  def update(self, *others: DeclerationsLike) -> None:
183
134
  for other in others:
@@ -200,82 +151,26 @@ class Declarations:
200
151
  """
201
152
  Updates the other decl with these values in palce.
202
153
  """
203
- # If cmds are == skip unioning for time savings
204
- # if set(self._cmds) == set(other._cmds) and self.record_cmds and other.record_cmds:
205
- # return self
206
154
  other._functions |= self._functions
207
155
  other._classes |= self._classes
208
156
  other._constants |= self._constants
209
- other._egg_sort_to_type_ref |= self._egg_sort_to_type_ref
210
- other._type_ref_to_egg_sort |= self._type_ref_to_egg_sort
211
- other._cmds |= self._cmds
212
- other._callable_ref_to_egg_fn |= self._callable_ref_to_egg_fn
213
- for egg_fn, callable_refs in self._egg_fn_to_callable_refs.items():
214
- other._egg_fn_to_callable_refs[egg_fn] |= callable_refs
215
-
216
- def set_function_decl(self, ref: FunctionCallableRef, decl: FunctionDecl) -> None:
217
- """
218
- Sets a function declaration for the given callable reference.
219
- """
157
+ other._rulesets |= self._rulesets
158
+
159
+ def get_callable_decl(self, ref: CallableRef) -> CallableDecl:
220
160
  match ref:
221
161
  case FunctionRef(name):
222
- if name in self._functions:
223
- raise ValueError(f"Function {name} already registered")
224
- self._functions[name] = decl
162
+ return self._functions[name]
163
+ case ConstantRef(name):
164
+ return self._constants[name]
225
165
  case MethodRef(class_name, method_name):
226
- if method_name in self._classes[class_name].methods:
227
- raise ValueError(f"Method {class_name}.{method_name} already registered")
228
- self._classes[class_name].methods[method_name] = decl
229
- case ClassMethodRef(class_name, method_name):
230
- if method_name in self._classes[class_name].class_methods:
231
- raise ValueError(f"Class method {class_name}.{method_name} already registered")
232
- self._classes[class_name].class_methods[method_name] = decl
166
+ return self._classes[class_name].methods[method_name]
167
+ case ClassVariableRef(class_name, name):
168
+ return self._classes[class_name].class_variables[name]
169
+ case ClassMethodRef(class_name, name):
170
+ return self._classes[class_name].class_methods[name]
233
171
  case PropertyRef(class_name, property_name):
234
- if property_name in self._classes[class_name].properties:
235
- raise ValueError(f"Property {class_name}.{property_name} already registered")
236
- self._classes[class_name].properties[property_name] = decl
237
- case _:
238
- assert_never(ref)
239
-
240
- def set_constant_type(self, ref: ConstantCallableRef, tp: JustTypeRef) -> None:
241
- match ref:
242
- case ConstantRef(name):
243
- if name in self._constants:
244
- raise ValueError(f"Constant {name} already registered")
245
- self._constants[name] = tp
246
- case ClassVariableRef(class_name, variable_name):
247
- if variable_name in self._classes[class_name].class_variables:
248
- raise ValueError(f"Class variable {class_name}.{variable_name} already registered")
249
- self._classes[class_name].class_variables[variable_name] = tp
250
- case _:
251
- assert_never(ref)
252
-
253
- def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None:
254
- """
255
- Registers a callable reference with the given egg name.
256
-
257
- The callable's function needs to be registered first.
258
- """
259
- if ref in self._callable_ref_to_egg_fn:
260
- raise ValueError(f"Callable ref {ref} already registered")
261
- self._callable_ref_to_egg_fn[ref] = egg_name
262
- self._egg_fn_to_callable_refs[egg_name].add(ref)
263
-
264
- def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]:
265
- return self._egg_fn_to_callable_refs[egg_name]
266
-
267
- def get_egg_fn(self, ref: CallableRef) -> str:
268
- return self._callable_ref_to_egg_fn[ref]
269
-
270
- def get_egg_sort(self, ref: JustTypeRef) -> str:
271
- return self._type_ref_to_egg_sort[ref]
272
-
273
- def op_mapping(self) -> dict[str, str]:
274
- """
275
- Create a mapping of egglog function name to Python function name, for use in the serialized format
276
- for better visualization.
277
- """
278
- return {k: str(next(iter(v))) for k, v in self._egg_fn_to_callable_refs.items() if len(v) == 1}
172
+ return self._classes[class_name].properties[property_name]
173
+ assert_never(ref)
279
174
 
280
175
  def has_method(self, class_name: str, method_name: str) -> bool | None:
281
176
  """
@@ -285,138 +180,36 @@ class Declarations:
285
180
  return method_name in self._classes[class_name].methods
286
181
  return None
287
182
 
288
- def get_function_decl(self, ref: CallableRef) -> FunctionDecl:
289
- match ref:
290
- case ConstantRef(name):
291
- return self._constants[name].to_constant_function_decl()
292
- case ClassVariableRef(class_name, variable_name):
293
- return self._classes[class_name].class_variables[variable_name].to_constant_function_decl()
294
- case FunctionRef(name):
295
- return self._functions[name]
296
- case MethodRef(class_name, method_name):
297
- return self._classes[class_name].methods[method_name]
298
- case ClassMethodRef(class_name, method_name):
299
- return self._classes[class_name].class_methods[method_name]
300
- case PropertyRef(class_name, property_name):
301
- return self._classes[class_name].properties[property_name]
302
- assert_never(ref)
303
-
304
183
  def get_class_decl(self, name: str) -> ClassDecl:
305
184
  return self._classes[name]
306
185
 
307
- def get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]:
308
- """
309
- Given a class name, returns all possible registered types that it can be.
310
- """
311
- return frozenset(tp for tp in self._type_ref_to_egg_sort if tp.name == cls_name)
312
186
 
313
- def register_class(self, name: str, type_vars: tuple[str, ...], builtin: bool, egg_sort: str | None) -> None:
314
- # Register class first
315
- if name in self._classes:
316
- raise ValueError(f"Class {name} already registered")
317
- decl = ClassDecl(type_vars=type_vars)
318
- self._classes[name] = decl
319
- self.register_sort(JustTypeRef(name), builtin, egg_sort)
187
+ @dataclass
188
+ class ClassDecl:
189
+ egg_name: str | None = None
190
+ type_vars: tuple[str, ...] = ()
191
+ builtin: bool = False
192
+ class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
193
+ # These have to be seperate from class_methods so that printing them can be done easily
194
+ class_variables: dict[str, ConstantDecl] = field(default_factory=dict)
195
+ methods: dict[str, FunctionDecl] = field(default_factory=dict)
196
+ properties: dict[str, FunctionDecl] = field(default_factory=dict)
197
+ preserved_methods: dict[str, Callable] = field(default_factory=dict)
320
198
 
321
- def register_sort(self, ref: JustTypeRef, builtin: bool, egg_name: str | None = None) -> str:
322
- """
323
- Register a sort with the given name. If no name is given, one is generated.
324
199
 
325
- If this is a type called with generic args, register the generic args as well.
326
- """
327
- # If the sort is already registered, do nothing
328
- try:
329
- egg_sort = self.get_egg_sort(ref)
330
- except KeyError:
331
- pass
332
- else:
333
- return egg_sort
334
- egg_name = egg_name or ref.generate_egg_name()
335
- if egg_name in self._egg_sort_to_type_ref:
336
- raise ValueError(f"Sort {egg_name} is already registered.")
337
- self._egg_sort_to_type_ref[egg_name] = ref
338
- self._type_ref_to_egg_sort[ref] = egg_name
339
- if not builtin:
340
- self.add_cmd(
341
- egg_name,
342
- bindings.Sort(
343
- egg_name,
344
- (
345
- self.get_egg_sort(JustTypeRef(ref.name)),
346
- [bindings.Var(self.register_sort(arg, False)) for arg in ref.args],
347
- )
348
- if ref.args
349
- else None,
350
- ),
351
- )
352
-
353
- return egg_name
354
-
355
- def register_function_callable(
356
- self,
357
- ref: FunctionCallableRef,
358
- fn_decl: FunctionDecl,
359
- egg_name: str | None,
360
- cost: int | None,
361
- default: ExprDecl | None,
362
- merge: ExprDecl | None,
363
- merge_action: list[bindings._Action],
364
- unextractable: bool,
365
- builtin: bool,
366
- is_relation: bool = False,
367
- ) -> None:
368
- """
369
- Registers a callable with the given egg name.
200
+ @dataclass(frozen=True)
201
+ class RulesetDecl:
202
+ rules: list[RewriteOrRuleDecl]
370
203
 
371
- The callable's function needs to be registered first.
372
- """
373
- egg_name = egg_name or ref.generate_egg_name()
374
- self.register_callable_ref(ref, egg_name)
375
- self.set_function_decl(ref, fn_decl)
376
-
377
- # Skip generating the cmds if we don't want to record them, like for the builtins
378
- if builtin:
379
- return
380
-
381
- if fn_decl.var_arg_type is not None:
382
- msg = "egglog does not support variable arguments yet."
383
- raise NotImplementedError(msg)
384
- # Remove all vars from the type refs, raising an errory if we find one,
385
- # since we cannot create egg functions with vars
386
- arg_sorts = [self.register_sort(a.to_just(), False) for a in fn_decl.arg_types]
387
- cmd: bindings._Command
388
- if is_relation:
389
- assert not default
390
- assert not merge
391
- assert not merge_action
392
- assert not cost
393
- cmd = bindings.Relation(egg_name, arg_sorts)
394
- else:
395
- egg_fn_decl = bindings.FunctionDecl(
396
- egg_name,
397
- bindings.Schema(arg_sorts, self.register_sort(fn_decl.return_type.to_just(), False)),
398
- default.to_egg(self) if default else None,
399
- merge.to_egg(self) if merge else None,
400
- merge_action,
401
- cost,
402
- unextractable,
403
- )
404
- cmd = bindings.Function(egg_fn_decl)
405
- self.add_cmd(egg_name, cmd)
406
-
407
- def register_constant_callable(self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: str | None) -> None:
408
- egg_name = egg_name or ref.generate_egg_name()
409
- self.register_callable_ref(ref, egg_name)
410
- self.set_constant_type(ref, type_ref)
411
- egg_sort = self.register_sort(type_ref, False)
412
- # self.add_cmd(egg_name, bindings.Declare(egg_name, self.get_egg_sort(type_ref)))
413
- # Use function decleration instead of constant b/c constants cannot be extracted
414
- # https://github.com/egraphs-good/egglog/issues/334
415
- fn_decl = bindings.FunctionDecl(egg_name, bindings.Schema([], egg_sort))
416
- self.add_cmd(egg_name, bindings.Function(fn_decl))
417
-
418
- def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None:
419
- self._classes[class_].preserved_methods[method] = fn
204
+ # Make hashable so when traversing for pretty-fying we can know which rulesets we have already
205
+ # made into strings
206
+ def __hash__(self) -> int:
207
+ return hash((type(self), tuple(self.rules)))
208
+
209
+
210
+ @dataclass(frozen=True)
211
+ class CombinedRulesetDecl:
212
+ rulesets: tuple[str, ...]
420
213
 
421
214
 
422
215
  # Have two different types of type refs, one that can include vars recursively and one that cannot.
@@ -427,38 +220,18 @@ class JustTypeRef:
427
220
  name: str
428
221
  args: tuple[JustTypeRef, ...] = ()
429
222
 
430
- def generate_egg_name(self) -> str:
431
- """
432
- Generates an egg sort name for this type reference by linearizing the type.
433
- """
434
- if not self.args:
435
- return self.name
436
- args = "_".join(a.generate_egg_name() for a in self.args)
437
- return f"{self.name}_{args}"
438
-
439
223
  def to_var(self) -> TypeRefWithVars:
440
224
  return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args))
441
225
 
442
- def pretty(self) -> str:
443
- if not self.args:
444
- return self.name
445
- args = ", ".join(a.pretty() for a in self.args)
446
- return f"{self.name}[{args}]"
226
+ def __str__(self) -> str:
227
+ if self.args:
228
+ return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
229
+ return self.name
447
230
 
448
- def to_constant_function_decl(self) -> FunctionDecl:
449
- """
450
- Create a function declaration for a constant function.
451
231
 
452
- This is similar to how egglog compiles the `constant` command.
453
- """
454
- return FunctionDecl(
455
- arg_types=(),
456
- arg_names=(),
457
- arg_defaults=(),
458
- return_type=self.to_var(),
459
- mutates_first_arg=False,
460
- var_arg_type=None,
461
- )
232
+ ##
233
+ # Type references with vars
234
+ ##
462
235
 
463
236
 
464
237
  @dataclass(frozen=True)
@@ -473,7 +246,7 @@ class ClassTypeVarRef:
473
246
  msg = "egglog does not support generic classes yet."
474
247
  raise NotImplementedError(msg)
475
248
 
476
- def pretty(self) -> str:
249
+ def __str__(self) -> str:
477
250
  return self.name
478
251
 
479
252
 
@@ -485,30 +258,27 @@ class TypeRefWithVars:
485
258
  def to_just(self) -> JustTypeRef:
486
259
  return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
487
260
 
488
- def pretty(self) -> str:
489
- if not self.args:
490
- return self.name
491
- args = ", ".join(a.pretty() for a in self.args)
492
- return f"{self.name}[{args}]"
261
+ def __str__(self) -> str:
262
+ if self.args:
263
+ return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
264
+ return self.name
493
265
 
494
266
 
495
267
  TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
496
268
 
269
+ ##
270
+ # Callables References
271
+ ##
272
+
497
273
 
498
274
  @dataclass(frozen=True)
499
275
  class FunctionRef:
500
276
  name: str
501
277
 
502
- def generate_egg_name(self) -> str:
503
- return self.name
504
-
505
- def __str__(self) -> str:
506
- return self.name
507
-
508
278
 
509
- # Use this special character in place of the args, so that if the args are inlined
510
- # in the viz, they will replace it
511
- ARG = "·"
279
+ @dataclass(frozen=True)
280
+ class ConstantRef:
281
+ name: str
512
282
 
513
283
 
514
284
  @dataclass(frozen=True)
@@ -516,123 +286,122 @@ class MethodRef:
516
286
  class_name: str
517
287
  method_name: str
518
288
 
519
- def generate_egg_name(self) -> str:
520
- return f"{self.class_name}_{self.method_name}"
521
-
522
- def __str__(self) -> str: # noqa: PLR0911
523
- match self.method_name:
524
- case _ if self.method_name in UNARY_METHODS:
525
- return f"{UNARY_METHODS[self.method_name]}{ARG}"
526
- case _ if self.method_name in BINARY_METHODS:
527
- return f"({ARG} {BINARY_METHODS[self.method_name]} {ARG})"
528
- case "__getitem__":
529
- return f"{ARG}[{ARG}]"
530
- case "__call__":
531
- return f"{ARG}({ARG})"
532
- case "__delitem__":
533
- return f"del {ARG}[{ARG}]"
534
- case "__setitem__":
535
- return f"{ARG}[{ARG}] = {ARG}"
536
- return f"{ARG}.{self.method_name}"
537
-
538
289
 
539
290
  @dataclass(frozen=True)
540
291
  class ClassMethodRef:
541
292
  class_name: str
542
293
  method_name: str
543
294
 
544
- def generate_egg_name(self) -> str:
545
- return f"{self.class_name}_{self.method_name}"
546
295
 
547
- def __str__(self) -> str:
548
- if self.method_name == "__init__":
549
- return self.class_name
550
- return f"{self.class_name}.{self.method_name}"
296
+ @dataclass(frozen=True)
297
+ class ClassVariableRef:
298
+ class_name: str
299
+ var_name: str
551
300
 
552
301
 
553
302
  @dataclass(frozen=True)
554
- class ConstantRef:
555
- name: str
303
+ class PropertyRef:
304
+ class_name: str
305
+ property_name: str
556
306
 
557
- def generate_egg_name(self) -> str:
558
- return self.name
559
307
 
560
- def __str__(self) -> str:
561
- return self.name
308
+ CallableRef: TypeAlias = FunctionRef | ConstantRef | MethodRef | ClassMethodRef | ClassVariableRef | PropertyRef
562
309
 
563
310
 
564
- @dataclass(frozen=True)
565
- class ClassVariableRef:
566
- class_name: str
567
- variable_name: str
311
+ ##
312
+ # Callables
313
+ ##
568
314
 
569
- def generate_egg_name(self) -> str:
570
- return f"{self.class_name}_{self.variable_name}"
571
315
 
572
- def __str__(self) -> str:
573
- return f"{self.class_name}.{self.variable_name}"
316
+ @dataclass(frozen=True)
317
+ class RelationDecl:
318
+ arg_types: tuple[JustTypeRef, ...]
319
+ # List of defaults. None for any arg which doesn't have one.
320
+ arg_defaults: tuple[ExprDecl | None, ...]
321
+ egg_name: str | None
322
+
323
+ def to_function_decl(self) -> FunctionDecl:
324
+ return FunctionDecl(
325
+ FunctionSignature(
326
+ arg_types=tuple(a.to_var() for a in self.arg_types),
327
+ arg_names=tuple(f"__{i}" for i in range(len(self.arg_types))),
328
+ arg_defaults=self.arg_defaults,
329
+ return_type=TypeRefWithVars("Unit"),
330
+ ),
331
+ egg_name=self.egg_name,
332
+ default=LitDecl(None),
333
+ )
574
334
 
575
335
 
576
336
  @dataclass(frozen=True)
577
- class PropertyRef:
578
- class_name: str
579
- property_name: str
337
+ class ConstantDecl:
338
+ """
339
+ Same as `(declare)` in egglog
340
+ """
580
341
 
581
- def generate_egg_name(self) -> str:
582
- return f"{self.class_name}_{self.property_name}"
342
+ type_ref: JustTypeRef
343
+ egg_name: str | None = None
583
344
 
584
- def __str__(self) -> str:
585
- return f"{ARG}.{self.property_name}"
345
+ def to_function_decl(self) -> FunctionDecl:
346
+ return FunctionDecl(
347
+ FunctionSignature(return_type=self.type_ref.to_var()),
348
+ egg_name=self.egg_name,
349
+ )
586
350
 
587
351
 
588
- ConstantCallableRef: TypeAlias = ConstantRef | ClassVariableRef
589
- FunctionCallableRef: TypeAlias = FunctionRef | MethodRef | ClassMethodRef | PropertyRef
590
- CallableRef: TypeAlias = ConstantCallableRef | FunctionCallableRef
352
+ # special cases for partial function creation and application, which cannot use the normal python rules
353
+ SpecialFunctions: TypeAlias = Literal["fn-partial", "fn-app"]
591
354
 
592
355
 
593
356
  @dataclass(frozen=True)
594
- class FunctionDecl:
595
- arg_types: tuple[TypeOrVarRef, ...]
596
- # Is None for relation which doesn't have named args
597
- arg_names: tuple[str, ...] | None
598
- arg_defaults: tuple[ExprDecl | None, ...]
599
- return_type: TypeOrVarRef
600
- mutates_first_arg: bool
357
+ class FunctionSignature:
358
+ arg_types: tuple[TypeOrVarRef, ...] = ()
359
+ arg_names: tuple[str, ...] = ()
360
+ # List of defaults. None for any arg which doesn't have one.
361
+ arg_defaults: tuple[ExprDecl | None, ...] = ()
362
+ # If None, then the first arg is mutated and returned
363
+ return_type: TypeOrVarRef | None = None
601
364
  var_arg_type: TypeOrVarRef | None = None
602
365
 
603
- def __post_init__(self) -> None:
604
- # If we mutate the first arg, then the first arg should be the same type as the return
605
- if self.mutates_first_arg:
606
- assert self.arg_types[0] == self.return_type
607
-
608
- def to_signature(self, transform_default: Callable[[TypedExprDecl], object]) -> Signature:
609
- arg_names = self.arg_names or tuple(f"__{i}" for i in range(len(self.arg_types)))
610
- parameters = [
611
- Parameter(
612
- n,
613
- Parameter.POSITIONAL_OR_KEYWORD,
614
- default=transform_default(TypedExprDecl(t.to_just(), d)) if d else Parameter.empty,
615
- )
616
- for n, d, t in zip(arg_names, self.arg_defaults, self.arg_types, strict=True)
617
- ]
618
- if self.var_arg_type is not None:
619
- parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
620
- return Signature(parameters)
366
+ @property
367
+ def semantic_return_type(self) -> TypeOrVarRef:
368
+ """
369
+ The type that is returned by the function, which wil be in the first arg if it mutates it.
370
+ """
371
+ return self.return_type or self.arg_types[0]
372
+
373
+ @property
374
+ def mutates(self) -> bool:
375
+ return self.return_type is None
621
376
 
622
377
 
623
378
  @dataclass(frozen=True)
624
- class VarDecl:
625
- name: str
379
+ class FunctionDecl:
380
+ signature: FunctionSignature | SpecialFunctions = field(default_factory=FunctionSignature)
381
+
382
+ # Egg params
383
+ builtin: bool = False
384
+ egg_name: str | None = None
385
+ cost: int | None = None
386
+ default: ExprDecl | None = None
387
+ on_merge: tuple[ActionDecl, ...] = ()
388
+ merge: ExprDecl | None = None
389
+ unextractable: bool = False
390
+
391
+ def to_function_decl(self) -> FunctionDecl:
392
+ return self
626
393
 
627
- @classmethod
628
- def from_egg(cls, var: bindings.TermVar) -> ExprDecl:
629
- return cls(var.name)
630
394
 
631
- def to_egg(self, _decls: Declarations) -> bindings.Var:
632
- return bindings.Var(self.name)
395
+ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl
633
396
 
634
- def pretty(self, context: PrettyContext, **kwargs) -> str:
635
- return self.name
397
+ ##
398
+ # Expressions
399
+ ##
400
+
401
+
402
+ @dataclass(frozen=True)
403
+ class VarDecl:
404
+ name: str
636
405
 
637
406
 
638
407
  @dataclass(frozen=True)
@@ -646,16 +415,14 @@ class PyObjectDecl:
646
415
  except TypeError:
647
416
  return id(self.value)
648
417
 
649
- @classmethod
650
- def from_egg(cls, egraph: bindings.EGraph, termdag: bindings.TermDag, term: bindings.TermApp) -> ExprDecl:
651
- call = bindings.termdag_term_to_expr(termdag, term)
652
- return cls(egraph.eval_py_object(call))
653
-
654
- def to_egg(self, _decls: Declarations) -> bindings._Expr:
655
- return GLOBAL_PY_OBJECT_SORT.store(self.value)
418
+ def __eq__(self, other: object) -> bool:
419
+ if not isinstance(other, PyObjectDecl):
420
+ return False
421
+ return self.parts == other.parts
656
422
 
657
- def pretty(self, context: PrettyContext, **kwargs) -> str:
658
- return repr(self.value)
423
+ @property
424
+ def parts(self) -> tuple[type, object]:
425
+ return (type(self.value), self.value)
659
426
 
660
427
 
661
428
  LitType: TypeAlias = int | str | float | bool | None
@@ -665,53 +432,30 @@ LitType: TypeAlias = int | str | float | bool | None
665
432
  class LitDecl:
666
433
  value: LitType
667
434
 
668
- @classmethod
669
- def from_egg(cls, lit: bindings.TermLit) -> ExprDecl:
670
- value = lit.value
671
- if isinstance(value, bindings.Unit):
672
- return cls(None)
673
- return cls(value.value)
674
-
675
- def to_egg(self, _decls: Declarations) -> bindings.Lit:
676
- if self.value is None:
677
- return bindings.Lit(bindings.Unit())
678
- if isinstance(self.value, bool):
679
- return bindings.Lit(bindings.Bool(self.value))
680
- if isinstance(self.value, int):
681
- return bindings.Lit(bindings.Int(self.value))
682
- if isinstance(self.value, float):
683
- return bindings.Lit(bindings.F64(self.value))
684
- if isinstance(self.value, str):
685
- return bindings.Lit(bindings.String(self.value))
686
- assert_never(self.value)
687
-
688
- def pretty(self, context: PrettyContext, unwrap_lit: bool = True, **kwargs) -> str:
435
+ def __hash__(self) -> int:
689
436
  """
690
- Returns a string representation of the literal.
691
-
692
- :param wrap_lit: If True, wraps the literal in a call to the literal constructor.
437
+ Include type in has so that 1.0 != 1
693
438
  """
694
- if self.value is None:
695
- return "Unit()"
696
- if isinstance(self.value, bool):
697
- return f"Bool({self.value})" if not unwrap_lit else str(self.value)
698
- if isinstance(self.value, int):
699
- return f"i64({self.value})" if not unwrap_lit else str(self.value)
700
- if isinstance(self.value, float):
701
- return f"f64({self.value})" if not unwrap_lit else str(self.value)
702
- if isinstance(self.value, str):
703
- return f"String({self.value!r})" if not unwrap_lit else repr(self.value)
704
- assert_never(self.value)
439
+ return hash(self.parts)
440
+
441
+ def __eq__(self, other: object) -> bool:
442
+ if not isinstance(other, LitDecl):
443
+ return False
444
+ return self.parts == other.parts
445
+
446
+ @property
447
+ def parts(self) -> tuple[type, LitType]:
448
+ return (type(self.value), self.value)
705
449
 
706
450
 
707
451
  @dataclass(frozen=True)
708
452
  class CallDecl:
709
453
  callable: CallableRef
454
+ # TODO: Can I make these not typed expressions?
710
455
  args: tuple[TypedExprDecl, ...] = ()
711
456
  # type parameters that were bound to the callable, if it is a classmethod
712
457
  # Used for pretty printing classmethod calls with type parameters
713
458
  bound_tp_params: tuple[JustTypeRef, ...] | None = None
714
- _cached_hash: int | None = None
715
459
 
716
460
  def __post_init__(self) -> None:
717
461
  if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
@@ -719,252 +463,33 @@ class CallDecl:
719
463
  raise ValueError(msg)
720
464
 
721
465
  def __hash__(self) -> int:
722
- # Modified hash which will cache result for performance
723
- if self._cached_hash is None:
724
- res = hash((self.callable, self.args, self.bound_tp_params))
725
- object.__setattr__(self, "_cached_hash", res)
726
- return res
727
466
  return self._cached_hash
728
467
 
468
+ @cached_property
469
+ def _cached_hash(self) -> int:
470
+ return hash((self.callable, self.args, self.bound_tp_params))
471
+
729
472
  def __eq__(self, other: object) -> bool:
730
473
  # Override eq to use cached hash for perf
731
474
  if not isinstance(other, CallDecl):
732
475
  return False
733
476
  return hash(self) == hash(other)
734
477
 
735
- @classmethod
736
- def from_egg(
737
- cls,
738
- egraph: bindings.EGraph,
739
- decls: Declarations,
740
- return_tp: JustTypeRef,
741
- termdag: bindings.TermDag,
742
- term: bindings.TermApp,
743
- cache: dict[int, TypedExprDecl],
744
- ) -> ExprDecl:
745
- """
746
- Convert an egg expression into a typed expression by using the declerations.
747
478
 
748
- Also pass in the desired type to do type checking top down. Needed to disambiguate calls like (map-create)
749
- during expression extraction, where we always know the types.
750
- """
751
- from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
752
-
753
- # Find the first callable ref that matches the call
754
- for callable_ref in decls.get_callable_refs(term.name):
755
- # If this is a classmethod, we might need the type params that were bound for this type
756
- # This could be multiple types if the classmethod is ambiguous, like map create.
757
- possible_types: Iterable[JustTypeRef | None]
758
- fn_decl = decls.get_function_decl(callable_ref)
759
- if isinstance(callable_ref, ClassMethodRef):
760
- possible_types = decls.get_possible_types(callable_ref.class_name)
761
- cls_name = callable_ref.class_name
762
- else:
763
- possible_types = [None]
764
- cls_name = None
765
- for possible_type in possible_types:
766
- tcs = TypeConstraintSolver(decls)
767
- if possible_type and possible_type.args:
768
- tcs.bind_class(possible_type)
769
-
770
- try:
771
- arg_types, bound_tp_params = tcs.infer_arg_types(
772
- fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, return_tp, cls_name
773
- )
774
- except TypeConstraintError:
775
- continue
776
- args: list[TypedExprDecl] = []
777
- for a, tp in zip(term.args, arg_types, strict=False):
778
- if a in cache:
779
- res = cache[a]
780
- else:
781
- res = TypedExprDecl.from_egg(egraph, decls, tp, termdag, termdag.nodes[a], cache)
782
- cache[a] = res
783
- args.append(res)
784
- return cls(callable_ref, tuple(args), bound_tp_params)
785
- raise ValueError(f"Could not find callable ref for call {term}")
786
-
787
- def to_egg(self, decls: Declarations) -> bindings._Expr:
788
- """Convert a Call to an egg Call."""
789
- # This was removed when we replaced declerations constants with our b/c of unextractable constants
790
- # # If this is a constant, then emit it just as a var, not as a call
791
- # if isinstance(self.callable, ConstantRef | ClassVariableRef):
792
- # decls.get_egg_fn
793
- # return bindings.Var(egg_fn)
794
- if hasattr(self, "_cached_egg"):
795
- return self._cached_egg
796
- egg_fn = decls.get_egg_fn(self.callable)
797
- res = bindings.Call(egg_fn, [a.to_egg(decls) for a in self.args])
798
- object.__setattr__(self, "_cached_egg", res)
799
- return res
800
-
801
- def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str: # noqa: C901
802
- """
803
- Pretty print the call.
804
-
805
- :param parens: If true, wrap the call in parens if it is a binary method call.
806
- """
807
- if self in context.names:
808
- return context.names[self]
809
- ref, args = self.callable, [a.expr for a in self.args]
810
- # Special case !=
811
- if ref == FunctionRef("!="):
812
- return f"ne({args[0].pretty(context, parens=False, unwrap_lit=False)}).to({args[1].pretty(context, parens=False, unwrap_lit=False)})"
813
- function_decl = context.decls.get_function_decl(ref)
814
- # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
815
- n_defaults = 0
816
- for arg, default in zip(
817
- reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type
818
- ):
819
- if arg != default:
820
- break
821
- n_defaults += 1
822
- if n_defaults:
823
- args = args[:-n_defaults]
824
- if function_decl.mutates_first_arg:
825
- first_arg = args[0]
826
- expr_str = first_arg.pretty(context, parens=False)
827
- # copy an identifer expression iff it has multiple parents (b/c then we can't mutate it directly)
828
- has_multiple_parents = context.parents[first_arg] > 1
829
- expr_name = context.name_expr(function_decl.arg_types[0], expr_str, copy_identifier=has_multiple_parents)
830
- # Set the first arg to be the name of the mutated arg and return the name
831
- args[0] = VarDecl(expr_name)
832
- else:
833
- expr_name = None
834
- match ref:
835
- case FunctionRef(name):
836
- expr = _pretty_call(context, name, args)
837
- case ClassMethodRef(class_name, method_name):
838
- tp_ref = JustTypeRef(class_name, self.bound_tp_params or ())
839
- fn_str = tp_ref.pretty() if method_name == "__init__" else f"{tp_ref.pretty()}.{method_name}"
840
- expr = _pretty_call(context, fn_str, args)
841
- case MethodRef(_class_name, method_name):
842
- slf, *args = args
843
- slf = slf.pretty(context, unwrap_lit=False)
844
- match method_name:
845
- case _ if method_name in UNARY_METHODS:
846
- expr = f"{UNARY_METHODS[method_name]}{slf}"
847
- case _ if method_name in BINARY_METHODS:
848
- assert len(args) == 1
849
- expr = f"{slf} {BINARY_METHODS[method_name]} {args[0].pretty(context)}"
850
- if parens:
851
- expr = f"({expr})"
852
- case "__getitem__":
853
- assert len(args) == 1
854
- expr = f"{slf}[{args[0].pretty(context, parens=False)}]"
855
- case "__call__":
856
- expr = _pretty_call(context, slf, args)
857
- case "__delitem__":
858
- assert len(args) == 1
859
- expr = f"del {slf}[{args[0].pretty(context, parens=False)}]"
860
- case "__setitem__":
861
- assert len(args) == 2
862
- expr = (
863
- f"{slf}[{args[0].pretty(context, parens=False)}] = {args[1].pretty(context, parens=False)}"
864
- )
865
- case _:
866
- expr = _pretty_call(context, f"{slf}.{method_name}", args)
867
- case ConstantRef(name):
868
- expr = name
869
- case ClassVariableRef(class_name, variable_name):
870
- expr = f"{class_name}.{variable_name}"
871
- case PropertyRef(_class_name, property_name):
872
- expr = f"{args[0].pretty(context)}.{property_name}"
873
- case _:
874
- assert_never(ref)
875
- # If we have a name, then we mutated
876
- if expr_name:
877
- context.statements.append(expr)
878
- context.names[self] = expr_name
879
- return expr_name
880
-
881
- # We use a heuristic to decide whether to name this sub-expression as a variable
882
- # The rough goal is to reduce the number of newlines, given our line length of ~180
883
- # We determine it's worth making a new line for this expression if the total characters
884
- # it would take up is > than some constant (~ line length).
885
- n_parents = context.parents[self]
886
- line_diff: int = len(expr) - LINE_DIFFERENCE
887
- if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH:
888
- expr_name = context.name_expr(function_decl.return_type, expr, copy_identifier=False)
889
- context.names[self] = expr_name
890
- return expr_name
891
- return expr
892
-
893
-
894
- MAX_LINE_LENGTH = 110
895
- LINE_DIFFERENCE = 10
896
-
897
-
898
- def _plot_line_length(expr: object):
899
- """
900
- Plots the number of line lengths based on different max lengths
479
+ @dataclass(frozen=True)
480
+ class PartialCallDecl:
901
481
  """
902
- global MAX_LINE_LENGTH, LINE_DIFFERENCE
903
- import altair as alt
904
- import pandas as pd
482
+ A partially applied function aka a function sort.
905
483
 
906
- sizes = []
907
- for line_length in range(40, 180, 10):
908
- MAX_LINE_LENGTH = line_length
909
- for diff in range(0, 40, 5):
910
- LINE_DIFFERENCE = diff
911
- new_l = len(str(expr).split())
912
- sizes.append((line_length, diff, new_l))
913
-
914
- df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"]) # noqa: PD901
915
-
916
- return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")
917
-
918
-
919
- def _pretty_call(context: PrettyContext, fn: str, args: Iterable[ExprDecl]) -> str:
920
- return f"{fn}({', '.join(a.pretty(context, parens=False) for a in args)})"
921
-
922
-
923
- @dataclass
924
- class PrettyContext:
925
- decls: Declarations
926
- # List of statements of "context" setting variable for the expr
927
- statements: list[str] = field(default_factory=list)
928
-
929
- names: dict[ExprDecl, str] = field(default_factory=dict)
930
- parents: dict[ExprDecl, int] = field(default_factory=lambda: defaultdict(lambda: 0))
931
- _traversed_exprs: set[ExprDecl] = field(default_factory=set)
932
-
933
- # Mapping of type to the number of times we have generated a name for that type, used to generate unique names
934
- _gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
935
-
936
- def generate_name(self, typ: str) -> str:
937
- self._gen_name_types[typ] += 1
938
- return f"_{typ}_{self._gen_name_types[typ]}"
939
-
940
- def name_expr(self, expr_type: TypeOrVarRef, expr_str: str, copy_identifier: bool) -> str:
941
- tp_name = expr_type.to_just().name
942
- # If the thing we are naming is already a variable, we don't need to name it
943
- if expr_str.isidentifier():
944
- if copy_identifier:
945
- name = self.generate_name(tp_name)
946
- self.statements.append(f"{name} = copy({expr_str})")
947
- else:
948
- name = expr_str
949
- else:
950
- name = self.generate_name(tp_name)
951
- self.statements.append(f"{name} = {expr_str}")
952
- return name
484
+ Note it does not need to have any args, in which case it's just a function pointer.
953
485
 
954
- def render(self, expr: str) -> str:
955
- return "\n".join([*self.statements, expr])
486
+ Seperated from the call decl so it's clear it is translated to a `unstable-fn` call.
487
+ """
956
488
 
957
- def traverse_for_parents(self, expr: ExprDecl) -> None:
958
- if expr in self._traversed_exprs:
959
- return
960
- self._traversed_exprs.add(expr)
961
- if isinstance(expr, CallDecl):
962
- for arg in set(expr.args):
963
- self.parents[arg.expr] += 1
964
- self.traverse_for_parents(arg.expr)
489
+ call: CallDecl
965
490
 
966
491
 
967
- ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl
492
+ ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
968
493
 
969
494
 
970
495
  @dataclass(frozen=True)
@@ -972,33 +497,6 @@ class TypedExprDecl:
972
497
  tp: JustTypeRef
973
498
  expr: ExprDecl
974
499
 
975
- @classmethod
976
- def from_egg(
977
- cls,
978
- egraph: bindings.EGraph,
979
- decls: Declarations,
980
- tp: JustTypeRef,
981
- termdag: bindings.TermDag,
982
- term: bindings._Term,
983
- cache: dict[int, TypedExprDecl],
984
- ) -> TypedExprDecl:
985
- expr_decl: ExprDecl
986
- if isinstance(term, bindings.TermVar):
987
- expr_decl = VarDecl.from_egg(term)
988
- elif isinstance(term, bindings.TermLit):
989
- expr_decl = LitDecl.from_egg(term)
990
- elif isinstance(term, bindings.TermApp):
991
- if term.name == "py-object":
992
- expr_decl = PyObjectDecl.from_egg(egraph, termdag, term)
993
- else:
994
- expr_decl = CallDecl.from_egg(egraph, decls, tp, termdag, term, cache)
995
- else:
996
- assert_never(term)
997
- return cls(tp, expr_decl)
998
-
999
- def to_egg(self, decls: Declarations) -> bindings._Expr:
1000
- return self.expr.to_egg(decls)
1001
-
1002
500
  def descendants(self) -> list[TypedExprDecl]:
1003
501
  """
1004
502
  Returns a list of all the descendants of this expression.
@@ -1010,11 +508,133 @@ class TypedExprDecl:
1010
508
  return l
1011
509
 
1012
510
 
1013
- @dataclass
1014
- class ClassDecl:
1015
- methods: dict[str, FunctionDecl] = field(default_factory=dict)
1016
- class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
1017
- class_variables: dict[str, JustTypeRef] = field(default_factory=dict)
1018
- properties: dict[str, FunctionDecl] = field(default_factory=dict)
1019
- preserved_methods: dict[str, Callable] = field(default_factory=dict)
1020
- type_vars: tuple[str, ...] = field(default=())
511
+ ##
512
+ # Schedules
513
+ ##
514
+
515
+
516
+ @dataclass(frozen=True)
517
+ class SaturateDecl:
518
+ schedule: ScheduleDecl
519
+
520
+
521
+ @dataclass(frozen=True)
522
+ class RepeatDecl:
523
+ schedule: ScheduleDecl
524
+ times: int
525
+
526
+
527
+ @dataclass(frozen=True)
528
+ class SequenceDecl:
529
+ schedules: tuple[ScheduleDecl, ...]
530
+
531
+
532
+ @dataclass(frozen=True)
533
+ class RunDecl:
534
+ ruleset: str
535
+ until: tuple[FactDecl, ...] | None
536
+
537
+
538
+ ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
539
+
540
+ ##
541
+ # Facts
542
+ ##
543
+
544
+
545
+ @dataclass(frozen=True)
546
+ class EqDecl:
547
+ tp: JustTypeRef
548
+ exprs: tuple[ExprDecl, ...]
549
+
550
+
551
+ @dataclass(frozen=True)
552
+ class ExprFactDecl:
553
+ typed_expr: TypedExprDecl
554
+
555
+
556
+ FactDecl: TypeAlias = EqDecl | ExprFactDecl
557
+
558
+ ##
559
+ # Actions
560
+ ##
561
+
562
+
563
+ @dataclass(frozen=True)
564
+ class LetDecl:
565
+ name: str
566
+ typed_expr: TypedExprDecl
567
+
568
+
569
+ @dataclass(frozen=True)
570
+ class SetDecl:
571
+ tp: JustTypeRef
572
+ call: CallDecl
573
+ rhs: ExprDecl
574
+
575
+
576
+ @dataclass(frozen=True)
577
+ class ExprActionDecl:
578
+ typed_expr: TypedExprDecl
579
+
580
+
581
+ @dataclass(frozen=True)
582
+ class ChangeDecl:
583
+ tp: JustTypeRef
584
+ call: CallDecl
585
+ change: Literal["delete", "subsume"]
586
+
587
+
588
+ @dataclass(frozen=True)
589
+ class UnionDecl:
590
+ tp: JustTypeRef
591
+ lhs: ExprDecl
592
+ rhs: ExprDecl
593
+
594
+
595
+ @dataclass(frozen=True)
596
+ class PanicDecl:
597
+ msg: str
598
+
599
+
600
+ ActionDecl: TypeAlias = LetDecl | SetDecl | ExprActionDecl | ChangeDecl | UnionDecl | PanicDecl
601
+
602
+
603
+ ##
604
+ # Commands
605
+ ##
606
+
607
+
608
+ @dataclass(frozen=True)
609
+ class RewriteDecl:
610
+ tp: JustTypeRef
611
+ lhs: ExprDecl
612
+ rhs: ExprDecl
613
+ conditions: tuple[FactDecl, ...]
614
+ subsume: bool
615
+
616
+
617
+ @dataclass(frozen=True)
618
+ class BiRewriteDecl:
619
+ tp: JustTypeRef
620
+ lhs: ExprDecl
621
+ rhs: ExprDecl
622
+ conditions: tuple[FactDecl, ...]
623
+
624
+
625
+ @dataclass(frozen=True)
626
+ class RuleDecl:
627
+ head: tuple[ActionDecl, ...]
628
+ body: tuple[FactDecl, ...]
629
+ name: str | None
630
+
631
+
632
+ RewriteOrRuleDecl: TypeAlias = RewriteDecl | BiRewriteDecl | RuleDecl
633
+
634
+
635
+ @dataclass(frozen=True)
636
+ class ActionCommandDecl:
637
+ action: ActionDecl
638
+
639
+
640
+ CommandDecl: TypeAlias = RewriteOrRuleDecl | ActionCommandDecl