egglog 6.1.0__cp310-none-win_amd64.whl → 7.0.0__cp310-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/runtime.py CHANGED
@@ -11,172 +11,57 @@ so they are not mangled by Python and can be accessed by the user.
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from dataclasses import dataclass, field
14
+ from dataclasses import dataclass
15
+ from inspect import Parameter, Signature
15
16
  from itertools import zip_longest
16
17
  from typing import TYPE_CHECKING, NoReturn, TypeVar, Union, cast, get_args, get_origin
17
18
 
18
- import black
19
- import black.parsing
20
- from typing_extensions import assert_never
21
-
22
- from . import bindings, config
23
19
  from .declarations import *
24
- from .declarations import BINARY_METHODS, REFLECTED_BINARY_METHODS, UNARY_METHODS
20
+ from .pretty import *
21
+ from .thunk import Thunk
25
22
  from .type_constraint_solver import *
26
23
 
27
24
  if TYPE_CHECKING:
28
- from collections.abc import Callable, Collection, Iterable
25
+ from collections.abc import Callable, Iterable
29
26
 
30
27
  from .egraph import Expr
31
28
 
32
29
  __all__ = [
33
30
  "LIT_CLASS_NAMES",
34
- "class_to_ref",
35
- "resolve_literal",
36
31
  "resolve_callable",
37
32
  "resolve_type_annotation",
38
- "convert_to_same_type",
39
33
  "RuntimeClass",
40
- "RuntimeParamaterizedClass",
41
- "RuntimeClassMethod",
42
34
  "RuntimeExpr",
43
35
  "RuntimeFunction",
44
- "convert",
45
- "converter",
36
+ "REFLECTED_BINARY_METHODS",
46
37
  ]
47
38
 
48
39
 
49
- BLACK_MODE = black.Mode(line_length=180)
50
-
51
40
  UNIT_CLASS_NAME = "Unit"
52
41
  UNARY_LIT_CLASS_NAMES = {"i64", "f64", "Bool", "String"}
53
42
  LIT_CLASS_NAMES = UNARY_LIT_CLASS_NAMES | {UNIT_CLASS_NAME, "PyObject"}
54
43
 
44
+ REFLECTED_BINARY_METHODS = {
45
+ "__radd__": "__add__",
46
+ "__rsub__": "__sub__",
47
+ "__rmul__": "__mul__",
48
+ "__rmatmul__": "__matmul__",
49
+ "__rtruediv__": "__truediv__",
50
+ "__rfloordiv__": "__floordiv__",
51
+ "__rmod__": "__mod__",
52
+ "__rpow__": "__pow__",
53
+ "__rlshift__": "__lshift__",
54
+ "__rrshift__": "__rshift__",
55
+ "__rand__": "__and__",
56
+ "__rxor__": "__xor__",
57
+ "__ror__": "__or__",
58
+ }
59
+
55
60
  # Set this globally so we can get access to PyObject when we have a type annotation of just object.
56
61
  # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
57
62
  _PY_OBJECT_CLASS: RuntimeClass | None = None
58
63
 
59
- ##
60
- # Converters
61
- ##
62
-
63
- # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
64
- CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
65
- # Global declerations to store all convertable types so we can query if they have certain methods or not
66
- CONVERSIONS_DECLS = Declarations()
67
-
68
64
  T = TypeVar("T")
