egglog 11.2.0__cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

Files changed (46) hide show
  1. egglog/__init__.py +13 -0
  2. egglog/bindings.cpython-314-x86_64-linux-gnu.so +0 -0
  3. egglog/bindings.pyi +734 -0
  4. egglog/builtins.py +1133 -0
  5. egglog/config.py +8 -0
  6. egglog/conversion.py +286 -0
  7. egglog/declarations.py +912 -0
  8. egglog/deconstruct.py +173 -0
  9. egglog/egraph.py +1875 -0
  10. egglog/egraph_state.py +680 -0
  11. egglog/examples/README.rst +5 -0
  12. egglog/examples/__init__.py +3 -0
  13. egglog/examples/bignum.py +32 -0
  14. egglog/examples/bool.py +38 -0
  15. egglog/examples/eqsat_basic.py +44 -0
  16. egglog/examples/fib.py +28 -0
  17. egglog/examples/higher_order_functions.py +42 -0
  18. egglog/examples/jointree.py +67 -0
  19. egglog/examples/lambda_.py +287 -0
  20. egglog/examples/matrix.py +175 -0
  21. egglog/examples/multiset.py +60 -0
  22. egglog/examples/ndarrays.py +144 -0
  23. egglog/examples/resolution.py +84 -0
  24. egglog/examples/schedule_demo.py +34 -0
  25. egglog/exp/__init__.py +3 -0
  26. egglog/exp/array_api.py +2019 -0
  27. egglog/exp/array_api_jit.py +51 -0
  28. egglog/exp/array_api_loopnest.py +74 -0
  29. egglog/exp/array_api_numba.py +69 -0
  30. egglog/exp/array_api_program_gen.py +510 -0
  31. egglog/exp/program_gen.py +425 -0
  32. egglog/exp/siu_examples.py +32 -0
  33. egglog/ipython_magic.py +41 -0
  34. egglog/pretty.py +509 -0
  35. egglog/py.typed +0 -0
  36. egglog/runtime.py +712 -0
  37. egglog/thunk.py +97 -0
  38. egglog/type_constraint_solver.py +113 -0
  39. egglog/version_compat.py +87 -0
  40. egglog/visualizer.css +1 -0
  41. egglog/visualizer.js +35777 -0
  42. egglog/visualizer_widget.py +39 -0
  43. egglog-11.2.0.dist-info/METADATA +74 -0
  44. egglog-11.2.0.dist-info/RECORD +46 -0
  45. egglog-11.2.0.dist-info/WHEEL +4 -0
  46. egglog-11.2.0.dist-info/licenses/LICENSE +21 -0
