egglog 10.0.2__cp312-cp312-win_amd64.whl → 11.1.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

@@ -505,6 +505,6 @@ def _ndarray_program(
505
505
  yield rewrite(ndarray_program(abs(x))).to((Program("np.abs(") + ndarray_program(x) + ")").assign())
506
506
 
507
507
  # asarray
508
- yield rewrite(ndarray_program(asarray(x, odtype))).to(
508
+ yield rewrite(ndarray_program(asarray(x, odtype, OptionalBool.none, optional_device_))).to(
509
509
  Program("np.asarray(") + ndarray_program(x) + ", " + optional_dtype_program(odtype) + ")"
510
510
  )
egglog/exp/program_gen.py CHANGED
@@ -79,7 +79,7 @@ class Program(Expr):
79
79
  Triggers compilation of the program.
80
80
  """
81
81
 
82
- @method(merge=lambda old, _new: old) # type: ignore[misc]
82
+ @method(merge=lambda old, _new: old) # type: ignore[prop-decorator]
83
83
  @property
84
84
  def parent(self) -> Program:
85
85
  """
@@ -108,7 +108,7 @@ class EvalProgram(Expr):
108
108
  """
109
109
 
110
110
  # Only allow it to be set once, b/c hash of functions not stable
111
- @method(merge=lambda old, _new: old) # type: ignore[misc]
111
+ @method(merge=lambda old, _new: old) # type: ignore[prop-decorator]
112
112
  @property
113
113
  def as_py_object(self) -> PyObject:
114
114
  """
egglog/pretty.py CHANGED
@@ -107,7 +107,7 @@ def pretty_callable_ref(
107
107
  """
108
108
  # Pass in three dummy args, which are the max used for any operation that
109
109
  # is not a generic function call
110
- args: list[ExprDecl] = [VarDecl(ARG_STR, False)] * 3
110
+ args: list[ExprDecl] = [UnboundVarDecl(ARG_STR)] * 3
111
111
  if first_arg:
112
112
  args.insert(0, first_arg)
113
113
  context = PrettyContext(decls, defaultdict(lambda: 0))
@@ -166,7 +166,7 @@ class TraverseContext:
166
166
  self(d.expr)
167
167
  case ChangeDecl(_, d, _) | SaturateDecl(d) | RepeatDecl(d, _) | ActionCommandDecl(d):
168
168
  self(d)
169
- case PanicDecl(_) | VarDecl(_) | LitDecl(_) | PyObjectDecl(_):
169
+ case PanicDecl(_) | UnboundVarDecl(_) | LetRefDecl(_) | LitDecl(_) | PyObjectDecl(_):
170
170
  pass
171
171
  case SequenceDecl(decls) | RulesetDecl(decls):
172
172
  for de in decls:
@@ -233,6 +233,10 @@ class PrettyContext:
233
233
  return expr
234
234
 
235
235
  def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_name: str | None) -> tuple[str, str]: # noqa: C901, PLR0911, PLR0912
236
+ """
237
+ Returns a tuple of a string value of the decleration and the "type" to use when create a memoized cached version
238
+ for de-duplication.
239
+ """
236
240
  match decl:
237
241
  case LitDecl(value):
238
242
  match value:
@@ -247,7 +251,7 @@ class PrettyContext:
247
251
  case str(s):
248
252
  return repr(s) if unwrap_lit else f"String({s!r})", "String"
249
253
  assert_never(value)
250
- case VarDecl(name):
254
+ case UnboundVarDecl(name) | LetRefDecl(name):
251
255
  return name, name
252
256
  case CallDecl(_, _, _):
253
257
  return self._call(decl, parens)
@@ -357,7 +361,7 @@ class PrettyContext:
357
361
  has_multiple_parents = self.parents[first_arg] > 1
358
362
  self.names[decl] = expr_name = self._name_expr(tp_name, expr_str, copy_identifier=has_multiple_parents)
359
363
  # Set the first arg to be the name of the mutated arg and return the name
360
- args[0] = VarDecl(expr_name, True)
364
+ args[0] = LetRefDecl(expr_name)
361
365
  else:
362
366
  expr_name = None
363
367
  res = self._call_inner(ref, args, decl.bound_tp_params, parens)
@@ -390,6 +394,7 @@ class PrettyContext:
390
394
  return f"{tp_ref}.{method_name}", args
391
395
  case MethodRef(_class_name, method_name):
392
396
  slf, *args = args
397
+ non_str_slf = slf
393
398
  slf = self(slf, parens=True)
394
399
  match method_name:
395
400
  case _ if method_name in UNARY_METHODS:
@@ -406,6 +411,8 @@ class PrettyContext:
406
411
  return f"del {slf}[{self(args[0], unwrap_lit=True)}]"
407
412
  case "__setitem__":
408
413
  return f"{slf}[{self(args[0], unwrap_lit=True)}] = {self(args[1], unwrap_lit=True)}"
414
+ case "__round__":
415
+ return "round", [non_str_slf, *args]
409
416
  case _:
410
417
  return f"{slf}.{method_name}", args
411
418
  case ConstantRef(name):
@@ -487,24 +494,3 @@ class PrettyContext:
487
494
  if arg_names:
488
495
  prefix += f" {', '.join(self(a.expr) for a in arg_names)}"
489
496
  return f"{prefix}: {self(res.expr)}"
490
-
491
-
492
- def _plot_line_length(expr: object): # pragma: no cover
493
- """
494
- Plots the number of line lengths based on different max lengths
495
- """
496
- global MAX_LINE_LENGTH, LINE_DIFFERENCE
497
- import altair as alt
498
- import pandas as pd
499
-
500
- sizes = []
501
- for line_length in range(40, 180, 10):
502
- MAX_LINE_LENGTH = line_length
503
- for diff in range(0, 40, 5):
504
- LINE_DIFFERENCE = diff
505
- new_l = len(str(expr).split())
506
- sizes.append((line_length, diff, new_l))
507
-
508
- df = pd.DataFrame(sizes, columns=["MAX_LINE_LENGTH", "LENGTH_DIFFERENCE", "n"])
509
-
510
- return alt.Chart(df).mark_rect().encode(x="MAX_LINE_LENGTH:O", y="LENGTH_DIFFERENCE:O", color="n:Q")
egglog/runtime.py CHANGED
@@ -11,12 +11,14 @@ so they are not mangled by Python and can be accessed by the user.
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
+ import itertools
14
15
  import operator
16
+ import types
15
17
  from collections.abc import Callable
16
- from dataclasses import dataclass, replace
18
+ from dataclasses import InitVar, dataclass, replace
17
19
  from inspect import Parameter, Signature
18
20
  from itertools import zip_longest
19
- from typing import TYPE_CHECKING, TypeVar, Union, cast, get_args, get_origin
21
+ from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin
20
22
 
21
23
  from .declarations import *
22
24
  from .pretty import *
@@ -27,15 +29,14 @@ from .version_compat import *
27
29
  if TYPE_CHECKING:
28
30
  from collections.abc import Iterable
29
31
 
30
- from .egraph import Fact
31
-
32
32
 
33
33
  __all__ = [
34
34
  "LIT_CLASS_NAMES",
35
- "REFLECTED_BINARY_METHODS",
35
+ "NUMERIC_BINARY_METHODS",
36
36
  "RuntimeClass",
37
37
  "RuntimeExpr",
38
38
  "RuntimeFunction",
39
+ "define_expr_method",
39
40
  "resolve_callable",
40
41
  "resolve_type_annotation",
41
42
  "resolve_type_annotation_mutate",
@@ -46,24 +47,34 @@ UNIT_CLASS_NAME = "Unit"
46
47
  UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
47
48
  LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
48
49
 
49
- REFLECTED_BINARY_METHODS = {
50
- "__radd__": "__add__",
51
- "__rsub__": "__sub__",
52
- "__rmul__": "__mul__",
53
- "__rmatmul__": "__matmul__",
54
- "__rtruediv__": "__truediv__",
55
- "__rfloordiv__": "__floordiv__",
56
- "__rmod__": "__mod__",
57
- "__rpow__": "__pow__",
58
- "__rlshift__": "__lshift__",
59
- "__rrshift__": "__rshift__",
60
- "__rand__": "__and__",
61
- "__rxor__": "__xor__",
62
- "__ror__": "__or__",
50
+ # All methods which should return NotImplemented if they fail to resolve and are reflected as well
51
+ # From https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
52
+
53
+ NUMERIC_BINARY_METHODS = {
54
+ "__add__",
55
+ "__sub__",
56
+ "__mul__",
57
+ "__matmul__",
58
+ "__truediv__",
59
+ "__floordiv__",
60
+ "__mod__",
61
+ "__divmod__",
62
+ "__pow__",
63
+ "__lshift__",
64
+ "__rshift__",
65
+ "__and__",
66
+ "__xor__",
67
+ "__or__",
68
+ "__lt__",
69
+ "__le__",
70
+ "__gt__",
71
+ "__ge__",
63
72
  }
64
73
 
65
- # Methods that need to return real Python values not expressions
66
- PRESERVED_METHODS = [
74
+
75
+ # Methods that need to be defined on the runtime type that holds `Expr` objects, so that they can be used as methods.
76
+
77
+ TYPE_DEFINED_METHODS = {
67
78
  "__bool__",
68
79
  "__len__",
69
80
  "__complex__",
@@ -71,9 +82,15 @@ PRESERVED_METHODS = [
71
82
  "__float__",
72
83
  "__iter__",
73
84
  "__index__",
74
- "__float__",
75
- "__int__",
76
- ]
85
+ "__call__",
86
+ "__getitem__",
87
+ "__setitem__",
88
+ "__delitem__",
89
+ "__pos__",
90
+ "__neg__",
91
+ "__invert__",
92
+ "__round__",
93
+ }
77
94
 
78
95
  # Set this globally so we can get access to PyObject when we have a type annotation of just object.
79
96
  # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
@@ -135,11 +152,48 @@ def inverse_resolve_type_annotation(decls_thunk: Callable[[], Declarations], tp:
135
152
  ##
136
153
 
137
154
 
138
- @dataclass
139
- class RuntimeClass(DelayedDeclerations):
155
+ class BaseClassFactoryMeta(type):
156
+ """
157
+ Base metaclass for all runtime classes created by ClassFactory
158
+ """
159
+
160
+ def __instancecheck__(cls, instance: object) -> bool:
161
+ assert isinstance(cls, RuntimeClass)
162
+ return isinstance(instance, RuntimeExpr) and cls.__egg_tp__.name == instance.__egg_typed_expr__.tp.name
163
+
164
+
165
+ class ClassFactory(type):
166
+ """
167
+ A metaclass for types which should create `type` objects when instantiated.
168
+
169
+ That's so that they work with `isinstance` and can be placed in `match ClassName()`.
170
+ """
171
+
172
+ def __call__(cls, *args, **kwargs) -> type:
173
+ # If we have params, don't inherit from `type` because we don't need to match against this and also
174
+ # this won't work with `Union[X]` because it won't look at `__parameters__` for instances of `type`.
175
+ if kwargs.pop("_egg_has_params", False):
176
+ return super().__call__(*args, **kwargs)
177
+ namespace: dict[str, Any] = {}
178
+ for m in reversed(cls.__mro__):
179
+ namespace.update(m.__dict__)
180
+ init = namespace.pop("__init__")
181
+ meta = types.new_class("type(RuntimeClass)", (BaseClassFactoryMeta,), {}, lambda ns: ns.update(**namespace))
182
+ tp = types.new_class("RuntimeClass", (), {"metaclass": meta})
183
+ init(tp, *args, **kwargs)
184
+ return tp
185
+
186
+ def __instancecheck__(cls, instance: object) -> bool:
187
+ return isinstance(instance, BaseClassFactoryMeta)
188
+
189
+
190
+ @dataclass(match_args=False)
191
+ class RuntimeClass(DelayedDeclerations, metaclass=ClassFactory):
140
192
  __egg_tp__: TypeRefWithVars
193
+ # True if we want `__parameters__` to be recognized by `Union`, which means we can't inherit from `type` directly.
194
+ _egg_has_params: InitVar[bool] = False
141
195
 
142
- def __post_init__(self) -> None:
196
+ def __post_init__(self, _egg_has_params: bool) -> None:
143
197
  global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
144
198
  if (name := self.__egg_tp__.name) == "PyObject":
145
199
  _PY_OBJECT_CLASS = self
@@ -229,7 +283,7 @@ class RuntimeClass(DelayedDeclerations):
229
283
  else:
230
284
  final_args = new_args
231
285
  tp = TypeRefWithVars(self.__egg_tp__.name, final_args)
232
- return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp)
286
+ return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp, _egg_has_params=True)
233
287
 
234
288
  def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
235
289
  if name == "__origin__" and self.__egg_tp__.args:
@@ -288,6 +342,14 @@ class RuntimeClass(DelayedDeclerations):
288
342
  def __hash__(self) -> int:
289
343
  return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
290
344
 
345
+ def __eq__(self, other: object) -> bool:
346
+ """
347
+ Support equality for runtime comparison of egglog classes.
348
+ """
349
+ if not isinstance(other, RuntimeClass):
350
+ return NotImplemented
351
+ return self.__egg_tp__ == other.__egg_tp__
352
+
291
353
  # Support unioning like types
292
354
  def __or__(self, value: type) -> object:
293
355
  return Union[self, value] # noqa: UP007
@@ -299,6 +361,10 @@ class RuntimeClass(DelayedDeclerations):
299
361
  """
300
362
  return tuple(inverse_resolve_type_annotation(self.__egg_decls_thunk__, tp) for tp in self.__egg_tp__.args)
301
363
 
364
+ @property
365
+ def __match_args__(self) -> tuple[str, ...]:
366
+ return self.__egg_decls__._classes[self.__egg_tp__.name].match_args
367
+
302
368
 
303
369
  @dataclass
304
370
  class RuntimeFunction(DelayedDeclerations):
@@ -306,12 +372,23 @@ class RuntimeFunction(DelayedDeclerations):
306
372
  # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
307
373
  __egg_bound__: JustTypeRef | RuntimeExpr | None = None
308
374
 
375
+ def __eq__(self, other: object) -> bool:
376
+ """
377
+ Support equality for runtime comparison of egglog functions.
378
+ """
379
+ if not isinstance(other, RuntimeFunction):
380
+ return NotImplemented
381
+ return self.__egg_ref__ == other.__egg_ref__ and bool(self.__egg_bound__ == other.__egg_bound__)
382
+
383
+ def __hash__(self) -> int:
384
+ return hash((self.__egg_ref__, self.__egg_bound__))
385
+
309
386
  @property
310
387
  def __egg_ref__(self) -> CallableRef:
311
388
  return self.__egg_ref_thunk__()
312
389
 
313
390
  def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
314
- from .conversion import resolve_literal
391
+ from .conversion import resolve_literal # noqa: PLC0415
315
392
 
316
393
  if isinstance(self.__egg_bound__, RuntimeExpr):
317
394
  args = (self.__egg_bound__, *args)
@@ -346,7 +423,7 @@ class RuntimeFunction(DelayedDeclerations):
346
423
  try:
347
424
  bound = py_signature.bind(*args, **kwargs)
348
425
  except TypeError as err:
349
- raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err
426
+ raise TypeError(f"Failed to bind arguments for {self} with args {args} and kwargs {kwargs}: {err}") from err
350
427
  del kwargs
351
428
  bound.apply_defaults()
352
429
  assert not bound.kwargs
@@ -415,9 +492,7 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
415
492
  Parameter(
416
493
  n,
417
494
  Parameter.POSITIONAL_OR_KEYWORD,
418
- default=RuntimeExpr.__from_values__(
419
- decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
420
- )
495
+ default=RuntimeExpr.__from_values__(decls, TypedExprDecl(t.to_just(), d or LetRefDecl(n)))
421
496
  if d is not None or optional_args
422
497
  else Parameter.empty,
423
498
  )
@@ -428,32 +503,6 @@ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args:
428
503
  return Signature(parameters)
429
504
 
430
505
 
431
- # All methods which should return NotImplemented if they fail to resolve
432
- # From https://docs.python.org/3/reference/datamodel.html
433
- PARTIAL_METHODS = {
434
- "__lt__",
435
- "__le__",
436
- "__eq__",
437
- "__ne__",
438
- "__gt__",
439
- "__ge__",
440
- "__add__",
441
- "__sub__",
442
- "__mul__",
443
- "__matmul__",
444
- "__truediv__",
445
- "__floordiv__",
446
- "__mod__",
447
- "__divmod__",
448
- "__pow__",
449
- "__lshift__",
450
- "__rshift__",
451
- "__and__",
452
- "__xor__",
453
- "__or__",
454
- }
455
-
456
-
457
506
  @dataclass
458
507
  class RuntimeExpr(DelayedDeclerations):
459
508
  __egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
@@ -470,17 +519,14 @@ class RuntimeExpr(DelayedDeclerations):
470
519
  return self.__egg_typed_expr_thunk__()
471
520
 
472
521
  def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
473
- cls_name = self.__egg_class_name__
474
- class_decl = self.__egg_class_decl__
475
-
476
- if name in (preserved_methods := class_decl.preserved_methods):
477
- return preserved_methods[name].__get__(self)
478
-
479
- if name in class_decl.methods:
480
- return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
481
- if name in class_decl.properties:
482
- return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
483
- raise AttributeError(f"{cls_name} has no method {name}") from None
522
+ if (method := _get_expr_method(self, name)) is not None:
523
+ return method
524
+ if name in self.__egg_class_decl__.properties:
525
+ fn = RuntimeFunction(
526
+ self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_class_name__, name)), self
527
+ )
528
+ return fn()
529
+ raise AttributeError(f"{self.__egg_class_name__} has no method {name}") from None
484
530
 
485
531
  def __repr__(self) -> str:
486
532
  """
@@ -495,7 +541,7 @@ class RuntimeExpr(DelayedDeclerations):
495
541
  return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
496
542
 
497
543
  def _ipython_display_(self) -> None:
498
- from IPython.display import Code, display
544
+ from IPython.display import Code, display # noqa: PLC0415
499
545
 
500
546
  display(Code(str(self), language="python"))
501
547
 
@@ -511,13 +557,6 @@ class RuntimeExpr(DelayedDeclerations):
511
557
  def __egg_class_decl__(self) -> ClassDecl:
512
558
  return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
513
559
 
514
- # These both will be overriden below in the special methods section, but add these here for type hinting purposes
515
- def __eq__(self, other: object) -> Fact: # type: ignore[override, empty-body]
516
- ...
517
-
518
- def __ne__(self, other: object) -> RuntimeExpr: # type: ignore[override, empty-body]
519
- ...
520
-
521
560
  # Implement these so that copy() works on this object
522
561
  # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
523
562
 
@@ -529,91 +568,102 @@ class RuntimeExpr(DelayedDeclerations):
529
568
  self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
530
569
 
531
570
  def __hash__(self) -> int:
571
+ if (method := _get_expr_method(self, "__hash__")) is not None:
572
+ return cast("int", cast("Any", method()))
532
573
  return hash(self.__egg_typed_expr__)
533
574
 
575
+ # Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a
576
+ # preserved method or defined on the class
577
+ def __eq__(self, other: object) -> object: # type: ignore[override]
578
+ if (method := _get_expr_method(self, "__eq__")) is not None:
579
+ return method(other)
534
580
 
535
- # Define each of the special methods, since we have already declared them for pretty printing
536
- for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__", "__setitem__", "__delitem__"]:
581
+ # TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other
582
+ # expr gets a chance to resolve __eq__ which could be a preserved method.
583
+ from .egraph import BaseExpr, eq # noqa: PLC0415
537
584
 
538
- def _special_method(
539
- self: RuntimeExpr,
540
- *args: object,
541
- __name: str = name,
542
- **kwargs: object,
543
- ) -> RuntimeExpr | Fact | None:
544
- from .conversion import ConvertError
585
+ return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other))
545
586
 
546
- class_name = self.__egg_class_name__
547
- class_decl = self.__egg_class_decl__
548
- # First, try to resolve as preserved method
549
- try:
550
- method = class_decl.preserved_methods[__name]
551
- except KeyError:
552
- pass
553
- else:
554
- return method(self, *args, **kwargs)
555
- # If this is a "partial" method meaning that it can return NotImplemented,
556
- # we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
557
- # using the arg type of the self arg.
558
- # This is neccesary so if we add like an int to a ndarray, it will upcast the int to an ndarray, instead of vice versa.
559
- if __name in PARTIAL_METHODS:
560
- try:
561
- return call_method_min_conversion(self, args[0], __name)
562
- except ConvertError:
563
- # Defer raising not imeplemented in case the dunder method is not symmetrical, then
564
- # we use the standard process
565
- pass
566
- if __name in class_decl.methods:
567
- fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
568
- return fn(*args, **kwargs) # type: ignore[arg-type]
569
- # Handle == and != fallbacks to eq and ne helpers if the methods aren't defined on the class explicitly.
570
- if __name == "__eq__":
571
- from .egraph import BaseExpr, eq
572
-
573
- return eq(cast("BaseExpr", self)).to(cast("BaseExpr", args[0]))
574
- if __name == "__ne__":
575
- from .egraph import BaseExpr, ne
576
-
577
- return cast("RuntimeExpr", ne(cast("BaseExpr", self)).to(cast("BaseExpr", args[0])))
578
-
579
- if __name in PARTIAL_METHODS:
580
- return NotImplemented
581
- raise TypeError(f"{class_name!r} object does not support {__name}")
587
+ def __ne__(self, other: object) -> object: # type: ignore[override]
588
+ if (method := _get_expr_method(self, "__ne__")) is not None:
589
+ return method(other)
582
590
 
583
- setattr(RuntimeExpr, name, _special_method)
591
+ from .egraph import BaseExpr, ne # noqa: PLC0415
584
592
 
585
- # For each of the reflected binary methods, translate to the corresponding non-reflected method
586
- for reflected, non_reflected in REFLECTED_BINARY_METHODS.items():
593
+ return ne(cast("BaseExpr", self)).to(cast("BaseExpr", other))
587
594
 
588
- def _reflected_method(self: RuntimeExpr, other: object, __non_reflected: str = non_reflected) -> RuntimeExpr | None:
589
- # All binary methods are also "partial" meaning we should try to upcast first.
590
- return call_method_min_conversion(other, self, __non_reflected)
595
+ def __call__(
596
+ self, *args: object, **kwargs: object
597
+ ) -> object: # define it here only for type checking, it will be overriden below
598
+ ...
599
+
600
+
601
+ def _get_expr_method(expr: RuntimeExpr, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
602
+ if name in (preserved_methods := expr.__egg_class_decl__.preserved_methods):
603
+ return preserved_methods[name].__get__(expr)
591
604
 
592
- setattr(RuntimeExpr, reflected, _reflected_method)
605
+ if name in expr.__egg_class_decl__.methods:
606
+ return RuntimeFunction(expr.__egg_decls_thunk__, Thunk.value(MethodRef(expr.__egg_class_name__, name)), expr)
607
+ return None
593
608
 
594
609
 
595
- def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None:
596
- from .conversion import min_convertable_tp, resolve_literal
610
+ def define_expr_method(name: str) -> None:
611
+ """
612
+ Given the name of a method, explicitly defines it on the runtime type that holds `Expr` objects as a method.
597
613
 
598
- # find a minimum type that both can be converted to
599
- # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
600
- min_tp = min_convertable_tp(slf, other, name).to_var()
601
- slf = resolve_literal(min_tp, slf)
602
- other = resolve_literal(min_tp, other)
603
- method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
604
- return method(other)
614
+ Call this if you need a method to be defined on the type itself where overrindg with `__getattr__` does not suffice,
615
+ like for NumPy's `__array_ufunc__`.
616
+ """
605
617
 
618
+ def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs):
619
+ fn = _get_expr_method(self, __name)
620
+ if fn is None:
621
+ raise TypeError(f"{self.__egg_class_name__} expression has no method {__name}")
622
+ return fn(*args, **kwargs)
606
623
 
607
- for name in PRESERVED_METHODS:
624
+ setattr(RuntimeExpr, name, _defined_method)
608
625
 
609
- def _preserved_method(self: RuntimeExpr, __name: str = name):
610
- try:
611
- method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name]
612
- except KeyError as e:
613
- raise TypeError(f"{self.__egg_typed_expr__.tp.name} has no method {__name}") from e
614
- return method(self)
615
626
 
616
- setattr(RuntimeExpr, name, _preserved_method)
627
+ for name in TYPE_DEFINED_METHODS:
628
+ define_expr_method(name)
629
+
630
+
631
+ for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
632
+
633
+ def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object:
634
+ """
635
+ Implements numeric binary operations.
636
+
637
+ Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
638
+ the LHS or the RHS as exactly the right type and then upcasting the other to that type.
639
+ """
640
+ # 1. switch if reversed method
641
+ if r_method:
642
+ self, other = other, self
643
+ # If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion
644
+ if not (
645
+ isinstance(self, RuntimeExpr)
646
+ and isinstance(other, RuntimeExpr)
647
+ and (
648
+ self.__egg_decls__.check_binary_method_with_types(
649
+ name, self.__egg_typed_expr__.tp, other.__egg_typed_expr__.tp
650
+ )
651
+ )
652
+ ):
653
+ from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415
654
+
655
+ best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))
656
+
657
+ if not best_method:
658
+ raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
659
+ self, other = best_method[0](self), best_method[1](other)
660
+
661
+ method_ref = MethodRef(self.__egg_class_name__, name)
662
+ fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
663
+ return fn(other)
664
+
665
+ method_name = f"__r{name[2:]}" if r_method else name
666
+ setattr(RuntimeExpr, method_name, _numeric_binary_method)
617
667
 
618
668
 
619
669
  def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
egglog/version_compat.py CHANGED
@@ -30,7 +30,7 @@ if BEFORE_3_11:
30
30
 
31
31
  _collect_type_vars((T, List[S, T])) == (T, S)
32
32
  """
33
- from .runtime import RuntimeClass
33
+ from .runtime import RuntimeClass # noqa: PLC0415
34
34
 
35
35
  if typevar_types is None:
36
36
  typevar_types = typing.TypeVar
@@ -39,7 +39,7 @@ if BEFORE_3_11:
39
39
  if isinstance(t, typevar_types) and t not in tvars:
40
40
  tvars.append(t)
41
41
  # **MONKEYPATCH CHANGE HERE TO ADD RuntimeClass**
42
- if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)): # type: ignore[name-defined]
42
+ if isinstance(t, (typing._GenericAlias, typing.GenericAlias, types.UnionType, RuntimeClass)):
43
43
  tvars.extend([t for t in t.__parameters__ if t not in tvars])
44
44
  return tuple(tvars)
45
45
 
@@ -48,7 +48,7 @@ if BEFORE_3_11:
48
48
  @typing.no_type_check
49
49
  @typing._tp_cache
50
50
  def __getitem__monkeypatch(self, params): # noqa: C901, PLR0912
51
- from .runtime import RuntimeClass
51
+ from .runtime import RuntimeClass # noqa: PLC0415
52
52
 
53
53
  if self.__origin__ in (typing.Generic, typing.Protocol):
54
54
  # Can't subscript Generic[...] or Protocol[...].
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: egglog
3
- Version: 10.0.2
3
+ Version: 11.1.0
4
4
  Classifier: Environment :: MacOS X
5
5
  Classifier: Environment :: Win32 (MS Windows)
6
6
  Classifier: Intended Audience :: Developers