egglog 7.2.0__cp311-none-win_amd64.whl → 8.0.1__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph_state.py CHANGED
@@ -4,10 +4,10 @@ Implement conversion to/from egglog.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ import re
7
8
  from collections import defaultdict
8
9
  from dataclasses import dataclass, field
9
10
  from typing import TYPE_CHECKING, overload
10
- from weakref import WeakKeyDictionary
11
11
 
12
12
  from typing_extensions import assert_never
13
13
 
@@ -19,7 +19,7 @@ from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
19
19
  if TYPE_CHECKING:
20
20
  from collections.abc import Iterable
21
21
 
22
- __all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT", "_rule_var_name"]
22
+ __all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
23
23
 
24
24
  # Create a global sort for python objects, so we can store them without an e-graph instance
25
25
  # Needed when serializing commands to egg commands when creating modules
@@ -52,7 +52,7 @@ class EGraphState:
52
52
  type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
53
53
 
54
54
  # Cache of egg expressions for converting to egg
55
- expr_to_egg_cache: WeakKeyDictionary[ExprDecl, bindings._Expr] = field(default_factory=WeakKeyDictionary)
55
+ expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
56
56
 
57
57
  def copy(self) -> EGraphState:
58
58
  """
@@ -71,15 +71,15 @@ class EGraphState:
71
71
  def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
72
72
  match schedule:
73
73
  case SaturateDecl(schedule):
74
- return bindings.Saturate(self.schedule_to_egg(schedule))
74
+ return bindings.Saturate(bindings.DUMMY_SPAN, self.schedule_to_egg(schedule))
75
75
  case RepeatDecl(schedule, times):
76
- return bindings.Repeat(times, self.schedule_to_egg(schedule))
76
+ return bindings.Repeat(bindings.DUMMY_SPAN, times, self.schedule_to_egg(schedule))
77
77
  case SequenceDecl(schedules):
78
- return bindings.Sequence([self.schedule_to_egg(s) for s in schedules])
78
+ return bindings.Sequence(bindings.DUMMY_SPAN, [self.schedule_to_egg(s) for s in schedules])
79
79
  case RunDecl(ruleset_name, until):
80
80
  self.ruleset_to_egg(ruleset_name)
81
81
  config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
82
- return bindings.Run(config)
82
+ return bindings.Run(bindings.DUMMY_SPAN, config)
83
83
  case _:
84
84
  assert_never(schedule)
85
85
 
@@ -116,6 +116,7 @@ class EGraphState:
116
116
  case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
117
117
  self.type_ref_to_egg(tp)
118
118
  rewrite = bindings.Rewrite(
119
+ bindings.DUMMY_SPAN,
119
120
  self._expr_to_egg(lhs),
120
121
  self._expr_to_egg(rhs),
121
122
  [self.fact_to_egg(c) for c in conditions],
@@ -127,19 +128,24 @@ class EGraphState:
127
128
  )
128
129
  case RuleDecl(head, body, name):
129
130
  rule = bindings.Rule(
131
+ bindings.DUMMY_SPAN,
130
132
  [self.action_to_egg(a) for a in head],
131
133
  [self.fact_to_egg(f) for f in body],
132
134
  )
133
135
  return bindings.RuleCommand(name or "", ruleset, rule)
136
+ # TODO: Replace with just constants value and looking at REF of function
134
137
  case DefaultRewriteDecl(ref, expr):
135
138
  decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
136
139
  sig = decl.signature
137
140
  assert isinstance(sig, FunctionSignature)
138
- args = tuple(
139
- TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name)))
141
+ # Replace args with rule_var_name mapping
142
+ arg_mapping = tuple(
143
+ TypedExprDecl(tp.to_just(), VarDecl(name, False))
140
144
  for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
141
145
  )
142
- rewrite_decl = RewriteDecl(sig.semantic_return_type.to_just(), CallDecl(ref, args), expr, (), False)
146
+ rewrite_decl = RewriteDecl(
147
+ sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), False
148
+ )
143
149
  return self.command_to_egg(rewrite_decl, ruleset)
144
150
  case _:
145
151
  assert_never(cmd)
@@ -147,13 +153,16 @@ class EGraphState:
147
153
  def action_to_egg(self, action: ActionDecl) -> bindings._Action:
148
154
  match action:
149
155
  case LetDecl(name, typed_expr):
150
- return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
156
+ var_decl = VarDecl(name, True)
157
+ var_egg = self._expr_to_egg(var_decl)
158
+ self.expr_to_egg_cache[var_decl] = var_egg
159
+ return bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
151
160
  case SetDecl(tp, call, rhs):
152
161
  self.type_ref_to_egg(tp)
153
162
  call_ = self._expr_to_egg(call)
154
- return bindings.Set(call_.name, call_.args, self._expr_to_egg(rhs))
163
+ return bindings.Set(bindings.DUMMY_SPAN, call_.name, call_.args, self._expr_to_egg(rhs))
155
164
  case ExprActionDecl(typed_expr):
156
- return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
165
+ return bindings.Expr_(bindings.DUMMY_SPAN, self.typed_expr_to_egg(typed_expr))
157
166
  case ChangeDecl(tp, call, change):
158
167
  self.type_ref_to_egg(tp)
159
168
  call_ = self._expr_to_egg(call)
@@ -165,12 +174,12 @@ class EGraphState:
165
174
  egg_change = bindings.Subsume()
166
175
  case _:
167
176
  assert_never(change)
168
- return bindings.Change(egg_change, call_.name, call_.args)
177
+ return bindings.Change(bindings.DUMMY_SPAN, egg_change, call_.name, call_.args)
169
178
  case UnionDecl(tp, lhs, rhs):
170
179
  self.type_ref_to_egg(tp)
171
- return bindings.Union(self._expr_to_egg(lhs), self._expr_to_egg(rhs))
180
+ return bindings.Union(bindings.DUMMY_SPAN, self._expr_to_egg(lhs), self._expr_to_egg(rhs))
172
181
  case PanicDecl(name):
173
- return bindings.Panic(name)
182
+ return bindings.Panic(bindings.DUMMY_SPAN, name)
174
183
  case _:
175
184
  assert_never(action)
176
185
 
@@ -178,9 +187,9 @@ class EGraphState:
178
187
  match fact:
179
188
  case EqDecl(tp, exprs):
180
189
  self.type_ref_to_egg(tp)
181
- return bindings.Eq([self._expr_to_egg(e) for e in exprs])
190
+ return bindings.Eq(bindings.DUMMY_SPAN, [self._expr_to_egg(e) for e in exprs])
182
191
  case ExprFactDecl(typed_expr):
183
- return bindings.Fact(self.typed_expr_to_egg(typed_expr))
192
+ return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
184
193
  case _:
185
194
  assert_never(fact)
186
195
 
@@ -191,22 +200,31 @@ class EGraphState:
191
200
  if ref in self.callable_ref_to_egg_fn:
192
201
  return self.callable_ref_to_egg_fn[ref]
193
202
  decl = self.__egg_decls__.get_callable_decl(ref)
194
- self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _generate_callable_egg_name(ref)
203
+ self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _sanitize_egg_ident(
204
+ self._generate_callable_egg_name(ref)
205
+ )
195
206
  self.egg_fn_to_callable_refs[egg_name].add(ref)
196
207
  match decl:
197
208
  case RelationDecl(arg_types, _, _):
198
- self.egraph.run_program(bindings.Relation(egg_name, [self.type_ref_to_egg(a) for a in arg_types]))
209
+ self.egraph.run_program(
210
+ bindings.Relation(bindings.DUMMY_SPAN, egg_name, [self.type_ref_to_egg(a) for a in arg_types])
211
+ )
199
212
  case ConstantDecl(tp, _):
200
213
  # Use function decleration instead of constant b/c constants cannot be extracted
201
214
  # https://github.com/egraphs-good/egglog/issues/334
202
215
  self.egraph.run_program(
203
- bindings.Function(bindings.FunctionDecl(egg_name, bindings.Schema([], self.type_ref_to_egg(tp))))
216
+ bindings.Function(
217
+ bindings.FunctionDecl(
218
+ bindings.DUMMY_SPAN, egg_name, bindings.Schema([], self.type_ref_to_egg(tp))
219
+ )
220
+ )
204
221
  )
205
222
  case FunctionDecl():
206
223
  if not decl.builtin:
207
224
  signature = decl.signature
208
225
  assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
209
226
  egg_fn_decl = bindings.FunctionDecl(
227
+ bindings.DUMMY_SPAN,
210
228
  egg_name,
211
229
  bindings.Schema(
212
230
  [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
@@ -239,17 +257,18 @@ class EGraphState:
239
257
  # UnstableFn is a special case, where the rest of args are collected into a call
240
258
  type_args: list[bindings._Expr] = [
241
259
  bindings.Call(
260
+ bindings.DUMMY_SPAN,
242
261
  self.type_ref_to_egg(ref.args[1]),
243
- [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
262
+ [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args[2:]],
244
263
  ),
245
- bindings.Var(self.type_ref_to_egg(ref.args[0])),
264
+ bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(ref.args[0])),
246
265
  ]
247
266
  else:
248
- type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
267
+ type_args = [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args]
249
268
  args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
250
269
  else:
251
270
  args = None
252
- self.egraph.run_program(bindings.Sort(egg_name, args))
271
+ self.egraph.run_program(bindings.Sort(bindings.DUMMY_SPAN, egg_name, args))
253
272
  # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
254
273
  # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
255
274
  # even if you never use that function.
@@ -272,31 +291,60 @@ class EGraphState:
272
291
  if len(v) == 1
273
292
  }
274
293
 
275
- def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
294
+ def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
295
+ # transform all expressions with multiple parents into a let binding, so that less expressions
296
+ # are sent to egglog. Only for performance reasons.
297
+ if transform_let:
298
+ have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
299
+ for expr in reversed(have_multiple_parents):
300
+ self._transform_let(expr)
301
+
276
302
  self.type_ref_to_egg(typed_expr_decl.tp)
277
303
  return self._expr_to_egg(typed_expr_decl.expr)
278
304
 
305
+ def _transform_let(self, typed_expr: TypedExprDecl) -> None:
306
+ """
307
+ Rewrites this expression as a let binding if it's not already a let binding.
308
+ """
309
+ var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
310
+ if var_decl in self.expr_to_egg_cache:
311
+ return
312
+ var_egg = self._expr_to_egg(var_decl)
313
+ cmd = bindings.ActionCommand(
314
+ bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
315
+ )
316
+ try:
317
+ self.egraph.run_program(cmd)
318
+ # errors when creating let bindings for things like `(vec-empty)`
319
+ except bindings.EggSmolError:
320
+ return
321
+ self.expr_to_egg_cache[typed_expr.expr] = var_egg
322
+ self.expr_to_egg_cache[var_decl] = var_egg
323
+
279
324
  @overload
280
325
  def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
281
326
 
327
+ @overload
328
+ def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
329
+
282
330
  @overload
283
331
  def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
284
332
 
285
- def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
333
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: # noqa: PLR0912,C901
286
334
  """
