egglog 9.0.1__cp311-cp311-manylinux_2_17_ppc64.manylinux2014_ppc64.whl → 10.0.1__cp311-cp311-manylinux_2_17_ppc64.manylinux2014_ppc64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/__init__.py CHANGED
@@ -3,6 +3,7 @@ Package for creating e-graphs in Python.
3
3
  """
4
4
 
5
5
  from . import config, ipython_magic # noqa: F401
6
+ from .bindings import EggSmolError # noqa: F401
6
7
  from .builtins import * # noqa: UP029
7
8
  from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
8
9
  from .egraph import *
egglog/builtins.py CHANGED
@@ -5,6 +5,7 @@ Builtin sorts and function to egg.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
+ from collections.abc import Callable
8
9
  from fractions import Fraction
9
10
  from functools import partial, reduce
10
11
  from types import FunctionType, MethodType
@@ -12,17 +13,15 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union,
12
13
 
13
14
  from typing_extensions import TypeVarTuple, Unpack
14
15
 
15
- from . import bindings
16
16
  from .conversion import convert, converter, get_type_args
17
17
  from .declarations import *
18
- from .egraph import BaseExpr, BuiltinExpr, EGraph, expr_fact, function, get_current_ruleset, method
19
- from .egraph_state import GLOBAL_PY_OBJECT_SORT
18
+ from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
20
19
  from .functionalize import functionalize
21
20
  from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
22
21
  from .thunk import Thunk
23
22
 
24
23
  if TYPE_CHECKING:
25
- from collections.abc import Callable, Iterator
24
+ from collections.abc import Iterator
26
25
 
27
26
 
28
27
  __all__ = [
@@ -32,6 +31,7 @@ __all__ = [
32
31
  "BigRatLike",
33
32
  "Bool",
34
33
  "BoolLike",
34
+ "BuiltinEvalError",
35
35
  "Map",
36
36
  "MapLike",
37
37
  "MultiSet",
@@ -56,6 +56,17 @@ __all__ = [
56
56
  ]
57
57
 
58
58
 
59
+ class BuiltinEvalError(Exception):
60
+ """
61
+ Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
62
+
63
+ Try extracting this expression first.
64
+ """
65
+
66
+ def __str__(self) -> str:
67
+ return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}"
68
+
69
+
59
70
  class Unit(BuiltinExpr, egg_sort="Unit"):
60
71
  """
61
72
  The unit type. This is used to reprsent if a value exists in the e-graph or not.
@@ -72,8 +83,8 @@ class String(BuiltinExpr):
72
83
  @method(preserve=True)
73
84
  def eval(self) -> str:
74
85
  value = _extract_lit(self)
75
- assert isinstance(value, bindings.String)
76
- return value.value
86
+ assert isinstance(value, str)
87
+ return value
77
88
 
78
89
  def __init__(self, value: str) -> None: ...
79
90
 
@@ -97,8 +108,8 @@ class Bool(BuiltinExpr, egg_sort="bool"):
97
108
  @method(preserve=True)
98
109
  def eval(self) -> bool:
99
110
  value = _extract_lit(self)
100
- assert isinstance(value, bindings.Bool)
101
- return value.value
111
+ assert isinstance(value, bool)
112
+ return value
102
113
 
103
114
  @method(preserve=True)
104
115
  def __bool__(self) -> bool:
@@ -132,8 +143,8 @@ class i64(BuiltinExpr): # noqa: N801
132
143
  @method(preserve=True)
133
144
  def eval(self) -> int:
134
145
  value = _extract_lit(self)
135
- assert isinstance(value, bindings.Int)
136
- return value.value
146
+ assert isinstance(value, int)
147
+ return value
137
148
 
138
149
  @method(preserve=True)
139
150
  def __index__(self) -> int:
@@ -251,8 +262,8 @@ class f64(BuiltinExpr): # noqa: N801
251
262
  @method(preserve=True)
252
263
  def eval(self) -> float:
253
264
  value = _extract_lit(self)
254
- assert isinstance(value, bindings.Float)
255
- return value.value
265
+ assert isinstance(value, float)
266
+ return value
256
267
 
257
268
  @method(preserve=True)
258
269
  def __float__(self) -> float:
@@ -340,9 +351,12 @@ class Map(BuiltinExpr, Generic[T, V]):
340
351
  expr = cast("RuntimeExpr", self)
341
352
  d = {}
342
353
  while call.callable != ClassMethodRef("Map", "empty"):
343
- assert call.callable == MethodRef("Map", "insert")
354
+ msg = "Map can only be evaluated if it is empty or a series of inserts."
355
+ if call.callable != MethodRef("Map", "insert"):
356
+ raise BuiltinEvalError(msg)
344
357
  call_typed, k_typed, v_typed = call.args
345
- assert isinstance(call_typed.expr, CallDecl)
358
+ if not isinstance(call_typed.expr, CallDecl):
359
+ raise BuiltinEvalError(msg)
346
360
  k = cast("T", expr.__with_expr__(k_typed))