69
- V = TypeVar("V", bound="Expr")
70
-
71
-
72
- class ConvertError(Exception):
73
- pass
74
-
75
-
76
- def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost: int = 1) -> None:
77
- """
78
- Register a converter from some type to an egglog type.
79
- """
80
- to_type_name = process_tp(to_type)
81
- if not isinstance(to_type_name, JustTypeRef):
82
- raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
83
- _register_converter(process_tp(from_type), to_type_name, fn, cost)
84
-
85
-
86
- def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
87
- """
88
- Registers a converter from some type to an egglog type, if not already registered.
89
-
90
- Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
91
- Also, if registering A->B and there is already D->A, then D->B will be registered.
92
- """
93
- if a == b:
94
- return
95
- if (a, b) in CONVERSIONS and CONVERSIONS[(a, b)][0] <= cost:
96
- return
97
- CONVERSIONS[(a, b)] = (cost, a_b)
98
- for (c, d), (other_cost, c_d) in list(CONVERSIONS.items()):
99
- if b == c:
100
- _register_converter(a, d, _ComposedConverter(a_b, c_d), cost + other_cost)
101
- if a == d:
102
- _register_converter(c, b, _ComposedConverter(c_d, a_b), cost + other_cost)
103
-
104
-
105
- @dataclass
106
- class _ComposedConverter:
107
- """
108
- A converter which is composed of multiple converters.
109
-
110
- _ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))
111
-
112
- We use the dataclass instead of the lambda to make it easier to debug.
113
- """
114
-
115
- a_b: Callable
116
- b_c: Callable
117
-
118
- def __call__(self, x: object) -> object:
119
- return self.b_c(self.a_b(x))
120
-
121
- def __str__(self) -> str:
122
- return f"{self.b_c} ∘ {self.a_b}"
123
-
124
-
125
- def convert(source: object, target: type[V]) -> V:
126
- """
127
- Convert a source object to a target type.
128
- """
129
- target_ref = class_to_ref(cast(RuntimeTypeArgType, target))
130
- return cast(V, resolve_literal(target_ref.to_var(), source))
131
-
132
-
133
- def convert_to_same_type(source: object, target: RuntimeExpr) -> RuntimeExpr:
134
- """
135
- Convert a source object to the same type as the target.
136
- """
137
- tp = target.__egg_typed_expr__.tp
138
- return resolve_literal(tp.to_var(), source)
139
-
140
-
141
- def process_tp(tp: type | RuntimeTypeArgType) -> JustTypeRef | type:
142
- """
143
- Process a type before converting it, to add it to the global declerations and resolve to a ref.
144
- """
145
- global CONVERSIONS_DECLS
146
- if isinstance(tp, RuntimeClass | RuntimeParamaterizedClass):
147
- CONVERSIONS_DECLS |= tp
148
- return class_to_ref(tp)
149
- return tp
150
-
151
-
152
- def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
153
- """
154
- Returns the minimum convertable type between a and b, that has a method `name`, raising a TypeError if no such type exists.
155
- """
156
- a_tp = _get_tp(a)
157
- b_tp = _get_tp(b)
158
- a_converts_to = {
159
- to: c
160
- for ((from_, to), (c, _)) in CONVERSIONS.items()
161
- if from_ == a_tp and CONVERSIONS_DECLS.has_method(to.name, name)
162
- }
163
- b_converts_to = {
164
- to: c
165
- for ((from_, to), (c, _)) in CONVERSIONS.items()
166
- if from_ == b_tp and CONVERSIONS_DECLS.has_method(to.name, name)
167
- }
168
- if isinstance(a_tp, JustTypeRef):
169
- a_converts_to[a_tp] = 0
170
- if isinstance(b_tp, JustTypeRef):
171
- b_converts_to[b_tp] = 0
172
- common = set(a_converts_to) & set(b_converts_to)
173
- if not common:
174
- raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
175
- return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
176
-
177
-
178
- def identity(x: object) -> object:
179
- return x
180
65
 
181
66
 
182
67
  def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
@@ -195,99 +80,62 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
195
80
  assert _PY_OBJECT_CLASS
196
81
  return resolve_type_annotation(decls, _PY_OBJECT_CLASS)
197
82
  if isinstance(tp, RuntimeClass):
198
- decls |= tp
199
- return tp.__egg_tp__.to_var()
200
- if isinstance(tp, RuntimeParamaterizedClass):
201
83
  decls |= tp
202
84
  return tp.__egg_tp__
203
85
  raise TypeError(f"Unexpected type annotation {tp}")
204
86
 
205
87
 
206
- def resolve_literal(tp: TypeOrVarRef, arg: object) -> RuntimeExpr:
207
- arg_type = _get_tp(arg)
208
-
209
- # If we have any type variables, dont bother trying to resolve the literal, just return the arg
210
- try:
211
- tp_just = tp.to_just()
212
- except NotImplementedError:
213
- # If this is a var, it has to be a runtime exprssions
214
- assert isinstance(arg, RuntimeExpr)
215
- return arg
216
- if arg_type == tp_just:
217
- # If the type is an egg type, it has to be a runtime expr
218
- assert isinstance(arg, RuntimeExpr)
219
- return arg
220
- # Try all parent types as well, if we are converting from a Python type
221
- for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
222
- try:
223
- fn = CONVERSIONS[(cast(JustTypeRef | type, arg_type_instance), tp_just)][1]
224
- except KeyError:
225
- continue
226
- break
227
- else:
228
- arg_type_str = arg_type.pretty() if isinstance(arg_type, JustTypeRef) else arg_type.__name__
229
- raise ConvertError(f"Cannot convert {arg_type_str} to {tp_just.pretty()}")
230
- return fn(arg)
231
-
232
-
233
- def _get_tp(x: object) -> JustTypeRef | type:
234
- if isinstance(x, RuntimeExpr):
235
- return x.__egg_typed_expr__.tp
236
- tp = type(x)
237
- # If this value has a custom metaclass, let's use that as our index instead of the type
238
- if type(tp) != type:
239
- return type(tp)
240
- return tp
241
-
242
-
243
88
  ##
