egglog 10.0.1__cp313-cp313-win_amd64.whl → 11.0.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +3 -1
- egglog/bindings.cp313-win_amd64.pyd +0 -0
- egglog/bindings.pyi +26 -34
- egglog/builtins.py +231 -182
- egglog/conversion.py +61 -43
- egglog/declarations.py +104 -18
- egglog/deconstruct.py +173 -0
- egglog/egraph.py +86 -144
- egglog/egraph_state.py +17 -14
- egglog/examples/bignum.py +1 -1
- egglog/examples/multiset.py +2 -2
- egglog/exp/array_api.py +46 -12
- egglog/exp/array_api_jit.py +11 -5
- egglog/exp/array_api_program_gen.py +1 -1
- egglog/exp/program_gen.py +4 -3
- egglog/pretty.py +11 -25
- egglog/runtime.py +203 -151
- egglog/thunk.py +6 -4
- egglog/type_constraint_solver.py +1 -1
- egglog/version_compat.py +87 -0
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/METADATA +1 -1
- egglog-11.0.0.dist-info/RECORD +45 -0
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/WHEEL +1 -1
- egglog/functionalize.py +0 -91
- egglog-10.0.1.dist-info/RECORD +0 -44
- {egglog-10.0.1.dist-info → egglog-11.0.0.dist-info}/licenses/LICENSE +0 -0
egglog/exp/array_api.py
CHANGED
|
@@ -69,6 +69,7 @@ import numpy as np
|
|
|
69
69
|
|
|
70
70
|
from egglog import *
|
|
71
71
|
from egglog.runtime import RuntimeExpr
|
|
72
|
+
from egglog.version_compat import add_note
|
|
72
73
|
|
|
73
74
|
from .program_gen import *
|
|
74
75
|
|
|
@@ -153,6 +154,18 @@ class Int(Expr, ruleset=array_api_ruleset):
|
|
|
153
154
|
def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
|
|
154
155
|
...
|
|
155
156
|
|
|
157
|
+
# add a hash so that this test can pass
|
|
158
|
+
# https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
|
|
159
|
+
@method(preserve=True)
|
|
160
|
+
def __hash__(self) -> int:
|
|
161
|
+
egraph = _get_current_egraph()
|
|
162
|
+
egraph.register(self)
|
|
163
|
+
egraph.run(array_api_schedule)
|
|
164
|
+
simplified = egraph.extract(self)
|
|
165
|
+
return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
|
|
166
|
+
|
|
167
|
+
def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...
|
|
168
|
+
|
|
156
169
|
# TODO: Fix this?
|
|
157
170
|
# Make != always return a Bool, so that numpy.unique works on a tuple of ints
|
|
158
171
|
# In _unique1d
|
|
@@ -279,6 +292,8 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
|
|
|
279
292
|
yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
|
|
280
293
|
yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
|
|
281
294
|
|
|
295
|
+
yield rewrite(o.__round__(OptionalInt.none)).to(o)
|
|
296
|
+
|
|
282
297
|
# Never cannot be equal to anything real
|
|
283
298
|
yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
|
|
284
299
|
|
|
@@ -353,8 +368,14 @@ class Float(Expr, ruleset=array_api_ruleset):
|
|
|
353
368
|
def __sub__(self, other: FloatLike) -> Float: ...
|
|
354
369
|
|
|
355
370
|
def __pow__(self, other: FloatLike) -> Float: ...
|
|
371
|
+
def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
|
|
356
372
|
|
|
357
373
|
def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
|
|
374
|
+
def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
|
|
375
|
+
def __lt__(self, other: FloatLike) -> Boolean: ...
|
|
376
|
+
def __le__(self, other: FloatLike) -> Boolean: ...
|
|
377
|
+
def __gt__(self, other: FloatLike) -> Boolean: ...
|
|
378
|
+
def __ge__(self, other: FloatLike) -> Boolean: ...
|
|
358
379
|
|
|
359
380
|
|
|
360
381
|
converter(float, Float, lambda x: Float(x))
|
|
@@ -365,9 +386,10 @@ FloatLike: TypeAlias = Float | float | IntLike
|
|
|
365
386
|
|
|
366
387
|
|
|
367
388
|
@array_api_ruleset.register
|
|
368
|
-
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
|
|
389
|
+
def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
|
|
369
390
|
return [
|
|
370
391
|
rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
|
|
392
|
+
rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
|
|
371
393
|
rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
|
|
372
394
|
rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
|
|
373
395
|
# Convert from float to rationl, if its a whole number i.e. can be converted to int
|
|
@@ -382,11 +404,22 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
|
|
|
382
404
|
rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
|
|
383
405
|
rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
|
|
384
406
|
rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
|
|
385
|
-
#
|
|
407
|
+
# comparisons
|
|
386
408
|
rewrite(Float(f) == Float(f)).to(TRUE),
|
|
387
409
|
rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
|
|
410
|
+
rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
|
|
411
|
+
rewrite(Float(f) != Float(f)).to(FALSE),
|
|
412
|
+
rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
|
|
413
|
+
rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
|
|
414
|
+
rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
|
|
415
|
+
rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
|
|
416
|
+
rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
|
|
417
|
+
rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
|
|
418
|
+
rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
|
|
388
419
|
rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
|
|
389
420
|
rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
|
|
421
|
+
# round
|
|
422
|
+
rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
|
|
390
423
|
]
|
|
391
424
|
|
|
392
425
|
|
|
@@ -670,6 +703,8 @@ class OptionalInt(Expr, ruleset=array_api_ruleset):
|
|
|
670
703
|
def some(cls, value: Int) -> OptionalInt: ...
|
|
671
704
|
|
|
672
705
|
|
|
706
|
+
OptionalIntLike: TypeAlias = OptionalInt | IntLike | None
|
|
707
|
+
|
|
673
708
|
converter(type(None), OptionalInt, lambda _: OptionalInt.none)
|
|
674
709
|
converter(Int, OptionalInt, OptionalInt.some)
|
|
675
710
|
|
|
@@ -1198,13 +1233,13 @@ class NDArray(Expr, ruleset=array_api_ruleset):
|
|
|
1198
1233
|
|
|
1199
1234
|
NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
|
|
1200
1235
|
|
|
1201
|
-
converter(NDArray, IndexKey, IndexKey.ndarray)
|
|
1202
|
-
converter(Value, NDArray, NDArray.scalar)
|
|
1236
|
+
converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
|
|
1237
|
+
converter(Value, NDArray, lambda v: NDArray.scalar(v))
|
|
1203
1238
|
# Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
|
|
1204
1239
|
# to prefer upcasting in the other direction when we can, which is safer at runtime
|
|
1205
1240
|
converter(NDArray, Value, lambda n: n.to_value(), 100)
|
|
1206
|
-
converter(TupleValue, NDArray, NDArray.vector)
|
|
1207
|
-
converter(TupleInt, TupleValue, TupleValue.from_tuple_int)
|
|
1241
|
+
converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
|
|
1242
|
+
converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
|
|
1208
1243
|
|
|
1209
1244
|
|
|
1210
1245
|
@array_api_ruleset.register
|
|
@@ -1383,8 +1418,8 @@ class IntOrTuple(Expr, ruleset=array_api_ruleset):
|
|
|
1383
1418
|
def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
|
|
1384
1419
|
|
|
1385
1420
|
|
|
1386
|
-
converter(Int, IntOrTuple, IntOrTuple.int)
|
|
1387
|
-
converter(TupleInt, IntOrTuple, IntOrTuple.tuple)
|
|
1421
|
+
converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
|
|
1422
|
+
converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
|
|
1388
1423
|
|
|
1389
1424
|
|
|
1390
1425
|
class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
|
|
@@ -1395,7 +1430,7 @@ class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
|
|
|
1395
1430
|
|
|
1396
1431
|
|
|
1397
1432
|
converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
|
|
1398
|
-
converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
|
|
1433
|
+
converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
|
|
1399
1434
|
|
|
1400
1435
|
|
|
1401
1436
|
@function
|
|
@@ -1980,6 +2015,5 @@ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: Built
|
|
|
1980
2015
|
extracted = egraph.extract(prim_expr)
|
|
1981
2016
|
except BaseException as e:
|
|
1982
2017
|
# egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
|
|
1983
|
-
|
|
1984
|
-
|
|
1985
|
-
return extracted.eval() # type: ignore[attr-defined]
|
|
2018
|
+
raise add_note(f"Cannot evaluate {egraph.extract(expr)}", e) # noqa: B904
|
|
2019
|
+
return extracted.value # type: ignore[attr-defined]
|
egglog/exp/array_api_jit.py
CHANGED
|
@@ -14,16 +14,22 @@ from .program_gen import Program
|
|
|
14
14
|
X = TypeVar("X", bound=Callable)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
def jit(
|
|
17
|
+
def jit(
|
|
18
|
+
fn: X,
|
|
19
|
+
*,
|
|
20
|
+
handle_expr: Callable[[NDArray], None] | None = None,
|
|
21
|
+
handle_optimized_expr: Callable[[NDArray], None] | None = None,
|
|
22
|
+
) -> X:
|
|
18
23
|
"""
|
|
19
24
|
Jit compiles a function
|
|
20
25
|
"""
|
|
21
26
|
egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
|
|
27
|
+
if handle_expr:
|
|
28
|
+
handle_expr(res)
|
|
29
|
+
if handle_optimized_expr:
|
|
30
|
+
handle_optimized_expr(res_optimized)
|
|
22
31
|
fn_program = EvalProgram(program, {"np": np})
|
|
23
|
-
|
|
24
|
-
fn.initial_expr = res # type: ignore[attr-defined]
|
|
25
|
-
fn.expr = res_optimized # type: ignore[attr-defined]
|
|
26
|
-
return fn
|
|
32
|
+
return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
|
|
27
33
|
|
|
28
34
|
|
|
29
35
|
def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:
|
|
@@ -505,6 +505,6 @@ def _ndarray_program(
|
|
|
505
505
|
yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())
|
|
506
506
|
|
|
507
507
|
# asarray
|
|
508
|
-
yield rewrite(ndarray_program(asarray(x, odtype))).to(
|
|
508
|
+
yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
|
|
509
509
|
Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
|
|
510
510
|
)
|
egglog/exp/program_gen.py
CHANGED
|
@@ -5,12 +5,10 @@ Builds up imperative string expressions from a functional expression.
|
|
|
5
5
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
|
-
from typing import TypeAlias
|
|
8
|
+
from typing import TypeAlias
|
|
9
9
|
|
|
10
10
|
from egglog import *
|
|
11
11
|
|
|
12
|
-
ProgramLike: TypeAlias = Union["Program", StringLike]
|
|
13
|
-
|
|
14
12
|
|
|
15
13
|
class Program(Expr):
|
|
16
14
|
"""
|
|
@@ -97,6 +95,9 @@ class Program(Expr):
|
|
|
97
95
|
"""
|
|
98
96
|
|
|
99
97
|
|
|
98
|
+
ProgramLike: TypeAlias = Program | StringLike
|
|
99
|
+
|
|
100
|
+
|
|
100
101
|
converter(String, Program, Program)
|
|
101
102
|
|
|
102
103
|
|
egglog/pretty.py
CHANGED
|
@@ -107,7 +107,7 @@ def pretty_callable_ref(
|
|
|
107
107
|
"""
|
|
108
108
|
# Pass in three dummy args, which are the max used for any operation that
|
|
109
109
|
# is not a generic function call
|
|
110
|
-
args: list[ExprDecl] = [
|
|
110
|
+
args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * 3
|
|
111
111
|
if first_arg:
|
|
112
112
|
args.insert(0, first_arg)
|
|
113
113
|
context = PrettyContext(decls, defaultdict(lambda: 0))
|
|
@@ -166,7 +166,7 @@ class TraverseContext:
|
|
|
166
166
|
self(d.expr)
|
|
167
167
|
case ChangeDecl(_, d, _) | SaturateDecl(d) | RepeatDecl(d, _) | ActionCommandDecl(d):
|
|
168
168
|
self(d)
|
|
169
|
-
case PanicDecl(_) |
|
|
169
|
+
case PanicDecl(_) | UnboundVarDecl(_) | LetRefDecl(_) | LitDecl(_) | PyObjectDecl(_):
|
|
170
170
|
pass
|
|
171
171
|
case SequenceDecl(decls) | RulesetDecl(decls):
|
|
172
172
|
for de in decls:
|
|
@@ -233,6 +233,10 @@ class PrettyContext:
|
|
|
233
233
|
return expr
|
|
234
234
|
|
|
235
235
|
def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: C901, PLR0911, PLR0912
|
|
236
|
+
"""
|
|
237
|
+
Returns a tuple of a string value of the decleration and the "type" to use when create a memoized cached version
|
|
238
|
+
for de-duplication.
|
|
239
|
+
"""
|
|
236
240
|
match decl:
|
|
237
241
|
case LitDecl(value):
|
|
238
242
|
match value:
|
|
@@ -247,7 +251,7 @@ class PrettyContext:
|
|
|
247
251
|
case str(s):
|
|
248
252
|
return repr(s) if unwrap_lit else f"String({s!r})", "String"
|
|
249
253
|
assert_never(value)
|
|
250
|
-
case
|
|
254
|
+
case UnboundVarDecl(name) | LetRefDecl(name):
|
|
251
255
|
return name, name
|
|
252
256
|
case CallDecl(_, _, _):
|
|
253
257
|
return self._call(decl, parens)
|
|
@@ -357,7 +361,7 @@ class PrettyContext:
|
|
|
357
361
|
has_multiple_parents = self.parents[first_arg] > 1
|
|
358
362
|
self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
|
|
359
363
|
# Set the first arg to be the name of the mutated arg and return the name
|
|
360
|
-
args[0] =
|
|
364
|
+
args[0] = LetRefDecl(expr_name)
|
|
361
365
|
else:
|
|
362
366
|
expr_name = None
|
|
363
367
|
res = self._call_inner(ref, args, decl.bound_tp_params, parens)
|
|
@@ -390,6 +394,7 @@ class PrettyContext:
|
|
|
390
394
|
return f"{tp_ref}.{method_name}", args
|
|
391
395
|
case MethodRef(_class_name, method_name):
|
|
392
396
|
slf, *args = args
|
|
397
|
+
non_str_slf = slf
|
|
393
398
|
slf = self(slf, parens=True)
|
|
394
399
|
match method_name:
|
|
395
400
|
case _ if method_name in UNARY_METHODS:
|
|
@@ -406,6 +411,8 @@ class PrettyContext:
|
|
|
406
411
|
return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
|
|
407
412
|
case "__setitem__":
|
|
408
413
|
return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
|
|
414
|
+
case "__round__":
|
|
415
|
+
return "round", [non_str_slf, *args]
|
|
409
416
|
case _:
|
|
410
417
|
return f"{slf}.{method_name}", args
|
|
411
418
|
case ConstantRef(name):
|
|
@@ -487,24 +494,3 @@ class PrettyContext:
|
|
|
487
494
|
if arg_names:
|
|
488
495
|
prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
|
|
489
496
|
return f"{prefix}: {self(res.expr)}"
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
def _plot_line_length(expr: object): # pragma: no cover
|
|
493
|
-
"""
|
|
494
|
-
Plots the number of line lengths based on different max lengths
|
|
495
|
-
"""
|
|
496
|
-
global MAX_LINE_LENGTH, LINE_DIFFERENCE
|
|
497
|
-
import altair as alt
|
|
498
|
-
import pandas as pd
|
|
499
|
-
|
|
500
|
-
sizes = []
|
|
501
|
-
for line_length in range(40, 180, 10):
|
|
502
|
-
MAX_LINE_LENGTH = line_length
|
|
503
|
-
for diff in range(0, 40, 5):
|
|
504
|
-
LINE_DIFFERENCE = diff
|
|
505
|
-
new_l = len(str(expr).split())
|
|
506
|
-
sizes.append((line_length, diff, new_l))
|
|
507
|
-
|
|
508
|
-
df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"])
|
|
509
|
-
|
|
510
|
-
return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")
|