egglog 7.2.0__cp312-none-win_amd64.whl → 8.0.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph_state.py CHANGED
@@ -4,10 +4,10 @@ Implement conversion to/from egglog.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ import re
7
8
  from collections import defaultdict
8
9
  from dataclasses import dataclass, field
9
10
  from typing import TYPE_CHECKING, overload
10
- from weakref import WeakKeyDictionary
11
11
 
12
12
  from typing_extensions import assert_never
13
13
 
@@ -19,7 +19,7 @@ from .type_constraint_solver import TypeConstraintError, TypeConstraintSolver
19
19
  if TYPE_CHECKING:
20
20
  from collections.abc import Iterable
21
21
 
22
- __all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT", "_rule_var_name"]
22
+ __all__ = ["EGraphState", "GLOBAL_PY_OBJECT_SORT"]
23
23
 
24
24
  # Create a global sort for python objects, so we can store them without an e-graph instance
25
25
  # Needed when serializing commands to egg commands when creating modules
@@ -52,7 +52,7 @@ class EGraphState:
52
52
  type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
53
53
 
54
54
  # Cache of egg expressions for converting to egg
55
- expr_to_egg_cache: WeakKeyDictionary[ExprDecl, bindings._Expr] = field(default_factory=WeakKeyDictionary)
55
+ expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
56
56
 
57
57
  def copy(self) -> EGraphState:
