egglog 6.1.0__cp310-none-win_amd64.whl → 7.1.0__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/runtime.py CHANGED
@@ -11,177 +11,67 @@ so they are not mangled by Python and can be accessed by the user.
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from dataclasses import dataclass, field
14
+ from collections.abc import Callable
15
+ from dataclasses import dataclass, replace
16
+ from inspect import Parameter, Signature
15
17
  from itertools import zip_longest
16
18
  from typing import TYPE_CHECKING, NoReturn, TypeVar, Union, cast, get_args, get_origin
17
19
 
18
- import black
19
- import black.parsing
20
- from typing_extensions import assert_never
21
-
22
- from . import bindings, config
23
20
  from .declarations import *
24
- from .declarations import BINARY_METHODS, REFLECTED_BINARY_METHODS, UNARY_METHODS
21
+ from .pretty import *
22
+ from .thunk import Thunk
25
23
  from .type_constraint_solver import *
26
24
 
27
25
  if TYPE_CHECKING:
28
- from collections.abc import Callable, Collection, Iterable
26
+ from collections.abc import Iterable
29
27
 
30
28
  from .egraph import Expr
31
29
 
32
30
  __all__ = [
33
31
  "LIT_CLASS_NAMES",
34
- "class_to_ref",
35
- "resolve_literal",
36
32
  "resolve_callable",
37
33
  "resolve_type_annotation",
38
- "convert_to_same_type",
39
34
  "RuntimeClass",
40
- "RuntimeParamaterizedClass",
41
- "RuntimeClassMethod",
42
35
  "RuntimeExpr",
43
36
  "RuntimeFunction",
44
- "convert",
45
- "converter",
37
+ "REFLECTED_BINARY_METHODS",
46
38
  ]
47
39
 
48
40
 
49
- BLACK_MODE = black.Mode(line_length=180)
50
-
51
41
  UNIT_CLASS_NAME = "Unit"
52
42
  UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
53
43
  LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
54
44
 
45
+ REFLECTED_BINARY_METHODS = {
46
+ "__radd__": "__add__",
47
+ "__rsub__": "__sub__",
48
+ "__rmul__": "__mul__",
49
+ "__rmatmul__": "__matmul__",
50
+ "__rtruediv__": "__truediv__",
51
+ "__rfloordiv__": "__floordiv__",
52
+ "__rmod__": "__mod__",
53
+ "__rpow__": "__pow__",
54
+ "__rlshift__": "__lshift__",
55
+ "__rrshift__": "__rshift__",
56
+ "__rand__": "__and__",
57
+ "__rxor__": "__xor__",
58
+ "__ror__": "__or__",
59
+ }
60
+
55
61
  # Set this globally so we can get access to PyObject when we have a type annotation of just object.
56
62
  # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
57
63
  _PY_OBJECT_CLASS: RuntimeClass | None = None
58
-
59
- ##
60
- # Converters
61
- ##
62
-
63
- # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
64
- CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
65
- # Global declerations to store all convertable types so we can query if they have certain methods or not
66
- CONVERSIONS_DECLS = Declarations()
64
+ # Same for functions
65
+ _UNSTABLE_FN_CLASS: RuntimeClass | None = None
67
66
 
68
67
  T = TypeVar("T")
