egglog 9.0.0__pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Files changed (44) hide show
  1. egglog/__init__.py +10 -0
  2. egglog/bindings.pyi +667 -0
  3. egglog/bindings.pypy311-pp73-aarch64-linux-gnu.so +0 -0
  4. egglog/builtins.py +1045 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +262 -0
  7. egglog/declarations.py +818 -0
  8. egglog/egraph.py +1909 -0
  9. egglog/egraph_state.py +634 -0
  10. egglog/examples/README.rst +5 -0
  11. egglog/examples/__init__.py +3 -0
  12. egglog/examples/bignum.py +31 -0
  13. egglog/examples/bool.py +38 -0
  14. egglog/examples/eqsat_basic.py +45 -0
  15. egglog/examples/fib.py +28 -0
  16. egglog/examples/higher_order_functions.py +45 -0
  17. egglog/examples/lambda_.py +288 -0
  18. egglog/examples/matrix.py +175 -0
  19. egglog/examples/multiset.py +61 -0
  20. egglog/examples/ndarrays.py +144 -0
  21. egglog/examples/resolution.py +84 -0
  22. egglog/examples/schedule_demo.py +34 -0
  23. egglog/exp/__init__.py +3 -0
  24. egglog/exp/array_api.py +1943 -0
  25. egglog/exp/array_api_jit.py +44 -0
  26. egglog/exp/array_api_loopnest.py +74 -0
  27. egglog/exp/array_api_numba.py +69 -0
  28. egglog/exp/array_api_program_gen.py +510 -0
  29. egglog/exp/program_gen.py +424 -0
  30. egglog/exp/siu_examples.py +32 -0
  31. egglog/functionalize.py +91 -0
  32. egglog/ipython_magic.py +41 -0
  33. egglog/pretty.py +510 -0
  34. egglog/py.typed +0 -0
  35. egglog/runtime.py +633 -0
  36. egglog/thunk.py +95 -0
  37. egglog/type_constraint_solver.py +113 -0
  38. egglog/visualizer.css +1 -0
  39. egglog/visualizer.js +35777 -0
  40. egglog/visualizer_widget.py +39 -0
  41. egglog-9.0.0.dist-info/METADATA +74 -0
  42. egglog-9.0.0.dist-info/RECORD +44 -0
  43. egglog-9.0.0.dist-info/WHEEL +4 -0
  44. egglog-9.0.0.dist-info/licenses/LICENSE +21 -0
