egglog 9.0.1__cp312-cp312-win_amd64.whl → 10.0.0__cp312-cp312-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -0
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/builtins.py +72 -37
- egglog/conversion.py +6 -0
- egglog/egraph.py +7 -60
- egglog/examples/bignum.py +3 -2
- egglog/examples/eqsat_basic.py +7 -8
- egglog/examples/multiset.py +20 -21
- egglog/exp/array_api.py +47 -7
- egglog/exp/array_api_jit.py +7 -6
- {egglog-9.0.1.dist-info → egglog-10.0.0.dist-info}/METADATA +1 -1
- {egglog-9.0.1.dist-info → egglog-10.0.0.dist-info}/RECORD +14 -14
- {egglog-9.0.1.dist-info → egglog-10.0.0.dist-info}/WHEEL +0 -0
- {egglog-9.0.1.dist-info → egglog-10.0.0.dist-info}/licenses/LICENSE +0 -0
egglog/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@ Package for creating e-graphs in Python.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from . import config, ipython_magic # noqa: F401
|
|
6
|
+
from .bindings import EggSmolError # noqa: F401
|
|
6
7
|
from .builtins import * # noqa: UP029
|
|
7
8
|
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
|
|
8
9
|
from .egraph import *
|
|
Binary file
|
egglog/builtins.py
CHANGED
|
@@ -12,11 +12,9 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union,
|
|
|
12
12
|
|
|
13
13
|
from typing_extensions import TypeVarTuple, Unpack
|
|
14
14
|
|
|
15
|
-
from . import bindings
|
|
16
15
|
from .conversion import convert, converter, get_type_args
|
|
17
16
|
from .declarations import *
|
|
18
|
-
from .egraph import BaseExpr, BuiltinExpr,
|
|
19
|
-
from .egraph_state import GLOBAL_PY_OBJECT_SORT
|
|
17
|
+
from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
|
|
20
18
|
from .functionalize import functionalize
|
|
21
19
|
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
|
|
22
20
|
from .thunk import Thunk
|
|
@@ -32,6 +30,7 @@ __all__ = [
|
|
|
32
30
|
"BigRatLike",
|
|
33
31
|
"Bool",
|
|
34
32
|
"BoolLike",
|
|
33
|
+
"BuiltinEvalError",
|
|
35
34
|
"Map",
|
|
36
35
|
"MapLike",
|
|
37
36
|
"MultiSet",
|
|
@@ -56,6 +55,17 @@ __all__ = [
|
|
|
56
55
|
]
|
|
57
56
|
|
|
58
57
|
|
|
58
|
+
class BuiltinEvalError(Exception):
|
|
59
|
+
"""
|
|
60
|
+
Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
|
|
61
|
+
|
|
62
|
+
Try extracting this expression first.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __str__(self) -> str:
|
|
66
|
+
return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}"
|
|
67
|
+
|
|
68
|
+
|
|
59
69
|
class Unit(BuiltinExpr, egg_sort="Unit"):
|
|
60
70
|
"""
|
|
61
71
|
The unit type. This is used to reprsent if a value exists in the e-graph or not.
|
|
@@ -72,8 +82,8 @@ class String(BuiltinExpr):
|
|
|
72
82
|
@method(preserve=True)
|
|
73
83
|
def eval(self) -> str:
|
|
74
84
|
value = _extract_lit(self)
|
|
75
|
-
assert isinstance(value,
|
|
76
|
-
return value
|
|
85
|
+
assert isinstance(value, str)
|
|
86
|
+
return value
|
|
77
87
|
|
|
78
88
|
def __init__(self, value: str) -> None: ...
|
|
79
89
|
|
|
@@ -97,8 +107,8 @@ class Bool(BuiltinExpr, egg_sort="bool"):
|
|
|
97
107
|
@method(preserve=True)
|
|
98
108
|
def eval(self) -> bool:
|
|
99
109
|
value = _extract_lit(self)
|
|
100
|
-
assert isinstance(value,
|
|
101
|
-
return value
|
|
110
|
+
assert isinstance(value, bool)
|
|
111
|
+
return value
|
|
102
112
|
|
|
103
113
|
@method(preserve=True)
|
|
104
114
|
def __bool__(self) -> bool:
|
|
@@ -132,8 +142,8 @@ class i64(BuiltinExpr): # noqa: N801
|
|
|
132
142
|
@method(preserve=True)
|
|
133
143
|
def eval(self) -> int:
|
|
134
144
|
value = _extract_lit(self)
|
|
135
|
-
assert isinstance(value,
|
|
136
|
-
return value
|
|
145
|
+
assert isinstance(value, int)
|
|
146
|
+
return value
|
|
137
147
|
|
|
138
148
|
@method(preserve=True)
|
|
139
149
|
def __index__(self) -> int:
|
|
@@ -251,8 +261,8 @@ class f64(BuiltinExpr): # noqa: N801
|
|
|
251
261
|
@method(preserve=True)
|
|
252
262
|
def eval(self) -> float:
|
|
253
263
|
value = _extract_lit(self)
|
|
254
|
-
assert isinstance(value,
|
|
255
|
-
return value
|
|
264
|
+
assert isinstance(value, float)
|
|
265
|
+
return value
|
|
256
266
|
|
|
257
267
|
@method(preserve=True)
|
|
258
268
|
def __float__(self) -> float:
|
|
@@ -340,9 +350,12 @@ class Map(BuiltinExpr, Generic[T, V]):
|
|
|
340
350
|
expr = cast("RuntimeExpr", self)
|
|
341
351
|
d = {}
|
|
342
352
|
while call.callable != ClassMethodRef("Map", "empty"):
|
|
343
|
-
|
|
353
|
+
msg = "Map can only be evaluated if it is empty or a series of inserts."
|
|
354
|
+
if call.callable != MethodRef("Map", "insert"):
|
|
355
|
+
raise BuiltinEvalError(msg)
|
|
344
356
|
call_typed, k_typed, v_typed = call.args
|
|
345
|
-
|
|
357
|
+
if not isinstance(call_typed.expr, CallDecl):
|
|
358
|
+
raise BuiltinEvalError(msg)
|
|
346
359
|
k = cast("T", expr.__with_expr__(k_typed))
|
|
347
360
|
v = cast("V", expr.__with_expr__(v_typed))
|
|
348
361
|
d[k] = v
|
|
@@ -404,7 +417,9 @@ class Set(BuiltinExpr, Generic[T]):
|
|
|
404
417
|
@method(preserve=True)
|
|
405
418
|
def eval(self) -> set[T]:
|
|
406
419
|
call = _extract_call(self)
|
|
407
|
-
|
|
420
|
+
if call.callable != InitRef("Set"):
|
|
421
|
+
msg = "Set can only be initialized with the Set constructor."
|
|
422
|
+
raise BuiltinEvalError(msg)
|
|
408
423
|
return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
|
|
409
424
|
|
|
410
425
|
@method(preserve=True)
|
|
@@ -466,7 +481,9 @@ class MultiSet(BuiltinExpr, Generic[T]):
|
|
|
466
481
|
@method(preserve=True)
|
|
467
482
|
def eval(self) -> list[T]:
|
|
468
483
|
call = _extract_call(self)
|
|
469
|
-
|
|
484
|
+
if call.callable != InitRef("MultiSet"):
|
|
485
|
+
msg = "MultiSet can only be initialized with the MultiSet constructor."
|
|
486
|
+
raise BuiltinEvalError(msg)
|
|
470
487
|
return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
|
|
471
488
|
|
|
472
489
|
@method(preserve=True)
|
|
@@ -513,11 +530,15 @@ class Rational(BuiltinExpr):
|
|
|
513
530
|
@method(preserve=True)
|
|
514
531
|
def eval(self) -> Fraction:
|
|
515
532
|
call = _extract_call(self)
|
|
516
|
-
|
|
533
|
+
if call.callable != InitRef("Rational"):
|
|
534
|
+
msg = "Rational can only be initialized with the Rational constructor."
|
|
535
|
+
raise BuiltinEvalError(msg)
|
|
517
536
|
|
|
518
537
|
def _to_int(e: TypedExprDecl) -> int:
|
|
519
538
|
expr = e.expr
|
|
520
|
-
|
|
539
|
+
if not isinstance(expr, LitDecl):
|
|
540
|
+
msg = "Rational can only be initialized with literals"
|
|
541
|
+
raise BuiltinEvalError(msg)
|
|
521
542
|
assert isinstance(expr.value, int)
|
|
522
543
|
return expr.value
|
|
523
544
|
|
|
@@ -596,9 +617,13 @@ class BigInt(BuiltinExpr):
|
|
|
596
617
|
@method(preserve=True)
|
|
597
618
|
def eval(self) -> int:
|
|
598
619
|
call = _extract_call(self)
|
|
599
|
-
|
|
620
|
+
if call.callable != ClassMethodRef("BigInt", "from_string"):
|
|
621
|
+
msg = "BigInt can only be initialized with the BigInt constructor."
|
|
622
|
+
raise BuiltinEvalError(msg)
|
|
600
623
|
(s,) = call.args
|
|
601
|
-
|
|
624
|
+
if not isinstance(s.expr, LitDecl):
|
|
625
|
+
msg = "BigInt can only be initialized with literals"
|
|
626
|
+
raise BuiltinEvalError(msg)
|
|
602
627
|
assert isinstance(s.expr.value, str)
|
|
603
628
|
return int(s.expr.value)
|
|
604
629
|
|
|
@@ -717,14 +742,19 @@ class BigRat(BuiltinExpr):
|
|
|
717
742
|
@method(preserve=True)
|
|
718
743
|
def eval(self) -> Fraction:
|
|
719
744
|
call = _extract_call(self)
|
|
720
|
-
|
|
745
|
+
if call.callable != InitRef("BigRat"):
|
|
746
|
+
msg = "BigRat can only be initialized with the BigRat constructor."
|
|
747
|
+
raise BuiltinEvalError(msg)
|
|
721
748
|
|
|
722
749
|
def _to_fraction(e: TypedExprDecl) -> Fraction:
|
|
723
750
|
expr = e.expr
|
|
724
|
-
|
|
725
|
-
|
|
751
|
+
if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"):
|
|
752
|
+
msg = "BigRat can only be initialized BigInt strings"
|
|
753
|
+
raise BuiltinEvalError(msg)
|
|
726
754
|
(s,) = expr.args
|
|
727
|
-
|
|
755
|
+
if not isinstance(s.expr, LitDecl):
|
|
756
|
+
msg = "BigInt can only be initialized with literals"
|
|
757
|
+
raise BuiltinEvalError(msg)
|
|
728
758
|
assert isinstance(s.expr.value, str)
|
|
729
759
|
return Fraction(s.expr.value)
|
|
730
760
|
|
|
@@ -821,7 +851,10 @@ class Vec(BuiltinExpr, Generic[T]):
|
|
|
821
851
|
call = _extract_call(self)
|
|
822
852
|
if call.callable == ClassMethodRef("Vec", "empty"):
|
|
823
853
|
return ()
|
|
824
|
-
|
|
854
|
+
|
|
855
|
+
if call.callable != InitRef("Vec"):
|
|
856
|
+
msg = "Vec can only be initialized with the Vec constructor."
|
|
857
|
+
raise BuiltinEvalError(msg)
|
|
825
858
|
return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
|
|
826
859
|
|
|
827
860
|
@method(preserve=True)
|
|
@@ -889,10 +922,11 @@ VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
|
|
|
889
922
|
class PyObject(BuiltinExpr):
|
|
890
923
|
@method(preserve=True)
|
|
891
924
|
def eval(self) -> object:
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
925
|
+
expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
|
|
926
|
+
if not isinstance(expr, PyObjectDecl):
|
|
927
|
+
msg = "PyObject can only be evaluated if it is a PyObject literal"
|
|
928
|
+
raise BuiltinEvalError(msg)
|
|
929
|
+
return expr.value
|
|
896
930
|
|
|
897
931
|
def __init__(self, value: object) -> None: ...
|
|
898
932
|
|
|
@@ -1027,22 +1061,23 @@ def value_to_annotation(a: object) -> type | None:
|
|
|
1027
1061
|
converter(FunctionType, UnstableFn, _convert_function)
|
|
1028
1062
|
|
|
1029
1063
|
|
|
1030
|
-
def _extract_lit(e: BaseExpr) ->
|
|
1064
|
+
def _extract_lit(e: BaseExpr) -> LitType:
|
|
1031
1065
|
"""
|
|
1032
1066
|
Special case extracting literals to make this faster by using termdag directly.
|
|
1033
1067
|
"""
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
return
|
|
1068
|
+
expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
|
|
1069
|
+
if not isinstance(expr, LitDecl):
|
|
1070
|
+
msg = "Expected a literal"
|
|
1071
|
+
raise BuiltinEvalError(msg)
|
|
1072
|
+
return expr.value
|
|
1039
1073
|
|
|
1040
1074
|
|
|
1041
1075
|
def _extract_call(e: BaseExpr) -> CallDecl:
|
|
1042
1076
|
"""
|
|
1043
1077
|
Extracts the call form of an expression
|
|
1044
1078
|
"""
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1079
|
+
expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
|
|
1080
|
+
if not isinstance(expr, CallDecl):
|
|
1081
|
+
msg = "Expected a call expression"
|
|
1082
|
+
raise BuiltinEvalError(msg)
|
|
1048
1083
|
return expr
|
egglog/conversion.py
CHANGED
|
@@ -149,6 +149,12 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
|
|
|
149
149
|
decls = _retrieve_conversion_decls()
|
|
150
150
|
a_tp = _get_tp(a)
|
|
151
151
|
b_tp = _get_tp(b)
|
|
152
|
+
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
|
|
153
|
+
if not (
|
|
154
|
+
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
|
|
155
|
+
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
|
|
156
|
+
):
|
|
157
|
+
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
|
|
152
158
|
a_converts_to = {
|
|
153
159
|
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
|
|
154
160
|
}
|
egglog/egraph.py
CHANGED
|
@@ -4,7 +4,7 @@ import contextlib
|
|
|
4
4
|
import inspect
|
|
5
5
|
import pathlib
|
|
6
6
|
import tempfile
|
|
7
|
-
from collections.abc import Callable, Generator, Iterable
|
|
7
|
+
from collections.abc import Callable, Generator, Iterable
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from dataclasses import InitVar, dataclass, field
|
|
10
10
|
from functools import partial
|
|
@@ -17,7 +17,6 @@ from typing import (
|
|
|
17
17
|
Generic,
|
|
18
18
|
Literal,
|
|
19
19
|
Never,
|
|
20
|
-
Protocol,
|
|
21
20
|
TypeAlias,
|
|
22
21
|
TypedDict,
|
|
23
22
|
TypeVar,
|
|
@@ -85,7 +84,6 @@ __all__ = [
|
|
|
85
84
|
"set_",
|
|
86
85
|
"simplify",
|
|
87
86
|
"subsume",
|
|
88
|
-
"try_evaling",
|
|
89
87
|
"union",
|
|
90
88
|
"unstable_combine_rulesets",
|
|
91
89
|
"var",
|
|
@@ -847,11 +845,10 @@ class EGraph:
|
|
|
847
845
|
Can run actions, check facts, run schedules, or extract minimal cost expressions.
|
|
848
846
|
"""
|
|
849
847
|
|
|
850
|
-
current: ClassVar[EGraph | None] = None
|
|
851
848
|
seminaive: InitVar[bool] = True
|
|
852
849
|
save_egglog_string: InitVar[bool] = False
|
|
853
850
|
|
|
854
|
-
_state: EGraphState = field(init=False)
|
|
851
|
+
_state: EGraphState = field(init=False, repr=False)
|
|
855
852
|
# For pushing/popping with egglog
|
|
856
853
|
_state_stack: list[EGraphState] = field(default_factory=list, repr=False)
|
|
857
854
|
# For storing the global "current" egraph
|
|
@@ -1200,16 +1197,6 @@ class EGraph:
|
|
|
1200
1197
|
if visualize:
|
|
1201
1198
|
VisualizerWidget(egraphs=egraphs).display_or_open()
|
|
1202
1199
|
|
|
1203
|
-
@contextlib.contextmanager
|
|
1204
|
-
def set_current(self) -> Iterator[None]:
|
|
1205
|
-
"""
|
|
1206
|
-
Context manager that will set the current egraph. It will be set back after.
|
|
1207
|
-
"""
|
|
1208
|
-
prev_current = EGraph.current
|
|
1209
|
-
EGraph.current = self
|
|
1210
|
-
yield
|
|
1211
|
-
EGraph.current = prev_current
|
|
1212
|
-
|
|
1213
1200
|
@property
|
|
1214
1201
|
def _egraph(self) -> bindings.EGraph:
|
|
1215
1202
|
return self._state.egraph
|
|
@@ -1303,8 +1290,6 @@ class Schedule(DelayedDeclerations):
|
|
|
1303
1290
|
A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
|
|
1304
1291
|
"""
|
|
1305
1292
|
|
|
1306
|
-
current: ClassVar[Schedule | None] = None
|
|
1307
|
-
|
|
1308
1293
|
# Defer declerations so that we can have rule generators that used not yet defined yet
|
|
1309
1294
|
schedule: ScheduleDecl
|
|
1310
1295
|
|
|
@@ -1332,16 +1317,6 @@ class Schedule(DelayedDeclerations):
|
|
|
1332
1317
|
"""
|
|
1333
1318
|
return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
|
|
1334
1319
|
|
|
1335
|
-
@contextlib.contextmanager
|
|
1336
|
-
def set_current(self) -> Iterator[None]:
|
|
1337
|
-
"""
|
|
1338
|
-
Context manager that will set the current schedule. It will be set back after
|
|
1339
|
-
"""
|
|
1340
|
-
prev_current = Schedule.current
|
|
1341
|
-
Schedule.current = self
|
|
1342
|
-
yield
|
|
1343
|
-
Schedule.current = prev_current
|
|
1344
|
-
|
|
1345
1320
|
|
|
1346
1321
|
@dataclass
|
|
1347
1322
|
class Ruleset(Schedule):
|
|
@@ -1488,9 +1463,12 @@ class Fact:
|
|
|
1488
1463
|
|
|
1489
1464
|
def __bool__(self) -> bool:
|
|
1490
1465
|
"""
|
|
1491
|
-
Returns True if the two sides of an equality are equal
|
|
1466
|
+
Returns True if the two sides of an equality are structurally equal.
|
|
1492
1467
|
"""
|
|
1493
|
-
|
|
1468
|
+
if not isinstance(self.fact, EqDecl):
|
|
1469
|
+
msg = "Can only check equality facts"
|
|
1470
|
+
raise TypeError(msg)
|
|
1471
|
+
return self.fact.left == self.fact.right
|
|
1494
1472
|
|
|
1495
1473
|
|
|
1496
1474
|
@dataclass
|
|
@@ -1876,34 +1854,3 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
|
1876
1854
|
yield
|
|
1877
1855
|
finally:
|
|
1878
1856
|
_CURRENT_RULESET.reset(token)
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
T_co = TypeVar("T_co", covariant=True)
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
class _EvalsTo(Protocol, Generic[T_co]):
|
|
1885
|
-
def eval(self) -> T_co: ...
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
def try_evaling(schedule: Schedule, expr: Expr, prim_expr: _EvalsTo[T]) -> T:
|
|
1889
|
-
"""
|
|
1890
|
-
Try evaling the expression that will result in a primitive expression being fill.
|
|
1891
|
-
if it fails, display the egraph and raise an error.
|
|
1892
|
-
"""
|
|
1893
|
-
egraph = EGraph.current or EGraph()
|
|
1894
|
-
with egraph.set_current():
|
|
1895
|
-
try:
|
|
1896
|
-
return prim_expr.eval()
|
|
1897
|
-
except BaseException: # noqa: S110
|
|
1898
|
-
pass
|
|
1899
|
-
# If this primitive doesn't exist in the egraph, we need to try to create it by
|
|
1900
|
-
# registering the expression and running the schedule
|
|
1901
|
-
egraph.register(expr)
|
|
1902
|
-
egraph.run(Schedule.current or schedule)
|
|
1903
|
-
try:
|
|
1904
|
-
with egraph.set_current():
|
|
1905
|
-
return prim_expr.eval()
|
|
1906
|
-
except BaseException as e:
|
|
1907
|
-
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
1908
|
-
e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
|
|
1909
|
-
raise
|
egglog/examples/bignum.py
CHANGED
|
@@ -12,14 +12,15 @@ x = BigInt(-1234)
|
|
|
12
12
|
y = BigInt.from_string("2")
|
|
13
13
|
z = BigRat(x, y)
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
egraph = EGraph()
|
|
16
|
+
|
|
17
|
+
assert egraph.extract(z.numer.to_string()).eval() == "-617"
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
@function
|
|
19
21
|
def bignums(x: BigInt, y: BigInt) -> BigRat: ...
|
|
20
22
|
|
|
21
23
|
|
|
22
|
-
egraph = EGraph()
|
|
23
24
|
egraph.register(set_(bignums(x, y)).to(z))
|
|
24
25
|
|
|
25
26
|
c = var("c", BigRat)
|
egglog/examples/eqsat_basic.py
CHANGED
|
@@ -28,18 +28,17 @@ expr2 = Num(6) + Num(2) * Num.var("x")
|
|
|
28
28
|
a, b, c = vars_("a b c", Num)
|
|
29
29
|
i, j = vars_("i j", i64)
|
|
30
30
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
31
|
+
egraph = EGraph()
|
|
32
|
+
egraph.register(expr1, expr2)
|
|
33
|
+
|
|
34
|
+
egraph.run(
|
|
35
35
|
ruleset(
|
|
36
36
|
rewrite(a + b).to(b + a),
|
|
37
37
|
rewrite(a * (b + c)).to((a * b) + (a * c)),
|
|
38
38
|
rewrite(Num(i) + Num(j)).to(Num(i + j)),
|
|
39
39
|
rewrite(Num(i) * Num(j)).to(Num(i * j)),
|
|
40
40
|
)
|
|
41
|
-
* 10
|
|
42
|
-
# On these two initial expressions
|
|
43
|
-
expr1,
|
|
44
|
-
expr2,
|
|
41
|
+
* 10
|
|
45
42
|
)
|
|
43
|
+
|
|
44
|
+
egraph.check(expr1 == expr2)
|
egglog/examples/multiset.py
CHANGED
|
@@ -29,33 +29,32 @@ egraph = EGraph()
|
|
|
29
29
|
xs = MultiSet(Math(1), Math(2), Math(3))
|
|
30
30
|
egraph.register(xs)
|
|
31
31
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
assert xs != MultiSet(Math(1), Math(1), Math(2), Math(3))
|
|
32
|
+
egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
|
|
33
|
+
egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
|
|
35
34
|
|
|
36
|
-
|
|
35
|
+
assert Counter(egraph.extract(xs).eval()) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
|
|
37
36
|
|
|
38
|
-
inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
|
|
39
|
-
egraph.register(inserted)
|
|
40
|
-
assert xs.insert(Math(4)) == inserted
|
|
41
37
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
38
|
+
inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
|
|
39
|
+
egraph.register(inserted)
|
|
40
|
+
egraph.check(xs.insert(Math(4)) == inserted)
|
|
41
|
+
egraph.check(xs.contains(Math(1)))
|
|
42
|
+
egraph.check(xs.not_contains(Math(4)))
|
|
43
|
+
assert Math(1) in xs
|
|
44
|
+
assert Math(4) not in xs
|
|
46
45
|
|
|
47
|
-
|
|
46
|
+
egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
|
|
48
47
|
|
|
49
|
-
|
|
50
|
-
|
|
48
|
+
assert egraph.extract(xs.length()).eval() == 3
|
|
49
|
+
assert len(xs) == 3
|
|
51
50
|
|
|
52
|
-
|
|
51
|
+
egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
|
|
53
52
|
|
|
54
|
-
|
|
53
|
+
egraph.check(MultiSet(Math(1)).pick() == Math(1))
|
|
55
54
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
55
|
+
mapped = xs.map(square)
|
|
56
|
+
egraph.register(mapped)
|
|
57
|
+
egraph.run(math_ruleset)
|
|
58
|
+
egraph.check(mapped == MultiSet(Math(1), Math(4), Math(9)))
|
|
60
59
|
|
|
61
|
-
|
|
60
|
+
egraph.check(xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3)))
|
egglog/exp/array_api.py
CHANGED
|
@@ -54,6 +54,7 @@ A: So that when converting from functional to lists we can use the same index fu
|
|
|
54
54
|
|
|
55
55
|
from __future__ import annotations
|
|
56
56
|
|
|
57
|
+
import contextlib
|
|
57
58
|
import itertools
|
|
58
59
|
import math
|
|
59
60
|
import numbers
|
|
@@ -61,7 +62,7 @@ import os
|
|
|
61
62
|
import sys
|
|
62
63
|
from copy import copy
|
|
63
64
|
from types import EllipsisType
|
|
64
|
-
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
|
|
65
|
+
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
|
|
65
66
|
|
|
66
67
|
import numpy as np
|
|
67
68
|
|
|
@@ -93,7 +94,7 @@ class Boolean(Expr, ruleset=array_api_ruleset):
|
|
|
93
94
|
|
|
94
95
|
@method(preserve=True)
|
|
95
96
|
def eval(self) -> bool:
|
|
96
|
-
return try_evaling(array_api_schedule, self, self.to_bool)
|
|
97
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool)
|
|
97
98
|
|
|
98
99
|
@property
|
|
99
100
|
def to_bool(self) -> Bool: ...
|
|
@@ -218,7 +219,7 @@ class Int(Expr, ruleset=array_api_ruleset):
|
|
|
218
219
|
|
|
219
220
|
@method(preserve=True)
|
|
220
221
|
def eval(self) -> int:
|
|
221
|
-
return try_evaling(array_api_schedule, self, self.to_i64)
|
|
222
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64)
|
|
222
223
|
|
|
223
224
|
@method(preserve=True)
|
|
224
225
|
def __index__(self) -> int:
|
|
@@ -330,7 +331,7 @@ class Float(Expr, ruleset=array_api_ruleset):
|
|
|
330
331
|
|
|
331
332
|
@method(preserve=True)
|
|
332
333
|
def eval(self) -> float:
|
|
333
|
-
return try_evaling(array_api_schedule, self, self.to_f64)
|
|
334
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64)
|
|
334
335
|
|
|
335
336
|
def abs(self) -> Float: ...
|
|
336
337
|
|
|
@@ -438,7 +439,7 @@ class TupleInt(Expr, ruleset=array_api_ruleset):
|
|
|
438
439
|
|
|
439
440
|
@method(preserve=True)
|
|
440
441
|
def eval(self) -> tuple[Int, ...]:
|
|
441
|
-
return try_evaling(array_api_schedule, self, self.to_vec)
|
|
442
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
442
443
|
|
|
443
444
|
def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
|
|
444
445
|
def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
|
|
@@ -586,7 +587,7 @@ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
|
|
|
586
587
|
|
|
587
588
|
@method(preserve=True)
|
|
588
589
|
def eval(self) -> tuple[TupleInt, ...]:
|
|
589
|
-
return try_evaling(array_api_schedule, self, self.to_vec)
|
|
590
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
590
591
|
|
|
591
592
|
def drop(self, n: Int) -> TupleTupleInt:
|
|
592
593
|
return TupleTupleInt(self.length() - n, lambda i: self[i + n])
|
|
@@ -1282,7 +1283,7 @@ class TupleNDArray(Expr, ruleset=array_api_ruleset):
|
|
|
1282
1283
|
|
|
1283
1284
|
@method(preserve=True)
|
|
1284
1285
|
def eval(self) -> tuple[NDArray, ...]:
|
|
1285
|
-
return try_evaling(array_api_schedule, self, self.to_vec)
|
|
1286
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
1286
1287
|
|
|
1287
1288
|
|
|
1288
1289
|
converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
|
|
@@ -1941,3 +1942,42 @@ def array_api_vec_to_cons_ruleset(
|
|
|
1941
1942
|
|
|
1942
1943
|
array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
|
|
1943
1944
|
array_api_schedule = array_api_combined_ruleset.saturate()
|
|
1945
|
+
|
|
1946
|
+
_CURRENT_EGRAPH: None | EGraph = None
|
|
1947
|
+
|
|
1948
|
+
|
|
1949
|
+
@contextlib.contextmanager
|
|
1950
|
+
def set_array_api_egraph(egraph: EGraph) -> Iterator[None]:
|
|
1951
|
+
"""
|
|
1952
|
+
Context manager that will set the current egraph. It will be set back after.
|
|
1953
|
+
"""
|
|
1954
|
+
global _CURRENT_EGRAPH
|
|
1955
|
+
assert _CURRENT_EGRAPH is None
|
|
1956
|
+
_CURRENT_EGRAPH = egraph
|
|
1957
|
+
yield
|
|
1958
|
+
_CURRENT_EGRAPH = None
|
|
1959
|
+
|
|
1960
|
+
|
|
1961
|
+
def _get_current_egraph() -> EGraph:
|
|
1962
|
+
return _CURRENT_EGRAPH or EGraph()
|
|
1963
|
+
|
|
1964
|
+
|
|
1965
|
+
def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any:
|
|
1966
|
+
"""
|
|
1967
|
+
Try evaling the expression that will result in a primitive expression being fill.
|
|
1968
|
+
if it fails, display the egraph and raise an error.
|
|
1969
|
+
"""
|
|
1970
|
+
try:
|
|
1971
|
+
extracted = egraph.extract(prim_expr)
|
|
1972
|
+
except EggSmolError:
|
|
1973
|
+
# If this primitive doesn't exist in the egraph, we need to try to create it by
|
|
1974
|
+
# registering the expression and running the schedule
|
|
1975
|
+
egraph.register(expr)
|
|
1976
|
+
egraph.run(schedule)
|
|
1977
|
+
try:
|
|
1978
|
+
extracted = egraph.extract(prim_expr)
|
|
1979
|
+
except BaseException as e:
|
|
1980
|
+
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
1981
|
+
e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
|
|
1982
|
+
raise
|
|
1983
|
+
return extracted.eval() # type: ignore[attr-defined]
|
egglog/exp/array_api_jit.py
CHANGED
|
@@ -4,8 +4,8 @@ from typing import TypeVar, cast
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
|
-
from egglog import EGraph
|
|
8
|
-
from egglog.exp.array_api import NDArray
|
|
7
|
+
from egglog import EGraph
|
|
8
|
+
from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
|
|
9
9
|
from egglog.exp.array_api_numba import array_api_numba_schedule
|
|
10
10
|
from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
|
|
11
11
|
|
|
@@ -20,8 +20,7 @@ def jit(fn: X) -> X:
|
|
|
20
20
|
"""
|
|
21
21
|
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
|
|
22
22
|
fn_program = EvalProgram(program, {"np": np})
|
|
23
|
-
|
|
24
|
-
fn = cast("X", try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
|
|
23
|
+
fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
|
|
25
24
|
fn.initial_expr = res # type: ignore[attr-defined]
|
|
26
25
|
fn.expr = res_optimized # type: ignore[attr-defined]
|
|
27
26
|
return fn
|
|
@@ -32,9 +31,11 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
|
|
|
32
31
|
arg1, arg2 = sig.parameters.keys()
|
|
33
32
|
egraph = EGraph(save_egglog_string=save_egglog_string)
|
|
34
33
|
with egraph:
|
|
35
|
-
with egraph
|
|
34
|
+
with set_array_api_egraph(egraph):
|
|
36
35
|
res = fn(NDArray.var(arg1), NDArray.var(arg2))
|
|
37
|
-
|
|
36
|
+
egraph.register(res)
|
|
37
|
+
egraph.run(array_api_numba_schedule)
|
|
38
|
+
res_optimized = egraph.extract(res)
|
|
38
39
|
|
|
39
40
|
return (
|
|
40
41
|
egraph,
|
|
@@ -1,28 +1,28 @@
|
|
|
1
|
-
egglog-
|
|
2
|
-
egglog-
|
|
3
|
-
egglog-
|
|
1
|
+
egglog-10.0.0.dist-info/METADATA,sha256=aWuZK6v1QbkwYpgfn9_9UnMiwCvPBONgvULMJCkMyFw,4019
|
|
2
|
+
egglog-10.0.0.dist-info/WHEEL,sha256=jABKVkLC9kJr8mi_er5jOqpiQUjARSLXDUIIxDqsS50,96
|
|
3
|
+
egglog-10.0.0.dist-info/licenses/LICENSE,sha256=TfaboMVZ81Q6OUaKjU7z6uVjSlcGKclLYcOpgDbm9_s,1091
|
|
4
4
|
egglog/bindings.pyi,sha256=fnRiGyHuM-m0fDe91kFryUC09U87f5XXPiY6Lno8bvk,14497
|
|
5
|
-
egglog/builtins.py,sha256=
|
|
5
|
+
egglog/builtins.py,sha256=ZLXp5txd7CD_Wlb137DFdy_qGbvmvzbfb54tc_KbRfY,30692
|
|
6
6
|
egglog/config.py,sha256=mALVaxh7zmGrbuyzaVKVmYKcu1lF703QsKJw8AF7gSM,176
|
|
7
|
-
egglog/conversion.py,sha256=
|
|
7
|
+
egglog/conversion.py,sha256=o6iE32taKz-PBX8aDo5i4jGELdEai-wrzMEZUC1lqsU,10257
|
|
8
8
|
egglog/declarations.py,sha256=R4U8JhM6ChOMmPEgG6tL2sJ8wJGsIEEVVN5lVY-6k2w,23011
|
|
9
|
-
egglog/egraph.py,sha256=
|
|
9
|
+
egglog/egraph.py,sha256=ssy2UHt8weA8GXzomqTWMKB389HzMo62k826gas5Fng,64537
|
|
10
10
|
egglog/egraph_state.py,sha256=wkuHoXmEzcUFspvNl8ZG7_xS1J7XZTK3RPnlZEuvk9k,28720
|
|
11
|
-
egglog/examples/bignum.py,sha256=
|
|
11
|
+
egglog/examples/bignum.py,sha256=r4UeVwpKK2iSZiT5MK4qqTWhWu76CAtSJ5YGDboO9tk,568
|
|
12
12
|
egglog/examples/bool.py,sha256=pWZTjfXR1cFy3KcihLBU5AF5rn83ImORlhUUJ1YiAXc,733
|
|
13
|
-
egglog/examples/eqsat_basic.py,sha256=
|
|
13
|
+
egglog/examples/eqsat_basic.py,sha256=ivYHFD52sjYyl5rWQ60ScxmvcbABYUBlPobDMu6HvN4,913
|
|
14
14
|
egglog/examples/fib.py,sha256=wAn-PjazxgHDkXAU4o2xTk_GtM_iGL0biV66vWM1st4,520
|
|
15
15
|
egglog/examples/higher_order_functions.py,sha256=5P-_10y8eTh-z3bze5GVzKHyFTeefTM774haF8KYGig,1164
|
|
16
16
|
egglog/examples/lambda_.py,sha256=hQBOaSw_yorNcbkQVu2EhgSc0IZNWIny7asaOlcUk9s,8496
|
|
17
17
|
egglog/examples/matrix.py,sha256=_zmjgfFr2O_LjTcsTD-45_38Y_M1sP3AV39K6oFxAdw,5136
|
|
18
|
-
egglog/examples/multiset.py,sha256=
|
|
18
|
+
egglog/examples/multiset.py,sha256=kSUFiN4qs8WXvexPCecqk3o5DzdywXacDmIiC5IDKCM,1507
|
|
19
19
|
egglog/examples/ndarrays.py,sha256=T-wwef-n-3LDSjaO35zA8AZH5DXFFqq0XBSCQKEXV6E,4186
|
|
20
20
|
egglog/examples/README.rst,sha256=QrbfmivODBvUvmY3-dHarcbC6bEvwoqAfTDhiI-aJxU,237
|
|
21
21
|
egglog/examples/resolution.py,sha256=sKkbRI_v9XkQM0DriacKLINqKKDqYGFhvMCAS9tZbTA,2203
|
|
22
22
|
egglog/examples/schedule_demo.py,sha256=iJtIbcLaZ7zK8UalY0z7KAKMqYjQx0MKTsNF24lKtik,652
|
|
23
23
|
egglog/examples/__init__.py,sha256=KuhaJFOyz_rpUvEqZubsgLnv6rhQNE_AVFXA6bUnpdY,34
|
|
24
|
-
egglog/exp/array_api.py,sha256=
|
|
25
|
-
egglog/exp/array_api_jit.py,sha256=
|
|
24
|
+
egglog/exp/array_api.py,sha256=Y-mG8tC8ihFkeIbNDsDyZl7-7zP6YFxZGS4rfsWOG-8,65661
|
|
25
|
+
egglog/exp/array_api_jit.py,sha256=c0lN-411M9LmEz8uM0ELXTtfUPMJDZFmy81-x3ck6PY,1591
|
|
26
26
|
egglog/exp/array_api_loopnest.py,sha256=1KZ8kWtv00TVCbzqzWG92vz8Y1pBlPOLRIrJfxyDs9s,2518
|
|
27
27
|
egglog/exp/array_api_numba.py,sha256=g0Yb7auB2MtJHgSzn7iVnsxn6ffL4FHggKvbEuuDwZs,2765
|
|
28
28
|
egglog/exp/array_api_program_gen.py,sha256=qXaaW_Koz4P-KmQcyRHY1NDLCJL3O7NnE6iq2NX4yZU,22252
|
|
@@ -39,6 +39,6 @@ egglog/type_constraint_solver.py,sha256=kSJf8S1CxkicpPy_B2cNP9XiTf-d4TQNd7_IwfG-
|
|
|
39
39
|
egglog/visualizer.css,sha256=DDmxZfLmFygwVkQwel0gGUiM5BZqgunawWeSfaRYBr8,259399
|
|
40
40
|
egglog/visualizer.js,sha256=kDkYzE5cLLhmpNaWqSfQWzD0SlSsNe9KzT_GpNqhWgo,3789092
|
|
41
41
|
egglog/visualizer_widget.py,sha256=TGkf61P_f4EvCFcyqwQC_8XYUkujFhcQnEG9vHWLPC4,1358
|
|
42
|
-
egglog/__init__.py,sha256=
|
|
43
|
-
egglog/bindings.cp312-win_amd64.pyd,sha256=
|
|
44
|
-
egglog-
|
|
42
|
+
egglog/__init__.py,sha256=TZ3IMVKoz5Qr97VFe-YT7oyT2NACikTis21PpVlSCmg,326
|
|
43
|
+
egglog/bindings.cp312-win_amd64.pyd,sha256=dvGDJB1fnJuq2jbfACQTUqn6dEidzGrcJ36li09GvLk,4540416
|
|
44
|
+
egglog-10.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|