244
89
  # Runtime objects
245
90
  ##
246
91
 
247
92
 
248
93
  @dataclass
249
- class RuntimeClass:
250
- # Pass in a constructor to make the declarations lazy, so we can have classes reference each other in their type constructors
251
- # This function should mutate the declerations and add to them
252
- # Used this instead of a lazy property so we can have a reference to the decls in the class as its computing
253
- lazy_decls: Callable[[Declarations], None] = field(repr=False)
254
- # Cached declerations
255
- _inner_decls: Declarations | None = field(init=False, repr=False, default=None)
256
- __egg_name__: str
94
+ class RuntimeClass(DelayedDeclerations):
95
+ __egg_tp__: TypeRefWithVars
257
96
 
258
97
  def __post_init__(self) -> None:
259
98
  global _PY_OBJECT_CLASS
260
- if self.__egg_name__ == "PyObject":
99
+ if self.__egg_tp__.name == "PyObject":
261
100
  _PY_OBJECT_CLASS = self
262
101
 
263
- @property
264
- def __egg_decls__(self) -> Declarations:
265
- if self._inner_decls is None:
266
- # Set it like this so we can have a reference to the decls in the class as its computing
267
- self._inner_decls = Declarations()
268
- self.lazy_decls(self._inner_decls)
269
- return self._inner_decls
102
+ def verify(self) -> None:
103
+ if not self.__egg_tp__.args:
104
+ return
105
+
106
+ # Raise error if we have args, but they are the wrong number
107
+ desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
108
+ if len(self.__egg_tp__.args) != len(desired_args):
109
+ raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
270
110
 
271
111
  def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
272
112
  """
273
113
  Create an instance of this kind by calling the __init__ classmethod
274
114
  """
275
115
  # If this is a literal type, initializing it with a literal should return a literal
276
- if self.__egg_name__ == "PyObject":
116
+ if self.__egg_tp__.name == "PyObject":
277
117
  assert len(args) == 1
278
- return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, PyObjectDecl(args[0])))
279
- if self.__egg_name__ in UNARY_LIT_CLASS_NAMES:
118
+ return RuntimeExpr.__from_value__(
119
+ self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
120
+ )
121
+ if self.__egg_tp__.name in UNARY_LIT_CLASS_NAMES:
280
122
  assert len(args) == 1
281
123
  assert isinstance(args[0], int | float | str | bool)
282
- return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, LitDecl(args[0])))
283
- if self.__egg_name__ == UNIT_CLASS_NAME:
124
+ return RuntimeExpr.__from_value__(
125
+ self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
126
+ )
127
+ if self.__egg_tp__.name == UNIT_CLASS_NAME:
284
128
  assert len(args) == 0
285
- return RuntimeExpr(self.__egg_decls__, TypedExprDecl(self.__egg_tp__, LitDecl(None)))
129
+ return RuntimeExpr.__from_value__(
130
+ self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
131
+ )
286
132
 
287
- return RuntimeClassMethod(self.__egg_decls__, self.__egg_tp__, "__init__")(*args, **kwargs)
133
+ return RuntimeFunction(
134
+ Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, "__init__"), self.__egg_tp__.to_just()
135
+ )(*args, **kwargs)
288
136
 
289
137
  def __dir__(self) -> list[str]:
290
- cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__)
138
+ cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
291
139
  possible_methods = (
292
140
  list(cls_decl.class_methods) + list(cls_decl.class_variables) + list(cls_decl.preserved_methods)
293
141
  )
@@ -296,14 +144,19 @@ class RuntimeClass:
296
144
  possible_methods.append("__call__")
297
145
  return possible_methods
298
146
 
299
- def __getitem__(self, args: object) -> RuntimeParamaterizedClass:
147
+ def __getitem__(self, args: object) -> RuntimeClass:
148
+ if self.__egg_tp__.args:
149
+ raise TypeError(f"Cannot index into a paramaterized class {self}")
300
150
  if not isinstance(args, tuple):
301
151
  args = (args,)
302
152
  decls = self.__egg_decls__.copy()
