egglog 6.1.0__cp311-none-win_amd64.whl → 7.1.0__cp311-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp311-win_amd64.pyd +0 -0
- egglog/bindings.pyi +9 -0
- egglog/builtins.py +42 -2
- egglog/conversion.py +177 -0
- egglog/declarations.py +354 -734
- egglog/egraph.py +602 -800
- egglog/egraph_state.py +456 -0
- egglog/exp/array_api.py +100 -88
- egglog/exp/array_api_numba.py +6 -1
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +464 -0
- egglog/runtime.py +279 -431
- egglog/thunk.py +71 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/METADATA +7 -7
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/RECORD +19 -14
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/WHEEL +0 -0
- {egglog-6.1.0.dist-info → egglog-7.1.0.dist-info}/license_files/LICENSE +0 -0
egglog/runtime.py
CHANGED
|
@@ -11,177 +11,67 @@ so they are not mangled by Python and can be accessed by the user.
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
-
from
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from dataclasses import dataclass, replace
|
|
16
|
+
from inspect import Parameter, Signature
|
|
15
17
|
from itertools import zip_longest
|
|
16
18
|
from typing import TYPE_CHECKING, NoReturn, TypeVar, Union, cast, get_args, get_origin
|
|
17
19
|
|
|
18
|
-
import black
|
|
19
|
-
import black.parsing
|
|
20
|
-
from typing_extensions import assert_never
|
|
21
|
-
|
|
22
|
-
from . import bindings, config
|
|
23
20
|
from .declarations import *
|
|
24
|
-
from .
|
|
21
|
+
from .pretty import *
|
|
22
|
+
from .thunk import Thunk
|
|
25
23
|
from .type_constraint_solver import *
|
|
26
24
|
|
|
27
25
|
if TYPE_CHECKING:
|
|
28
|
-
from collections.abc import
|
|
26
|
+
from collections.abc import Iterable
|
|
29
27
|
|
|
30
28
|
from .egraph import Expr
|
|
31
29
|
|
|
32
30
|
__all__ = [
|
|
33
31
|
"LIT_CLASS_NAMES",
|
|
34
|
-
"class_to_ref",
|
|
35
|
-
"resolve_literal",
|
|
36
32
|
"resolve_callable",
|
|
37
33
|
"resolve_type_annotation",
|
|
38
|
-
"convert_to_same_type",
|
|
39
34
|
"RuntimeClass",
|
|
40
|
-
"RuntimeParamaterizedClass",
|
|
41
|
-
"RuntimeClassMethod",
|
|
42
35
|
"RuntimeExpr",
|
|
43
36
|
"RuntimeFunction",
|
|
44
|
-
"
|
|
45
|
-
"converter",
|
|
37
|
+
"REFLECTED_BINARY_METHODS",
|
|
46
38
|
]
|
|
47
39
|
|
|
48
40
|
|
|
49
|
-
BLACK_MODE = black.Mode(line_length=180)
|
|
50
|
-
|
|
51
41
|
UNIT_CLASS_NAME = "Unit"
|
|
52
42
|
UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
|
|
53
43
|
LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
|
|
54
44
|
|
|
45
|
+
REFLECTED_BINARY_METHODS = {
|
|
46
|
+
"__radd__": "__add__",
|
|
47
|
+
"__rsub__": "__sub__",
|
|
48
|
+
"__rmul__": "__mul__",
|
|
49
|
+
"__rmatmul__": "__matmul__",
|
|
50
|
+
"__rtruediv__": "__truediv__",
|
|
51
|
+
"__rfloordiv__": "__floordiv__",
|
|
52
|
+
"__rmod__": "__mod__",
|
|
53
|
+
"__rpow__": "__pow__",
|
|
54
|
+
"__rlshift__": "__lshift__",
|
|
55
|
+
"__rrshift__": "__rshift__",
|
|
56
|
+
"__rand__": "__and__",
|
|
57
|
+
"__rxor__": "__xor__",
|
|
58
|
+
"__ror__": "__or__",
|
|
59
|
+
}
|
|
60
|
+
|
|
55
61
|
# Set this globally so we can get access to PyObject when we have a type annotation of just object.
|
|
56
62
|
# This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
|
|
57
63
|
_PY_OBJECT_CLASS: RuntimeClass | None = None
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
# Converters
|
|
61
|
-
##
|
|
62
|
-
|
|
63
|
-
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
|
|
64
|
-
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
|
|
65
|
-
# Global declerations to store all convertable types so we can query if they have certain methods or not
|
|
66
|
-
CONVERSIONS_DECLS = Declarations()
|
|
64
|
+
# Same for functions
|
|
65
|
+
_UNSTABLE_FN_CLASS: RuntimeClass | None = None
|
|
67
66
|
|
|
68
67
|
T = TypeVar("T")
|
|
69
|
-
V = TypeVar("V", bound="Expr")
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class ConvertError(Exception):
|
|
73
|
-
pass
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None:
|
|
77
|
-
"""
|
|
78
|
-
Register a converter from some type to an egglog type.
|
|
79
|
-
"""
|
|
80
|
-
to_type_name = process_tp(to_type)
|
|
81
|
-
if not isinstance(to_type_name, JustTypeRef):
|
|
82
|
-
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
|
|
83
|
-
_register_converter(process_tp(from_type), to_type_name, fn, cost)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
|
|
87
|
-
"""
|
|
88
|
-
Registers a converter from some type to an egglog type, if not already registered.
|
|
89
|
-
|
|
90
|
-
Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
|
|
91
|
-
Also, if registering A->B and there is already D->A, then D->B will be registered.
|
|
92
|
-
"""
|
|
93
|
-
if a == b:
|
|
94
|
-
return
|
|
95
|
-
if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost:
|
|
96
|
-
return
|
|
97
|
-
CONVERSIONS[(a, b)] = (cost, a_b)
|
|
98
|
-
for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()):
|
|
99
|
-
if b == c:
|
|
100
|
-
_register_converter(a, d, _ComposedConverter(a_b, c_d), cost + other_cost)
|
|
101
|
-
if a == d:
|
|
102
|
-
_register_converter(c, b, _ComposedConverter(c_d, a_b), cost + other_cost)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
@dataclass
|
|
106
|
-
class _ComposedConverter:
|
|
107
|
-
"""
|
|
108
|
-
A converter which is composed of multiple converters.
|
|
109
|
-
|
|
110
|
-
_ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))
|
|
111
|
-
|
|
112
|
-
We use the dataclass instead of the lambda to make it easier to debug.
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
a_b: Callable
|
|
116
|
-
b_c: Callable
|
|
117
|
-
|
|
118
|
-
def __call__(self, x: object) -> object:
|
|
119
|
-
return self.b_c(self.a_b(x))
|
|
120
|
-
|
|
121
|
-
def __str__(self) -> str:
|
|
122
|
-
return f"{self.b_c} ∘ {self.a_b}"
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def convert(source: object, target: type[V]) -> V:
|
|
126
|
-
"""
|
|
127
|
-
Convert a source object to a target type.
|
|
128
|
-
"""
|
|
129
|
-
target_ref = class_to_ref(cast(RuntimeTypeArgType, target))
|
|
130
|
-
return cast(V, resolve_literal(target_ref.to_var(), source))
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
|
|
134
|
-
"""
|
|
135
|
-
Convert a source object to the same type as the target.
|
|
136
|
-
"""
|
|
137
|
-
tp = target.__egg_typed_expr__.tp
|
|
138
|
-
return resolve_literal(tp.to_var(), source)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def process_tp(tp: type | RuntimeTypeArgType) -> JustTypeRef | type:
|
|
142
|
-
"""
|
|
143
|
-
Process a type before converting it, to add it to the global declerations and resolve to a ref.
|
|
144
|
-
"""
|
|
145
|
-
global CONVERSIONS_DECLS
|
|
146
|
-
if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass):
|
|
147
|
-
CONVERSIONS_DECLS |= tp
|
|
148
|
-
return class_to_ref(tp)
|
|
149
|
-
return tp
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
|
|
153
|
-
"""
|
|
154
|
-
Returns the minimum convertable type between a and b, that has a method `name`, raising a TypeError if no such type exists.
|
|
155
|
-
"""
|
|
156
|
-
a_tp = _get_tp(a)
|
|
157
|
-
b_tp = _get_tp(b)
|
|
158
|
-
a_converts_to = {
|
|
159
|
-
to: c
|
|
160
|
-
for ((from_, to), (c, _)) in CONVERSIONS.items()
|
|
161
|
-
if from_ == a_tp and CONVERSIONS_DECLS.has_method(to.name, name)
|
|
162
|
-
}
|
|
163
|
-
b_converts_to = {
|
|
164
|
-
to: c
|
|
165
|
-
for ((from_, to), (c, _)) in CONVERSIONS.items()
|
|
166
|
-
if from_ == b_tp and CONVERSIONS_DECLS.has_method(to.name, name)
|
|
167
|
-
}
|
|
168
|
-
if isinstance(a_tp, JustTypeRef):
|
|
169
|
-
a_converts_to[a_tp] = 0
|
|
170
|
-
if isinstance(b_tp, JustTypeRef):
|
|
171
|
-
b_converts_to[b_tp] = 0
|
|
172
|
-
common = set(a_converts_to) & set(b_converts_to)
|
|
173
|
-
if not common:
|
|
174
|
-
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
|
|
175
|
-
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def identity(x: object) -> object:
|
|
179
|
-
return x
|
|
180
68
|
|
|
181
69
|
|
|
182
70
|
def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
|
|
183
71
|
"""
|
|
184
72
|
Resolves a type object into a type reference.
|
|
73
|
+
|
|
74
|
+
Any runtime type object decls will be add to those passed in.
|
|
185
75
|
"""
|
|
186
76
|
if isinstance(tp, TypeVar):
|
|
187
77
|
return ClassTypeVarRef(tp.__name__)
|
|
@@ -194,100 +84,92 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
|
|
|
194
84
|
if tp == object:
|
|
195
85
|
assert _PY_OBJECT_CLASS
|
|
196
86
|
return resolve_type_annotation(decls, _PY_OBJECT_CLASS)
|
|
87
|
+
# If the type is a `Callable` then convert it into a UnstableFn
|
|
88
|
+
if get_origin(tp) == Callable:
|
|
89
|
+
assert _UNSTABLE_FN_CLASS
|
|
90
|
+
args, ret = get_args(tp)
|
|
91
|
+
return resolve_type_annotation(decls, _UNSTABLE_FN_CLASS[(ret, *args)])
|
|
197
92
|
if isinstance(tp, RuntimeClass):
|
|
198
|
-
decls |= tp
|
|
199
|
-
return tp.__egg_tp__.to_var()
|
|
200
|
-
if isinstance(tp, RuntimeParamaterizedClass):
|
|
201
93
|
decls |= tp
|
|
202
94
|
return tp.__egg_tp__
|
|
203
95
|
raise TypeError(f"Unexpected type annotation {tp}")
|
|
204
96
|
|
|
205
97
|
|
|
206
|
-
def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
|
|
207
|
-
arg_type = _get_tp(arg)
|
|
208
|
-
|
|
209
|
-
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
|
|
210
|
-
try:
|
|
211
|
-
tp_just = tp.to_just()
|
|
212
|
-
except NotImplementedError:
|
|
213
|
-
# If this is a var, it has to be a runtime exprssions
|
|
214
|
-
assert isinstance(arg, RuntimeExpr)
|
|
215
|
-
return arg
|
|
216
|
-
if arg_type == tp_just:
|
|
217
|
-
# If the type is an egg type, it has to be a runtime expr
|
|
218
|
-
assert isinstance(arg, RuntimeExpr)
|
|
219
|
-
return arg
|
|
220
|
-
# Try all parent types as well, if we are converting from a Python type
|
|
221
|
-
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
|
|
222
|
-
try:
|
|
223
|
-
fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1]
|
|
224
|
-
except KeyError:
|
|
225
|
-
continue
|
|
226
|
-
break
|
|
227
|
-
else:
|
|
228
|
-
arg_type_str = arg_type.pretty() if isinstance(arg_type, JustTypeRef) else arg_type.__name__
|
|
229
|
-
raise ConvertError(f"Cannot convert {arg_type_str} to {tp_just.pretty()}")
|
|
230
|
-
return fn(arg)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def _get_tp(x: object) -> JustTypeRef | type:
|
|
234
|
-
if isinstance(x, RuntimeExpr):
|
|
235
|
-
return x.__egg_typed_expr__.tp
|
|
236
|
-
tp = type(x)
|
|
237
|
-
# If this value has a custom metaclass, let's use that as our index instead of the type
|
|
238
|
-
if type(tp) != type:
|
|
239
|
-
return type(tp)
|
|
240
|
-
return tp
|
|
241
|
-
|
|
242
|
-
|
|
243
98
|
##
|
|
244
99
|
# Runtime objects
|
|
245
100
|
##
|
|
246
101
|
|
|
247
102
|
|
|
248
103
|
@dataclass
|
|
249
|
-
class RuntimeClass:
|
|
250
|
-
|
|
251
|
-
# This function should mutate the declerations and add to them
|
|
252
|
-
# Used this instead of a lazy property so we can have a reference to the decls in the class as its computing
|
|
253
|
-
lazy_decls: Callable[[Declarations], None] = field(repr=False)
|
|
254
|
-
# Cached declerations
|
|
255
|
-
_inner_decls: Declarations | None = field(init=False, repr=False, default=None)
|
|
256
|
-
__egg_name__: str
|
|
104
|
+
class RuntimeClass(DelayedDeclerations):
|
|
105
|
+
__egg_tp__: TypeRefWithVars
|
|
257
106
|
|
|
258
107
|
def __post_init__(self) -> None:
|
|
259
|
-
global _PY_OBJECT_CLASS
|
|
260
|
-
if self.
|
|
108
|
+
global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
|
|
109
|
+
if (name := self.__egg_tp__.name) == "PyObject":
|
|
261
110
|
_PY_OBJECT_CLASS = self
|
|
111
|
+
elif name == "UnstableFn" and not self.__egg_tp__.args:
|
|
112
|
+
_UNSTABLE_FN_CLASS = self
|
|
262
113
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
114
|
+
def verify(self) -> None:
|
|
115
|
+
if not self.__egg_tp__.args:
|
|
116
|
+
return
|
|
117
|
+
|
|
118
|
+
# Raise error if we have args, but they are the wrong number
|
|
119
|
+
desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
|
|
120
|
+
if len(self.__egg_tp__.args) != len(desired_args):
|
|
121
|
+
raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
|
|
270
122
|
|
|
271
123
|
def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
|
|
272
124
|
"""
|
|
273
125
|
Create an instance of this kind by calling the __init__ classmethod
|
|
274
126
|
"""
|
|
275
127
|
# If this is a literal type, initializing it with a literal should return a literal
|
|
276
|
-
if self.
|
|
128
|
+
if (name := self.__egg_tp__.name) == "PyObject":
|
|
277
129
|
assert len(args) == 1
|
|
278
|
-
return RuntimeExpr
|
|
279
|
-
|
|
130
|
+
return RuntimeExpr.__from_value__(
|
|
131
|
+
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
|
|
132
|
+
)
|
|
133
|
+
if name == "UnstableFn":
|
|
134
|
+
assert not kwargs
|
|
135
|
+
fn_arg, *partial_args = args
|
|
136
|
+
del args
|
|
137
|
+
# Assumes we don't have types set for UnstableFn w/ generics, that they have to be inferred
|
|
138
|
+
|
|
139
|
+
# 1. Create a runtime function for the first arg
|
|
140
|
+
assert isinstance(fn_arg, RuntimeFunction)
|
|
141
|
+
# 2. Call it with the partial args, and use untyped vars for the rest of the args
|
|
142
|
+
res = fn_arg(*partial_args, _egg_partial_function=True)
|
|
143
|
+
assert res is not None, "Mutable partial functions not supported"
|
|
144
|
+
# 3. Use the inferred return type and inferred rest arg types as the types of the function, and
|
|
145
|
+
# the partially applied args as the args.
|
|
146
|
+
call = (res_typed_expr := res.__egg_typed_expr__).expr
|
|
147
|
+
return_tp = res_typed_expr.tp
|
|
148
|
+
assert isinstance(call, CallDecl), "partial function must be a call"
|
|
149
|
+
n_args = len(partial_args)
|
|
150
|
+
value = PartialCallDecl(replace(call, args=call.args[:n_args]))
|
|
151
|
+
remaining_arg_types = [a.tp for a in call.args[n_args:]]
|
|
152
|
+
type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
|
|
153
|
+
return RuntimeExpr.__from_value__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
|
|
154
|
+
|
|
155
|
+
if name in UNARY_LIT_CLASS_NAMES:
|
|
280
156
|
assert len(args) == 1
|
|
281
157
|
assert isinstance(args[0], int | float | str | bool)
|
|
282
|
-
return RuntimeExpr
|
|
283
|
-
|
|
158
|
+
return RuntimeExpr.__from_value__(
|
|
159
|
+
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
|
|
160
|
+
)
|
|
161
|
+
if name == UNIT_CLASS_NAME:
|
|
284
162
|
assert len(args) == 0
|
|
285
|
-
return RuntimeExpr
|
|
163
|
+
return RuntimeExpr.__from_value__(
|
|
164
|
+
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
|
|
165
|
+
)
|
|
286
166
|
|
|
287
|
-
return
|
|
167
|
+
return RuntimeFunction(
|
|
168
|
+
Thunk.value(self.__egg_decls__), ClassMethodRef(name, "__init__"), self.__egg_tp__.to_just()
|
|
169
|
+
)(*args, **kwargs) # type: ignore[arg-type]
|
|
288
170
|
|
|
289
171
|
def __dir__(self) -> list[str]:
|
|
290
|
-
cls_decl = self.__egg_decls__.get_class_decl(self.
|
|
172
|
+
cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
|
|
291
173
|
possible_methods = (
|
|
292
174
|
list(cls_decl.class_methods) + list(cls_decl.class_variables) + list(cls_decl.preserved_methods)
|
|
293
175
|
)
|
|
@@ -296,14 +178,19 @@ class RuntimeClass:
|
|
|
296
178
|
possible_methods.append("__call__")
|
|
297
179
|
return possible_methods
|
|
298
180
|
|
|
299
|
-
def __getitem__(self, args: object) ->
|
|
181
|
+
def __getitem__(self, args: object) -> RuntimeClass:
|
|
182
|
+
if self.__egg_tp__.args:
|
|
183
|
+
raise TypeError(f"Cannot index into a paramaterized class {self}")
|
|
300
184
|
if not isinstance(args, tuple):
|
|
301
185
|
args = (args,)
|
|
302
186
|
decls = self.__egg_decls__.copy()
|
|
303
|
-
tp = TypeRefWithVars(self.
|
|
304
|
-
return
|
|
187
|
+
tp = TypeRefWithVars(self.__egg_tp__.name, tuple(resolve_type_annotation(decls, arg) for arg in args))
|
|
188
|
+
return RuntimeClass(Thunk.value(decls), tp)
|
|
189
|
+
|
|
190
|
+
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
|
|
191
|
+
if name == "__origin__" and self.__egg_tp__.args:
|
|
192
|
+
return RuntimeClass(self.__egg_decls_thunk__, TypeRefWithVars(self.__egg_tp__.name))
|
|
305
193
|
|
|
306
|
-
def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr | Callable:
|
|
307
194
|
# Special case some names that don't exist so we can exit early without resolving decls
|
|
308
195
|
# Important so if we take union of RuntimeClass it won't try to resolve decls
|
|
309
196
|
if name in {
|
|
@@ -314,7 +201,7 @@ class RuntimeClass:
|
|
|
314
201
|
}:
|
|
315
202
|
raise AttributeError
|
|
316
203
|
|
|
317
|
-
cls_decl = self.__egg_decls__.
|
|
204
|
+
cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
|
|
318
205
|
|
|
319
206
|
preserved_methods = cls_decl.preserved_methods
|
|
320
207
|
if name in preserved_methods:
|
|
@@ -323,159 +210,151 @@ class RuntimeClass:
|
|
|
323
210
|
# if this is a class variable, return an expr for it, otherwise, assume it's a method
|
|
324
211
|
if name in cls_decl.class_variables:
|
|
325
212
|
return_tp = cls_decl.class_variables[name]
|
|
326
|
-
return RuntimeExpr(
|
|
327
|
-
self.__egg_decls__,
|
|
213
|
+
return RuntimeExpr.__from_value__(
|
|
214
|
+
self.__egg_decls__,
|
|
215
|
+
TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
|
|
328
216
|
)
|
|
329
|
-
|
|
217
|
+
if name in cls_decl.class_methods:
|
|
218
|
+
return RuntimeFunction(
|
|
219
|
+
Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
|
|
220
|
+
)
|
|
221
|
+
# allow referencing properties and methods as class variables as well
|
|
222
|
+
if name in cls_decl.properties:
|
|
223
|
+
return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(self.__egg_tp__.name, name))
|
|
224
|
+
if name in cls_decl.methods:
|
|
225
|
+
return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(self.__egg_tp__.name, name))
|
|
226
|
+
|
|
227
|
+
msg = f"Class {self.__egg_tp__.name} has no method {name}"
|
|
228
|
+
if name == "__ne__":
|
|
229
|
+
msg += ". Did you mean to use the ne(...).to(...)?"
|
|
230
|
+
raise AttributeError(msg) from None
|
|
330
231
|
|
|
331
232
|
def __str__(self) -> str:
|
|
332
|
-
return self.
|
|
233
|
+
return str(self.__egg_tp__)
|
|
333
234
|
|
|
334
235
|
# Make hashable so can go in Union
|
|
335
236
|
def __hash__(self) -> int:
|
|
336
|
-
return hash((id(self.
|
|
237
|
+
return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
|
|
337
238
|
|
|
338
239
|
# Support unioning like types
|
|
339
240
|
def __or__(self, __value: type) -> object:
|
|
340
241
|
return Union[self, __value] # noqa: UP007
|
|
341
242
|
|
|
342
|
-
@property
|
|
343
|
-
def __egg_tp__(self) -> JustTypeRef:
|
|
344
|
-
return JustTypeRef(self.__egg_name__)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@dataclass
|
|
348
|
-
class RuntimeParamaterizedClass:
|
|
349
|
-
__egg_decls__: Declarations
|
|
350
|
-
__egg_tp__: TypeRefWithVars
|
|
351
|
-
|
|
352
|
-
def __post_init__(self) -> None:
|
|
353
|
-
desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
|
|
354
|
-
if len(self.__egg_tp__.args) != len(desired_args):
|
|
355
|
-
raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
|
|
356
|
-
|
|
357
|
-
def __call__(self, *args: object) -> RuntimeExpr | None:
|
|
358
|
-
return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), "__init__")(*args)
|
|
359
|
-
|
|
360
|
-
def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeClass:
|
|
361
|
-
# Special case so when get_type_annotations proccessed it can work
|
|
362
|
-
if name == "__origin__":
|
|
363
|
-
return RuntimeClass(self.__egg_decls__.update_other, self.__egg_tp__.name)
|
|
364
|
-
return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), name)
|
|
365
|
-
|
|
366
|
-
def __str__(self) -> str:
|
|
367
|
-
return self.__egg_tp__.pretty()
|
|
368
|
-
|
|
369
|
-
# Support unioning
|
|
370
|
-
def __or__(self, __value: type) -> object:
|
|
371
|
-
return Union[self, __value] # noqa: UP007
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
# Type args can either be typevars or classes
|
|
375
|
-
RuntimeTypeArgType = RuntimeClass | RuntimeParamaterizedClass
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def class_to_ref(cls: RuntimeTypeArgType) -> JustTypeRef:
|
|
379
|
-
if isinstance(cls, RuntimeClass):
|
|
380
|
-
return JustTypeRef(cls.__egg_name__)
|
|
381
|
-
if isinstance(cls, RuntimeParamaterizedClass):
|
|
382
|
-
# Currently this is used when calling methods on a parametrized class, which is only possible when we
|
|
383
|
-
# have actualy types currently, not typevars, currently.
|
|
384
|
-
return cls.__egg_tp__.to_just()
|
|
385
|
-
assert_never(cls)
|
|
386
|
-
|
|
387
243
|
|
|
388
244
|
@dataclass
|
|
389
|
-
class RuntimeFunction:
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
__egg_fn_decl__: FunctionDecl = field(init=False)
|
|
394
|
-
|
|
395
|
-
def __post_init__(self) -> None:
|
|
396
|
-
self.__egg_fn_ref__ = FunctionRef(self.__egg_name__)
|
|
397
|
-
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_fn_ref__)
|
|
398
|
-
|
|
399
|
-
def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
|
|
400
|
-
return _call(self.__egg_decls__, self.__egg_fn_ref__, self.__egg_fn_decl__, args, kwargs)
|
|
401
|
-
|
|
402
|
-
def __str__(self) -> str:
|
|
403
|
-
return self.__egg_name__
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
def _call(
|
|
407
|
-
decls_from_fn: Declarations,
|
|
408
|
-
callable_ref: CallableRef,
|
|
409
|
-
fn_decl: FunctionDecl,
|
|
410
|
-
args: Collection[object],
|
|
411
|
-
kwargs: dict[str, object],
|
|
412
|
-
bound_class: JustTypeRef | None = None,
|
|
413
|
-
) -> RuntimeExpr | None:
|
|
414
|
-
# Turn all keyword args into positional args
|
|
415
|
-
bound = fn_decl.to_signature(lambda expr: RuntimeExpr(decls_from_fn, expr)).bind(*args, **kwargs)
|
|
416
|
-
bound.apply_defaults()
|
|
417
|
-
assert not bound.kwargs
|
|
418
|
-
del args, kwargs
|
|
419
|
-
|
|
420
|
-
upcasted_args = [
|
|
421
|
-
resolve_literal(cast(TypeOrVarRef, tp), arg)
|
|
422
|
-
for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type)
|
|
423
|
-
]
|
|
245
|
+
class RuntimeFunction(DelayedDeclerations):
|
|
246
|
+
__egg_ref__: CallableRef
|
|
247
|
+
# bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
|
|
248
|
+
__egg_bound__: JustTypeRef | RuntimeExpr | None = None
|
|
424
249
|
|
|
425
|
-
|
|
426
|
-
|
|
250
|
+
def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
|
|
251
|
+
from .conversion import resolve_literal
|
|
427
252
|
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
253
|
+
if isinstance(self.__egg_bound__, RuntimeExpr):
|
|
254
|
+
args = (self.__egg_bound__, *args)
|
|
255
|
+
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
|
|
256
|
+
decls = self.__egg_decls__.copy()
|
|
257
|
+
# Special case function application bc we dont support variadic generics yet generally
|
|
258
|
+
if signature == "fn-app":
|
|
259
|
+
fn, *rest_args = args
|
|
260
|
+
args = tuple(rest_args)
|
|
261
|
+
assert not kwargs
|
|
262
|
+
assert isinstance(fn, RuntimeExpr)
|
|
263
|
+
decls.update(fn)
|
|
264
|
+
function_value = fn.__egg_typed_expr__
|
|
265
|
+
fn_tp = function_value.tp
|
|
266
|
+
assert fn_tp.name == "UnstableFn"
|
|
267
|
+
fn_return_tp, *fn_arg_tps = fn_tp.args
|
|
268
|
+
signature = FunctionSignature(
|
|
269
|
+
tuple(tp.to_var() for tp in fn_arg_tps),
|
|
270
|
+
tuple(f"_{i}" for i in range(len(fn_arg_tps))),
|
|
271
|
+
(None,) * len(fn_arg_tps),
|
|
272
|
+
fn_return_tp.to_var(),
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
function_value = None
|
|
276
|
+
assert isinstance(signature, FunctionSignature)
|
|
277
|
+
|
|
278
|
+
# Turn all keyword args into positional args
|
|
279
|
+
py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
|
|
280
|
+
bound = py_signature.bind(*args, **kwargs)
|
|
281
|
+
del kwargs
|
|
282
|
+
bound.apply_defaults()
|
|
283
|
+
assert not bound.kwargs
|
|
284
|
+
args = bound.args
|
|
285
|
+
|
|
286
|
+
upcasted_args = [
|
|
287
|
+
resolve_literal(cast(TypeOrVarRef, tp), arg)
|
|
288
|
+
for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
|
|
289
|
+
]
|
|
290
|
+
decls.update(*upcasted_args)
|
|
291
|
+
|
|
292
|
+
tcs = TypeConstraintSolver(decls)
|
|
293
|
+
bound_tp = (
|
|
294
|
+
None
|
|
295
|
+
if self.__egg_bound__ is None
|
|
296
|
+
else self.__egg_bound__.__egg_typed_expr__.tp
|
|
297
|
+
if isinstance(self.__egg_bound__, RuntimeExpr)
|
|
298
|
+
else self.__egg_bound__
|
|
299
|
+
)
|
|
300
|
+
if (
|
|
301
|
+
bound_tp
|
|
302
|
+
and bound_tp.args
|
|
303
|
+
# Don't bind class if we have a first class function arg, b/c we don't support that yet
|
|
304
|
+
and not function_value
|
|
305
|
+
):
|
|
306
|
+
tcs.bind_class(bound_tp)
|
|
307
|
+
arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
|
|
433
308
|
arg_types = [expr.tp for expr in arg_exprs]
|
|
434
|
-
cls_name =
|
|
309
|
+
cls_name = bound_tp.name if bound_tp else None
|
|
435
310
|
return_tp = tcs.infer_return_type(
|
|
436
|
-
|
|
311
|
+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, arg_types, cls_name
|
|
437
312
|
)
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
return
|
|
450
|
-
return RuntimeExpr(decls, TypedExprDecl(return_tp, expr_decl))
|
|
451
|
-
|
|
313
|
+
bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None
|
|
314
|
+
# If we were using unstable-app to call a funciton, add that function back as the first arg.
|
|
315
|
+
if function_value:
|
|
316
|
+
arg_exprs = (function_value, *arg_exprs)
|
|
317
|
+
expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
|
|
318
|
+
typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
|
|
319
|
+
# If there is not return type, we are mutating the first arg
|
|
320
|
+
if not signature.return_type:
|
|
321
|
+
first_arg = upcasted_args[0]
|
|
322
|
+
first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
|
|
323
|
+
return None
|
|
324
|
+
return RuntimeExpr.__from_value__(decls, typed_expr_decl)
|
|
452
325
|
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
326
|
+
def __str__(self) -> str:
|
|
327
|
+
first_arg, bound_tp_params = None, None
|
|
328
|
+
match self.__egg_bound__:
|
|
329
|
+
case RuntimeExpr(_):
|
|
330
|
+
first_arg = self.__egg_bound__.__egg_typed_expr__.expr
|
|
331
|
+
case JustTypeRef(_, args):
|
|
332
|
+
bound_tp_params = args
|
|
333
|
+
return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
|
|
460
334
|
|
|
461
|
-
def __post_init__(self) -> None:
|
|
462
|
-
self.__egg_callable_ref__ = ClassMethodRef(self.class_name, self.__egg_method_name__)
|
|
463
|
-
try:
|
|
464
|
-
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
|
|
465
|
-
except KeyError as e:
|
|
466
|
-
raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") from e
|
|
467
335
|
|
|
468
|
-
|
|
469
|
-
|
|
336
|
+
def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
|
|
337
|
+
"""
|
|
338
|
+
Convert to a Python signature.
|
|
470
339
|
|
|
471
|
-
|
|
472
|
-
|
|
340
|
+
If optional_args is true, then all args will be treated as optional, as if a default was provided that makes them
|
|
341
|
+
a var with that arg name as the value.
|
|
473
342
|
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
343
|
+
Used for partial application to try binding a function with only some of its args.
|
|
344
|
+
"""
|
|
345
|
+
parameters = [
|
|
346
|
+
Parameter(
|
|
347
|
+
n,
|
|
348
|
+
Parameter.POSITIONAL_OR_KEYWORD,
|
|
349
|
+
default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n)))
|
|
350
|
+
if d is not None or optional_args
|
|
351
|
+
else Parameter.empty,
|
|
352
|
+
)
|
|
353
|
+
for n, d, t in zip(sig.arg_names, sig.arg_defaults, sig.arg_types, strict=True)
|
|
354
|
+
]
|
|
355
|
+
if isinstance(sig, FunctionSignature) and sig.var_arg_type is not None:
|
|
356
|
+
parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
|
|
357
|
+
return Signature(parameters)
|
|
479
358
|
|
|
480
359
|
|
|
481
360
|
# All methods which should return NotImplemented if they fail to resolve
|
|
@@ -505,63 +384,34 @@ PARTIAL_METHODS = {
|
|
|
505
384
|
|
|
506
385
|
|
|
507
386
|
@dataclass
|
|
508
|
-
class
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
__egg_callable_ref__: MethodRef | PropertyRef = field(init=False)
|
|
512
|
-
__egg_fn_decl__: FunctionDecl = field(init=False, repr=False)
|
|
513
|
-
__egg_decls__: Declarations = field(init=False)
|
|
514
|
-
|
|
515
|
-
def __post_init__(self) -> None:
|
|
516
|
-
self.__egg_decls__ = self.__egg_self__.__egg_decls__
|
|
517
|
-
if self.__egg_method_name__ in self.__egg_decls__.get_class_decl(self.class_name).properties:
|
|
518
|
-
self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__)
|
|
519
|
-
else:
|
|
520
|
-
self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__)
|
|
521
|
-
try:
|
|
522
|
-
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
|
|
523
|
-
except KeyError:
|
|
524
|
-
msg = f"Class {self.class_name} does not have method {self.__egg_method_name__}"
|
|
525
|
-
if self.__egg_method_name__ == "__ne__":
|
|
526
|
-
msg += ". Did you mean to use the ne(...).to(...)?"
|
|
527
|
-
raise AttributeError(msg) from None
|
|
387
|
+
class RuntimeExpr:
|
|
388
|
+
# Defer needing decls/expr so we can make constants that don't resolve their class types
|
|
389
|
+
__egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
|
|
528
390
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
return _call(
|
|
533
|
-
self.__egg_decls__,
|
|
534
|
-
self.__egg_callable_ref__,
|
|
535
|
-
self.__egg_fn_decl__,
|
|
536
|
-
args,
|
|
537
|
-
kwargs,
|
|
538
|
-
self.__egg_self__.__egg_typed_expr__.tp,
|
|
539
|
-
)
|
|
540
|
-
except ConvertError as e:
|
|
541
|
-
name = self.__egg_method_name__
|
|
542
|
-
raise TypeError(f"Wrong types for {self.__egg_self__.__egg_typed_expr__.tp.pretty()}.{name}") from e
|
|
391
|
+
@classmethod
|
|
392
|
+
def __from_value__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
|
|
393
|
+
return cls(Thunk.value((d, e)))
|
|
543
394
|
|
|
544
395
|
@property
|
|
545
|
-
def
|
|
546
|
-
return self.
|
|
547
|
-
|
|
396
|
+
def __egg_decls__(self) -> Declarations:
|
|
397
|
+
return self.__egg_thunk__()[0]
|
|
548
398
|
|
|
549
|
-
@
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
__egg_typed_expr__: TypedExprDecl
|
|
399
|
+
@property
|
|
400
|
+
def __egg_typed_expr__(self) -> TypedExprDecl:
|
|
401
|
+
return self.__egg_thunk__()[1]
|
|
553
402
|
|
|
554
|
-
def __getattr__(self, name: str) ->
|
|
555
|
-
|
|
403
|
+
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
404
|
+
cls_name = self.__egg_class_name__
|
|
405
|
+
class_decl = self.__egg_class_decl__
|
|
556
406
|
|
|
557
|
-
preserved_methods
|
|
558
|
-
if name in preserved_methods:
|
|
407
|
+
if name in (preserved_methods := class_decl.preserved_methods):
|
|
559
408
|
return preserved_methods[name].__get__(self)
|
|
560
409
|
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
410
|
+
if name in class_decl.methods:
|
|
411
|
+
return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(cls_name, name), self)
|
|
412
|
+
if name in class_decl.properties:
|
|
413
|
+
return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(cls_name, name), self)()
|
|
414
|
+
raise AttributeError(f"{cls_name} has no method {name}") from None
|
|
565
415
|
|
|
566
416
|
def __repr__(self) -> str:
|
|
567
417
|
"""
|
|
@@ -570,18 +420,10 @@ class RuntimeExpr:
|
|
|
570
420
|
return str(self)
|
|
571
421
|
|
|
572
422
|
def __str__(self) -> str:
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
if config.SHOW_TYPES:
|
|
578
|
-
raise NotImplementedError
|
|
579
|
-
# s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}"
|
|
580
|
-
# return black.format_str(s, mode=black.FileMode()).strip()
|
|
581
|
-
pretty_statements = context.render(pretty_expr)
|
|
582
|
-
return black.format_str(pretty_statements, mode=BLACK_MODE).strip()
|
|
583
|
-
except black.parsing.InvalidInput:
|
|
584
|
-
return pretty_expr
|
|
423
|
+
return self.__egg_pretty__(None)
|
|
424
|
+
|
|
425
|
+
def __egg_pretty__(self, wrapping_fn: str | None) -> str:
|
|
426
|
+
return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
|
|
585
427
|
|
|
586
428
|
def _ipython_display_(self) -> None:
|
|
587
429
|
from IPython.display import Code, display
|
|
@@ -589,28 +431,32 @@ class RuntimeExpr:
|
|
|
589
431
|
display(Code(str(self), language="python"))
|
|
590
432
|
|
|
591
433
|
def __dir__(self) -> Iterable[str]:
|
|
592
|
-
|
|
434
|
+
class_decl = self.__egg_class_decl__
|
|
435
|
+
return list(class_decl.methods) + list(class_decl.properties) + list(class_decl.preserved_methods)
|
|
593
436
|
|
|
594
437
|
@property
|
|
595
|
-
def
|
|
596
|
-
return self.__egg_typed_expr__.
|
|
438
|
+
def __egg_class_name__(self) -> str:
|
|
439
|
+
return self.__egg_typed_expr__.tp.name
|
|
440
|
+
|
|
441
|
+
@property
|
|
442
|
+
def __egg_class_decl__(self) -> ClassDecl:
|
|
443
|
+
return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
|
|
597
444
|
|
|
598
445
|
# Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because
|
|
599
446
|
# we don't wany any type that MyPy thinks is an expr to be used with __eq__.
|
|
600
447
|
# That's because we want to reserve __eq__ for domain specific equality checks, overloading this method.
|
|
601
448
|
# To check if two exprs are equal, use the expr_eq method.
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
raise NotImplementedError(msg)
|
|
449
|
+
# At runtime, this will resolve if there is a defined egg function for `__eq__`
|
|
450
|
+
def __eq__(self, other: NoReturn) -> Expr: ... # type: ignore[override, empty-body]
|
|
605
451
|
|
|
606
452
|
# Implement these so that copy() works on this object
|
|
607
453
|
# otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
|
|
608
454
|
|
|
609
455
|
def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
|
|
610
|
-
return
|
|
456
|
+
return self.__egg_thunk__()
|
|
611
457
|
|
|
612
458
|
def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
|
|
613
|
-
self.
|
|
459
|
+
self.__egg_thunk__ = Thunk.value(d)
|
|
614
460
|
|
|
615
461
|
def __hash__(self) -> int:
|
|
616
462
|
return hash(self.__egg_typed_expr__)
|
|
@@ -625,12 +471,17 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
|
|
|
625
471
|
__name: str = name,
|
|
626
472
|
**kwargs: object,
|
|
627
473
|
) -> RuntimeExpr | None:
|
|
474
|
+
from .conversion import ConvertError
|
|
475
|
+
|
|
476
|
+
class_name = self.__egg_class_name__
|
|
477
|
+
class_decl = self.__egg_class_decl__
|
|
628
478
|
# First, try to resolve as preserved method
|
|
629
479
|
try:
|
|
630
|
-
method =
|
|
631
|
-
return method(self, *args, **kwargs)
|
|
480
|
+
method = class_decl.preserved_methods[__name]
|
|
632
481
|
except KeyError:
|
|
633
482
|
pass
|
|
483
|
+
else:
|
|
484
|
+
return method(self, *args, **kwargs)
|
|
634
485
|
# If this is a "partial" method meaning that it can return NotImplemented,
|
|
635
486
|
# we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
|
|
636
487
|
# using the arg type of the self arg.
|
|
@@ -639,8 +490,15 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
|
|
|
639
490
|
try:
|
|
640
491
|
return call_method_min_conversion(self, args[0], __name)
|
|
641
492
|
except ConvertError:
|
|
642
|
-
|
|
643
|
-
|
|
493
|
+
# Defer raising not imeplemented in case the dunder method is not symmetrical, then
|
|
494
|
+
# we use the standard process
|
|
495
|
+
pass
|
|
496
|
+
if __name in class_decl.methods:
|
|
497
|
+
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
|
|
498
|
+
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
499
|
+
if __name in PARTIAL_METHODS:
|
|
500
|
+
return NotImplemented
|
|
501
|
+
raise TypeError(f"{class_name!r} object does not support {__name}")
|
|
644
502
|
|
|
645
503
|
setattr(RuntimeExpr, name, _special_method)
|
|
646
504
|
|
|
@@ -655,12 +513,14 @@ for reflected, non_reflected in REFLECTED_BINARY_METHODS.items():
|
|
|
655
513
|
|
|
656
514
|
|
|
657
515
|
def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None:
|
|
516
|
+
from .conversion import min_convertable_tp, resolve_literal
|
|
517
|
+
|
|
658
518
|
# find a minimum type that both can be converted to
|
|
659
519
|
# This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
|
|
660
520
|
min_tp = min_convertable_tp(slf, other, name)
|
|
661
|
-
slf = resolve_literal(min_tp
|
|
662
|
-
other = resolve_literal(min_tp
|
|
663
|
-
method =
|
|
521
|
+
slf = resolve_literal(TypeRefWithVars(min_tp), slf)
|
|
522
|
+
other = resolve_literal(TypeRefWithVars(min_tp), other)
|
|
523
|
+
method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
|
|
664
524
|
return method(other)
|
|
665
525
|
|
|
666
526
|
|
|
@@ -680,21 +540,9 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
|
|
|
680
540
|
"""
|
|
681
541
|
Resolves a runtime callable into a ref
|
|
682
542
|
"""
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
elif isinstance(callable, RuntimeClassMethod):
|
|
690
|
-
ref = ClassMethodRef(callable.class_name, callable.__egg_method_name__)
|
|
691
|
-
decls = callable.__egg_decls__
|
|
692
|
-
elif isinstance(callable, RuntimeMethod):
|
|
693
|
-
ref = MethodRef(callable.__egg_self__.__egg_typed_expr__.tp.name, callable.__egg_method_name__)
|
|
694
|
-
decls = callable.__egg_decls__
|
|
695
|
-
elif isinstance(callable, RuntimeClass):
|
|
696
|
-
ref = ClassMethodRef(callable.__egg_name__, "__init__")
|
|
697
|
-
decls = callable.__egg_decls__
|
|
698
|
-
else:
|
|
699
|
-
raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
|
|
700
|
-
return (ref, decls)
|
|
543
|
+
match callable:
|
|
544
|
+
case RuntimeFunction(decls, ref, _):
|
|
545
|
+
return ref, decls()
|
|
546
|
+
case RuntimeClass(thunk, tp):
|
|
547
|
+
return ClassMethodRef(tp.name, "__init__"), thunk()
|
|
548
|
+
raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
|