287
335
  Convert an ExprDecl to an egg expression.
288
-
289
- Cached using weakrefs to avoid memory leaks.
290
336
  """
291
337
  try:
292
338
  return self.expr_to_egg_cache[expr_decl]
293
339
  except KeyError:
294
340
  pass
295
-
296
341
  res: bindings._Expr
297
342
  match expr_decl:
298
- case VarDecl(name):
299
- res = bindings.Var(name)
343
+ case VarDecl(name, is_let):
344
+ # prefix let bindings with % to avoid name conflicts with rewrites
345
+ if is_let:
346
+ name = f"%{name}"
347
+ res = bindings.Var(bindings.DUMMY_SPAN, name)
300
348
  case LitDecl(value):
301
349
  l: bindings._Literal
302
350
  match value:
@@ -312,19 +360,22 @@ class EGraphState:
312
360
  l = bindings.String(s)
313
361
  case _:
314
362
  assert_never(value)
315
- res = bindings.Lit(l)
363
+ res = bindings.Lit(bindings.DUMMY_SPAN, l)
316
364
  case CallDecl(ref, args, _):
317
365
  egg_fn = self.callable_ref_to_egg(ref)
318
- egg_args = [self.typed_expr_to_egg(a) for a in args]
319
- res = bindings.Call(egg_fn, egg_args)
366
+ egg_args = [self.typed_expr_to_egg(a, False) for a in args]
367
+ res = bindings.Call(bindings.DUMMY_SPAN, egg_fn, egg_args)
320
368
  case PyObjectDecl(value):
321
369
  res = GLOBAL_PY_OBJECT_SORT.store(value)
322
370
  case PartialCallDecl(call_decl):
323
371
  egg_fn_call = self._expr_to_egg(call_decl)
324
- res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
372
+ res = bindings.Call(
373
+ bindings.DUMMY_SPAN,
374
+ "unstable-fn",
375
+ [bindings.Lit(bindings.DUMMY_SPAN, bindings.String(egg_fn_call.name)), *egg_fn_call.args],
376
+ )
325
377
  case _:
326
378
  assert_never(expr_decl.expr)
327
-
328
379
  self.expr_to_egg_cache[expr_decl] = res
329
380
  return res
330
381
 
@@ -343,6 +394,65 @@ class EGraphState:
343
394
  """