58
58
  """
@@ -71,15 +71,15 @@ class EGraphState:
71
71
  def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
72
72
  match schedule:
73
73
  case SaturateDecl(schedule):
74
- return bindings.Saturate(self.schedule_to_egg(schedule))
74
+ return bindings.Saturate(bindings.DUMMY_SPAN, self.schedule_to_egg(schedule))
75
75
  case RepeatDecl(schedule, times):
76
- return bindings.Repeat(times, self.schedule_to_egg(schedule))
76
+ return bindings.Repeat(bindings.DUMMY_SPAN, times, self.schedule_to_egg(schedule))
77
77
  case SequenceDecl(schedules):
78
- return bindings.Sequence([self.schedule_to_egg(s) for s in schedules])
78
+ return bindings.Sequence(bindings.DUMMY_SPAN, [self.schedule_to_egg(s) for s in schedules])
79
79
  case RunDecl(ruleset_name, until):
80
80
  self.ruleset_to_egg(ruleset_name)
81
81
  config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
82
- return bindings.Run(config)
82
+ return bindings.Run(bindings.DUMMY_SPAN, config)
83
83
  case _:
84
84
  assert_never(schedule)
85
85
 
@@ -116,6 +116,7 @@ class EGraphState:
116
116
  case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
117
117
  self.type_ref_to_egg(tp)
118
118
  rewrite = bindings.Rewrite(
119
+ bindings.DUMMY_SPAN,
119
120
  self._expr_to_egg(lhs),
120
121
  self._expr_to_egg(rhs),
121
122
  [self.fact_to_egg(c) for c in conditions],
@@ -127,19 +128,24 @@ class EGraphState:
127
128
  )
128
129
  case RuleDecl(head, body, name):
129
130
  rule = bindings.Rule(
131
+ bindings.DUMMY_SPAN,
130
132
  [self.action_to_egg(a) for a in head],
131
133
  [self.fact_to_egg(f) for f in body],
132
134
  )
133
135
  return bindings.RuleCommand(name or "", ruleset, rule)
136
+ # TODO: Replace with just constants value and looking at REF of function
134
137
  case DefaultRewriteDecl(ref, expr):
135
138
  decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
136
139
  sig = decl.signature
137
140
  assert isinstance(sig, FunctionSignature)
138
- args = tuple(
139
- TypedExprDecl(tp.to_just(), VarDecl(_rule_var_name(name)))
141
+ # Replace args with rule_var_name mapping
142
+ arg_mapping = tuple(
143
+ TypedExprDecl(tp.to_just(), VarDecl(name, False))
140
144
  for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
141
145
  )
142
- rewrite_decl = RewriteDecl(sig.semantic_return_type.to_just(), CallDecl(ref, args), expr, (), False)
146
+ rewrite_decl = RewriteDecl(
147
+ sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), False
148
+ )
143
149
  return self.command_to_egg(rewrite_decl, ruleset)
144
150
  case _:
145
151
  assert_never(cmd)
@@ -147,13 +153,16 @@ class EGraphState:
147
153
  def action_to_egg(self, action: ActionDecl) -> bindings._Action:
148
154
  match action:
149
155
  case LetDecl(name, typed_expr):
150
- return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
156
+ var_decl = VarDecl(name, True)
157
+ var_egg = self._expr_to_egg(var_decl)
158
+ self.expr_to_egg_cache[var_decl] = var_egg
159
+ return bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
151
160
  case SetDecl(tp, call, rhs):
152
161
  self.type_ref_to_egg(tp)
153
162
  call_ = self._expr_to_egg(call)
154
- return bindings.Set(call_.name, call_.args, self._expr_to_egg(rhs))
163
+ return bindings.Set(bindings.DUMMY_SPAN, call_.name, call_.args, self._expr_to_egg(rhs))
155
164
  case ExprActionDecl(typed_expr):
156
- return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
165
+ return bindings.Expr_(bindings.DUMMY_SPAN, self.typed_expr_to_egg(typed_expr))
157
166
  case ChangeDecl(tp, call, change):
158
167
  self.type_ref_to_egg(tp)
159
168
  call_ = self._expr_to_egg(call)
@@ -165,12 +174,12 @@ class EGraphState:
165
174
  egg_change = bindings.Subsume()
166
175
  case _:
167
176
  assert_never(change)
168
- return bindings.Change(egg_change, call_.name, call_.args)
177
+ return bindings.Change(bindings.DUMMY_SPAN, egg_change, call_.name, call_.args)
169
178
  case UnionDecl(tp, lhs, rhs):
170
179
  self.type_ref_to_egg(tp)
171
- return bindings.Union(self._expr_to_egg(lhs), self._expr_to_egg(rhs))
180
+ return bindings.Union(bindings.DUMMY_SPAN, self._expr_to_egg(lhs), self._expr_to_egg(rhs))
172
181
  case PanicDecl(name):
173
- return bindings.Panic(name)
182
+ return bindings.Panic(bindings.DUMMY_SPAN, name)
174
183
  case _:
175
184
  assert_never(action)
176
185
 
@@ -178,9 +187,9 @@ class EGraphState:
178
187
  match fact:
179
188
  case EqDecl(tp, exprs):
180
189
  self.type_ref_to_egg(tp)
181
- return bindings.Eq([self._expr_to_egg(e) for e in exprs])
190
+ return bindings.Eq(bindings.DUMMY_SPAN, [self._expr_to_egg(e) for e in exprs])
182
191
  case ExprFactDecl(typed_expr):
183
- return bindings.Fact(self.typed_expr_to_egg(typed_expr))
192
+ return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
184
193
  case _:
185
194
  assert_never(fact)
186
195
 
@@ -191,7 +200,9 @@ class EGraphState:
191
200
  if ref in self.callable_ref_to_egg_fn:
192
201
  return self.callable_ref_to_egg_fn[ref]
193
202
  decl = self.__egg_decls__.get_callable_decl(ref)
194
- self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _generate_callable_egg_name(ref)
203
+ self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _sanitize_egg_ident(
204
+ self._generate_callable_egg_name(ref)
205
+ )
195
206
  self.egg_fn_to_callable_refs[egg_name].add(ref)
196
207
  match decl:
197
208
  case RelationDecl(arg_types, _, _):
@@ -239,13 +250,14 @@ class EGraphState:
239
250
  # UnstableFn is a special case, where the rest of args are collected into a call
240
251
  type_args: list[bindings._Expr] = [
241
252
  bindings.Call(
253
+ bindings.DUMMY_SPAN,
242
254
  self.type_ref_to_egg(ref.args[1]),
243
- [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
255
+ [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args[2:]],
244
256
  ),
245
- bindings.Var(self.type_ref_to_egg(ref.args[0])),
257
+ bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(ref.args[0])),
246
258
  ]
247
259
  else:
248
- type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
260
+ type_args = [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args]
249
261
  args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
250
262
  else:
251
263
  args = None
@@ -272,31 +284,60 @@ class EGraphState:
272
284
  if len(v) == 1
273
285
  }
274
286
 
275
- def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
287
+ def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
288
+ # transform all expressions with multiple parents into a let binding, so that less expressions
289
+ # are sent to egglog. Only for performance reasons.
290
+ if transform_let:
291
+ have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
292
+ for expr in reversed(have_multiple_parents):
293
+ self._transform_let(expr)
294
+
276
295
  self.type_ref_to_egg(typed_expr_decl.tp)
277
296
  return self._expr_to_egg(typed_expr_decl.expr)
278
297
 
298
+ def _transform_let(self, typed_expr: TypedExprDecl) -> None:
299
+ """
300
+ Rewrites this expression as a let binding if it's not already a let binding.
301
+ """
302
+ var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
303
+ if var_decl in self.expr_to_egg_cache:
304
+ return
305
+ var_egg = self._expr_to_egg(var_decl)
306
+ cmd = bindings.ActionCommand(
307
+ bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
308
+ )
309
+ try:
310
+ self.egraph.run_program(cmd)
311
+ # errors when creating let bindings for things like `(vec-empty)`
312
+ except bindings.EggSmolError:
313
+ return
314
+ self.expr_to_egg_cache[typed_expr.expr] = var_egg
315
+ self.expr_to_egg_cache[var_decl] = var_egg
316
+
279
317
  @overload
280
318
  def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
281
319
 
320
+ @overload
321
+ def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
322
+
282
323
  @overload
283
324
  def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
284
325
 
285
326
  def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
286
327
  """
