egglog 0.4.0__pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/declarations.py ADDED
@@ -0,0 +1,934 @@
1
+ """
2
+ Data only descriptions of the components of an egraph and the expressions.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import itertools
8
+ from abc import ABC, abstractmethod
9
+ from collections import defaultdict
10
+ from dataclasses import dataclass, field
11
+ from typing import ClassVar, Iterable, Optional, Union
12
+
13
+ from typing_extensions import assert_never
14
+
15
+ from . import bindings
16
+
17
+ __all__ = [
18
+ "Declarations",
19
+ "ModuleDeclarations",
20
+ "JustTypeRef",
21
+ "ClassTypeVarRef",
22
+ "TypeRefWithVars",
23
+ "TypeOrVarRef",
24
+ "FunctionRef",
25
+ "MethodRef",
26
+ "ClassMethodRef",
27
+ "ClassVariableRef",
28
+ "FunctionCallableRef",
29
+ "CallableRef",
30
+ "ConstantRef",
31
+ "FunctionDecl",
32
+ "VarDecl",
33
+ "LitType",
34
+ "LitDecl",
35
+ "CallDecl",
36
+ "ExprDecl",
37
+ "TypedExprDecl",
38
+ "ClassDecl",
39
+ "Command",
40
+ "Action",
41
+ "ExprAction",
42
+ "Fact",
43
+ "Rewrite",
44
+ "BiRewrite",
45
+ "Eq",
46
+ "ExprFact",
47
+ "Rule",
48
+ "Let",
49
+ "Set",
50
+ "Delete",
51
+ "Union_",
52
+ "Panic",
53
+ "Action",
54
+ "Schedule",
55
+ "Sequence",
56
+ "Run",
57
+ ]
58
+ # Special methods which we might want to use as functions
59
+ # Mapping to the operator they represent for pretty printing them
60
+ # https://docs.python.org/3/reference/datamodel.html
61
+ BINARY_METHODS = {
62
+ "__lt__": "<",
63
+ "__le__": "<=",
64
+ "__eq__": "==",
65
+ "__ne__": "!=",
66
+ "__gt__": ">",
67
+ "__ge__": ">=",
68
+ # Numeric
69
+ "__add__": "+",
70
+ "__sub__": "-",
71
+ "__mul__": "*",
72
+ "__matmul__": "@",
73
+ "__truediv__": "/",
74
+ "__floordiv__": "//",
75
+ "__mod__": "%",
76
+ "__divmod__": "divmod",
77
+ "__pow__": "**",
78
+ "__lshift__": "<<",
79
+ "__rshift__": ">>",
80
+ "__and__": "&",
81
+ "__xor__": "^",
82
+ "__or__": "|",
83
+ }
84
+ UNARY_METHODS = {
85
+ "__pos__": "+",
86
+ "__neg__": "-",
87
+ "__invert__": "~",
88
+ }
89
+
90
+
91
+ @dataclass
92
+ class Declarations:
93
+ _functions: dict[str, FunctionDecl] = field(default_factory=dict)
94
+ _classes: dict[str, ClassDecl] = field(default_factory=dict)
95
+ _constants: dict[str, JustTypeRef] = field(default_factory=dict)
96
+
97
+ # Bidirectional mapping between egg function names and python callable references.
98
+ # Note that there are possibly mutliple callable references for a single egg function name, like `+`
99
+ # for both int and rational classes.
100
+ _egg_fn_to_callable_refs: defaultdict[str, set[CallableRef]] = field(default_factory=lambda: defaultdict(set))
101
+ _callable_ref_to_egg_fn: dict[CallableRef, str] = field(default_factory=dict)
102
+
103
+ # Bidirectional mapping between egg sort names and python type references.
104
+ _egg_sort_to_type_ref: dict[str, JustTypeRef] = field(default_factory=dict)
105
+ _type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
106
+
107
+ def set_function_decl(self, ref: FunctionCallableRef, decl: FunctionDecl) -> None:
108
+ """
109
+ Sets a function declaration for the given callable reference.
110
+ """
111
+ if isinstance(ref, FunctionRef):
112
+ if ref.name in self._functions:
113
+ raise ValueError(f"Function {ref.name} already registered")
114
+ self._functions[ref.name] = decl
115
+ elif isinstance(ref, MethodRef):
116
+ if ref.method_name in self._classes[ref.class_name].methods:
117
+ raise ValueError(f"Method {ref.class_name}.{ref.method_name} already registered")
118
+ self._classes[ref.class_name].methods[ref.method_name] = decl
119
+ elif isinstance(ref, ClassMethodRef):
120
+ if ref.method_name in self._classes[ref.class_name].class_methods:
121
+ raise ValueError(f"Class method {ref.class_name}.{ref.method_name} already registered")
122
+ self._classes[ref.class_name].class_methods[ref.method_name] = decl
123
+ else:
124
+ assert_never(ref)
125
+
126
+ def set_constant_type(self, ref: ConstantCallableRef, tp: JustTypeRef) -> None:
127
+ if isinstance(ref, ConstantRef):
128
+ if ref.name in self._constants:
129
+ raise ValueError(f"Constant {ref.name} already registered")
130
+ self._constants[ref.name] = tp
131
+ elif isinstance(ref, ClassVariableRef):
132
+ if ref.variable_name in self._classes[ref.class_name].class_variables:
133
+ raise ValueError(f"Class variable {ref.class_name}.{ref.variable_name} already registered")
134
+ self._classes[ref.class_name].class_variables[ref.variable_name] = tp
135
+ else:
136
+ assert_never(ref)
137
+
138
+ def register_callable_ref(self, ref: CallableRef, egg_name: str) -> None:
139
+ """
140
+ Registers a callable reference with the given egg name. The callable's function needs to be registered
141
+ first.
142
+ """
143
+ if ref in self._callable_ref_to_egg_fn:
144
+ raise ValueError(f"Callable ref {ref} already registered")
145
+ self._callable_ref_to_egg_fn[ref] = egg_name
146
+ self._egg_fn_to_callable_refs[egg_name].add(ref)
147
+
148
+ def get_function_decl(self, ref: FunctionCallableRef) -> FunctionDecl:
149
+ if isinstance(ref, FunctionRef):
150
+ return self._functions[ref.name]
151
+ elif isinstance(ref, MethodRef):
152
+ return self._classes[ref.class_name].methods[ref.method_name]
153
+ elif isinstance(ref, ClassMethodRef):
154
+ return self._classes[ref.class_name].class_methods[ref.method_name]
155
+ assert_never(ref)
156
+
157
+ def get_constant_type(self, ref: ConstantCallableRef) -> JustTypeRef:
158
+ if isinstance(ref, ConstantRef):
159
+ return self._constants[ref.name]
160
+ elif isinstance(ref, ClassVariableRef):
161
+ return self._classes[ref.class_name].class_variables[ref.variable_name]
162
+ assert_never(ref)
163
+
164
+ def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]:
165
+ return self._egg_fn_to_callable_refs[egg_name]
166
+
167
+ def get_egg_fn(self, ref: CallableRef) -> str:
168
+ return self._callable_ref_to_egg_fn[ref]
169
+
170
+ def get_egg_sort(self, ref: JustTypeRef) -> str:
171
+ return self._type_ref_to_egg_sort[ref]
172
+
173
+
174
+ @dataclass
175
+ class ModuleDeclarations:
176
+ """
177
+ A set of working declerations for a module.
178
+ """
179
+
180
+ # The modules declarations we have, which we can edit
181
+ _decl: Declarations
182
+ # A list of other declarations we can use, but not edit
183
+ _included_decls: list[Declarations] = field(default_factory=list, repr=False)
184
+
185
+ @property
186
+ def all_decls(self) -> Iterable[Declarations]:
187
+ return itertools.chain([self._decl], self._included_decls)
188
+
189
+ def get_function_decl(self, ref: CallableRef) -> FunctionDecl:
190
+ if isinstance(ref, (ClassVariableRef, ConstantRef)):
191
+ for decls in self.all_decls:
192
+ try:
193
+ return decls.get_constant_type(ref).to_constant_function_decl()
194
+ except KeyError:
195
+ pass
196
+ raise KeyError(f"Constant {ref} not found")
197
+ elif isinstance(ref, (FunctionRef, MethodRef, ClassMethodRef)):
198
+ for decls in self.all_decls:
199
+ try:
200
+ return decls.get_function_decl(ref)
201
+ except KeyError:
202
+ pass
203
+ raise KeyError(f"Function {ref} not found")
204
+ else:
205
+ assert_never(ref)
206
+
207
+ def get_callable_refs(self, egg_name: str) -> Iterable[CallableRef]:
208
+ return itertools.chain.from_iterable(decls.get_callable_refs(egg_name) for decls in self.all_decls)
209
+
210
+ def get_egg_fn(self, ref: CallableRef) -> str:
211
+ for decls in self.all_decls:
212
+ try:
213
+ return decls.get_egg_fn(ref)
214
+ except KeyError:
215
+ pass
216
+ raise KeyError(f"Callable ref {ref} not found")
217
+
218
+ def get_egg_sort(self, ref: JustTypeRef) -> str:
219
+ for decls in self.all_decls:
220
+ try:
221
+ return decls.get_egg_sort(ref)
222
+ except KeyError:
223
+ pass
224
+ raise KeyError(f"Type {ref} not found")
225
+
226
+ def get_class_decl(self, name: str) -> ClassDecl:
227
+ for decls in self.all_decls:
228
+ try:
229
+ return decls._classes[name]
230
+ except KeyError:
231
+ pass
232
+ raise KeyError(f"Class {name} not found")
233
+
234
+ def get_registered_class_args(self, cls_name: str) -> tuple[JustTypeRef, ...]:
235
+ """
236
+ Given a class name, returns the first typevar regsisted with args of that class.
237
+ """
238
+ for decl in self.all_decls:
239
+ for tp in decl._type_ref_to_egg_sort.keys():
240
+ if tp.name == cls_name and tp.args:
241
+ return tp.args
242
+ return ()
243
+
244
+ def register_class(self, name: str, n_type_vars: int, egg_sort: Optional[str]) -> Iterable[bindings._Command]:
245
+ # Register class first
246
+ if name in self._decl._classes:
247
+ raise ValueError(f"Class {name} already registered")
248
+ decl = ClassDecl(n_type_vars=n_type_vars)
249
+ self._decl._classes[name] = decl
250
+ _egg_sort, cmds = self.register_sort(JustTypeRef(name), egg_sort)
251
+ return cmds
252
+
253
+ def register_sort(
254
+ self, ref: JustTypeRef, egg_name: Optional[str] = None
255
+ ) -> tuple[str, Iterable[bindings._Command]]:
256
+ """
257
+ Register a sort with the given name. If no name is given, one is generated.
258
+
259
+ If this is a type called with generic args, register the generic args as well.
260
+ """
261
+ # If the sort is already registered, do nothing
262
+ try:
263
+ egg_sort = self.get_egg_sort(ref)
264
+ except KeyError:
265
+ pass
266
+ else:
267
+ return (egg_sort, [])
268
+ egg_name = egg_name or ref.generate_egg_name()
269
+ if egg_name in self._decl._egg_sort_to_type_ref:
270
+ raise ValueError(f"Sort {egg_name} is already registered.")
271
+ self._decl._egg_sort_to_type_ref[egg_name] = ref
272
+ self._decl._type_ref_to_egg_sort[ref] = egg_name
273
+ return egg_name, ref.to_commands(self)
274
+
275
+ def register_function_callable(
276
+ self,
277
+ ref: FunctionCallableRef,
278
+ fn_decl: FunctionDecl,
279
+ egg_name: Optional[str],
280
+ cost: Optional[int],
281
+ default: Optional[ExprDecl],
282
+ merge: Optional[ExprDecl],
283
+ merge_action: Iterable[Action],
284
+ ) -> Iterable[bindings._Command]:
285
+ """
286
+ Registers a callable with the given egg name. The callable's function needs to be registered
287
+ first.
288
+ """
289
+ egg_name = egg_name or ref.generate_egg_name()
290
+ self._decl.register_callable_ref(ref, egg_name)
291
+ self._decl.set_function_decl(ref, fn_decl)
292
+ return fn_decl.to_commands(self, egg_name, cost, default, merge, merge_action)
293
+
294
+ def register_constant_callable(
295
+ self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: Optional[str]
296
+ ) -> Iterable[bindings._Command]:
297
+ egg_function = ref.generate_egg_name()
298
+ self._decl.register_callable_ref(ref, egg_function)
299
+ self._decl.set_constant_type(ref, type_ref)
300
+ # Create a function decleartion for a constant function. This is similar to how egglog compiles
301
+ # the `declare` command.
302
+ return FunctionDecl((), type_ref.to_var()).to_commands(self, egg_name or ref.generate_egg_name())
303
+
304
+
305
+ # Have two different types of type refs, one that can include vars recursively and one that cannot.
306
+ # We only use the one with vars for classmethods and methods, and the other one for egg references as
307
+ # well as runtime values.
308
+ @dataclass(frozen=True)
309
+ class JustTypeRef:
310
+ name: str
311
+ args: tuple[JustTypeRef, ...] = ()
312
+
313
+ def generate_egg_name(self) -> str:
314
+ """
315
+ Generates an egg sort name for this type reference by linearizing the type.
316
+ """
317
+ if not self.args:
318
+ return self.name
319
+ args = ", ".join(a.generate_egg_name() for a in self.args)
320
+ return f"{self.name}[{args}]"
321
+
322
+ def to_commands(self, mod_decls: ModuleDeclarations) -> Iterable[bindings._Command]:
323
+ """
324
+ Returns commands to register this as a sort, as well as for any of its arguments.
325
+ """
326
+ egg_name = mod_decls.get_egg_sort(self)
327
+ arg_sorts: list[bindings._Expr] = []
328
+ for arg in self.args:
329
+ egg_sort, cmds = mod_decls.register_sort(arg)
330
+ arg_sorts.append(bindings.Var(egg_sort))
331
+ yield from cmds
332
+ yield bindings.Sort(egg_name, (self.name, arg_sorts) if arg_sorts else None)
333
+
334
+ def to_var(self) -> TypeRefWithVars:
335
+ return TypeRefWithVars(self.name, tuple(a.to_var() for a in self.args))
336
+
337
+ def pretty(self) -> str:
338
+ if not self.args:
339
+ return self.name
340
+ args = ", ".join(a.pretty() for a in self.args)
341
+ return f"{self.name}[{args}]"
342
+
343
+ def to_constant_function_decl(self) -> FunctionDecl:
344
+ """
345
+ Create a function declaration for a constant function. This is similar to how egglog compiles
346
+ the `constant` command.
347
+ """
348
+ return FunctionDecl(arg_types=(), return_type=self.to_var(), var_arg_type=None)
349
+
350
+
351
+ @dataclass(frozen=True)
352
+ class ClassTypeVarRef:
353
+ """
354
+ A class type variable represents one of the types of the class, if it is a generic
355
+ class.
356
+ """
357
+
358
+ index: int
359
+
360
+ def to_just(self) -> JustTypeRef:
361
+ raise NotImplementedError("egglog does not support generic classes yet.")
362
+
363
+
364
+ @dataclass(frozen=True)
365
+ class TypeRefWithVars:
366
+ name: str
367
+ args: tuple[TypeOrVarRef, ...] = ()
368
+
369
+ def to_just(self) -> JustTypeRef:
370
+ return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
371
+
372
+
373
+ TypeOrVarRef = Union[ClassTypeVarRef, TypeRefWithVars]
374
+
375
+
376
+ @dataclass(frozen=True)
377
+ class FunctionRef:
378
+ name: str
379
+
380
+ def generate_egg_name(self) -> str:
381
+ return self.name
382
+
383
+
384
+ @dataclass(frozen=True)
385
+ class MethodRef:
386
+ class_name: str
387
+ method_name: str
388
+
389
+ def generate_egg_name(self) -> str:
390
+ return f"{self.class_name}.{self.method_name}"
391
+
392
+
393
+ @dataclass(frozen=True)
394
+ class ClassMethodRef:
395
+ class_name: str
396
+ method_name: str
397
+
398
+ def to_egg(self, decls: Declarations) -> str:
399
+ return decls.get_egg_fn(self)
400
+
401
+ def generate_egg_name(self) -> str:
402
+ return f"{self.class_name}.{self.method_name}"
403
+
404
+
405
+ @dataclass(frozen=True)
406
+ class ConstantRef:
407
+ name: str
408
+
409
+ def generate_egg_name(self) -> str:
410
+ return self.name
411
+
412
+
413
+ @dataclass(frozen=True)
414
+ class ClassVariableRef:
415
+ class_name: str
416
+ variable_name: str
417
+
418
+ def generate_egg_name(self) -> str:
419
+ return f"{self.class_name}.{self.variable_name}"
420
+
421
+
422
+ ConstantCallableRef = Union[ConstantRef, ClassVariableRef]
423
+ FunctionCallableRef = Union[FunctionRef, MethodRef, ClassMethodRef]
424
+ CallableRef = Union[ConstantCallableRef, FunctionCallableRef]
425
+
426
+
427
+ @dataclass(frozen=True)
428
+ class FunctionDecl:
429
+ # TODO: Add arg name to arg so can call with keyword arg
430
+ arg_types: tuple[TypeOrVarRef, ...]
431
+ return_type: TypeOrVarRef
432
+ var_arg_type: Optional[TypeOrVarRef] = None
433
+
434
+ def to_commands(
435
+ self,
436
+ mod_decls: ModuleDeclarations,
437
+ egg_name: str,
438
+ cost: Optional[int] = None,
439
+ default: Optional[ExprDecl] = None,
440
+ merge: Optional[ExprDecl] = None,
441
+ merge_action: Iterable[Action] = (),
442
+ ) -> Iterable[bindings._Command]:
443
+ if self.var_arg_type is not None:
444
+ raise NotImplementedError("egglog does not support variable arguments yet.")
445
+ arg_sorts: list[str] = []
446
+ for a in self.arg_types:
447
+ # Remove all vars from the type refs, raising an errory if we find one,
448
+ # since we cannot create egg functions with vars
449
+ arg_sort, cmds = mod_decls.register_sort(a.to_just())
450
+ yield from cmds
451
+ arg_sorts.append(arg_sort)
452
+ return_sort, cmds = mod_decls.register_sort(self.return_type.to_just())
453
+ yield from cmds
454
+
455
+ egg_fn_decl = bindings.FunctionDecl(
456
+ egg_name,
457
+ bindings.Schema(arg_sorts, return_sort),
458
+ default.to_egg(mod_decls) if default else None,
459
+ merge.to_egg(mod_decls) if merge else None,
460
+ [a._to_egg_action(mod_decls) for a in merge_action],
461
+ cost,
462
+ )
463
+ yield bindings.Function(egg_fn_decl)
464
+
465
+
466
+ @dataclass(frozen=True)
467
+ class VarDecl:
468
+ name: str
469
+
470
+ @classmethod
471
+ def from_egg(cls, var: bindings.Var) -> TypedExprDecl:
472
+ raise NotImplementedError("Cannot turn var into egg type because typing unknown.")
473
+
474
+ def to_egg(self, _decls: ModuleDeclarations) -> bindings.Var:
475
+ return bindings.Var(self.name)
476
+
477
+ def pretty(self, **kwargs) -> str:
478
+ return self.name
479
+
480
+
481
+ LitType = Union[int, str, float, None]
482
+
483
+
484
+ @dataclass(frozen=True)
485
+ class LitDecl:
486
+ value: LitType
487
+
488
+ @classmethod
489
+ def from_egg(cls, lit: bindings.Lit) -> TypedExprDecl:
490
+ if isinstance(lit.value, bindings.Int):
491
+ return TypedExprDecl(JustTypeRef("i64"), cls(lit.value.value))
492
+ if isinstance(lit.value, bindings.String):
493
+ return TypedExprDecl(JustTypeRef("String"), cls(lit.value.value))
494
+ if isinstance(lit.value, bindings.F64):
495
+ return TypedExprDecl(JustTypeRef("f64"), cls(lit.value.value))
496
+ elif isinstance(lit.value, bindings.Unit):
497
+ return TypedExprDecl(JustTypeRef("Unit"), cls(None))
498
+ assert_never(lit.value)
499
+
500
+ def to_egg(self, _decls: ModuleDeclarations) -> bindings.Lit:
501
+ if self.value is None:
502
+ return bindings.Lit(bindings.Unit())
503
+ if isinstance(self.value, int):
504
+ return bindings.Lit(bindings.Int(self.value))
505
+ if isinstance(self.value, float):
506
+ return bindings.Lit(bindings.F64(self.value))
507
+ if isinstance(self.value, str):
508
+ return bindings.Lit(bindings.String(self.value))
509
+ assert_never(self.value)
510
+
511
+ def pretty(self, wrap_lit=True, **kwargs) -> str:
512
+ """
513
+ Returns a string representation of the literal.
514
+
515
+ :param wrap_lit: If True, wraps the literal in a call to the literal constructor.
516
+ """
517
+ if self.value is None:
518
+ return "Unit()"
519
+ if isinstance(self.value, int):
520
+ return f"i64({self.value})" if wrap_lit else str(self.value)
521
+ if isinstance(self.value, float):
522
+ return f"f64({self.value})" if wrap_lit else str(self.value)
523
+ if isinstance(self.value, str):
524
+ return f"String({repr(self.value)})" if wrap_lit else repr(self.value)
525
+ assert_never(self.value)
526
+
527
+
528
+ @dataclass(frozen=True)
529
+ class CallDecl:
530
+ callable: CallableRef
531
+ args: tuple[TypedExprDecl, ...] = ()
532
+ # type parameters that were bound to the callable, if it is a classmethod
533
+ # Used for pretty printing classmethod calls with type parameters
534
+ bound_tp_params: Optional[tuple[JustTypeRef, ...]] = None
535
+
536
+ def __post_init__(self):
537
+ if self.bound_tp_params and not isinstance(self.callable, ClassMethodRef):
538
+ raise ValueError("Cannot bind type parameters to a non-class method callable.")
539
+
540
+ @classmethod
541
+ def from_egg(cls, mod_decls: ModuleDeclarations, call: bindings.Call) -> TypedExprDecl:
542
+ from .type_constraint_solver import TypeConstraintSolver
543
+
544
+ results = tuple(TypedExprDecl.from_egg(mod_decls, a) for a in call.args)
545
+ arg_types = tuple(r.tp for r in results)
546
+
547
+ # Find the first callable ref that matches the call
548
+ for callable_ref in mod_decls.get_callable_refs(call.name):
549
+ # If this is a classmethod, we might need the type params that were bound for this type
550
+ # egglog currently only allows one instantiated type of any generic sort to be used in any program
551
+ # So we just lookup what args were registered for this sort
552
+ if isinstance(callable_ref, ClassMethodRef):
553
+ cls_args = mod_decls.get_registered_class_args(callable_ref.class_name)
554
+ tcs = TypeConstraintSolver.from_type_parameters(cls_args)
555
+ else:
556
+ tcs = TypeConstraintSolver()
557
+ fn_decl = mod_decls.get_function_decl(callable_ref)
558
+ return_tp = tcs.infer_return_type(fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types)
559
+ return TypedExprDecl(return_tp, cls(callable_ref, tuple(results)))
560
+ raise ValueError(f"Could not find callable ref for call {call}")
561
+
562
+ def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call:
563
+ """Convert a Call to an egg Call."""
564
+ egg_fn = mod_decls.get_egg_fn(self.callable)
565
+ return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args])
566
+
567
+ def pretty(self, parens=True, **kwargs) -> str:
568
+ """
569
+ Pretty print the call.
570
+
571
+ :param parens: If true, wrap the call in parens if it is a binary or unary method call.
572
+ """
573
+ ref, args = self.callable, [a.expr for a in self.args]
574
+ if isinstance(ref, FunctionRef):
575
+ fn_str = ref.name
576
+ elif isinstance(ref, ClassMethodRef):
577
+ tp_ref = JustTypeRef(ref.class_name, self.bound_tp_params or ())
578
+ if ref.method_name == "__init__":
579
+ fn_str = tp_ref.pretty()
580
+ else:
581
+ fn_str = f"{tp_ref.pretty()}.{ref.method_name}"
582
+ elif isinstance(ref, MethodRef):
583
+ name = ref.method_name
584
+ slf, *args = args
585
+ if name in UNARY_METHODS:
586
+ return f"{UNARY_METHODS[name]}{slf.pretty()}"
587
+ elif name in BINARY_METHODS:
588
+ assert len(args) == 1
589
+ expr = f"{slf.pretty()} {BINARY_METHODS[name]} {args[0].pretty(wrap_lit=False)}"
590
+ return expr if not parens else f"({expr})"
591
+ elif name == "__getitem__":
592
+ assert len(args) == 1
593
+ return f"{slf.pretty()}[{args[0].pretty(wrap_lit=False)}]"
594
+ elif name == "__call__":
595
+ return f"{slf.pretty()}({', '.join(a.pretty(wrap_lit=False) for a in args)})"
596
+ fn_str = f"{slf.pretty()}.{name}"
597
+ elif isinstance(ref, ConstantRef):
598
+ return ref.name
599
+ elif isinstance(ref, ClassVariableRef):
600
+ return f"{ref.class_name}.{ref.variable_name}"
601
+ else:
602
+ assert_never(ref)
603
+ return f"{fn_str}({', '.join(a.pretty(wrap_lit=False) for a in args)})"
604
+
605
+
606
+ def test_expr_pretty():
607
+ assert VarDecl("x").pretty() == "x"
608
+ assert LitDecl(42).pretty() == "i64(42)"
609
+ assert LitDecl("foo").pretty() == 'String("foo")'
610
+ assert LitDecl(None).pretty() == "unit()"
611
+
612
+ def v(x: str) -> TypedExprDecl:
613
+ return TypedExprDecl(JustTypeRef(""), VarDecl(x))
614
+
615
+ assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty() == "foo(x)"
616
+ assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty() == "foo(x, y, z)"
617
+ assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty() == "x + y"
618
+ assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty() == "x[y]"
619
+ assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty() == "foo(x, y)"
620
+ assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty() == "foo.bar(x, y)"
621
+ assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty() == "x(y)"
622
+ assert (
623
+ CallDecl(
624
+ ClassMethodRef("Map", "__init__"),
625
+ (),
626
+ (JustTypeRef("i64"), JustTypeRef("Unit")),
627
+ ).pretty()
628
+ == "Map[i64, Unit]()"
629
+ )
630
+
631
+
632
+ ExprDecl = Union[VarDecl, LitDecl, CallDecl]
633
+
634
+
635
+ @dataclass(frozen=True)
636
+ class TypedExprDecl:
637
+ tp: JustTypeRef
638
+ expr: ExprDecl
639
+
640
+ @classmethod
641
+ def from_egg(cls, mod_decls: ModuleDeclarations, expr: bindings._Expr) -> TypedExprDecl:
642
+ if isinstance(expr, bindings.Var):
643
+ return VarDecl.from_egg(expr)
644
+ if isinstance(expr, bindings.Lit):
645
+ return LitDecl.from_egg(expr)
646
+ if isinstance(expr, bindings.Call):
647
+ return CallDecl.from_egg(mod_decls, expr)
648
+ assert_never(expr)
649
+
650
+ def to_egg(self, decls: ModuleDeclarations) -> bindings._Expr:
651
+ return self.expr.to_egg(decls)
652
+
653
+
654
+ @dataclass
655
+ class ClassDecl:
656
+ methods: dict[str, FunctionDecl] = field(default_factory=dict)
657
+ class_methods: dict[str, FunctionDecl] = field(default_factory=dict)
658
+ class_variables: dict[str, JustTypeRef] = field(default_factory=dict)
659
+ n_type_vars: int = 0
660
+
661
+
662
+ class Command(ABC):
663
+ """
664
+ A command that can be executed in the egg interpreter.
665
+
666
+ We only use this for commands which return no result and don't create new Python objects.
667
+
668
+ Anything that can be passed to the `register` function in a Module is a Command.
669
+ """
670
+
671
+ @abstractmethod
672
+ def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
673
+ raise NotImplementedError
674
+
675
+ @abstractmethod
676
+ def __str__(self) -> str:
677
+ raise NotImplementedError
678
+
679
+
680
+ @dataclass(frozen=True)
681
+ class Rewrite(Command):
682
+ _ruleset: str
683
+ _lhs: ExprDecl
684
+ _rhs: ExprDecl
685
+ _conditions: tuple[Fact, ...]
686
+ _fn_name: ClassVar[str] = "rewrite"
687
+
688
+ def __str__(self) -> str:
689
+ args_str = ", ".join(map(str, [self._rhs.pretty(), *self._conditions]))
690
+ return f"{self._fn_name}({self._lhs.pretty()}).to({args_str})"
691
+
692
+ def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
693
+ return bindings.RewriteCommand(self._ruleset, self._to_egg_rewrite(mod_decls))
694
+
695
+ def _to_egg_rewrite(self, mod_decls: ModuleDeclarations) -> bindings.Rewrite:
696
+ return bindings.Rewrite(
697
+ self._lhs.to_egg(mod_decls),
698
+ self._rhs.to_egg(mod_decls),
699
+ [c._to_egg_fact(mod_decls) for c in self._conditions],
700
+ )
701
+
702
+
703
+ @dataclass(frozen=True)
704
+ class BiRewrite(Rewrite):
705
+ _fn_name: ClassVar[str] = "birewrite"
706
+
707
+ def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
708
+ return bindings.BiRewriteCommand(self._ruleset, self._to_egg_rewrite(mod_decls))
709
+
710
+
711
+ class Fact(ABC):
712
+ """
713
+ An e-graph fact, either an equality or a unit expression.
714
+ """
715
+
716
+ @abstractmethod
717
+ def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings._Fact:
718
+ raise NotImplementedError
719
+
720
+
721
+ @dataclass(frozen=True)
722
+ class Eq(Fact):
723
+ _exprs: tuple[ExprDecl, ...]
724
+
725
+ def __str__(self) -> str:
726
+ first, *rest = (e.pretty() for e in self._exprs)
727
+ args_str = ", ".join(rest)
728
+ return f"eq({first}).to({args_str})"
729
+
730
+ def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings.Eq:
731
+ return bindings.Eq([e.to_egg(mod_decls) for e in self._exprs])
732
+
733
+
734
+ @dataclass(frozen=True)
735
+ class ExprFact(Fact):
736
+ _expr: ExprDecl
737
+
738
+ def __str__(self) -> str:
739
+ return self._expr.pretty()
740
+
741
+ def _to_egg_fact(self, mod_decls: ModuleDeclarations) -> bindings.Fact:
742
+ return bindings.Fact(self._expr.to_egg(mod_decls))
743
+
744
+
745
+ @dataclass(frozen=True)
746
+ class Rule(Command):
747
+ head: tuple[Action, ...]
748
+ body: tuple[Fact, ...]
749
+ name: str
750
+ ruleset: str
751
+
752
+ def __str__(self) -> str:
753
+ head_str = ", ".join(map(str, self.head))
754
+ body_str = ", ".join(map(str, self.body))
755
+ return f"rule({head_str}).then({body_str})"
756
+
757
+ def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings.RuleCommand:
758
+ return bindings.RuleCommand(
759
+ self.name,
760
+ self.ruleset,
761
+ bindings.Rule(
762
+ [a._to_egg_action(mod_decls) for a in self.head],
763
+ [f._to_egg_fact(mod_decls) for f in self.body],
764
+ ),
765
+ )
766
+
767
+
768
+ class Action(Command, ABC):
769
+ @abstractmethod
770
+ def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings._Action:
771
+ raise NotImplementedError
772
+
773
+ def _to_egg_command(self, mod_decls: ModuleDeclarations) -> bindings._Command:
774
+ return bindings.ActionCommand(self._to_egg_action(mod_decls))
775
+
776
+
777
+ @dataclass(frozen=True)
778
+ class Let(Action):
779
+ _name: str
780
+ _value: ExprDecl
781
+
782
+ def __str__(self) -> str:
783
+ return f"let({self._name}, {self._value.pretty()})"
784
+
785
+ def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Let:
786
+ return bindings.Let(self._name, self._value.to_egg(mod_decls))
787
+
788
+
789
+ @dataclass(frozen=True)
790
+ class Set(Action):
791
+ _call: CallDecl
792
+ _rhs: ExprDecl
793
+
794
+ def __str__(self) -> str:
795
+ return f"set({self._call.pretty()}).to({self._rhs.pretty()})"
796
+
797
+ def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Set:
798
+ return bindings.Set(
799
+ mod_decls.get_egg_fn(self._call.callable),
800
+ [a.to_egg(mod_decls) for a in self._call.args],
801
+ self._rhs.to_egg(mod_decls),
802
+ )
803
+
804
+
805
+ @dataclass(frozen=True)
806
+ class ExprAction(Action):
807
+ _expr: ExprDecl
808
+
809
+ def __str__(self) -> str:
810
+ return self._expr.pretty()
811
+
812
+ def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Expr_:
813
+ return bindings.Expr_(self._expr.to_egg(mod_decls))
814
+
815
+
816
+ @dataclass(frozen=True)
817
+ class Delete(Action):
818
+ _call: CallDecl
819
+
820
+ def __str__(self) -> str:
821
+ return f"delete({self._call.pretty()})"
822
+
823
+ def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Delete:
824
+ return bindings.Delete(
825
+ mod_decls.get_egg_fn(self._call.callable), [a.to_egg(mod_decls) for a in self._call.args]
826
+ )
827
+
828
+
829
+ @dataclass(frozen=True)
830
+ class Union_(Action):
831
+ _lhs: ExprDecl
832
+ _rhs: ExprDecl
833
+
834
+ def __str__(self) -> str:
835
+ return f"union({self._lhs.pretty()}).with_({self._rhs.pretty()})"
836
+
837
+ def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Union:
838
+ return bindings.Union(self._lhs.to_egg(mod_decls), self._rhs.to_egg(mod_decls))
839
+
840
+
841
+ @dataclass(frozen=True)
842
+ class Panic(Action):
843
+ message: str
844
+
845
+ def __str__(self) -> str:
846
+ return f"panic({self.message})"
847
+
848
+ def _to_egg_action(self, _decls: ModuleDeclarations) -> bindings.Panic:
849
+ return bindings.Panic(self.message)
850
+
851
+
852
+ # def action_decl_to_egg(decls: Declarations, action: ActionDecl) -> bindings._Action:
853
+ # if isinstance(action, (CallDecl, LitDecl, VarDecl)):
854
+ # return bindings.Expr_(action.to_egg(decls))
855
+ # return action.to_egg(decls)
856
+
857
+
858
+ class Schedule(ABC):
859
+ def __mul__(self, length: int) -> Schedule:
860
+ """
861
+ Repeat the schedule a number of times.
862
+ """
863
+ return Repeat(length, self)
864
+
865
+ def saturate(self) -> Schedule:
866
+ """
867
+ Run the schedule until the e-graph is saturated.
868
+ """
869
+ return Saturate(self)
870
+
871
+ @abstractmethod
872
+ def __str__(self) -> str:
873
+ raise NotImplementedError
874
+
875
+ @abstractmethod
876
+ def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
877
+ raise NotImplementedError
878
+
879
+
880
+ @dataclass
881
+ class Run(Schedule):
882
+ """Configuration of a run"""
883
+
884
+ limit: int
885
+ ruleset: str
886
+ until: tuple[Fact, ...]
887
+
888
+ def __str__(self) -> str:
889
+ args_str = ", ".join(map(str, [self.ruleset, self.limit, *self.until]))
890
+ return f"run({args_str})"
891
+
892
+ def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
893
+ return bindings.Run(self._to_egg_config(mod_decls))
894
+
895
+ def _to_egg_config(self, mod_decls: ModuleDeclarations) -> bindings.RunConfig:
896
+ return bindings.RunConfig(
897
+ self.ruleset,
898
+ self.limit,
899
+ [fact._to_egg_fact(mod_decls) for fact in self.until] if self.until else None,
900
+ )
901
+
902
+
903
+ @dataclass
904
+ class Saturate(Schedule):
905
+ schedule: Schedule
906
+
907
+ def __str__(self) -> str:
908
+ return f"{self.schedule}.saturate()"
909
+
910
+ def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
911
+ return bindings.Saturate(self.schedule._to_egg_schedule(mod_decls))
912
+
913
+
914
+ @dataclass
915
+ class Repeat(Schedule):
916
+ length: int
917
+ schedule: Schedule
918
+
919
+ def __str__(self) -> str:
920
+ return f"{self.schedule} * {self.length}"
921
+
922
+ def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
923
+ return bindings.Repeat(self.length, self.schedule._to_egg_schedule(mod_decls))
924
+
925
+
926
+ @dataclass
927
+ class Sequence(Schedule):
928
+ schedules: tuple[Schedule, ...]
929
+
930
+ def __str__(self) -> str:
931
+ return f"sequence({', '.join(map(str, self.schedules))})"
932
+
933
+ def _to_egg_schedule(self, mod_decls: ModuleDeclarations) -> bindings._Schedule:
934
+ return bindings.Sequence([schedule._to_egg_schedule(mod_decls) for schedule in self.schedules])