344
395
  return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
345
396
 
397
+ def _generate_callable_egg_name(self, ref: CallableRef) -> str:
398
+ """
399
+ Generates a valid egg function name for a callable reference.
400
+ """
401
+ match ref:
402
+ case FunctionRef(name):
403
+ return name
404
+
405
+ case ConstantRef(name):
406
+ return name
407
+ case (
408
+ MethodRef(cls_name, name)
409
+ | ClassMethodRef(cls_name, name)
410
+ | ClassVariableRef(cls_name, name)
411
+ | PropertyRef(cls_name, name)
412
+ ):
413
+ return f"{cls_name}.{name}"
414
+ case InitRef(cls_name):
415
+ return f"{cls_name}.__init__"
416
+ case UnnamedFunctionRef(args, val):
417
+ parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
418
+ str(self.typed_expr_to_egg(val, False))
419
+ ]
420
+ return "_".join(parts)
421
+ case _:
422
+ assert_never(ref)
423
+
424
+
425
+ # https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
426
+ _EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
427
+
428
+
429
+ def _sanitize_egg_ident(input_string: str) -> str:
430
+ """
431
+ Replaces all invalid characters in an egg identifier with an underscore.
432
+ """
433
+ return _EGGLOG_INVALID_IDENT.sub("_", input_string)
434
+
435
+
436
+ def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
437
+ """
438
+ Returns all expressions that have multiple parents (a list but semantically just an ordered set).
439
+ """
440
+ to_traverse = {typed_expr}
441
+ traversed = set[TypedExprDecl]()
442
+ traversed_twice = list[TypedExprDecl]()
443
+ while to_traverse:
444
+ typed_expr = to_traverse.pop()
445
+ if typed_expr in traversed:
446
+ traversed_twice.append(typed_expr)
447
+ continue
448
+ traversed.add(typed_expr)
449
+ expr = typed_expr.expr
450
+ if isinstance(expr, CallDecl):
451
+ to_traverse.update(expr.args)
452
+ elif isinstance(expr, PartialCallDecl):
453
+ to_traverse.update(expr.call.args)
454
+ return traversed_twice
455
+
346
456
 