347
361
  v = cast("V", expr.__with_expr__(v_typed))
348
362
  d[k] = v
@@ -404,7 +418,9 @@ class Set(BuiltinExpr, Generic[T]):
404
418
  @method(preserve=True)
405
419
  def eval(self) -> set[T]:
406
420
  call = _extract_call(self)
407
- assert call.callable == InitRef("Set")
421
+ if call.callable != InitRef("Set"):
422
+ msg = "Set can only be initialized with the Set constructor."
423
+ raise BuiltinEvalError(msg)
408
424
  return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
409
425
 
410
426
  @method(preserve=True)
@@ -466,7 +482,9 @@ class MultiSet(BuiltinExpr, Generic[T]):
466
482
  @method(preserve=True)
467
483
  def eval(self) -> list[T]:
468
484
  call = _extract_call(self)
469
- assert call.callable == InitRef("MultiSet")
485
+ if call.callable != InitRef("MultiSet"):
486
+ msg = "MultiSet can only be initialized with the MultiSet constructor."
487
+ raise BuiltinEvalError(msg)
470
488
  return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
471
489
 
472
490
  @method(preserve=True)
@@ -513,11 +531,15 @@ class Rational(BuiltinExpr):
513
531
  @method(preserve=True)
514
532
  def eval(self) -> Fraction:
515
533
  call = _extract_call(self)
516
- assert call.callable == InitRef("Rational")
534
+ if call.callable != InitRef("Rational"):
535
+ msg = "Rational can only be initialized with the Rational constructor."
536
+ raise BuiltinEvalError(msg)
517
537
 
518
538
  def _to_int(e: TypedExprDecl) -> int:
519
539
  expr = e.expr
520
- assert isinstance(expr, LitDecl)
540
+ if not isinstance(expr, LitDecl):
541
+ msg = "Rational can only be initialized with literals"
542
+ raise BuiltinEvalError(msg)
521
543
  assert isinstance(expr.value, int)
522
544
  return expr.value
523
545
 
@@ -596,9 +618,13 @@ class BigInt(BuiltinExpr):
596
618
  @method(preserve=True)
597
619
  def eval(self) -> int:
598
620
  call = _extract_call(self)
599
- assert call.callable == ClassMethodRef("BigInt", "from_string")
621
+ if call.callable != ClassMethodRef("BigInt", "from_string"):
622
+ msg = "BigInt can only be initialized with the BigInt constructor."
623
+ raise BuiltinEvalError(msg)
600
624
  (s,) = call.args
601
- assert isinstance(s.expr, LitDecl)
625
+ if not isinstance(s.expr, LitDecl):
626
+ msg = "BigInt can only be initialized with literals"
627
+ raise BuiltinEvalError(msg)
602
628
  assert isinstance(s.expr.value, str)
603
629
  return int(s.expr.value)
604
630
 
@@ -717,14 +743,19 @@ class BigRat(BuiltinExpr):
717
743
  @method(preserve=True)
718
744
  def eval(self) -> Fraction:
719
745
  call = _extract_call(self)
720
- assert call.callable == InitRef("BigRat")
746
+ if call.callable != InitRef("BigRat"):
747
+ msg = "BigRat can only be initialized with the BigRat constructor."
748
+ raise BuiltinEvalError(msg)
721
749
 
722
750
  def _to_fraction(e: TypedExprDecl) -> Fraction:
723
751
  expr = e.expr
724
- assert isinstance(expr, CallDecl)
725
- assert expr.callable == ClassMethodRef("BigInt", "from_string")
752
+ if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"):
753
+ msg = "BigRat can only be initialized BigInt strings"
754
+ raise BuiltinEvalError(msg)
726
755
  (s,) = expr.args
727
- assert isinstance(s.expr, LitDecl)
756
+ if not isinstance(s.expr, LitDecl):
757
+ msg = "BigInt can only be initialized with literals"
758
+ raise BuiltinEvalError(msg)
728
759
  assert isinstance(s.expr.value, str)
729
760
  return Fraction(s.expr.value)
730
761
 
@@ -821,7 +852,10 @@ class Vec(BuiltinExpr, Generic[T]):
821
852
  call = _extract_call(self)
822
853
  if call.callable == ClassMethodRef("Vec", "empty"):
823
854
  return ()
824
- assert call.callable == InitRef("Vec")
855
+
856
+ if call.callable != InitRef("Vec"):
857
+ msg = "Vec can only be initialized with the Vec constructor."
858
+ raise BuiltinEvalError(msg)
825
859
  return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
826
860
 
827
861
  @method(preserve=True)
@@ -889,10 +923,11 @@ VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
889
923
  class PyObject(BuiltinExpr):
890
924
  @method(preserve=True)
891
925
  def eval(self) -> object:
892
- report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", self), 0)
893
- assert isinstance(report, bindings.Best)
894
- expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan())
895
- return GLOBAL_PY_OBJECT_SORT.load(expr)
926
+ expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
927
+ if not isinstance(expr, PyObjectDecl):
928
+ msg = "PyObject can only be evaluated if it is a PyObject literal"
929
+ raise BuiltinEvalError(msg)
930
+ return expr.value
896
931
 
