egglog 7.1.0__cp311-none-win_amd64.whl → 8.0.0__cp311-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph_state.py CHANGED
@@ -4,10 +4,10 @@ Implement conversion to/from egglog.
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
+ import re
7
8
  from collections import defaultdict
8
9
  from dataclasses import dataclass, field
9
10
  from typing import TYPE_CHECKING, overload
10
- from weakref import WeakKeyDictionary
11
11
 
12
12
  from typing_extensions import assert_never
13
13
 
@@ -52,7 +52,7 @@ class EGraphState:
52
52
  type_ref_to_egg_sort: dict[JustTypeRef, str] = field(default_factory=dict)
53
53
 
54
54
  # Cache of egg expressions for converting to egg
55
- expr_to_egg_cache: WeakKeyDictionary[ExprDecl, bindings._Expr] = field(default_factory=WeakKeyDictionary)
55
+ expr_to_egg_cache: dict[ExprDecl, bindings._Expr] = field(default_factory=dict)
56
56
 
57
57
  def copy(self) -> EGraphState:
58
58
  """
@@ -71,15 +71,15 @@ class EGraphState:
71
71
  def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
72
72
  match schedule:
73
73
  case SaturateDecl(schedule):
74
- return bindings.Saturate(self.schedule_to_egg(schedule))
74
+ return bindings.Saturate(bindings.DUMMY_SPAN, self.schedule_to_egg(schedule))
75
75
  case RepeatDecl(schedule, times):
76
- return bindings.Repeat(times, self.schedule_to_egg(schedule))
76
+ return bindings.Repeat(bindings.DUMMY_SPAN, times, self.schedule_to_egg(schedule))
77
77
  case SequenceDecl(schedules):
78
- return bindings.Sequence([self.schedule_to_egg(s) for s in schedules])
78
+ return bindings.Sequence(bindings.DUMMY_SPAN, [self.schedule_to_egg(s) for s in schedules])
79
79
  case RunDecl(ruleset_name, until):
80
80
  self.ruleset_to_egg(ruleset_name)
81
81
  config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
82
- return bindings.Run(config)
82
+ return bindings.Run(bindings.DUMMY_SPAN, config)
83
83
  case _:
84
84
  assert_never(schedule)
85
85
 
@@ -98,7 +98,8 @@ class EGraphState:
98
98
  for rule in rules:
99
99
  if rule in added_rules:
100
100
  continue
101
- self.egraph.run_program(self.command_to_egg(rule, name))
101
+ cmd = self.command_to_egg(rule, name)
102
+ self.egraph.run_program(cmd)
102
103
  added_rules.add(rule)
103
104
  case CombinedRulesetDecl(rulesets):
104
105
  if name in self.rulesets:
@@ -115,8 +116,9 @@ class EGraphState:
115
116
  case RewriteDecl(tp, lhs, rhs, conditions) | BiRewriteDecl(tp, lhs, rhs, conditions):
116
117
  self.type_ref_to_egg(tp)
117
118
  rewrite = bindings.Rewrite(
118
- self.expr_to_egg(lhs),
119
- self.expr_to_egg(rhs),
119
+ bindings.DUMMY_SPAN,
120
+ self._expr_to_egg(lhs),
121
+ self._expr_to_egg(rhs),
120
122
  [self.fact_to_egg(c) for c in conditions],
121
123
  )
122
124
  return (
@@ -126,26 +128,44 @@ class EGraphState:
126
128
  )
127
129
  case RuleDecl(head, body, name):
128
130
  rule = bindings.Rule(
131
+ bindings.DUMMY_SPAN,
129
132
  [self.action_to_egg(a) for a in head],
130
133
  [self.fact_to_egg(f) for f in body],
131
134
  )
132
135
  return bindings.RuleCommand(name or "", ruleset, rule)
136
+ # TODO: Replace with just constants value and looking at REF of function
137
+ case DefaultRewriteDecl(ref, expr):
138
+ decl = self.__egg_decls__.get_callable_decl(ref).to_function_decl()
139
+ sig = decl.signature
140
+ assert isinstance(sig, FunctionSignature)
141
+ # Replace args with rule_var_name mapping
142
+ arg_mapping = tuple(
143
+ TypedExprDecl(tp.to_just(), VarDecl(name, False))
144
+ for name, tp in zip(sig.arg_names, sig.arg_types, strict=False)
145
+ )
146
+ rewrite_decl = RewriteDecl(
147
+ sig.semantic_return_type.to_just(), CallDecl(ref, arg_mapping), expr, (), False
148
+ )
149
+ return self.command_to_egg(rewrite_decl, ruleset)
133
150
  case _:
134
151
  assert_never(cmd)
135
152
 
136
153
  def action_to_egg(self, action: ActionDecl) -> bindings._Action:
137
154
  match action:
138
155
  case LetDecl(name, typed_expr):
139
- return bindings.Let(name, self.typed_expr_to_egg(typed_expr))
156
+ var_decl = VarDecl(name, True)
157
+ var_egg = self._expr_to_egg(var_decl)
158
+ self.expr_to_egg_cache[var_decl] = var_egg
159
+ return bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
140
160
  case SetDecl(tp, call, rhs):
141
161
  self.type_ref_to_egg(tp)
142
- call_ = self.expr_to_egg(call)
143
- return bindings.Set(call_.name, call_.args, self.expr_to_egg(rhs))
162
+ call_ = self._expr_to_egg(call)
163
+ return bindings.Set(bindings.DUMMY_SPAN, call_.name, call_.args, self._expr_to_egg(rhs))
144
164
  case ExprActionDecl(typed_expr):
145
- return bindings.Expr_(self.typed_expr_to_egg(typed_expr))
165
+ return bindings.Expr_(bindings.DUMMY_SPAN, self.typed_expr_to_egg(typed_expr))
146
166
  case ChangeDecl(tp, call, change):
147
167
  self.type_ref_to_egg(tp)
148
- call_ = self.expr_to_egg(call)
168
+ call_ = self._expr_to_egg(call)
149
169
  egg_change: bindings._Change
150
170
  match change:
151
171
  case "delete":
@@ -154,12 +174,12 @@ class EGraphState:
154
174
  egg_change = bindings.Subsume()
155
175
  case _:
156
176
  assert_never(change)
157
- return bindings.Change(egg_change, call_.name, call_.args)
177
+ return bindings.Change(bindings.DUMMY_SPAN, egg_change, call_.name, call_.args)
158
178
  case UnionDecl(tp, lhs, rhs):
159
179
  self.type_ref_to_egg(tp)
160
- return bindings.Union(self.expr_to_egg(lhs), self.expr_to_egg(rhs))
180
+ return bindings.Union(bindings.DUMMY_SPAN, self._expr_to_egg(lhs), self._expr_to_egg(rhs))
161
181
  case PanicDecl(name):
162
- return bindings.Panic(name)
182
+ return bindings.Panic(bindings.DUMMY_SPAN, name)
163
183
  case _:
164
184
  assert_never(action)
165
185
 
@@ -167,9 +187,9 @@ class EGraphState:
167
187
  match fact:
168
188
  case EqDecl(tp, exprs):
169
189
  self.type_ref_to_egg(tp)
170
- return bindings.Eq([self.expr_to_egg(e) for e in exprs])
190
+ return bindings.Eq(bindings.DUMMY_SPAN, [self._expr_to_egg(e) for e in exprs])
171
191
  case ExprFactDecl(typed_expr):
172
- return bindings.Fact(self.typed_expr_to_egg(typed_expr))
192
+ return bindings.Fact(self.typed_expr_to_egg(typed_expr, False))
173
193
  case _:
174
194
  assert_never(fact)
175
195
 
@@ -180,7 +200,9 @@ class EGraphState:
180
200
  if ref in self.callable_ref_to_egg_fn:
181
201
  return self.callable_ref_to_egg_fn[ref]
182
202
  decl = self.__egg_decls__.get_callable_decl(ref)
183
- self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _generate_callable_egg_name(ref)
203
+ self.callable_ref_to_egg_fn[ref] = egg_name = decl.egg_name or _sanitize_egg_ident(
204
+ self._generate_callable_egg_name(ref)
205
+ )
184
206
  self.egg_fn_to_callable_refs[egg_name].add(ref)
185
207
  match decl:
186
208
  case RelationDecl(arg_types, _, _):
@@ -201,8 +223,8 @@ class EGraphState:
201
223
  [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
202
224
  self.type_ref_to_egg(signature.semantic_return_type.to_just()),
203
225
  ),
204
- self.expr_to_egg(decl.default) if decl.default else None,
205
- self.expr_to_egg(decl.merge) if decl.merge else None,
226
+ self._expr_to_egg(decl.default) if decl.default else None,
227
+ self._expr_to_egg(decl.merge) if decl.merge else None,
206
228
  [self.action_to_egg(a) for a in decl.on_merge],
207
229
  decl.cost,
208
230
  decl.unextractable,
@@ -228,13 +250,14 @@ class EGraphState:
228
250
  # UnstableFn is a special case, where the rest of args are collected into a call
229
251
  type_args: list[bindings._Expr] = [
230
252
  bindings.Call(
253
+ bindings.DUMMY_SPAN,
231
254
  self.type_ref_to_egg(ref.args[1]),
232
- [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
255
+ [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args[2:]],
233
256
  ),
234
- bindings.Var(self.type_ref_to_egg(ref.args[0])),
257
+ bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(ref.args[0])),
235
258
  ]
236
259
  else:
237
- type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
260
+ type_args = [bindings.Var(bindings.DUMMY_SPAN, self.type_ref_to_egg(a)) for a in ref.args]
238
261
  args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
239
262
  else:
240
263
  args = None
@@ -245,6 +268,8 @@ class EGraphState:
245
268
  if decl.builtin:
246
269
  for method in decl.class_methods:
247
270
  self.callable_ref_to_egg(ClassMethodRef(ref.name, method))
271
+ if decl.init:
272
+ self.callable_ref_to_egg(InitRef(ref.name))
248
273
 
249
274
  return egg_name
250
275
 
@@ -259,31 +284,60 @@ class EGraphState:
259
284
  if len(v) == 1
260
285
  }
261
286
 
262
- def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl) -> bindings._Expr:
287
+ def typed_expr_to_egg(self, typed_expr_decl: TypedExprDecl, transform_let: bool = True) -> bindings._Expr:
288
+ # transform all expressions with multiple parents into a let binding, so that less expressions
289
+ # are sent to egglog. Only for performance reasons.
290
+ if transform_let:
291
+ have_multiple_parents = _exprs_multiple_parents(typed_expr_decl)
292
+ for expr in reversed(have_multiple_parents):
293
+ self._transform_let(expr)
294
+
263
295
  self.type_ref_to_egg(typed_expr_decl.tp)
264
- return self.expr_to_egg(typed_expr_decl.expr)
296
+ return self._expr_to_egg(typed_expr_decl.expr)
297
+
298
+ def _transform_let(self, typed_expr: TypedExprDecl) -> None:
299
+ """
300
+ Rewrites this expression as a let binding if it's not already a let binding.
301
+ """
302
+ var_decl = VarDecl(f"__expr_{hash(typed_expr)}", True)
303
+ if var_decl in self.expr_to_egg_cache:
304
+ return
305
+ var_egg = self._expr_to_egg(var_decl)
306
+ cmd = bindings.ActionCommand(
307
+ bindings.Let(bindings.DUMMY_SPAN, var_egg.name, self.typed_expr_to_egg(typed_expr))
308
+ )
309
+ try:
310
+ self.egraph.run_program(cmd)
311
+ # errors when creating let bindings for things like `(vec-empty)`
312
+ except bindings.EggSmolError:
313
+ return
314
+ self.expr_to_egg_cache[typed_expr.expr] = var_egg
315
+ self.expr_to_egg_cache[var_decl] = var_egg
265
316
 
266
317
  @overload
267
- def expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
318
+ def _expr_to_egg(self, expr_decl: CallDecl) -> bindings.Call: ...
268
319
 
269
320
  @overload
270
- def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
321
+ def _expr_to_egg(self, expr_decl: VarDecl) -> bindings.Var: ...
271
322
 
272
- def expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
323
+ @overload
324
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr: ...
325
+
326
+ def _expr_to_egg(self, expr_decl: ExprDecl) -> bindings._Expr:
273
327
  """
