egglog 10.0.1__cp312-cp312-win_amd64.whl → 11.0.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/exp/array_api.py CHANGED
@@ -69,6 +69,7 @@ import numpy as np
69
69
 
70
70
  from egglog import *
71
71
  from egglog.runtime import RuntimeExpr
72
+ from egglog.version_compat import add_note
72
73
 
73
74
  from .program_gen import *
74
75
 
@@ -153,6 +154,18 @@ class Int(Expr, ruleset=array_api_ruleset):
153
154
  def __eq__(self, other: IntLike) -> Boolean: # type: ignore[override]
154
155
  ...
155
156
 
157
+ # add a hash so that this test can pass
158
+ # https://github.com/scikit-learn/scikit-learn/blob/6fd23fca53845b32b249f2b36051c081b65e2fab/sklearn/utils/validation.py#L486-L487
159
+ @method(preserve=True)
160
+ def __hash__(self) -> int:
161
+ egraph = _get_current_egraph()
162
+ egraph.register(self)
163
+ egraph.run(array_api_schedule)
164
+ simplified = egraph.extract(self)
165
+ return hash(cast("RuntimeExpr", simplified).__egg_typed_expr__)
166
+
167
+ def __round__(self, ndigits: OptionalIntLike = None) -> Int: ...
168
+
156
169
  # TODO: Fix this?
157
170
  # Make != always return a Bool, so that numpy.unique works on a tuple of ints
158
171
  # In _unique1d
@@ -279,6 +292,8 @@ def _int(i: i64, j: i64, r: Boolean, o: Int, b: Int):
279
292
  yield rewrite(Int.if_(TRUE, o, b), subsume=True).to(o)
280
293
  yield rewrite(Int.if_(FALSE, o, b), subsume=True).to(b)
281
294
 
295
+ yield rewrite(o.__round__(OptionalInt.none)).to(o)
296
+
282
297
  # Never cannot be equal to anything real
283
298
  yield rule(eq(Int.NEVER).to(Int(i))).then(panic("Int.NEVER cannot be equal to any real int"))
284
299
 
@@ -353,8 +368,14 @@ class Float(Expr, ruleset=array_api_ruleset):
353
368
  def __sub__(self, other: FloatLike) -> Float: ...
354
369
 
355
370
  def __pow__(self, other: FloatLike) -> Float: ...
371
+ def __round__(self, ndigits: OptionalIntLike = None) -> Float: ...
356
372
 
357
373
  def __eq__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
374
+ def __ne__(self, other: FloatLike) -> Boolean: ... # type: ignore[override]
375
+ def __lt__(self, other: FloatLike) -> Boolean: ...
376
+ def __le__(self, other: FloatLike) -> Boolean: ...
377
+ def __gt__(self, other: FloatLike) -> Boolean: ...
378
+ def __ge__(self, other: FloatLike) -> Boolean: ...
358
379
 
359
380
 
360
381
  converter(float, Float, lambda x: Float(x))
@@ -365,9 +386,10 @@ FloatLike: TypeAlias = Float | float | IntLike
365
386
 
366
387
 
367
388
  @array_api_ruleset.register