897
932
  def __init__(self, value: object) -> None: ...
898
933
 
@@ -1027,22 +1062,23 @@ def value_to_annotation(a: object) -> type | None:
1027
1062
  converter(FunctionType, UnstableFn, _convert_function)
1028
1063
 
1029
1064
 
1030
- def _extract_lit(e: BaseExpr) -> bindings._Literal:
1065
+ def _extract_lit(e: BaseExpr) -> LitType:
1031
1066
  """
1032
1067
  Special case extracting literals to make this faster by using termdag directly.
1033
1068
  """
1034
- report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", e), 0)
1035
- assert isinstance(report, bindings.Best)
1036
- term = report.term
1037
- assert isinstance(term, bindings.TermLit)
1038
- return term.value
1069
+ expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1070
+ if not isinstance(expr, LitDecl):
1071
+ msg = "Expected a literal"
1072
+ raise BuiltinEvalError(msg)
1073
+ return expr.value
1039
1074
 
1040
1075
 
1041
1076
  def _extract_call(e: BaseExpr) -> CallDecl:
1042
1077
  """
1043
1078
  Extracts the call form of an expression
1044
1079
  """
1045
- extracted = cast("RuntimeExpr", (EGraph.current or EGraph()).extract(e))
1046
- expr = extracted.__egg_typed_expr__.expr
1047
- assert isinstance(expr, CallDecl)
1080
+ expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1081
+ if not isinstance(expr, CallDecl):
1082
+ msg = "Expected a call expression"
1083
+ raise BuiltinEvalError(msg)
1048
1084
  return expr
egglog/conversion.py CHANGED
@@ -149,6 +149,12 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
149
149
  decls = _retrieve_conversion_decls()
150
150
  a_tp = _get_tp(a)
151
151
  b_tp = _get_tp(b)
