egglog 10.0.2__cp312-cp312-win_amd64.whl → 11.1.0__cp312-cp312-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/conversion.py CHANGED
@@ -1,10 +1,11 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from collections import defaultdict
4
+ from collections.abc import Callable
4
5
  from contextlib import contextmanager
5
6
  from contextvars import ContextVar
6
7
  from dataclasses import dataclass
7
- from typing import TYPE_CHECKING, TypeVar, cast
8
+ from typing import TYPE_CHECKING, Any, TypeVar, cast
8
9
 
9
10
  from .declarations import *
10
11
  from .pretty import *
@@ -13,14 +14,14 @@ from .thunk import *
13
14
  from .type_constraint_solver import TypeConstraintError
14
15
 
15
16
  if TYPE_CHECKING:
16
- from collections.abc import Callable, Generator
17
+ from collections.abc import Generator
17
18
 
18
19
  from .egraph import BaseExpr
19
20
  from .type_constraint_solver import TypeConstraintSolver
20
21
 
21
- __all__ = ["ConvertError", "convert", "convert_to_same_type", "converter", "resolve_literal"]
22
+ __all__ = ["ConvertError", "convert", "converter", "get_type_args"]
22
23
  # Mapping from (source type, target type) to and function which takes in the runtimes values of the source and return the target
23
- CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable]] = {}
24
+ CONVERSIONS: dict[tuple[type | JustTypeRef, JustTypeRef], tuple[int, Callable[[Any], RuntimeExpr]]] = {}
24
25
  # Global declerations to store all convertable types so we can query if they have certain methods or not
25
26
  _CONVERSION_DECLS = Declarations.create()
26
27
  # Defer a list of declerations to be added to the global declerations, so that we can not trigger them procesing
@@ -28,7 +29,7 @@ _CONVERSION_DECLS = Declarations.create()
28
29
  _TO_PROCESS_DECLS: list[DeclerationsLike] = []
29
30
 
30
31
 
31
- def _retrieve_conversion_decls() -> Declarations:
32
+ def retrieve_conversion_decls() -> Declarations:
32
33
  _CONVERSION_DECLS.update(*_TO_PROCESS_DECLS)
33
34
  _TO_PROCESS_DECLS.clear()
34
35
  return _CONVERSION_DECLS
@@ -49,10 +50,10 @@ def converter(from_type: type[T], to_type: type[V], fn: Callable[[T], V], cost:
49
50
  to_type_name = process_tp(to_type)
50
51
  if not isinstance(to_type_name, JustTypeRef):
51
52
  raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
52
- _register_converter(process_tp(from_type), to_type_name, fn, cost)
53
+ _register_converter(process_tp(from_type), to_type_name, cast("Callable[[Any], RuntimeExpr]", fn), cost)
53
54
 
54
55
 
55
- def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable, cost: int) -> None:
56
+ def _register_converter(a: type | JustTypeRef, b: JustTypeRef, a_b: Callable[[Any], RuntimeExpr], cost: int) -> None:
56
57
  """
57
58
  Registers a converter from some type to an egglog type, if not already registered.
58
59
 
@@ -97,15 +98,15 @@ class _ComposedConverter:
97
98
  We use the dataclass instead of the lambda to make it easier to debug.
98
99
  """
99
100
 
100
- a_b: Callable
101
- b_c: Callable
101
+ a_b: Callable[[Any], RuntimeExpr]
102
+ b_c: Callable[[Any], RuntimeExpr]
102
103
  b_args: tuple[JustTypeRef, ...]
103
104
 
104
- def __call__(self, x: object) -> object:
105
+ def __call__(self, x: Any) -> RuntimeExpr:
105
106
  # if we have A -> B and B[C] -> D then we should use (C,) as the type args
106
107
  # when converting from A -> B
107
108
  if self.b_args:
