egglog 9.0.0__cp311-cp311-win_amd64.whl → 10.0.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/__init__.py CHANGED
@@ -3,6 +3,7 @@ Package for creating e-graphs in Python.
3
3
  """
4
4
 
5
5
  from . import config, ipython_magic # noqa: F401
6
+ from .bindings import EggSmolError # noqa: F401
6
7
  from .builtins import * # noqa: UP029
7
8
  from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
8
9
  from .egraph import *
Binary file
egglog/builtins.py CHANGED
@@ -12,11 +12,9 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union,
12
12
 
13
13
  from typing_extensions import TypeVarTuple, Unpack
14
14
 
15
- from . import bindings
16
15
  from .conversion import convert, converter, get_type_args
17
16
  from .declarations import *
18
- from .egraph import BaseExpr, BuiltinExpr, EGraph, expr_fact, function, get_current_ruleset, method
19
- from .egraph_state import GLOBAL_PY_OBJECT_SORT
17
+ from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
20
18
  from .functionalize import functionalize
21
19
  from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
22
20
  from .thunk import Thunk
@@ -32,6 +30,7 @@ __all__ = [
32
30
  "BigRatLike",
33
31
  "Bool",
34
32
  "BoolLike",
33
+ "BuiltinEvalError",
35
34
  "Map",
36
35
  "MapLike",
37
36
  "MultiSet",
@@ -56,6 +55,17 @@ __all__ = [
56
55
  ]
57
56
 
58
57
 
58
+ class BuiltinEvalError(Exception):
59
+ """
60
+ Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
61
+
62
+ Try extracting this expression first.
63
+ """
64
+
65
+ def __str__(self) -> str:
66
+ return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}"
67
+
68
+
59
69
  class Unit(BuiltinExpr, egg_sort="Unit"):
60
70
  """
61
71
  The unit type. This is used to reprsent if a value exists in the e-graph or not.
@@ -72,8 +82,8 @@ class String(BuiltinExpr):
72
82
  @method(preserve=True)
73
83
  def eval(self) -> str:
74
84
  value = _extract_lit(self)
75
- assert isinstance(value, bindings.String)
76
- return value.value
85
+ assert isinstance(value, str)
86
+ return value
77
87
 
78
88
  def __init__(self, value: str) -> None: ...
79
89
 
@@ -97,8 +107,8 @@ class Bool(BuiltinExpr, egg_sort="bool"):
97
107
  @method(preserve=True)
98
108
  def eval(self) -> bool:
99
109
  value = _extract_lit(self)
100
- assert isinstance(value, bindings.Bool)
101
- return value.value
110
+ assert isinstance(value, bool)
111
+ return value
102
112
 
103
113
  @method(preserve=True)
104
114
  def __bool__(self) -> bool:
@@ -132,8 +142,8 @@ class i64(BuiltinExpr): # noqa: N801
132
142
  @method(preserve=True)
133
143
  def eval(self) -> int:
134
144
  value = _extract_lit(self)
135
- assert isinstance(value, bindings.Int)
136
- return value.value
145
+ assert isinstance(value, int)
146
+ return value
137
147
 
138
148
  @method(preserve=True)
139
149
  def __index__(self) -> int:
@@ -195,6 +205,9 @@ class i64(BuiltinExpr): # noqa: N801
195
205
 
196
206
  def __rrshift__(self, other: i64Like) -> i64: ...
197
207
 
208
+ @method(egg_fn="log2")
209
+ def log2(self) -> i64: ...
210
+
198
211
  @method(egg_fn="not-i64")
199
212
  def __invert__(self) -> i64: ...
200
213
 
@@ -248,8 +261,8 @@ class f64(BuiltinExpr): # noqa: N801
248
261
  @method(preserve=True)
249
262
  def eval(self) -> float:
250
263
  value = _extract_lit(self)
251
- assert isinstance(value, bindings.Float)
252
- return value.value
264
+ assert isinstance(value, float)
265
+ return value
253
266
 
254
267
  @method(preserve=True)
255
268
  def __float__(self) -> float:
@@ -337,9 +350,12 @@ class Map(BuiltinExpr, Generic[T, V]):
337
350
  expr = cast("RuntimeExpr", self)
338
351
  d = {}
339
352
  while call.callable != ClassMethodRef("Map", "empty"):