69
- V = TypeVar("V", bound="Expr")
70
-
71
-
72
- class ConvertError(Exception):
73
- pass
74
-
75
-
76
- def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None:
77
- """
78
- Register a converter from some type to an egglog type.
79
- """
80
- to_type_name = process_tp(to_type)
81
- if not isinstance(to_type_name, JustTypeRef):
82
- raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
83
- _register_converter(process_tp(from_type), to_type_name, fn, cost)
84
-
85
-
86
- def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
87
- """
88
- Registers a converter from some type to an egglog type, if not already registered.
89
-
90
- Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
91
- Also, if registering A->B and there is already D->A, then D->B will be registered.
92
- """
93
- if a == b:
94
- return
95
- if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost:
96
- return
97
- CONVERSIONS[(a, b)] = (cost, a_b)
98
- for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()):
99
- if b == c:
100
- _register_converter(a, d, _ComposedConverter(a_b, c_d), cost + other_cost)
101
- if a == d:
102
- _register_converter(c, b, _ComposedConverter(c_d, a_b), cost + other_cost)
103
-
104
-
105
- @dataclass
106
- class _ComposedConverter:
107
- """
108
- A converter which is composed of multiple converters.
109
-
110
- _ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))
111
-
112
- We use the dataclass instead of the lambda to make it easier to debug.
113
- """
114
-
115
- a_b: Callable
116
- b_c: Callable
117
-
118
- def __call__(self, x: object) -> object:
119
- return self.b_c(self.a_b(x))
120
-
121
- def __str__(self) -> str:
122
- return f"{self.b_c} ∘ {self.a_b}"
123
-
124
-
125
- def convert(source: object, target: type[V]) -> V:
126
- """
127
- Convert a source object to a target type.
128
- """
129
- target_ref = class_to_ref(cast(RuntimeTypeArgType, target))
130
- return cast(V, resolve_literal(target_ref.to_var(), source))
131
-
132
-
133
- def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
134
- """
135
- Convert a source object to the same type as the target.
136
- """
137
- tp = target.__egg_typed_expr__.tp
138
- return resolve_literal(tp.to_var(), source)
139
-
140
-
141
- def process_tp(tp: type | RuntimeTypeArgType) -> JustTypeRef | type:
142
- """
143
- Process a type before converting it, to add it to the global declerations and resolve to a ref.
144
- """
145
- global CONVERSIONS_DECLS
146
- if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass):
147
- CONVERSIONS_DECLS |= tp
148
- return class_to_ref(tp)
149
- return tp
150
-
151
-
152
- def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
153
- """
154
- Returns the minimum convertable type between a and b, that has a method `name`, raising a TypeError if no such type exists.
155
- """
156
- a_tp = _get_tp(a)
157
- b_tp = _get_tp(b)
158
- a_converts_to = {
159
- to: c
160
- for ((from_, to), (c, _)) in CONVERSIONS.items()
161
- if from_ == a_tp and CONVERSIONS_DECLS.has_method(to.name, name)
162
- }
163
- b_converts_to = {
164
- to: c
165
- for ((from_, to), (c, _)) in CONVERSIONS.items()
166
- if from_ == b_tp and CONVERSIONS_DECLS.has_method(to.name, name)
167
- }
168
- if isinstance(a_tp, JustTypeRef):
169
- a_converts_to[a_tp] = 0
170
- if isinstance(b_tp, JustTypeRef):
171
- b_converts_to[b_tp] = 0
172
- common = set(a_converts_to) & set(b_converts_to)
173
- if not common:
174
- raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
175
- return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
176
-
177
-
178
- def identity(x: object) -> object:
179
- return x
180
68
 
181
69
 
182
70
  def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
183
71
  """
184
72
  Resolves a type object into a type reference.
73
+
74
+ Any runtime type object decls will be add to those passed in.
185
75
  """
186
76
  if isinstance(tp, TypeVar):
187
77
  return ClassTypeVarRef(tp.__name__)
@@ -194,100 +84,92 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
194
84
  if tp == object:
195
85
  assert _PY_OBJECT_CLASS
196
86
  return resolve_type_annotation(decls, _PY_OBJECT_CLASS)
87
+ # If the type is a `Callable` then convert it into a UnstableFn
88
+ if get_origin(tp) == Callable:
89
+ assert _UNSTABLE_FN_CLASS
90
+ args, ret = get_args(tp)
91
+ return resolve_type_annotation(decls, _UNSTABLE_FN_CLASS[(ret, *args)])
197
92
  if isinstance(tp, RuntimeClass):
198
- decls |= tp
199
- return tp.__egg_tp__.to_var()
200
- if isinstance(tp, RuntimeParamaterizedClass):
201
93
  decls |= tp
202
94
  return tp.__egg_tp__
203
95
  raise TypeError(f"Unexpected type annotation {tp}")
204
96
 
205
97
 
206
- def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
207
- arg_type = _get_tp(arg)
208
-
209
- # If we have any type variables, dont bother trying to resolve the literal, just return the arg
210
- try:
211
- tp_just = tp.to_just()
212
- except NotImplementedError:
213
- # If this is a var, it has to be a runtime exprssions
214
- assert isinstance(arg, RuntimeExpr)
215
- return arg
216
- if arg_type == tp_just:
217
- # If the type is an egg type, it has to be a runtime expr
218
- assert isinstance(arg, RuntimeExpr)
219
- return arg
220
- # Try all parent types as well, if we are converting from a Python type
221
- for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
222
- try:
223
- fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1]
224
- except KeyError:
225
- continue
226
- break
227
- else:
228
- arg_type_str = arg_type.pretty() if isinstance(arg_type, JustTypeRef) else arg_type.__name__
229
- raise ConvertError(f"Cannot convert {arg_type_str} to {tp_just.pretty()}")
230
- return fn(arg)
231
-
232
-
233
- def _get_tp(x: object) -> JustTypeRef | type:
234
- if isinstance(x, RuntimeExpr):
235
- return x.__egg_typed_expr__.tp
236
- tp = type(x)
237
- # If this value has a custom metaclass, let's use that as our index instead of the type
238
- if type(tp) != type:
239
- return type(tp)
240
- return tp
241
-
242
-
243
98
  ##