274
328
  Convert an ExprDecl to an egg expression.
275
-
276
- Cached using weakrefs to avoid memory leaks.
277
329
  """
278
330
  try:
279
331
  return self.expr_to_egg_cache[expr_decl]
280
332
  except KeyError:
281
333
  pass
282
-
283
334
  res: bindings._Expr
284
335
  match expr_decl:
285
- case VarDecl(name):
286
- res = bindings.Var(name)
336
+ case VarDecl(name, is_let):
337
+ # prefix let bindings with % to avoid name conflicts with rewrites
338
+ if is_let:
339
+ name = f"%{name}"
340
+ res = bindings.Var(bindings.DUMMY_SPAN, name)
287
341
  case LitDecl(value):
288
342
  l: bindings._Literal
289
343
  match value:
@@ -299,19 +353,22 @@ class EGraphState:
299
353
  l = bindings.String(s)
300
354
  case _:
301
355
  assert_never(value)
302
- res = bindings.Lit(l)
356
+ res = bindings.Lit(bindings.DUMMY_SPAN, l)
303
357
  case CallDecl(ref, args, _):
304
358
  egg_fn = self.callable_ref_to_egg(ref)
305
- egg_args = [self.typed_expr_to_egg(a) for a in args]
306
- res = bindings.Call(egg_fn, egg_args)
359
+ egg_args = [self.typed_expr_to_egg(a, False) for a in args]
360
+ res = bindings.Call(bindings.DUMMY_SPAN, egg_fn, egg_args)
307
361
  case PyObjectDecl(value):
308
362
  res = GLOBAL_PY_OBJECT_SORT.store(value)
309
363
  case PartialCallDecl(call_decl):
310
- egg_fn_call = self.expr_to_egg(call_decl)
311
- res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
364
+ egg_fn_call = self._expr_to_egg(call_decl)
365
+ res = bindings.Call(
366
+ bindings.DUMMY_SPAN,
367
+ "unstable-fn",
368
+ [bindings.Lit(bindings.DUMMY_SPAN, bindings.String(egg_fn_call.name)), *egg_fn_call.args],
369
+ )
312
370
  case _:
313
371
  assert_never(expr_decl.expr)
314
-
315
372
  self.expr_to_egg_cache[expr_decl] = res
316
373
  return res
317
374
 
@@ -330,6 +387,65 @@ class EGraphState:
330
387
  """