egglog/egraph_state.py ADDED
@@ -0,0 +1,634 @@
1
+ """
2
+ Implement conversion to/from egglog.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import re
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass, field
10
+ from typing import TYPE_CHECKING, Literal, overload
11
+
12
+ from typing_extensions import assert_never
13
+
14
+ from . import bindings
15
+ from .declarations import *
16
+ from .declarations import ConstructorDecl
17
+ from .pretty import *
18
+ from .type_constraint_solver import *
19
+
20
+ if TYPE_CHECKING:
21
+ from collections.abc import Iterable
22
+
23
+ __all__ = ["GLOBAL_PY_OBJECT_SORT", "EGraphState", "span"]
24
+
25
+ # Create a global sort for python objects, so we can store them without an e-graph instance
26
+ # Needed when serializing commands to egg commands when creating modules
27
+ GLOBAL_PY_OBJECT_SORT = bindings.PyObjectSort()
28
+
29
+
30
+ def span(frame_index: int = 0) -> bindings.RustSpan:
31
+ """
32
+ Returns a span for the current file and line.
33
+
34
+ If `frame_index` is passed, it will return the span for that frame in the stack, where 0 is the current frame
35
+ this is called in and 1 is the parent.
36
+ """
37
+ # Currently disable this because it's too expensive.
38
+ # import inspect
39
+
40
+ # frame = inspect.stack()[frame_index + 1]
41
+ return bindings.RustSpan("", 0, 0)
42
+
43
+
44
+ @dataclass
45
+ class EGraphState:
46
+ """
47
+ State of the EGraph declerations and rulesets, so when we pop/push the stack we know whats defined.
48
+
49
+ Used for converting to/from egg and for pretty printing.
50
+ """
51
+
52
+ egraph: bindings.EGraph
53
+ # The decleratons we have added.
54
+ __egg_decls__: Declarations = field(default_factory=Declarations)
55
+ # Mapping of added rulesets to the added rules
56
+ rulesets: dict[str, set[RewriteOrRuleDecl]] = field(default_factory=dict)
57
+
58
+ # Bidirectional mapping between egg function names and python callable references.
59
+ # Note that there are possibly mutliple callable references for a single egg function name, like `+`
60
+ # for both int and rational classes.
61
+ egg_fn_to_callable_refs: dict[str, set[CallableRef]] = field(
62
+ default_factory=lambda: defaultdict(set, {"!=": {FunctionRef("!=")}})
63
+ )
64
+ callable_ref_to_egg_fn: dict[CallableRef, tuple[str, bool]] = field(
65
+ default_factory=lambda: {FunctionRef("!="): ("!=", False)}
66
+ )
67
+
68
+ # Bidirectional mapping between egg sort names and python type references.
69
+ type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
70
+
71
+ # Cache of egg expressions for converting to egg
72
+ expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
73
+
74
+ def copy(self) -> EGraphState:
75
+ """
76
+ Returns a copy of the state. Th egraph reference is kept the same. Used for pushing/popping.
77
+ """
78
+ return EGraphState(
79
+ egraph=self.egraph,
80
+ __egg_decls__=self.__egg_decls__.copy(),
81
+ rulesets={k: v.copy() for k, v in self.rulesets.items()},
82
+ egg_fn_to_callable_refs=defaultdict(set, {k: v.copy() for k, v in self.egg_fn_to_callable_refs.items()}),
83
+ callable_ref_to_egg_fn=self.callable_ref_to_egg_fn.copy(),
84
+ type_ref_to_egg_sort=self.type_ref_to_egg_sort.copy(),
85
+ expr_to_egg_cache=self.expr_to_egg_cache.copy(),
86
+ )
87
+
88
+ def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
89
+ match schedule:
90
+ case SaturateDecl(schedule):
91
+ return bindings.Saturate(span(), self.schedule_to_egg(schedule))
92
+ case RepeatDecl(schedule, times):
93
+ return bindings.Repeat(span(), times, self.schedule_to_egg(schedule))
94
+ case SequenceDecl(schedules):
95
+ return bindings.Sequence(span(), [self.schedule_to_egg(s) for s in schedules])
96
+ case RunDecl(ruleset_name, until):
97
+ self.ruleset_to_egg(ruleset_name)
98
+ config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
99
+ return bindings.Run(span(), config)
100
+ case _:
101
+ assert_never(schedule)
102
+
103
+ def ruleset_to_egg(self, name: str) -> None:
104
+ """
105
+ Registers a ruleset if it's not already registered.
106
+ """
107
+ match self.__egg_decls__._rulesets[name]:
108
+ case RulesetDecl(rules):
109
+ if name not in self.rulesets:
110
+ if name:
111
+ self.egraph.run_program(bindings.AddRuleset(name))
112
+ added_rules = self.rulesets[name] = set()
113
+ else:
114
+ added_rules = self.rulesets[name]
115
+ for rule in rules:
116
+ if rule in added_rules:
117
+ continue
118
+ cmd = self.command_to_egg(rule, name)
119
+ if cmd is not None:
120
+ self.egraph.run_program(cmd)
121
+ added_rules.add(rule)
122
+ case CombinedRulesetDecl(rulesets):
123
+ if name in self.rulesets:
124
+ return
125
+ self.rulesets[name] = set()
126
+ for ruleset in rulesets:
127
+ self.ruleset_to_egg(ruleset)
128
+ self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
129
+
130
+ def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command | None:
131
+ match cmd:
132
+ case ActionCommandDecl(action):
133
+ action_egg = self.action_to_egg(action, expr_to_let=True)
134
+ if not action_egg:
135
+ return None
136
+ return bindings.ActionCommand(action_egg)
137
+ case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
138
+ self.type_ref_to_egg(tp)
139
+ rewrite = bindings.Rewrite(
140
+ span(),
141
+ self._expr_to_egg(lhs),
142
+ self._expr_to_egg(rhs),
143
+ [self.fact_to_egg(c) for c in conditions],
144
+ )
145
+ return (
146
+ bindings.RewriteCommand(ruleset, rewrite, cmd.subsume)
147
+ if isinstance(cmd, RewriteDecl)
148
+ else bindings.BiRewriteCommand(ruleset, rewrite)
149
+ )
150
+ case RuleDecl(head, body, name):
151
+ rule = bindings.Rule(
152
+ span(),
153
+ [self.action_to_egg(a) for a in head],
154
+ [self.fact_to_egg(f) for f in body],
155
+ )
156
+ return bindings.RuleCommand(name or "", ruleset, rule)
157
+ # TODO: Replace with just constants value and looking at REF of function
158
+ case DefaultRewriteDecl(ref, expr, subsume):
159
+ sig = self.__egg_decls__.get_callable_decl(ref).signature
160
+ assert isinstance(sig, FunctionSignature)
161
+ # Replace args with rule_var_name mapping
162
+ arg_mapping = tuple(
163
+ TypedExprDecl(tp.to_just(), VarDecl(name, False))
164
+ for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
165
+ )
166
+ rewrite_decl = RewriteDecl(
167
+ sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), subsume
168
+ )
169
+ return self.command_to_egg(rewrite_decl, ruleset)
170
+ case _:
171
+ assert_never(cmd)
172
+
173
+ @overload
174
+ def action_to_egg(self, action: ActionDecl) -> bindings._Action: ...
175
+
176
+ @overload
177
+ def action_to_egg(self, action: ActionDecl, expr_to_let: Literal[True] = ...) -> bindings._Action | None: ...
178
+
179
+ def action_to_egg(self, action: ActionDecl, expr_to_let: bool = False) -> bindings._Action | None: # noqa: C901, PLR0911, PLR0912
180
+ match action:
181
+ case LetDecl(name, typed_expr):
182
+ var_decl = VarDecl(name, True)
183
+ var_egg = self._expr_to_egg(var_decl)
184
+ self.expr_to_egg_cache[var_decl] = var_egg
185
+ return bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr))
186
+ case SetDecl(tp, call, rhs):
187
+ self.type_ref_to_egg(tp)
188
+ call_ = self._expr_to_egg(call)
189
+ return bindings.Set(span(), call_.name, call_.args, self._expr_to_egg(rhs))
190
+ case ExprActionDecl(typed_expr):
191
+ if expr_to_let:
192
+ maybe_typed_expr = self._transform_let(typed_expr)
193
+ if maybe_typed_expr:
194
+ typed_expr = maybe_typed_expr
195
+ else:
196
+ return None
197
+ return bindings.Expr_(span(), self.typed_expr_to_egg(typed_expr))
198
+ case ChangeDecl(tp, call, change):
199
+ self.type_ref_to_egg(tp)
200
+ call_ = self._expr_to_egg(call)
201
+ egg_change: bindings._Change
202
+ match change:
203
+ case "delete":
204
+ egg_change = bindings.Delete()
205
+ case "subsume":
206
+ egg_change = bindings.Subsume()
207
+ case _:
208
+ assert_never(change)
209
+ return bindings.Change(span(), egg_change, call_.name, call_.args)
210
+ case UnionDecl(tp, lhs, rhs):
211
+ self.type_ref_to_egg(tp)
212
+ return bindings.Union(span(), self._expr_to_egg(lhs), self._expr_to_egg(rhs))
213
+ case PanicDecl(name):
214
+ return bindings.Panic(span(), name)
215
+ case _:
216
+ assert_never(action)
217
+
218
+ def fact_to_egg(self, fact: FactDecl) -> bindings._Fact:
219
+ match fact:
220
+ case EqDecl(tp, left, right):
221
+ self.type_ref_to_egg(tp)
222
+ return bindings.Eq(span(), self._expr_to_egg(left), self._expr_to_egg(right))
223
+ case ExprFactDecl(typed_expr):
224
+ return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
225
+ case _:
226
+ assert_never(fact)
227
+
228
+ def callable_ref_to_egg(self, ref: CallableRef) -> tuple[str, bool]:
229
+ """
230
+ Returns the egg function name for a callable reference, registering it if it is not already registered.
231
+
232
+ Also returns whether the args should be reversed
233
+ """
234
+ if ref in self.callable_ref_to_egg_fn:
235
+ return self.callable_ref_to_egg_fn[ref]
236
+ decl = self.__egg_decls__.get_callable_decl(ref)
237
+ egg_name = decl.egg_name or _sanitize_egg_ident(self._generate_callable_egg_name(ref))
238
+ self.egg_fn_to_callable_refs[egg_name].add(ref)
239
+ reverse_args = False
240
+ match decl:
241
+ case RelationDecl(arg_types, _, _):
242
+ self.egraph.run_program(
243
+ bindings.Relation(span(), egg_name, [self.type_ref_to_egg(a) for a in arg_types])
244
+ )
245
+ case ConstantDecl(tp, _):
246
+ # Use constructor decleration instead of constant b/c constants cannot be extracted
247
+ # https://github.com/egraphs-good/egglog/issues/334
248
+ self.egraph.run_program(
249
+ bindings.Constructor(span(), egg_name, bindings.Schema([], self.type_ref_to_egg(tp)), None, False)
250
+ )
251
+ case FunctionDecl(signature, builtin, _, merge):
252
+ if isinstance(signature, FunctionSignature):
253
+ reverse_args = signature.reverse_args
254
+ if not builtin:
255
+ assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
256
+ # Compile functions that return unit to relations, because these show up in methods where you
257
+ # cant use the relation helper
258
+ schema = self._signature_to_egg_schema(signature)
259
+ if signature.return_type == TypeRefWithVars("Unit"):
260
+ if merge:
261
+ msg = "Cannot specify a merge function for a function that returns unit"
262
+ raise ValueError(msg)
263
+ self.egraph.run_program(bindings.Relation(span(), egg_name, schema.input))
264
+ else:
265
+ self.egraph.run_program(
266
+ bindings.Function(
267
+ span(),
268
+ egg_name,
269
+ self._signature_to_egg_schema(signature),
270
+ self._expr_to_egg(merge) if merge else None,
271
+ )
272
+ )
273
+ case ConstructorDecl(signature, _, cost, unextractable):
274
+ self.egraph.run_program(
275
+ bindings.Constructor(
276
+ span(),
277
+ egg_name,
278
+ self._signature_to_egg_schema(signature),
279
+ cost,
280
+ unextractable,
281
+ )
282
+ )
283
+
284
+ case _:
285
+ assert_never(decl)
286
+ self.callable_ref_to_egg_fn[ref] = egg_name, reverse_args
287
+ return egg_name, reverse_args
288
+
289
+ def _signature_to_egg_schema(self, signature: FunctionSignature) -> bindings.Schema:
290
+ return bindings.Schema(
291
+ [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
292
+ self.type_ref_to_egg(signature.semantic_return_type.to_just()),
293
+ )
294
+
295
+ def type_ref_to_egg(self, ref: JustTypeRef) -> str: # noqa: C901, PLR0912
296
+ """
297
+ Returns the egg sort name for a type reference, registering it if it is not already registered.
298
+ """
299
+ try:
300
+ return self.type_ref_to_egg_sort[ref]
301
+ except KeyError:
302
+ pass
303
+ decl = self.__egg_decls__._classes[ref.name]
304
+ self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
305
+ if not decl.builtin or ref.args:
306
+ if ref.args:
307
+ if ref.name == "UnstableFn":
308
+ # UnstableFn is a special case, where the rest of args are collected into a call
309
+ type_args: list[bindings._Expr] = [
310
+ bindings.Call(
311
+ span(),
312
+ self.type_ref_to_egg(ref.args[1]),
313
+ [bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args[2:]],
314
+ ),
315
+ bindings.Var(span(), self.type_ref_to_egg(ref.args[0])),
316
+ ]
317
+ else:
318
+ # If any of methods have another type ref in them process all those first with substituted vars
319
+ # so that things like multiset - mapp will be added. Function type must be added first.
320
+ # Find all args of all methods and find any with type args themselves that are not this type and add them
321
+ tcs = TypeConstraintSolver(self.__egg_decls__)
322
+ tcs.bind_class(ref)
323
+ for method in decl.methods.values():
324
+ if not isinstance((signature := method.signature), FunctionSignature):
325
+ continue
326
+ for arg_tp in signature.arg_types:
327
+ if isinstance(arg_tp, TypeRefWithVars) and arg_tp.args and arg_tp.name != ref.name:
328
+ self.type_ref_to_egg(tcs.substitute_typevars(arg_tp, ref.name))
329
+
330
+ type_args = [bindings.Var(span(), self.type_ref_to_egg(a)) for a in ref.args]
331
+ args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
332
+ else:
333
+ args = None
334
+ self.egraph.run_program(bindings.Sort(span(), egg_name, args))
335
+ # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
336
+ # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
337
+ # even if you never use that function.
338
+ if decl.builtin:
339
+ for method_name in decl.class_methods:
340
+ self.callable_ref_to_egg(ClassMethodRef(ref.name, method_name))
341
+ if decl.init:
342
+ self.callable_ref_to_egg(InitRef(ref.name))
343
+
344
+ return egg_name
345
+
346
+ def op_mapping(self) -> dict[str, str]:
347
+ """
348
+ Create a mapping of egglog function name to Python function name, for use in the serialized format
349
+ for better visualization.
350
+ """
351
+ return {
352
+ k: pretty_callable_ref(self.__egg_decls__, next(iter(v)))
353
+ for k, v in self.egg_fn_to_callable_refs.items()
354
+ if len(v) == 1
355
+ }
356
+
357
+ def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
358
+ # transform all expressions with multiple parents into a let binding, so that less expressions
359
+ # are sent to egglog. Only for performance reasons.
360
+ if transform_let:
361
+ have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
362
+ for expr in reversed(have_multiple_parents):
363
+ self._transform_let(expr)
364
+
365
+ self.type_ref_to_egg(typed_expr_decl.tp)
366
+ return self._expr_to_egg(typed_expr_decl.expr)
367
+
368
+ def _transform_let(self, typed_expr: TypedExprDecl) -> TypedExprDecl | None:
369
+ """
370
+ Rewrites this expression as a let binding if it's not already a let binding.
371
+ """
372
+ var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
373
+ if var_decl in self.expr_to_egg_cache:
374
+ return None
375
+ var_egg = self._expr_to_egg(var_decl)
376
+ cmd = bindings.ActionCommand(bindings.Let(span(), var_egg.name, self.typed_expr_to_egg(typed_expr)))
377
+ try:
378
+ self.egraph.run_program(cmd)
379
+ # errors when creating let bindings for things like `(vec-empty)`
380
+ except bindings.EggSmolError:
381
+ return typed_expr
382
+ self.expr_to_egg_cache[typed_expr.expr] = var_egg
383
+ self.expr_to_egg_cache[var_decl] = var_egg
384
+ return None
385
+
386
+ @overload
387
+ def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
388
+
389
+ @overload
390
+ def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
391
+
392
+ @overload
393
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
394
+
395
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912,C901
396
+ """
397
+ Convert an ExprDecl to an egg expression.
398
+ """
399
+ try:
400
+ return self.expr_to_egg_cache[expr_decl]
401
+ except KeyError:
402
+ pass
403
+ res: bindings._Expr
404
+ match expr_decl:
405
+ case VarDecl(name, is_let):
406
+ # prefix let bindings with % to avoid name conflicts with rewrites
407
+ if is_let:
408
+ name = f"%{name}"
409
+ res = bindings.Var(span(), name)
410
+ case LitDecl(value):
411
+ l: bindings._Literal
412
+ match value:
413
+ case None:
414
+ l = bindings.Unit()
415
+ case bool(i):
416
+ l = bindings.Bool(i)
417
+ case int(i):
418
+ l = bindings.Int(i)
419
+ case float(f):
420
+ l = bindings.Float(f)
421
+ case str(s):
422
+ l = bindings.String(s)
423
+ case _:
424
+ assert_never(value)
425
+ res = bindings.Lit(span(), l)
426
+ case CallDecl(ref, args, _):
427
+ egg_fn, reverse_args = self.callable_ref_to_egg(ref)
428
+ egg_args = [self.typed_expr_to_egg(a, False) for a in args]
429
+ if reverse_args:
430
+ egg_args.reverse()
431
+ res = bindings.Call(span(), egg_fn, egg_args)
432
+ case PyObjectDecl(value):
433
+ res = GLOBAL_PY_OBJECT_SORT.store(value)
434
+ case PartialCallDecl(call_decl):
435
+ egg_fn_call = self._expr_to_egg(call_decl)
436
+ res = bindings.Call(
437
+ span(),
438
+ "unstable-fn",
439
+ [bindings.Lit(span(), bindings.String(egg_fn_call.name)), *egg_fn_call.args],
440
+ )
441
+ case _:
442
+ assert_never(expr_decl.expr)
443
+ self.expr_to_egg_cache[expr_decl] = res
444
+ return res
445
+
446
+ def exprs_from_egg(
447
+ self, termdag: bindings.TermDag, terms: list[bindings._Term], tp: JustTypeRef
448
+ ) -> Iterable[TypedExprDecl]:
449
+ """
450
+ Create a function that can convert from an egg term to a typed expr.
451
+ """
452
+ state = FromEggState(self, termdag)
453
+ return [state.from_expr(tp, term) for term in terms]
454
+
455
+ def _get_possible_types(self, cls_name: str) -> frozenset[JustTypeRef]:
456
+ """
457
+ Given a class name, returns all possible registered types that it can be.
458
+ """
459
+ return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
460
+
461
+ def _generate_callable_egg_name(self, ref: CallableRef) -> str:
462
+ """
463
+ Generates a valid egg function name for a callable reference.
464
+ """
465
+ match ref:
466
+ case FunctionRef(name):
467
+ return name
468
+
469
+ case ConstantRef(name):
470
+ return name
471
+ case (
472
+ MethodRef(cls_name, name)
473
+ | ClassMethodRef(cls_name, name)
474
+ | ClassVariableRef(cls_name, name)
475
+ | PropertyRef(cls_name, name)
476
+ ):
477
+ return f"{cls_name}.{name}"
478
+ case InitRef(cls_name):
479
+ return f"{cls_name}.__init__"
480
+ case UnnamedFunctionRef(args, val):
481
+ parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
482
+ str(self.typed_expr_to_egg(val, False))
483
+ ]
484
+ return "_".join(parts)
485
+ case _:
486
+ assert_never(ref)
487
+
488
+
489
+ # https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
490
+ _EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
491
+
492
+
493
+ def _sanitize_egg_ident(input_string: str) -> str:
494
+ """
495
+ Replaces all invalid characters in an egg identifier with an underscore.
496
+ """
497
+ return _EGGLOG_INVALID_IDENT.sub("_", input_string)
498
+
499
+
500
+ def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
501
+ """
502
+ Returns all expressions that have multiple parents (a list but semantically just an ordered set).
503
+ """
504
+ to_traverse = {typed_expr}
505
+ traversed = set[TypedExprDecl]()
506
+ traversed_twice = list[TypedExprDecl]()
507
+ while to_traverse:
508
+ typed_expr = to_traverse.pop()
509
+ if typed_expr in traversed:
510
+ traversed_twice.append(typed_expr)
511
+ continue
512
+ traversed.add(typed_expr)
513
+ expr = typed_expr.expr
514
+ if isinstance(expr, CallDecl):
515
+ to_traverse.update(expr.args)
516
+ elif isinstance(expr, PartialCallDecl):
517
+ to_traverse.update(expr.call.args)
518
+ return traversed_twice
519
+
520
+
521
+ def _generate_type_egg_name(ref: JustTypeRef) -> str:
522
+ """
523
+ Generates an egg sort name for this type reference by linearizing the type.
524
+ """
525
+ name = ref.name
526
+ if not ref.args:
527
+ return name
528
+ return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}"
529
+
530
+
531
+ @dataclass
532
+ class FromEggState:
533
+ """
534
+ Dataclass containing state used when converting from an egg term to a typed expr.
535
+ """
536
+
537
+ state: EGraphState
538
+ termdag: bindings.TermDag
539
+ # Cache of termdag ID to TypedExprDecl
540
+ cache: dict[int, TypedExprDecl] = field(default_factory=dict)
541
+
542
+ @property
543
+ def decls(self) -> Declarations:
544
+ return self.state.__egg_decls__
545
+
546
+ def from_expr(self, tp: JustTypeRef, term: bindings._Term) -> TypedExprDecl:
547
+ """
548
+ Convert an egg term to a typed expr.
549
+ """
550
+ expr_decl: ExprDecl
551
+ if isinstance(term, bindings.TermVar):
552
+ expr_decl = VarDecl(term.name, True)
553
+ elif isinstance(term, bindings.TermLit):
554
+ value = term.value
555
+ expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
556
+ elif isinstance(term, bindings.TermApp):
557
+ if term.name == "py-object":
558
+ call = self.termdag.term_to_expr(term, span())
559
+ expr_decl = PyObjectDecl(GLOBAL_PY_OBJECT_SORT.load(call))
560
+ elif term.name == "unstable-fn":
561
+ # Get function name
562
+ fn_term, *arg_terms = term.args
563
+ fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
564
+ assert isinstance(fn_value.expr, LitDecl)
565
+ fn_name = fn_value.expr.value
566
+ assert isinstance(fn_name, str)
567
+
568
+ # Resolve what types the partially applied args are
569
+ assert tp.name == "UnstableFn"
570
+ call_decl = self.from_call(tp.args[0], bindings.TermApp(fn_name, arg_terms))
571
+ expr_decl = PartialCallDecl(call_decl)
572
+ else:
573
+ expr_decl = self.from_call(tp, term)
574
+ else:
575
+ assert_never(term)
576
+ return TypedExprDecl(tp, expr_decl)
577
+
578
+ def from_call(
579
+ self,
580
+ tp: JustTypeRef,
581
+ term: bindings.TermApp, # additional_arg_tps: tuple[JustTypeRef, ...]
582
+ ) -> CallDecl:
583
+ """
584
+ Convert a call to a CallDecl.
585
+
586
+ There could be Python call refs which match the call, so we need to find the correct one.
587
+
588
+ The additional_arg_tps are known types for arguments that come after the term args, used to infer types
589
+ for partially applied functions, where we know the types of the later args, but not of the earlier ones where
590
+ we have values for.
591
+ """
592
+ # Find the first callable ref that matches the call
593
+ for callable_ref in self.state.egg_fn_to_callable_refs[term.name]:
594
+ # If this is a classmethod, we might need the type params that were bound for this type
595
+ # This could be multiple types if the classmethod is ambiguous, like map create.
596
+ possible_types: Iterable[JustTypeRef | None]
597
+ signature = self.decls.get_callable_decl(callable_ref).signature
598
+ assert isinstance(signature, FunctionSignature)
599
+ if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
600
+ # Need OR in case we have class method whose class whas never added as a sort, which would happen
601
+ # if the class method didn't return that type and no other function did. In this case, we don't need
602
+ # to care about the type vars and we we don't need to bind any possible type.
603
+ possible_types = self.state._get_possible_types(callable_ref.class_name) or [None]
604
+ cls_name = callable_ref.class_name
605
+ else:
606
+ possible_types = [None]
607
+ cls_name = None
608
+ for possible_type in possible_types:
609
+ tcs = TypeConstraintSolver(self.decls)
610
+ if possible_type and possible_type.args:
611
+ tcs.bind_class(possible_type)
612
+ try:
613
+ arg_types, bound_tp_params = tcs.infer_arg_types(
614
+ signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
615
+ )
616
+ except TypeConstraintError:
617
+ continue
618
+ args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
619
+
620
+ return CallDecl(
621
+ callable_ref,
622
+ args,
623
+ # Don't include bound type params if this is just a method, we only needed them for type resolution
624
+ # but dont need to store them
625
+ bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
626
+ )
627
+ raise ValueError(f"Could not find callable ref for call {term}")
628
+
629
+ def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
630
+ try:
631
+ return self.cache[term_id]
632
+ except KeyError:
633
+ res = self.cache[term_id] = self.from_expr(tp, self.termdag.get(term_id))
634
+ return res
@@ -0,0 +1,5 @@
1
+ Examples Gallery
2
+ ================
3
+
4
+ This is a gallery of examples, most of which were translated from the original
5
+ `egglog rust examples <https://github.com/egraphs-good/egglog/tree/08a6e8fecdb77e6ba72a1b1d9ff4aff33229912c/tests>`_.
@@ -0,0 +1,3 @@
1
+ """
2
+ Examples using egglog.
3
+ """
@@ -0,0 +1,31 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ BigNum/BigRat Example
4
+ =====================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from egglog import *
10
+
11
+ x = BigInt(-1234)
12
+ y = BigInt.from_string("2")
13
+ z = BigRat(x, y)
14
+
15
+ assert z.numer.to_string() == "-617"
16
+
17
+
18
+ @function
19
+ def bignums(x: BigInt, y: BigInt) -> BigRat: ...
20
+
21
+
22
+ egraph = EGraph()
23
+ egraph.register(set_(bignums(x, y)).to(z))
24
+
25
+ c = var("c", BigRat)
26
+ a, b = vars_("a b", BigInt)
27
+ egraph.check(
28
+ bignums(a, b) == c,
29
+ c.numer == a >> 1,
30
+ c.denom == b >> 1,
31
+ )