egglog 6.1.0__cp310-none-win_amd64.whl → 7.0.0__cp310-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/__init__.py +1 -1
- egglog/bindings.cp310-win_amd64.pyd +0 -0
- egglog/bindings.pyi +2 -0
- egglog/builtins.py +1 -1
- egglog/conversion.py +172 -0
- egglog/declarations.py +329 -735
- egglog/egraph.py +531 -804
- egglog/egraph_state.py +417 -0
- egglog/exp/array_api.py +92 -80
- egglog/exp/array_api_numba.py +6 -1
- egglog/exp/siu_examples.py +35 -0
- egglog/pretty.py +418 -0
- egglog/runtime.py +196 -430
- egglog/thunk.py +72 -0
- egglog/type_constraint_solver.py +5 -2
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/METADATA +19 -19
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/RECORD +19 -14
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/WHEEL +0 -0
- {egglog-6.1.0.dist-info → egglog-7.0.0.dist-info}/license_files/LICENSE +0 -0
egglog/runtime.py
CHANGED
|
@@ -11,172 +11,57 @@ so they are not mangled by Python and can be accessed by the user.
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
-
from dataclasses import dataclass
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from inspect import Parameter, Signature
|
|
15
16
|
from itertools import zip_longest
|
|
16
17
|
from typing import TYPE_CHECKING, NoReturn, TypeVar, Union, cast, get_args, get_origin
|
|
17
18
|
|
|
18
|
-
import black
|
|
19
|
-
import black.parsing
|
|
20
|
-
from typing_extensions import assert_never
|
|
21
|
-
|
|
22
|
-
from . import bindings, config
|
|
23
19
|
from .declarations import *
|
|
24
|
-
from .
|
|
20
|
+
from .pretty import *
|
|
21
|
+
from .thunk import Thunk
|
|
25
22
|
from .type_constraint_solver import *
|
|
26
23
|
|
|
27
24
|
if TYPE_CHECKING:
|
|
28
|
-
from collections.abc import Callable,
|
|
25
|
+
from collections.abc import Callable, Iterable
|
|
29
26
|
|
|
30
27
|
from .egraph import Expr
|
|
31
28
|
|
|
32
29
|
__all__ = [
|
|
33
30
|
"LIT_CLASS_NAMES",
|
|
34
|
-
"class_to_ref",
|
|
35
|
-
"resolve_literal",
|
|
36
31
|
"resolve_callable",
|
|
37
32
|
"resolve_type_annotation",
|
|
38
|
-
"convert_to_same_type",
|
|
39
33
|
"RuntimeClass",
|
|
40
|
-
"RuntimeParamaterizedClass",
|
|
41
|
-
"RuntimeClassMethod",
|
|
42
34
|
"RuntimeExpr",
|
|
43
35
|
"RuntimeFunction",
|
|
44
|
-
"
|
|
45
|
-
"converter",
|
|
36
|
+
"REFLECTED_BINARY_METHODS",
|
|
46
37
|
]
|
|
47
38
|
|
|
48
39
|
|
|
49
|
-
BLACK_MODE = black.Mode(line_length=180)
|
|
50
|
-
|
|
51
40
|
UNIT_CLASS_NAME = "Unit"
|
|
52
41
|
UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
|
|
53
42
|
LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
|
|
54
43
|
|
|
44
|
+
REFLECTED_BINARY_METHODS = {
|
|
45
|
+
"__radd__": "__add__",
|
|
46
|
+
"__rsub__": "__sub__",
|
|
47
|
+
"__rmul__": "__mul__",
|
|
48
|
+
"__rmatmul__": "__matmul__",
|
|
49
|
+
"__rtruediv__": "__truediv__",
|
|
50
|
+
"__rfloordiv__": "__floordiv__",
|
|
51
|
+
"__rmod__": "__mod__",
|
|
52
|
+
"__rpow__": "__pow__",
|
|
53
|
+
"__rlshift__": "__lshift__",
|
|
54
|
+
"__rrshift__": "__rshift__",
|
|
55
|
+
"__rand__": "__and__",
|
|
56
|
+
"__rxor__": "__xor__",
|
|
57
|
+
"__ror__": "__or__",
|
|
58
|
+
}
|
|
59
|
+
|
|
55
60
|
# Set this globally so we can get access to PyObject when we have a type annotation of just object.
|
|
56
61
|
# This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
|
|
57
62
|
_PY_OBJECT_CLASS: RuntimeClass | None = None
|
|
58
63
|
|
|
59
|
-
##
|
|
60
|
-
# Converters
|
|
61
|
-
##
|
|
62
|
-
|
|
63
|
-
# Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
|
|
64
|
-
CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
|
|
65
|
-
# Global declerations to store all convertable types so we can query if they have certain methods or not
|
|
66
|
-
CONVERSIONS_DECLS = Declarations()
|
|
67
|
-
|
|
68
64
|
T = TypeVar("T")
|
|
69
|
-
V = TypeVar("V", bound="Expr")
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
class ConvertError(Exception):
|
|
73
|
-
pass
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None:
|
|
77
|
-
"""
|
|
78
|
-
Register a converter from some type to an egglog type.
|
|
79
|
-
"""
|
|
80
|
-
to_type_name = process_tp(to_type)
|
|
81
|
-
if not isinstance(to_type_name, JustTypeRef):
|
|
82
|
-
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
|
|
83
|
-
_register_converter(process_tp(from_type), to_type_name, fn, cost)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
|
|
87
|
-
"""
|
|
88
|
-
Registers a converter from some type to an egglog type, if not already registered.
|
|
89
|
-
|
|
90
|
-
Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
|
|
91
|
-
Also, if registering A->B and there is already D->A, then D->B will be registered.
|
|
92
|
-
"""
|
|
93
|
-
if a == b:
|
|
94
|
-
return
|
|
95
|
-
if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost:
|
|
96
|
-
return
|
|
97
|
-
CONVERSIONS[(a, b)] = (cost, a_b)
|
|
98
|
-
for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()):
|
|
99
|
-
if b == c:
|
|
100
|
-
_register_converter(a, d, _ComposedConverter(a_b, c_d), cost + other_cost)
|
|
101
|
-
if a == d:
|
|
102
|
-
_register_converter(c, b, _ComposedConverter(c_d, a_b), cost + other_cost)
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
@dataclass
|
|
106
|
-
class _ComposedConverter:
|
|
107
|
-
"""
|
|
108
|
-
A converter which is composed of multiple converters.
|
|
109
|
-
|
|
110
|
-
_ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))
|
|
111
|
-
|
|
112
|
-
We use the dataclass instead of the lambda to make it easier to debug.
|
|
113
|
-
"""
|
|
114
|
-
|
|
115
|
-
a_b: Callable
|
|
116
|
-
b_c: Callable
|
|
117
|
-
|
|
118
|
-
def __call__(self, x: object) -> object:
|
|
119
|
-
return self.b_c(self.a_b(x))
|
|
120
|
-
|
|
121
|
-
def __str__(self) -> str:
|
|
122
|
-
return f"{self.b_c} ∘ {self.a_b}"
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
def convert(source: object, target: type[V]) -> V:
|
|
126
|
-
"""
|
|
127
|
-
Convert a source object to a target type.
|
|
128
|
-
"""
|
|
129
|
-
target_ref = class_to_ref(cast(RuntimeTypeArgType, target))
|
|
130
|
-
return cast(V, resolve_literal(target_ref.to_var(), source))
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
|
|
134
|
-
"""
|
|
135
|
-
Convert a source object to the same type as the target.
|
|
136
|
-
"""
|
|
137
|
-
tp = target.__egg_typed_expr__.tp
|
|
138
|
-
return resolve_literal(tp.to_var(), source)
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
def process_tp(tp: type | RuntimeTypeArgType) -> JustTypeRef | type:
|
|
142
|
-
"""
|
|
143
|
-
Process a type before converting it, to add it to the global declerations and resolve to a ref.
|
|
144
|
-
"""
|
|
145
|
-
global CONVERSIONS_DECLS
|
|
146
|
-
if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass):
|
|
147
|
-
CONVERSIONS_DECLS |= tp
|
|
148
|
-
return class_to_ref(tp)
|
|
149
|
-
return tp
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
|
|
153
|
-
"""
|
|
154
|
-
Returns the minimum convertable type between a and b, that has a method `name`, raising a TypeError if no such type exists.
|
|
155
|
-
"""
|
|
156
|
-
a_tp = _get_tp(a)
|
|
157
|
-
b_tp = _get_tp(b)
|
|
158
|
-
a_converts_to = {
|
|
159
|
-
to: c
|
|
160
|
-
for ((from_, to), (c, _)) in CONVERSIONS.items()
|
|
161
|
-
if from_ == a_tp and CONVERSIONS_DECLS.has_method(to.name, name)
|
|
162
|
-
}
|
|
163
|
-
b_converts_to = {
|
|
164
|
-
to: c
|
|
165
|
-
for ((from_, to), (c, _)) in CONVERSIONS.items()
|
|
166
|
-
if from_ == b_tp and CONVERSIONS_DECLS.has_method(to.name, name)
|
|
167
|
-
}
|
|
168
|
-
if isinstance(a_tp, JustTypeRef):
|
|
169
|
-
a_converts_to[a_tp] = 0
|
|
170
|
-
if isinstance(b_tp, JustTypeRef):
|
|
171
|
-
b_converts_to[b_tp] = 0
|
|
172
|
-
common = set(a_converts_to) & set(b_converts_to)
|
|
173
|
-
if not common:
|
|
174
|
-
raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
|
|
175
|
-
return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
def identity(x: object) -> object:
|
|
179
|
-
return x
|
|
180
65
|
|
|
181
66
|
|
|
182
67
|
def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
|
|
@@ -195,99 +80,62 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
|
|
|
195
80
|
assert _PY_OBJECT_CLASS
|
|
196
81
|
return resolve_type_annotation(decls, _PY_OBJECT_CLASS)
|
|
197
82
|
if isinstance(tp, RuntimeClass):
|
|
198
|
-
decls |= tp
|
|
199
|
-
return tp.__egg_tp__.to_var()
|
|
200
|
-
if isinstance(tp, RuntimeParamaterizedClass):
|
|
201
83
|
decls |= tp
|
|
202
84
|
return tp.__egg_tp__
|
|
203
85
|
raise TypeError(f"Unexpected type annotation {tp}")
|
|
204
86
|
|
|
205
87
|
|
|
206
|
-
def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
|
|
207
|
-
arg_type = _get_tp(arg)
|
|
208
|
-
|
|
209
|
-
# If we have any type variables, dont bother trying to resolve the literal, just return the arg
|
|
210
|
-
try:
|
|
211
|
-
tp_just = tp.to_just()
|
|
212
|
-
except NotImplementedError:
|
|
213
|
-
# If this is a var, it has to be a runtime exprssions
|
|
214
|
-
assert isinstance(arg, RuntimeExpr)
|
|
215
|
-
return arg
|
|
216
|
-
if arg_type == tp_just:
|
|
217
|
-
# If the type is an egg type, it has to be a runtime expr
|
|
218
|
-
assert isinstance(arg, RuntimeExpr)
|
|
219
|
-
return arg
|
|
220
|
-
# Try all parent types as well, if we are converting from a Python type
|
|
221
|
-
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
|
|
222
|
-
try:
|
|
223
|
-
fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1]
|
|
224
|
-
except KeyError:
|
|
225
|
-
continue
|
|
226
|
-
break
|
|
227
|
-
else:
|
|
228
|
-
arg_type_str = arg_type.pretty() if isinstance(arg_type, JustTypeRef) else arg_type.__name__
|
|
229
|
-
raise ConvertError(f"Cannot convert {arg_type_str} to {tp_just.pretty()}")
|
|
230
|
-
return fn(arg)
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
def _get_tp(x: object) -> JustTypeRef | type:
|
|
234
|
-
if isinstance(x, RuntimeExpr):
|
|
235
|
-
return x.__egg_typed_expr__.tp
|
|
236
|
-
tp = type(x)
|
|
237
|
-
# If this value has a custom metaclass, let's use that as our index instead of the type
|
|
238
|
-
if type(tp) != type:
|
|
239
|
-
return type(tp)
|
|
240
|
-
return tp
|
|
241
|
-
|
|
242
|
-
|
|
243
88
|
##
|
|
244
89
|
# Runtime objects
|
|
245
90
|
##
|
|
246
91
|
|
|
247
92
|
|
|
248
93
|
@dataclass
|
|
249
|
-
class RuntimeClass:
|
|
250
|
-
|
|
251
|
-
# This function should mutate the declerations and add to them
|
|
252
|
-
# Used this instead of a lazy property so we can have a reference to the decls in the class as its computing
|
|
253
|
-
lazy_decls: Callable[[Declarations], None] = field(repr=False)
|
|
254
|
-
# Cached declerations
|
|
255
|
-
_inner_decls: Declarations | None = field(init=False, repr=False, default=None)
|
|
256
|
-
__egg_name__: str
|
|
94
|
+
class RuntimeClass(DelayedDeclerations):
|
|
95
|
+
__egg_tp__: TypeRefWithVars
|
|
257
96
|
|
|
258
97
|
def __post_init__(self) -> None:
|
|
259
98
|
global _PY_OBJECT_CLASS
|
|
260
|
-
if self.
|
|
99
|
+
if self.__egg_tp__.name == "PyObject":
|
|
261
100
|
_PY_OBJECT_CLASS = self
|
|
262
101
|
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
102
|
+
def verify(self) -> None:
|
|
103
|
+
if not self.__egg_tp__.args:
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
# Raise error if we have args, but they are the wrong number
|
|
107
|
+
desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
|
|
108
|
+
if len(self.__egg_tp__.args) != len(desired_args):
|
|
109
|
+
raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
|
|
270
110
|
|
|
271
111
|
def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
|
|
272
112
|
"""
|
|
273
113
|
Create an instance of this kind by calling the __init__ classmethod
|
|
274
114
|
"""
|
|
275
115
|
# If this is a literal type, initializing it with a literal should return a literal
|
|
276
|
-
if self.
|
|
116
|
+
if self.__egg_tp__.name == "PyObject":
|
|
277
117
|
assert len(args) == 1
|
|
278
|
-
return RuntimeExpr
|
|
279
|
-
|
|
118
|
+
return RuntimeExpr.__from_value__(
|
|
119
|
+
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
|
|
120
|
+
)
|
|
121
|
+
if self.__egg_tp__.name in UNARY_LIT_CLASS_NAMES:
|
|
280
122
|
assert len(args) == 1
|
|
281
123
|
assert isinstance(args[0], int | float | str | bool)
|
|
282
|
-
return RuntimeExpr
|
|
283
|
-
|
|
124
|
+
return RuntimeExpr.__from_value__(
|
|
125
|
+
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
|
|
126
|
+
)
|
|
127
|
+
if self.__egg_tp__.name == UNIT_CLASS_NAME:
|
|
284
128
|
assert len(args) == 0
|
|
285
|
-
return RuntimeExpr
|
|
129
|
+
return RuntimeExpr.__from_value__(
|
|
130
|
+
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
|
|
131
|
+
)
|
|
286
132
|
|
|
287
|
-
return
|
|
133
|
+
return RuntimeFunction(
|
|
134
|
+
Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, "__init__"), self.__egg_tp__.to_just()
|
|
135
|
+
)(*args, **kwargs)
|
|
288
136
|
|
|
289
137
|
def __dir__(self) -> list[str]:
|
|
290
|
-
cls_decl = self.__egg_decls__.get_class_decl(self.
|
|
138
|
+
cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
|
|
291
139
|
possible_methods = (
|
|
292
140
|
list(cls_decl.class_methods) + list(cls_decl.class_variables) + list(cls_decl.preserved_methods)
|
|
293
141
|
)
|
|
@@ -296,14 +144,19 @@ class RuntimeClass:
|
|
|
296
144
|
possible_methods.append("__call__")
|
|
297
145
|
return possible_methods
|
|
298
146
|
|
|
299
|
-
def __getitem__(self, args: object) ->
|
|
147
|
+
def __getitem__(self, args: object) -> RuntimeClass:
|
|
148
|
+
if self.__egg_tp__.args:
|
|
149
|
+
raise TypeError(f"Cannot index into a paramaterized class {self}")
|
|
300
150
|
if not isinstance(args, tuple):
|
|
301
151
|
args = (args,)
|
|
302
152
|
decls = self.__egg_decls__.copy()
|
|
303
|
-
tp = TypeRefWithVars(self.
|
|
304
|
-
return
|
|
153
|
+
tp = TypeRefWithVars(self.__egg_tp__.name, tuple(resolve_type_annotation(decls, arg) for arg in args))
|
|
154
|
+
return RuntimeClass(Thunk.value(decls), tp)
|
|
155
|
+
|
|
156
|
+
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
|
|
157
|
+
if name == "__origin__" and self.__egg_tp__.args:
|
|
158
|
+
return RuntimeClass(self.__egg_decls_thunk__, TypeRefWithVars(self.__egg_tp__.name))
|
|
305
159
|
|
|
306
|
-
def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr | Callable:
|
|
307
160
|
# Special case some names that don't exist so we can exit early without resolving decls
|
|
308
161
|
# Important so if we take union of RuntimeClass it won't try to resolve decls
|
|
309
162
|
if name in {
|
|
@@ -314,7 +167,7 @@ class RuntimeClass:
|
|
|
314
167
|
}:
|
|
315
168
|
raise AttributeError
|
|
316
169
|
|
|
317
|
-
cls_decl = self.__egg_decls__.
|
|
170
|
+
cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
|
|
318
171
|
|
|
319
172
|
preserved_methods = cls_decl.preserved_methods
|
|
320
173
|
if name in preserved_methods:
|
|
@@ -323,159 +176,107 @@ class RuntimeClass:
|
|
|
323
176
|
# if this is a class variable, return an expr for it, otherwise, assume it's a method
|
|
324
177
|
if name in cls_decl.class_variables:
|
|
325
178
|
return_tp = cls_decl.class_variables[name]
|
|
326
|
-
return RuntimeExpr(
|
|
327
|
-
self.__egg_decls__,
|
|
179
|
+
return RuntimeExpr.__from_value__(
|
|
180
|
+
self.__egg_decls__,
|
|
181
|
+
TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
|
|
328
182
|
)
|
|
329
|
-
|
|
183
|
+
if name in cls_decl.class_methods:
|
|
184
|
+
return RuntimeFunction(
|
|
185
|
+
Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
|
|
186
|
+
)
|
|
187
|
+
msg = f"Class {self.__egg_tp__.name} has no method {name}"
|
|
188
|
+
if name == "__ne__":
|
|
189
|
+
msg += ". Did you mean to use the ne(...).to(...)?"
|
|
190
|
+
raise AttributeError(msg) from None
|
|
330
191
|
|
|
331
192
|
def __str__(self) -> str:
|
|
332
|
-
return self.
|
|
193
|
+
return str(self.__egg_tp__)
|
|
333
194
|
|
|
334
195
|
# Make hashable so can go in Union
|
|
335
196
|
def __hash__(self) -> int:
|
|
336
|
-
return hash((id(self.
|
|
197
|
+
return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
|
|
337
198
|
|
|
338
199
|
# Support unioning like types
|
|
339
200
|
def __or__(self, __value: type) -> object:
|
|
340
201
|
return Union[self, __value] # noqa: UP007
|
|
341
202
|
|
|
342
|
-
@property
|
|
343
|
-
def __egg_tp__(self) -> JustTypeRef:
|
|
344
|
-
return JustTypeRef(self.__egg_name__)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@dataclass
|
|
348
|
-
class RuntimeParamaterizedClass:
|
|
349
|
-
__egg_decls__: Declarations
|
|
350
|
-
__egg_tp__: TypeRefWithVars
|
|
351
|
-
|
|
352
|
-
def __post_init__(self) -> None:
|
|
353
|
-
desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
|
|
354
|
-
if len(self.__egg_tp__.args) != len(desired_args):
|
|
355
|
-
raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
|
|
356
|
-
|
|
357
|
-
def __call__(self, *args: object) -> RuntimeExpr | None:
|
|
358
|
-
return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), "__init__")(*args)
|
|
359
|
-
|
|
360
|
-
def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeClass:
|
|
361
|
-
# Special case so when get_type_annotations proccessed it can work
|
|
362
|
-
if name == "__origin__":
|
|
363
|
-
return RuntimeClass(self.__egg_decls__.update_other, self.__egg_tp__.name)
|
|
364
|
-
return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), name)
|
|
365
|
-
|
|
366
|
-
def __str__(self) -> str:
|
|
367
|
-
return self.__egg_tp__.pretty()
|
|
368
|
-
|
|
369
|
-
# Support unioning
|
|
370
|
-
def __or__(self, __value: type) -> object:
|
|
371
|
-
return Union[self, __value] # noqa: UP007
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
# Type args can either be typevars or classes
|
|
375
|
-
RuntimeTypeArgType = RuntimeClass | RuntimeParamaterizedClass
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
def class_to_ref(cls: RuntimeTypeArgType) -> JustTypeRef:
|
|
379
|
-
if isinstance(cls, RuntimeClass):
|
|
380
|
-
return JustTypeRef(cls.__egg_name__)
|
|
381
|
-
if isinstance(cls, RuntimeParamaterizedClass):
|
|
382
|
-
# Currently this is used when calling methods on a parametrized class, which is only possible when we
|
|
383
|
-
# have actualy types currently, not typevars, currently.
|
|
384
|
-
return cls.__egg_tp__.to_just()
|
|
385
|
-
assert_never(cls)
|
|
386
|
-
|
|
387
203
|
|
|
388
204
|
@dataclass
|
|
389
|
-
class RuntimeFunction:
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
__egg_fn_decl__: FunctionDecl = field(init=False)
|
|
394
|
-
|
|
395
|
-
def __post_init__(self) -> None:
|
|
396
|
-
self.__egg_fn_ref__ = FunctionRef(self.__egg_name__)
|
|
397
|
-
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_fn_ref__)
|
|
205
|
+
class RuntimeFunction(DelayedDeclerations):
|
|
206
|
+
__egg_ref__: CallableRef
|
|
207
|
+
# bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
|
|
208
|
+
__egg_bound__: JustTypeRef | RuntimeExpr | None = None
|
|
398
209
|
|
|
399
210
|
def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
if bound_class is not None and bound_class.args:
|
|
430
|
-
tcs.bind_class(bound_class)
|
|
431
|
-
|
|
432
|
-
if fn_decl is not None:
|
|
211
|
+
from .conversion import resolve_literal
|
|
212
|
+
|
|
213
|
+
if isinstance(self.__egg_bound__, RuntimeExpr):
|
|
214
|
+
args = (self.__egg_bound__, *args)
|
|
215
|
+
fn_decl = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl()
|
|
216
|
+
# Turn all keyword args into positional args
|
|
217
|
+
bound = callable_decl_to_signature(fn_decl, self.__egg_decls__).bind(*args, **kwargs)
|
|
218
|
+
bound.apply_defaults()
|
|
219
|
+
assert not bound.kwargs
|
|
220
|
+
del args, kwargs
|
|
221
|
+
|
|
222
|
+
upcasted_args = [
|
|
223
|
+
resolve_literal(cast(TypeOrVarRef, tp), arg)
|
|
224
|
+
for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type)
|
|
225
|
+
]
|
|
226
|
+
|
|
227
|
+
decls = Declarations.create(self, *upcasted_args)
|
|
228
|
+
|
|
229
|
+
tcs = TypeConstraintSolver(decls)
|
|
230
|
+
bound_tp = (
|
|
231
|
+
None
|
|
232
|
+
if self.__egg_bound__ is None
|
|
233
|
+
else self.__egg_bound__.__egg_typed_expr__.tp
|
|
234
|
+
if isinstance(self.__egg_bound__, RuntimeExpr)
|
|
235
|
+
else self.__egg_bound__
|
|
236
|
+
)
|
|
237
|
+
if bound_tp and bound_tp.args:
|
|
238
|
+
tcs.bind_class(bound_tp)
|
|
239
|
+
arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
|
|
433
240
|
arg_types = [expr.tp for expr in arg_exprs]
|
|
434
|
-
cls_name =
|
|
241
|
+
cls_name = bound_tp.name if bound_tp else None
|
|
435
242
|
return_tp = tcs.infer_return_type(
|
|
436
|
-
fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types, cls_name
|
|
243
|
+
fn_decl.arg_types, fn_decl.return_type or fn_decl.arg_types[0], fn_decl.var_arg_type, arg_types, cls_name
|
|
437
244
|
)
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
first_arg.__egg_typed_expr__ = typed_expr_decl
|
|
448
|
-
first_arg.__egg_decls__ = decls
|
|
449
|
-
return None
|
|
450
|
-
return RuntimeExpr(decls, TypedExprDecl(return_tp, expr_decl))
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
@dataclass
|
|
454
|
-
class RuntimeClassMethod:
|
|
455
|
-
__egg_decls__: Declarations
|
|
456
|
-
__egg_tp__: JustTypeRef
|
|
457
|
-
__egg_method_name__: str
|
|
458
|
-
__egg_callable_ref__: ClassMethodRef = field(init=False)
|
|
459
|
-
__egg_fn_decl__: FunctionDecl = field(init=False)
|
|
460
|
-
|
|
461
|
-
def __post_init__(self) -> None:
|
|
462
|
-
self.__egg_callable_ref__ = ClassMethodRef(self.class_name, self.__egg_method_name__)
|
|
463
|
-
try:
|
|
464
|
-
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
|
|
465
|
-
except KeyError as e:
|
|
466
|
-
raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") from e
|
|
467
|
-
|
|
468
|
-
def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None:
|
|
469
|
-
return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs, self.__egg_tp__)
|
|
245
|
+
bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None
|
|
246
|
+
expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
|
|
247
|
+
typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
|
|
248
|
+
# If there is not return type, we are mutating the first arg
|
|
249
|
+
if not fn_decl.return_type:
|
|
250
|
+
first_arg = upcasted_args[0]
|
|
251
|
+
first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
|
|
252
|
+
return None
|
|
253
|
+
return RuntimeExpr.__from_value__(decls, typed_expr_decl)
|
|
470
254
|
|
|
471
255
|
def __str__(self) -> str:
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
return self.
|
|
256
|
+
first_arg, bound_tp_params = None, None
|
|
257
|
+
match self.__egg_bound__:
|
|
258
|
+
case RuntimeExpr(_):
|
|
259
|
+
first_arg = self.__egg_bound__.__egg_typed_expr__.expr
|
|
260
|
+
case JustTypeRef(_, args):
|
|
261
|
+
bound_tp_params = args
|
|
262
|
+
return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def callable_decl_to_signature(
|
|
266
|
+
decl: FunctionDecl,
|
|
267
|
+
decls: Declarations,
|
|
268
|
+
) -> Signature:
|
|
269
|
+
parameters = [
|
|
270
|
+
Parameter(
|
|
271
|
+
n,
|
|
272
|
+
Parameter.POSITIONAL_OR_KEYWORD,
|
|
273
|
+
default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d)) if d else Parameter.empty,
|
|
274
|
+
)
|
|
275
|
+
for n, d, t in zip(decl.arg_names, decl.arg_defaults, decl.arg_types, strict=True)
|
|
276
|
+
]
|
|
277
|
+
if isinstance(decl, FunctionDecl) and decl.var_arg_type is not None:
|
|
278
|
+
parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
|
|
279
|
+
return Signature(parameters)
|
|
479
280
|
|
|
480
281
|
|
|
481
282
|
# All methods which should return NotImplemented if they fail to resolve
|
|
@@ -505,63 +306,34 @@ PARTIAL_METHODS = {
|
|
|
505
306
|
|
|
506
307
|
|
|
507
308
|
@dataclass
|
|
508
|
-
class
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
__egg_callable_ref__: MethodRef | PropertyRef = field(init=False)
|
|
512
|
-
__egg_fn_decl__: FunctionDecl = field(init=False, repr=False)
|
|
513
|
-
__egg_decls__: Declarations = field(init=False)
|
|
514
|
-
|
|
515
|
-
def __post_init__(self) -> None:
|
|
516
|
-
self.__egg_decls__ = self.__egg_self__.__egg_decls__
|
|
517
|
-
if self.__egg_method_name__ in self.__egg_decls__.get_class_decl(self.class_name).properties:
|
|
518
|
-
self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__)
|
|
519
|
-
else:
|
|
520
|
-
self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__)
|
|
521
|
-
try:
|
|
522
|
-
self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
|
|
523
|
-
except KeyError:
|
|
524
|
-
msg = f"Class {self.class_name} does not have method {self.__egg_method_name__}"
|
|
525
|
-
if self.__egg_method_name__ == "__ne__":
|
|
526
|
-
msg += ". Did you mean to use the ne(...).to(...)?"
|
|
527
|
-
raise AttributeError(msg) from None
|
|
309
|
+
class RuntimeExpr:
|
|
310
|
+
# Defer needing decls/expr so we can make constants that don't resolve their class types
|
|
311
|
+
__egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
|
|
528
312
|
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
return _call(
|
|
533
|
-
self.__egg_decls__,
|
|
534
|
-
self.__egg_callable_ref__,
|
|
535
|
-
self.__egg_fn_decl__,
|
|
536
|
-
args,
|
|
537
|
-
kwargs,
|
|
538
|
-
self.__egg_self__.__egg_typed_expr__.tp,
|
|
539
|
-
)
|
|
540
|
-
except ConvertError as e:
|
|
541
|
-
name = self.__egg_method_name__
|
|
542
|
-
raise TypeError(f"Wrong types for {self.__egg_self__.__egg_typed_expr__.tp.pretty()}.{name}") from e
|
|
313
|
+
@classmethod
|
|
314
|
+
def __from_value__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
|
|
315
|
+
return cls(Thunk.value((d, e)))
|
|
543
316
|
|
|
544
317
|
@property
|
|
545
|
-
def
|
|
546
|
-
return self.
|
|
547
|
-
|
|
318
|
+
def __egg_decls__(self) -> Declarations:
|
|
319
|
+
return self.__egg_thunk__()[0]
|
|
548
320
|
|
|
549
|
-
@
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
__egg_typed_expr__: TypedExprDecl
|
|
321
|
+
@property
|
|
322
|
+
def __egg_typed_expr__(self) -> TypedExprDecl:
|
|
323
|
+
return self.__egg_thunk__()[1]
|
|
553
324
|
|
|
554
|
-
def __getattr__(self, name: str) ->
|
|
555
|
-
|
|
325
|
+
def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
|
|
326
|
+
cls_name = self.__egg_class_name__
|
|
327
|
+
class_decl = self.__egg_class_decl__
|
|
556
328
|
|
|
557
|
-
preserved_methods
|
|
558
|
-
if name in preserved_methods:
|
|
329
|
+
if name in (preserved_methods := class_decl.preserved_methods):
|
|
559
330
|
return preserved_methods[name].__get__(self)
|
|
560
331
|
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
332
|
+
if name in class_decl.methods:
|
|
333
|
+
return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(cls_name, name), self)
|
|
334
|
+
if name in class_decl.properties:
|
|
335
|
+
return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(cls_name, name), self)()
|
|
336
|
+
raise AttributeError(f"{cls_name} has no method {name}") from None
|
|
565
337
|
|
|
566
338
|
def __repr__(self) -> str:
|
|
567
339
|
"""
|
|
@@ -570,18 +342,10 @@ class RuntimeExpr:
|
|
|
570
342
|
return str(self)
|
|
571
343
|
|
|
572
344
|
def __str__(self) -> str:
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
if config.SHOW_TYPES:
|
|
578
|
-
raise NotImplementedError
|
|
579
|
-
# s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}"
|
|
580
|
-
# return black.format_str(s, mode=black.FileMode()).strip()
|
|
581
|
-
pretty_statements = context.render(pretty_expr)
|
|
582
|
-
return black.format_str(pretty_statements, mode=BLACK_MODE).strip()
|
|
583
|
-
except black.parsing.InvalidInput:
|
|
584
|
-
return pretty_expr
|
|
345
|
+
return self.__egg_pretty__(None)
|
|
346
|
+
|
|
347
|
+
def __egg_pretty__(self, wrapping_fn: str | None) -> str:
|
|
348
|
+
return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
|
|
585
349
|
|
|
586
350
|
def _ipython_display_(self) -> None:
|
|
587
351
|
from IPython.display import Code, display
|
|
@@ -589,28 +353,32 @@ class RuntimeExpr:
|
|
|
589
353
|
display(Code(str(self), language="python"))
|
|
590
354
|
|
|
591
355
|
def __dir__(self) -> Iterable[str]:
|
|
592
|
-
|
|
356
|
+
class_decl = self.__egg_class_decl__
|
|
357
|
+
return list(class_decl.methods) + list(class_decl.properties) + list(class_decl.preserved_methods)
|
|
593
358
|
|
|
594
359
|
@property
|
|
595
|
-
def
|
|
596
|
-
return self.__egg_typed_expr__.
|
|
360
|
+
def __egg_class_name__(self) -> str:
|
|
361
|
+
return self.__egg_typed_expr__.tp.name
|
|
362
|
+
|
|
363
|
+
@property
|
|
364
|
+
def __egg_class_decl__(self) -> ClassDecl:
|
|
365
|
+
return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
|
|
597
366
|
|
|
598
367
|
# Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because
|
|
599
368
|
# we don't wany any type that MyPy thinks is an expr to be used with __eq__.
|
|
600
369
|
# That's because we want to reserve __eq__ for domain specific equality checks, overloading this method.
|
|
601
370
|
# To check if two exprs are equal, use the expr_eq method.
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
raise NotImplementedError(msg)
|
|
371
|
+
# At runtime, this will resolve if there is a defined egg function for `__eq__`
|
|
372
|
+
def __eq__(self, other: NoReturn) -> Expr: ... # type: ignore[override, empty-body]
|
|
605
373
|
|
|
606
374
|
# Implement these so that copy() works on this object
|
|
607
375
|
# otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
|
|
608
376
|
|
|
609
377
|
def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
|
|
610
|
-
return
|
|
378
|
+
return self.__egg_thunk__()
|
|
611
379
|
|
|
612
380
|
def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
|
|
613
|
-
self.
|
|
381
|
+
self.__egg_thunk__ = Thunk.value(d)
|
|
614
382
|
|
|
615
383
|
def __hash__(self) -> int:
|
|
616
384
|
return hash(self.__egg_typed_expr__)
|
|
@@ -625,12 +393,17 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
|
|
|
625
393
|
__name: str = name,
|
|
626
394
|
**kwargs: object,
|
|
627
395
|
) -> RuntimeExpr | None:
|
|
396
|
+
from .conversion import ConvertError
|
|
397
|
+
|
|
398
|
+
class_name = self.__egg_class_name__
|
|
399
|
+
class_decl = self.__egg_class_decl__
|
|
628
400
|
# First, try to resolve as preserved method
|
|
629
401
|
try:
|
|
630
|
-
method =
|
|
631
|
-
return method(self, *args, **kwargs)
|
|
402
|
+
method = class_decl.preserved_methods[__name]
|
|
632
403
|
except KeyError:
|
|
633
404
|
pass
|
|
405
|
+
else:
|
|
406
|
+
return method(self, *args, **kwargs)
|
|
634
407
|
# If this is a "partial" method meaning that it can return NotImplemented,
|
|
635
408
|
# we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
|
|
636
409
|
# using the arg type of the self arg.
|
|
@@ -640,7 +413,10 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
|
|
|
640
413
|
return call_method_min_conversion(self, args[0], __name)
|
|
641
414
|
except ConvertError:
|
|
642
415
|
return NotImplemented
|
|
643
|
-
|
|
416
|
+
if __name in class_decl.methods:
|
|
417
|
+
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
|
|
418
|
+
return fn(*args, **kwargs)
|
|
419
|
+
raise TypeError(f"{class_name!r} object does not support {__name}")
|
|
644
420
|
|
|
645
421
|
setattr(RuntimeExpr, name, _special_method)
|
|
646
422
|
|
|
@@ -655,12 +431,14 @@ for reflected, non_reflected in REFLECTED_BINARY_METHODS.items():
|
|
|
655
431
|
|
|
656
432
|
|
|
657
433
|
def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None:
|
|
434
|
+
from .conversion import min_convertable_tp, resolve_literal
|
|
435
|
+
|
|
658
436
|
# find a minimum type that both can be converted to
|
|
659
437
|
# This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
|
|
660
438
|
min_tp = min_convertable_tp(slf, other, name)
|
|
661
439
|
slf = resolve_literal(min_tp.to_var(), slf)
|
|
662
440
|
other = resolve_literal(min_tp.to_var(), other)
|
|
663
|
-
method =
|
|
441
|
+
method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
|
|
664
442
|
return method(other)
|
|
665
443
|
|
|
666
444
|
|
|
@@ -680,21 +458,9 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
|
|
|
680
458
|
"""
|
|
681
459
|
Resolves a runtime callable into a ref
|
|
682
460
|
"""
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
elif isinstance(callable, RuntimeClassMethod):
|
|
690
|
-
ref = ClassMethodRef(callable.class_name, callable.__egg_method_name__)
|
|
691
|
-
decls = callable.__egg_decls__
|
|
692
|
-
elif isinstance(callable, RuntimeMethod):
|
|
693
|
-
ref = MethodRef(callable.__egg_self__.__egg_typed_expr__.tp.name, callable.__egg_method_name__)
|
|
694
|
-
decls = callable.__egg_decls__
|
|
695
|
-
elif isinstance(callable, RuntimeClass):
|
|
696
|
-
ref = ClassMethodRef(callable.__egg_name__, "__init__")
|
|
697
|
-
decls = callable.__egg_decls__
|
|
698
|
-
else:
|
|
699
|
-
raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
|
|
700
|
-
return (ref, decls)
|
|
461
|
+
match callable:
|
|
462
|
+
case RuntimeFunction(decls, ref, _):
|
|
463
|
+
return ref, decls()
|
|
464
|
+
case RuntimeClass(thunk, tp):
|
|
465
|
+
return ClassMethodRef(tp.name, "__init__"), thunk()
|
|
466
|
+
raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
|