340
- assert call.callable == MethodRef("Map", "insert")
353
+ msg = "Map can only be evaluated if it is empty or a series of inserts."
354
+ if call.callable != MethodRef("Map", "insert"):
355
+ raise BuiltinEvalError(msg)
341
356
  call_typed, k_typed, v_typed = call.args
342
- assert isinstance(call_typed.expr, CallDecl)
357
+ if not isinstance(call_typed.expr, CallDecl):
358
+ raise BuiltinEvalError(msg)
343
359
  k = cast("T", expr.__with_expr__(k_typed))
344
360
  v = cast("V", expr.__with_expr__(v_typed))
345
361
  d[k] = v
@@ -401,7 +417,9 @@ class Set(BuiltinExpr, Generic[T]):
401
417
  @method(preserve=True)
402
418
  def eval(self) -> set[T]:
403
419
  call = _extract_call(self)
404
- assert call.callable == InitRef("Set")
420
+ if call.callable != InitRef("Set"):
421
+ msg = "Set can only be initialized with the Set constructor."
422
+ raise BuiltinEvalError(msg)
405
423
  return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
406
424
 
407
425
  @method(preserve=True)
@@ -463,7 +481,9 @@ class MultiSet(BuiltinExpr, Generic[T]):
463
481
  @method(preserve=True)
464
482
  def eval(self) -> list[T]:
465
483
  call = _extract_call(self)
466
- assert call.callable == InitRef("MultiSet")
484
+ if call.callable != InitRef("MultiSet"):
485
+ msg = "MultiSet can only be initialized with the MultiSet constructor."
486
+ raise BuiltinEvalError(msg)
467
487
  return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
468
488
 
469
489
  @method(preserve=True)
@@ -510,11 +530,15 @@ class Rational(BuiltinExpr):
510
530
  @method(preserve=True)
511
531
  def eval(self) -> Fraction:
512
532
  call = _extract_call(self)
513
- assert call.callable == InitRef("Rational")
533
+ if call.callable != InitRef("Rational"):
534
+ msg = "Rational can only be initialized with the Rational constructor."
535
+ raise BuiltinEvalError(msg)
514
536
 
515
537
  def _to_int(e: TypedExprDecl) -> int:
516
538
  expr = e.expr
517
- assert isinstance(expr, LitDecl)
539
+ if not isinstance(expr, LitDecl):
540
+ msg = "Rational can only be initialized with literals"
541
+ raise BuiltinEvalError(msg)
518
542
  assert isinstance(expr.value, int)
519
543
  return expr.value
520
544
 
@@ -593,9 +617,13 @@ class BigInt(BuiltinExpr):
593
617
  @method(preserve=True)
594
618
  def eval(self) -> int:
595
619
  call = _extract_call(self)
596
- assert call.callable == ClassMethodRef("BigInt", "from_string")
620
+ if call.callable != ClassMethodRef("BigInt", "from_string"):
621
+ msg = "BigInt can only be initialized with the BigInt constructor."
622
+ raise BuiltinEvalError(msg)
597
623
  (s,) = call.args
598
- assert isinstance(s.expr, LitDecl)
624
+ if not isinstance(s.expr, LitDecl):
625
+ msg = "BigInt can only be initialized with literals"
626
+ raise BuiltinEvalError(msg)
599
627
  assert isinstance(s.expr.value, str)
600
628
  return int(s.expr.value)
601
629
 
@@ -714,14 +742,19 @@ class BigRat(BuiltinExpr):
714
742
  @method(preserve=True)
715
743
  def eval(self) -> Fraction:
716
744
  call = _extract_call(self)
717
- assert call.callable == InitRef("BigRat")
745
+ if call.callable != InitRef("BigRat"):
746
+ msg = "BigRat can only be initialized with the BigRat constructor."
747
+ raise BuiltinEvalError(msg)
718
748
 
719
749
  def _to_fraction(e: TypedExprDecl) -> Fraction:
720
750
  expr = e.expr
721
- assert isinstance(expr, CallDecl)
722
- assert expr.callable == ClassMethodRef("BigInt", "from_string")
751
+ if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"):
752
+ msg = "BigRat can only be initialized BigInt strings"
753
+ raise BuiltinEvalError(msg)
723
754
  (s,) = expr.args
724
- assert isinstance(s.expr, LitDecl)
755
+ if not isinstance(s.expr, LitDecl):
756
+ msg = "BigInt can only be initialized with literals"
757
+ raise BuiltinEvalError(msg)
725
758
  assert isinstance(s.expr.value, str)
726
759
  return Fraction(s.expr.value)