303
- tp = TypeRefWithVars(self.__egg_name__, tuple(resolve_type_annotation(decls, arg) for arg in args))
304
- return RuntimeParamaterizedClass(self.__egg_decls__, tp)
153
+ tp = TypeRefWithVars(self.__egg_tp__.name, tuple(resolve_type_annotation(decls, arg) for arg in args))
154
+ return RuntimeClass(Thunk.value(decls), tp)
155
+
156
+ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable:
157
+ if name == "__origin__" and self.__egg_tp__.args:
158
+ return RuntimeClass(self.__egg_decls_thunk__, TypeRefWithVars(self.__egg_tp__.name))
305
159
 
306
- def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeExpr | Callable:
307
160
  # Special case some names that don't exist so we can exit early without resolving decls
308
161
  # Important so if we take union of RuntimeClass it won't try to resolve decls
309
162
  if name in {
@@ -314,7 +167,7 @@ class RuntimeClass:
314
167
  }:
315
168
  raise AttributeError
316
169
 
317
- cls_decl = self.__egg_decls__.get_class_decl(self.__egg_name__)
170
+ cls_decl = self.__egg_decls__._classes[self.__egg_tp__.name]
318
171
 
319
172
  preserved_methods = cls_decl.preserved_methods
320
173
  if name in preserved_methods:
@@ -323,159 +176,107 @@ class RuntimeClass:
323
176
  # if this is a class variable, return an expr for it, otherwise, assume it's a method
324
177
  if name in cls_decl.class_variables:
325
178
  return_tp = cls_decl.class_variables[name]
326
- return RuntimeExpr(
327
- self.__egg_decls__, TypedExprDecl(return_tp, CallDecl(ClassVariableRef(self.__egg_name__, name)))
179
+ return RuntimeExpr.__from_value__(
180
+ self.__egg_decls__,
181
+ TypedExprDecl(return_tp.type_ref, CallDecl(ClassVariableRef(self.__egg_tp__.name, name))),
328
182
  )
329
- return RuntimeClassMethod(self.__egg_decls__, self.__egg_tp__, name)
183
+ if name in cls_decl.class_methods:
184
+ return RuntimeFunction(
185
+ Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
186
+ )
187
+ msg = f"Class {self.__egg_tp__.name} has no method {name}"
188
+ if name == "__ne__":
189
+ msg += ". Did you mean to use the ne(...).to(...)?"
190
+ raise AttributeError(msg) from None
330
191
 
331
192
  def __str__(self) -> str:
332
- return self.__egg_name__
193
+ return str(self.__egg_tp__)
333
194
 
334
195
  # Make hashable so can go in Union
335
196
  def __hash__(self) -> int:
336
- return hash((id(self.lazy_decls), self.__egg_name__))
197
+ return hash((id(self.__egg_decls_thunk__), self.__egg_tp__))
337
198
 
338
199
  # Support unioning like types
339
200
  def __or__(self, __value: type) -> object:
340
201
  return Union[self, __value] # noqa: UP007
341
202
 
342
- @property
343
- def __egg_tp__(self) -> JustTypeRef:
344
- return JustTypeRef(self.__egg_name__)
345
-
346
-
347
- @dataclass
348
- class RuntimeParamaterizedClass:
349
- __egg_decls__: Declarations
350
- __egg_tp__: TypeRefWithVars
351
-
352
- def __post_init__(self) -> None:
353
- desired_args = self.__egg_decls__.get_class_decl(self.__egg_tp__.name).type_vars
354
- if len(self.__egg_tp__.args) != len(desired_args):
355
- raise ValueError(f"Expected {desired_args} type args, got {len(self.__egg_tp__.args)}")
356
-
357
- def __call__(self, *args: object) -> RuntimeExpr | None:
358
- return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), "__init__")(*args)
359
-
360
- def __getattr__(self, name: str) -> RuntimeClassMethod | RuntimeClass:
361
- # Special case so when get_type_annotations proccessed it can work
362
- if name == "__origin__":
363
- return RuntimeClass(self.__egg_decls__.update_other, self.__egg_tp__.name)
364
- return RuntimeClassMethod(self.__egg_decls__, class_to_ref(self), name)
365
-
366
- def __str__(self) -> str:
367
- return self.__egg_tp__.pretty()
368
-
369
- # Support unioning
370
- def __or__(self, __value: type) -> object:
371
- return Union[self, __value] # noqa: UP007
372
-
373
-
374
- # Type args can either be typevars or classes
375
- RuntimeTypeArgType = RuntimeClass | RuntimeParamaterizedClass
376
-
377
-
378
- def class_to_ref(cls: RuntimeTypeArgType) -> JustTypeRef:
379
- if isinstance(cls, RuntimeClass):
380
- return JustTypeRef(cls.__egg_name__)
381
- if isinstance(cls, RuntimeParamaterizedClass):
382
- # Currently this is used when calling methods on a parametrized class, which is only possible when we
383
- # have actualy types currently, not typevars, currently.
384
- return cls.__egg_tp__.to_just()
385
- assert_never(cls)
386
-
387
203
 