108
- with with_type_args(self.b_args, _retrieve_conversion_decls):
109
+ with with_type_args(self.b_args, retrieve_conversion_decls):
109
110
  first_res = self.a_b(x)
110
111
  else:
111
112
  first_res = self.a_b(x)
@@ -142,36 +143,53 @@ def process_tp(tp: type | RuntimeClass) -> JustTypeRef | type:
142
143
  return tp
143
144
 
144
145
 
145
- def min_convertable_tp(a: object, b: object, name: str) -> JustTypeRef:
146
+ def min_binary_conversion(
147
+ method_name: str, lhs: type | JustTypeRef, rhs: type | JustTypeRef
148
+ ) -> tuple[Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None:
146
149
  """
147
- Returns the minimum convertable type between a and b, that has a method `name`, raising a ConvertError if no such type exists.
150
+ Given a binary method and two starting types for the LHS and RHS, return a pair of callable which will convert
151
+ the LHS and RHS to appropriate types which support this method. If no such conversion is possible, return None.
152
+
153
+ It should return the types which minimize the total conversion cost. If one of the types is a Python type, then
154
+ both of them can be converted. However, if both are egglog types, then only one of them can be converted.
155
+ """
156
+ decls = retrieve_conversion_decls()
157
+ # tuple of (cost, convert lhs, convert rhs)
158
+ best_method: tuple[int, Callable[[Any], RuntimeExpr], Callable[[Any], RuntimeExpr]] | None = None
159
+
160
+ possible_lhs = _all_conversions_from(lhs) if isinstance(lhs, type) else [(0, lhs, identity)]
161
+ possible_rhs = _all_conversions_from(rhs) if isinstance(rhs, type) else [(0, rhs, identity)]
162
+ for lhs_cost, lhs_converted_type, lhs_convert in possible_lhs:
163
+ # Start by checking if we have a LHS that matches exactly and a RHS which can be converted
164
+ if (desired_other_type := decls.check_binary_method_with_self_type(method_name, lhs_converted_type)) and (
165
+ converter := CONVERSIONS.get((rhs, desired_other_type))
166
+ ):
167
+ cost = lhs_cost + converter[0]
168
+ if best_method is None or best_method[0] > cost:
169
+ best_method = (cost, lhs_convert, converter[1])
170
+
171
+ for rhs_cost, rhs_converted_type, rhs_convert in possible_rhs:
172
+ # Next see if it's possible to convert the LHS and keep the RHS as is
173
+ for desired_self_type in decls.check_binary_method_with_other_type(method_name, rhs_converted_type):
174
+ if converter := CONVERSIONS.get((lhs, desired_self_type)):
175
+ cost = rhs_cost + converter[0]
176
+ if best_method is None or best_method[0] > cost:
177
+ best_method = (cost, converter[1], rhs_convert)
178
+ if best_method is None:
179
+ return None
180
+ return best_method[1], best_method[2]
181
+
182
+
183
+ def _all_conversions_from(tp: JustTypeRef | type) -> list[tuple[int, JustTypeRef, Callable[[Any], RuntimeExpr]]]:
184
+ """
185
+ Get all conversions from a type to other types.
186
+
187
+ Returns a list of tuples of (cost, target type, conversion function).
148
188
  """
149
- decls = _retrieve_conversion_decls()
150
- a_tp = _get_tp(a)
151
- b_tp = _get_tp(b)
152
- # Make sure at least one of the types has the method, to avoid issue with == upcasting improperly
153
- if not (
154
- (isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name))
155
- or (isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name))
156
- ):
157
- raise ConvertError(f"Neither {a_tp} nor {b_tp} has method {name}")
158
- a_converts_to = {
159
- to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == a_tp and decls.has_method(to.name, name)
160
- }
161
- b_converts_to = {
162
- to: c for ((from_, to), (c, _)) in CONVERSIONS.items() if from_ == b_tp and decls.has_method(to.name, name)
163
- }
164
- if isinstance(a_tp, JustTypeRef) and decls.has_method(a_tp.name, name):
165
- a_converts_to[a_tp] = 0
166
- if isinstance(b_tp, JustTypeRef) and decls.has_method(b_tp.name, name):
167
- b_converts_to[b_tp] = 0
168
- common = set(a_converts_to) & set(b_converts_to)
169
- if not common:
170
- raise ConvertError(f"Cannot convert {a_tp} and {b_tp} to a common type")
171
- return min(common, key=lambda tp: a_converts_to[tp] + b_converts_to[tp])
172
-
173
-
174
- def identity(x: object) -> object:
189
+ return [(cost, target, fn) for (source, target), (cost, fn) in CONVERSIONS.items() if source == tp]
190
+
191
+
192
+ def identity(x: Any) -> Any:
175
193
  return x