727
760
 
@@ -818,7 +851,10 @@ class Vec(BuiltinExpr, Generic[T]):
818
851
  call = _extract_call(self)
819
852
  if call.callable == ClassMethodRef("Vec", "empty"):
820
853
  return ()
821
- assert call.callable == InitRef("Vec")
854
+
855
+ if call.callable != InitRef("Vec"):
856
+ msg = "Vec can only be initialized with the Vec constructor."
857
+ raise BuiltinEvalError(msg)
822
858
  return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
823
859
 
824
860
  @method(preserve=True)
@@ -886,10 +922,11 @@ VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
886
922
  class PyObject(BuiltinExpr):
887
923
  @method(preserve=True)
888
924
  def eval(self) -> object:
889
- report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", self), 0)
890
- assert isinstance(report, bindings.Best)
891
- expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan())
892
- return GLOBAL_PY_OBJECT_SORT.load(expr)
925
+ expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
926
+ if not isinstance(expr, PyObjectDecl):
927
+ msg = "PyObject can only be evaluated if it is a PyObject literal"
928
+ raise BuiltinEvalError(msg)
929
+ return expr.value
893
930
 
894
931
  def __init__(self, value: object) -> None: ...
895
932
 
@@ -1024,22 +1061,23 @@ def value_to_annotation(a: object) -> type | None:
1024
1061
  converter(FunctionType, UnstableFn, _convert_function)
1025
1062
 
1026
1063
 
1027
- def _extract_lit(e: BaseExpr) -> bindings._Literal:
1064
+ def _extract_lit(e: BaseExpr) -> LitType:
1028
1065
  """
1029
1066
  Special case extracting literals to make this faster by using termdag directly.
1030
1067
  """
1031
- report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", e), 0)
1032
- assert isinstance(report, bindings.Best)
1033
- term = report.term
1034
- assert isinstance(term, bindings.TermLit)
1035
- return term.value
1068
+ expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1069
+ if not isinstance(expr, LitDecl):
1070
+ msg = "Expected a literal"
1071
+ raise BuiltinEvalError(msg)
1072
+ return expr.value
1036
1073
 
1037
1074
 
1038
1075
  def _extract_call(e: BaseExpr) -> CallDecl:
1039
1076
  """
1040
1077
  Extracts the call form of an expression
1041
1078
  """
1042
- extracted = cast("RuntimeExpr", (EGraph.current or EGraph()).extract(e))
1043
- expr = extracted.__egg_typed_expr__.expr
1044
- assert isinstance(expr, CallDecl)
1079
+ expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1080
+ if not isinstance(expr, CallDecl):
1081
+ msg = "Expected a call expression"
1082
+ raise BuiltinEvalError(msg)
1045
1083
  return expr
egglog/conversion.py CHANGED
@@ -149,6 +149,12 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
149
149
  decls = _retrieve_conversion_decls()
150
150
  a_tp = _get_tp(a)
151
151
  b_tp = _get_tp(b)