244
99
  # Runtime objects
245
100
  ##
246
101
 
247
102
 
248
103
  @dataclass
249
- class RuntimeClass:
250
- # Pass in a constructor to make the declarations lazy, so we can have classes reference each other in their type constructors
251
- # This function should mutate the declerations and add to them
252
- # Used this instead of a lazy property so we can have a reference to the decls in the class as its computing
253
- lazy_decls: Callable[[Declarations], None] = field(repr=False)
254
- # Cached declerations
255
- _inner_decls: Declarations | None = field(init=False, repr=False, default=None)
256
- __egg_name__: str
104
+ class RuntimeClass(DelayedDeclerations):
105
+ __egg_tp__: TypeRefWithVars
257
106
 
258
107
  def __post_init__(self) -> None:
259
- global _PY_OBJECT_CLASS
260
- if self.__egg_name__ == "PyObject":
108
+ global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
109
+ if (name := self.__egg_tp__.name) == "PyObject":
261
110
  _PY_OBJECT_CLASS = self
111
+ elif name == "UnstableFn" and not self.__egg_tp__.args:
112
+ _UNSTABLE_FN_CLASS = self
262
113
 
263
- @property
264
- def __egg_decls__(self) -> Declarations:
265
- if self._inner_decls is None:
266
- # Set it like this so we can have a reference to the decls in the class as its computing
267
- self._inner_decls = Declarations()
268
- self.lazy_decls(self._inner_decls)
269
- return self._inner_decls
114
+ def verify(self) -> None:
115
+ if not self.__egg_tp__.args:
116
+ return
117
+
118
+ # Raise error if we have args, but they are the wrong number
119
+ desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
120
+ if len(self.__egg_tp__.args) != len(desired_args):
121
+ raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
270
122
 
271
123
  def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
272
124
  """
273
125
  Create an instance of this kind by calling the __init__ classmethod
274
126
  """
275
127
  # If this is a literal type, initializing it with a literal should return a literal
276
- if self.__egg_name__ == "PyObject":
128
+ if (name := self.__egg_tp__.name) == "PyObject":
277
129
  assert len(args) == 1
278
- return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, PyObjectDecl(args[0])))
279
- if self.__egg_name__ in UNARY_LIT_CLASS_NAMES:
130
+ return RuntimeExpr.__from_value__(
131
+ self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
132
+ )
133
+ if name == "UnstableFn":
134
+ assert not kwargs
135
+ fn_arg, *partial_args = args
136
+ del args
137
+ # Assumes we don't have types set for UnstableFn w/ generics, that they have to be inferred
138
+
139
+ # 1. Create a runtime function for the first arg
140
+ assert isinstance(fn_arg, RuntimeFunction)
141
+ # 2. Call it with the partial args, and use untyped vars for the rest of the args
142
+ res = fn_arg(*partial_args, _egg_partial_function=True)
143
+ assert res is not None, "Mutable partial functions not supported"
144
+ # 3. Use the inferred return type and inferred rest arg types as the types of the function, and
145
+ # the partially applied args as the args.
146
+ call = (res_typed_expr := res.__egg_typed_expr__).expr
147
+ return_tp = res_typed_expr.tp
148
+ assert isinstance(call, CallDecl), "partial function must be a call"
149
+ n_args = len(partial_args)
150
+ value = PartialCallDecl(replace(call, args=call.args[:n_args]))
151
+ remaining_arg_types = [a.tp for a in call.args[n_args:]]
152
+ type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
153
+ return RuntimeExpr.__from_value__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
154
+
155
+ if name in UNARY_LIT_CLASS_NAMES:
280
156
  assert len(args) == 1
281
157
  assert isinstance(args[0], int | float | str | bool)
282
- return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, LitDecl(args[0])))
283
- if self.__egg_name__ == UNIT_CLASS_NAME:
158
+ return RuntimeExpr.__from_value__(
159
+ self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
160
+ )
161
+ if name == UNIT_CLASS_NAME:
284
162
  assert len(args) == 0
285
- return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, LitDecl(None)))
163
+ return RuntimeExpr.__from_value__(
164
+ self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
165
+ )
286
166
 
287
- return RuntimeClassMethod(self.__egg_decls__, self.__egg_tp__, "__init__")(*args, **kwargs)
167
+ return RuntimeFunction(
168
+ Thunk.value(self.__egg_decls__), ClassMethodRef(name, "__init__"), self.__egg_tp__.to_just()
169
+ )(*args, **kwargs) # type: ignore[arg-type]
288
170
 
289
171
  def __dir__(self) -> list[str]:
290
- cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__)
172
+ cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
291
173
  possible_methods = (
292
174
  list(cls_decl.class_methods) + list(cls_decl.class_variables) + list(cls_decl.preserved_methods)
293
175
  )