287
328
  Convert an ExprDecl to an egg expression.
288
-
289
- Cached using weakrefs to avoid memory leaks.
290
329
  """
291
330
  try:
292
331
  return self.expr_to_egg_cache[expr_decl]
293
332
  except KeyError:
294
333
  pass
295
-
296
334
  res: bindings._Expr
297
335
  match expr_decl:
298
- case VarDecl(name):
299
- res = bindings.Var(name)
336
+ case VarDecl(name, is_let):
337
+ # prefix let bindings with % to avoid name conflicts with rewrites
338
+ if is_let:
339
+ name = f"%{name}"
340
+ res = bindings.Var(bindings.DUMMY_SPAN, name)
300
341
  case LitDecl(value):
301
342
  l: bindings._Literal
302
343
  match value:
@@ -312,19 +353,22 @@ class EGraphState:
312
353
  l = bindings.String(s)
313
354
  case _:
314
355
  assert_never(value)
315
- res = bindings.Lit(l)
356
+ res = bindings.Lit(bindings.DUMMY_SPAN, l)
316
357
  case CallDecl(ref, args, _):
317
358
  egg_fn = self.callable_ref_to_egg(ref)
318
- egg_args = [self.typed_expr_to_egg(a) for a in args]
319
- res = bindings.Call(egg_fn, egg_args)
359
+ egg_args = [self.typed_expr_to_egg(a, False) for a in args]
360
+ res = bindings.Call(bindings.DUMMY_SPAN, egg_fn, egg_args)
320
361
  case PyObjectDecl(value):
321
362
  res = GLOBAL_PY_OBJECT_SORT.store(value)
322
363
  case PartialCallDecl(call_decl):
323
364
  egg_fn_call = self._expr_to_egg(call_decl)
324
- res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
365
+ res = bindings.Call(
366
+ bindings.DUMMY_SPAN,
367
+ "unstable-fn",
368
+ [bindings.Lit(bindings.DUMMY_SPAN, bindings.String(egg_fn_call.name)), *egg_fn_call.args],
369
+ )
325
370
  case _:
326
371
  assert_never(expr_decl.expr)
327
-
328
372
  self.expr_to_egg_cache[expr_decl] = res
329
373
  return res
330
374
 
@@ -343,6 +387,65 @@ class EGraphState:
343
387
  """
344
388
  return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
345
389
 