368
- def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
389
+ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat, i_: Int):
369
390
  return [
370
391
  rule(eq(fl).to(Float(f))).then(set_(fl.to_f64).to(f)),
392
+ rewrite(Float.from_int(Int(i))).to(Float(f64.from_i64(i))),
371
393
  rewrite(Float(f).abs()).to(Float(f), f >= 0.0),
372
394
  rewrite(Float(f).abs()).to(Float(-f), f < 0.0),
373
395
  # Convert from float to rationl, if its a whole number i.e. can be converted to int
@@ -382,11 +404,22 @@ def _float(fl: Float, f: f64, f2: f64, i: i64, r: BigRat, r1: BigRat):
382
404
  rewrite(Float.rational(r) - Float.rational(r1)).to(Float.rational(r - r1)),
383
405
  rewrite(Float.rational(r) * Float.rational(r1)).to(Float.rational(r * r1)),
384
406
  rewrite(Float(f) ** Float(f2)).to(Float(f**f2)),
385
- # ==
407
+ # comparisons
386
408
  rewrite(Float(f) == Float(f)).to(TRUE),
387
409
  rewrite(Float(f) == Float(f2)).to(FALSE, ne(f).to(f2)),
410
+ rewrite(Float(f) != Float(f2)).to(TRUE, f != f2),
411
+ rewrite(Float(f) != Float(f)).to(FALSE),
412
+ rewrite(Float(f) >= Float(f2)).to(TRUE, f >= f2),
413
+ rewrite(Float(f) >= Float(f2)).to(FALSE, f < f2),
414
+ rewrite(Float(f) <= Float(f2)).to(TRUE, f <= f2),
415
+ rewrite(Float(f) <= Float(f2)).to(FALSE, f > f2),
416
+ rewrite(Float(f) > Float(f2)).to(TRUE, f > f2),
417
+ rewrite(Float(f) > Float(f2)).to(FALSE, f <= f2),
418
+ rewrite(Float(f) < Float(f2)).to(TRUE, f < f2),
388
419
  rewrite(Float.rational(r) == Float.rational(r)).to(TRUE),
389
420
  rewrite(Float.rational(r) == Float.rational(r1)).to(FALSE, ne(r).to(r1)),
421
+ # round
422
+ rewrite(Float.rational(r).__round__()).to(Float.rational(r.round())),
390
423
  ]
391
424
 
392
425
 
@@ -670,6 +703,8 @@ class OptionalInt(Expr, ruleset=array_api_ruleset):
670
703
  def some(cls, value: Int) -> OptionalInt: ...
671
704
 
672
705
 
706
+ OptionalIntLike: TypeAlias = OptionalInt | IntLike | None
707
+
673
708
  converter(type(None), OptionalInt, lambda _: OptionalInt.none)
674
709
  converter(Int, OptionalInt, OptionalInt.some)
675
710
 
@@ -1198,13 +1233,13 @@ class NDArray(Expr, ruleset=array_api_ruleset):
1198
1233
 
1199
1234
  NDArrayLike: TypeAlias = NDArray | ValueLike | TupleValueLike
1200
1235
 
1201
- converter(NDArray, IndexKey, IndexKey.ndarray)
1202
- converter(Value, NDArray, NDArray.scalar)
1236
+ converter(NDArray, IndexKey, lambda v: IndexKey.ndarray(v))
1237
+ converter(Value, NDArray, lambda v: NDArray.scalar(v))
1203
1238
  # Need this if we want to use ints in slices of arrays coming from 1d arrays, but make it more expensive
1204
1239
  # to prefer upcasting in the other direction when we can, which is safer at runtime
1205
1240
  converter(NDArray, Value, lambda n: n.to_value(), 100)
1206
- converter(TupleValue, NDArray, NDArray.vector)
1207
- converter(TupleInt, TupleValue, TupleValue.from_tuple_int)
1241
+ converter(TupleValue, NDArray, lambda v: NDArray.vector(v))
1242
+ converter(TupleInt, TupleValue, lambda v: TupleValue.from_tuple_int(v))
1208
1243
 
1209
1244
 
1210
1245
  @array_api_ruleset.register
@@ -1383,8 +1418,8 @@ class IntOrTuple(Expr, ruleset=array_api_ruleset):
1383
1418
  def tuple(cls, value: TupleIntLike) -> IntOrTuple: ...
1384
1419
 
1385
1420
 
1386
- converter(Int, IntOrTuple, IntOrTuple.int)
1387
- converter(TupleInt, IntOrTuple, IntOrTuple.tuple)
1421
+ converter(Int, IntOrTuple, lambda v: IntOrTuple.int(v))
1422
+ converter(TupleInt, IntOrTuple, lambda v: IntOrTuple.tuple(v))
1388
1423
 
