egglog 9.0.1__cp311-cp311-win_amd64.whl → 10.0.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/__init__.py CHANGED
@@ -3,6 +3,7 @@ Package for creating e-graphs in Python.
3
3
  """
4
4
 
5
5
  from . import config, ipython_magic # noqa: F401
6
+ from .bindings import EggSmolError # noqa: F401
6
7
  from .builtins import * # noqa: UP029
7
8
  from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
8
9
  from .egraph import *
Binary file
egglog/builtins.py CHANGED
@@ -12,11 +12,9 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union,
12
12
 
13
13
  from typing_extensions import TypeVarTuple, Unpack
14
14
 
15
- from . import bindings
16
15
  from .conversion import convert, converter, get_type_args
17
16
  from .declarations import *
18
- from .egraph import BaseExpr, BuiltinExpr, EGraph, expr_fact, function, get_current_ruleset, method
19
- from .egraph_state import GLOBAL_PY_OBJECT_SORT
17
+ from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
20
18
  from .functionalize import functionalize
21
19
  from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
22
20
  from .thunk import Thunk
@@ -32,6 +30,7 @@ __all__ = [
32
30
  "BigRatLike",
33
31
  "Bool",
34
32
  "BoolLike",
33
+ "BuiltinEvalError",
35
34
  "Map",
36
35
  "MapLike",
37
36
  "MultiSet",
@@ -56,6 +55,17 @@ __all__ = [
56
55
  ]
57
56
 
58
57
 
58
+ class BuiltinEvalError(Exception):
59
+ """
60
+ Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
61
+
62
+ Try extracting this expression first.
63
+ """
64
+
65
+ def __str__(self) -> str:
66
+ return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}"
67
+
68
+
59
69
  class Unit(BuiltinExpr, egg_sort="Unit"):
60
70
  """
61
71
  The unit type. This is used to reprsent if a value exists in the e-graph or not.
@@ -72,8 +82,8 @@ class String(BuiltinExpr):
72
82
  @method(preserve=True)
73
83
  def eval(self) -> str:
74
84
  value = _extract_lit(self)
75
- assert isinstance(value, bindings.String)
76
- return value.value
85
+ assert isinstance(value, str)
86
+ return value
77
87
 
78
88
  def __init__(self, value: str) -> None: ...
79
89
 
@@ -97,8 +107,8 @@ class Bool(BuiltinExpr, egg_sort="bool"):
97
107
  @method(preserve=True)
98
108
  def eval(self) -> bool:
99
109
  value = _extract_lit(self)
100
- assert isinstance(value, bindings.Bool)
101
- return value.value
110
+ assert isinstance(value, bool)
111
+ return value
102
112
 
103
113
  @method(preserve=True)
104
114
  def __bool__(self) -> bool:
@@ -132,8 +142,8 @@ class i64(BuiltinExpr): # noqa: N801
132
142
  @method(preserve=True)
133
143
  def eval(self) -> int:
134
144
  value = _extract_lit(self)
135
- assert isinstance(value, bindings.Int)
136
- return value.value
145
+ assert isinstance(value, int)
146
+ return value
137
147
 
138
148
  @method(preserve=True)
139
149
  def __index__(self) -> int:
@@ -251,8 +261,8 @@ class f64(BuiltinExpr): # noqa: N801
251
261
  @method(preserve=True)
252
262
  def eval(self) -> float:
253
263
  value = _extract_lit(self)
254
- assert isinstance(value, bindings.Float)
255
- return value.value
264
+ assert isinstance(value, float)
265
+ return value
256
266
 
257
267
  @method(preserve=True)
258
268
  def __float__(self) -> float:
@@ -340,9 +350,12 @@ class Map(BuiltinExpr, Generic[T, V]):
340
350
  expr = cast("RuntimeExpr", self)
341
351
  d = {}
342
352
  while call.callable != ClassMethodRef("Map", "empty"):
343
- assert call.callable == MethodRef("Map", "insert")
353
+ msg = "Map can only be evaluated if it is empty or a series of inserts."
354
+ if call.callable != MethodRef("Map", "insert"):
355
+ raise BuiltinEvalError(msg)
344
356
  call_typed, k_typed, v_typed = call.args
345
- assert isinstance(call_typed.expr, CallDecl)
357
+ if not isinstance(call_typed.expr, CallDecl):
358
+ raise BuiltinEvalError(msg)
346
359
  k = cast("T", expr.__with_expr__(k_typed))