331
388
  return frozenset(tp for tp in self.type_ref_to_egg_sort if tp.name == cls_name)
332
389
 
390
+ def _generate_callable_egg_name(self, ref: CallableRef) -> str:
391
+ """
392
+ Generates a valid egg function name for a callable reference.
393
+ """
394
+ match ref:
395
+ case FunctionRef(name):
396
+ return name
397
+
398
+ case ConstantRef(name):
399
+ return name
400
+ case (
401
+ MethodRef(cls_name, name)
402
+ | ClassMethodRef(cls_name, name)
403
+ | ClassVariableRef(cls_name, name)
404
+ | PropertyRef(cls_name, name)
405
+ ):
406
+ return f"{cls_name}.{name}"
407
+ case InitRef(cls_name):
408
+ return f"{cls_name}.__init__"
409
+ case UnnamedFunctionRef(args, val):
410
+ parts = [str(self._expr_to_egg(a.expr)) + "-" + str(self.type_ref_to_egg(a.tp)) for a in args] + [
411
+ str(self.typed_expr_to_egg(val, False))
412
+ ]
413
+ return "_".join(parts)
414
+ case _:
415
+ assert_never(ref)
416
+
417
+
418
+ # https://chatgpt.com/share/9ab899b4-4e17-4426-a3f2-79d67a5ec456
419
+ _EGGLOG_INVALID_IDENT = re.compile(r"[^\w\-+*/?!=<>&|^/%]")
420
+
421
+
422
+ def _sanitize_egg_ident(input_string: str) -> str:
423
+ """
424
+ Replaces all invalid characters in an egg identifier with an underscore.
425
+ """
426
+ return _EGGLOG_INVALID_IDENT.sub("_", input_string)
427
+
428
+
429
+ def _exprs_multiple_parents(typed_expr: TypedExprDecl) -> list[TypedExprDecl]:
430
+ """
431
+ Returns all expressions that have multiple parents (a list but semantically just an ordered set).
432
+ """
433
+ to_traverse = {typed_expr}
434
+ traversed = set[TypedExprDecl]()
435
+ traversed_twice = list[TypedExprDecl]()
436
+ while to_traverse:
437
+ typed_expr = to_traverse.pop()
438
+ if typed_expr in traversed:
439
+ traversed_twice.append(typed_expr)
440
+ continue
441
+ traversed.add(typed_expr)
442
+ expr = typed_expr.expr
443
+ if isinstance(expr, CallDecl):
444
+ to_traverse.update(expr.args)
445
+ elif isinstance(expr, PartialCallDecl):
446
+ to_traverse.update(expr.call.args)
447
+ return traversed_twice
448
+
333
449
 