@@ -296,14 +178,19 @@ class RuntimeClass:
296
178
  possible_methods.append("__call__")
297
179
  return possible_methods
298
180
 
299
- def __getitem__(self, args: object) -> RuntimeParamaterizedClass:
181
+ def __getitem__(self, args: object) -> RuntimeClass:
182
+ if self.__egg_tp__.args:
183
+ raise TypeError(f"Cannot index into a paramaterized class {self}")
300
184
  if not isinstance(args, tuple):
301
185
  args = (args,)
302
186
  decls = self.__egg_decls__.copy()
303
- tp = TypeRefWithVars(self.__egg_name__, tuple(resolve_type_annotation(decls, arg) for arg in args))
304
- return RuntimeParamaterizedClass(self.__egg_decls__, tp)
187
+ tp = TypeRefWithVars(self.__egg_tp__.name, tuple(resolve_type_annotation(decls, arg) for arg in args))
188
+ return RuntimeClass(Thunk.value(decls), tp)
189
+
190
+ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
191
+ if name == "__origin__" and self.__egg_tp__.args:
192
+ return RuntimeClass(self.__egg_decls_thunk__, TypeRefWithVars(self.__egg_tp__.name))
305
193
 
306
- def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr | Callable:
307
194
  # Special case some names that don't exist so we can exit early without resolving decls
308
195
  # Important so if we take union of RuntimeClass it won't try to resolve decls
309
196
  if name in {
@@ -314,7 +201,7 @@ class RuntimeClass:
314
201
  }:
315
202
  raise AttributeError
316
203
 
317
- cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__)
204
+ cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
318
205
 
319
206
  preserved_methods = cls_decl.preserved_methods
320
207
  if name in preserved_methods:
@@ -323,159 +210,151 @@ class RuntimeClass:
323
210
  # if this is a class variable, return an expr for it, otherwise, assume it's a method
324
211
  if name in cls_decl.class_variables:
325
212
  return_tp = cls_decl.class_variables[name]
326
- return RuntimeExpr(
327
- self.__egg_decls__, TypedExprDecl(return_tp, CallDecl(ClassVariableRef(self.__egg_name__, name)))
213
+ return RuntimeExpr.__from_value__(
214
+ self.__egg_decls__,
215
+ TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
328
216
  )
329
- return RuntimeClassMethod(self.__egg_decls__, self.__egg_tp__, name)
217
+ if name in cls_decl.class_methods:
218
+ return RuntimeFunction(
219
+ Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
220
+ )
221
+ # allow referencing properties and methods as class variables as well
222
+ if name in cls_decl.properties:
223
+ return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(self.__egg_tp__.name, name))
224
+ if name in cls_decl.methods:
225
+ return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(self.__egg_tp__.name, name))
226
+
227
+ msg = f"Class {self.__egg_tp__.name} has no method {name}"
228
+ if name == "__ne__":
229
+ msg += ". Did you mean to use the ne(...).to(...)?"
230
+ raise AttributeError(msg) from None
330
231
 
331
232
  def __str__(self) -> str:
332
- return self.__egg_name__
233
+ return str(self.__egg_tp__)
333
234
 
334
235
  # Make hashable so can go in Union
335
236
  def __hash__(self) -> int:
336
- return hash((id(self.lazy_decls), self.__egg_name__))
237
+ return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
337
238
 
338
239
  # Support unioning like types
339
240
  def __or__(self, __value: type) -> object:
340
241
  return Union[self, __value] # noqa: UP007
341
242
 
342
- @property
343
- def __egg_tp__(self) -> JustTypeRef:
344
- return JustTypeRef(self.__egg_name__)
345
-
346
-
347
- @dataclass
348
- class RuntimeParamaterizedClass:
349
- __egg_decls__: Declarations
350
- __egg_tp__: TypeRefWithVars
351
-
352
- def __post_init__(self) -> None:
353
- desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
354
- if len(self.__egg_tp__.args) != len(desired_args):
355
- raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
356
-
357
- def __call__(self, *args: object) -> RuntimeExpr | None:
358
- return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), "__init__")(*args)
359
-
360
- def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeClass:
361
- # Special case so when get_type_annotations proccessed it can work
362
- if name == "__origin__":
363
- return RuntimeClass(self.__egg_decls__.update_other, self.__egg_tp__.name)
364
- return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), name)
365
-
366
- def __str__(self) -> str:
367
- return self.__egg_tp__.pretty()
368
-
369
- # Support unioning
370
- def __or__(self, __value: type) -> object:
371
- return Union[self, __value] # noqa: UP007
372
-
373
-
374
- # Type args can either be typevars or classes
375
- RuntimeTypeArgType = RuntimeClass | RuntimeParamaterizedClass
376
-
377
-
378
- def class_to_ref(cls: RuntimeTypeArgType) -> JustTypeRef:
379
- if isinstance(cls, RuntimeClass):
380
- return JustTypeRef(cls.__egg_name__)
381
- if isinstance(cls, RuntimeParamaterizedClass):
382
- # Currently this is used when calling methods on a parametrized class, which is only possible when we
383
- # have actualy types currently, not typevars, currently.
384
- return cls.__egg_tp__.to_just()
385
- assert_never(cls)
386
-
387
243
 