152
+ # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153
+ if not (
154
+ (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155
+ or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156
+ ):
157
+ raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
152
158
  a_converts_to = {
153
159
  to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
154
160
  }
egglog/egraph.py CHANGED
@@ -4,7 +4,7 @@ import contextlib
4
4
  import inspect
5
5
  import pathlib
6
6
  import tempfile
7
- from collections.abc import Callable, Generator, Iterable, Iterator
7
+ from collections.abc import Callable, Generator, Iterable
8
8
  from contextvars import ContextVar
9
9
  from dataclasses import InitVar, dataclass, field
10
10
  from functools import partial
@@ -17,7 +17,6 @@ from typing import (
17
17
  Generic,
18
18
  Literal,
19
19
  Never,
20
- Protocol,
21
20
  TypeAlias,
22
21
  TypedDict,
23
22
  TypeVar,
@@ -85,7 +84,6 @@ __all__ = [
85
84
  "set_",
86
85
  "simplify",
87
86
  "subsume",
88
- "try_evaling",
89
87
  "union",
90
88
  "unstable_combine_rulesets",
91
89
  "var",
@@ -847,11 +845,10 @@ class EGraph:
847
845
  Can run actions, check facts, run schedules, or extract minimal cost expressions.
848
846
  """
849
847
 
850
- current: ClassVar[EGraph | None] = None
851
848
  seminaive: InitVar[bool] = True
852
849
  save_egglog_string: InitVar[bool] = False
853
850
 
854
- _state: EGraphState = field(init=False)
851
+ _state: EGraphState = field(init=False, repr=False)
855
852
  # For pushing/popping with egglog
856
853
  _state_stack: list[EGraphState] = field(default_factory=list, repr=False)
857
854
  # For storing the global "current" egraph
@@ -1200,16 +1197,6 @@ class EGraph:
1200
1197
  if visualize:
1201
1198
  VisualizerWidget(egraphs=egraphs).display_or_open()
1202
1199
 
1203
- @contextlib.contextmanager
1204
- def set_current(self) -> Iterator[None]:
1205
- """
1206
- Context manager that will set the current egraph. It will be set back after.
1207
- """
1208
- prev_current = EGraph.current
1209
- EGraph.current = self
1210
- yield
1211
- EGraph.current = prev_current
1212
-
1213
1200
  @property
1214
1201
  def _egraph(self) -> bindings.EGraph:
1215
1202
  return self._state.egraph
@@ -1303,8 +1290,6 @@ class Schedule(DelayedDeclerations):
1303
1290
  A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
1304
1291
  """
1305
1292
 
1306
- current: ClassVar[Schedule | None] = None
1307
-
1308
1293
  # Defer declerations so that we can have rule generators that used not yet defined yet
1309
1294
  schedule: ScheduleDecl
1310
1295
 
@@ -1332,16 +1317,6 @@ class Schedule(DelayedDeclerations):
1332
1317
  """
1333
1318
  return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
1334
1319
 
1335
- @contextlib.contextmanager
1336
- def set_current(self) -> Iterator[None]:
1337
- """
1338
- Context manager that will set the current schedule. It will be set back after
1339
- """
1340
- prev_current = Schedule.current
1341
- Schedule.current = self
1342
- yield
1343
- Schedule.current = prev_current
1344
-
1345
1320
 
1346
1321
  @dataclass
1347
1322
  class Ruleset(Schedule):
@@ -1488,9 +1463,12 @@ class Fact:
1488
1463
 
1489
1464
  def __bool__(self) -> bool:
1490
1465
  """
1491
- Returns True if the two sides of an equality are equal in the egraph or the expression is in the egraph.
1466
+ Returns True if the two sides of an equality are structurally equal.
1492
1467
  """
1493
- return (EGraph.current or EGraph()).check_bool(self)
1468
+ if not isinstance(self.fact, EqDecl):
1469
+ msg = "Can only check equality facts"
1470
+ raise TypeError(msg)
1471
+ return self.fact.left == self.fact.right
1494
1472
 
1495
1473
 
1496
1474
  @dataclass
@@ -1876,34 +1854,3 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
1876
1854
  yield
1877
1855
  finally:
1878
1856
  _CURRENT_RULESET.reset(token)
1879
-
1880
-
1881
- T_co = TypeVar("T_co", covariant=True)
1882
-
1883
-
1884
- class _EvalsTo(Protocol, Generic[T_co]):
1885
- def eval(self) -> T_co: ...
1886
-
1887
-
1888
- def try_evaling(schedule: Schedule, expr: Expr, prim_expr: _EvalsTo[T]) -> T:
1889
- """
1890
- Try evaling the expression that will result in a primitive expression being fill.
1891
- if it fails, display the egraph and raise an error.
1892
- """
1893
- egraph = EGraph.current or EGraph()
1894
- with egraph.set_current():
1895
- try:
1896
- return prim_expr.eval()
1897
- except BaseException: # noqa: S110
1898
- pass
1899
- # If this primitive doesn't exist in the egraph, we need to try to create it by
1900
- # registering the expression and running the schedule
1901
- egraph.register(expr)
1902
- egraph.run(Schedule.current or schedule)
1903
- try:
1904
- with egraph.set_current():
1905
- return prim_expr.eval()
1906
- except BaseException as e:
1907
- # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1908
- e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1909
- raise
egglog/examples/bignum.py CHANGED
@@ -12,14 +12,15 @@ x = BigInt(-1234)
12
12
  y = BigInt.from_string("2")
13
13
  z = BigRat(x, y)
14
14
 
15
- assert z.numer.to_string() == "-617"
15
+ egraph = EGraph()
16
+
17
+ assert egraph.extract(z.numer.to_string()).eval() == "-617"
16
18
 
17
19
 
18
20
  @function
19
21
  def bignums(x: BigInt, y: BigInt) -> BigRat: ...
20
22
 
21
23
 
22
- egraph = EGraph()
23
24
  egraph.register(set_(bignums(x, y)).to(z))
24
25
 
25
26
  c = var("c", BigRat)
@@ -28,18 +28,17 @@ expr2 = Num(6) + Num(2) * Num.var("x")
28
28
  a, b, c = vars_("a b c", Num)
29
29
  i, j = vars_("i j", i64)
30
30
 
31
- check(
32
- # Check that these expressions are equal
33
- eq(expr1).to(expr2),
34
- # After running these rules, up to ten times
31
+ egraph = EGraph()
32
+ egraph.register(expr1, expr2)
33
+
34
+ egraph.run(
35
35
  ruleset(
36
36
  rewrite(a + b).to(b + a),
37
37
  rewrite(a * (b + c)).to((a * b) + (a * c)),
38
38
  rewrite(Num(i) + Num(j)).to(Num(i + j)),
39
39
  rewrite(Num(i) * Num(j)).to(Num(i * j)),
40
40
  )
41
- * 10,
42
- # On these two initial expressions
43
- expr1,
44
- expr2,
41
+ * 10
45
42
  )
43
+
44
+ egraph.check(expr1 == expr2)
@@ -29,33 +29,32 @@ egraph = EGraph()
29
29
  xs = MultiSet(Math(1), Math(2), Math(3))
30
30
  egraph.register(xs)
31
31
 
32
- with egraph.set_current():
33
- assert xs == MultiSet(Math(1), Math(3), Math(2))
34
- assert xs != MultiSet(Math(1), Math(1), Math(2), Math(3))
32
+ egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
33
+ egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
35
34
 
36
- assert Counter(xs) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
35
+ assert Counter(egraph.extract(xs).eval()) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
37
36
 
38
- inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
39
- egraph.register(inserted)
40
- assert xs.insert(Math(4)) == inserted
41
37
 
42
- assert xs.contains(Math(1))
43
- assert xs.not_contains(Math(4))
44
- assert Math(1) in xs
45
- assert Math(4) not in xs
38
+ inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
39
+ egraph.register(inserted)
40
+ egraph.check(xs.insert(Math(4)) == inserted)
41
+ egraph.check(xs.contains(Math(1)))
42
+ egraph.check(xs.not_contains(Math(4)))
43
+ assert Math(1) in xs
44
+ assert Math(4) not in xs
46
45
 
47
- assert xs.remove(Math(1)) == MultiSet(Math(2), Math(3))
46
+ egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
48
47
 
49
- assert xs.length() == i64(3)
50
- assert len(xs) == 3
48
+ assert egraph.extract(xs.length()).eval() == 3
49
+ assert len(xs) == 3
51
50
 
52
- assert MultiSet(Math(1), Math(1)).length() == i64(2)
51
+ egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
53
52
 
54
- assert MultiSet(Math(1)).pick() == Math(1)
53
+ egraph.check(MultiSet(Math(1)).pick() == Math(1))
55
54
 
56
- mapped = xs.map(square)
57
- egraph.register(mapped)
58
- egraph.run(math_ruleset)
59
- assert mapped == MultiSet(Math(1), Math(4), Math(9))
55
+ mapped = xs.map(square)
56
+ egraph.register(mapped)
57
+ egraph.run(math_ruleset)
58
+ egraph.check(mapped == MultiSet(Math(1), Math(4), Math(9)))
60
59
 
61
- assert xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3))
60
+ egraph.check(xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3)))
egglog/exp/array_api.py CHANGED
@@ -54,6 +54,7 @@ A: So that when converting from functional to lists we can use the same index fu
54
54
 
55
55
  from __future__ import annotations
56
56
 
57
+ import contextlib
57
58
  import itertools
58
59
  import math
59
60
  import numbers
@@ -61,7 +62,7 @@ import os
61
62
  import sys
62
63
  from copy import copy
63
64
  from types import EllipsisType
64
- from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
65
+ from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
65
66
 
66
67
  import numpy as np
67
68
 
@@ -93,7 +94,7 @@ class Boolean(Expr, ruleset=array_api_ruleset):
93
94
 
94
95
  @method(preserve=True)
95
96
  def eval(self) -> bool:
96
- return try_evaling(array_api_schedule, self, self.to_bool)
97
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool)
97
98
 
98
99
  @property
99
100
  def to_bool(self) -> Bool: ...
@@ -218,7 +219,7 @@ class Int(Expr, ruleset=array_api_ruleset):
218
219
 
219
220
  @method(preserve=True)
220
221
  def eval(self) -> int:
221
- return try_evaling(array_api_schedule, self, self.to_i64)
222
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64)
222
223
 
223
224
  @method(preserve=True)
224
225
  def __index__(self) -> int:
@@ -330,7 +331,7 @@ class Float(Expr, ruleset=array_api_ruleset):
330
331
 
331
332
  @method(preserve=True)
332
333
  def eval(self) -> float:
333
- return try_evaling(array_api_schedule, self, self.to_f64)
334
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64)
334
335
 
335
336
  def abs(self) -> Float: ...
336
337
 
@@ -438,7 +439,7 @@ class TupleInt(Expr, ruleset=array_api_ruleset):
438
439
 
439
440
  @method(preserve=True)
440
441
  def eval(self) -> tuple[Int, ...]:
441
- return try_evaling(array_api_schedule, self, self.to_vec)
442
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
442
443
 
443
444
  def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
444
445
  def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
@@ -586,7 +587,7 @@ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
586
587
 
587
588
  @method(preserve=True)
588
589
  def eval(self) -> tuple[TupleInt, ...]:
589
- return try_evaling(array_api_schedule, self, self.to_vec)
590
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
590
591
 
591
592
  def drop(self, n: Int) -> TupleTupleInt:
592
593
  return TupleTupleInt(self.length() - n, lambda i: self[i + n])
@@ -1282,7 +1283,7 @@ class TupleNDArray(Expr, ruleset=array_api_ruleset):
1282
1283
 
1283
1284
  @method(preserve=True)
1284
1285
  def eval(self) -> tuple[NDArray, ...]:
1285
- return try_evaling(array_api_schedule, self, self.to_vec)
1286
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
1286
1287
 
1287
1288
 
1288
1289
  converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
@@ -1941,3 +1942,42 @@ def array_api_vec_to_cons_ruleset(
1941
1942
 
1942
1943
  array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
1943
1944
  array_api_schedule = array_api_combined_ruleset.saturate()
1945
+
1946
+ _CURRENT_EGRAPH: None | EGraph = None
1947
+
1948
+
1949
+ @contextlib.contextmanager
1950
+ def set_array_api_egraph(egraph: EGraph) -> Iterator[None]:
1951
+ """
1952
+ Context manager that will set the current egraph. It will be set back after.
1953
+ """
1954
+ global _CURRENT_EGRAPH
1955
+ assert _CURRENT_EGRAPH is None
1956
+ _CURRENT_EGRAPH = egraph
1957
+ yield
1958
+ _CURRENT_EGRAPH = None
1959
+
1960
+
1961
+ def _get_current_egraph() -> EGraph:
1962
+ return _CURRENT_EGRAPH or EGraph()
1963
+
1964
+
1965
+ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any:
1966
+ """
1967
+ Try evaling the expression that will result in a primitive expression being fill.
1968
+ if it fails, display the egraph and raise an error.
1969
+ """
1970
+ try:
1971
+ extracted = egraph.extract(prim_expr)
1972
+ except EggSmolError:
1973
+ # If this primitive doesn't exist in the egraph, we need to try to create it by
1974
+ # registering the expression and running the schedule
1975
+ egraph.register(expr)
1976
+ egraph.run(schedule)
1977
+ try:
1978
+ extracted = egraph.extract(prim_expr)
1979
+ except BaseException as e:
1980
+ # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1981
+ e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1982
+ raise
1983
+ return extracted.eval() # type: ignore[attr-defined]
@@ -4,8 +4,8 @@ from typing import TypeVar, cast
4
4
 
5
5
  import numpy as np
6
6
 
7
- from egglog import EGraph, try_evaling
8
- from egglog.exp.array_api import NDArray
7
+ from egglog import EGraph
8
+ from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
9
9
  from egglog.exp.array_api_numba import array_api_numba_schedule
10
10
  from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
11
11
 
@@ -20,8 +20,7 @@ def jit(fn: X) -> X:
20
20
  """
21
21
  egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
22
22
  fn_program = EvalProgram(program, {"np": np})
23
- with egraph.set_current():
24
- fn = cast("X", try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
23
+ fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
25
24
  fn.initial_expr = res # type: ignore[attr-defined]
26
25
  fn.expr = res_optimized # type: ignore[attr-defined]
27
26
  return fn
@@ -32,9 +31,11 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
32
31
  arg1, arg2 = sig.parameters.keys()
33
32
  egraph = EGraph(save_egglog_string=save_egglog_string)
34
33
  with egraph:
35
- with egraph.set_current():
34
+ with set_array_api_egraph(egraph):
36
35
  res = fn(NDArray.var(arg1), NDArray.var(arg2))
37
- res_optimized = egraph.simplify(res, array_api_numba_schedule)
36
+ egraph.register(res)
37
+ egraph.run(array_api_numba_schedule)
38
+ res_optimized = egraph.extract(res)
38
39
 
39
40
  return (
40
41
  egraph,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: egglog
3
- Version: 9.0.0
3
+ Version: 10.0.0
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers
@@ -1,28 +1,28 @@
1
- egglog-9.0.0.dist-info/METADATA,sha256=IvZ414Uz6q7vVueU9cJ2CdukU-Yq9CdoZ2IbJAikiVE,4018
2
- egglog-9.0.0.dist-info/WHEEL,sha256=tAGdc4C2KTz7B2CZ8Jf3DcKSAviAbCg44UH9ma2gYww,96
3
- egglog-9.0.0.dist-info/licenses/LICENSE,sha256=TfaboMVZ81Q6OUaKjU7z6uVjSlcGKclLYcOpgDbm9_s,1091
1
+ egglog-10.0.0.dist-info/METADATA,sha256=aWuZK6v1QbkwYpgfn9_9UnMiwCvPBONgvULMJCkMyFw,4019
2
+ egglog-10.0.0.dist-info/WHEEL,sha256=tAGdc4C2KTz7B2CZ8Jf3DcKSAviAbCg44UH9ma2gYww,96
3
+ egglog-10.0.0.dist-info/licenses/LICENSE,sha256=TfaboMVZ81Q6OUaKjU7z6uVjSlcGKclLYcOpgDbm9_s,1091
4
4
  egglog/bindings.pyi,sha256=fnRiGyHuM-m0fDe91kFryUC09U87f5XXPiY6Lno8bvk,14497
5
- egglog/builtins.py,sha256=GaryZMi5SlyhZ2LS3ggV7mCdrmmJ3EU0GrvvIcMTsEI,29081
5
+ egglog/builtins.py,sha256=ZLXp5txd7CD_Wlb137DFdy_qGbvmvzbfb54tc_KbRfY,30692
6
6
  egglog/config.py,sha256=mALVaxh7zmGrbuyzaVKVmYKcu1lF703QsKJw8AF7gSM,176
7
- egglog/conversion.py,sha256=qhH-3hs9TYnQWooaJOiV_fnol19_Zzxxf6XZ7LHudM4,9895
7
+ egglog/conversion.py,sha256=o6iE32taKz-PBX8aDo5i4jGELdEai-wrzMEZUC1lqsU,10257
8
8
  egglog/declarations.py,sha256=R4U8JhM6ChOMmPEgG6tL2sJ8wJGsIEEVVN5lVY-6k2w,23011
9
- egglog/egraph.py,sha256=t7Z4yaAeI1d-gQgsC3xDIE6U4OQQghN7AH3yqlqJ2gI,66259
9
+ egglog/egraph.py,sha256=ssy2UHt8weA8GXzomqTWMKB389HzMo62k826gas5Fng,64537
10
10
  egglog/egraph_state.py,sha256=wkuHoXmEzcUFspvNl8ZG7_xS1J7XZTK3RPnlZEuvk9k,28720
11
- egglog/examples/bignum.py,sha256=pnXy8J8dQs-QjuLH4icahJFz2TyqRCL0gpBn9YPQg2M,543
11
+ egglog/examples/bignum.py,sha256=r4UeVwpKK2iSZiT5MK4qqTWhWu76CAtSJ5YGDboO9tk,568
12
12
  egglog/examples/bool.py,sha256=pWZTjfXR1cFy3KcihLBU5AF5rn83ImORlhUUJ1YiAXc,733
13
- egglog/examples/eqsat_basic.py,sha256=ORXFYYEDsEZK2IPhHtoFsd-LdjMiQi1nn7kix4Nam0s,1011
13
+ egglog/examples/eqsat_basic.py,sha256=ivYHFD52sjYyl5rWQ60ScxmvcbABYUBlPobDMu6HvN4,913
14
14
  egglog/examples/fib.py,sha256=wAn-PjazxgHDkXAU4o2xTk_GtM_iGL0biV66vWM1st4,520
15
15
  egglog/examples/higher_order_functions.py,sha256=5P-_10y8eTh-z3bze5GVzKHyFTeefTM774haF8KYGig,1164
16
16
  egglog/examples/lambda_.py,sha256=hQBOaSw_yorNcbkQVu2EhgSc0IZNWIny7asaOlcUk9s,8496
17
17
  egglog/examples/matrix.py,sha256=_zmjgfFr2O_LjTcsTD-45_38Y_M1sP3AV39K6oFxAdw,5136
18
- egglog/examples/multiset.py,sha256=4qVjs74SL_0T4MYOyr-3TomhMXuke2DVrQFk-CQJT04,1499
18
+ egglog/examples/multiset.py,sha256=kSUFiN4qs8WXvexPCecqk3o5DzdywXacDmIiC5IDKCM,1507
19
19
  egglog/examples/ndarrays.py,sha256=T-wwef-n-3LDSjaO35zA8AZH5DXFFqq0XBSCQKEXV6E,4186
20
20
  egglog/examples/README.rst,sha256=QrbfmivODBvUvmY3-dHarcbC6bEvwoqAfTDhiI-aJxU,237
21
21
  egglog/examples/resolution.py,sha256=sKkbRI_v9XkQM0DriacKLINqKKDqYGFhvMCAS9tZbTA,2203
22
22
  egglog/examples/schedule_demo.py,sha256=iJtIbcLaZ7zK8UalY0z7KAKMqYjQx0MKTsNF24lKtik,652
23
23
  egglog/examples/__init__.py,sha256=KuhaJFOyz_rpUvEqZubsgLnv6rhQNE_AVFXA6bUnpdY,34
24
- egglog/exp/array_api.py,sha256=LDrNkrMznd937_It1sKn9H_gCnJMELC0jsGUjcjjA3Y,64166
25
- egglog/exp/array_api_jit.py,sha256=9eDVR0hCLeM3RLkIOrfitjX1bCvwoWZG1LF5nmRLWhc,1540
24
+ egglog/exp/array_api.py,sha256=Y-mG8tC8ihFkeIbNDsDyZl7-7zP6YFxZGS4rfsWOG-8,65661
25
+ egglog/exp/array_api_jit.py,sha256=c0lN-411M9LmEz8uM0ELXTtfUPMJDZFmy81-x3ck6PY,1591
26
26
  egglog/exp/array_api_loopnest.py,sha256=1KZ8kWtv00TVCbzqzWG92vz8Y1pBlPOLRIrJfxyDs9s,2518
27
27
  egglog/exp/array_api_numba.py,sha256=g0Yb7auB2MtJHgSzn7iVnsxn6ffL4FHggKvbEuuDwZs,2765
28
28
  egglog/exp/array_api_program_gen.py,sha256=qXaaW_Koz4P-KmQcyRHY1NDLCJL3O7NnE6iq2NX4yZU,22252
@@ -39,6 +39,6 @@ egglog/type_constraint_solver.py,sha256=kSJf8S1CxkicpPy_B2cNP9XiTf-d4TQNd7_IwfG-
39
39
  egglog/visualizer.css,sha256=DDmxZfLmFygwVkQwel0gGUiM5BZqgunawWeSfaRYBr8,259399
40
40
  egglog/visualizer.js,sha256=kDkYzE5cLLhmpNaWqSfQWzD0SlSsNe9KzT_GpNqhWgo,3789092
41
41
  egglog/visualizer_widget.py,sha256=TGkf61P_f4EvCFcyqwQC_8XYUkujFhcQnEG9vHWLPC4,1358
42
- egglog/__init__.py,sha256=YBsuQdA1JYMtdz79jVNGY1DXDPaWChjiba0Bu3RYfr0,276
43
- egglog/bindings.cp311-win_amd64.pyd,sha256=BYXenTFpW_C0rpAnUl30v5LQVQtW3b2Ulw1Y5LFoV8U,4500480
44
- egglog-9.0.0.dist-info/RECORD,,
42
+ egglog/__init__.py,sha256=TZ3IMVKoz5Qr97VFe-YT7oyT2NACikTis21PpVlSCmg,326
43
+ egglog/bindings.cp311-win_amd64.pyd,sha256=lydYMT5DJB3jAsg_IpweX1BgjWbVm5Z7cPQDaakldrY,4500992
44
+ egglog-10.0.0.dist-info/RECORD,,