egglog 9.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl → 10.0.0__cp313-cp313t-manylinux_2_17_ppc64.manylinux2014_ppc64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -0
- egglog/bindings.cpython-313t-powerpc64-linux-gnu.so +0 -0
- egglog/builtins.py +75 -37
- egglog/conversion.py +6 -0
- egglog/egraph.py +7 -60
- egglog/examples/bignum.py +3 -2
- egglog/examples/eqsat_basic.py +7 -8
- egglog/examples/multiset.py +20 -21
- egglog/exp/array_api.py +47 -7
- egglog/exp/array_api_jit.py +7 -6
- {egglog-9.0.0.dist-info → egglog-10.0.0.dist-info}/METADATA +1 -1
- {egglog-9.0.0.dist-info → egglog-10.0.0.dist-info}/RECORD +14 -14
- {egglog-9.0.0.dist-info → egglog-10.0.0.dist-info}/WHEEL +0 -0
- {egglog-9.0.0.dist-info → egglog-10.0.0.dist-info}/licenses/LICENSE +0 -0
egglog/__init__.py
CHANGED
|
@@ -3,6 +3,7 @@ Package for creating e-graphs in Python.
|
|
|
3
3
|
"""
|
|
4
4
|
|
|
5
5
|
from . import config, ipython_magic # noqa: F401
|
|
6
|
+
from .bindings import EggSmolError # noqa: F401
|
|
6
7
|
from .builtins import * # noqa: UP029
|
|
7
8
|
from .conversion import ConvertError, convert, converter, get_type_args # noqa: F401
|
|
8
9
|
from .egraph import *
|
|
Binary file
|
egglog/builtins.py
CHANGED
|
@@ -12,11 +12,9 @@ from typing import TYPE_CHECKING, Generic, Protocol, TypeAlias, TypeVar, Union,
|
|
|
12
12
|
|
|
13
13
|
from typing_extensions import TypeVarTuple, Unpack
|
|
14
14
|
|
|
15
|
-
from . import bindings
|
|
16
15
|
from .conversion import convert, converter, get_type_args
|
|
17
16
|
from .declarations import *
|
|
18
|
-
from .egraph import BaseExpr, BuiltinExpr,
|
|
19
|
-
from .egraph_state import GLOBAL_PY_OBJECT_SORT
|
|
17
|
+
from .egraph import BaseExpr, BuiltinExpr, expr_fact, function, get_current_ruleset, method
|
|
20
18
|
from .functionalize import functionalize
|
|
21
19
|
from .runtime import RuntimeClass, RuntimeExpr, RuntimeFunction
|
|
22
20
|
from .thunk import Thunk
|
|
@@ -32,6 +30,7 @@ __all__ = [
|
|
|
32
30
|
"BigRatLike",
|
|
33
31
|
"Bool",
|
|
34
32
|
"BoolLike",
|
|
33
|
+
"BuiltinEvalError",
|
|
35
34
|
"Map",
|
|
36
35
|
"MapLike",
|
|
37
36
|
"MultiSet",
|
|
@@ -56,6 +55,17 @@ __all__ = [
|
|
|
56
55
|
]
|
|
57
56
|
|
|
58
57
|
|
|
58
|
+
class BuiltinEvalError(Exception):
|
|
59
|
+
"""
|
|
60
|
+
Raised when an builtin cannot be evaluated into a Python primitive because it is complex.
|
|
61
|
+
|
|
62
|
+
Try extracting this expression first.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __str__(self) -> str:
|
|
66
|
+
return f"Cannot evaluate builtin expression into a Python primitive. Try extracting this expression first: {super().__str__()}"
|
|
67
|
+
|
|
68
|
+
|
|
59
69
|
class Unit(BuiltinExpr, egg_sort="Unit"):
|
|
60
70
|
"""
|
|
61
71
|
The unit type. This is used to reprsent if a value exists in the e-graph or not.
|
|
@@ -72,8 +82,8 @@ class String(BuiltinExpr):
|
|
|
72
82
|
@method(preserve=True)
|
|
73
83
|
def eval(self) -> str:
|
|
74
84
|
value = _extract_lit(self)
|
|
75
|
-
assert isinstance(value,
|
|
76
|
-
return value
|
|
85
|
+
assert isinstance(value, str)
|
|
86
|
+
return value
|
|
77
87
|
|
|
78
88
|
def __init__(self, value: str) -> None: ...
|
|
79
89
|
|
|
@@ -97,8 +107,8 @@ class Bool(BuiltinExpr, egg_sort="bool"):
|
|
|
97
107
|
@method(preserve=True)
|
|
98
108
|
def eval(self) -> bool:
|
|
99
109
|
value = _extract_lit(self)
|
|
100
|
-
assert isinstance(value,
|
|
101
|
-
return value
|
|
110
|
+
assert isinstance(value, bool)
|
|
111
|
+
return value
|
|
102
112
|
|
|
103
113
|
@method(preserve=True)
|
|
104
114
|
def __bool__(self) -> bool:
|
|
@@ -132,8 +142,8 @@ class i64(BuiltinExpr): # noqa: N801
|
|
|
132
142
|
@method(preserve=True)
|
|
133
143
|
def eval(self) -> int:
|
|
134
144
|
value = _extract_lit(self)
|
|
135
|
-
assert isinstance(value,
|
|
136
|
-
return value
|
|
145
|
+
assert isinstance(value, int)
|
|
146
|
+
return value
|
|
137
147
|
|
|
138
148
|
@method(preserve=True)
|
|
139
149
|
def __index__(self) -> int:
|
|
@@ -195,6 +205,9 @@ class i64(BuiltinExpr): # noqa: N801
|
|
|
195
205
|
|
|
196
206
|
def __rrshift__(self, other: i64Like) -> i64: ...
|
|
197
207
|
|
|
208
|
+
@method(egg_fn="log2")
|
|
209
|
+
def log2(self) -> i64: ...
|
|
210
|
+
|
|
198
211
|
@method(egg_fn="not-i64")
|
|
199
212
|
def __invert__(self) -> i64: ...
|
|
200
213
|
|
|
@@ -248,8 +261,8 @@ class f64(BuiltinExpr): # noqa: N801
|
|
|
248
261
|
@method(preserve=True)
|
|
249
262
|
def eval(self) -> float:
|
|
250
263
|
value = _extract_lit(self)
|
|
251
|
-
assert isinstance(value,
|
|
252
|
-
return value
|
|
264
|
+
assert isinstance(value, float)
|
|
265
|
+
return value
|
|
253
266
|
|
|
254
267
|
@method(preserve=True)
|
|
255
268
|
def __float__(self) -> float:
|
|
@@ -337,9 +350,12 @@ class Map(BuiltinExpr, Generic[T, V]):
|
|
|
337
350
|
expr = cast("RuntimeExpr", self)
|
|
338
351
|
d = {}
|
|
339
352
|
while call.callable != ClassMethodRef("Map", "empty"):
|
|
340
|
-
|
|
353
|
+
msg = "Map can only be evaluated if it is empty or a series of inserts."
|
|
354
|
+
if call.callable != MethodRef("Map", "insert"):
|
|
355
|
+
raise BuiltinEvalError(msg)
|
|
341
356
|
call_typed, k_typed, v_typed = call.args
|
|
342
|
-
|
|
357
|
+
if not isinstance(call_typed.expr, CallDecl):
|
|
358
|
+
raise BuiltinEvalError(msg)
|
|
343
359
|
k = cast("T", expr.__with_expr__(k_typed))
|
|
344
360
|
v = cast("V", expr.__with_expr__(v_typed))
|
|
345
361
|
d[k] = v
|
|
@@ -401,7 +417,9 @@ class Set(BuiltinExpr, Generic[T]):
|
|
|
401
417
|
@method(preserve=True)
|
|
402
418
|
def eval(self) -> set[T]:
|
|
403
419
|
call = _extract_call(self)
|
|
404
|
-
|
|
420
|
+
if call.callable != InitRef("Set"):
|
|
421
|
+
msg = "Set can only be initialized with the Set constructor."
|
|
422
|
+
raise BuiltinEvalError(msg)
|
|
405
423
|
return {cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args}
|
|
406
424
|
|
|
407
425
|
@method(preserve=True)
|
|
@@ -463,7 +481,9 @@ class MultiSet(BuiltinExpr, Generic[T]):
|
|
|
463
481
|
@method(preserve=True)
|
|
464
482
|
def eval(self) -> list[T]:
|
|
465
483
|
call = _extract_call(self)
|
|
466
|
-
|
|
484
|
+
if call.callable != InitRef("MultiSet"):
|
|
485
|
+
msg = "MultiSet can only be initialized with the MultiSet constructor."
|
|
486
|
+
raise BuiltinEvalError(msg)
|
|
467
487
|
return [cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args]
|
|
468
488
|
|
|
469
489
|
@method(preserve=True)
|
|
@@ -510,11 +530,15 @@ class Rational(BuiltinExpr):
|
|
|
510
530
|
@method(preserve=True)
|
|
511
531
|
def eval(self) -> Fraction:
|
|
512
532
|
call = _extract_call(self)
|
|
513
|
-
|
|
533
|
+
if call.callable != InitRef("Rational"):
|
|
534
|
+
msg = "Rational can only be initialized with the Rational constructor."
|
|
535
|
+
raise BuiltinEvalError(msg)
|
|
514
536
|
|
|
515
537
|
def _to_int(e: TypedExprDecl) -> int:
|
|
516
538
|
expr = e.expr
|
|
517
|
-
|
|
539
|
+
if not isinstance(expr, LitDecl):
|
|
540
|
+
msg = "Rational can only be initialized with literals"
|
|
541
|
+
raise BuiltinEvalError(msg)
|
|
518
542
|
assert isinstance(expr.value, int)
|
|
519
543
|
return expr.value
|
|
520
544
|
|
|
@@ -593,9 +617,13 @@ class BigInt(BuiltinExpr):
|
|
|
593
617
|
@method(preserve=True)
|
|
594
618
|
def eval(self) -> int:
|
|
595
619
|
call = _extract_call(self)
|
|
596
|
-
|
|
620
|
+
if call.callable != ClassMethodRef("BigInt", "from_string"):
|
|
621
|
+
msg = "BigInt can only be initialized with the BigInt constructor."
|
|
622
|
+
raise BuiltinEvalError(msg)
|
|
597
623
|
(s,) = call.args
|
|
598
|
-
|
|
624
|
+
if not isinstance(s.expr, LitDecl):
|
|
625
|
+
msg = "BigInt can only be initialized with literals"
|
|
626
|
+
raise BuiltinEvalError(msg)
|
|
599
627
|
assert isinstance(s.expr.value, str)
|
|
600
628
|
return int(s.expr.value)
|
|
601
629
|
|
|
@@ -714,14 +742,19 @@ class BigRat(BuiltinExpr):
|
|
|
714
742
|
@method(preserve=True)
|
|
715
743
|
def eval(self) -> Fraction:
|
|
716
744
|
call = _extract_call(self)
|
|
717
|
-
|
|
745
|
+
if call.callable != InitRef("BigRat"):
|
|
746
|
+
msg = "BigRat can only be initialized with the BigRat constructor."
|
|
747
|
+
raise BuiltinEvalError(msg)
|
|
718
748
|
|
|
719
749
|
def _to_fraction(e: TypedExprDecl) -> Fraction:
|
|
720
750
|
expr = e.expr
|
|
721
|
-
|
|
722
|
-
|
|
751
|
+
if not isinstance(expr, CallDecl) or expr.callable != ClassMethodRef("BigInt", "from_string"):
|
|
752
|
+
msg = "BigRat can only be initialized BigInt strings"
|
|
753
|
+
raise BuiltinEvalError(msg)
|
|
723
754
|
(s,) = expr.args
|
|
724
|
-
|
|
755
|
+
if not isinstance(s.expr, LitDecl):
|
|
756
|
+
msg = "BigInt can only be initialized with literals"
|
|
757
|
+
raise BuiltinEvalError(msg)
|
|
725
758
|
assert isinstance(s.expr.value, str)
|
|
726
759
|
return Fraction(s.expr.value)
|
|
727
760
|
|
|
@@ -818,7 +851,10 @@ class Vec(BuiltinExpr, Generic[T]):
|
|
|
818
851
|
call = _extract_call(self)
|
|
819
852
|
if call.callable == ClassMethodRef("Vec", "empty"):
|
|
820
853
|
return ()
|
|
821
|
-
|
|
854
|
+
|
|
855
|
+
if call.callable != InitRef("Vec"):
|
|
856
|
+
msg = "Vec can only be initialized with the Vec constructor."
|
|
857
|
+
raise BuiltinEvalError(msg)
|
|
822
858
|
return tuple(cast("T", cast("RuntimeExpr", self).__with_expr__(x)) for x in call.args)
|
|
823
859
|
|
|
824
860
|
@method(preserve=True)
|
|
@@ -886,10 +922,11 @@ VecLike: TypeAlias = Vec[T] | tuple[TO, ...] | list[TO]
|
|
|
886
922
|
class PyObject(BuiltinExpr):
|
|
887
923
|
@method(preserve=True)
|
|
888
924
|
def eval(self) -> object:
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
925
|
+
expr = cast("RuntimeExpr", self).__egg_typed_expr__.expr
|
|
926
|
+
if not isinstance(expr, PyObjectDecl):
|
|
927
|
+
msg = "PyObject can only be evaluated if it is a PyObject literal"
|
|
928
|
+
raise BuiltinEvalError(msg)
|
|
929
|
+
return expr.value
|
|
893
930
|
|
|
894
931
|
def __init__(self, value: object) -> None: ...
|
|
895
932
|
|
|
@@ -1024,22 +1061,23 @@ def value_to_annotation(a: object) -> type | None:
|
|
|
1024
1061
|
converter(FunctionType, UnstableFn, _convert_function)
|
|
1025
1062
|
|
|
1026
1063
|
|
|
1027
|
-
def _extract_lit(e: BaseExpr) ->
|
|
1064
|
+
def _extract_lit(e: BaseExpr) -> LitType:
|
|
1028
1065
|
"""
|
|
1029
1066
|
Special case extracting literals to make this faster by using termdag directly.
|
|
1030
1067
|
"""
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
return
|
|
1068
|
+
expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
|
|
1069
|
+
if not isinstance(expr, LitDecl):
|
|
1070
|
+
msg = "Expected a literal"
|
|
1071
|
+
raise BuiltinEvalError(msg)
|
|
1072
|
+
return expr.value
|
|
1036
1073
|
|
|
1037
1074
|
|
|
1038
1075
|
def _extract_call(e: BaseExpr) -> CallDecl:
|
|
1039
1076
|
"""
|
|
1040
1077
|
Extracts the call form of an expression
|
|
1041
1078
|
"""
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1079
|
+
expr = cast("RuntimeExpr", e).__egg_typed_expr__.expr
|
|
1080
|
+
if not isinstance(expr, CallDecl):
|
|
1081
|
+
msg = "Expected a call expression"
|
|
1082
|
+
raise BuiltinEvalError(msg)
|
|
1045
1083
|
return expr
|
egglog/conversion.py
CHANGED
|
@@ -149,6 +149,12 @@ def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
|
|
|
149
149
|
decls = _retrieve_conversion_decls()
|
|
150
150
|
a_tp = _get_tp(a)
|
|
151
151
|
b_tp = _get_tp(b)
|
|
152
|
+
# Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
|
|
153
|
+
if not (
|
|
154
|
+
(isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
|
|
155
|
+
or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
|
|
156
|
+
):
|
|
157
|
+
raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
|
|
152
158
|
a_converts_to = {
|
|
153
159
|
to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
|
|
154
160
|
}
|
egglog/egraph.py
CHANGED
|
@@ -4,7 +4,7 @@ import contextlib
|
|
|
4
4
|
import inspect
|
|
5
5
|
import pathlib
|
|
6
6
|
import tempfile
|
|
7
|
-
from collections.abc import Callable, Generator, Iterable
|
|
7
|
+
from collections.abc import Callable, Generator, Iterable
|
|
8
8
|
from contextvars import ContextVar
|
|
9
9
|
from dataclasses import InitVar, dataclass, field
|
|
10
10
|
from functools import partial
|
|
@@ -17,7 +17,6 @@ from typing import (
|
|
|
17
17
|
Generic,
|
|
18
18
|
Literal,
|
|
19
19
|
Never,
|
|
20
|
-
Protocol,
|
|
21
20
|
TypeAlias,
|
|
22
21
|
TypedDict,
|
|
23
22
|
TypeVar,
|
|
@@ -85,7 +84,6 @@ __all__ = [
|
|
|
85
84
|
"set_",
|
|
86
85
|
"simplify",
|
|
87
86
|
"subsume",
|
|
88
|
-
"try_evaling",
|
|
89
87
|
"union",
|
|
90
88
|
"unstable_combine_rulesets",
|
|
91
89
|
"var",
|
|
@@ -847,11 +845,10 @@ class EGraph:
|
|
|
847
845
|
Can run actions, check facts, run schedules, or extract minimal cost expressions.
|
|
848
846
|
"""
|
|
849
847
|
|
|
850
|
-
current: ClassVar[EGraph | None] = None
|
|
851
848
|
seminaive: InitVar[bool] = True
|
|
852
849
|
save_egglog_string: InitVar[bool] = False
|
|
853
850
|
|
|
854
|
-
_state: EGraphState = field(init=False)
|
|
851
|
+
_state: EGraphState = field(init=False, repr=False)
|
|
855
852
|
# For pushing/popping with egglog
|
|
856
853
|
_state_stack: list[EGraphState] = field(default_factory=list, repr=False)
|
|
857
854
|
# For storing the global "current" egraph
|
|
@@ -1200,16 +1197,6 @@ class EGraph:
|
|
|
1200
1197
|
if visualize:
|
|
1201
1198
|
VisualizerWidget(egraphs=egraphs).display_or_open()
|
|
1202
1199
|
|
|
1203
|
-
@contextlib.contextmanager
|
|
1204
|
-
def set_current(self) -> Iterator[None]:
|
|
1205
|
-
"""
|
|
1206
|
-
Context manager that will set the current egraph. It will be set back after.
|
|
1207
|
-
"""
|
|
1208
|
-
prev_current = EGraph.current
|
|
1209
|
-
EGraph.current = self
|
|
1210
|
-
yield
|
|
1211
|
-
EGraph.current = prev_current
|
|
1212
|
-
|
|
1213
1200
|
@property
|
|
1214
1201
|
def _egraph(self) -> bindings.EGraph:
|
|
1215
1202
|
return self._state.egraph
|
|
@@ -1303,8 +1290,6 @@ class Schedule(DelayedDeclerations):
|
|
|
1303
1290
|
A composition of some rulesets, either composing them sequentially, running them repeatedly, running them till saturation, or running until some facts are met
|
|
1304
1291
|
"""
|
|
1305
1292
|
|
|
1306
|
-
current: ClassVar[Schedule | None] = None
|
|
1307
|
-
|
|
1308
1293
|
# Defer declerations so that we can have rule generators that used not yet defined yet
|
|
1309
1294
|
schedule: ScheduleDecl
|
|
1310
1295
|
|
|
@@ -1332,16 +1317,6 @@ class Schedule(DelayedDeclerations):
|
|
|
1332
1317
|
"""
|
|
1333
1318
|
return Schedule(Thunk.fn(Declarations.create, self, other), SequenceDecl((self.schedule, other.schedule)))
|
|
1334
1319
|
|
|
1335
|
-
@contextlib.contextmanager
|
|
1336
|
-
def set_current(self) -> Iterator[None]:
|
|
1337
|
-
"""
|
|
1338
|
-
Context manager that will set the current schedule. It will be set back after
|
|
1339
|
-
"""
|
|
1340
|
-
prev_current = Schedule.current
|
|
1341
|
-
Schedule.current = self
|
|
1342
|
-
yield
|
|
1343
|
-
Schedule.current = prev_current
|
|
1344
|
-
|
|
1345
1320
|
|
|
1346
1321
|
@dataclass
|
|
1347
1322
|
class Ruleset(Schedule):
|
|
@@ -1488,9 +1463,12 @@ class Fact:
|
|
|
1488
1463
|
|
|
1489
1464
|
def __bool__(self) -> bool:
|
|
1490
1465
|
"""
|
|
1491
|
-
Returns True if the two sides of an equality are equal
|
|
1466
|
+
Returns True if the two sides of an equality are structurally equal.
|
|
1492
1467
|
"""
|
|
1493
|
-
|
|
1468
|
+
if not isinstance(self.fact, EqDecl):
|
|
1469
|
+
msg = "Can only check equality facts"
|
|
1470
|
+
raise TypeError(msg)
|
|
1471
|
+
return self.fact.left == self.fact.right
|
|
1494
1472
|
|
|
1495
1473
|
|
|
1496
1474
|
@dataclass
|
|
@@ -1876,34 +1854,3 @@ def set_current_ruleset(r: Ruleset | None) -> Generator[None, None, None]:
|
|
|
1876
1854
|
yield
|
|
1877
1855
|
finally:
|
|
1878
1856
|
_CURRENT_RULESET.reset(token)
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
T_co = TypeVar("T_co", covariant=True)
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
class _EvalsTo(Protocol, Generic[T_co]):
|
|
1885
|
-
def eval(self) -> T_co: ...
|
|
1886
|
-
|
|
1887
|
-
|
|
1888
|
-
def try_evaling(schedule: Schedule, expr: Expr, prim_expr: _EvalsTo[T]) -> T:
|
|
1889
|
-
"""
|
|
1890
|
-
Try evaling the expression that will result in a primitive expression being fill.
|
|
1891
|
-
if it fails, display the egraph and raise an error.
|
|
1892
|
-
"""
|
|
1893
|
-
egraph = EGraph.current or EGraph()
|
|
1894
|
-
with egraph.set_current():
|
|
1895
|
-
try:
|
|
1896
|
-
return prim_expr.eval()
|
|
1897
|
-
except BaseException: # noqa: S110
|
|
1898
|
-
pass
|
|
1899
|
-
# If this primitive doesn't exist in the egraph, we need to try to create it by
|
|
1900
|
-
# registering the expression and running the schedule
|
|
1901
|
-
egraph.register(expr)
|
|
1902
|
-
egraph.run(Schedule.current or schedule)
|
|
1903
|
-
try:
|
|
1904
|
-
with egraph.set_current():
|
|
1905
|
-
return prim_expr.eval()
|
|
1906
|
-
except BaseException as e:
|
|
1907
|
-
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
1908
|
-
e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
|
|
1909
|
-
raise
|
egglog/examples/bignum.py
CHANGED
|
@@ -12,14 +12,15 @@ x = BigInt(-1234)
|
|
|
12
12
|
y = BigInt.from_string("2")
|
|
13
13
|
z = BigRat(x, y)
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
egraph = EGraph()
|
|
16
|
+
|
|
17
|
+
assert egraph.extract(z.numer.to_string()).eval() == "-617"
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
@function
|
|
19
21
|
def bignums(x: BigInt, y: BigInt) -> BigRat: ...
|
|
20
22
|
|
|
21
23
|
|
|
22
|
-
egraph = EGraph()
|
|
23
24
|
egraph.register(set_(bignums(x, y)).to(z))
|
|
24
25
|
|
|
25
26
|
c = var("c", BigRat)
|
egglog/examples/eqsat_basic.py
CHANGED
|
@@ -28,18 +28,17 @@ expr2 = Num(6) + Num(2) * Num.var("x")
|
|
|
28
28
|
a, b, c = vars_("a b c", Num)
|
|
29
29
|
i, j = vars_("i j", i64)
|
|
30
30
|
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
31
|
+
egraph = EGraph()
|
|
32
|
+
egraph.register(expr1, expr2)
|
|
33
|
+
|
|
34
|
+
egraph.run(
|
|
35
35
|
ruleset(
|
|
36
36
|
rewrite(a + b).to(b + a),
|
|
37
37
|
rewrite(a * (b + c)).to((a * b) + (a * c)),
|
|
38
38
|
rewrite(Num(i) + Num(j)).to(Num(i + j)),
|
|
39
39
|
rewrite(Num(i) * Num(j)).to(Num(i * j)),
|
|
40
40
|
)
|
|
41
|
-
* 10
|
|
42
|
-
# On these two initial expressions
|
|
43
|
-
expr1,
|
|
44
|
-
expr2,
|
|
41
|
+
* 10
|
|
45
42
|
)
|
|
43
|
+
|
|
44
|
+
egraph.check(expr1 == expr2)
|
egglog/examples/multiset.py
CHANGED
|
@@ -29,33 +29,32 @@ egraph = EGraph()
|
|
|
29
29
|
xs = MultiSet(Math(1), Math(2), Math(3))
|
|
30
30
|
egraph.register(xs)
|
|
31
31
|
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
assert xs != MultiSet(Math(1), Math(1), Math(2), Math(3))
|
|
32
|
+
egraph.check(xs == MultiSet(Math(1), Math(3), Math(2)))
|
|
33
|
+
egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3)))
|
|
35
34
|
|
|
36
|
-
|
|
35
|
+
assert Counter(egraph.extract(xs).eval()) == Counter({Math(1): 1, Math(2): 1, Math(3): 1})
|
|
37
36
|
|
|
38
|
-
inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
|
|
39
|
-
egraph.register(inserted)
|
|
40
|
-
assert xs.insert(Math(4)) == inserted
|
|
41
37
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
38
|
+
inserted = MultiSet(Math(1), Math(2), Math(3), Math(4))
|
|
39
|
+
egraph.register(inserted)
|
|
40
|
+
egraph.check(xs.insert(Math(4)) == inserted)
|
|
41
|
+
egraph.check(xs.contains(Math(1)))
|
|
42
|
+
egraph.check(xs.not_contains(Math(4)))
|
|
43
|
+
assert Math(1) in xs
|
|
44
|
+
assert Math(4) not in xs
|
|
46
45
|
|
|
47
|
-
|
|
46
|
+
egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3)))
|
|
48
47
|
|
|
49
|
-
|
|
50
|
-
|
|
48
|
+
assert egraph.extract(xs.length()).eval() == 3
|
|
49
|
+
assert len(xs) == 3
|
|
51
50
|
|
|
52
|
-
|
|
51
|
+
egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2))
|
|
53
52
|
|
|
54
|
-
|
|
53
|
+
egraph.check(MultiSet(Math(1)).pick() == Math(1))
|
|
55
54
|
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
55
|
+
mapped = xs.map(square)
|
|
56
|
+
egraph.register(mapped)
|
|
57
|
+
egraph.run(math_ruleset)
|
|
58
|
+
egraph.check(mapped == MultiSet(Math(1), Math(4), Math(9)))
|
|
60
59
|
|
|
61
|
-
|
|
60
|
+
egraph.check(xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3)))
|
egglog/exp/array_api.py
CHANGED
|
@@ -54,6 +54,7 @@ A: So that when converting from functional to lists we can use the same index fu
|
|
|
54
54
|
|
|
55
55
|
from __future__ import annotations
|
|
56
56
|
|
|
57
|
+
import contextlib
|
|
57
58
|
import itertools
|
|
58
59
|
import math
|
|
59
60
|
import numbers
|
|
@@ -61,7 +62,7 @@ import os
|
|
|
61
62
|
import sys
|
|
62
63
|
from copy import copy
|
|
63
64
|
from types import EllipsisType
|
|
64
|
-
from typing import TYPE_CHECKING, ClassVar, TypeAlias, cast
|
|
65
|
+
from typing import TYPE_CHECKING, Any, ClassVar, TypeAlias, cast
|
|
65
66
|
|
|
66
67
|
import numpy as np
|
|
67
68
|
|
|
@@ -93,7 +94,7 @@ class Boolean(Expr, ruleset=array_api_ruleset):
|
|
|
93
94
|
|
|
94
95
|
@method(preserve=True)
|
|
95
96
|
def eval(self) -> bool:
|
|
96
|
-
return try_evaling(array_api_schedule, self, self.to_bool)
|
|
97
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_bool)
|
|
97
98
|
|
|
98
99
|
@property
|
|
99
100
|
def to_bool(self) -> Bool: ...
|
|
@@ -218,7 +219,7 @@ class Int(Expr, ruleset=array_api_ruleset):
|
|
|
218
219
|
|
|
219
220
|
@method(preserve=True)
|
|
220
221
|
def eval(self) -> int:
|
|
221
|
-
return try_evaling(array_api_schedule, self, self.to_i64)
|
|
222
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_i64)
|
|
222
223
|
|
|
223
224
|
@method(preserve=True)
|
|
224
225
|
def __index__(self) -> int:
|
|
@@ -330,7 +331,7 @@ class Float(Expr, ruleset=array_api_ruleset):
|
|
|
330
331
|
|
|
331
332
|
@method(preserve=True)
|
|
332
333
|
def eval(self) -> float:
|
|
333
|
-
return try_evaling(array_api_schedule, self, self.to_f64)
|
|
334
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_f64)
|
|
334
335
|
|
|
335
336
|
def abs(self) -> Float: ...
|
|
336
337
|
|
|
@@ -438,7 +439,7 @@ class TupleInt(Expr, ruleset=array_api_ruleset):
|
|
|
438
439
|
|
|
439
440
|
@method(preserve=True)
|
|
440
441
|
def eval(self) -> tuple[Int, ...]:
|
|
441
|
-
return try_evaling(array_api_schedule, self, self.to_vec)
|
|
442
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
442
443
|
|
|
443
444
|
def foldl(self, f: Callable[[Int, Int], Int], init: Int) -> Int: ...
|
|
444
445
|
def foldl_boolean(self, f: Callable[[Boolean, Int], Boolean], init: Boolean) -> Boolean: ...
|
|
@@ -586,7 +587,7 @@ class TupleTupleInt(Expr, ruleset=array_api_ruleset):
|
|
|
586
587
|
|
|
587
588
|
@method(preserve=True)
|
|
588
589
|
def eval(self) -> tuple[TupleInt, ...]:
|
|
589
|
-
return try_evaling(array_api_schedule, self, self.to_vec)
|
|
590
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
590
591
|
|
|
591
592
|
def drop(self, n: Int) -> TupleTupleInt:
|
|
592
593
|
return TupleTupleInt(self.length() - n, lambda i: self[i + n])
|
|
@@ -1282,7 +1283,7 @@ class TupleNDArray(Expr, ruleset=array_api_ruleset):
|
|
|
1282
1283
|
|
|
1283
1284
|
@method(preserve=True)
|
|
1284
1285
|
def eval(self) -> tuple[NDArray, ...]:
|
|
1285
|
-
return try_evaling(array_api_schedule, self, self.to_vec)
|
|
1286
|
+
return try_evaling(_get_current_egraph(), array_api_schedule, self, self.to_vec)
|
|
1286
1287
|
|
|
1287
1288
|
|
|
1288
1289
|
converter(Vec[NDArray], TupleNDArray, lambda x: TupleNDArray.from_vec(x))
|
|
@@ -1941,3 +1942,42 @@ def array_api_vec_to_cons_ruleset(
|
|
|
1941
1942
|
|
|
1942
1943
|
array_api_combined_ruleset = array_api_ruleset | array_api_vec_to_cons_ruleset
|
|
1943
1944
|
array_api_schedule = array_api_combined_ruleset.saturate()
|
|
1945
|
+
|
|
1946
|
+
_CURRENT_EGRAPH: None | EGraph = None
|
|
1947
|
+
|
|
1948
|
+
|
|
1949
|
+
@contextlib.contextmanager
|
|
1950
|
+
def set_array_api_egraph(egraph: EGraph) -> Iterator[None]:
|
|
1951
|
+
"""
|
|
1952
|
+
Context manager that will set the current egraph. It will be set back after.
|
|
1953
|
+
"""
|
|
1954
|
+
global _CURRENT_EGRAPH
|
|
1955
|
+
assert _CURRENT_EGRAPH is None
|
|
1956
|
+
_CURRENT_EGRAPH = egraph
|
|
1957
|
+
yield
|
|
1958
|
+
_CURRENT_EGRAPH = None
|
|
1959
|
+
|
|
1960
|
+
|
|
1961
|
+
def _get_current_egraph() -> EGraph:
|
|
1962
|
+
return _CURRENT_EGRAPH or EGraph()
|
|
1963
|
+
|
|
1964
|
+
|
|
1965
|
+
def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: BuiltinExpr) -> Any:
|
|
1966
|
+
"""
|
|
1967
|
+
Try evaling the expression that will result in a primitive expression being fill.
|
|
1968
|
+
if it fails, display the egraph and raise an error.
|
|
1969
|
+
"""
|
|
1970
|
+
try:
|
|
1971
|
+
extracted = egraph.extract(prim_expr)
|
|
1972
|
+
except EggSmolError:
|
|
1973
|
+
# If this primitive doesn't exist in the egraph, we need to try to create it by
|
|
1974
|
+
# registering the expression and running the schedule
|
|
1975
|
+
egraph.register(expr)
|
|
1976
|
+
egraph.run(schedule)
|
|
1977
|
+
try:
|
|
1978
|
+
extracted = egraph.extract(prim_expr)
|
|
1979
|
+
except BaseException as e:
|
|
1980
|
+
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
1981
|
+
e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
|
|
1982
|
+
raise
|
|
1983
|
+
return extracted.eval() # type: ignore[attr-defined]
|
egglog/exp/array_api_jit.py
CHANGED
|
@@ -4,8 +4,8 @@ from typing import TypeVar, cast
|
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
|
|
7
|
-
from egglog import EGraph
|
|
8
|
-
from egglog.exp.array_api import NDArray
|
|
7
|
+
from egglog import EGraph
|
|
8
|
+
from egglog.exp.array_api import NDArray, set_array_api_egraph, try_evaling
|
|
9
9
|
from egglog.exp.array_api_numba import array_api_numba_schedule
|
|
10
10
|
from egglog.exp.array_api_program_gen import EvalProgram, array_api_program_gen_schedule, ndarray_function_two_program
|
|
11
11
|
|
|
@@ -20,8 +20,7 @@ def jit(fn: X) -> X:
|
|
|
20
20
|
"""
|
|
21
21
|
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
|
|
22
22
|
fn_program = EvalProgram(program, {"np": np})
|
|
23
|
-
|
|
24
|
-
fn = cast("X", try_evaling(array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
|
|
23
|
+
fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
|
|
25
24
|
fn.initial_expr = res # type: ignore[attr-defined]
|
|
26
25
|
fn.expr = res_optimized # type: ignore[attr-defined]
|
|
27
26
|
return fn
|
|
@@ -32,9 +31,11 @@ def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph,
|
|
|
32
31
|
arg1, arg2 = sig.parameters.keys()
|
|
33
32
|
egraph = EGraph(save_egglog_string=save_egglog_string)
|
|
34
33
|
with egraph:
|
|
35
|
-
with egraph
|
|
34
|
+
with set_array_api_egraph(egraph):
|
|
36
35
|
res = fn(NDArray.var(arg1), NDArray.var(arg2))
|
|
37
|
-
|
|
36
|
+
egraph.register(res)
|
|
37
|
+
egraph.run(array_api_numba_schedule)
|
|
38
|
+
res_optimized = egraph.extract(res)
|
|
38
39
|
|
|
39
40
|
return (
|
|
40
41
|
egraph,
|
|
@@ -1,16 +1,16 @@
|
|
|
1
|
-
egglog-
|
|
2
|
-
egglog-
|
|
3
|
-
egglog-
|
|
1
|
+
egglog-10.0.0.dist-info/METADATA,sha256=doz_j7GUVTFvkbavd_wTPB4etBIjYF4IpJxs7e4hlVA,4009
|
|
2
|
+
egglog-10.0.0.dist-info/WHEEL,sha256=nkVomvFuhEEAN02m8Bbi0AHf2v0ovNpK1mLZPbatVh0,128
|
|
3
|
+
egglog-10.0.0.dist-info/licenses/LICENSE,sha256=w7VlVv5O_FPZRo8Z-4Zb_q7D5ac3YDs8JUkMZ4Gq9CE,1070
|
|
4
4
|
egglog/visualizer_widget.py,sha256=LtVfzOtv2WeKtNuILQQ_9SOHWvRr8YdBYQDKQSgry_s,1319
|
|
5
5
|
egglog/egraph_state.py,sha256=KxP4aXfh21vVqx1HbQMtb2PRxJg4Gzp1Po6Jp5Zppk0,28086
|
|
6
|
-
egglog/builtins.py,sha256=
|
|
6
|
+
egglog/builtins.py,sha256=6kVNJHRsJJDy_4_MWKtMztbpfWKyv15NnHb3m-9aJlY,29609
|
|
7
7
|
egglog/exp/array_api_numba.py,sha256=X3H1TnCjPL92uVm6OvcWMJ11IeorAE58zWiOX6huPv4,2696
|
|
8
8
|
egglog/exp/program_gen.py,sha256=9q8-ihkXLgeRU_BrINF3t3gThdS74KCzTo4UiRIR6Dk,13027
|
|
9
9
|
egglog/exp/array_api_program_gen.py,sha256=0FpWZZJ3VqIod9vIYfQYUIiyswfj4h4K_4s6GZb8SJY,21742
|
|
10
10
|
egglog/exp/siu_examples.py,sha256=yZ-sgH2Y12iTdwBUumP7D2OtCGL83M6pPW7PMobVFXc,719
|
|
11
11
|
egglog/exp/__init__.py,sha256=nPtzrH1bz1LVZhZCuS0S9Qild8m5gEikjOVqWAFIa88,49
|
|
12
|
-
egglog/exp/array_api.py,sha256=
|
|
13
|
-
egglog/exp/array_api_jit.py,sha256=
|
|
12
|
+
egglog/exp/array_api.py,sha256=jcVAwJN-b_n7bDnDIz2cadQ67Okk5CDGPOiAgpTJgfc,63678
|
|
13
|
+
egglog/exp/array_api_jit.py,sha256=90RmyivRoCKzVtiYWnTBkK2q2FDkD2p1iq7vzbh68b4,1546
|
|
14
14
|
egglog/exp/array_api_loopnest.py,sha256=-kbyorlGxvlaNsLx1nmLfEZHQM7VMEBwSKtV0l-bs0g,2444
|
|
15
15
|
egglog/type_constraint_solver.py,sha256=_y52JJoxWJDzokvS07-QHJ4sQpvWE86_Dk43RAUqhAk,4516
|
|
16
16
|
egglog/runtime.py,sha256=vyzN52njMucBaf8sRl4A5FVBrqDMfNaV8PlG8lzNxhc,25741
|
|
@@ -18,13 +18,13 @@ egglog/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
18
18
|
egglog/bindings.pyi,sha256=b0osaWlb6Yqh7Cw7I_DnRjU-6NTeXq1XOolvOvYMMcI,13830
|
|
19
19
|
egglog/config.py,sha256=yM3FIcVCKnhWZmHD0pxkzx7ah7t5PxZx3WUqKtA9tjU,168
|
|
20
20
|
egglog/declarations.py,sha256=UzpOjoZz21T9QfodmS3U00vx5VcTcryXmAw11CALix8,22193
|
|
21
|
-
egglog/__init__.py,sha256=
|
|
21
|
+
egglog/__init__.py,sha256=5n2nas3VbxxgsvYeT8CutVNHbC2jggXQtDQ3RzX3vCM,315
|
|
22
22
|
egglog/thunk.py,sha256=BwdqvNhY8g3KdSE5DtT1yIoxCF3mNGcqoODy8nXm9z8,2168
|
|
23
23
|
egglog/pretty.py,sha256=5hcc5-tAGnhcpz_aClJ1PgE3StOvFqQRYXYFDWXbNP0,20476
|
|
24
|
-
egglog/conversion.py,sha256=
|
|
24
|
+
egglog/conversion.py,sha256=lE9a3myOWyxlYbCVqOMKnqXw2_Ul24DJuJlVZ50HCA0,9989
|
|
25
25
|
egglog/ipython_magic.py,sha256=2hs3g2cSiyDmbCvE2t1OINmu17Bb8MWV--2DpEWwO7I,1189
|
|
26
26
|
egglog/functionalize.py,sha256=VMqiWcDbn1pYrY3FMxT143Bk90dMaxo2VapUL60-fvE,3847
|
|
27
|
-
egglog/egraph.py,sha256=
|
|
27
|
+
egglog/egraph.py,sha256=efeQ6KLPy4FPtEB9bLxyf4FI0RqFYzsYOBZP-aP-3YQ,62681
|
|
28
28
|
egglog/visualizer.css,sha256=eL0POoThQRc0P4OYnDT-d808ln9O5Qy6DizH9Z5LgWc,259398
|
|
29
29
|
egglog/visualizer.js,sha256=2qZZ-9W_INJx4gZMYjnVXl27IjT_JNuQyEeI2dbjWoU,3753315
|
|
30
30
|
egglog/examples/schedule_demo.py,sha256=JbXdPII7_adxtgyKVAiqCyV2sj88VZ-DhomYrdn8vuc,618
|
|
@@ -32,13 +32,13 @@ egglog/examples/README.rst,sha256=ztTvpofR0eotSqGoCy_C1fPLDPCncjvcqDanXtLHNNU,23
|
|
|
32
32
|
egglog/examples/fib.py,sha256=BOHxKWA7jGx4FURBmfmuZKfLo6xq9-uXAwAXjYid7LU,492
|
|
33
33
|
egglog/examples/resolution.py,sha256=BJd5JClA3DBVGfiVRa-H0gbbFvIqeP3uYbhCXHblSQc,2119
|
|
34
34
|
egglog/examples/lambda_.py,sha256=LK1ZSCITpvxpRMaqISZkRAKqNAU7QbaWq9oGe7ku2Ns,8208
|
|
35
|
-
egglog/examples/multiset.py,sha256=
|
|
36
|
-
egglog/examples/bignum.py,sha256=
|
|
35
|
+
egglog/examples/multiset.py,sha256=fdEPvNFkHe_XmzBv90JCP8SCxoKgpg_CIDsR8HWcbnY,1447
|
|
36
|
+
egglog/examples/bignum.py,sha256=LziQqBdOaYNhO3VQS6ZokHDO7_QSsOMV4Whyrj6-6nA,536
|
|
37
37
|
egglog/examples/matrix.py,sha256=7_mPcMcgE-t_GJDyf76-nv3xhPIeN2mvFkc_p_Gnr8g,4961
|
|
38
38
|
egglog/examples/__init__.py,sha256=wm9evUbMPfbtylXIjbDdRTAVMLH4OjT4Z77PCBFyaPU,31
|
|
39
39
|
egglog/examples/higher_order_functions.py,sha256=hWs9sCM5mERuuYscw9Qp-W3HEHi0EMtAs57gV2Qv69k,1119
|
|
40
40
|
egglog/examples/bool.py,sha256=e0z2YoYJsLlhpSADZK1yRYHzilyxSZWGiYAaM0DQ_Gw,695
|
|
41
41
|
egglog/examples/ndarrays.py,sha256=mfr410eletH8gfdg-P8L90vlF6TUifvYV_-ryOwvZZE,4042
|
|
42
|
-
egglog/examples/eqsat_basic.py,sha256=
|
|
43
|
-
egglog/bindings.cpython-313t-powerpc64-linux-gnu.so,sha256=
|
|
44
|
-
egglog-
|
|
42
|
+
egglog/examples/eqsat_basic.py,sha256=2xtM81gG9Br72mr58N-2BUeksR7C_UXnZJ4MvzSPplc,869
|
|
43
|
+
egglog/bindings.cpython-313t-powerpc64-linux-gnu.so,sha256=w6c8a6OcN0Xb9Kf_7NLeuWb3vvj5bFXJyYe3n9RK8Zc,6991072
|
|
44
|
+
egglog-10.0.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|