388
244
  @dataclass
389
- class RuntimeFunction:
390
- __egg_decls__: Declarations
391
- __egg_name__: str
392
- __egg_fn_ref__: FunctionRef = field(init=False)
393
- __egg_fn_decl__: FunctionDecl = field(init=False)
394
-
395
- def __post_init__(self) -> None:
396
- self.__egg_fn_ref__ = FunctionRef(self.__egg_name__)
397
- self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_fn_ref__)
398
-
399
- def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
400
- return _call(self.__egg_decls__, self.__egg_fn_ref__, self.__egg_fn_decl__, args, kwargs)
401
-
402
- def __str__(self) -> str:
403
- return self.__egg_name__
404
-
405
-
406
- def _call(
407
- decls_from_fn: Declarations,
408
- callable_ref: CallableRef,
409
- fn_decl: FunctionDecl,
410
- args: Collection[object],
411
- kwargs: dict[str, object],
412
- bound_class: JustTypeRef | None = None,
413
- ) -> RuntimeExpr | None:
414
- # Turn all keyword args into positional args
415
- bound = fn_decl.to_signature(lambda expr: RuntimeExpr(decls_from_fn, expr)).bind(*args, **kwargs)
416
- bound.apply_defaults()
417
- assert not bound.kwargs
418
- del args, kwargs
419
-
420
- upcasted_args = [
421
- resolve_literal(cast(TypeOrVarRef, tp), arg)
422
- for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type)
423
- ]
245
+ class RuntimeFunction(DelayedDeclerations):
246
+ __egg_ref__: CallableRef
247
+ # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
248
+ __egg_bound__: JustTypeRef | RuntimeExpr | None = None
424
249
 
425
- arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
426
- decls = Declarations.create(decls_from_fn, *upcasted_args)
250
+ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
251
+ from .conversion import resolve_literal
427
252
 
428
- tcs = TypeConstraintSolver(decls)
429
- if bound_class is not None and bound_class.args:
430
- tcs.bind_class(bound_class)
431
-
432
- if fn_decl is not None:
253
+ if isinstance(self.__egg_bound__, RuntimeExpr):
254
+ args = (self.__egg_bound__, *args)
255
+ signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
256
+ decls = self.__egg_decls__.copy()
257
+ # Special case function application bc we dont support variadic generics yet generally
258
+ if signature == "fn-app":
259
+ fn, *rest_args = args
260
+ args = tuple(rest_args)
261
+ assert not kwargs
262
+ assert isinstance(fn, RuntimeExpr)
263
+ decls.update(fn)
264
+ function_value = fn.__egg_typed_expr__
265
+ fn_tp = function_value.tp
266
+ assert fn_tp.name == "UnstableFn"
267
+ fn_return_tp, *fn_arg_tps = fn_tp.args
268
+ signature = FunctionSignature(
269
+ tuple(tp.to_var() for tp in fn_arg_tps),
270
+ tuple(f"_{i}" for i in range(len(fn_arg_tps))),
271
+ (None,) * len(fn_arg_tps),
272
+ fn_return_tp.to_var(),
273
+ )
274
+ else:
275
+ function_value = None
276
+ assert isinstance(signature, FunctionSignature)
277
+
278
+ # Turn all keyword args into positional args
279
+ py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
280
+ bound = py_signature.bind(*args, **kwargs)
281
+ del kwargs
282
+ bound.apply_defaults()
283
+ assert not bound.kwargs
284
+ args = bound.args
285
+
286
+ upcasted_args = [
287
+ resolve_literal(cast(TypeOrVarRef, tp), arg)
288
+ for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
289
+ ]
290
+ decls.update(*upcasted_args)
291
+
292
+ tcs = TypeConstraintSolver(decls)
293
+ bound_tp = (
294
+ None
295
+ if self.__egg_bound__ is None
296
+ else self.__egg_bound__.__egg_typed_expr__.tp
297
+ if isinstance(self.__egg_bound__, RuntimeExpr)
298
+ else self.__egg_bound__
299
+ )
300
+ if (
301
+ bound_tp
302
+ and bound_tp.args
303
+ # Don't bind class if we have a first class function arg, b/c we don't support that yet
304
+ and not function_value
305
+ ):
306
+ tcs.bind_class(bound_tp)
307
+ arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
433
308
  arg_types = [expr.tp for expr in arg_exprs]