388
204
  @dataclass
389
- class RuntimeFunction:
390
- __egg_decls__: Declarations
391
- __egg_name__: str
392
- __egg_fn_ref__: FunctionRef = field(init=False)
393
- __egg_fn_decl__: FunctionDecl = field(init=False)
394
-
395
- def __post_init__(self) -> None:
396
- self.__egg_fn_ref__ = FunctionRef(self.__egg_name__)
397
- self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_fn_ref__)
205
+ class RuntimeFunction(DelayedDeclerations):
206
+ __egg_ref__: CallableRef
207
+ # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
208
+ __egg_bound__: JustTypeRef | RuntimeExpr | None = None
398
209
 
399
210
  def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
400
- return _call(self.__egg_decls__, self.__egg_fn_ref__, self.__egg_fn_decl__, args, kwargs)
401
-
402
- def __str__(self) -> str:
403
- return self.__egg_name__
404
-
405
-
406
- def _call(
407
- decls_from_fn: Declarations,
408
- callable_ref: CallableRef,
409
- fn_decl: FunctionDecl,
410
- args: Collection[object],
411
- kwargs: dict[str, object],
412
- bound_class: JustTypeRef | None = None,
413
- ) -> RuntimeExpr | None:
414
- # Turn all keyword args into positional args
415
- bound = fn_decl.to_signature(lambda expr: RuntimeExpr(decls_from_fn, expr)).bind(*args, **kwargs)
416
- bound.apply_defaults()
417
- assert not bound.kwargs
418
- del args, kwargs
419
-
420
- upcasted_args = [
421
- resolve_literal(cast(TypeOrVarRef, tp), arg)
422
- for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type)
423
- ]
424
-
425
- arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
426
- decls = Declarations.create(decls_from_fn, *upcasted_args)
427
-
428
- tcs = TypeConstraintSolver(decls)
429
- if bound_class is not None and bound_class.args:
430
- tcs.bind_class(bound_class)
431
-
432
- if fn_decl is not None:
211
+ from .conversion import resolve_literal
212
+
213
+ if isinstance(self.__egg_bound__, RuntimeExpr):
214
+ args = (self.__egg_bound__, *args)
215
+ fn_decl = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl()
216
+ # Turn all keyword args into positional args
217
+ bound = callable_decl_to_signature(fn_decl, self.__egg_decls__).bind(*args, **kwargs)
218
+ bound.apply_defaults()
219
+ assert not bound.kwargs
220
+ del args, kwargs
221
+
222
+ upcasted_args = [
223
+ resolve_literal(cast(TypeOrVarRef, tp), arg)
224
+ for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type)
225
+ ]
226
+
227
+ decls = Declarations.create(self, *upcasted_args)
228
+
229
+ tcs = TypeConstraintSolver(decls)
230
+ bound_tp = (
231
+ None
232
+ if self.__egg_bound__ is None
233
+ else self.__egg_bound__.__egg_typed_expr__.tp
234
+ if isinstance(self.__egg_bound__, RuntimeExpr)
235
+ else self.__egg_bound__
236
+ )
237
+ if bound_tp and bound_tp.args:
238
+ tcs.bind_class(bound_tp)
239
+ arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
433
240
  arg_types = [expr.tp for expr in arg_exprs]
434
- cls_name = bound_class.name if bound_class is not None else None
241
+ cls_name = bound_tp.name if bound_tp else None
435
242
  return_tp = tcs.infer_return_type(
436
- fn_decl.arg_types, fn_decl.return_type, fn_decl.var_arg_type, arg_types, cls_name
243
+ fn_decl.arg_types, fn_decl.return_type or fn_decl.arg_types[0], fn_decl.var_arg_type, arg_types, cls_name
437
244
  )