334
450
  def _generate_type_egg_name(ref: JustTypeRef) -> str:
335
451
  """
@@ -341,24 +457,6 @@ def _generate_type_egg_name(ref: JustTypeRef) -> str:
341
457
  return f"{name}_{'_'.join(map(_generate_type_egg_name, ref.args))}"
342
458
 
343
459
 
344
- def _generate_callable_egg_name(ref: CallableRef) -> str:
345
- """
346
- Generates a valid egg function name for a callable reference.
347
- """
348
- match ref:
349
- case FunctionRef(name) | ConstantRef(name):
350
- return name
351
- case (
352
- MethodRef(cls_name, name)
353
- | ClassMethodRef(cls_name, name)
354
- | ClassVariableRef(cls_name, name)
355
- | PropertyRef(cls_name, name)
356
- ):
357
- return f"{cls_name}_{name}"
358
- case _:
359
- assert_never(ref)
360
-
361
-
362
460
  @dataclass
363
461
  class FromEggState:
364
462
  """
@@ -380,7 +478,7 @@ class FromEggState:
380
478
  """
381
479
  expr_decl: ExprDecl
382
480
  if isinstance(term, bindings.TermVar):
383
- expr_decl = VarDecl(term.name)
481
+ expr_decl = VarDecl(term.name, True)
384
482
  elif isinstance(term, bindings.TermLit):
385
483
  value = term.value
386
484
  expr_decl = LitDecl(None if isinstance(value, bindings.Unit) else value.value)
@@ -427,8 +525,11 @@ class FromEggState:
427
525
  possible_types: Iterable[JustTypeRef | None]
428
526
  signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature
429
527
  assert isinstance(signature, FunctionSignature)
430
- if isinstance(callable_ref, ClassMethodRef):
431
- possible_types = self.state._get_possible_types(callable_ref.class_name)
528
+ if isinstance(callable_ref, ClassMethodRef | InitRef | MethodRef):
529
+ # Need OR in case we have class method whose class whas never added as a sort, which would happen
530
+ # if the class method didn't return that type and no other function did. In this case, we don't need
531
+ # to care about the type vars and we we don't need to bind any possible type.
532
+ possible_types = self.state._get_possible_types(callable_ref.class_name) or [None]
432
533
  cls_name = callable_ref.class_name
433
534
  else:
434
535
  possible_types = [None]
@@ -437,7 +538,6 @@ class FromEggState:
437
538
  tcs = TypeConstraintSolver(self.decls)
438
539
  if possible_type and possible_type.args:
439
540
  tcs.bind_class(possible_type)
440
-
441
541
  try:
442
542
  arg_types, bound_tp_params = tcs.infer_arg_types(
443
543
  signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
@@ -445,7 +545,14 @@ class FromEggState:
445
545
  except TypeConstraintError:
446
546
  continue
447
547
  args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
448
- return CallDecl(callable_ref, args, bound_tp_params)
548
+
549
+ return CallDecl(
550
+ callable_ref,
551
+ args,
552
+ # Don't include bound type params if this is just a method, we only needed them for type resolution
553
+ # but dont need to store them
554
+ bound_tp_params if isinstance(callable_ref, ClassMethodRef | InitRef) else None,
555
+ )
449
556
  raise ValueError(f"Could not find callable ref for call {term}")
450
557
 
451
558
  def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
@@ -0,0 +1,45 @@
1
+ # mypy: disable-error-code="empty-body"
2
+ """
3
+ Higher Order Functions
4
+ ======================
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING
10
+
11
+ from egglog import *
12
+
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Callable
15
+
16
+
17
+ class Math(Expr):
18
+ def __init__(self, i: i64Like) -> None: ...
19
+ def __add__(self, other: Math) -> Math: ...
20
+
21
+
22
+ class MathList(Expr):
23
+ def __init__(self) -> None: ...
24
+ def append(self, i: Math) -> MathList: ...
25
+ def map(self, f: Callable[[Math], Math]) -> MathList: ...
26
+
27
+
28
+ @ruleset
29
+ def math_ruleset(i: i64, j: i64, xs: MathList, x: Math, f: Callable[[Math], Math]): # noqa: ANN201
30
+ yield rewrite(Math(i) + Math(j)).to(Math(i + j))
31
+ yield rewrite(xs.append(x).map(f)).to(xs.map(f).append(f(x)))
32
+ yield rewrite(MathList().map(f)).to(MathList())
33
+
34
+
35
+ @function(ruleset=math_ruleset)
36
+ def incr_list(xs: MathList) -> MathList:
37
+ return xs.map(lambda x: x + Math(1))
38
+
39
+
40
+ egraph = EGraph()
41
+ y = egraph.let("y", incr_list(MathList().append(Math(1)).append(Math(2))))
42
+ egraph.run(math_ruleset.saturate())
43
+ egraph.check(eq(y).to(MathList().append(Math(2)).append(Math(3))))
44
+
45
+ egraph