egglog 0.4.0__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph.py ADDED
@@ -0,0 +1,1041 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from copy import deepcopy
5
+ from dataclasses import InitVar, dataclass, field
6
+ from inspect import Parameter, currentframe, signature
7
+ from types import FunctionType
8
+ from typing import _GenericAlias # type: ignore[attr-defined]
9
+ from typing import (
10
+ TYPE_CHECKING,
11
+ Any,
12
+ Callable,
13
+ ClassVar,
14
+ Generic,
15
+ Iterable,
16
+ Literal,
17
+ NoReturn,
18
+ Optional,
19
+ TypeVar,
20
+ Union,
21
+ cast,
22
+ get_type_hints,
23
+ overload,
24
+ )
25
+
26
+ import graphviz
27
+ from egglog.declarations import Declarations
28
+ from typing_extensions import ParamSpec, get_args, get_origin
29
+
30
+ from . import bindings
31
+ from .declarations import *
32
+ from .monkeypatch import monkeypatch_forward_ref
33
+ from .runtime import *
34
+ from .runtime import _resolve_callable, class_to_ref
35
+
36
+ if TYPE_CHECKING:
37
+ from .builtins import String
38
+
39
+ monkeypatch_forward_ref()
40
+
41
+ __all__ = [
42
+ "EGraph",
43
+ "Module",
44
+ "BUILTINS",
45
+ "BaseExpr",
46
+ "Unit",
47
+ "rewrite",
48
+ "eq",
49
+ "panic",
50
+ "let",
51
+ "delete",
52
+ "union",
53
+ "set_",
54
+ "rule",
55
+ "var",
56
+ "vars_",
57
+ "Fact",
58
+ "expr_parts",
59
+ "Schedule",
60
+ "run",
61
+ "seq",
62
+ ]
63
+
64
+ T = TypeVar("T")
65
+ P = ParamSpec("P")
66
+ TYPE = TypeVar("TYPE", bound="type[BaseExpr]")
67
+ CALLABLE = TypeVar("CALLABLE", bound=Callable)
68
+ EXPR = TypeVar("EXPR", bound="BaseExpr")
69
+ E1 = TypeVar("E1", bound="BaseExpr")
70
+ E2 = TypeVar("E2", bound="BaseExpr")
71
+ E3 = TypeVar("E3", bound="BaseExpr")
72
+ E4 = TypeVar("E4", bound="BaseExpr")
73
+ # Attributes which are sometimes added to classes by the interpreter or the dataclass decorator, or by ipython.
74
+ # We ignore these when inspecting the class.
75
+
76
+ IGNORED_ATTRIBUTES = {
77
+ "__module__",
78
+ "__doc__",
79
+ "__dict__",
80
+ "__weakref__",
81
+ "__orig_bases__",
82
+ "__annotations__",
83
+ "__hash__",
84
+ }
85
+
86
+
87
+ _BUILTIN_DECLS: Declarations | None = None
88
+
89
+
90
+ @dataclass
91
+ class _BaseModule(ABC):
92
+ """
93
+ Base Module which provides methods to register sorts, expressions, actions etc.
94
+
95
+ Inherited by:
96
+ - EGraph: Holds a live EGraph instance
97
+ - Builtins: Stores a list of the builtins which have already been pre-regsietered
98
+ - Module: Stores a list of commands and additional declerations
99
+ """
100
+
101
+ # Any modules you want to depend on
102
+ deps: InitVar[list[Module]] = []
103
+ # All dependencies flattened
104
+ _flatted_deps: list[Module] = field(init=False, default_factory=list)
105
+ _mod_decls: ModuleDeclarations = field(init=False)
106
+
107
+ def __post_init__(self, modules: list[Module] = []) -> None:
108
+ included_decls = [_BUILTIN_DECLS] if _BUILTIN_DECLS else []
109
+ # Traverse all the included modules to flatten all their dependencies and add to the included declerations
110
+ for mod in modules:
111
+ for child_mod in [*mod._flatted_deps, mod]:
112
+ if child_mod not in self._flatted_deps:
113
+ self._flatted_deps.append(child_mod)
114
+ included_decls.append(child_mod._mod_decls._decl)
115
+ self._mod_decls = ModuleDeclarations(Declarations(), included_decls)
116
+
117
+ @abstractmethod
118
+ def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
119
+ """
120
+ Process the commands generated by this module.
121
+ """
122
+ raise NotImplementedError
123
+
124
+ @overload
125
+ def class_(self, *, egg_sort: str) -> Callable[[TYPE], TYPE]:
126
+ ...
127
+
128
+ @overload
129
+ def class_(self, cls: TYPE, /) -> TYPE:
130
+ ...
131
+
132
+ def class_(self, *args, **kwargs) -> Any:
133
+ """
134
+ Registers a class.
135
+ """
136
+ frame = currentframe()
137
+ assert frame
138
+ prev_frame = frame.f_back
139
+ assert prev_frame
140
+
141
+ if kwargs:
142
+ assert set(kwargs.keys()) == {"egg_sort"}
143
+ return lambda cls: self._class(cls, prev_frame.f_locals, prev_frame.f_globals, kwargs["egg_sort"])
144
+ assert len(args) == 1
145
+ return self._class(args[0], prev_frame.f_locals, prev_frame.f_globals)
146
+
147
+ def _class(
148
+ self,
149
+ cls: type[BaseExpr],
150
+ hint_locals: dict[str, Any],
151
+ hint_globals: dict[str, Any],
152
+ egg_sort: Optional[str] = None,
153
+ ) -> RuntimeClass:
154
+ """
155
+ Registers a class.
156
+ """
157
+ cls_name = cls.__name__
158
+ # Get all the methods from the class
159
+ cls_dict: dict[str, Any] = {k: v for k, v in cls.__dict__.items() if k not in IGNORED_ATTRIBUTES}
160
+ parameters: list[TypeVar] = cls_dict.pop("__parameters__", [])
161
+
162
+ n_type_vars = len(parameters)
163
+ self._process_commands(self._mod_decls.register_class(cls_name, n_type_vars, egg_sort))
164
+ # The type ref of self is paramterized by the type vars
165
+ slf_type_ref = TypeRefWithVars(cls_name, tuple(ClassTypeVarRef(i) for i in range(n_type_vars)))
166
+
167
+ # First register any class vars as constants
168
+ hint_globals = hint_globals.copy()
169
+ hint_globals[cls_name] = cls
170
+ for k, v in get_type_hints(cls, globalns=hint_globals, localns=hint_locals).items():
171
+ if v.__origin__ == ClassVar:
172
+ (inner_tp,) = v.__args__
173
+ self._register_constant(ClassVariableRef(cls_name, k), inner_tp, None, (cls, cls_name))
174
+ else:
175
+ raise NotImplementedError("The only supported annotations on class attributes are class vars")
176
+
177
+ # Then register each of its methods
178
+ for method_name, method in cls_dict.items():
179
+ is_init = method_name == "__init__"
180
+ # Don't register the init methods for literals, since those don't use the type checking mechanisms
181
+ if is_init and cls_name in LIT_CLASS_NAMES:
182
+ continue
183
+ if isinstance(method, _WrappedMethod):
184
+ fn = method.fn
185
+ egg_fn = method.egg_fn
186
+ cost = method.cost
187
+ default = method.default
188
+ merge = method.merge
189
+ on_merge = method.on_merge
190
+ else:
191
+ fn = method
192
+ egg_fn, cost, default, merge, on_merge = None, None, None, None, None
193
+ if isinstance(fn, classmethod):
194
+ fn = fn.__func__
195
+ is_classmethod = True
196
+ else:
197
+ # We count __init__ as a classmethod since it is called on the class
198
+ is_classmethod = is_init
199
+
200
+ ref: ClassMethodRef | MethodRef = (
201
+ ClassMethodRef(cls_name, method_name) if is_classmethod else MethodRef(cls_name, method_name)
202
+ )
203
+ self._register_function(
204
+ ref,
205
+ egg_fn,
206
+ fn,
207
+ hint_locals,
208
+ default,
209
+ cost,
210
+ merge,
211
+ on_merge,
212
+ "cls" if is_classmethod and not is_init else slf_type_ref,
213
+ parameters,
214
+ is_init,
215
+ # If this is an i64, use the runtime class for the alias so that i64Like is resolved properly
216
+ # Otherwise, this might be a Map in which case pass in the original cls so that we
217
+ # can do Map[T, V] on it, which is not allowed on the runtime class
218
+ cls_type_and_name=(
219
+ RuntimeClass(self._mod_decls, cls_name) if cls_name in {"i64", "String"} else cls,
220
+ cls_name,
221
+ ),
222
+ )
223
+
224
+ # Register != as a method so we can print it as a string
225
+ self._mod_decls._decl.register_callable_ref(MethodRef(cls_name, "__ne__"), "!=")
226
+ return RuntimeClass(self._mod_decls, cls_name)
227
+
228
+ # We seperate the function and method overloads to make it simpler to know if we are modifying a function or method,
229
+ # So that we can add the functions eagerly to the registry and wait on the methods till we process the class.
230
+
231
+ # We have to seperate method/function overloads for those that use the T params and those that don't
232
+ # Otherwise, if you say just pass in `cost` then the T param is inferred as `Nothing` and
233
+ # It will break the typing.
234
+ @overload
235
+ def method( # type: ignore
236
+ self,
237
+ *,
238
+ egg_fn: Optional[str] = None,
239
+ cost: Optional[int] = None,
240
+ merge: Optional[Callable[[Any, Any], Any]] = None,
241
+ on_merge: Optional[Callable[[Any, Any], Iterable[ActionLike]]] = None,
242
+ ) -> Callable[[CALLABLE], CALLABLE]:
243
+ ...
244
+
245
+ @overload
246
+ def method(
247
+ self,
248
+ *,
249
+ egg_fn: Optional[str] = None,
250
+ cost: Optional[int] = None,
251
+ default: Optional[EXPR] = None,
252
+ merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None,
253
+ on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None,
254
+ ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
255
+ ...
256
+
257
+ def method(
258
+ self,
259
+ *,
260
+ egg_fn: Optional[str] = None,
261
+ cost: Optional[int] = None,
262
+ default: Optional[EXPR] = None,
263
+ merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None,
264
+ on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None,
265
+ ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
266
+ return lambda fn: _WrappedMethod(egg_fn, cost, default, merge, on_merge, fn)
267
+
268
+ @overload
269
+ def function(self, fn: CALLABLE, /) -> CALLABLE:
270
+ ...
271
+
272
+ @overload
273
+ def function( # type: ignore
274
+ self,
275
+ *,
276
+ egg_fn: Optional[str] = None,
277
+ cost: Optional[int] = None,
278
+ merge: Optional[Callable[[Any, Any], Any]] = None,
279
+ on_merge: Optional[Callable[[Any, Any], Iterable[ActionLike]]] = None,
280
+ ) -> Callable[[CALLABLE], CALLABLE]:
281
+ ...
282
+
283
+ @overload
284
+ def function(
285
+ self,
286
+ *,
287
+ egg_fn: Optional[str] = None,
288
+ cost: Optional[int] = None,
289
+ default: Optional[EXPR] = None,
290
+ merge: Optional[Callable[[EXPR, EXPR], EXPR]] = None,
291
+ on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]] = None,
292
+ ) -> Callable[[Callable[P, EXPR]], Callable[P, EXPR]]:
293
+ ...
294
+
295
+ def function(self, *args, **kwargs) -> Any:
296
+ """
297
+ Registers a function.
298
+ """
299
+ fn_locals = currentframe().f_back.f_locals # type: ignore
300
+
301
+ # If we have any positional args, then we are calling it directly on a function
302
+ if args:
303
+ assert len(args) == 1
304
+ return self._function(args[0], fn_locals)
305
+ # otherwise, we are passing some keyword args, so save those, and then return a partial
306
+ return lambda fn: self._function(fn, fn_locals, **kwargs)
307
+
308
+ def _function(
309
+ self,
310
+ fn: Callable[..., RuntimeExpr],
311
+ hint_locals: dict[str, Any],
312
+ egg_fn: Optional[str] = None,
313
+ cost: Optional[int] = None,
314
+ default: Optional[RuntimeExpr] = None,
315
+ merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr]] = None,
316
+ on_merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]]] = None,
317
+ ) -> RuntimeFunction:
318
+ """
319
+ Uncurried version of function decorator
320
+ """
321
+ name = fn.__name__
322
+ # Save function decleartion
323
+ self._register_function(FunctionRef(name), egg_fn, fn, hint_locals, default, cost, merge, on_merge)
324
+ # Return a runtime function which will act like the decleration
325
+ return RuntimeFunction(self._mod_decls, name)
326
+
327
+ def _register_function(
328
+ self,
329
+ ref: FunctionCallableRef,
330
+ egg_name: Optional[str],
331
+ fn: Any,
332
+ # Pass in the locals, retrieved from the frame when wrapping,
333
+ # so that we support classes and function defined inside of other functions (which won't show up in the globals)
334
+ hint_locals: dict[str, Any],
335
+ default: Optional[RuntimeExpr],
336
+ cost: Optional[int],
337
+ merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], RuntimeExpr]],
338
+ on_merge: Optional[Callable[[RuntimeExpr, RuntimeExpr], Iterable[ActionLike]]],
339
+ # The first arg is either cls, for a classmethod, a self type, or none for a function
340
+ first_arg: Literal["cls"] | TypeOrVarRef | None = None,
341
+ cls_typevars: list[TypeVar] = [],
342
+ is_init: bool = False,
343
+ cls_type_and_name: Optional[tuple[type | RuntimeClass, str]] = None,
344
+ ) -> None:
345
+ if not isinstance(fn, FunctionType):
346
+ raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}")
347
+
348
+ hint_globals = fn.__globals__.copy()
349
+
350
+ if cls_type_and_name:
351
+ hint_globals[cls_type_and_name[1]] = cls_type_and_name[0]
352
+ hints = get_type_hints(fn, hint_globals, hint_locals)
353
+ # If this is an init fn use the first arg as the return type
354
+ if is_init:
355
+ if not isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)):
356
+ raise ValueError("Init function must have a self type")
357
+ return_type = first_arg
358
+ else:
359
+ return_type = self._resolve_type_annotation(hints["return"], cls_typevars, cls_type_and_name)
360
+
361
+ params = list(signature(fn).parameters.values())
362
+ # Remove first arg if this is a classmethod or a method, since it won't have an annotation
363
+ if first_arg is not None:
364
+ first, *params = params
365
+ if first.annotation != Parameter.empty:
366
+ raise ValueError(f"First arg of a method must not have an annotation, not {first.annotation}")
367
+
368
+ # Check that all the params are positional or keyword, and that there is only one var arg at the end
369
+ found_var_arg = False
370
+ for param in params:
371
+ if found_var_arg:
372
+ raise ValueError("Can only have a single var arg at the end")
373
+ kind = param.kind
374
+ if kind == Parameter.VAR_POSITIONAL:
375
+ found_var_arg = True
376
+ elif kind != Parameter.POSITIONAL_OR_KEYWORD:
377
+ raise ValueError(f"Can only register functions with positional or keyword args, not {param.kind}")
378
+
379
+ if found_var_arg:
380
+ var_arg_param, *params = params
381
+ var_arg_type = self._resolve_type_annotation(hints[var_arg_param.name], cls_typevars, cls_type_and_name)
382
+ else:
383
+ var_arg_type = None
384
+ arg_types = tuple(self._resolve_type_annotation(hints[t.name], cls_typevars, cls_type_and_name) for t in params)
385
+ # If the first arg is a self, and this not an __init__ fn, add this as a typeref
386
+ if isinstance(first_arg, (ClassTypeVarRef, TypeRefWithVars)) and not is_init:
387
+ arg_types = (first_arg,) + arg_types
388
+
389
+ default_decl = None if default is None else default.__egg_typed_expr__.expr
390
+ merge_decl = (
391
+ None
392
+ if merge is None
393
+ else merge(
394
+ RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
395
+ RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
396
+ ).__egg_typed_expr__.expr
397
+ )
398
+ merge_action = (
399
+ []
400
+ if on_merge is None
401
+ else _action_likes(
402
+ on_merge(
403
+ RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("old"))),
404
+ RuntimeExpr(self._mod_decls, TypedExprDecl(return_type.to_just(), VarDecl("new"))),
405
+ )
406
+ )
407
+ )
408
+ fn_decl = FunctionDecl(return_type=return_type, var_arg_type=var_arg_type, arg_types=arg_types)
409
+ self._process_commands(
410
+ self._mod_decls.register_function_callable(
411
+ ref, fn_decl, egg_name, cost, default_decl, merge_decl, merge_action
412
+ )
413
+ )
414
+
415
+ def _resolve_type_annotation(
416
+ self,
417
+ tp: object,
418
+ cls_typevars: list[TypeVar],
419
+ cls_type_and_name: Optional[tuple[type | RuntimeClass, str]],
420
+ ) -> TypeOrVarRef:
421
+ if isinstance(tp, TypeVar):
422
+ return ClassTypeVarRef(cls_typevars.index(tp))
423
+ # If there is a union, it should be of a literal and another type to allow type promotion
424
+ if get_origin(tp) == Union:
425
+ args = get_args(tp)
426
+ if len(args) != 2:
427
+ raise TypeError("Union types are only supported for type promotion")
428
+ fst, snd = args
429
+ if fst in {int, str, float}:
430
+ return self._resolve_type_annotation(snd, cls_typevars, cls_type_and_name)
431
+ if snd in {int, str, float}:
432
+ return self._resolve_type_annotation(fst, cls_typevars, cls_type_and_name)
433
+ raise TypeError("Union types are only supported for type promotion")
434
+
435
+ # If this is the type for the class, use the class name
436
+ if cls_type_and_name and tp == cls_type_and_name[0]:
437
+ return TypeRefWithVars(cls_type_and_name[1])
438
+
439
+ # If this is the class for this method and we have a paramaterized class, recurse
440
+ if (
441
+ cls_type_and_name
442
+ and isinstance(tp, _GenericAlias)
443
+ and tp.__origin__ == cls_type_and_name[0] # type: ignore
444
+ ):
445
+ return TypeRefWithVars(
446
+ cls_type_and_name[1],
447
+ tuple(
448
+ self._resolve_type_annotation(a, cls_typevars, cls_type_and_name)
449
+ for a in tp.__args__ # type: ignore
450
+ ),
451
+ )
452
+
453
+ if isinstance(tp, (RuntimeClass, RuntimeParamaterizedClass)):
454
+ return class_to_ref(tp).to_var()
455
+ raise TypeError(f"Unexpected type annotation {tp}")
456
+
457
+ def register(self, command_or_generator: CommandLike | CommandGenerator, *commands: CommandLike) -> None:
458
+ """
459
+ Registers any number of rewrites or rules.
460
+ """
461
+ if isinstance(command_or_generator, FunctionType):
462
+ assert not commands
463
+ commands = tuple(_command_generator(command_or_generator))
464
+ else:
465
+ commands = (cast(CommandLike, command_or_generator), *commands)
466
+ self._process_commands(_command_like(command)._to_egg_command(self._mod_decls) for command in commands)
467
+
468
+ def ruleset(self, name: str) -> Ruleset:
469
+ self._process_commands([bindings.AddRuleset(name)])
470
+ return Ruleset(name)
471
+
472
+ # Overload to support aritys 0-4 until variadic generic support map, so we can map from type to value
473
+ @overload
474
+ def relation(
475
+ self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], tp4: type[E4], /
476
+ ) -> Callable[[E1, E2, E3, E4], Unit]:
477
+ ...
478
+
479
+ @overload
480
+ def relation(self, name: str, tp1: type[E1], tp2: type[E2], tp3: type[E3], /) -> Callable[[E1, E2, E3], Unit]:
481
+ ...
482
+
483
+ @overload
484
+ def relation(self, name: str, tp1: type[E1], tp2: type[E2], /) -> Callable[[E1, E2], Unit]:
485
+ ...
486
+
487
+ @overload
488
+ def relation(self, name: str, tp1: type[T], /, *, egg_fn: Optional[str] = None) -> Callable[[T], Unit]:
489
+ ...
490
+
491
+ @overload
492
+ def relation(self, name: str, /, *, egg_fn: Optional[str] = None) -> Callable[[], Unit]:
493
+ ...
494
+
495
+ def relation(self, name: str, /, *tps: type, egg_fn: Optional[str] = None) -> Callable[..., Unit]:
496
+ """
497
+ Defines a relation, which is the same as a function which returns unit.
498
+ """
499
+ arg_types = tuple(self._resolve_type_annotation(cast(object, tp), [], None) for tp in tps)
500
+ fn_decl = FunctionDecl(arg_types, TypeRefWithVars("Unit"))
501
+ commands = self._mod_decls.register_function_callable(
502
+ FunctionRef(name), fn_decl, egg_fn, cost=None, default=None, merge=None, merge_action=[]
503
+ )
504
+ self._process_commands(commands)
505
+ return cast(Callable[..., Unit], RuntimeFunction(self._mod_decls, name))
506
+
507
+ def input(self, fn: Callable[..., String], path: str) -> None:
508
+ """
509
+ Loads a CSV file and sets it as *input, output of the function.
510
+ """
511
+ fn_name = self._mod_decls.get_egg_fn(_resolve_callable(fn))
512
+ self._process_commands([bindings.Input(fn_name, path)])
513
+
514
+ def constant(self, name: str, tp: type[EXPR], egg_name: Optional[str] = None) -> EXPR:
515
+ """
516
+ Defines a named constant of a certain type.
517
+
518
+ This is the same as defining a nullary function with a high cost.
519
+ """
520
+ ref = ConstantRef(name)
521
+ type_ref = self._register_constant(ref, tp, egg_name, None)
522
+ return cast(EXPR, RuntimeExpr(self._mod_decls, TypedExprDecl(type_ref, CallDecl(ref))))
523
+
524
+ def _register_constant(
525
+ self,
526
+ ref: ConstantRef | ClassVariableRef,
527
+ tp: object,
528
+ egg_name: Optional[str],
529
+ cls_type_and_name: Optional[tuple[type | RuntimeClass, str]],
530
+ ) -> JustTypeRef:
531
+ """
532
+ Register a constant, returning its typeref().
533
+ """
534
+ type_ref = self._resolve_type_annotation(tp, [], cls_type_and_name).to_just()
535
+ self._process_commands(self._mod_decls.register_constant_callable(ref, type_ref, egg_name))
536
+ return type_ref
537
+
538
+ def define(self, name: str, expr: EXPR) -> EXPR:
539
+ """
540
+ Define a new expression in the egraph and return a reference to it.
541
+ """
542
+ # Don't support cost and maybe will be removed in favor of let
543
+ # https://github.com/egraphs-good/egglog/issues/128#issuecomment-1523760578
544
+ typed_expr = expr_parts(expr)
545
+ self._process_commands([bindings.Define(name, typed_expr.to_egg(self._mod_decls), None)])
546
+ return cast(EXPR, RuntimeExpr(self._mod_decls, TypedExprDecl(typed_expr.tp, VarDecl(name))))
547
+
548
+
549
+ @dataclass
550
+ class _Builtins(_BaseModule):
551
+ def __post_init__(self, modules: list[Module] = []) -> None:
552
+ """
553
+ Register these declarations as builtins, so others can use them.
554
+ """
555
+ assert not modules
556
+ super().__post_init__(modules)
557
+ global _BUILTIN_DECLS
558
+ if _BUILTIN_DECLS is not None:
559
+ raise RuntimeError("Builtins already initialized")
560
+ _BUILTIN_DECLS = self._mod_decls._decl
561
+
562
+ def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
563
+ """
564
+ Commands which would have been used to create the builtins are discarded, since they are already registered.
565
+ """
566
+ pass
567
+
568
+
569
+ @dataclass
570
+ class Module(_BaseModule):
571
+ _cmds: list[bindings._Command] = field(default_factory=list, repr=False)
572
+
573
+ def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
574
+ self._cmds.extend(cmds)
575
+
576
+
577
+ @dataclass
578
+ class EGraph(_BaseModule):
579
+ """
580
+ Represents an EGraph instance at runtime
581
+ """
582
+
583
+ _egraph: bindings.EGraph = field(repr=False, default_factory=bindings.EGraph)
584
+ # The current declarations which have been pushed to the stack
585
+ _decl_stack: list[Declarations] = field(default_factory=list, repr=False)
586
+
587
+ def __post_init__(self, modules: list[Module] = []) -> None:
588
+ super().__post_init__(modules)
589
+ for m in self._flatted_deps:
590
+ self._process_commands(m._cmds)
591
+
592
+ def _process_commands(self, commands: Iterable[bindings._Command]) -> None:
593
+ self._egraph.run_program(*commands)
594
+
595
+ def _repr_mimebundle_(self, *args, **kwargs):
596
+ """
597
+ Returns the graphviz representation of the e-graph.
598
+ """
599
+
600
+ return self.graphviz._repr_mimebundle_(*args, **kwargs)
601
+
602
+ @property
603
+ def graphviz(self) -> graphviz.Source:
604
+ return graphviz.Source(self._egraph.to_graphviz_string())
605
+
606
+ def _repr_html_(self) -> str:
607
+ """
608
+ Add a _repr_html_ to be an SVG to work with sphinx gallery
609
+ ala https://github.com/xflr6/graphviz/pull/121
610
+ until this PR is merged and released
611
+ https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
612
+ """
613
+ return self.graphviz.pipe(format="svg").decode()
614
+
615
+ def display(self):
616
+ """
617
+ Displays the e-graph in the notebook.
618
+ """
619
+ from IPython.display import display
620
+
621
+ display(self)
622
+
623
+ def simplify(self, expr: EXPR, limit: int, *until: Fact, ruleset: Optional[Ruleset] = None) -> EXPR:
624
+ """
625
+ Simplifies the given expression.
626
+ """
627
+ typed_expr = expr_parts(expr)
628
+ egg_expr = typed_expr.to_egg(self._mod_decls)
629
+ self._process_commands(
630
+ [bindings.Simplify(egg_expr, Run(limit, _ruleset_name(ruleset), until)._to_egg_config(self._mod_decls))]
631
+ )
632
+ extract_report = self._egraph.extract_report()
633
+ if not extract_report:
634
+ raise ValueError("No extract report saved")
635
+ new_typed_expr = TypedExprDecl.from_egg(self._mod_decls, extract_report.expr)
636
+ return cast(EXPR, RuntimeExpr(self._mod_decls, new_typed_expr))
637
+
638
+ def include(self, path: str) -> None:
639
+ """
640
+ Include a file of rules.
641
+ """
642
+ raise NotImplementedError(
643
+ "Not implemented yet, because we don't have a way of registering the types with Python"
644
+ )
645
+
646
+ def output(self) -> None:
647
+ raise NotImplementedError("Not imeplemented yet, because there are no examples in the egglog repo")
648
+
649
+ @overload
650
+ def run(self, limit: int, /, *until: Fact, ruleset: Optional[Ruleset] = None) -> bindings.RunReport:
651
+ ...
652
+
653
+ @overload
654
+ def run(self, schedule: Schedule, /) -> bindings.RunReport:
655
+ ...
656
+
657
+ def run(
658
+ self, limit_or_schedule: int | Schedule, /, *until: Fact, ruleset: Optional[Ruleset] = None
659
+ ) -> bindings.RunReport:
660
+ """
661
+ Run the egraph until the given limit or until the given facts are true.
662
+ """
663
+ if isinstance(limit_or_schedule, int):
664
+ limit_or_schedule = run(ruleset, limit_or_schedule, *until)
665
+ return self._run_schedule(limit_or_schedule)
666
+
667
+ def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
668
+ self._process_commands([bindings.RunScheduleCommand(schedule._to_egg_schedule(self._mod_decls))])
669
+ run_report = self._egraph.run_report()
670
+ if not run_report:
671
+ raise ValueError("No run report saved")
672
+ return run_report
673
+
674
+ def check(self, *facts: FactLike) -> None:
675
+ """
676
+ Check if a fact is true in the egraph.
677
+ """
678
+ self._process_commands([self._facts_to_check(facts)])
679
+
680
+ def check_fail(self, *facts: FactLike) -> None:
681
+ """
682
+ Checks that one of the facts is not true
683
+ """
684
+ self._process_commands([bindings.Fail(self._facts_to_check(facts))])
685
+
686
+ def _facts_to_check(self, facts: Iterable[FactLike]) -> bindings.Check:
687
+ egg_facts = [f._to_egg_fact(self._mod_decls) for f in _fact_likes(facts)]
688
+ return bindings.Check(egg_facts)
689
+
690
+ def extract(self, expr: EXPR) -> EXPR:
691
+ """
692
+ Extract the lowest cost expression from the egraph.
693
+ """
694
+ typed_expr = expr_parts(expr)
695
+ egg_expr = typed_expr.to_egg(self._mod_decls)
696
+ extract_report = self._run_extract(egg_expr, 0)
697
+ new_typed_expr = TypedExprDecl.from_egg(self._mod_decls, extract_report.expr)
698
+ if new_typed_expr.tp != typed_expr.tp:
699
+ raise RuntimeError(f"Type mismatch: {new_typed_expr.tp} != {typed_expr.tp}")
700
+ return cast(EXPR, RuntimeExpr(self._mod_decls, new_typed_expr))
701
+
702
+ def extract_multiple(self, expr: EXPR, n: int) -> list[EXPR]:
703
+ """
704
+ Extract multiple expressions from the egraph.
705
+ """
706
+ typed_expr = expr_parts(expr)
707
+ egg_expr = typed_expr.to_egg(self._mod_decls)
708
+ extract_report = self._run_extract(egg_expr, n)
709
+ new_exprs = [TypedExprDecl.from_egg(self._mod_decls, egg_expr) for egg_expr in extract_report.variants]
710
+ return [cast(EXPR, RuntimeExpr(self._mod_decls, expr)) for expr in new_exprs]
711
+
712
+ def _run_extract(self, expr: bindings._Expr, n: int) -> bindings.ExtractReport:
713
+ self._process_commands([bindings.Extract(n, expr)])
714
+ extract_report = self._egraph.extract_report()
715
+ if not extract_report:
716
+ raise ValueError("No extract report saved")
717
+ return extract_report
718
+
719
+ def push(self) -> None:
720
+ """
721
+ Push the current state of the egraph, so that it can be popped later and reverted back.
722
+ """
723
+ self._process_commands([bindings.Push(1)])
724
+ self._decl_stack.append(self._mod_decls._decl)
725
+ self._decls = deepcopy(self._mod_decls._decl)
726
+
727
+ def pop(self) -> None:
728
+ """
729
+ Pop the current state of the egraph, reverting back to the previous state.
730
+ """
731
+ self._process_commands([bindings.Pop(1)])
732
+ self._mod_decls._decl = self._decl_stack.pop()
733
+
734
+ def __enter__(self):
735
+ """
736
+ Copy the egraph state, so that it can be reverted back to the original state at the end.
737
+ """
738
+ self.push()
739
+
740
+ def __exit__(self, exc_type, exc, exc_tb):
741
+ self.pop()
742
+
743
+
744
+ @dataclass(frozen=True)
745
+ class _WrappedMethod(Generic[P, EXPR]):
746
+ """
747
+ Used to wrap a method and store some extra options on it before processing it when processing the class.
748
+ """
749
+
750
+ egg_fn: Optional[str]
751
+ cost: Optional[int]
752
+ default: Optional[EXPR]
753
+ merge: Optional[Callable[[EXPR, EXPR], EXPR]]
754
+ on_merge: Optional[Callable[[EXPR, EXPR], Iterable[ActionLike]]]
755
+ fn: Callable[P, EXPR]
756
+
757
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> EXPR:
758
+ raise NotImplementedError("We should never call a wrapped method. Did you forget to wrap the class?")
759
+
760
+
761
+ class _BaseExprMetaclass(type):
762
+ """
763
+ Metaclass of BaseExpr, used to override isistance checks, so that runtime expressions are instances
764
+ of BaseExpr at runtime.
765
+ """
766
+
767
+ def __instancecheck__(self, instance: object) -> bool:
768
+ return isinstance(instance, RuntimeExpr)
769
+
770
+
771
+ class BaseExpr(metaclass=_BaseExprMetaclass):
772
+ """
773
+ Expression base class, which adds suport for != to all expression types.
774
+ """
775
+
776
+ def __ne__(self: EXPR, other_expr: EXPR) -> Unit: # type: ignore[override, empty-body]
777
+ """
778
+ Compare whether to expressions are not equal.
779
+
780
+ :param self: The expression to compare.
781
+ :param other_expr: The other expression to compare to, which must be of the same type.
782
+ :meta public:
783
+ """
784
+ ...
785
+
786
+ def __eq__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body]
787
+ """
788
+ Equality is currently not supported. We only add this method so that
789
+ if you try to use it MyPy will warn you.
790
+ """
791
+ ...
792
+
793
+
794
+ BUILTINS = _Builtins()
795
+
796
+
797
+ @BUILTINS.class_(egg_sort="Unit")
798
+ class Unit(BaseExpr):
799
+ """
800
+ The unit type. This is also used to reprsent if a value exists, if it is resolved or not.
801
+ """
802
+
803
+ def __init__(self) -> None:
804
+ ...
805
+
806
+
807
+ @dataclass(frozen=True)
808
+ class Ruleset:
809
+ name: str
810
+
811
+
812
+ def _ruleset_name(ruleset: Optional[Ruleset]) -> str:
813
+ return ruleset.name if ruleset else ""
814
+
815
+
816
+ # We use these builders so that when creating these structures we can type check
817
+ # if the arguments are the same type of expression
818
+
819
+
820
+ def rewrite(lhs: EXPR, ruleset: Optional[Ruleset] = None) -> _RewriteBuilder[EXPR]:
821
+ """Rewrite the given expression to a new expression."""
822
+ return _RewriteBuilder(lhs, ruleset)
823
+
824
+
825
+ def birewrite(lhs: EXPR, ruleset: Optional[Ruleset] = None) -> _BirewriteBuilder[EXPR]:
826
+ """Rewrite the given expression to a new expression and vice versa."""
827
+ return _BirewriteBuilder(lhs, ruleset)
828
+
829
+
830
+ def eq(expr: EXPR) -> _EqBuilder[EXPR]:
831
+ """Check if the given expression is equal to the given value."""
832
+ return _EqBuilder(expr)
833
+
834
+
835
+ def panic(message: str) -> Action:
836
+ """Raise an error with the given message."""
837
+ return Panic(message)
838
+
839
+
840
+ def let(name: str, expr: BaseExpr) -> Action:
841
+ """Create a let binding."""
842
+ return Let(name, expr_parts(expr).expr)
843
+
844
+
845
+ def expr_action(expr: BaseExpr) -> Action:
846
+ typed_expr = expr_parts(expr)
847
+ return ExprAction(typed_expr.expr)
848
+
849
+
850
+ def delete(expr: BaseExpr) -> Action:
851
+ """Create a delete expression."""
852
+ decl = expr_parts(expr).expr
853
+ if not isinstance(decl, CallDecl):
854
+ raise ValueError(f"Can only delete calls not {decl}")
855
+ return Delete(decl)
856
+
857
+
858
+ def expr_fact(expr: BaseExpr) -> Fact:
859
+ return ExprFact(expr_parts(expr).expr)
860
+
861
+
862
+ def union(lhs: EXPR) -> _UnionBuilder[EXPR]:
863
+ """Create a union of the given expression."""
864
+ return _UnionBuilder(lhs=lhs)
865
+
866
+
867
+ def set_(lhs: EXPR) -> _SetBuilder[EXPR]:
868
+ """Create a set of the given expression."""
869
+ return _SetBuilder(lhs=lhs)
870
+
871
+
872
+ def rule(*facts: FactLike, ruleset: Optional[Ruleset] = None, name: Optional[str] = None) -> _RuleBuilder:
873
+ """Create a rule with the given facts."""
874
+ return _RuleBuilder(facts=_fact_likes(facts), name=name, ruleset=ruleset)
875
+
876
+
877
+ def var(name: str, bound: type[EXPR]) -> EXPR:
878
+ """Create a new variable with the given name and type."""
879
+ return cast(EXPR, _var(name, bound))
880
+
881
+
882
+ def _var(name: str, bound: Any) -> RuntimeExpr:
883
+ """Create a new variable with the given name and type."""
884
+ if not isinstance(bound, (RuntimeClass, RuntimeParamaterizedClass)):
885
+ raise TypeError(f"Unexpected type {type(bound)}")
886
+ return RuntimeExpr(bound.__egg_decls__, TypedExprDecl(class_to_ref(bound), VarDecl(name)))
887
+
888
+
889
+ def vars_(names: str, bound: type[EXPR]) -> Iterable[EXPR]:
890
+ """Create variables with the given names and type."""
891
+ for name in names.split(" "):
892
+ yield var(name, bound)
893
+
894
+
895
+ @dataclass
896
+ class _RewriteBuilder(Generic[EXPR]):
897
+ lhs: EXPR
898
+ ruleset: Optional[Ruleset]
899
+
900
+ def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
901
+ return Rewrite(
902
+ _ruleset_name(self.ruleset),
903
+ expr_parts(self.lhs).expr,
904
+ expr_parts(rhs).expr,
905
+ _fact_likes(conditions),
906
+ )
907
+
908
+ def __str__(self) -> str:
909
+ return f"rewrite({self.lhs})"
910
+
911
+
912
+ @dataclass
913
+ class _BirewriteBuilder(Generic[EXPR]):
914
+ lhs: EXPR
915
+ ruleset: Optional[Ruleset]
916
+
917
+ def to(self, rhs: EXPR, *conditions: FactLike) -> Command:
918
+ return BiRewrite(
919
+ _ruleset_name(self.ruleset),
920
+ expr_parts(self.lhs).expr,
921
+ expr_parts(rhs).expr,
922
+ _fact_likes(conditions),
923
+ )
924
+
925
+ def __str__(self) -> str:
926
+ return f"birewrite({self.lhs})"
927
+
928
+
929
+ @dataclass
930
+ class _EqBuilder(Generic[EXPR]):
931
+ expr: EXPR
932
+
933
+ def to(self, *exprs: EXPR) -> Fact:
934
+ return Eq(tuple(expr_parts(e).expr for e in (self.expr, *exprs)))
935
+
936
+ def __str__(self) -> str:
937
+ return f"eq({self.expr})"
938
+
939
+
940
+ @dataclass
941
+ class _SetBuilder(Generic[EXPR]):
942
+ lhs: BaseExpr
943
+
944
+ def to(self, rhs: EXPR) -> Action:
945
+ lhs = expr_parts(self.lhs).expr
946
+ if not isinstance(lhs, CallDecl):
947
+ raise ValueError(f"Can only create a call with a call for the lhs, got {lhs}")
948
+ return Set(lhs, expr_parts(rhs).expr)
949
+
950
+ def __str__(self) -> str:
951
+ return f"set_({self.lhs})"
952
+
953
+
954
+ @dataclass
955
+ class _UnionBuilder(Generic[EXPR]):
956
+ lhs: BaseExpr
957
+
958
+ def with_(self, rhs: EXPR) -> Action:
959
+ return Union_(expr_parts(self.lhs).expr, expr_parts(rhs).expr)
960
+
961
+ def __str__(self) -> str:
962
+ return f"union({self.lhs})"
963
+
964
+
965
+ @dataclass
966
+ class _RuleBuilder:
967
+ facts: tuple[Fact, ...]
968
+ name: Optional[str]
969
+ ruleset: Optional[Ruleset]
970
+
971
+ def then(self, *actions: ActionLike) -> Command:
972
+ return Rule(_action_likes(actions), self.facts, self.name or "", _ruleset_name(self.ruleset))
973
+
974
+
975
+ def expr_parts(expr: BaseExpr) -> TypedExprDecl:
976
+ """
977
+ Returns the underlying type and decleration of the expression. Useful for testing structural equality or debugging.
978
+ """
979
+ assert isinstance(expr, RuntimeExpr)
980
+ return expr.__egg_typed_expr__
981
+
982
+
983
+ def run(ruleset: Optional[Ruleset] = None, limit: int = 1, *until: Fact) -> Run:
984
+ """
985
+ Create a run configuration.
986
+ """
987
+ return Run(limit, _ruleset_name(ruleset), tuple(until))
988
+
989
+
990
+ def seq(*schedules: Schedule) -> Schedule:
991
+ """
992
+ Run a sequence of schedules.
993
+ """
994
+ return Sequence(tuple(schedules))
995
+
996
+
997
+ CommandLike = Union[Command, BaseExpr]
998
+
999
+
1000
+ def _command_like(command_like: CommandLike) -> Command:
1001
+ if isinstance(command_like, BaseExpr):
1002
+ return expr_action(command_like)
1003
+ return command_like
1004
+
1005
+
1006
+ CommandGenerator = Callable[..., Iterable[Command]]
1007
+
1008
+
1009
+ def _command_generator(gen: CommandGenerator) -> Iterable[Command]:
1010
+ """
1011
+ Calls the function with variables of the type and name of the arguments.
1012
+ """
1013
+ hints = get_type_hints(gen)
1014
+ args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
1015
+ return gen(*args)
1016
+
1017
+
1018
+ ActionLike = Union[Action, BaseExpr]
1019
+
1020
+
1021
+ def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]:
1022
+ return tuple(map(_action_like, action_likes))
1023
+
1024
+
1025
+ def _action_like(action_like: ActionLike) -> Action:
1026
+ if isinstance(action_like, BaseExpr):
1027
+ return expr_action(action_like)
1028
+ return action_like
1029
+
1030
+
1031
+ FactLike = Union[Fact, Unit]
1032
+
1033
+
1034
+ def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]:
1035
+ return tuple(map(_fact_like, fact_likes))
1036
+
1037
+
1038
+ def _fact_like(fact_like: FactLike) -> Fact:
1039
+ if isinstance(fact_like, BaseExpr):
1040
+ return expr_fact(fact_like)
1041
+ return fact_like