438
- else:
439
- return_tp = JustTypeRef("Unit")
440
- bound_params = cast(JustTypeRef, bound_class).args if isinstance(callable_ref, ClassMethodRef) else None
441
- expr_decl = CallDecl(callable_ref, arg_exprs, bound_params)
442
- typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
443
- # Register return type sort in case it's a variadic generic that needs to be created
444
- decls.register_sort(return_tp, False)
445
- if fn_decl.mutates_first_arg:
446
- first_arg = upcasted_args[0]
447
- first_arg.__egg_typed_expr__ = typed_expr_decl
448
- first_arg.__egg_decls__ = decls
449
- return None
450
- return RuntimeExpr(decls, TypedExprDecl(return_tp, expr_decl))
451
-
452
-
453
- @dataclass
454
- class RuntimeClassMethod:
455
- __egg_decls__: Declarations
456
- __egg_tp__: JustTypeRef
457
- __egg_method_name__: str
458
- __egg_callable_ref__: ClassMethodRef = field(init=False)
459
- __egg_fn_decl__: FunctionDecl = field(init=False)
460
-
461
- def __post_init__(self) -> None:
462
- self.__egg_callable_ref__ = ClassMethodRef(self.class_name, self.__egg_method_name__)
463
- try:
464
- self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
465
- except KeyError as e:
466
- raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") from e
467
-
468
- def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None:
469
- return _call(self.__egg_decls__, self.__egg_callable_ref__, self.__egg_fn_decl__, args, kwargs, self.__egg_tp__)
245
+ bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None
246
+ expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
247
+ typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
248
+ # If there is not return type, we are mutating the first arg
249
+ if not fn_decl.return_type:
250
+ first_arg = upcasted_args[0]
251
+ first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
252
+ return None
253
+ return RuntimeExpr.__from_value__(decls, typed_expr_decl)
470
254
 
471
255
  def __str__(self) -> str:
472
- return f"{self.class_name}.{self.__egg_method_name__}"
473
-
474
- @property
475
- def class_name(self) -> str:
476
- if isinstance(self.__egg_tp__, str):
477
- return self.__egg_tp__
478
- return self.__egg_tp__.name
256
+ first_arg, bound_tp_params = None, None
257
+ match self.__egg_bound__:
258
+ case RuntimeExpr(_):
259
+ first_arg = self.__egg_bound__.__egg_typed_expr__.expr
260
+ case JustTypeRef(_, args):
261
+ bound_tp_params = args
262
+ return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
263
+
264
+
265
+ def callable_decl_to_signature(
266
+ decl: FunctionDecl,
267
+ decls: Declarations,
268
+ ) -> Signature:
269
+ parameters = [
270
+ Parameter(
271
+ n,
272
+ Parameter.POSITIONAL_OR_KEYWORD,
273
+ default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d)) if d else Parameter.empty,
274
+ )
275
+ for n, d, t in zip(decl.arg_names, decl.arg_defaults, decl.arg_types, strict=True)
276
+ ]
277
+ if isinstance(decl, FunctionDecl) and decl.var_arg_type is not None:
278
+ parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
279
+ return Signature(parameters)
479
280
 
480
281
 
481
282
  # All methods which should return NotImplemented if they fail to resolve
@@ -505,63 +306,34 @@ PARTIAL_METHODS = {
505
306
 
506
307
 
507
308
  @dataclass
508
- class RuntimeMethod:
509
- __egg_self__: RuntimeExpr
510
- __egg_method_name__: str
511
- __egg_callable_ref__: MethodRef | PropertyRef = field(init=False)
512
- __egg_fn_decl__: FunctionDecl = field(init=False, repr=False)
513
- __egg_decls__: Declarations = field(init=False)
514
-
515
- def __post_init__(self) -> None:
516
- self.__egg_decls__ = self.__egg_self__.__egg_decls__
517
- if self.__egg_method_name__ in self.__egg_decls__.get_class_decl(self.class_name).properties:
518
- self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__)
519
- else:
520
- self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__)
521
- try:
522
- self.__egg_fn_decl__ = self.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
523
- except KeyError:
524
- msg = f"Class {self.class_name} does not have method {self.__egg_method_name__}"
525
- if self.__egg_method_name__ == "__ne__":
526
- msg += ". Did you mean to use the ne(...).to(...)?"
527
- raise AttributeError(msg) from None
309
+ class RuntimeExpr:
310
+ # Defer needing decls/expr so we can make constants that don't resolve their class types
311
+ __egg_thunk__: Callable[[], tuple[Declarations, TypedExprDecl]]
528
312
 
529
- def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None:
530
- args = (self.__egg_self__, *args)
531
- try:
532
- return _call(
533
- self.__egg_decls__,
534
- self.__egg_callable_ref__,
535
- self.__egg_fn_decl__,
536
- args,
537
- kwargs,
538
- self.__egg_self__.__egg_typed_expr__.tp,
539
- )
540
- except ConvertError as e:
541
- name = self.__egg_method_name__
542
- raise TypeError(f"Wrong types for {self.__egg_self__.__egg_typed_expr__.tp.pretty()}.{name}") from e
313
+ @classmethod
314
+ def __from_value__(cls, d: Declarations, e: TypedExprDecl) -> RuntimeExpr:
315
+ return cls(Thunk.value((d, e)))
543
316
 