434
- cls_name = bound_class.name if bound_class is not None else None
309
+ cls_name = bound_tp.name if bound_tp else None
435
310
  return_tp = tcs.infer_return_type(
436
- fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types, cls_name
311
+ signature.arg_types, signature.semantic_return_type, signature.var_arg_type, arg_types, cls_name
437
312
  )
438
- else:
439
- return_tp = JustTypeRef("Unit")
440
- bound_params = cast(JustTypeRef, bound_class).args if isinstance(callable_ref, ClassMethodRef) else None
441
- expr_decl = CallDecl(callable_ref, arg_exprs, bound_params)
442
- typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
443
- # Register return type sort in case it's a variadic generic that needs to be created
444
- decls.register_sort(return_tp, False)
445
- if fn_decl.mutates_first_arg:
446
- first_arg = upcasted_args[0]
447
- first_arg.__egg_typed_expr__ = typed_expr_decl
448
- first_arg.__egg_decls__ = decls
449
- return None
450
- return RuntimeExpr(decls, TypedExprDecl(return_tp, expr_decl))
451
-
313
+ bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None
314
+ # If we were using unstable-app to call a funciton, add that function back as the first arg.
315
+ if function_value:
316
+ arg_exprs = (function_value, *arg_exprs)
317
+ expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
318
+ typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
319
+ # If there is not return type, we are mutating the first arg
320
+ if not signature.return_type:
321
+ first_arg = upcasted_args[0]
322
+ first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
323
+ return None
324
+ return RuntimeExpr.__from_value__(decls, typed_expr_decl)
452
325
 
453
- @dataclass
454
- class RuntimeClassMethod:
455
- __egg_decls__: Declarations
456
- __egg_tp__: JustTypeRef
457
- __egg_method_name__: str
458
- __egg_callable_ref__: ClassMethodRef = field(init=False)
459
- __egg_fn_decl__: FunctionDecl = field(init=False)
326
+ def __str__(self) -> str:
327
+ first_arg, bound_tp_params = None, None
328
+ match self.__egg_bound__:
329
+ case RuntimeExpr(_):
330
+ first_arg = self.__egg_bound__.__egg_typed_expr__.expr
331
+ case JustTypeRef(_, args):
332
+ bound_tp_params = args
333
+ return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
460
334
 
461
- def __post_init__(self) -> None:
462
- self.__egg_callable_ref__ = ClassMethodRef(self.class_name, self.__egg_method_name__)
463
- try:
464
- self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
465
- except KeyError as e:
466
- raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") from e
467
335
 
468
- def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None:
469
- return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs, self.__egg_tp__)
336
+ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
337
+ """
338
+ Convert to a Python signature.
470
339
 
471
- def __str__(self) -> str:
472
- return f"{self.class_name}.{self.__egg_method_name__}"
340
+ If optional_args is true, then all args will be treated as optional, as if a default was provided that makes them
341
+ a var with that arg name as the value.
473
342
 
474
- @property
475
- def class_name(self) -> str:
476
- if isinstance(self.__egg_tp__, str):
477
- return self.__egg_tp__
478
- return self.__egg_tp__.name
343
+ Used for partial application to try binding a function with only some of its args.
344
+ """
345
+ parameters = [
346
+ Parameter(
347
+ n,
348
+ Parameter.POSITIONAL_OR_KEYWORD,
349
+ default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n)))
350
+ if d is not None or optional_args
351
+ else Parameter.empty,
352
+ )
353
+ for n, d, t in zip(sig.arg_names, sig.arg_defaults, sig.arg_types, strict=True)
354
+ ]
355
+ if isinstance(sig, FunctionSignature) and sig.var_arg_type is not None:
356
+ parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
357
+ return Signature(parameters)
479
358
 
480
359
 
481
360
  # All methods which should return NotImplemented if they fail to resolve
@@ -505,63 +384,34 @@ PARTIAL_METHODS = {
505
384
 
506
385
 
507
386
  @dataclass
508
- class RuntimeMethod:
509
- __egg_self__: RuntimeExpr
510
- __egg_method_name__: str
511
- __egg_callable_ref__: MethodRef | PropertyRef = field(init=False)
512
- __egg_fn_decl__: FunctionDecl = field(init=False, repr=False)
513
- __egg_decls__: Declarations = field(init=False)
514
-
515
- def __post_init__(self) -> None:
516
- self.__egg_decls__ = self.__egg_self__.__egg_decls__
517
- if self.__egg_method_name__ in self.__egg_decls__.get_class_decl(self.class_name).properties:
518
- self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__)
519
- else:
520
- self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__)
521
- try:
522
- self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
523
- except KeyError:
524
- msg = f"Class {self.class_name} does not have method {self.__egg_method_name__}"
525
- if self.__egg_method_name__ == "__ne__":
526
- msg += ". Did you mean to use the ne(...).to(...)?"
527
- raise AttributeError(msg) from None
387
+ class RuntimeExpr:
388
+ # Defer needing decls/expr so we can make constants that don't resolve their class types
389
+ __egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
528
390
 