egglog/runtime.py ADDED
@@ -0,0 +1,712 @@
1
+ """
2
+ Holds a number of types which are only used at runtime to emulate Python objects.
3
+
4
+ Users will not import anything from this module, and statically they won't know these are the types they are using.
5
+
6
+ But at runtime they will be exposed.
7
+
8
+ Note that all their internal fields are prefixed with __egg_ to avoid name collisions with user code, but will end in __
9
+ so they are not mangled by Python and can be accessed by the user.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import itertools
15
+ import operator
16
+ import types
17
+ from collections.abc import Callable
18
+ from dataclasses import InitVar, dataclass, replace
19
+ from inspect import Parameter, Signature
20
+ from itertools import zip_longest
21
+ from typing import TYPE_CHECKING, Any, TypeVar, Union, cast, get_args, get_origin
22
+
23
+ from typing_extensions import assert_never
24
+
25
+ from .declarations import *
26
+ from .pretty import *
27
+ from .thunk import Thunk
28
+ from .type_constraint_solver import *
29
+ from .version_compat import *
30
+
31
+ if TYPE_CHECKING:
32
+ from collections.abc import Iterable
33
+
34
+
35
+ __all__ = [
36
+ "LIT_CLASS_NAMES",
37
+ "NUMERIC_BINARY_METHODS",
38
+ "RuntimeClass",
39
+ "RuntimeExpr",
40
+ "RuntimeFunction",
41
+ "create_callable",
42
+ "define_expr_method",
43
+ "resolve_callable",
44
+ "resolve_type_annotation",
45
+ "resolve_type_annotation_mutate",
46
+ ]
47
+
48
+
49
+ UNIT_CLASS_NAME = "Unit"
50
+ UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
51
+ LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
52
+
53
+ # All methods which should return NotImplemented if they fail to resolve and are reflected as well
54
+ # From https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
55
+
56
+ NUMERIC_BINARY_METHODS = {
57
+ "__add__",
58
+ "__sub__",
59
+ "__mul__",
60
+ "__matmul__",
61
+ "__truediv__",
62
+ "__floordiv__",
63
+ "__mod__",
64
+ "__divmod__",
65
+ "__pow__",
66
+ "__lshift__",
67
+ "__rshift__",
68
+ "__and__",
69
+ "__xor__",
70
+ "__or__",
71
+ "__lt__",
72
+ "__le__",
73
+ "__gt__",
74
+ "__ge__",
75
+ }
76
+
77
+
78
+ # Methods that need to be defined on the runtime type that holds `Expr` objects, so that they can be used as methods.
79
+
80
+ TYPE_DEFINED_METHODS = {
81
+ "__bool__",
82
+ "__len__",
83
+ "__complex__",
84
+ "__int__",
85
+ "__float__",
86
+ "__iter__",
87
+ "__index__",
88
+ "__call__",
89
+ "__getitem__",
90
+ "__setitem__",
91
+ "__delitem__",
92
+ "__pos__",
93
+ "__neg__",
94
+ "__invert__",
95
+ "__round__",
96
+ }
97
+
98
+ # Set this globally so we can get access to PyObject when we have a type annotation of just object.
99
+ # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
100
+ _PY_OBJECT_CLASS: RuntimeClass | None = None
101
+ # Same for functions
102
+ _UNSTABLE_FN_CLASS: RuntimeClass | None = None
103
+
104
+ T = TypeVar("T")
105
+
106
+
107
+ def resolve_type_annotation_mutate(decls: Declarations, tp: object) -> TypeOrVarRef:
108
+ """
109
+ Wrap resolve_type_annotation to mutate decls, as a helper for internal use in sitations where that is more ergonomic.
110
+ """
111
+ new_decls, tp = resolve_type_annotation(tp)
112
+ decls |= new_decls
113
+ return tp
114
+
115
+
116
+ def resolve_type_annotation(tp: object) -> tuple[DeclerationsLike, TypeOrVarRef]:
117
+ """
118
+ Resolves a type object into a type reference.
119
+
120
+ Any runtime type object decls will be returned as well. We do this so we can use this without having to
121
+ resolve the decls if need be.
122
+ """
123
+ if isinstance(tp, TypeVar):
124
+ return None, ClassTypeVarRef.from_type_var(tp)
125
+ # If there is a union, then we assume the first item is the type we want, and the others are types that can be converted to that type.
126
+ if get_origin(tp) == Union:
127
+ first, *_rest = get_args(tp)
128
+ return resolve_type_annotation(first)
129
+
130
+ # If the type is `object` then this is assumed to be a PyObjectLike, i.e. converted into a PyObject
131
+ if tp is object:
132
+ assert _PY_OBJECT_CLASS
133
+ return resolve_type_annotation(_PY_OBJECT_CLASS)
134
+ # If the type is a `Callable` then convert it into a UnstableFn
135
+ if get_origin(tp) == Callable:
136
+ assert _UNSTABLE_FN_CLASS
137
+ args, ret = get_args(tp)
138
+ return resolve_type_annotation(_UNSTABLE_FN_CLASS[(ret, *args)])
139
+ if isinstance(tp, RuntimeClass):
140
+ return tp, tp.__egg_tp__
141
+ raise TypeError(f"Unexpected type annotation {tp}")
142
+
143
+
144
+ def inverse_resolve_type_annotation(decls_thunk: Callable[[], Declarations], tp: TypeOrVarRef) -> object:
145
+ """
146
+ Inverse of resolve_type_annotation
147
+ """
148
+ if isinstance(tp, ClassTypeVarRef):
149
+ return tp.to_type_var()
150
+ return RuntimeClass(decls_thunk, tp)
151
+
152
+
153
+ ##
154
+ # Runtime objects
155
+ ##
156
+
157
+
158
+ class BaseClassFactoryMeta(type):
159
+ """
160
+ Base metaclass for all runtime classes created by ClassFactory
161
+ """
162
+
163
+ def __instancecheck__(cls, instance: object) -> bool:
164
+ assert isinstance(cls, RuntimeClass)
165
+ return isinstance(instance, RuntimeExpr) and cls.__egg_tp__.name == instance.__egg_typed_expr__.tp.name
166
+
167
+
168
+ class ClassFactory(type):
169
+ """
170
+ A metaclass for types which should create `type` objects when instantiated.
171
+
172
+ That's so that they work with `isinstance` and can be placed in `match ClassName()`.
173
+ """
174
+
175
+ def __call__(cls, *args, **kwargs) -> type:
176
+ # If we have params, don't inherit from `type` because we don't need to match against this and also
177
+ # this won't work with `Union[X]` because it won't look at `__parameters__` for instances of `type`.
178
+ if kwargs.pop("_egg_has_params", False):
179
+ return super().__call__(*args, **kwargs)
180
+ namespace: dict[str, Any] = {}
181
+ for m in reversed(cls.__mro__):
182
+ namespace.update(m.__dict__)
183
+ init = namespace.pop("__init__")
184
+ meta = types.new_class("type(RuntimeClass)", (BaseClassFactoryMeta,), {}, lambda ns: ns.update(**namespace))
185
+ tp = types.new_class("RuntimeClass", (), {"metaclass": meta})
186
+ init(tp, *args, **kwargs)
187
+ return tp
188
+
189
+ def __instancecheck__(cls, instance: object) -> bool:
190
+ return isinstance(instance, BaseClassFactoryMeta)
191
+
192
+
193
+ @dataclass(match_args=False)
194
+ class RuntimeClass(DelayedDeclerations, metaclass=ClassFactory):
195
+ __egg_tp__: TypeRefWithVars
196
+ # True if we want `__parameters__` to be recognized by `Union`, which means we can't inherit from `type` directly.
197
+ _egg_has_params: InitVar[bool] = False
198
+
199
+ def __post_init__(self, _egg_has_params: bool) -> None:
200
+ global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
201
+ if (name := self.__egg_tp__.name) == "PyObject":
202
+ _PY_OBJECT_CLASS = self
203
+ elif name == "UnstableFn" and not self.__egg_tp__.args:
204
+ _UNSTABLE_FN_CLASS = self
205
+
206
+ def verify(self) -> None:
207
+ if not self.__egg_tp__.args:
208
+ return
209
+
210
+ # Raise error if we have args, but they are the wrong number
211
+ desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
212
+ if len(self.__egg_tp__.args) != len(desired_args):
213
+ raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
214
+
215
+ def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
216
+ """
217
+ Create an instance of this kind by calling the __init__ classmethod
218
+ """
219
+ # If this is a literal type, initializing it with a literal should return a literal
220
+ if (name := self.__egg_tp__.name) == "PyObject":
221
+ assert len(args) == 1
222
+ return RuntimeExpr(
223
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0])))
224
+ )
225
+ if name == "UnstableFn":
226
+ assert not kwargs
227
+ fn_arg, *partial_args = args
228
+ del args
229
+ # Assumes we don't have types set for UnstableFn w/ generics, that they have to be inferred
230
+
231
+ # 1. Call it with the partial args, and use untyped vars for the rest of the args
232
+ res = cast("Callable", fn_arg)(*partial_args, _egg_partial_function=True)
233
+ assert res is not None, "Mutable partial functions not supported"
234
+ # 2. Use the inferred return type and inferred rest arg types as the types of the function, and
235
+ # the partially applied args as the args.
236
+ call = (res_typed_expr := res.__egg_typed_expr__).expr
237
+ return_tp = res_typed_expr.tp
238
+ assert isinstance(call, CallDecl), "partial function must be a call"
239
+ n_args = len(partial_args)
240
+ value = PartialCallDecl(replace(call, args=call.args[:n_args]))
241
+ remaining_arg_types = [a.tp for a in call.args[n_args:]]
242
+ type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
243
+ return RuntimeExpr.__from_values__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
244
+
245
+ if name in UNARY_LIT_CLASS_NAMES:
246
+ assert len(args) == 1
247
+ assert isinstance(args[0], int | float | str | bool)
248
+ return RuntimeExpr(
249
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0])))
250
+ )
251
+ if name == UNIT_CLASS_NAME:
252
+ assert len(args) == 0
253
+ return RuntimeExpr(
254
+ self.__egg_decls_thunk__, Thunk.value(TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None)))
255
+ )
256
+ fn = RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(InitRef(name)), self.__egg_tp__.to_just())
257
+ return fn(*args, **kwargs) # type: ignore[arg-type]
258
+
259
+ def __dir__(self) -> list[str]:
260
+ cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
261
+ possible_methods = (
262
+ list(cls_decl.class_methods) + list(cls_decl.class_variables) + list(cls_decl.preserved_methods)
263
+ )
264
+ if "__init__" in possible_methods:
265
+ possible_methods.remove("__init__")
266
+ possible_methods.append("__call__")
267
+ return possible_methods
268
+
269
+ def __getitem__(self, args: object) -> RuntimeClass:
270
+ if not isinstance(args, tuple):
271
+ args = (args,)
272
+ # defer resolving decls so that we can do generic instantiation for converters before all
273
+ # method types are defined.
274
+ decls_like, new_args = cast(
275
+ "tuple[tuple[DeclerationsLike, ...], tuple[TypeOrVarRef, ...]]",
276
+ zip(*(resolve_type_annotation(arg) for arg in args), strict=False),
277
+ )
278
+ # if we already have some args bound and some not, then we shold replace all existing args of typevars with new
279
+ # args
280
+ if old_args := self.__egg_tp__.args:
281
+ is_typevar = [isinstance(arg, ClassTypeVarRef) for arg in old_args]
282
+ if sum(is_typevar) != len(new_args):
283
+ raise TypeError(f"Expected {sum(is_typevar)} typevars, got {len(new_args)}")
284
+ new_args_list = list(new_args)
285
+ final_args = tuple(new_args_list.pop(0) if is_typevar[i] else old_args[i] for i in range(len(old_args)))
286
+ else:
287
+ final_args = new_args
288
+ tp = TypeRefWithVars(self.__egg_tp__.name, final_args)
289
+ return RuntimeClass(Thunk.fn(Declarations.create, self, *decls_like), tp, _egg_has_params=True)
290
+
291
+ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
292
+ if name == "__origin__" and self.__egg_tp__.args:
293
+ return RuntimeClass(self.__egg_decls_thunk__, TypeRefWithVars(self.__egg_tp__.name))
294
+
295
+ # Special case some names that don't exist so we can exit early without resolving decls
296
+ # Important so if we take union of RuntimeClass it won't try to resolve decls
297
+ if name in {
298
+ "__typing_subst__",
299
+ "__parameters__",
300
+ # Origin is used in get_type_hints which is used when resolving the class itself
301
+ "__origin__",
302
+ "__typing_unpacked_tuple_args__",
303
+ "__typing_is_unpacked_typevartuple__",
304
+ }:
305
+ raise AttributeError
306
+
307
+ try:
308
+ cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
309
+ except Exception as e:
310
+ raise add_note(f"Error processing class {self.__egg_tp__.name}", e) from None
311
+
312
+ preserved_methods = cls_decl.preserved_methods
313
+ if name in preserved_methods:
314
+ return preserved_methods[name].__get__(self)
315
+
316
+ # if this is a class variable, return an expr for it, otherwise, assume it's a method
317
+ if name in cls_decl.class_variables:
318
+ return_tp = cls_decl.class_variables[name]
319
+ return RuntimeExpr(
320
+ self.__egg_decls_thunk__,
321
+ Thunk.value(TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name)))),
322
+ )
323
+ if name in cls_decl.class_methods:
324
+ return RuntimeFunction(
325
+ self.__egg_decls_thunk__,
326
+ Thunk.value(ClassMethodRef(self.__egg_tp__.name, name)),
327
+ self.__egg_tp__.to_just(),
328
+ )
329
+ # allow referencing properties and methods as class variables as well
330
+ if name in cls_decl.properties:
331
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_tp__.name, name)))
332
+ if name in cls_decl.methods:
333
+ return RuntimeFunction(self.__egg_decls_thunk__, Thunk.value(MethodRef(self.__egg_tp__.name, name)))
334
+
335
+ msg = f"Class {self.__egg_tp__.name} has no method {name}"
336
+ raise AttributeError(msg) from None
337
+
338
+ def __str__(self) -> str:
339
+ return str(self.__egg_tp__)
340
+
341
+ def __repr__(self) -> str:
342
+ return str(self)
343
+
344
+ # Make hashable so can go in Union
345
+ def __hash__(self) -> int:
346
+ return hash(self.__egg_tp__)
347
+
348
+ def __eq__(self, other: object) -> bool:
349
+ """
350
+ Support equality for runtime comparison of egglog classes.
351
+ """
352
+ if not isinstance(other, RuntimeClass):
353
+ return NotImplemented
354
+ return self.__egg_tp__ == other.__egg_tp__
355
+
356
+ # Support unioning like types
357
+ def __or__(self, value: type) -> object:
358
+ return Union[self, value] # noqa: UP007
359
+
360
+ @property
361
+ def __parameters__(self) -> tuple[object, ...]:
362
+ """
363
+ Emit a number of typevar params so that when using generic type aliases, we know how to resolve these properly.
364
+ """
365
+ return tuple(inverse_resolve_type_annotation(self.__egg_decls_thunk__, tp) for tp in self.__egg_tp__.args)
366
+
367
+ @property
368
+ def __match_args__(self) -> tuple[str, ...]:
369
+ return self.__egg_decls__._classes[self.__egg_tp__.name].match_args
370
+
371
+
372
+ @dataclass
373
+ class RuntimeFunction(DelayedDeclerations):
374
+ __egg_ref_thunk__: Callable[[], CallableRef]
375
+ # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
376
+ __egg_bound__: JustTypeRef | RuntimeExpr | None = None
377
+
378
+ def __eq__(self, other: object) -> bool:
379
+ """
380
+ Support equality for runtime comparison of egglog functions.
381
+ """
382
+ if not isinstance(other, RuntimeFunction):
383
+ return NotImplemented
384
+ return self.__egg_ref__ == other.__egg_ref__ and bool(self.__egg_bound__ == other.__egg_bound__)
385
+
386
+ def __hash__(self) -> int:
387
+ return hash((self.__egg_ref__, self.__egg_bound__))
388
+
389
+ @property
390
+ def __egg_ref__(self) -> CallableRef:
391
+ return self.__egg_ref_thunk__()
392
+
393
+ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
394
+ from .conversion import resolve_literal # noqa: PLC0415
395
+
396
+ if isinstance(self.__egg_bound__, RuntimeExpr):
397
+ args = (self.__egg_bound__, *args)
398
+ try:
399
+ signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).signature
400
+ except Exception as e:
401
+ raise add_note(f"Failed to find callable {self}", e) # noqa: B904
402
+ decls = self.__egg_decls__.copy()
403
+ # Special case function application bc we dont support variadic generics yet generally
404
+ if signature == "fn-app":
405
+ fn, *rest_args = args
406
+ args = tuple(rest_args)
407
+ assert not kwargs
408
+ assert isinstance(fn, RuntimeExpr)
409
+ decls.update(fn)
410
+ function_value = fn.__egg_typed_expr__
411
+ fn_tp = function_value.tp
412
+ assert fn_tp.name == "UnstableFn"
413
+ fn_return_tp, *fn_arg_tps = fn_tp.args
414
+ signature = FunctionSignature(
415
+ tuple(tp.to_var() for tp in fn_arg_tps),
416
+ tuple(f"_{i}" for i in range(len(fn_arg_tps))),
417
+ (None,) * len(fn_arg_tps),
418
+ fn_return_tp.to_var(),
419
+ )
420
+ else:
421
+ function_value = None
422
+ assert isinstance(signature, FunctionSignature)
423
+
424
+ # Turn all keyword args into positional args
425
+ py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
426
+ try:
427
+ bound = py_signature.bind(*args, **kwargs)
428
+ except TypeError as err:
429
+ raise TypeError(f"Failed to bind arguments for {self} with args {args} and kwargs {kwargs}: {err}") from err
430
+ del kwargs
431
+ bound.apply_defaults()
432
+ assert not bound.kwargs
433
+ args = bound.args
434
+
435
+ tcs = TypeConstraintSolver(decls)
436
+ bound_tp = (
437
+ None
438
+ if self.__egg_bound__ is None
439
+ else self.__egg_bound__.__egg_typed_expr__.tp
440
+ if isinstance(self.__egg_bound__, RuntimeExpr)
441
+ else self.__egg_bound__
442
+ )
443
+ if (
444
+ bound_tp
445
+ and bound_tp.args
446
+ # Don't bind class if we have a first class function arg, b/c we don't support that yet
447
+ and not function_value
448
+ ):
449
+ tcs.bind_class(bound_tp)
450
+ assert (operator.ge if signature.var_arg_type else operator.eq)(len(args), len(signature.arg_types))
451
+ cls_name = bound_tp.name if bound_tp else None
452
+ upcasted_args = [
453
+ resolve_literal(cast("TypeOrVarRef", tp), arg, Thunk.value(decls), tcs=tcs, cls_name=cls_name)
454
+ for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
455
+ ]
456
+ decls.update(*upcasted_args)
457
+ arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
458
+ return_tp = tcs.substitute_typevars(signature.semantic_return_type, cls_name)
459
+ bound_params = (
460
+ cast("JustTypeRef", bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef | InitRef) else None
461
+ )
462
+ # If we were using unstable-app to call a funciton, add that function back as the first arg.
463
+ if function_value:
464
+ arg_exprs = (function_value, *arg_exprs)
465
+ expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
466
+ typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
467
+ # If there is not return type, we are mutating the first arg
468
+ if not signature.return_type:
469
+ first_arg = upcasted_args[0]
470
+ first_arg.__egg_decls_thunk__ = Thunk.value(decls)
471
+ first_arg.__egg_typed_expr_thunk__ = Thunk.value(typed_expr_decl)
472
+ return None
473
+ return RuntimeExpr.__from_values__(decls, typed_expr_decl)
474
+
475
+ def __str__(self) -> str:
476
+ first_arg, bound_tp_params = None, None
477
+ match self.__egg_bound__:
478
+ case RuntimeExpr(_):
479
+ first_arg = self.__egg_bound__.__egg_typed_expr__.expr
480
+ case JustTypeRef(_, args):
481
+ bound_tp_params = args
482
+ return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
483
+
484
+ def __repr__(self) -> str:
485
+ return str(self)
486
+
487
+
488
+ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
489
+ """
490
+ Convert to a Python signature.
491
+
492
+ If optional_args is true, then all args will be treated as optional, as if a default was provided that makes them
493
+ a var with that arg name as the value.
494
+
495
+ Used for partial application to try binding a function with only some of its args.
496
+ """
497
+ parameters = [
498
+ Parameter(
499
+ n,
500
+ Parameter.POSITIONAL_OR_KEYWORD,
501
+ default=RuntimeExpr.__from_values__(decls, TypedExprDecl(t.to_just(), d or LetRefDecl(n)))
502
+ if d is not None or optional_args
503
+ else Parameter.empty,
504
+ )
505
+ for n, d, t in zip(sig.arg_names, sig.arg_defaults, sig.arg_types, strict=True)
506
+ ]
507
+ if isinstance(sig, FunctionSignature) and sig.var_arg_type is not None:
508
+ parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
509
+ return Signature(parameters)
510
+
511
+
512
+ @dataclass
513
+ class RuntimeExpr(DelayedDeclerations):
514
+ __egg_typed_expr_thunk__: Callable[[], TypedExprDecl]
515
+
516
+ @classmethod
517
+ def __from_values__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
518
+ return cls(Thunk.value(d), Thunk.value(e))
519
+
520
+ def __with_expr__(self, e: TypedExprDecl) -> RuntimeExpr:
521
+ return RuntimeExpr(self.__egg_decls_thunk__, Thunk.value(e))
522
+
523
+ @property
524
+ def __egg_typed_expr__(self) -> TypedExprDecl:
525
+ return self.__egg_typed_expr_thunk__()
526
+
527
+ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
528
+ if (method := _get_expr_method(self, name)) is not None:
529
+ return method
530
+ if name in self.__egg_class_decl__.properties:
531
+ fn = RuntimeFunction(
532
+ self.__egg_decls_thunk__, Thunk.value(PropertyRef(self.__egg_class_name__, name)), self
533
+ )
534
+ return fn()
535
+ raise AttributeError(f"{self.__egg_class_name__} has no method {name}") from None
536
+
537
+ def __repr__(self) -> str:
538
+ """
539
+ The repr of the expr is the pretty printed version of the expr.
540
+ """
541
+ return str(self)
542
+
543
+ def __str__(self) -> str:
544
+ return self.__egg_pretty__(None)
545
+
546
+ def __egg_pretty__(self, wrapping_fn: str | None) -> str:
547
+ return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
548
+
549
+ def _ipython_display_(self) -> None:
550
+ from IPython.display import Code, display # noqa: PLC0415
551
+
552
+ display(Code(str(self), language="python"))
553
+
554
+ def __dir__(self) -> Iterable[str]:
555
+ class_decl = self.__egg_class_decl__
556
+ return list(class_decl.methods) + list(class_decl.properties) + list(class_decl.preserved_methods)
557
+
558
+ @property
559
+ def __egg_class_name__(self) -> str:
560
+ return self.__egg_typed_expr__.tp.name
561
+
562
+ @property
563
+ def __egg_class_decl__(self) -> ClassDecl:
564
+ return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
565
+
566
+ # Implement these so that copy() works on this object
567
+ # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
568
+
569
+ def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
570
+ return self.__egg_decls__, self.__egg_typed_expr__
571
+
572
+ def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
573
+ self.__egg_decls_thunk__ = Thunk.value(d[0])
574
+ self.__egg_typed_expr_thunk__ = Thunk.value(d[1])
575
+
576
+ def __hash__(self) -> int:
577
+ if (method := _get_expr_method(self, "__hash__")) is not None:
578
+ return cast("int", cast("Any", method()))
579
+ return hash(self.__egg_typed_expr__)
580
+
581
+ # Implement this directly to special case behavior where it transforms to an egraph equality, if it is not a
582
+ # preserved method or defined on the class
583
+ def __eq__(self, other: object) -> object: # type: ignore[override]
584
+ if (method := _get_expr_method(self, "__eq__")) is not None:
585
+ return method(other)
586
+
587
+ # TODO: Check if two objects can be upcasted to be the same. If not, then return NotImplemented so other
588
+ # expr gets a chance to resolve __eq__ which could be a preserved method.
589
+ from .egraph import BaseExpr, eq # noqa: PLC0415
590
+
591
+ return eq(cast("BaseExpr", self)).to(cast("BaseExpr", other))
592
+
593
+ def __ne__(self, other: object) -> object: # type: ignore[override]
594
+ if (method := _get_expr_method(self, "__ne__")) is not None:
595
+ return method(other)
596
+
597
+ from .egraph import BaseExpr, ne # noqa: PLC0415
598
+
599
+ return ne(cast("BaseExpr", self)).to(cast("BaseExpr", other))
600
+
601
+ def __call__(
602
+ self, *args: object, **kwargs: object
603
+ ) -> object: # define it here only for type checking, it will be overriden below
604
+ ...
605
+
606
+
607
+ def _get_expr_method(expr: RuntimeExpr, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
608
+ if name in (preserved_methods := expr.__egg_class_decl__.preserved_methods):
609
+ return preserved_methods[name].__get__(expr)
610
+
611
+ if name in expr.__egg_class_decl__.methods:
612
+ return RuntimeFunction(expr.__egg_decls_thunk__, Thunk.value(MethodRef(expr.__egg_class_name__, name)), expr)
613
+ return None
614
+
615
+
616
+ def define_expr_method(name: str) -> None:
617
+ """
618
+ Given the name of a method, explicitly defines it on the runtime type that holds `Expr` objects as a method.
619
+
620
+ Call this if you need a method to be defined on the type itself where overrindg with `__getattr__` does not suffice,
621
+ like for NumPy's `__array_ufunc__`.
622
+ """
623
+
624
+ def _defined_method(self: RuntimeExpr, *args, __name: str = name, **kwargs):
625
+ fn = _get_expr_method(self, __name)
626
+ if fn is None:
627
+ raise TypeError(f"{self.__egg_class_name__} expression has no method {__name}")
628
+ return fn(*args, **kwargs)
629
+
630
+ setattr(RuntimeExpr, name, _defined_method)
631
+
632
+
633
+ for name in TYPE_DEFINED_METHODS:
634
+ define_expr_method(name)
635
+
636
+
637
+ for name, r_method in itertools.product(NUMERIC_BINARY_METHODS, (False, True)):
638
+
639
+ def _numeric_binary_method(self: object, other: object, name: str = name, r_method: bool = r_method) -> object:
640
+ """
641
+ Implements numeric binary operations.
642
+
643
+ Tries to find the minimum cost conversion of either the LHS or the RHS, by finding all methods with either
644
+ the LHS or the RHS as exactly the right type and then upcasting the other to that type.
645
+ """
646
+ # 1. switch if reversed method
647
+ if r_method:
648
+ self, other = other, self
649
+ # If the types don't exactly match to start, then we need to try converting one of them, by finding the cheapest conversion
650
+ if not (
651
+ isinstance(self, RuntimeExpr)
652
+ and isinstance(other, RuntimeExpr)
653
+ and (
654
+ self.__egg_decls__.check_binary_method_with_types(
655
+ name, self.__egg_typed_expr__.tp, other.__egg_typed_expr__.tp
656
+ )
657
+ )
658
+ ):
659
+ from .conversion import min_binary_conversion, resolve_type # noqa: PLC0415
660
+
661
+ best_method = min_binary_conversion(name, resolve_type(self), resolve_type(other))
662
+
663
+ if not best_method:
664
+ raise RuntimeError(f"Cannot resolve {name} for {self} and {other}, no conversion found")
665
+ self, other = best_method[0](self), best_method[1](other)
666
+
667
+ method_ref = MethodRef(self.__egg_class_name__, name)
668
+ fn = RuntimeFunction(Thunk.value(self.__egg_decls__), Thunk.value(method_ref), self)
669
+ return fn(other)
670
+
671
+ method_name = f"__r{name[2:]}" if r_method else name
672
+ setattr(RuntimeExpr, method_name, _numeric_binary_method)
673
+
674
+
675
+ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
676
+ """
677
+ Resolves a runtime callable into a ref
678
+ """
679
+ # TODO: Make runtime class work with __match_args__
680
+ if isinstance(callable, RuntimeClass):
681
+ return InitRef(callable.__egg_tp__.name), callable.__egg_decls__
682
+ match callable:
683
+ case RuntimeFunction(decls, ref, _):
684
+ return ref(), decls()
685
+ case RuntimeExpr(decl_thunk, expr_thunk):
686
+ if not isinstance((expr := expr_thunk().expr), CallDecl) or not isinstance(
687
+ expr.callable, ConstantRef | ClassVariableRef
688
+ ):
689
+ raise NotImplementedError(f"Can only turn constants or classvars into callable refs, not {expr}")
690
+ return expr.callable, decl_thunk()
691
+ case _:
692
+ raise NotImplementedError(f"Cannot turn {callable} of type {type(callable)} into a callable ref")
693
+
694
+
695
+ def create_callable(decls: Declarations, ref: CallableRef) -> RuntimeClass | RuntimeFunction | RuntimeExpr:
696
+ """
697
+ Creates a callable object from a callable ref. This might not actually be callable, if the ref is a constant
698
+ or classvar then it is a value
699
+ """
700
+ match ref:
701
+ case InitRef(name):
702
+ return RuntimeClass(Thunk.value(decls), TypeRefWithVars(name))
703
+ case FunctionRef() | MethodRef() | ClassMethodRef() | PropertyRef() | UnnamedFunctionRef():
704
+ bound = JustTypeRef(ref.class_name) if isinstance(ref, ClassMethodRef) else None
705
+ return RuntimeFunction(Thunk.value(decls), Thunk.value(ref), bound)
706
+ case ConstantRef(name):
707
+ tp = decls._constants[name].type_ref
708
+ case ClassVariableRef(cls_name, var_name):
709
+ tp = decls._classes[cls_name].class_variables[var_name].type_ref
710
+ case _:
711
+ assert_never(ref)
712
+ return RuntimeExpr.__from_values__(decls, TypedExprDecl(tp, CallDecl(ref)))