347
360
  v = cast("V", expr.__with_expr__(v_typed))
348
361
  d[k] = v
@@ -404,7 +417,9 @@ class Set(BuiltinExpr, Generic[T]):
404
417
  @method(preserve=True)
405
418
  def eval(self) -> set[T]:
406
419
  call = _extract_call(self)
407
- assert call.callable == InitRef("Set")
420
+ if call.callable != InitRef("Set"):
421
+ msg = "Set can only be initialized with the Set constructor."
422
+ raise BuiltinEvalError(msg)
408
423
  return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
409
424
 
410
425
  @method(preserve=True)
@@ -466,7 +481,9 @@ class MultiSet(BuiltinExpr, Generic[T]):
466
481
  @method(preserve=True)
467
482
  def eval(self) -> list[T]:
468
483
  call = _extract_call(self)
469
- assert call.callable == InitRef("MultiSet")
484
+ if call.callable != InitRef("MultiSet"):
485
+ msg = "MultiSet can only be initialized with the MultiSet constructor."
486
+ raise BuiltinEvalError(msg)
470
487
  return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
471
488
 
472
489
  @method(preserve=True)
@@ -513,11 +530,15 @@ class Rational(BuiltinExpr):
513
530
  @method(preserve=True)
514
531
  def eval(self) -> Fraction:
515
532
  call = _extract_call(self)
516
- assert call.callable == InitRef("Rational")
533
+ if call.callable != InitRef("Rational"):
534
+ msg = "Rational can only be initialized with the Rational constructor."
535
+ raise BuiltinEvalError(msg)
517
536
 
518
537
  def _to_int(e: TypedExprDecl) -> int:
519
538
  expr = e.expr
520
- assert isinstance(expr, LitDecl)
539
+ if not isinstance(expr, LitDecl):
540
+ msg = "Rational can only be initialized with literals"
541
+ raise BuiltinEvalError(msg)
521
542
  assert isinstance(expr.value, int)
522
543
  return expr.value
523
544
 
@@ -596,9 +617,13 @@ class BigInt(BuiltinExpr):
596
617
  @method(preserve=True)
597
618
  def eval(self) -> int:
598
619
  call = _extract_call(self)
599
- assert call.callable == ClassMethodRef("BigInt", "from_string")
620
+ if call.callable != ClassMethodRef("BigInt", "from_string"):
621
+ msg = "BigInt can only be initialized with the BigInt constructor."
622
+ raise BuiltinEvalError(msg)
600
623
  (s,) = call.args
601
- assert isinstance(s.expr, LitDecl)
624
+ if not isinstance(s.expr, LitDecl):
625
+ msg = "BigInt can only be initialized with literals"
626
+ raise BuiltinEvalError(msg)
602
627
  assert isinstance(s.expr.value, str)
603
628
  return int(s.expr.value)
604
629
 
@@ -717,14 +742,19 @@ class BigRat(BuiltinExpr):
717
742
  @method(preserve=True)
718
743
  def eval(self) -> Fraction:
719
744
  call = _extract_call(self)
720
- assert call.callable == InitRef("BigRat")
745
+ if call.callable != InitRef("BigRat"):
746
+ msg = "BigRat can only be initialized with the BigRat constructor."
747
+ raise BuiltinEvalError(msg)
721
748
 
722
749
  def _to_fraction(e: TypedExprDecl) -> Fraction:
723
750
  expr = e.expr
724
- assert isinstance(expr, CallDecl)
725
- assert expr.callable == ClassMethodRef("BigInt", "from_string")
751
+ if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"):
752
+ msg = "BigRat can only be initialized BigInt strings"
753
+ raise BuiltinEvalError(msg)
726
754
  (s,) = expr.args
727
- assert isinstance(s.expr, LitDecl)
755
+ if not isinstance(s.expr, LitDecl):
756
+ msg = "BigInt can only be initialized with literals"
757
+ raise BuiltinEvalError(msg)
728
758
  assert isinstance(s.expr.value, str)
729
759
  return Fraction(s.expr.value)
730
760
 
@@ -821,7 +851,10 @@ class Vec(BuiltinExpr, Generic[T]):
821
851
  call = _extract_call(self)
822
852
  if call.callable == ClassMethodRef("Vec", "empty"):
823
853
  return ()