529
- def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None:
530
- args = (self.__egg_self__, *args)
531
- try:
532
- return _call(
533
- self.__egg_decls__,
534
- self.__egg_callable_ref__,
535
- self.__egg_fn_decl__,
536
- args,
537
- kwargs,
538
- self.__egg_self__.__egg_typed_expr__.tp,
539
- )
540
- except ConvertError as e:
541
- name = self.__egg_method_name__
542
- raise TypeError(f"Wrong types for {self.__egg_self__.__egg_typed_expr__.tp.pretty()}.{name}") from e
391
+ @classmethod
392
+ def __from_value__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
393
+ return cls(Thunk.value((d, e)))
543
394
 
544
395
  @property
545
- def class_name(self) -> str:
546
- return self.__egg_self__.__egg_typed_expr__.tp.name
547
-
396
+ def __egg_decls__(self) -> Declarations:
397
+ return self.__egg_thunk__()[0]
548
398
 
549
- @dataclass
550
- class RuntimeExpr:
551
- __egg_decls__: Declarations
552
- __egg_typed_expr__: TypedExprDecl
399
+ @property
400
+ def __egg_typed_expr__(self) -> TypedExprDecl:
401
+ return self.__egg_thunk__()[1]
553
402
 
554
- def __getattr__(self, name: str) -> RuntimeMethod | RuntimeExpr | Callable | None:
555
- class_decl = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name)
403
+ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
404
+ cls_name = self.__egg_class_name__
405
+ class_decl = self.__egg_class_decl__
556
406
 
557
- preserved_methods = class_decl.preserved_methods
558
- if name in preserved_methods:
407
+ if name in (preserved_methods := class_decl.preserved_methods):
559
408
  return preserved_methods[name].__get__(self)
560
409
 
561
- method = RuntimeMethod(self, name)
562
- if isinstance(method.__egg_callable_ref__, PropertyRef):
563
- return method()
564
- return method
410
+ if name in class_decl.methods:
411
+ return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(cls_name, name), self)
412
+ if name in class_decl.properties:
413
+ return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(cls_name, name), self)()
414
+ raise AttributeError(f"{cls_name} has no method {name}") from None
565
415
 
566
416
  def __repr__(self) -> str:
567
417
  """
@@ -570,18 +420,10 @@ class RuntimeExpr:
570
420
  return str(self)
571
421
 
572
422
  def __str__(self) -> str:
573
- context = PrettyContext(self.__egg_decls__)
574
- context.traverse_for_parents(self.__egg_typed_expr__.expr)
575
- pretty_expr = self.__egg_typed_expr__.expr.pretty(context, parens=False)
576
- try:
577
- if config.SHOW_TYPES:
578
- raise NotImplementedError
579
- # s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}"
580
- # return black.format_str(s, mode=black.FileMode()).strip()
581
- pretty_statements = context.render(pretty_expr)
582
- return black.format_str(pretty_statements, mode=BLACK_MODE).strip()
583
- except black.parsing.InvalidInput:
584
- return pretty_expr
423
+ return self.__egg_pretty__(None)
424
+
425
+ def __egg_pretty__(self, wrapping_fn: str | None) -> str:
426
+ return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
585
427
 
586
428
  def _ipython_display_(self) -> None:
587
429
  from IPython.display import Code, display
@@ -589,28 +431,32 @@ class RuntimeExpr:
589
431
  display(Code(str(self), language="python"))
590
432
 
591
433
  def __dir__(self) -> Iterable[str]:
592
- return list(self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).methods)
434
+ class_decl = self.__egg_class_decl__
435
+ return list(class_decl.methods) + list(class_decl.properties) + list(class_decl.preserved_methods)
593
436
 
594
437
  @property
595
- def __egg__(self) -> bindings._Expr:
596
- return self.__egg_typed_expr__.to_egg(self.__egg_decls__)
438
+ def __egg_class_name__(self) -> str:
439
+ return self.__egg_typed_expr__.tp.name
440
+
441
+ @property
442
+ def __egg_class_decl__(self) -> ClassDecl:
443
+ return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
597
444
 
598
445
  # Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because
599
446
  # we don't wany any type that MyPy thinks is an expr to be used with __eq__.
600
447
  # That's because we want to reserve __eq__ for domain specific equality checks, overloading this method.
601
448
  # To check if two exprs are equal, use the expr_eq method.
602
- def __eq__(self, other: NoReturn) -> Expr: # type: ignore[override]
603
- msg = "Do not use == on RuntimeExpr. Compare the __egg_typed_expr__ attribute instead for structural equality."
604
- raise NotImplementedError(msg)
449
+ # At runtime, this will resolve if there is a defined egg function for `__eq__`
450
+ def __eq__(self, other: NoReturn) -> Expr: ... # type: ignore[override, empty-body]
605
451
 
606
452
  # Implement these so that copy() works on this object
607
453
  # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
608
454
 
609
455
  def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
610
- return (self.__egg_decls__, self.__egg_typed_expr__)
456
+ return self.__egg_thunk__()
611
457
 
612
458
  def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
613
- self.__egg_decls__, self.__egg_typed_expr__ = d
459
+ self.__egg_thunk__ = Thunk.value(d)
614
460
 
615
461
  def __hash__(self) -> int:
616
462
  return hash(self.__egg_typed_expr__)
@@ -625,12 +471,17 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
625
471
  __name: str = name,
626
472
  **kwargs: object,
627
473
  ) -> RuntimeExpr | None:
474
+ from .conversion import ConvertError
475
+
476
+ class_name = self.__egg_class_name__
477
+ class_decl = self.__egg_class_decl__
628
478
  # First, try to resolve as preserved method
629
479
  try:
630
- method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name]
631
- return method(self, *args, **kwargs)
480
+ method = class_decl.preserved_methods[__name]
632
481
  except KeyError:
633
482
  pass
483
+ else:
484
+ return method(self, *args, **kwargs)
634
485
  # If this is a "partial" method meaning that it can return NotImplemented,
635
486
  # we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
636
487
  # using the arg type of the self arg.
@@ -639,8 +490,15 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
639
490
  try:
640
491
  return call_method_min_conversion(self, args[0], __name)
641
492
  except ConvertError:
642
- return NotImplemented
643
- return RuntimeMethod(self, __name)(*args, **kwargs)
493
+ # Defer raising not imeplemented in case the dunder method is not symmetrical, then
494
+ # we use the standard process
495
+ pass
496
+ if __name in class_decl.methods:
497
+ fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
498
+ return fn(*args, **kwargs) # type: ignore[arg-type]
499
+ if __name in PARTIAL_METHODS:
500
+ return NotImplemented
501
+ raise TypeError(f"{class_name!r} object does not support {__name}")
644
502
 
645
503
  setattr(RuntimeExpr, name, _special_method)
646
504
 
@@ -655,12 +513,14 @@ for reflected, non_reflected in REFLECTED_BINARY_METHODS.items():
655
513
 
656
514
 
657
515
  def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None:
516
+ from .conversion import min_convertable_tp, resolve_literal
517
+
658
518
  # find a minimum type that both can be converted to
659
519
  # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
660
520
  min_tp = min_convertable_tp(slf, other, name)
661
- slf = resolve_literal(min_tp.to_var(), slf)
662
- other = resolve_literal(min_tp.to_var(), other)
663
- method = RuntimeMethod(slf, name)
521
+ slf = resolve_literal(TypeRefWithVars(min_tp), slf)
522
+ other = resolve_literal(TypeRefWithVars(min_tp), other)
523
+ method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
664
524
  return method(other)
665
525
 
666
526
 
@@ -680,21 +540,9 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
680
540
  """
681
541
  Resolves a runtime callable into a ref
682
542
  """
683
- # TODO: Fix these typings.
684
- ref: CallableRef
685
- decls: Declarations
686
- if isinstance(callable, RuntimeFunction):
687
- ref = FunctionRef(callable.__egg_name__)
688
- decls = callable.__egg_decls__
689
- elif isinstance(callable, RuntimeClassMethod):
690
- ref = ClassMethodRef(callable.class_name, callable.__egg_method_name__)
691
- decls = callable.__egg_decls__
692
- elif isinstance(callable, RuntimeMethod):
693
- ref = MethodRef(callable.__egg_self__.__egg_typed_expr__.tp.name, callable.__egg_method_name__)
694
- decls = callable.__egg_decls__
695
- elif isinstance(callable, RuntimeClass):
696
- ref = ClassMethodRef(callable.__egg_name__, "__init__")
697
- decls = callable.__egg_decls__
698
- else:
699
- raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
700
- return (ref, decls)
543
+ match callable:
544
+ case RuntimeFunction(decls, ref, _):
545
+ return ref, decls()
546
+ case RuntimeClass(thunk, tp):
547
+ return ClassMethodRef(tp.name, "__init__"), thunk()
548
+ raise NotImplementedError(f"Cannot turn {callable} into a callable ref")