544
317
  @property
545
- def class_name(self) -> str:
546
- return self.__egg_self__.__egg_typed_expr__.tp.name
547
-
318
+ def __egg_decls__(self) -> Declarations:
319
+ return self.__egg_thunk__()[0]
548
320
 
549
- @dataclass
550
- class RuntimeExpr:
551
- __egg_decls__: Declarations
552
- __egg_typed_expr__: TypedExprDecl
321
+ @property
322
+ def __egg_typed_expr__(self) -> TypedExprDecl:
323
+ return self.__egg_thunk__()[1]
553
324
 
554
- def __getattr__(self, name: str) -> RuntimeMethod | RuntimeExpr | Callable | None:
555
- class_decl = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name)
325
+ def __getattr__(self, name: str) -> RuntimeFunction | RuntimeExpr | Callable | None:
326
+ cls_name = self.__egg_class_name__
327
+ class_decl = self.__egg_class_decl__
556
328
 
557
- preserved_methods = class_decl.preserved_methods
558
- if name in preserved_methods:
329
+ if name in (preserved_methods := class_decl.preserved_methods):
559
330
  return preserved_methods[name].__get__(self)
560
331
 
561
- method = RuntimeMethod(self, name)
562
- if isinstance(method.__egg_callable_ref__, PropertyRef):
563
- return method()
564
- return method
332
+ if name in class_decl.methods:
333
+ return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(cls_name, name), self)
334
+ if name in class_decl.properties:
335
+ return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(cls_name, name), self)()
336
+ raise AttributeError(f"{cls_name} has no method {name}") from None
565
337
 
566
338
  def __repr__(self) -> str:
567
339
  """
@@ -570,18 +342,10 @@ class RuntimeExpr:
570
342
  return str(self)
571
343
 
572
344
  def __str__(self) -> str:
573
- context = PrettyContext(self.__egg_decls__)
574
- context.traverse_for_parents(self.__egg_typed_expr__.expr)
575
- pretty_expr = self.__egg_typed_expr__.expr.pretty(context, parens=False)
576
- try:
577
- if config.SHOW_TYPES:
578
- raise NotImplementedError
579
- # s = f"_: {self.__egg_typed_expr__.tp.pretty()} = {pretty_expr}"
580
- # return black.format_str(s, mode=black.FileMode()).strip()
581
- pretty_statements = context.render(pretty_expr)
582
- return black.format_str(pretty_statements, mode=BLACK_MODE).strip()
583
- except black.parsing.InvalidInput:
584
- return pretty_expr
345
+ return self.__egg_pretty__(None)
346
+
347
+ def __egg_pretty__(self, wrapping_fn: str | None) -> str:
348
+ return pretty_decl(self.__egg_decls__, self.__egg_typed_expr__.expr, wrapping_fn=wrapping_fn)
585
349
 
586
350
  def _ipython_display_(self) -> None:
587
351
  from IPython.display import Code, display
@@ -589,28 +353,32 @@ class RuntimeExpr:
589
353
  display(Code(str(self), language="python"))
590
354
 
591
355
  def __dir__(self) -> Iterable[str]:
592
- return list(self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).methods)
356
+ class_decl = self.__egg_class_decl__
357
+ return list(class_decl.methods) + list(class_decl.properties) + list(class_decl.preserved_methods)
593
358
 
594
359
  @property
595
- def __egg__(self) -> bindings._Expr:
596
- return self.__egg_typed_expr__.to_egg(self.__egg_decls__)
360
+ def __egg_class_name__(self) -> str:
361
+ return self.__egg_typed_expr__.tp.name
362
+
363
+ @property
364
+ def __egg_class_decl__(self) -> ClassDecl:
365
+ return self.__egg_decls__.get_class_decl(self.__egg_class_name__)
597
366
 
598
367
  # Have __eq__ take no NoReturn (aka Never https://docs.python.org/3/library/typing.html#typing.Never) because
599
368
  # we don't wany any type that MyPy thinks is an expr to be used with __eq__.
600
369
  # That's because we want to reserve __eq__ for domain specific equality checks, overloading this method.
601
370
  # To check if two exprs are equal, use the expr_eq method.
602
- def __eq__(self, other: NoReturn) -> Expr: # type: ignore[override]
603
- msg = "Do not use == on RuntimeExpr. Compare the __egg_typed_expr__ attribute instead for structural equality."
604
- raise NotImplementedError(msg)
371
+ # At runtime, this will resolve if there is a defined egg function for `__eq__`
372
+ def __eq__(self, other: NoReturn) -> Expr: ... # type: ignore[override, empty-body]
605
373
 
606
374
  # Implement these so that copy() works on this object
607
375
  # otherwise copy will try to call `__getstate__` before object is initialized with properties which will cause inifinite recursion
608
376
 
609
377
  def __getstate__(self) -> tuple[Declarations, TypedExprDecl]:
610
- return (self.__egg_decls__, self.__egg_typed_expr__)
378
+ return self.__egg_thunk__()
611
379
 
612
380
  def __setstate__(self, d: tuple[Declarations, TypedExprDecl]) -> None:
613
- self.__egg_decls__, self.__egg_typed_expr__ = d
381
+ self.__egg_thunk__ = Thunk.value(d)
614
382
 
615
383
  def __hash__(self) -> int:
616
384
  return hash(self.__egg_typed_expr__)
@@ -625,12 +393,17 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
625
393
  __name: str = name,
626
394
  **kwargs: object,
627
395
  ) -> RuntimeExpr | None:
396
+ from .conversion import ConvertError
397
+
398
+ class_name = self.__egg_class_name__
399
+ class_decl = self.__egg_class_decl__
628
400
  # First, try to resolve as preserved method
629
401
  try:
630
- method = self.__egg_decls__.get_class_decl(self.__egg_typed_expr__.tp.name).preserved_methods[__name]
631
- return method(self, *args, **kwargs)
402
+ method = class_decl.preserved_methods[__name]
632
403
  except KeyError:
633
404
  pass
405
+ else:
406
+ return method(self, *args, **kwargs)
634
407
  # If this is a "partial" method meaning that it can return NotImplemented,
635
408
  # we want to find the "best" superparent (lowest cost) of the arg types to call with it, instead of just
636
409
  # using the arg type of the self arg.
@@ -640,7 +413,10 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
640
413
  return call_method_min_conversion(self, args[0], __name)
641
414
  except ConvertError:
642
415
  return NotImplemented
643
- return RuntimeMethod(self, __name)(*args, **kwargs)
416
+ if __name in class_decl.methods:
417
+ fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
418
+ return fn(*args, **kwargs)
419
+ raise TypeError(f"{class_name!r} object does not support {__name}")
644
420
 
645
421
  setattr(RuntimeExpr, name, _special_method)
646
422
 
@@ -655,12 +431,14 @@ for reflected, non_reflected in REFLECTED_BINARY_METHODS.items():
655
431
 
656
432
 
657
433
  def call_method_min_conversion(slf: object, other: object, name: str) -> RuntimeExpr | None:
434
+ from .conversion import min_convertable_tp, resolve_literal
435
+
658
436
  # find a minimum type that both can be converted to
659
437
  # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
660
438
  min_tp = min_convertable_tp(slf, other, name)
661
439
  slf = resolve_literal(min_tp.to_var(), slf)
662
440
  other = resolve_literal(min_tp.to_var(), other)
663
- method = RuntimeMethod(slf, name)
441
+ method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
664
442
  return method(other)
665
443
 
666
444
 
@@ -680,21 +458,9 @@ def resolve_callable(callable: object) -> tuple[CallableRef, Declarations]:
680
458
  """
681
459
  Resolves a runtime callable into a ref
682
460
  """
683
- # TODO: Fix these typings.
684
- ref: CallableRef
685
- decls: Declarations
686
- if isinstance(callable, RuntimeFunction):
687
- ref = FunctionRef(callable.__egg_name__)
688
- decls = callable.__egg_decls__
689
- elif isinstance(callable, RuntimeClassMethod):
690
- ref = ClassMethodRef(callable.class_name, callable.__egg_method_name__)
691
- decls = callable.__egg_decls__
692
- elif isinstance(callable, RuntimeMethod):
693
- ref = MethodRef(callable.__egg_self__.__egg_typed_expr__.tp.name, callable.__egg_method_name__)
694
- decls = callable.__egg_decls__
695
- elif isinstance(callable, RuntimeClass):
696
- ref = ClassMethodRef(callable.__egg_name__, "__init__")
697
- decls = callable.__egg_decls__
698
- else:
699
- raise NotImplementedError(f"Cannot turn {callable} into a callable ref")
700
- return (ref, decls)
461
+ match callable:
462
+ case RuntimeFunction(decls, ref, _):
463
+ return ref, decls()
464
+ case RuntimeClass(thunk, tp):
465
+ return ClassMethodRef(tp.name, "__init__"), thunk()
466
+ raise NotImplementedError(f"Cannot turn {callable} into a callable ref")