176
194
 
177
195
 
@@ -197,7 +215,7 @@ def with_type_args(args: tuple[JustTypeRef, ...], decls: Callable[[], Declaratio
197
215
  def resolve_literal(
198
216
  tp: TypeOrVarRef,
199
217
  arg: object,
200
- decls: Callable[[], Declarations] = _retrieve_conversion_decls,
218
+ decls: Callable[[], Declarations] = retrieve_conversion_decls,
201
219
  tcs: TypeConstraintSolver | None = None,
202
220
  cls_name: str | None = None,
203
221
  ) -> RuntimeExpr:
@@ -208,12 +226,12 @@ def resolve_literal(
208
226
 
209
227
  If it cannot be resolved, we assume that the value passed in will resolve it.
210
228
  """
211
- arg_type = _get_tp(arg)
229
+ arg_type = resolve_type(arg)
212
230
 
213
231
  # If we have any type variables, dont bother trying to resolve the literal, just return the arg
214
232
  try:
215
233
  tp_just = tp.to_just()
216
- except NotImplementedError:
234
+ except TypeVarError:
217
235
  # If this is a generic arg but passed in a non runtime expression, try to resolve the generic
218
236
  # args first based on the existing type constraint solver
219
237
  if tcs:
@@ -258,7 +276,7 @@ def _debug_print_converers():
258
276
  source_to_targets[source].append(target)
259
277
 
260
278
 
261
- def _get_tp(x: object) -> JustTypeRef | type:
279
+ def resolve_type(x: object) -> JustTypeRef | type:
262
280
  if isinstance(x, RuntimeExpr):
263
281
  return x.__egg_typed_expr__.tp
264
282
  tp = type(x)
egglog/declarations.py CHANGED
@@ -51,6 +51,7 @@ __all__ = [
51
51
  "InitRef",
52
52
  "JustTypeRef",
53
53
  "LetDecl",
54
+ "LetRefDecl",
54
55
  "LitDecl",
55
56
  "LitType",
56
57
  "MethodRef",
@@ -72,16 +73,18 @@ __all__ = [
72
73
  "SpecialFunctions",
73
74
  "TypeOrVarRef",
74
75
  "TypeRefWithVars",
76
+ "TypeVarError",
75
77
  "TypedExprDecl",
78
+ "UnboundVarDecl",
76
79
  "UnionDecl",
77
80
  "UnnamedFunctionRef",
78
- "VarDecl",
81
+ "collect_unbound_vars",
79
82
  "replace_typed_expr",
80
83
  "upcast_declerations",
81
84
  ]
82
85
 
83
86
 
84
- @dataclass
87
+ @dataclass(match_args=False)
85
88
  class DelayedDeclerations:
86
89
  __egg_decls_thunk__: Callable[[], Declarations] = field(repr=False)
87
90
 
@@ -93,7 +96,7 @@ class DelayedDeclerations:
93
96
  # Catch attribute error, so that it isn't bubbled up as a missing attribute and fallbacks on `__getattr__`
94
97
  # instead raise explicitly
95
98
  except AttributeError as err:
96
- msg = f"Cannot resolve declarations for {self}"
99
+ msg = f"Cannot resolve declarations for {self}: {err}"
97
100
  raise RuntimeError(msg) from err
98
101
 
99
102
 
@@ -223,14 +226,46 @@ class Declarations:
223
226
  case _:
224
227
  assert_never(ref)
225
228
 
226
- def has_method(self, class_name: str, method_name: str) -> bool | None:
229
+ def check_binary_method_with_types(self, method_name: str, self_type: JustTypeRef, other_type: JustTypeRef) -> bool:
230
+ """
231
+ Checks if the class has a binary method compatible with the given types.
232
+ """
233
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
234
+ if callable_decl := self._classes[self_type.name].methods.get(method_name):
235
+ match callable_decl.signature:
236
+ case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(
237
+ vars, self_type
238
+ ) and other_arg_type.matches_just(vars, other_type):
239
+ return True
240
+ return False
241
+
242
+ def check_binary_method_with_self_type(self, method_name: str, self_type: JustTypeRef) -> JustTypeRef | None:
227
243
  """
228
- Returns whether the given class has the given method, or None if we cant find the class.
244
+ Checks if the class has a binary method with the given name and self type. Returns the other type if it exists.
229
245
  """
230
- if class_name in self._classes:
231
- return method_name in self._classes[class_name].methods
246
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
247
+ class_decl = self._classes.get(self_type.name)
248
+ if class_decl is None:
249
+ return None
250
+ if callable_decl := class_decl.methods.get(method_name):
251
+ match callable_decl.signature:
252
+ case FunctionSignature((self_arg_type, other_arg_type)) if self_arg_type.matches_just(vars, self_type):
253
+ return other_arg_type.to_just(vars)
232
254
  return None
233
255
 
256
+ def check_binary_method_with_other_type(self, method_name: str, other_type: JustTypeRef) -> Iterable[JustTypeRef]:
257
+ """
258
+ Returns the types which are compatible with the given binary method name and other type.
259
+ """
260
+ for class_decl in self._classes.values():
261
+ vars: dict[ClassTypeVarRef, JustTypeRef] = {}
262
+ if callable_decl := class_decl.methods.get(method_name):
263
+ match callable_decl.signature:
264
+ case FunctionSignature((self_arg_type, other_arg_type)) if other_arg_type.matches_just(
265
+ vars, other_type
266
+ ):
267
+ yield self_arg_type.to_just(vars)
268
+
234
269
  def get_class_decl(self, name: str) -> ClassDecl:
235
270
  return self._classes[name]
236
271
 
@@ -254,6 +289,7 @@ class ClassDecl:
254
289
  methods: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
255
290
  properties: dict[str, FunctionDecl | ConstructorDecl] = field(default_factory=dict)
256
291
  preserved_methods: dict[str, Callable] = field(default_factory=dict)
292
+ match_args: tuple[str, ...] = field(default=())
257
293
 
258
294
 
259
295
  @dataclass(frozen=True)
@@ -298,6 +334,10 @@ class JustTypeRef:
298
334
  _RESOLVED_TYPEVARS: dict[ClassTypeVarRef, TypeVar] = {}
299
335
 
300
336
 
337
+ class TypeVarError(RuntimeError):
338
+ """Error when trying to resolve a type variable that doesn't exist."""
339
+
340
+
301
341
  @dataclass(frozen=True)
302
342
  class ClassTypeVarRef:
303
343
  """
@@ -307,9 +347,10 @@ class ClassTypeVarRef:
307
347
  name: str
308
348
  module: str
309
349
 
310
- def to_just(self) -> JustTypeRef:
311
- msg = f"{self}: egglog does not support generic classes yet."
312
- raise NotImplementedError(msg)
350
+ def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
351
+ if vars is None or self not in vars:
352
+ raise TypeVarError(f"Cannot convert type variable {self} to concrete type without variable bindings")
353
+ return vars[self]
313
354
 
314
355
  def __str__(self) -> str:
315
356
  return str(self.to_type_var())
@@ -323,20 +364,39 @@ class ClassTypeVarRef:
323
364
  def to_type_var(self) -> TypeVar:
324
365
  return _RESOLVED_TYPEVARS[self]
325
366
 
367
+ def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
368
+ """
369
+ Checks if this type variable matches the given JustTypeRef, including type variables.
370
+ """
371
+ if self in vars:
372
+ return vars[self] == other
373
+ vars[self] = other
374
+ return True
375
+
326
376
 
327
377
  @dataclass(frozen=True)
328
378
  class TypeRefWithVars:
329
379
  name: str
330
380
  args: tuple[TypeOrVarRef, ...] = ()
331
381
 
332
- def to_just(self) -> JustTypeRef:
333
- return JustTypeRef(self.name, tuple(a.to_just() for a in self.args))
382
+ def to_just(self, vars: dict[ClassTypeVarRef, JustTypeRef] | None = None) -> JustTypeRef:
383
+ return JustTypeRef(self.name, tuple(a.to_just(vars) for a in self.args))
334
384
 
335
385
  def __str__(self) -> str:
336
386
  if self.args:
337
387
  return f"{self.name}[{', '.join(str(a) for a in self.args)}]"
338
388
  return self.name
339
389
 
390
+ def matches_just(self, vars: dict[ClassTypeVarRef, JustTypeRef], other: JustTypeRef) -> bool:
391
+ """
392
+ Checks if this type reference matches the given JustTypeRef, including type variables.
393
+ """
394
+ return (
395
+ self.name == other.name
396
+ and len(self.args) == len(other.args)
397
+ and all(a.matches_just(vars, b) for a, b in zip(self.args, other.args, strict=True))
398
+ )
399
+
340
400
 
341
401
  TypeOrVarRef: TypeAlias = ClassTypeVarRef | TypeRefWithVars
342
402
 
@@ -361,7 +421,7 @@ class UnnamedFunctionRef:
361
421
  arg_names = []
362
422
  for a in self.args:
363
423
  arg_types.append(a.tp.to_var())
364
- assert isinstance(a.expr, VarDecl)
424
+ assert isinstance(a.expr, UnboundVarDecl)
365
425
  arg_names.append(a.expr.name)
366
426
  return FunctionSignature(
367
427
  arg_types=tuple(arg_types),
@@ -514,10 +574,14 @@ CallableDecl: TypeAlias = RelationDecl | ConstantDecl | FunctionDecl | Construct
514
574
 
515
575
 
516
576
  @dataclass(frozen=True)
517
- class VarDecl:
577
+ class UnboundVarDecl:
578
+ name: str
579
+ egg_name: str | None = None
580
+
581
+
582
+ @dataclass(frozen=True)
583
+ class LetRefDecl:
518
584
  name: str
519
- # Differentiate between let bound vars and vars created in rules so that they won't shadow in egglog, by adding a prefix
520
- is_let: bool
521
585
 
522
586
 
523
587
  @dataclass(frozen=True)
@@ -628,7 +692,7 @@ class PartialCallDecl:
628
692
  call: CallDecl
629
693
 
630
694
 
631
- ExprDecl: TypeAlias = VarDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
695
+ ExprDecl: TypeAlias = UnboundVarDecl | LetRefDecl | LitDecl | CallDecl | PyObjectDecl | PartialCallDecl
632
696
 
633
697
 
634
698
  @dataclass(frozen=True)
@@ -678,6 +742,28 @@ def replace_typed_expr(typed_expr: TypedExprDecl, replacements: Mapping[TypedExp
678
742
  return _inner(typed_expr)
679
743
 
680
744
 
745
+ def collect_unbound_vars(typed_expr: TypedExprDecl) -> set[TypedExprDecl]:
746
+ """
747
+ Returns the set of all unbound vars
748
+ """
749
+ seen = set[TypedExprDecl]()
750
+ unbound_vars = set[TypedExprDecl]()
751
+
752
+ def visit(typed_expr: TypedExprDecl) -> None:
753
+ if typed_expr in seen:
754
+ return
755
+ seen.add(typed_expr)
756
+ match typed_expr.expr:
757
+ case CallDecl(_, args) | PartialCallDecl(CallDecl(_, args)):
758
+ for arg in args:
759
+ visit(arg)
760
+ case UnboundVarDecl(_):
761
+ unbound_vars.add(typed_expr)
762
+
763
+ visit(typed_expr)
764
+ return unbound_vars
765
+
766
+
681
767
  ##
682
768
  # Schedules
683
769
  ##
egglog/deconstruct.py ADDED
@@ -0,0 +1,173 @@
1
+ """
2
+ Utility functions to deconstruct expressions in Python.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Callable
8
+ from functools import partial
9
+ from typing import TYPE_CHECKING, TypeVar, overload
10
+
11
+ from typing_extensions import TypeVarTuple, Unpack
12
+
13
+ from .declarations import *
14
+ from .egraph import BaseExpr
15
+ from .runtime import *
16
+ from .thunk import *
17
+
18
+ if TYPE_CHECKING:
19
+ from .builtins import Bool, PyObject, String, UnstableFn, f64, i64
20
+
21
+
22
+ T = TypeVar("T", bound=BaseExpr)
23
+ TS = TypeVarTuple("TS", default=Unpack[tuple[BaseExpr, ...]])
24
+
25
+ __all__ = ["get_callable_args", "get_callable_fn", "get_let_name", "get_literal_value", "get_var_name"]
26
+
27
+
28
+ @overload
29
+ def get_literal_value(x: String) -> str | None: ...
30
+
31
+
32
+ @overload
33
+ def get_literal_value(x: Bool) -> bool | None: ...
34
+
35
+
36
+ @overload
37
+ def get_literal_value(x: i64) -> int | None: ...
38
+
39
+
40
+ @overload
41
+ def get_literal_value(x: f64) -> float | None: ...
42
+
43
+
44
+ @overload
45
+ def get_literal_value(x: PyObject) -> object: ...
46
+
47
+
48
+ @overload
49
+ def get_literal_value(x: UnstableFn[T, Unpack[TS]]) -> Callable[[Unpack[TS]], T] | None: ...
50
+
51
+
52
+ def get_literal_value(x: String | Bool | i64 | f64 | PyObject | UnstableFn) -> object:
53
+ """
54
+ Returns the literal value of an expression if it is a literal.
55
+ If it is not a literal, returns None.
56
+ """
57
+ if not isinstance(x, RuntimeExpr):
58
+ raise TypeError(f"Expected Expression, got {type(x).__name__}")
59
+ match x.__egg_typed_expr__.expr:
60
+ case LitDecl(v):
61
+ return v
62
+ case PyObjectDecl(obj):
63
+ return obj
64
+ case PartialCallDecl(call):
65
+ fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
66
+ if not args:
67
+ return fn
68
+ return partial(fn, *args)
69
+ return None
70
+
71
+
72
+ def get_let_name(x: BaseExpr) -> str | None:
73
+ """
74
+ Check if the expression is a `let` expression and return the name of the variable.
75
+ If it is not a `let` expression, return None.
76
+ """
77
+ if not isinstance(x, RuntimeExpr):
78
+ raise TypeError(f"Expected Expression, got {type(x).__name__}")
79
+ match x.__egg_typed_expr__.expr:
80
+ case LetRefDecl(name):
81
+ return name
82
+ return None
83
+
84
+
85
+ def get_var_name(x: BaseExpr) -> str | None:
86
+ """
87
+ Check if the expression is a variable and return its name.
88
+ If it is not a variable, return None.
89
+ """
90
+ if not isinstance(x, RuntimeExpr):
91
+ raise TypeError(f"Expected Expression, got {type(x).__name__}")
92
+ match x.__egg_typed_expr__.expr:
93
+ case UnboundVarDecl(name, _egg_name):
94
+ return name
95
+ return None
96
+
97
+
98
+ def get_callable_fn(x: T) -> Callable[..., T] | None:
99
+ """
100
+ Gets the function of an expression if it is a call expression.
101
+ If it is not a call expression (a property, a primitive value, constants, classvars, a let value), return None.
102
+ For those values, you can check them by comparing them directly with equality or for primitives calling `.eval()`
103
+ to return the Python value.
104
+ """
105
+ if not isinstance(x, RuntimeExpr):
106
+ raise TypeError(f"Expected Expression, got {type(x).__name__}")
107
+ match x.__egg_typed_expr__.expr:
108
+ case CallDecl() as call:
109
+ fn, _ = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
110
+ return fn
111
+ return None
112
+
113
+
114
+ @overload
115
+ def get_callable_args(x: T, fn: None = ...) -> tuple[BaseExpr, ...]: ...
116
+
117
+
118
+ @overload
119
+ def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T]) -> tuple[Unpack[TS]] | None: ...
120
+
121
+
122
+ def get_callable_args(x: T, fn: Callable[[Unpack[TS]], T] | None = None) -> tuple[Unpack[TS]] | None:
123
+ """
124
+ Gets all the arguments of an expression.
125
+ If a function is provided, it will only return the arguments if the expression is a call
126
+ to that function.
127
+
128
+ Note that recursively calling the arguments is the safe way to walk the expression tree.
129
+ """
130
+ if not isinstance(x, RuntimeExpr):
131
+ raise TypeError(f"Expected Expression, got {type(x).__name__}")
132
+ match x.__egg_typed_expr__.expr:
133
+ case CallDecl() as call:
134
+ actual_fn, args = _deconstruct_call_decl(x.__egg_decls_thunk__, call)
135
+ if fn is None:
136
+ return args
137
+ # Compare functions and classes without considering bound type parameters, so that you can pass
138
+ # in a binding like Vec[i64] and match Vec[i64](...) or Vec(...) calls.
139
+ if (
140
+ isinstance(actual_fn, RuntimeFunction)
141
+ and isinstance(fn, RuntimeFunction)
142
+ and actual_fn.__egg_ref__ == fn.__egg_ref__
143
+ ):
144
+ return args
145
+ if (
146
+ isinstance(actual_fn, RuntimeClass)
147
+ and isinstance(fn, RuntimeClass)
148
+ and actual_fn.__egg_tp__.name == fn.__egg_tp__.name
149
+ ):
150
+ return args
151
+ return None
152
+
153
+
154
+ def _deconstruct_call_decl(
155
+ decls_thunk: Callable[[], Declarations], call: CallDecl
156
+ ) -> tuple[Callable, tuple[object, ...]]:
157
+ """
158
+ Deconstructs a CallDecl into a runtime callable and its arguments.
159
+ """
160
+ args = call.args
161
+ arg_exprs = tuple(RuntimeExpr(decls_thunk, Thunk.value(a)) for a in args)
162
+ if isinstance(call.callable, InitRef):
163
+ return RuntimeClass(
164
+ decls_thunk,
165
+ TypeRefWithVars(call.callable.class_name, tuple(tp.to_var() for tp in (call.bound_tp_params or []))),
166
+ ), arg_exprs
167
+ egg_bound = (
168
+ JustTypeRef(call.callable.class_name, call.bound_tp_params or ())
169
+ if isinstance(call.callable, ClassMethodRef)
170
+ else None
171
+ )
172
+
173
+ return RuntimeFunction(decls_thunk, Thunk.value(call.callable), egg_bound), arg_exprs