egglog 9.0.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/runtime.py ADDED
@@ -0,0 +1,633 @@
1
+ """
2
+ Holds a number of types which are only used at runtime to emulate Python objects.
3
+
4
+ Users will not import anything from this module, and statically they won't know these are the types they are using.
5
+
6
+ But at runtime they will be exposed.
7
+
8
+ Note that all their internal fields are prefixed with __egg_ to avoid name collisions with user code, but will end in __
9
+ so they are not mangled by Python and can be accessed by the user.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import operator
15
+ from collections.abc import Callable
16
+ from dataclasses import dataclass, replace
17
+ from inspect import Parameter, Signature
18
+ from itertools import zip_longest
19
+ from typing import TYPE_CHECKING, TypeVar, Union, cast, get_args, get_origin
20
+
21
+ from .declarations import *
22
+ from .pretty import *
23
+ from .thunk import Thunk
24
+ from .type_constraint_solver import *
25
+
26
+ if TYPE_CHECKING:
27
+ from collections.abc import Iterable
28
+
29
+ from .egraph import Fact
30
+
31
+
32
+ __all__ = [
33
+ "LIT_CLASS_NAMES",
34
+ "REFLECTED_BINARY_METHODS",
35
+ "RuntimeClass",
36
+ "RuntimeExpr",
37
+ "RuntimeFunction",
38
+ "resolve_callable",
39
+ "resolve_type_annotation",
40
+ "resolve_type_annotation_mutate",
41
+ ]
42
+
43
+
44
+ UNIT_CLASS_NAME = "Unit"
45
+ UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
46
+ LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
47
+
48
+ REFLECTED_BINARY_METHODS = {
49
+ "__radd__": "__add__",
50
+ "__rsub__": "__sub__",
51
+ "__rmul__": "__mul__",
52
+ "__rmatmul__": "__matmul__",
53
+ "__rtruediv__": "__truediv__",
54
+ "__rfloordiv__": "__floordiv__",
55
+ "__rmod__": "__mod__",
56
+ "__rpow__": "__pow__",
57
+ "__rlshift__": "__lshift__",
58
+ "__rrshift__": "__rshift__",
59
+ "__rand__": "__and__",
60
+ "__rxor__": "__xor__",
61
+ "__ror__": "__or__",
62
+ }
63
+
64
+ # Methods that need to return real Python values not expressions
65
+ PRESERVED_METHODS = [
66
+ "__bool__",
67
+ "__len__",
68
+ "__complex__",
69
+ "__int__",
70
+ "__float__",
71
+ "__iter__",
72
+ "__index__",
73
+ "__float__",
74
+ "__int__",
75
+ ]
76
+
77
+ # Set this globally so we can get access to PyObject when we have a type annotation of just object.
78
+ # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
79
+ _PY_OBJECT_CLASS: RuntimeClass | None = None
80
+ # Same for functions
81
+ _UNSTABLE_FN_CLASS: RuntimeClass | None = None
82
+
83
+ T = TypeVar("T")
84
+
85
+
86
+ def resolve_type_annotation_mutate(decls: Declarations, tp: object) -> TypeOrVarRef:
87
+ """
88
+ Wrap resolve_type_annotation to mutate decls, as a helper for internal use in sitations where that is more ergonomic.
89
+ """
90
+ new_decls, tp = resolve_type_annotation(tp)
91
+ decls |= new_decls
92
+ return tp
93
+
94
+
95
+ def resolve_type_annotation(tp: object) -> tuple[DeclerationsLike, TypeOrVarRef]:
96
+ """
97
+ Resolves a type object into a type reference.
98
+
99
+ Any runtime type object decls will be returned as well. We do this so we can use this without having to
100
+ resolve the decls if need be.
101
+ """
102
+ if isinstance(tp, TypeVar):
103
+ return None, ClassTypeVarRef.from_type_var(tp)
104
+ # If there is a union, then we assume the first item is the type we want, and the others are types that can be converted to that type.
105
+ if get_origin(tp) == Union:
106
+ first, *_rest = get_args(tp)
107
+ return resolve_type_annotation(first)
108
+
109
+ # If the type is `object` then this is assumed to be a PyObjectLike, i.e. converted into a PyObject
110
+ if tp is object:
111
+ assert _PY_OBJECT_CLASS
112
+ return resolve_type_annotation(_PY_OBJECT_CLASS)
113
+ # If the type is a `Callable` then convert it into a UnstableFn
114
+ if get_origin(tp) == Callable:
115
+ assert _UNSTABLE_FN_CLASS
116
+ args, ret = get_args(tp)
117
+ return resolve_type_annotation(_UNSTABLE_FN_CLASS[(ret, *args)])
118
+ if isinstance(tp, RuntimeClass):
119
+ return tp, tp.__egg_tp__
120
+ raise TypeError(f"Unexpected type annotation {tp}")
121
+
122
+
123
+ def inverse_resolve_type_annotation(decls_thunk: Callable[[], Declarations], tp: TypeOrVarRef) -> object:
124
+ """
125
+ Inverse of resolve_type_annotation
126
+ """
127
+ if isinstance(tp, ClassTypeVarRef):
128
+ return tp.to_type_var()
129
+ return RuntimeClass(decls_thunk, tp)
130
+
131
+
132
+ ##
133
+ # Runtime objects
134
+ ##
135
+
136
+
137
+ @dataclass
138
+ class RuntimeClass(DelayedDeclerations):
139
+ __egg_tp__: TypeRefWithVars
140
+
141
+ def __post_init__(self) -> None:
142
+ global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
143
+ if (name := self.__egg_tp__.name) == "PyObject":
144
+ _PY_OBJECT_CLASS = self
145
+ elif name == "UnstableFn" and not self.__egg_tp__.args:
146
+ _UNSTABLE_FN_CLASS = self
147
+
148
+ def verify(self) -> None:
149
+ if not self.__egg_tp__.args:
150
+ return
151
+
152
+ # Raise error if we have args, but they are the wrong number
153
+ desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
154
+ if len(self.__egg_tp__.args) != len(desired_args):
155
+ raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
156
+
157
+ def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
158
+ """
159
+ Create an instance of this kind by calling the __init__ classmethod
160
+ """
161
+ # If this is a literal type, initializing it with a literal should return a literal
162
+ if (name := self.__egg_tp__.name) == "PyObject":
163
+ assert len(args) == 1
164
+ return RuntimeExpr(
165
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])))
166
+ )
167
+ if name == "UnstableFn":
168
+ assert not kwargs
169
+ fn_arg, *partial_args = args
170
+ del args
171
+ # Assumes we don't have types set for UnstableFn w/ generics, that they have to be inferred
172
+
173
+ # 1. Call it with the partial args, and use untyped vars for the rest of the args
174
+ res = cast("Callable", fn_arg)(*partial_args, _egg_partial_function=True)
175
+ assert res is not None, "Mutable partial functions not supported"
176
+ # 2. Use the inferred return type and inferred rest arg types as the types of the function, and
177
+ # the partially applied args as the args.
178
+ call = (res_typed_expr := res.__egg_typed_expr__).expr
179
+ return_tp = res_typed_expr.tp
180
+ assert isinstance(call, CallDecl), "partial function must be a call"
181
+ n_args = len(partial_args)
182
+ value = PartialCallDecl(replace(call, args=call.args[:n_args]))
183
+ remaining_arg_types = [a.tp for a in call.args[n_args:]]
184
+ type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
185
+ return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
186
+
187
+ if name in UNARY_LIT_CLASS_NAMES:
188
+ assert len(args) == 1
189
+ assert isinstance(args[0], int | float | str | bool)
190
+ return RuntimeExpr(
191
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])))
192
+ )
193
+ if name == UNIT_CLASS_NAME:
194
+ assert len(args) == 0
195
+ return RuntimeExpr(
196
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)))
197
+ )
198
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(InitRef(name)), self.__egg_tp__.to_just())
199
+ return fn(*args, **kwargs) # type: ignore[arg-type]
200
+
201
+ def __dir__(self) -> list[str]:
202
+ cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
203
+ possible_methods = (
204
+ list(cls_decl.class_methods) + list(cls_decl.class_variables) + list(cls_decl.preserved_methods)
205
+ )
206
+ if "__init__" in possible_methods:
207
+ possible_methods.remove("__init__")
208
+ possible_methods.append("__call__")
209
+ return possible_methods
210
+
211
+ def __getitem__(self, args: object) -> RuntimeClass:
212
+ if not isinstance(args, tuple):
213
+ args = (args,)
214
+ # defer resolving decls so that we can do generic instantiation for converters before all
215
+ # method types are defined.
216
+ decls_like, new_args = cast(
217
+ "tuple[tuple[DeclerationsLike, ...], tuple[TypeOrVarRef, ...]]",
218
+ zip(*(resolve_type_annotation(arg) for arg in args), strict=False),
219
+ )
220
+ # if we already have some args bound and some not, then we shold replace all existing args of typevars with new
221
+ # args
222
+ if old_args := self.__egg_tp__.args:
223
+ is_typevar = [isinstance(arg, ClassTypeVarRef) for arg in old_args]
224
+ if sum(is_typevar) != len(new_args):
225
+ raise TypeError(f"Expected {sum(is_typevar)} typevars, got {len(new_args)}")
226
+ new_args_list = list(new_args)
227
+ final_args = tuple(new_args_list.pop(0) if is_typevar[i] else old_args[i] for i in range(len(old_args)))
228
+ else:
229
+ final_args = new_args
230
+ tp = TypeRefWithVars(self.__egg_tp__.name, final_args)
231
+ return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp)
232
+
233
+ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
234
+ if name == "__origin__" and self.__egg_tp__.args:
235
+ return RuntimeClass(self.__egg_decls_thunk__, TypeRefWithVars(self.__egg_tp__.name))
236
+
237
+ # Special case some names that don't exist so we can exit early without resolving decls
238
+ # Important so if we take union of RuntimeClass it won't try to resolve decls
239
+ if name in {
240
+ "__typing_subst__",
241
+ "__parameters__",
242
+ # Origin is used in get_type_hints which is used when resolving the class itself
243
+ "__origin__",
244
+ "__typing_unpacked_tuple_args__",
245
+ "__typing_is_unpacked_typevartuple__",
246
+ }:
247
+ raise AttributeError
248
+
249
+ try:
250
+ cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
251
+ except Exception as e:
252
+ e.add_note(f"Error processing class {self.__egg_tp__.name}")
253
+ raise
254
+
255
+ preserved_methods = cls_decl.preserved_methods
256
+ if name in preserved_methods:
257
+ return preserved_methods[name].__get__(self)
258
+
259
+ # if this is a class variable, return an expr for it, otherwise, assume it's a method
260
+ if name in cls_decl.class_variables:
261
+ return_tp = cls_decl.class_variables[name]
262
+ return RuntimeExpr(
263
+ self.__egg_decls_thunk__,
264
+ Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name)))),
265
+ )
266
+ if name in cls_decl.class_methods:
267
+ return RuntimeFunction(
268
+ self.__egg_decls_thunk__,
269
+ Thunk.value(ClassMethodRef(self.__egg_tp__.name, name)),
270
+ self.__egg_tp__.to_just(),
271
+ )
272
+ # allow referencing properties and methods as class variables as well
273
+ if name in cls_decl.properties:
274
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.name, name)))
275
+ if name in cls_decl.methods:
276
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name)))
277
+
278
+ msg = f"Class {self.__egg_tp__.name} has no method {name}"
279
+ raise AttributeError(msg) from None
280
+
281
+ def __str__(self) -> str:
282
+ return str(self.__egg_tp__)
283
+
284
+ # Make hashable so can go in Union
285
+ def __hash__(self) -> int:
286
+ return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
287
+
288
+ # Support unioning like types
289
+ def __or__(self, value: type) -> object:
290
+ return Union[self, value] # noqa: UP007
291
+
292
+ @property
293
+ def __parameters__(self) -> tuple[object, ...]:
294
+ """
295
+ Emit a number of typevar params so that when using generic type aliases, we know how to resolve these properly.
296
+ """
297
+ return tuple(inverse_resolve_type_annotation(self.__egg_decls_thunk__, tp) for tp in self.__egg_tp__.args)
298
+
299
+
300
+ @dataclass
301
+ class RuntimeFunction(DelayedDeclerations):
302
+ __egg_ref_thunk__: Callable[[], CallableRef]
303
+ # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
304
+ __egg_bound__: JustTypeRef | RuntimeExpr | None = None
305
+
306
+ @property
307
+ def __egg_ref__(self) -> CallableRef:
308
+ return self.__egg_ref_thunk__()
309
+
310
+ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
311
+ from .conversion import resolve_literal
312
+
313
+ if isinstance(self.__egg_bound__, RuntimeExpr):
314
+ args = (self.__egg_bound__, *args)
315
+ try:
316
+ signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).signature
317
+ except Exception as e:
318
+ e.add_note(f"Failed to find callable {self}")
319
+ raise
320
+ decls = self.__egg_decls__.copy()
321
+ # Special case function application bc we dont support variadic generics yet generally
322
+ if signature == "fn-app":
323
+ fn, *rest_args = args
324
+ args = tuple(rest_args)
325
+ assert not kwargs
326
+ assert isinstance(fn, RuntimeExpr)
327
+ decls.update(fn)
328
+ function_value = fn.__egg_typed_expr__
329
+ fn_tp = function_value.tp
330
+ assert fn_tp.name == "UnstableFn"
331
+ fn_return_tp, *fn_arg_tps = fn_tp.args
332
+ signature = FunctionSignature(
333
+ tuple(tp.to_var() for tp in fn_arg_tps),
334
+ tuple(f"_{i}" for i in range(len(fn_arg_tps))),
335
+ (None,) * len(fn_arg_tps),
336
+ fn_return_tp.to_var(),
337
+ )
338
+ else:
339
+ function_value = None
340
+ assert isinstance(signature, FunctionSignature)
341
+
342
+ # Turn all keyword args into positional args
343
+ py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
344
+ try:
345
+ bound = py_signature.bind(*args, **kwargs)
346
+ except TypeError as err:
347
+ raise TypeError(f"Failed to call {self} with args {args} and kwargs {kwargs}") from err
348
+ del kwargs
349
+ bound.apply_defaults()
350
+ assert not bound.kwargs
351
+ args = bound.args
352
+
353
+ tcs = TypeConstraintSolver(decls)
354
+ bound_tp = (
355
+ None
356
+ if self.__egg_bound__ is None
357
+ else self.__egg_bound__.__egg_typed_expr__.tp
358
+ if isinstance(self.__egg_bound__, RuntimeExpr)
359
+ else self.__egg_bound__
360
+ )
361
+ if (
362
+ bound_tp
363
+ and bound_tp.args
364
+ # Don't bind class if we have a first class function arg, b/c we don't support that yet
365
+ and not function_value
366
+ ):
367
+ tcs.bind_class(bound_tp)
368
+ assert (operator.ge if signature.var_arg_type else operator.eq)(len(args), len(signature.arg_types))
369
+ cls_name = bound_tp.name if bound_tp else None
370
+ upcasted_args = [
371
+ resolve_literal(cast("TypeOrVarRef", tp), arg, Thunk.value(decls), tcs=tcs, cls_name=cls_name)
372
+ for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
373
+ ]
374
+ decls.update(*upcasted_args)
375
+ arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
376
+ return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_name)
377
+ bound_params = (
378
+ cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None
379
+ )
380
+ # If we were using unstable-app to call a funciton, add that function back as the first arg.
381
+ if function_value:
382
+ arg_exprs = (function_value, *arg_exprs)
383
+ expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
384
+ typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
385
+ # If there is not return type, we are mutating the first arg
386
+ if not signature.return_type:
387
+ first_arg = upcasted_args[0]
388
+ first_arg.__egg_decls_thunk__ = Thunk.value(decls)
389
+ first_arg.__egg_typed_expr_thunk__ = Thunk.value(typed_expr_decl)
390
+ return None
391
+ return RuntimeExpr.__from_values__(decls, typed_expr_decl)
392
+
393
+ def __str__(self) -> str:
394
+ first_arg, bound_tp_params = None, None
395
+ match self.__egg_bound__:
396
+ case RuntimeExpr(_):
397
+ first_arg = self.__egg_bound__.__egg_typed_expr__.expr
398
+ case JustTypeRef(_, args):
399
+ bound_tp_params = args
400
+ return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
401
+
402
+
403
+ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
404
+ """
405
+ Convert to a Python signature.
406
+
407
+ If optional_args is true, then all args will be treated as optional, as if a default was provided that makes them
408
+ a var with that arg name as the value.
409
+
410
+ Used for partial application to try binding a function with only some of its args.
411
+ """
412
+ parameters = [
413
+ Parameter(
414
+ n,
415
+ Parameter.POSITIONAL_OR_KEYWORD,
416
+ default=RuntimeExpr.__from_values__(
417
+ decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n, True))
418
+ )
419
+ if d is not None or optional_args
420
+ else Parameter.empty,
421
+ )
422
+ for n, d, t in zip(sig.arg_names, sig.arg_defaults, sig.arg_types, strict=True)
423
+ ]
424
+ if isinstance(sig, FunctionSignature) and sig.var_arg_type is not None:
425
+ parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
426
+ return Signature(parameters)
427
+
428
+
429
+ # All methods which should return NotImplemented if they fail to resolve
430
+ # From https://docs.python.org/3/reference/datamodel.html
431
+ PARTIAL_METHODS = {
432
+ "__lt__",
433
+ "__le__",
434
+ "__eq__",
435
+ "__ne__",
436
+ "__gt__",
437
+ "__ge__",
438
+ "__add__",
439
+ "__sub__",
440
+ "__mul__",
441
+ "__matmul__",
442
+ "__truediv__",
443
+ "__floordiv__",
444
+ "__mod__",
445
+ "__divmod__",
446
+ "__pow__",
447
+ "__lshift__",
448
+ "__rshift__",
449
+ "__and__",
450
+ "__xor__",
451
+ "__or__",
452
+ }
453
+
454
+
455
+ @dataclass
456
+ class RuntimeExpr(DelayedDeclerations):
457
+ __egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
458
+
459
+ @classmethod
460
+ def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
461
+ return cls(Thunk.value(d), Thunk.value(e))
462
+
463
+ def __with_expr__(self, e: TypedExprDecl) -> RuntimeExpr:
464
+ return RuntimeExpr(self.__egg_decls_thunk__, Thunk.value(e))
465
+
466
+ @property
467
+ def __egg_typed_expr__(self) -> TypedExprDecl:
468
+ return self.__egg_typed_expr_thunk__()
469
+
470
+ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
471
+ cls_name = self.__egg_class_name__
472
+ class_decl = self.__egg_class_decl__
473
+
474
+ if name in (preserved_methods := class_decl.preserved_methods):
475
+ return preserved_methods[name].__get__(self)
476
+
477
+ if name in class_decl.methods:
478
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(cls_name, name)), self)
479
+ if name in class_decl.properties:
480
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(cls_name, name)), self)()
481
+ raise AttributeError(f"{cls_name} has no method {name}") from None
482
+
483
+ def __repr__(self) -> str:
484
+ """
485
+ The repr of the expr is the pretty printed version of the expr.
486
+ """
487
+ return str(self)
488
+
489
+ def __str__(self) -> str:
490
+ return self.__egg_pretty__(None)
491
+
492
+ def __egg_pretty__(self, wrapping_fn: str | None) -> str:
493
+ return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
494
+
495
+ def _ipython_display_(self) -> None:
496
+ from IPython.display import Code, display
497
+
498
+ display(Code(str(self), language="python"))
499
+
500
+ def __dir__(self) -> Iterable[str]:
501
+ class_decl = self.__egg_class_decl__
502
+ return list(class_decl.methods) + list(class_decl.properties) + list(class_decl.preserved_methods)
503
+
504
+ @property
505
+ def __egg_class_name__(self) -> str:
506
+ return self.__egg_typed_expr__.tp.name
507
+
508
+ @property
509
+ def __egg_class_decl__(self) -> ClassDecl:
510
+ return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
511
+
512
+ # These both will be overriden below in the special methods section, but add these here for type hinting purposes
513
+ def __eq__(self, other: object) -> Fact: # type: ignore[override, empty-body]
514
+ ...
515
+
516
+ def __ne__(self, other: object) -> RuntimeExpr: # type: ignore[override, empty-body]
517
+ ...
518
+
519
+ # Implement these so that copy() works on this object
520
+ # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
521
+
522
+ def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
523
+ return self.__egg_decls__, self.__egg_typed_expr__
524
+
525
+ def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
526
+ self.__egg_decls_thunk__ = Thunk.value(d[0])
527
+ self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
528
+
529
+ def __hash__(self) -> int:
530
+ return hash(self.__egg_typed_expr__)
531
+
532
+
533
+ # Define each of the special methods, since we have already declared them for pretty printing
534
+ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call__", "__setitem__", "__delitem__"]:
535
+
536
+ def _special_method(
537
+ self: RuntimeExpr,
538
+ *args: object,
539
+ __name: str = name,
540
+ **kwargs: object,
541
+ ) -> RuntimeExpr | Fact | None:
542
+ from .conversion import ConvertError
543
+
544
+ class_name = self.__egg_class_name__
545
+ class_decl = self.__egg_class_decl__
546
+ # First, try to resolve as preserved method
547
+ try:
548
+ method = class_decl.preserved_methods[__name]
549
+ except KeyError:
550
+ pass
551
+ else:
552
+ return method(self, *args, **kwargs)
553
+ # If this is a "partial" method meaning that it can return NotImplemented,
554
+ # we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
555
+ # using the arg type of the self arg.
556
+ # This is neccesary so if we add like an int to a ndarray, it will upcast the int to an ndarray, instead of vice versa.
557
+ if __name in PARTIAL_METHODS:
558
+ try:
559
+ return call_method_min_conversion(self, args[0], __name)
560
+ except ConvertError:
561
+ # Defer raising not imeplemented in case the dunder method is not symmetrical, then
562
+ # we use the standard process
563
+ pass
564
+ if __name in class_decl.methods:
565
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(class_name, __name)), self)
566
+ return fn(*args, **kwargs) # type: ignore[arg-type]
567
+ # Handle == and != fallbacks to eq and ne helpers if the methods aren't defined on the class explicitly.
568
+ if __name == "__eq__":
569
+ from .egraph import BaseExpr, eq
570
+
571
+ return eq(cast("BaseExpr", self)).to(cast("BaseExpr", args[0]))
572
+ if __name == "__ne__":
573
+ from .egraph import BaseExpr, ne
574
+
575
+ return cast("RuntimeExpr", ne(cast("BaseExpr", self)).to(cast("BaseExpr", args[0])))
576
+
577
+ if __name in PARTIAL_METHODS:
578
+ return NotImplemented
579
+ raise TypeError(f"{class_name!r} object does not support {__name}")
580
+
581
+ setattr(RuntimeExpr, name, _special_method)
582
+
583
+ # For each of the reflected binary methods, translate to the corresponding non-reflected method
584
+ for reflected, non_reflected in REFLECTED_BINARY_METHODS.items():
585
+
586
+ def _reflected_method(self: RuntimeExpr, other: object, __non_reflected: str = non_reflected) -> RuntimeExpr | None:
587
+ # All binary methods are also "partial" meaning we should try to upcast first.
588
+ return call_method_min_conversion(other, self, __non_reflected)
589
+
590
+ setattr(RuntimeExpr, reflected, _reflected_method)
591
+
592
+
593
+ def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None:
594
+ from .conversion import min_convertable_tp, resolve_literal
595
+
596
+ # find a minimum type that both can be converted to
597
+ # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
598
+ min_tp = min_convertable_tp(slf, other, name).to_var()
599
+ slf = resolve_literal(min_tp, slf)
600
+ other = resolve_literal(min_tp, other)
601
+ method = RuntimeFunction(slf.__egg_decls_thunk__, Thunk.value(MethodRef(slf.__egg_class_name__, name)), slf)
602
+ return method(other)
603
+
604
+
605
+ for name in PRESERVED_METHODS:
606
+
607
+ def _preserved_method(self: RuntimeExpr, __name: str = name):
608
+ try:
609
+ method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name]
610
+ except KeyError as e:
611
+ raise TypeError(f"{self.__egg_typed_expr__.tp.name} has no method {__name}") from e
612
+ return method(self)
613
+
614
+ setattr(RuntimeExpr, name, _preserved_method)
615
+
616
+
617
+ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
618
+ """
619
+ Resolves a runtime callable into a ref
620
+ """
621
+ match callable:
622
+ case RuntimeFunction(decls, ref, _):
623
+ return ref(), decls()
624
+ case RuntimeClass(thunk, tp):
625
+ return InitRef(tp.name), thunk()
626
+ case RuntimeExpr(decl_thunk, expr_thunk):
627
+ if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(
628
+ expr.callable, ConstantRef | ClassVariableRef
629
+ ):
630
+ raise NotImplementedError(f"Can only turn constants or classvars into callable refs, not {expr}")
631
+ return expr.callable, decl_thunk()
632
+ case _:
633
+ raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")