347
457
  def _generate_type_egg_name(ref: JustTypeRef) -> str:
348
458
  """
@@ -354,26 +464,6 @@ def _generate_type_egg_name(ref: JustTypeRef) -> str:
354
464
  return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}"
355
465
 
356
466
 
357
- def _generate_callable_egg_name(ref: CallableRef) -> str:
358
- """
359
- Generates a valid egg function name for a callable reference.
360
- """
361
- match ref:
362
- case FunctionRef(name) | ConstantRef(name):
363
- return name
364
- case (
365
- MethodRef(cls_name, name)
366
- | ClassMethodRef(cls_name, name)
367
- | ClassVariableRef(cls_name, name)
368
- | PropertyRef(cls_name, name)
369
- ):
370
- return f"{cls_name}_{name}"
371
- case InitRef(cls_name):
372
- return f"{cls_name}___init__"
373
- case _:
374
- assert_never(ref)
375
-
376
-
377
467
  @dataclass
378
468
  class FromEggState:
379
469
  """
@@ -395,7 +485,7 @@ class FromEggState:
395
485
  """
396
486
  expr_decl: ExprDecl
397
487
  if isinstance(term, bindings.TermVar):
398
- expr_decl = VarDecl(term.name)
488
+ expr_decl = VarDecl(term.name, True)
399
489
  elif isinstance(term, bindings.TermLit):
400
490
  value = term.value
401
491
  expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
@@ -403,7 +493,7 @@ class FromEggState:
403
493
  if term.name == "py-object":
404
494
  call = bindings.termdag_term_to_expr(self.termdag, term)
405
495
  expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
406
- if term.name == "unstable-fn":
496
+ elif term.name == "unstable-fn":
407
497
  # Get function name
408
498
  fn_term, *arg_terms = term.args
409
499
  fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
@@ -478,10 +568,3 @@ class FromEggState:
478
568
  except KeyError:
479
569
  res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
480
570
  return res
481
-
482
-
483
- def _rule_var_name(s: str) -> str:
484
- """
485
- Create a hidden variable name, for rewrites, so that let bindings or function won't conflict with it
486
- """
487
- return f"__var__{s}"
@@ -16,35 +16,30 @@ if TYPE_CHECKING:
16
16
 
17
17
  class Math(Expr):
18
18
  def __init__(self, i: i64Like) -> None: ...
19
-
20
19
  def __add__(self, other: Math) -> Math: ...
21
20
 
22
21
 
23
22
  class MathList(Expr):
24
23
  def __init__(self) -> None: ...
25
-
26
24
  def append(self, i: Math) -> MathList: ...
27
-
28
25
  def map(self, f: Callable[[Math], Math]) -> MathList: ...
29
26
 
30
27
 
31
28
  @ruleset
32
- def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]): # noqa: ANN201
29
+ def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]):
33
30
  yield rewrite(Math(i) + Math(j)).to(Math(i + j))
34
31
  yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
35
32
  yield rewrite(MathList().map(f)).to(MathList())
36
33
 
37
34
 
38
35
  @function(ruleset=math_ruleset)
39
- def increment_by_one(x: Math) -> Math:
40
- return x + Math(1)
36
+ def incr_list(xs: MathList) -> MathList:
37
+ return xs.map(lambda x: x + Math(1))
41
38
 
42
39
 
43
40
  egraph = EGraph()
44
- x = egraph.let("x", MathList().append(Math(1)).append(Math(2)))
45
- y = egraph.let("y", x.map(increment_by_one))
41
+ y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
46
42
  egraph.run(math_ruleset.saturate())
47
-
48
43
  egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
49
44
 
50
45
  egraph