390
+ def _generate_callable_egg_name(self, ref: CallableRef) -> str:
391
+ """
392
+ Generates a valid egg function name for a callable reference.
393
+ """
394
+ match ref:
395
+ case FunctionRef(name):
396
+ return name
397
+
398
+ case ConstantRef(name):
399
+ return name
400
+ case (
401
+ MethodRef(cls_name, name)
402
+ | ClassMethodRef(cls_name, name)
403
+ | ClassVariableRef(cls_name, name)
404
+ | PropertyRef(cls_name, name)
405
+ ):
406
+ return f"{cls_name}.{name}"
407
+ case InitRef(cls_name):
408
+ return f"{cls_name}.__init__"
409
+ case UnnamedFunctionRef(args, val):
410
+ parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
411
+ str(self.typed_expr_to_egg(val, False))
412
+ ]
413
+ return "_".join(parts)
414
+ case _:
415
+ assert_never(ref)
416
+
417
+
418
+ # https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
419
+ _EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
420
+
421
+
422
+ def _sanitize_egg_ident(input_string: str) -> str:
423
+ """
424
+ Replaces all invalid characters in an egg identifier with an underscore.
425
+ """
426
+ return _EGGLOG_INVALID_IDENT.sub("_", input_string)
427
+
428
+
429
+ def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
430
+ """
431
+ Returns all expressions that have multiple parents (a list but semantically just an ordered set).
432
+ """
433
+ to_traverse = {typed_expr}
434
+ traversed = set[TypedExprDecl]()
435
+ traversed_twice = list[TypedExprDecl]()
436
+ while to_traverse:
437
+ typed_expr = to_traverse.pop()
438
+ if typed_expr in traversed:
439
+ traversed_twice.append(typed_expr)
440
+ continue
441
+ traversed.add(typed_expr)
442
+ expr = typed_expr.expr
443
+ if isinstance(expr, CallDecl):
444
+ to_traverse.update(expr.args)
445
+ elif isinstance(expr, PartialCallDecl):
446
+ to_traverse.update(expr.call.args)
447
+ return traversed_twice
448
+
346
449
 
347
450
  def _generate_type_egg_name(ref: JustTypeRef) -> str:
348
451
  """
@@ -354,26 +457,6 @@ def _generate_type_egg_name(ref: JustTypeRef) -> str:
354
457
  return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}"
355
458
 
356
459
 
357
- def _generate_callable_egg_name(ref: CallableRef) -> str:
358
- """
359
- Generates a valid egg function name for a callable reference.
360
- """
361
- match ref:
362
- case FunctionRef(name) | ConstantRef(name):
363
- return name
364
- case (
365
- MethodRef(cls_name, name)
366
- | ClassMethodRef(cls_name, name)
367
- | ClassVariableRef(cls_name, name)
368
- | PropertyRef(cls_name, name)
369
- ):
370
- return f"{cls_name}_{name}"
371
- case InitRef(cls_name):
372
- return f"{cls_name}___init__"
373
- case _:
374
- assert_never(ref)
375
-
376
-
377
460
  @dataclass
378
461
  class FromEggState:
379
462
  """
@@ -395,7 +478,7 @@ class FromEggState:
395
478
  """
396
479
  expr_decl: ExprDecl
397
480
  if isinstance(term, bindings.TermVar):
398
- expr_decl = VarDecl(term.name)
481
+ expr_decl = VarDecl(term.name, True)
399
482
  elif isinstance(term, bindings.TermLit):
400
483
  value = term.value
401
484
  expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
@@ -478,10 +561,3 @@ class FromEggState:
478
561
  except KeyError:
479
562
  res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
480
563
  return res
481
-
482
-
483
- def _rule_var_name(s: str) -> str:
484
- """
485
- Create a hidden variable name, for rewrites, so that let bindings or function won't conflict with it
486
- """
487
- return f"__var__{s}"
@@ -16,15 +16,12 @@ if TYPE_CHECKING:
16
16
 
17
17
  class Math(Expr):
18
18
  def __init__(self, i: i64Like) -> None: ...
19
-
20
19
  def __add__(self, other: Math) -> Math: ...
21
20
 
22
21
 
23
22
  class MathList(Expr):
24
23
  def __init__(self) -> None: ...
25
-
26
24
  def append(self, i: Math) -> MathList: ...
27
-
28
25
  def map(self, f: Callable[[Math], Math]) -> MathList: ...
29
26
 
30
27
 
@@ -36,15 +33,13 @@ def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math
36
33
 
37
34
 
38
35
  @function(ruleset=math_ruleset)
39
- def increment_by_one(x: Math) -> Math:
40
- return x + Math(1)
36
+ def incr_list(xs: MathList) -> MathList:
37
+ return xs.map(lambda x: x + Math(1))
41
38
 
42
39
 
43
40
  egraph = EGraph()
44
- x = egraph.let("x", MathList().append(Math(1)).append(Math(2)))
45
- y = egraph.let("y", x.map(increment_by_one))
41
+ y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
46
42
  egraph.run(math_ruleset.saturate())
47
-
48
43
  egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
49
44
 
50
45
  egraph