824
- assert call.callable == InitRef("Vec")
854
+
855
+ if call.callable != InitRef("Vec"):
856
+ msg = "Vec can only be initialized with the Vec constructor."
857
+ raise BuiltinEvalError(msg)
825
858
  return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
826
859
 
827
860
  @method(preserve=True)
@@ -889,10 +922,11 @@ VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
889
922
  class PyObject(BuiltinExpr):
890
923
  @method(preserve=True)
891
924
  def eval(self) -> object:
892
- report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", self), 0)
893
- assert isinstance(report, bindings.Best)
894
- expr = report.termdag.term_to_expr(report.term, bindings.PanicSpan())
895
- return GLOBAL_PY_OBJECT_SORT.load(expr)
925
+ expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
926
+ if not isinstance(expr, PyObjectDecl):
927
+ msg = "PyObject can only be evaluated if it is a PyObject literal"
928
+ raise BuiltinEvalError(msg)
929
+ return expr.value
896
930
 
897
931
  def __init__(self, value: object) -> None: ...
898
932
 
@@ -1027,22 +1061,23 @@ def value_to_annotation(a: object) -> type | None:
1027
1061
  converter(FunctionType, UnstableFn, _convert_function)
1028
1062
 
1029
1063
 
1030
- def _extract_lit(e: BaseExpr) -> bindings._Literal:
1064
+ def _extract_lit(e: BaseExpr) -> LitType:
1031
1065
  """
1032
1066
  Special case extracting literals to make this faster by using termdag directly.
1033
1067
  """
1034
- report = (EGraph.current or EGraph())._run_extract(cast("RuntimeExpr", e), 0)
1035
- assert isinstance(report, bindings.Best)
1036
- term = report.term
1037
- assert isinstance(term, bindings.TermLit)
1038
- return term.value
1068
+ expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1069
+ if not isinstance(expr, LitDecl):
1070
+ msg = "Expected a literal"
1071
+ raise BuiltinEvalError(msg)
1072
+ return expr.value
1039
1073
 
1040
1074
 
1041
1075
  def _extract_call(e: BaseExpr) -> CallDecl:
1042
1076
  """
1043
1077
  Extracts the call form of an expression
1044
1078
  """
1045
- extracted = cast("RuntimeExpr", (EGraph.current or EGraph()).extract(e))
1046
- expr = extracted.__egg_typed_expr__.expr
1047
- assert isinstance(expr, CallDecl)
1079
+ expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
1080
+ if not isinstance(expr, CallDecl):
1081
+ msg = "Expected a call expression"
1082
+ raise BuiltinEvalError(msg)
1048
1083
  return expr
egglog/conversion.py CHANGED
@@ -149,6 +149,12 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
149
149
  decls = _retrieve_conversion_decls()
150
150
  a_tp = _get_tp(a)
151
151
  b_tp = _get_tp(b)