152
+ # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153
+ if not (
154
+ (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155
+ or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156
+ ):
157
+ raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
152
158
  a_converts_to = {
153
159
  to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
154
160
  }
egglog/egraph.py CHANGED
@@ -4,7 +4,7 @@ import contextlib
4
4
  import inspect
5
5
  import pathlib
6
6
  import tempfile
7
- from collections.abc import Callable, Generator, Iterable, Iterator
7
+ from collections.abc import Callable, Generator, Iterable
8
8
  from contextvars import ContextVar
9
9
  from dataclasses import InitVar, dataclass, field
10
10
  from functools import partial
@@ -17,7 +17,6 @@ from typing import (
17
17
  Generic,
18
18
  Literal,
19
19
  Never,
20
- Protocol,
21
20
  TypeAlias,
22
21
  TypedDict,
23
22
  TypeVar,
@@ -85,7 +84,6 @@ __all__ = [
85
84
  "set_",
86
85
  "simplify",
87
86
  "subsume",
88
- "try_evaling",
89
87
  "union",
90
88
  "unstable_combine_rulesets",
91
89
  "var",
@@ -847,11 +845,10 @@ class EGraph:
847
845
  Can run actions, check facts, run schedules, or extract minimal cost expressions.
848
846
  """
849
847
 
850
- current: ClassVar[EGraph | None] = None
851
848
  seminaive: InitVar[bool] = True
852
849
  save_egglog_string: InitVar[bool] = False
853
850
 
854
- _state: EGraphState = field(init=False)
851
+ _state: EGraphState = field(init=False, repr=False)
855
852
  # For pushing/popping with egglog
856
853
  _state_stack: list[EGraphState] = field(default_factory=list, repr=False)
857
854
  # For storing the global "current" egraph
@@ -1200,16 +1197,6 @@ class EGraph:
1200
1197
  if visualize:
1201
1198
  VisualizerWidget(egraphs=egraphs).display_or_open()
1202
1199
 
1203
- @contextlib.contextmanager
1204
- def set_current(self) -> Iterator[None]:
1205
- """
1206
- Context manager that will set the current egraph. It will be set back after.
1207
- """
1208
- prev_current = EGraph.current
1209
- EGraph.current = self
1210
- yield
1211
- EGraph.current = prev_current
1212
-
1213
1200
  @property
1214
1201
  def _egraph(self) -> bindings.EGraph:
1215
1202
  return self._state.egraph
@@ -1303,8 +1290,6 @@ class Schedule(DelayedDeclerations):
1303
1290
  A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
1304
1291
  """
1305
1292
 
1306
- current: ClassVar[Schedule | None] = None
1307
-
1308
1293
  # Defer declerations so that we can have rule generators that used not yet defined yet
1309
1294
  schedule: ScheduleDecl
1310
1295
 
@@ -1332,16 +1317,6 @@ class Schedule(DelayedDeclerations):
1332
1317
  """
1333
1318
  return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
1334
1319
 
1335
- @contextlib.contextmanager
1336
- def set_current(self) -> Iterator[None]:
1337
- """
1338
- Context manager that will set the current schedule. It will be set back after
1339
- """
1340
- prev_current = Schedule.current
1341
- Schedule.current = self
1342
- yield
1343
- Schedule.current = prev_current
1344
-
1345
1320
 
1346
1321
  @dataclass
1347
1322
  class Ruleset(Schedule):
@@ -1488,9 +1463,12 @@ class Fact:
1488
1463
 
1489
1464
  def __bool__(self) -> bool:
1490
1465
  """
1491
- Returns True if the two sides of an equality are equal in the egraph or the expression is in the egraph.
1466
+ Returns True if the two sides of an equality are structurally equal.
1492
1467
  """
1493
- return (EGraph.current or EGraph()).check_bool(self)
1468
+ if not isinstance(self.fact, EqDecl):
1469
+ msg = "Can only check equality facts"
1470
+ raise TypeError(msg)
1471
+ return self.fact.left == self.fact.right
1494
1472
 
1495
1473
 
1496
1474
  @dataclass
@@ -1839,12 +1817,13 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) ->
1839
1817
  """
1840
1818
  Returns a thunk which will call the function with variables of the type and name of the arguments.
1841
1819
  """
1842
- # Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
1843
- # but not in the globals
1844
- globals = gen.__globals__.copy()
1845
- if "Callable" not in globals:
1846
- globals["Callable"] = Callable
1847
- hints = get_type_hints(gen, globals, frame.f_locals)
1820
+ # Need to manually pass in the frame locals from the generator, because otherwise classes defined within function
1821
+ # will not be available in the annotations
1822
+ # combine locals and globals so that they are the same dict. Otherwise get_type_hints will go through the wrong
1823
+ # path and give an error for the test
1824
+ # python/tests/test_no_import_star.py::test_no_import_star_rulesset
1825
+ combined = {**gen.__globals__, **frame.f_locals}
1826
+ hints = get_type_hints(gen, combined, combined)
1848
1827
  args = [_var(p.name, hints[p.name]) for p in signature(gen).parameters.values()]
1849
1828
  return list(gen(*args)) # type: ignore[misc]
1850
1829
 
@@ -1876,34 +1855,3 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
1876
1855
  yield
1877
1856
  finally:
1878
1857
  _CURRENT_RULESET.reset(token)
1879
-
1880
-
1881
- T_co = TypeVar("T_co", covariant=True)
1882
-
1883
-
1884
- class _EvalsTo(Protocol, Generic[T_co]):
1885
- def eval(self) -> T_co: ...
1886
-
1887
-
1888
- def try_evaling(schedule: Schedule, expr: Expr, prim_expr: _EvalsTo[T]) -> T:
1889
- """
1890
- Try evaling the expression that will result in a primitive expression being fill.
1891
- if it fails, display the egraph and raise an error.
1892
- """
1893
- egraph = EGraph.current or EGraph()
1894
- with egraph.set_current():
1895
- try:
1896
- return prim_expr.eval()
1897
- except BaseException: # noqa: S110
1898
- pass
1899
- # If this primitive doesn't exist in the egraph, we need to try to create it by
1900
- # registering the expression and running the schedule
1901
- egraph.register(expr)
1902
- egraph.run(Schedule.current or schedule)
1903
- try:
1904
- with egraph.set_current():
1905
- return prim_expr.eval()
1906
- except BaseException as e:
1907
- # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1908
- e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1909
- raise
egglog/examples/bignum.py CHANGED
@@ -12,14 +12,15 @@ x = BigInt(-1234)
12
12
  y = BigInt.from_string("2")
13
13
  z = BigRat(x, y)
14
14
 
15
- assert z.numer.to_string() == "-617"
15
+ egraph = EGraph()
16
+
17
+ assert egraph.extract(z.numer.to_string()).eval() == "-617"
16
18
 
17
19
 
18
20
  @function
19
21
  def bignums(x: BigInt, y: BigInt) -> BigRat: ...
20
22
 
21
23
 
22
- egraph = EGraph()
23
24
  egraph.register(set_(bignums(x, y)).to(z))
24
25
 
25
26
  c = var("c", BigRat)
@@ -28,18 +28,17 @@ expr2 = Num(6) + Num(2) * Num.var("x")
28
28
  a, b, c = vars_("a b c", Num)
29
29
  i, j = vars_("i j", i64)
30
30
 
31
- check(
32
- # Check that these expressions are equal
33
- eq(expr1).to(expr2),
34
- # After running these rules, up to ten times
31
+ egraph = EGraph()
32
+ egraph.register(expr1, expr2)
33
+
34
+ egraph.run(
35
35
  ruleset(
36
36
  rewrite(a + b).to(b + a),
37
37
  rewrite(a * (b + c)).to((a * b) + (a * c)),
38
38
  rewrite(Num(i) + Num(j)).to(Num(i + j)),
39
39
  rewrite(Num(i) * Num(j)).to(Num(i * j)),
40
40
  )
41
- * 10,
42
- # On these two initial expressions
43
- expr1,
44
- expr2,
41
+ * 10
45
42
  )
43
+
44
+ egraph.check(expr1 == expr2)
@@ -6,13 +6,10 @@ Higher Order Functions
6
6
 
7
7
  from __future__ import annotations
8
8
 
9
- from typing import TYPE_CHECKING
9
+ from collections.abc import Callable
10
10
 
11
11
  from egglog import *
12
12
 
13
- if TYPE_CHECKING:
14
- from collections.abc import Callable
15
-
16
13
 
17
14
  class Math(Expr):
18
15
  def __init__(self, i: i64Like) -> None: ...
@@ -7,12 +7,11 @@ Lambda Calculus
7
7
 
8
8
  from __future__ import annotations
9
9
 
10
- from typing import TYPE_CHECKING, ClassVar
10
+ from collections.abc import Callable
11
+ from typing import ClassVar
11
12
 
12
13
  from egglog import *
13
-
14
- if TYPE_CHECKING:
15
- from collections.abc import Callable
14
+ from egglog import Expr
16
15
 
17
16
 
18
17
  class Val(Expr):
@@ -29,33 +29,32 @@ egraph = EGraph()
29
29
  xs = MultiSet(Math(1), Math(2), Math(3))
30
30
  egraph.register(xs)
31
31
 
32
- with egraph.set_current():
33
- assert xs == MultiSet(Math(1), Math(3), Math(2))
34
- assert xs != MultiSet(Math(1), Math(1), Math(2), Math(3))
32
+ egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
33
+ egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
35
34
 
36
- assert Counter(xs) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
35
+ assert Counter(egraph.extract(xs).eval()) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
37
36
 
38
- inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
39
- egraph.register(inserted)
40
- assert xs.insert(Math(4)) == inserted
41
37
 
42
- assert xs.contains(Math(1))
43
- assert xs.not_contains(Math(4))
44
- assert Math(1) in xs
45
- assert Math(4) not in xs
38
+ inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
39
+ egraph.register(inserted)
40
+ egraph.check(xs.insert(Math(4)) == inserted)
41
+ egraph.check(xs.contains(Math(1)))
42
+ egraph.check(xs.not_contains(Math(4)))
43
+ assert Math(1) in xs
44
+ assert Math(4) not in xs
46
45
 
47
- assert xs.remove(Math(1)) == MultiSet(Math(2), Math(3))
46
+ egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
48
47
 
49
- assert xs.length() == i64(3)
50
- assert len(xs) == 3
48
+ assert egraph.extract(xs.length()).eval() == 3
49
+ assert len(xs) == 3
51
50
 
52
- assert MultiSet(Math(1), Math(1)).length() == i64(2)
51
+ egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
53
52
 
54
- assert MultiSet(Math(1)).pick() == Math(1)
53
+ egraph.check(MultiSet(Math(1)).pick() == Math(1))
55
54
 
56
- mapped = xs.map(square)
57
- egraph.register(mapped)
58
- egraph.run(math_ruleset)
59
- assert mapped == MultiSet(Math(1), Math(4), Math(9))
55
+ mapped = xs.map(square)
56
+ egraph.register(mapped)
57
+ egraph.run(math_ruleset)
58
+ egraph.check(mapped == MultiSet(Math(1), Math(4), Math(9)))
60
59
 
61
- assert xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3))
60
+ egraph.check(xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3)))
egglog/exp/array_api.py CHANGED
@@ -54,14 +54,16 @@ A: So that when converting from functional to lists we can use the same index fu
54
54
 
55
55
  from __future__ import annotations
56
56
 
57
+ import contextlib
57
58
  import itertools
58
59
  import math
59
60
  import numbers
60
61
  import os
61
62
  import sys
63
+ from collections.abc import Callable
62
64
  from copy import copy
63
65
  from types import EllipsisType
64
- from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
66
+ from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
65
67
 
66
68
  import numpy as np
67
69
 
@@ -71,9 +73,10 @@ from egglog.runtime import RuntimeExpr
71
73
  from .program_gen import *
72
74
 
73
75
  if TYPE_CHECKING:
74
- from collections.abc import Callable, Iterator
76
+ from collections.abc import Iterator
75
77
  from types import ModuleType
76
78
 
79
+
77
80
  # Pretend that exprs are numbers b/c sklearn does isinstance checks
78
81
  numbers.Integral.register(RuntimeExpr)
79
82
 
@@ -93,7 +96,7 @@ class Boolean(Expr, ruleset=array_api_ruleset):
93
96
 
94
97
  @method(preserve=True)
95
98
  def eval(self) -> bool:
96
- return try_evaling(array_api_schedule, self, self.to_bool)
99
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool)
97
100
 
98
101
  @property
99
102
  def to_bool(self) -> Bool: ...
@@ -218,7 +221,7 @@ class Int(Expr, ruleset=array_api_ruleset):
218
221
 
219
222
  @method(preserve=True)
220
223
  def eval(self) -> int:
221
- return try_evaling(array_api_schedule, self, self.to_i64)
224
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64)
222
225
 
223
226
  @method(preserve=True)
224
227
  def __index__(self) -> int:
@@ -330,7 +333,7 @@ class Float(Expr, ruleset=array_api_ruleset):
330
333
 
331
334
  @method(preserve=True)
332
335
  def eval(self) -> float:
333
- return try_evaling(array_api_schedule, self, self.to_f64)
336
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64)
334
337
 
335
338
  def abs(self) -> Float: ...
336
339
 
@@ -438,7 +441,7 @@ class TupleInt(Expr, ruleset=array_api_ruleset):
438
441
 
439
442
  @method(preserve=True)
440
443
  def eval(self) -> tuple[Int, ...]:
441
- return try_evaling(array_api_schedule, self, self.to_vec)
444
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
442
445
 
443
446
  def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
444
447
  def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
@@ -586,7 +589,7 @@ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
586
589
 
587
590
  @method(preserve=True)
588
591
  def eval(self) -> tuple[TupleInt, ...]:
589
- return try_evaling(array_api_schedule, self, self.to_vec)
592
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
590
593
 
591
594
  def drop(self, n: Int) -> TupleTupleInt:
592
595
  return TupleTupleInt(self.length() - n, lambda i: self[i + n])
@@ -1282,7 +1285,7 @@ class TupleNDArray(Expr, ruleset=array_api_ruleset):
1282
1285
 
1283
1286
  @method(preserve=True)
1284
1287
  def eval(self) -> tuple[NDArray, ...]:
1285
- return try_evaling(array_api_schedule, self, self.to_vec)
1288
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
1286
1289
 
1287
1290
 
1288
1291
  converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
@@ -1941,3 +1944,42 @@ def array_api_vec_to_cons_ruleset(
1941
1944
 
1942
1945
  array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
1943
1946
  array_api_schedule = array_api_combined_ruleset.saturate()
1947
+
1948
+ _CURRENT_EGRAPH: None | EGraph = None
1949
+
1950
+
1951
+ @contextlib.contextmanager
1952
+ def set_array_api_egraph(egraph: EGraph) -> Iterator[None]:
1953
+ """
1954
+ Context manager that will set the current egraph. It will be set back after.
1955
+ """
1956
+ global _CURRENT_EGRAPH
1957
+ assert _CURRENT_EGRAPH is None
1958
+ _CURRENT_EGRAPH = egraph
1959
+ yield
1960
+ _CURRENT_EGRAPH = None
1961
+
1962
+
1963
+ def _get_current_egraph() -> EGraph:
1964
+ return _CURRENT_EGRAPH or EGraph()
1965
+
1966
+
1967
+ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any:
1968
+ """
1969
+ Try evaling the expression that will result in a primitive expression being fill.
1970
+ if it fails, display the egraph and raise an error.
1971
+ """
1972
+ try:
1973
+ extracted = egraph.extract(prim_expr)
1974
+ except EggSmolError:
1975
+ # If this primitive doesn't exist in the egraph, we need to try to create it by
1976
+ # registering the expression and running the schedule
1977
+ egraph.register(expr)
1978
+ egraph.run(schedule)
1979
+ try:
1980
+ extracted = egraph.extract(prim_expr)
1981
+ except BaseException as e:
1982
+ # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1983
+ e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1984
+ raise
1985
+ return extracted.eval() # type: ignore[attr-defined]
@@ -4,8 +4,8 @@ from typing import TypeVar, cast
4
4
 
5
5
  import numpy as np
6
6
 
7
- from egglog import EGraph, try_evaling
8
- from egglog.exp.array_api import NDArray
7
+ from egglog import EGraph
8
+ from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
9
9
  from egglog.exp.array_api_numba import array_api_numba_schedule
10
10
  from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
11
11
 
@@ -20,8 +20,7 @@ def jit(fn: X) -> X:
20
20
  """
21
21
  egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
22
22
  fn_program = EvalProgram(program, {"np": np})
23
- with egraph.set_current():
24
- fn = cast("X", try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
23
+ fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
25
24
  fn.initial_expr = res # type: ignore[attr-defined]
26
25
  fn.expr = res_optimized # type: ignore[attr-defined]
27
26
  return fn
@@ -32,9 +31,11 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
32
31
  arg1, arg2 = sig.parameters.keys()
33
32
  egraph = EGraph(save_egglog_string=save_egglog_string)
34
33
  with egraph:
35
- with egraph.set_current():
34
+ with set_array_api_egraph(egraph):
36
35
  res = fn(NDArray.var(arg1), NDArray.var(arg2))
37
- res_optimized = egraph.simplify(res, array_api_numba_schedule)
36
+ egraph.register(res)
37
+ egraph.run(array_api_numba_schedule)
38
+ res_optimized = egraph.extract(res)
38
39
 
39
40
  return (
40
41
  egraph,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: egglog
3
- Version: 9.0.1
3
+ Version: 10.0.1
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers
@@ -1,16 +1,16 @@
1
- egglog-9.0.1.dist-info/METADATA,sha256=avDfv1AMXRxm7eQcAhZYWbEm0UVuFB68FYUcT_jzUVQ,4008
2
- egglog-9.0.1.dist-info/WHEEL,sha256=TXoyImGYix7Rgz9vFJf5yHjHurpJZift2bEePFE703I,127
3
- egglog-9.0.1.dist-info/licenses/LICENSE,sha256=w7VlVv5O_FPZRo8Z-4Zb_q7D5ac3YDs8JUkMZ4Gq9CE,1070
1
+ egglog-10.0.1.dist-info/METADATA,sha256=BKd5QxorELhavgKQf4E6fDVQ8FpCCifHy4v_BxFiXT0,4009
2
+ egglog-10.0.1.dist-info/WHEEL,sha256=TXoyImGYix7Rgz9vFJf5yHjHurpJZift2bEePFE703I,127
3
+ egglog-10.0.1.dist-info/licenses/LICENSE,sha256=w7VlVv5O_FPZRo8Z-4Zb_q7D5ac3YDs8JUkMZ4Gq9CE,1070
4
4
  egglog/visualizer_widget.py,sha256=LtVfzOtv2WeKtNuILQQ_9SOHWvRr8YdBYQDKQSgry_s,1319
5
5
  egglog/egraph_state.py,sha256=KxP4aXfh21vVqx1HbQMtb2PRxJg4Gzp1Po6Jp5Zppk0,28086
6
- egglog/builtins.py,sha256=yMvPxcYKjUxIFtDzKUr1XUwZSN5pLlcR66K7BCG5eS4,28095
6
+ egglog/builtins.py,sha256=7cZk-MxUvOA3OT9RSX7WkYBoj1whAfp3tOB3a82ZTH0,29636
7
7
  egglog/exp/array_api_numba.py,sha256=X3H1TnCjPL92uVm6OvcWMJ11IeorAE58zWiOX6huPv4,2696
8
8
  egglog/exp/program_gen.py,sha256=9q8-ihkXLgeRU_BrINF3t3gThdS74KCzTo4UiRIR6Dk,13027
9
9
  egglog/exp/array_api_program_gen.py,sha256=0FpWZZJ3VqIod9vIYfQYUIiyswfj4h4K_4s6GZb8SJY,21742
10
10
  egglog/exp/siu_examples.py,sha256=yZ-sgH2Y12iTdwBUumP7D2OtCGL83M6pPW7PMobVFXc,719
11
11
  egglog/exp/__init__.py,sha256=nPtzrH1bz1LVZhZCuS0S9Qild8m5gEikjOVqWAFIa88,49
12
- egglog/exp/array_api.py,sha256=V-xiDokdp2QOOxduXwDO39BazxjG9mFMeD5tuIMhy-s,62223
13
- egglog/exp/array_api_jit.py,sha256=RzlMyEQ8pAKrQSjJ93jAHK-i3pFu6OXgs6ljDpPHqtQ,1496
12
+ egglog/exp/array_api.py,sha256=jKQoDg3bYg0TGGEJLwXF3okWDkgi4JA4ayXzwf-9XB4,63706
13
+ egglog/exp/array_api_jit.py,sha256=90RmyivRoCKzVtiYWnTBkK2q2FDkD2p1iq7vzbh68b4,1546
14
14
  egglog/exp/array_api_loopnest.py,sha256=-kbyorlGxvlaNsLx1nmLfEZHQM7VMEBwSKtV0l-bs0g,2444
15
15
  egglog/type_constraint_solver.py,sha256=_y52JJoxWJDzokvS07-QHJ4sQpvWE86_Dk43RAUqhAk,4516
16
16
  egglog/runtime.py,sha256=vyzN52njMucBaf8sRl4A5FVBrqDMfNaV8PlG8lzNxhc,25741
@@ -18,27 +18,27 @@ egglog/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  egglog/bindings.pyi,sha256=b0osaWlb6Yqh7Cw7I_DnRjU-6NTeXq1XOolvOvYMMcI,13830
19
19
  egglog/config.py,sha256=yM3FIcVCKnhWZmHD0pxkzx7ah7t5PxZx3WUqKtA9tjU,168
20
20
  egglog/declarations.py,sha256=UzpOjoZz21T9QfodmS3U00vx5VcTcryXmAw11CALix8,22193
21
- egglog/__init__.py,sha256=xVZPz0Pplz0OoXcrSPhNY7fHfja7Sao03DYbIKp8kcU,266
21
+ egglog/__init__.py,sha256=5n2nas3VbxxgsvYeT8CutVNHbC2jggXQtDQ3RzX3vCM,315
22
22
  egglog/thunk.py,sha256=BwdqvNhY8g3KdSE5DtT1yIoxCF3mNGcqoODy8nXm9z8,2168
23
23
  egglog/pretty.py,sha256=5hcc5-tAGnhcpz_aClJ1PgE3StOvFqQRYXYFDWXbNP0,20476
24
- egglog/conversion.py,sha256=FnZFEZdZHkyrmg8aJ0_7SA-my5O0b2hjtjGGBrFR00U,9633
24
+ egglog/conversion.py,sha256=lE9a3myOWyxlYbCVqOMKnqXw2_Ul24DJuJlVZ50HCA0,9989
25
25
  egglog/ipython_magic.py,sha256=2hs3g2cSiyDmbCvE2t1OINmu17Bb8MWV--2DpEWwO7I,1189
26
26
  egglog/functionalize.py,sha256=VMqiWcDbn1pYrY3FMxT143Bk90dMaxo2VapUL60-fvE,3847
27
- egglog/egraph.py,sha256=SJI2I7njTUdyg9BU0-SdwkHTvQVfck0mkCjA5zCvJ1c,64350
27
+ egglog/egraph.py,sha256=Y2vAfJpH3ss51zc9krkfAVzkcttKtSwrnSpmrBupvXs,62867
28
28
  egglog/visualizer.css,sha256=eL0POoThQRc0P4OYnDT-d808ln9O5Qy6DizH9Z5LgWc,259398
29
29
  egglog/visualizer.js,sha256=2qZZ-9W_INJx4gZMYjnVXl27IjT_JNuQyEeI2dbjWoU,3753315
30
30
  egglog/examples/schedule_demo.py,sha256=JbXdPII7_adxtgyKVAiqCyV2sj88VZ-DhomYrdn8vuc,618
31
31
  egglog/examples/README.rst,sha256=ztTvpofR0eotSqGoCy_C1fPLDPCncjvcqDanXtLHNNU,232
32
32
  egglog/examples/fib.py,sha256=BOHxKWA7jGx4FURBmfmuZKfLo6xq9-uXAwAXjYid7LU,492
33
33
  egglog/examples/resolution.py,sha256=BJd5JClA3DBVGfiVRa-H0gbbFvIqeP3uYbhCXHblSQc,2119
34
- egglog/examples/lambda_.py,sha256=LK1ZSCITpvxpRMaqISZkRAKqNAU7QbaWq9oGe7ku2Ns,8208
35
- egglog/examples/multiset.py,sha256=bI7AmNlU3_D1GAhti2ojei5WQcbA3nnMun7y33hX0bE,1438
36
- egglog/examples/bignum.py,sha256=Rs_kttvEu_OJW26CRR0AY4Hup_FsvLpg6Klkt-5QfFU,512
34
+ egglog/examples/lambda_.py,sha256=iQvwaXVhp2VNOMS7j1WwceZaiq3dqqilwUkMcW5GFBE,8194
35
+ egglog/examples/multiset.py,sha256=fdEPvNFkHe_XmzBv90JCP8SCxoKgpg_CIDsR8HWcbnY,1447
36
+ egglog/examples/bignum.py,sha256=LziQqBdOaYNhO3VQS6ZokHDO7_QSsOMV4Whyrj6-6nA,536
37
37
  egglog/examples/matrix.py,sha256=7_mPcMcgE-t_GJDyf76-nv3xhPIeN2mvFkc_p_Gnr8g,4961
38
38
  egglog/examples/__init__.py,sha256=wm9evUbMPfbtylXIjbDdRTAVMLH4OjT4Z77PCBFyaPU,31
39
- egglog/examples/higher_order_functions.py,sha256=hWs9sCM5mERuuYscw9Qp-W3HEHi0EMtAs57gV2Qv69k,1119
39
+ egglog/examples/higher_order_functions.py,sha256=DNLIQfPJCX_DOLbHNiaYsfvcFIYCYOsRUqp99r9bpc8,1063
40
40
  egglog/examples/bool.py,sha256=e0z2YoYJsLlhpSADZK1yRYHzilyxSZWGiYAaM0DQ_Gw,695
41
41
  egglog/examples/ndarrays.py,sha256=mfr410eletH8gfdg-P8L90vlF6TUifvYV_-ryOwvZZE,4042
42
- egglog/examples/eqsat_basic.py,sha256=YKA_l6kKNfPk1geaf-qDohYHpvIji_HmABGr3Ez_N8k,966
43
- egglog/bindings.cpython-311-powerpc64-linux-gnu.so,sha256=GTNefRU7n9yaMMUINZsiHHm6WIiHNXX0Oagu_aDPhqw,6993280
44
- egglog-9.0.1.dist-info/RECORD,,
42
+ egglog/examples/eqsat_basic.py,sha256=2xtM81gG9Br72mr58N-2BUeksR7C_UXnZJ4MvzSPplc,869
43
+ egglog/bindings.cpython-311-powerpc64-linux-gnu.so,sha256=El3R-hSWj7dW_ACcndW3WdEwQ6PeTC_uHNxnf9MVVTA,6992704
44
+ egglog-10.0.1.dist-info/RECORD,,