1389
1424
 
1390
1425
  class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
@@ -1395,7 +1430,7 @@ class OptionalIntOrTuple(Expr, ruleset=array_api_ruleset):
1395
1430
 
1396
1431
 
1397
1432
  converter(type(None), OptionalIntOrTuple, lambda _: OptionalIntOrTuple.none)
1398
- converter(IntOrTuple, OptionalIntOrTuple, OptionalIntOrTuple.some)
1433
+ converter(IntOrTuple, OptionalIntOrTuple, lambda v: OptionalIntOrTuple.some(v))
1399
1434
 
1400
1435
 
1401
1436
  @function
@@ -1980,6 +2015,5 @@ def try_evaling(egraph: EGraph, schedule: Schedule, expr: Expr, prim_expr: Built
1980
2015
  extracted = egraph.extract(prim_expr)
1981
2016
  except BaseException as e:
1982
2017
  # egraph.display(n_inline_leaves=1, split_primitive_outputs=True)
1983
- e.add_note(f"Cannot evaluate {egraph.extract(expr)}")
1984
- raise
1985
- return extracted.eval() # type: ignore[attr-defined]
2018
+ raise add_note(f"Cannot evaluate {egraph.extract(expr)}", e) # noqa: B904
2019
+ return extracted.value # type: ignore[attr-defined]
@@ -14,16 +14,22 @@ from .program_gen import Program
14
14
  X = TypeVar("X", bound=Callable)
15
15
 
16
16
 
17
- def jit(fn: X) -> X:
17
+ def jit(
18
+ fn: X,
19
+ *,
20
+ handle_expr: Callable[[NDArray], None] | None = None,
21
+ handle_optimized_expr: Callable[[NDArray], None] | None = None,
22
+ ) -> X:
18
23
  """
19
24
  Jit compiles a function
20
25
  """
21
26
  egraph, res, res_optimized, program = function_to_program(fn, save_egglog_string=False)
27
+ if handle_expr:
28
+ handle_expr(res)
29
+ if handle_optimized_expr:
30
+ handle_optimized_expr(res_optimized)
22
31
  fn_program = EvalProgram(program, {"np": np})
23
- fn = cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
24
- fn.initial_expr = res # type: ignore[attr-defined]
25
- fn.expr = res_optimized # type: ignore[attr-defined]
26
- return fn
32
+ return cast("X", try_evaling(egraph, array_api_program_gen_schedule, fn_program, fn_program.as_py_object))
27
33
 
28
34
 
29
35
  def function_to_program(fn: Callable, save_egglog_string: bool) -> tuple[EGraph, NDArray, NDArray, Program]:
@@ -505,6 +505,6 @@ def _ndarray_program(
505
505
  yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())
506
506
 
507
507
  # asarray
508
- yield rewrite(ndarray_program(asarray(x, odtype))).to(
508
+ yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
509
509
  Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
510
510
  )
egglog/exp/program_gen.py CHANGED
@@ -5,12 +5,10 @@ Builds up imperative string expressions from a functional expression.
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- from typing import TypeAlias, Union
8
+ from typing import TypeAlias
9
9
 
10
10
  from egglog import *
11
11
 
12
- ProgramLike: TypeAlias = Union["Program", StringLike]
13
-
14
12
 
15
13
  class Program(Expr):
16
14
  """
@@ -97,6 +95,9 @@ class Program(Expr):
97
95
  """
98
96
 
99
97
 
98
+ ProgramLike: TypeAlias = Program | StringLike
99
+
100
+
100
101
  converter(String, Program, Program)
101
102
 
102
103
 
egglog/pretty.py CHANGED
@@ -107,7 +107,7 @@ def pretty_callable_ref(
107
107
  """
108
108
  # Pass in three dummy args, which are the max used for any operation that
109
109
  # is not a generic function call
110
- args: list[ExprDecl] = [VarDecl(ARG_STR, False)] * 3
110
+ args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * 3
111
111
  if first_arg:
112
112
  args.insert(0, first_arg)
113
113
  context = PrettyContext(decls, defaultdict(lambda: 0))
@@ -166,7 +166,7 @@ class TraverseContext:
166
166
  self(d.expr)
167
167
  case ChangeDecl(_, d, _) | SaturateDecl(d) | RepeatDecl(d, _) | ActionCommandDecl(d):
168
168
  self(d)
169
- case PanicDecl(_) | VarDecl(_) | LitDecl(_) | PyObjectDecl(_):
169
+ case PanicDecl(_) | UnboundVarDecl(_) | LetRefDecl(_) | LitDecl(_) | PyObjectDecl(_):
170
170
  pass
171
171
  case SequenceDecl(decls) | RulesetDecl(decls):
172
172
  for de in decls:
@@ -233,6 +233,10 @@ class PrettyContext:
233
233
  return expr
234
234
 
235
235
  def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: C901, PLR0911, PLR0912
236
+ """
237
+ Returns a tuple of a string value of the decleration and the "type" to use when create a memoized cached version
238
+ for de-duplication.
239
+ """
236
240
  match decl:
237
241
  case LitDecl(value):
238
242
  match value:
@@ -247,7 +251,7 @@ class PrettyContext:
247
251
  case str(s):
248
252
  return repr(s) if unwrap_lit else f"String({s!r})", "String"
249
253
  assert_never(value)
250
- case VarDecl(name):
254
+ case UnboundVarDecl(name) | LetRefDecl(name):
251
255
  return name, name
252
256
  case CallDecl(_, _, _):
253
257
  return self._call(decl, parens)
@@ -357,7 +361,7 @@ class PrettyContext:
357
361
  has_multiple_parents = self.parents[first_arg] > 1
358
362
  self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
359
363
  # Set the first arg to be the name of the mutated arg and return the name
360
- args[0] = VarDecl(expr_name, True)
364
+ args[0] = LetRefDecl(expr_name)
361
365
  else:
362
366
  expr_name = None
363
367
  res = self._call_inner(ref, args, decl.bound_tp_params, parens)
@@ -390,6 +394,7 @@ class PrettyContext:
390
394
  return f"{tp_ref}.{method_name}", args
391
395
  case MethodRef(_class_name, method_name):
392
396
  slf, *args = args
397
+ non_str_slf = slf
393
398
  slf = self(slf, parens=True)
394
399
  match method_name:
395
400
  case _ if method_name in UNARY_METHODS:
@@ -406,6 +411,8 @@ class PrettyContext:
406
411
  return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
407
412
  case "__setitem__":
408
413
  return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
414
+ case "__round__":
415
+ return "round", [non_str_slf, *args]
409
416
  case _:
410
417
  return f"{slf}.{method_name}", args
411
418
  case ConstantRef(name):
@@ -487,24 +494,3 @@ class PrettyContext:
487
494
  if arg_names:
488
495
  prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
489
496
  return f"{prefix}: {self(res.expr)}"
490
-
491
-
492
- def _plot_line_length(expr: object): # pragma: no cover
493
- """
494
- Plots the number of line lengths based on different max lengths
495
- """
496
- global MAX_LINE_LENGTH, LINE_DIFFERENCE
497
- import altair as alt
498
- import pandas as pd
499
-
500
- sizes = []
501
- for line_length in range(40, 180, 10):
502
- MAX_LINE_LENGTH = line_length
503
- for diff in range(0, 40, 5):
504
- LINE_DIFFERENCE = diff
505
- new_l = len(str(expr).split())
506
- sizes.append((line_length, diff, new_l))
507
-
508
- df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"])
509
-
510
- return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")