152
+ # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153
+ if not (
154
+ (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155
+ or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156
+ ):
157
+ raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
152
158
  a_converts_to = {
153
159
  to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
154
160
  }
egglog/egraph.py CHANGED
@@ -4,7 +4,7 @@ import contextlib
4
4
  import inspect
5
5
  import pathlib
6
6
  import tempfile
7
- from collections.abc import Callable, Generator, Iterable, Iterator
7
+ from collections.abc import Callable, Generator, Iterable
8
8
  from contextvars import ContextVar
9
9
  from dataclasses import InitVar, dataclass, field
10
10
  from functools import partial
@@ -17,7 +17,6 @@ from typing import (
17
17
  Generic,
18
18
  Literal,
19
19
  Never,
20
- Protocol,
21
20
  TypeAlias,
22
21
  TypedDict,
23
22
  TypeVar,
@@ -85,7 +84,6 @@ __all__ = [
85
84
  "set_",
86
85
  "simplify",
87
86
  "subsume",
88
- "try_evaling",
89
87
  "union",
90
88
  "unstable_combine_rulesets",
91
89
  "var",
@@ -847,11 +845,10 @@ class EGraph:
847
845
  Can run actions, check facts, run schedules, or extract minimal cost expressions.
848
846
  """
849
847
 
850
- current: ClassVar[EGraph | None] = None
851
848
  seminaive: InitVar[bool] = True
852
849
  save_egglog_string: InitVar[bool] = False
853
850
 
854
- _state: EGraphState = field(init=False)
851
+ _state: EGraphState = field(init=False, repr=False)
855
852
  # For pushing/popping with egglog
856
853
  _state_stack: list[EGraphState] = field(default_factory=list, repr=False)
857
854
  # For storing the global "current" egraph
@@ -1200,16 +1197,6 @@ class EGraph:
1200
1197
  if visualize:
1201
1198
  VisualizerWidget(egraphs=egraphs).display_or_open()
1202
1199
 
1203
- @contextlib.contextmanager
1204
- def set_current(self) -> Iterator[None]:
1205
- """
1206
- Context manager that will set the current egraph. It will be set back after.
1207
- """
1208
- prev_current = EGraph.current
1209
- EGraph.current = self
1210
- yield
1211
- EGraph.current = prev_current
1212
-
1213
1200
  @property
1214
1201
  def _egraph(self) -> bindings.EGraph:
1215
1202
  return self._state.egraph
@@ -1303,8 +1290,6 @@ class Schedule(DelayedDeclerations):
1303
1290
  A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
1304
1291
  """
1305
1292
 
1306
- current: ClassVar[Schedule | None] = None
1307
-
1308
1293
  # Defer declerations so that we can have rule generators that used not yet defined yet
1309
1294
  schedule: ScheduleDecl
1310
1295
 
@@ -1332,16 +1317,6 @@ class Schedule(DelayedDeclerations):
1332
1317
  """
1333
1318
  return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
1334
1319
 
1335
- @contextlib.contextmanager
1336
- def set_current(self) -> Iterator[None]:
1337
- """
1338
- Context manager that will set the current schedule. It will be set back after
1339
- """
1340
- prev_current = Schedule.current
1341
- Schedule.current = self
1342
- yield
1343
- Schedule.current = prev_current
1344
-
1345
1320
 
1346
1321
  @dataclass
1347
1322
  class Ruleset(Schedule):
@@ -1488,9 +1463,12 @@ class Fact:
1488
1463
 
1489
1464
  def __bool__(self) -> bool:
1490
1465
  """
1491
- Returns True if the two sides of an equality are equal in the egraph or the expression is in the egraph.
1466
+ Returns True if the two sides of an equality are structurally equal.
1492
1467
  """
1493
- return (EGraph.current or EGraph()).check_bool(self)
1468
+ if not isinstance(self.fact, EqDecl):
1469
+ msg = "Can only check equality facts"
1470
+ raise TypeError(msg)
1471
+ return self.fact.left == self.fact.right
1494
1472
 
1495
1473
 
1496
1474
  @dataclass
@@ -1876,34 +1854,3 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
1876
1854
  yield
1877
1855
  finally:
1878
1856
  _CURRENT_RULESET.reset(token)
1879
-
1880
-
1881
- T_co = TypeVar("T_co", covariant=True)
1882
-
1883
-
1884
- class _EvalsTo(Protocol, Generic[T_co]):
1885
- def eval(self) -> T_co: ...
1886
-
1887
-
1888
- def try_evaling(schedule: Schedule, expr: Expr, prim_expr: _EvalsTo[T]) -> T:
1889
- """
1890
- Try evaling the expression that will result in a primitive expression being fill.
1891
- if it fails, display the egraph and raise an error.
1892
- """
1893
- egraph = EGraph.current or EGraph()
1894
- with egraph.set_current():
1895
- try:
1896
- return prim_expr.eval()
1897
- except BaseException: # noqa: S110
1898
- pass
1899
- # If this primitive doesn't exist in the egraph, we need to try to create it by
1900
- # registering the expression and running the schedule
1901
- egraph.register(expr)
1902
- egraph.run(Schedule.current or schedule)
1903
- try:
1904
- with egraph.set_current():
1905
- return prim_expr.eval()
1906
- except BaseException as e:
1907
- # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1908
- e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1909
- raise
egglog/examples/bignum.py CHANGED
@@ -12,14 +12,15 @@ x = BigInt(-1234)
12
12
  y = BigInt.from_string("2")
13
13
  z = BigRat(x, y)
14
14
 
15
- assert z.numer.to_string() == "-617"
15
+ egraph = EGraph()
16
+
17
+ assert egraph.extract(z.numer.to_string()).eval() == "-617"
16
18
 
17
19
 
18
20
  @function
19
21
  def bignums(x: BigInt, y: BigInt) -> BigRat: ...
20
22
 
21
23
 
22
- egraph = EGraph()
23
24
  egraph.register(set_(bignums(x, y)).to(z))
24
25
 
25
26
  c = var("c", BigRat)
@@ -28,18 +28,17 @@ expr2 = Num(6) + Num(2) * Num.var("x")
28
28
  a, b, c = vars_("a b c", Num)
29
29
  i, j = vars_("i j", i64)
30
30
 
31
- check(
32
- # Check that these expressions are equal
33
- eq(expr1).to(expr2),
34
- # After running these rules, up to ten times
31
+ egraph = EGraph()
32
+ egraph.register(expr1, expr2)
33
+
34
+ egraph.run(
35
35
  ruleset(
36
36
  rewrite(a + b).to(b + a),
37
37
  rewrite(a * (b + c)).to((a * b) + (a * c)),
38
38
  rewrite(Num(i) + Num(j)).to(Num(i + j)),
39
39
  rewrite(Num(i) * Num(j)).to(Num(i * j)),
40
40
  )
41
- * 10,
42
- # On these two initial expressions
43
- expr1,
44
- expr2,
41
+ * 10
45
42
  )
43
+
44
+ egraph.check(expr1 == expr2)
@@ -29,33 +29,32 @@ egraph = EGraph()
29
29
  xs = MultiSet(Math(1), Math(2), Math(3))
30
30
  egraph.register(xs)
31
31
 
32
- with egraph.set_current():
33
- assert xs == MultiSet(Math(1), Math(3), Math(2))
34
- assert xs != MultiSet(Math(1), Math(1), Math(2), Math(3))
32
+ egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
33
+ egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
35
34
 
36
- assert Counter(xs) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
35
+ assert Counter(egraph.extract(xs).eval()) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
37
36
 
38
- inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
39
- egraph.register(inserted)
40
- assert xs.insert(Math(4)) == inserted
41
37
 
42
- assert xs.contains(Math(1))
43
- assert xs.not_contains(Math(4))
44
- assert Math(1) in xs
45
- assert Math(4) not in xs
38
+ inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
39
+ egraph.register(inserted)
40
+ egraph.check(xs.insert(Math(4)) == inserted)
41
+ egraph.check(xs.contains(Math(1)))
42
+ egraph.check(xs.not_contains(Math(4)))
43
+ assert Math(1) in xs
44
+ assert Math(4) not in xs
46
45
 
47
- assert xs.remove(Math(1)) == MultiSet(Math(2), Math(3))
46
+ egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
48
47
 
49
- assert xs.length() == i64(3)
50
- assert len(xs) == 3
48
+ assert egraph.extract(xs.length()).eval() == 3
49
+ assert len(xs) == 3
51
50
 
52
- assert MultiSet(Math(1), Math(1)).length() == i64(2)
51
+ egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
53
52
 
54
- assert MultiSet(Math(1)).pick() == Math(1)
53
+ egraph.check(MultiSet(Math(1)).pick() == Math(1))
55
54
 
56
- mapped = xs.map(square)
57
- egraph.register(mapped)
58
- egraph.run(math_ruleset)
59
- assert mapped == MultiSet(Math(1), Math(4), Math(9))
55
+ mapped = xs.map(square)
56
+ egraph.register(mapped)
57
+ egraph.run(math_ruleset)
58
+ egraph.check(mapped == MultiSet(Math(1), Math(4), Math(9)))
60
59
 
61
- assert xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3))
60
+ egraph.check(xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3)))
egglog/exp/array_api.py CHANGED
@@ -54,6 +54,7 @@ A: So that when converting from functional to lists we can use the same index fu
54
54
 
55
55
  from __future__ import annotations
56
56
 
57
+ import contextlib
57
58
  import itertools
58
59
  import math
59
60
  import numbers
@@ -61,7 +62,7 @@ import os
61
62
  import sys
62
63
  from copy import copy
63
64
  from types import EllipsisType
64
- from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
65
+ from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
65
66
 
66
67
  import numpy as np
67
68
 
@@ -93,7 +94,7 @@ class Boolean(Expr, ruleset=array_api_ruleset):
93
94
 
94
95
  @method(preserve=True)
95
96
  def eval(self) -> bool:
96
- return try_evaling(array_api_schedule, self, self.to_bool)
97
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool)
97
98
 
98
99
  @property
99
100
  def to_bool(self) -> Bool: ...
@@ -218,7 +219,7 @@ class Int(Expr, ruleset=array_api_ruleset):
218
219
 
219
220
  @method(preserve=True)
220
221
  def eval(self) -> int:
221
- return try_evaling(array_api_schedule, self, self.to_i64)
222
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64)
222
223
 
223
224
  @method(preserve=True)
224
225
  def __index__(self) -> int:
@@ -330,7 +331,7 @@ class Float(Expr, ruleset=array_api_ruleset):
330
331
 
331
332
  @method(preserve=True)
332
333
  def eval(self) -> float:
333
- return try_evaling(array_api_schedule, self, self.to_f64)
334
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64)
334
335
 
335
336
  def abs(self) -> Float: ...
336
337
 
@@ -438,7 +439,7 @@ class TupleInt(Expr, ruleset=array_api_ruleset):
438
439
 
439
440
  @method(preserve=True)
440
441
  def eval(self) -> tuple[Int, ...]:
441
- return try_evaling(array_api_schedule, self, self.to_vec)
442
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
442
443
 
443
444
  def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
444
445
  def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
@@ -586,7 +587,7 @@ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
586
587
 
587
588
  @method(preserve=True)
588
589
  def eval(self) -> tuple[TupleInt, ...]:
589
- return try_evaling(array_api_schedule, self, self.to_vec)
590
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
590
591
 
591
592
  def drop(self, n: Int) -> TupleTupleInt:
592
593
  return TupleTupleInt(self.length() - n, lambda i: self[i + n])
@@ -1282,7 +1283,7 @@ class TupleNDArray(Expr, ruleset=array_api_ruleset):
1282
1283
 
1283
1284
  @method(preserve=True)
1284
1285
  def eval(self) -> tuple[NDArray, ...]:
1285
- return try_evaling(array_api_schedule, self, self.to_vec)
1286
+ return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
1286
1287
 
1287
1288
 
1288
1289
  converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
@@ -1941,3 +1942,42 @@ def array_api_vec_to_cons_ruleset(
1941
1942
 
1942
1943
  array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
1943
1944
  array_api_schedule = array_api_combined_ruleset.saturate()
1945
+
1946
+ _CURRENT_EGRAPH: None | EGraph = None
1947
+
1948
+
1949
+ @contextlib.contextmanager
1950
+ def set_array_api_egraph(egraph: EGraph) -> Iterator[None]:
1951
+ """
1952
+ Context manager that will set the current egraph. It will be set back after.
1953
+ """
1954
+ global _CURRENT_EGRAPH
1955
+ assert _CURRENT_EGRAPH is None
1956
+ _CURRENT_EGRAPH = egraph
1957
+ yield
1958
+ _CURRENT_EGRAPH = None
1959
+
1960
+
1961
+ def _get_current_egraph() -> EGraph:
1962
+ return _CURRENT_EGRAPH or EGraph()
1963
+
1964
+
1965
+ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any:
1966
+ """
1967
+ Try evaling the expression that will result in a primitive expression being fill.
1968
+ if it fails, display the egraph and raise an error.
1969
+ """
1970
+ try:
1971
+ extracted = egraph.extract(prim_expr)
1972
+ except EggSmolError:
1973
+ # If this primitive doesn't exist in the egraph, we need to try to create it by
1974
+ # registering the expression and running the schedule
1975
+ egraph.register(expr)
1976
+ egraph.run(schedule)
1977
+ try:
1978
+ extracted = egraph.extract(prim_expr)
1979
+ except BaseException as e:
1980
+ # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1981
+ e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1982
+ raise
1983
+ return extracted.eval() # type: ignore[attr-defined]
@@ -4,8 +4,8 @@ from typing import TypeVar, cast
4
4
 
5
5
  import numpy as np
6
6
 
7
- from egglog import EGraph, try_evaling
8
- from egglog.exp.array_api import NDArray
7
+ from egglog import EGraph
8
+ from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
9
9
  from egglog.exp.array_api_numba import array_api_numba_schedule
10
10
  from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
11
11
 
@@ -20,8 +20,7 @@ def jit(fn: X) -> X:
20
20
  """
21
21
  egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
22
22
  fn_program = EvalProgram(program, {"np": np})
23
- with egraph.set_current():
24
- fn = cast("X", try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
23
+ fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
25
24
  fn.initial_expr = res # type: ignore[attr-defined]
26
25
  fn.expr = res_optimized # type: ignore[attr-defined]
27
26
  return fn
@@ -32,9 +31,11 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
32
31
  arg1, arg2 = sig.parameters.keys()
33
32
  egraph = EGraph(save_egglog_string=save_egglog_string)
34
33
  with egraph:
35
- with egraph.set_current():
34
+ with set_array_api_egraph(egraph):
36
35
  res = fn(NDArray.var(arg1), NDArray.var(arg2))
37
- res_optimized = egraph.simplify(res, array_api_numba_schedule)
36
+ egraph.register(res)
37
+ egraph.run(array_api_numba_schedule)
38
+ res_optimized = egraph.extract(res)
38
39
 
39
40
  return (
40
41
  egraph,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: egglog
3
- Version: 9.0.1
3
+ Version: 10.0.0
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers
@@ -1,28 +1,28 @@
1
- egglog-9.0.1.dist-info/METADATA,sha256=c75HEwCDBdeEj8KEDeHe4FBweGGWS-mHSsYyOVRTDH4,4018
2
- egglog-9.0.1.dist-info/WHEEL,sha256=tAGdc4C2KTz7B2CZ8Jf3DcKSAviAbCg44UH9ma2gYww,96
3
- egglog-9.0.1.dist-info/licenses/LICENSE,sha256=TfaboMVZ81Q6OUaKjU7z6uVjSlcGKclLYcOpgDbm9_s,1091
1
+ egglog-10.0.0.dist-info/METADATA,sha256=aWuZK6v1QbkwYpgfn9_9UnMiwCvPBONgvULMJCkMyFw,4019
2
+ egglog-10.0.0.dist-info/WHEEL,sha256=tAGdc4C2KTz7B2CZ8Jf3DcKSAviAbCg44UH9ma2gYww,96
3
+ egglog-10.0.0.dist-info/licenses/LICENSE,sha256=TfaboMVZ81Q6OUaKjU7z6uVjSlcGKclLYcOpgDbm9_s,1091
4
4
  egglog/bindings.pyi,sha256=fnRiGyHuM-m0fDe91kFryUC09U87f5XXPiY6Lno8bvk,14497
5
- egglog/builtins.py,sha256=_hgkpMUF-lCNasV7wG9r-7mSd-w0eNfyjYnwCKiWEP8,29143
5
+ egglog/builtins.py,sha256=ZLXp5txd7CD_Wlb137DFdy_qGbvmvzbfb54tc_KbRfY,30692
6
6
  egglog/config.py,sha256=mALVaxh7zmGrbuyzaVKVmYKcu1lF703QsKJw8AF7gSM,176
7
- egglog/conversion.py,sha256=qhH-3hs9TYnQWooaJOiV_fnol19_Zzxxf6XZ7LHudM4,9895
7
+ egglog/conversion.py,sha256=o6iE32taKz-PBX8aDo5i4jGELdEai-wrzMEZUC1lqsU,10257
8
8
  egglog/declarations.py,sha256=R4U8JhM6ChOMmPEgG6tL2sJ8wJGsIEEVVN5lVY-6k2w,23011
9
- egglog/egraph.py,sha256=t7Z4yaAeI1d-gQgsC3xDIE6U4OQQghN7AH3yqlqJ2gI,66259
9
+ egglog/egraph.py,sha256=ssy2UHt8weA8GXzomqTWMKB389HzMo62k826gas5Fng,64537
10
10
  egglog/egraph_state.py,sha256=wkuHoXmEzcUFspvNl8ZG7_xS1J7XZTK3RPnlZEuvk9k,28720
11
- egglog/examples/bignum.py,sha256=pnXy8J8dQs-QjuLH4icahJFz2TyqRCL0gpBn9YPQg2M,543
11
+ egglog/examples/bignum.py,sha256=r4UeVwpKK2iSZiT5MK4qqTWhWu76CAtSJ5YGDboO9tk,568
12
12
  egglog/examples/bool.py,sha256=pWZTjfXR1cFy3KcihLBU5AF5rn83ImORlhUUJ1YiAXc,733
13
- egglog/examples/eqsat_basic.py,sha256=ORXFYYEDsEZK2IPhHtoFsd-LdjMiQi1nn7kix4Nam0s,1011
13
+ egglog/examples/eqsat_basic.py,sha256=ivYHFD52sjYyl5rWQ60ScxmvcbABYUBlPobDMu6HvN4,913
14
14
  egglog/examples/fib.py,sha256=wAn-PjazxgHDkXAU4o2xTk_GtM_iGL0biV66vWM1st4,520
15
15
  egglog/examples/higher_order_functions.py,sha256=5P-_10y8eTh-z3bze5GVzKHyFTeefTM774haF8KYGig,1164
16
16
  egglog/examples/lambda_.py,sha256=hQBOaSw_yorNcbkQVu2EhgSc0IZNWIny7asaOlcUk9s,8496
17
17
  egglog/examples/matrix.py,sha256=_zmjgfFr2O_LjTcsTD-45_38Y_M1sP3AV39K6oFxAdw,5136
18
- egglog/examples/multiset.py,sha256=4qVjs74SL_0T4MYOyr-3TomhMXuke2DVrQFk-CQJT04,1499
18
+ egglog/examples/multiset.py,sha256=kSUFiN4qs8WXvexPCecqk3o5DzdywXacDmIiC5IDKCM,1507
19
19
  egglog/examples/ndarrays.py,sha256=T-wwef-n-3LDSjaO35zA8AZH5DXFFqq0XBSCQKEXV6E,4186
20
20
  egglog/examples/README.rst,sha256=QrbfmivODBvUvmY3-dHarcbC6bEvwoqAfTDhiI-aJxU,237
21
21
  egglog/examples/resolution.py,sha256=sKkbRI_v9XkQM0DriacKLINqKKDqYGFhvMCAS9tZbTA,2203
22
22
  egglog/examples/schedule_demo.py,sha256=iJtIbcLaZ7zK8UalY0z7KAKMqYjQx0MKTsNF24lKtik,652
23
23
  egglog/examples/__init__.py,sha256=KuhaJFOyz_rpUvEqZubsgLnv6rhQNE_AVFXA6bUnpdY,34
24
- egglog/exp/array_api.py,sha256=LDrNkrMznd937_It1sKn9H_gCnJMELC0jsGUjcjjA3Y,64166
25
- egglog/exp/array_api_jit.py,sha256=9eDVR0hCLeM3RLkIOrfitjX1bCvwoWZG1LF5nmRLWhc,1540
24
+ egglog/exp/array_api.py,sha256=Y-mG8tC8ihFkeIbNDsDyZl7-7zP6YFxZGS4rfsWOG-8,65661
25
+ egglog/exp/array_api_jit.py,sha256=c0lN-411M9LmEz8uM0ELXTtfUPMJDZFmy81-x3ck6PY,1591
26
26
  egglog/exp/array_api_loopnest.py,sha256=1KZ8kWtv00TVCbzqzWG92vz8Y1pBlPOLRIrJfxyDs9s,2518
27
27
  egglog/exp/array_api_numba.py,sha256=g0Yb7auB2MtJHgSzn7iVnsxn6ffL4FHggKvbEuuDwZs,2765
28
28
  egglog/exp/array_api_program_gen.py,sha256=qXaaW_Koz4P-KmQcyRHY1NDLCJL3O7NnE6iq2NX4yZU,22252
@@ -39,6 +39,6 @@ egglog/type_constraint_solver.py,sha256=kSJf8S1CxkicpPy_B2cNP9XiTf-d4TQNd7_IwfG-
39
39
  egglog/visualizer.css,sha256=DDmxZfLmFygwVkQwel0gGUiM5BZqgunawWeSfaRYBr8,259399
40
40
  egglog/visualizer.js,sha256=kDkYzE5cLLhmpNaWqSfQWzD0SlSsNe9KzT_GpNqhWgo,3789092
41
41
  egglog/visualizer_widget.py,sha256=TGkf61P_f4EvCFcyqwQC_8XYUkujFhcQnEG9vHWLPC4,1358
42
- egglog/__init__.py,sha256=YBsuQdA1JYMtdz79jVNGY1DXDPaWChjiba0Bu3RYfr0,276
43
- egglog/bindings.cp311-win_amd64.pyd,sha256=HGN2sZzhRRBh6b_yyDUaD1giNNn5F42QekvOd0kdpUA,4500480
44
- egglog-9.0.1.dist-info/RECORD,,
42
+ egglog/__init__.py,sha256=TZ3IMVKoz5Qr97VFe-YT7oyT2NACikTis21PpVlSCmg,326
43
+ egglog/bindings.cp311-win_amd64.pyd,sha256=lydYMT5DJB3jAsg_IpweX1BgjWbVm5Z7cPQDaakldrY,4500992
44
+ egglog-10.0.0.dist-info/RECORD,,