egglog 7.0.0__cp312-none-win_amd64.whl → 7.1.0__cp312-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of egglog might be problematic. Click here for more details.

egglog/egraph_state.py CHANGED
@@ -87,17 +87,26 @@ class EGraphState:
87
87
  """
88
88
  Registers a ruleset if it's not already registered.
89
89
  """
90
- if name not in self.rulesets:
91
- if name:
92
- self.egraph.run_program(bindings.AddRuleset(name))
93
- rules = self.rulesets[name] = set()
94
- else:
95
- rules = self.rulesets[name]
96
- for rule in self.__egg_decls__._rulesets[name].rules:
97
- if rule in rules:
98
- continue
99
- self.egraph.run_program(self.command_to_egg(rule, name))
100
- rules.add(rule)
90
+ match self.__egg_decls__._rulesets[name]:
91
+ case RulesetDecl(rules):
92
+ if name not in self.rulesets:
93
+ if name:
94
+ self.egraph.run_program(bindings.AddRuleset(name))
95
+ added_rules = self.rulesets[name] = set()
96
+ else:
97
+ added_rules = self.rulesets[name]
98
+ for rule in rules:
99
+ if rule in added_rules:
100
+ continue
101
+ self.egraph.run_program(self.command_to_egg(rule, name))
102
+ added_rules.add(rule)
103
+ case CombinedRulesetDecl(rulesets):
104
+ if name in self.rulesets:
105
+ return
106
+ self.rulesets[name] = set()
107
+ for ruleset in rulesets:
108
+ self.ruleset_to_egg(ruleset)
109
+ self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
101
110
 
102
111
  def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
103
112
  match cmd:
@@ -184,11 +193,13 @@ class EGraphState:
184
193
  )
185
194
  case FunctionDecl():
186
195
  if not decl.builtin:
196
+ signature = decl.signature
197
+ assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
187
198
  egg_fn_decl = bindings.FunctionDecl(
188
199
  egg_name,
189
200
  bindings.Schema(
190
- [self.type_ref_to_egg(a.to_just()) for a in decl.arg_types],
191
- self.type_ref_to_egg(decl.semantic_return_type.to_just()),
201
+ [self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
202
+ self.type_ref_to_egg(signature.semantic_return_type.to_just()),
192
203
  ),
193
204
  self.expr_to_egg(decl.default) if decl.default else None,
194
205
  self.expr_to_egg(decl.merge) if decl.merge else None,
@@ -212,19 +223,22 @@ class EGraphState:
212
223
  decl = self.__egg_decls__._classes[ref.name]
213
224
  self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
214
225
  if not decl.builtin or ref.args:
215
- self.egraph.run_program(
216
- bindings.Sort(
217
- egg_name,
218
- (
219
- (
220
- self.type_ref_to_egg(JustTypeRef(ref.name)),
221
- [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args],
222
- )
223
- if ref.args
224
- else None
225
- ),
226
- )
227
- )
226
+ if ref.args:
227
+ if ref.name == "UnstableFn":
228
+ # UnstableFn is a special case, where the rest of args are collected into a call
229
+ type_args: list[bindings._Expr] = [
230
+ bindings.Call(
231
+ self.type_ref_to_egg(ref.args[1]),
232
+ [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
233
+ ),
234
+ bindings.Var(self.type_ref_to_egg(ref.args[0])),
235
+ ]
236
+ else:
237
+ type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
238
+ args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
239
+ else:
240
+ args = None
241
+ self.egraph.run_program(bindings.Sort(egg_name, args))
228
242
  # For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
229
243
  # these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
230
244
  # even if you never use that function.
@@ -292,6 +306,9 @@ class EGraphState:
292
306
  res = bindings.Call(egg_fn, egg_args)
293
307
  case PyObjectDecl(value):
294
308
  res = GLOBAL_PY_OBJECT_SORT.store(value)
309
+ case PartialCallDecl(call_decl):
310
+ egg_fn_call = self.expr_to_egg(call_decl)
311
+ res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
295
312
  case _:
296
313
  assert_never(expr_decl.expr)
297
314
 
@@ -371,24 +388,45 @@ class FromEggState:
371
388
  if term.name == "py-object":
372
389
  call = bindings.termdag_term_to_expr(self.termdag, term)
373
390
  expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
391
+ if term.name == "unstable-fn":
392
+ # Get function name
393
+ fn_term, *arg_terms = term.args
394
+ fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
395
+ assert isinstance(fn_value.expr, LitDecl)
396
+ fn_name = fn_value.expr.value
397
+ assert isinstance(fn_name, str)
398
+
399
+ # Resolve what types the partiallied applied args are
400
+ assert tp.name == "UnstableFn"
401
+ call_decl = self.from_call(tp.args[0], bindings.TermApp(fn_name, arg_terms))
402
+ expr_decl = PartialCallDecl(call_decl)
374
403
  else:
375
404
  expr_decl = self.from_call(tp, term)
376
405
  else:
377
406
  assert_never(term)
378
407
  return TypedExprDecl(tp, expr_decl)
379
408
 
380
- def from_call(self, tp: JustTypeRef, term: bindings.TermApp) -> CallDecl:
409
+ def from_call(
410
+ self,
411
+ tp: JustTypeRef,
412
+ term: bindings.TermApp, # additional_arg_tps: tuple[JustTypeRef, ...]
413
+ ) -> CallDecl:
381
414
  """
382
415
  Convert a call to a CallDecl.
383
416
 
384
417
  There could be Python call refs which match the call, so we need to find the correct one.
418
+
419
+ The additional_arg_tps are known types for arguments that come after the term args, used to infer types
420
+ for partially applied functions, where we know the types of the later args, but not of the earlier ones where
421
+ we have values for.
385
422
  """
386
423
  # Find the first callable ref that matches the call
387
424
  for callable_ref in self.state.egg_fn_to_callable_refs[term.name]:
388
425
  # If this is a classmethod, we might need the type params that were bound for this type
389
426
  # This could be multiple types if the classmethod is ambiguous, like map create.
390
427
  possible_types: Iterable[JustTypeRef | None]
391
- fn_decl = self.decls.get_callable_decl(callable_ref).to_function_decl()
428
+ signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature
429
+ assert isinstance(signature, FunctionSignature)
392
430
  if isinstance(callable_ref, ClassMethodRef):
393
431
  possible_types = self.state._get_possible_types(callable_ref.class_name)
394
432
  cls_name = callable_ref.class_name
@@ -402,16 +440,17 @@ class FromEggState:
402
440
 
403
441
  try:
404
442
  arg_types, bound_tp_params = tcs.infer_arg_types(
405
- fn_decl.arg_types, fn_decl.semantic_return_type, fn_decl.var_arg_type, tp, cls_name
443
+ signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
406
444
  )
407
445
  except TypeConstraintError:
408
446
  continue
409
- args: list[TypedExprDecl] = []
410
- for a, tp in zip(term.args, arg_types, strict=False):
411
- try:
412
- res = self.cache[a]
413
- except KeyError:
414
- res = self.cache[a] = self.from_expr(tp, self.termdag.nodes[a])
415
- args.append(res)
416
- return CallDecl(callable_ref, tuple(args), bound_tp_params)
447
+ args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
448
+ return CallDecl(callable_ref, args, bound_tp_params)
417
449
  raise ValueError(f"Could not find callable ref for call {term}")
450
+
451
+ def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
452
+ try:
453
+ return self.cache[term_id]
454
+ except KeyError:
455
+ res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
456
+ return res
egglog/exp/array_api.py CHANGED
@@ -18,7 +18,7 @@ from egglog.runtime import RuntimeExpr
18
18
  from .program_gen import *
19
19
 
20
20
  if TYPE_CHECKING:
21
- from collections.abc import Iterator
21
+ from collections.abc import Callable, Iterator
22
22
  from types import ModuleType
23
23
 
24
24
  # Pretend that exprs are numbers b/c sklearn does isinstance checks
@@ -257,7 +257,7 @@ class TupleInt(Expr):
257
257
 
258
258
  def __getitem__(self, i: Int) -> Int: ...
259
259
 
260
- def product(self) -> Int: ...
260
+ def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
261
261
 
262
262
 
263
263
  converter(
@@ -272,7 +272,7 @@ converter(
272
272
 
273
273
 
274
274
  @array_api_ruleset.register
275
- def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
275
+ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64, f: Callable[[Int, Int], Int]):
276
276
  return [
277
277
  rewrite(ti + TupleInt.EMPTY).to(ti),
278
278
  rewrite(TupleInt(i).length()).to(Int(1)),
@@ -281,10 +281,10 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
281
281
  rewrite((TupleInt(i) + ti)[Int(0)]).to(i),
282
282
  # Rule for indexing > 0
283
283
  rule(eq(i).to((TupleInt(i2) + ti)[Int(k)]), k > 0).then(union(i).with_(ti[Int(k - 1)])),
284
- # Product
285
- rewrite(TupleInt(i).product()).to(i),
286
- rewrite((TupleInt(i) + ti).product()).to(i * ti.product()),
287
- rewrite(TupleInt.EMPTY.product()).to(Int(1)),
284
+ # fold
285
+ rewrite(TupleInt.EMPTY.fold(i, f)).to(i),
286
+ rewrite(TupleInt(i2).fold(i, f)).to(f(i, i2)),
287
+ rewrite((TupleInt(i2) + ti).fold(i, f)).to(ti.fold(f(i, i2), f)),
288
288
  ]
289
289
 
290
290
 
@@ -1346,7 +1346,7 @@ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
1346
1346
 
1347
1347
  @array_api_ruleset.register
1348
1348
  def _size(x: NDArray):
1349
- yield rewrite(x.size).to(x.shape.product())
1349
+ yield rewrite(x.size).to(x.shape.fold(Int(1), Int.__mul__))
1350
1350
 
1351
1351
 
1352
1352
  @overload
egglog/pretty.py CHANGED
@@ -66,7 +66,7 @@ UNARY_METHODS = {
66
66
  "__invert__": "~",
67
67
  }
68
68
 
69
- AllDecls: TypeAlias = RulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
69
+ AllDecls: TypeAlias = RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
70
70
 
71
71
 
72
72
  def pretty_decl(
@@ -106,7 +106,7 @@ def pretty_callable_ref(
106
106
  """
107
107
  # Pass in three dummy args, which are the max used for any operation that
108
108
  # is not a generic function call
109
- args: list[ExprDecl] = [LitDecl(ARG_STR)] * 3
109
+ args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
110
110
  if first_arg:
111
111
  args.insert(0, first_arg)
112
112
  res = PrettyContext(decls, defaultdict(lambda: 0))._call_inner(
@@ -117,6 +117,10 @@ def pretty_callable_ref(
117
117
  return res[0] if isinstance(res, tuple) else res
118
118
 
119
119
 
120
+ # TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
121
+ # so that things like Math.__add__ will be represented properly
122
+
123
+
120
124
  @dataclass
121
125
  class TraverseContext:
122
126
  """
@@ -170,6 +174,10 @@ class TraverseContext:
170
174
  if until:
171
175
  for f in until:
172
176
  self(f)
177
+ case PartialCallDecl(c):
178
+ self(c)
179
+ case CombinedRulesetDecl(_):
180
+ pass
173
181
  case _:
174
182
  assert_never(decl)
175
183
 
@@ -231,6 +239,9 @@ class PrettyContext:
231
239
  return name, name
232
240
  case CallDecl(_, _, _):
233
241
  return self._call(decl, parens)
242
+ case PartialCallDecl(CallDecl(ref, typed_args, _)):
243
+ arg_strs = (_pretty_callable(ref), *(self(a.expr, parens=False, unwrap_lit=True) for a in typed_args))
244
+ return f"UnstableFn({', '.join(arg_strs)})", "fn"
234
245
  case PyObjectDecl(value):
235
246
  return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
236
247
  case ActionCommandDecl(action):
@@ -267,6 +278,10 @@ class PrettyContext:
267
278
  return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
268
279
  args = ", ".join(map(self, rules))
269
280
  return f"ruleset({args})", "ruleset"
281
+ case CombinedRulesetDecl(rulesets):
282
+ if ruleset_name:
283
+ rulesets = (*rulesets, f"name={ruleset_name!r})")
284
+ return f"unstable_combine_rulesets({', '.join(rulesets)})", "combined_ruleset"
270
285
  case SaturateDecl(schedule):
271
286
  return f"{self(schedule, parens=True)}.saturate()", "schedule"
272
287
  case RepeatDecl(schedule, times):
@@ -302,19 +317,28 @@ class PrettyContext:
302
317
  l, r = self(args[0]), self(args[1])
303
318
  return f"ne({l}).to({r})", "Unit"
304
319
  function_decl = self.decls.get_callable_decl(ref).to_function_decl()
320
+ signature = function_decl.signature
321
+
305
322
  # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
306
323
  n_defaults = 0
307
- for arg, default in zip(
308
- reversed(args), reversed(function_decl.arg_defaults), strict=not function_decl.var_arg_type
309
- ):
310
- if arg != default:
311
- break
312
- n_defaults += 1
324
+ # Dont try counting defaults for function application
325
+ if isinstance(signature, FunctionSignature):
326
+ for arg, default in zip(
327
+ reversed(args), reversed(signature.arg_defaults), strict=not signature.var_arg_type
328
+ ):
329
+ if arg != default:
330
+ break
331
+ n_defaults += 1
313
332
  if n_defaults:
314
333
  args = args[:-n_defaults]
315
334
 
316
- tp_name = function_decl.semantic_return_type.name
317
- if function_decl.mutates:
335
+ # If this is a function application, the type is the first type arg of the function object
336
+ if signature == "fn-app":
337
+ tp_name = decl.args[0].tp.args[0].name
338
+ else:
339
+ assert isinstance(signature, FunctionSignature)
340
+ tp_name = signature.semantic_return_type.name
341
+ if isinstance(signature, FunctionSignature) and signature.mutates:
318
342
  first_arg = args[0]
319
343
  expr_str = self(first_arg)
320
344
  # copy an identifier expression iff it has multiple parents (b/c then we can't mutate it directly)
@@ -397,6 +421,28 @@ class PrettyContext:
397
421
  return name
398
422
 
399
423
 
424
+ def _pretty_callable(ref: CallableRef) -> str:
425
+ """
426
+ Returns a function call as a string.
427
+ """
428
+ match ref:
429
+ case FunctionRef(name):
430
+ return name
431
+ case (
432
+ ClassMethodRef(class_name, method_name)
433
+ | MethodRef(class_name, method_name)
434
+ | PropertyRef(class_name, method_name)
435
+ ):
436
+ return f"{class_name}.{method_name}"
437
+ case ConstantRef(_):
438
+ msg = "Constants should not be callable"
439
+ raise NotImplementedError(msg)
440
+ case ClassVariableRef(_, _):
441
+ msg = "Class variables should not be callable"
442
+ raise NotADirectoryError(msg)
443
+ assert_never(ref)
444
+
445
+
400
446
  def _plot_line_length(expr: object): # pragma: no cover
401
447
  """
402
448
  Plots the number of line lengths based on different max lengths
egglog/runtime.py CHANGED
@@ -11,7 +11,8 @@ so they are not mangled by Python and can be accessed by the user.
11
11
 
12
12
  from __future__ import annotations
13
13
 
14
- from dataclasses import dataclass
14
+ from collections.abc import Callable
15
+ from dataclasses import dataclass, replace
15
16
  from inspect import Parameter, Signature
16
17
  from itertools import zip_longest
17
18
  from typing import TYPE_CHECKING, NoReturn, TypeVar, Union, cast, get_args, get_origin
@@ -22,7 +23,7 @@ from .thunk import Thunk
22
23
  from .type_constraint_solver import *
23
24
 
24
25
  if TYPE_CHECKING:
25
- from collections.abc import Callable, Iterable
26
+ from collections.abc import Iterable
26
27
 
27
28
  from .egraph import Expr
28
29
 
@@ -60,6 +61,8 @@ REFLECTED_BINARY_METHODS = {
60
61
  # Set this globally so we can get access to PyObject when we have a type annotation of just object.
61
62
  # This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
62
63
  _PY_OBJECT_CLASS: RuntimeClass | None = None
64
+ # Same for functions
65
+ _UNSTABLE_FN_CLASS: RuntimeClass | None = None
63
66
 
64
67
  T = TypeVar("T")
65
68
 
@@ -67,6 +70,8 @@ T = TypeVar("T")
67
70
  def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
68
71
  """
69
72
  Resolves a type object into a type reference.
73
+
74
+ Any runtime type object decls will be add to those passed in.
70
75
  """
71
76
  if isinstance(tp, TypeVar):
72
77
  return ClassTypeVarRef(tp.__name__)
@@ -79,6 +84,11 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
79
84
  if tp == object:
80
85
  assert _PY_OBJECT_CLASS
81
86
  return resolve_type_annotation(decls, _PY_OBJECT_CLASS)
87
+ # If the type is a `Callable` then convert it into a UnstableFn
88
+ if get_origin(tp) == Callable:
89
+ assert _UNSTABLE_FN_CLASS
90
+ args, ret = get_args(tp)
91
+ return resolve_type_annotation(decls, _UNSTABLE_FN_CLASS[(ret, *args)])
82
92
  if isinstance(tp, RuntimeClass):
83
93
  decls |= tp
84
94
  return tp.__egg_tp__
@@ -95,9 +105,11 @@ class RuntimeClass(DelayedDeclerations):
95
105
  __egg_tp__: TypeRefWithVars
96
106
 
97
107
  def __post_init__(self) -> None:
98
- global _PY_OBJECT_CLASS
99
- if self.__egg_tp__.name == "PyObject":
108
+ global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
109
+ if (name := self.__egg_tp__.name) == "PyObject":
100
110
  _PY_OBJECT_CLASS = self
111
+ elif name == "UnstableFn" and not self.__egg_tp__.args:
112
+ _UNSTABLE_FN_CLASS = self
101
113
 
102
114
  def verify(self) -> None:
103
115
  if not self.__egg_tp__.args:
@@ -113,26 +125,48 @@ class RuntimeClass(DelayedDeclerations):
113
125
  Create an instance of this kind by calling the __init__ classmethod
114
126
  """
115
127
  # If this is a literal type, initializing it with a literal should return a literal
116
- if self.__egg_tp__.name == "PyObject":
128
+ if (name := self.__egg_tp__.name) == "PyObject":
117
129
  assert len(args) == 1
118
130
  return RuntimeExpr.__from_value__(
119
131
  self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
120
132
  )
121
- if self.__egg_tp__.name in UNARY_LIT_CLASS_NAMES:
133
+ if name == "UnstableFn":
134
+ assert not kwargs
135
+ fn_arg, *partial_args = args
136
+ del args
137
+ # Assumes we don't have types set for UnstableFn w/ generics, that they have to be inferred
138
+
139
+ # 1. Create a runtime function for the first arg
140
+ assert isinstance(fn_arg, RuntimeFunction)
141
+ # 2. Call it with the partial args, and use untyped vars for the rest of the args
142
+ res = fn_arg(*partial_args, _egg_partial_function=True)
143
+ assert res is not None, "Mutable partial functions not supported"
144
+ # 3. Use the inferred return type and inferred rest arg types as the types of the function, and
145
+ # the partially applied args as the args.
146
+ call = (res_typed_expr := res.__egg_typed_expr__).expr
147
+ return_tp = res_typed_expr.tp
148
+ assert isinstance(call, CallDecl), "partial function must be a call"
149
+ n_args = len(partial_args)
150
+ value = PartialCallDecl(replace(call, args=call.args[:n_args]))
151
+ remaining_arg_types = [a.tp for a in call.args[n_args:]]
152
+ type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
153
+ return RuntimeExpr.__from_value__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
154
+
155
+ if name in UNARY_LIT_CLASS_NAMES:
122
156
  assert len(args) == 1
123
157
  assert isinstance(args[0], int | float | str | bool)
124
158
  return RuntimeExpr.__from_value__(
125
159
  self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
126
160
  )
127
- if self.__egg_tp__.name == UNIT_CLASS_NAME:
161
+ if name == UNIT_CLASS_NAME:
128
162
  assert len(args) == 0
129
163
  return RuntimeExpr.__from_value__(
130
164
  self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
131
165
  )
132
166
 
133
167
  return RuntimeFunction(
134
- Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, "__init__"), self.__egg_tp__.to_just()
135
- )(*args, **kwargs)
168
+ Thunk.value(self.__egg_decls__), ClassMethodRef(name, "__init__"), self.__egg_tp__.to_just()
169
+ )(*args, **kwargs) # type: ignore[arg-type]
136
170
 
137
171
  def __dir__(self) -> list[str]:
138
172
  cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
@@ -184,6 +218,12 @@ class RuntimeClass(DelayedDeclerations):
184
218
  return RuntimeFunction(
185
219
  Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
186
220
  )
221
+ # allow referencing properties and methods as class variables as well
222
+ if name in cls_decl.properties:
223
+ return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(self.__egg_tp__.name, name))
224
+ if name in cls_decl.methods:
225
+ return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(self.__egg_tp__.name, name))
226
+
187
227
  msg = f"Class {self.__egg_tp__.name} has no method {name}"
188
228
  if name == "__ne__":
189
229
  msg += ". Did you mean to use the ne(...).to(...)?"
@@ -207,24 +247,47 @@ class RuntimeFunction(DelayedDeclerations):
207
247
  # bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
208
248
  __egg_bound__: JustTypeRef | RuntimeExpr | None = None
209
249
 
210
- def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
250
+ def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
211
251
  from .conversion import resolve_literal
212
252
 
213
253
  if isinstance(self.__egg_bound__, RuntimeExpr):
214
254
  args = (self.__egg_bound__, *args)
215
- fn_decl = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl()
255
+ signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
256
+ decls = self.__egg_decls__.copy()
257
+ # Special case function application bc we dont support variadic generics yet generally
258
+ if signature == "fn-app":
259
+ fn, *rest_args = args
260
+ args = tuple(rest_args)
261
+ assert not kwargs
262
+ assert isinstance(fn, RuntimeExpr)
263
+ decls.update(fn)
264
+ function_value = fn.__egg_typed_expr__
265
+ fn_tp = function_value.tp
266
+ assert fn_tp.name == "UnstableFn"
267
+ fn_return_tp, *fn_arg_tps = fn_tp.args
268
+ signature = FunctionSignature(
269
+ tuple(tp.to_var() for tp in fn_arg_tps),
270
+ tuple(f"_{i}" for i in range(len(fn_arg_tps))),
271
+ (None,) * len(fn_arg_tps),
272
+ fn_return_tp.to_var(),
273
+ )
274
+ else:
275
+ function_value = None
276
+ assert isinstance(signature, FunctionSignature)
277
+
216
278
  # Turn all keyword args into positional args
217
- bound = callable_decl_to_signature(fn_decl, self.__egg_decls__).bind(*args, **kwargs)
279
+ py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
280
+ bound = py_signature.bind(*args, **kwargs)
281
+ del kwargs
218
282
  bound.apply_defaults()
219
283
  assert not bound.kwargs
220
- del args, kwargs
284
+ args = bound.args
221
285
 
222
286
  upcasted_args = [
223
287
  resolve_literal(cast(TypeOrVarRef, tp), arg)
224
- for arg, tp in zip_longest(bound.args, fn_decl.arg_types, fillvalue=fn_decl.var_arg_type)
288
+ for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
225
289
  ]
226
-
227
- decls = Declarations.create(self, *upcasted_args)
290
+ decls.update(*upcasted_args)
228
291
 
229
292
  tcs = TypeConstraintSolver(decls)
230
293
  bound_tp = (
@@ -234,19 +297,27 @@ class RuntimeFunction(DelayedDeclerations):
234
297
  if isinstance(self.__egg_bound__, RuntimeExpr)
235
298
  else self.__egg_bound__
236
299
  )
237
- if bound_tp and bound_tp.args:
300
+ if (
301
+ bound_tp
302
+ and bound_tp.args
303
+ # Don't bind class if we have a first class function arg, b/c we don't support that yet
304
+ and not function_value
305
+ ):
238
306
  tcs.bind_class(bound_tp)
239
307
  arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
240
308
  arg_types = [expr.tp for expr in arg_exprs]
241
309
  cls_name = bound_tp.name if bound_tp else None
242
310
  return_tp = tcs.infer_return_type(
243
- fn_decl.arg_types, fn_decl.return_type or fn_decl.arg_types[0], fn_decl.var_arg_type, arg_types, cls_name
311
+ signature.arg_types, signature.semantic_return_type, signature.var_arg_type, arg_types, cls_name
244
312
  )
245
313
  bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None
314
+ # If we were using unstable-app to call a funciton, add that function back as the first arg.
315
+ if function_value:
316
+ arg_exprs = (function_value, *arg_exprs)
246
317
  expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
247
318
  typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
248
319
  # If there is not return type, we are mutating the first arg
249
- if not fn_decl.return_type:
320
+ if not signature.return_type:
250
321
  first_arg = upcasted_args[0]
251
322
  first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
252
323
  return None
@@ -262,19 +333,26 @@ class RuntimeFunction(DelayedDeclerations):
262
333
  return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
263
334
 
264
335
 
265
- def callable_decl_to_signature(
266
- decl: FunctionDecl,
267
- decls: Declarations,
268
- ) -> Signature:
336
+ def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
337
+ """
338
+ Convert to a Python signature.
339
+
340
+ If optional_args is true, then all args will be treated as optional, as if a default was provided that makes them
341
+ a var with that arg name as the value.
342
+
343
+ Used for partial application to try binding a function with only some of its args.
344
+ """
269
345
  parameters = [
270
346
  Parameter(
271
347
  n,
272
348
  Parameter.POSITIONAL_OR_KEYWORD,
273
- default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d)) if d else Parameter.empty,
349
+ default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n)))
350
+ if d is not None or optional_args
351
+ else Parameter.empty,
274
352
  )
275
- for n, d, t in zip(decl.arg_names, decl.arg_defaults, decl.arg_types, strict=True)
353
+ for n, d, t in zip(sig.arg_names, sig.arg_defaults, sig.arg_types, strict=True)
276
354
  ]
277
- if isinstance(decl, FunctionDecl) and decl.var_arg_type is not None:
355
+ if isinstance(sig, FunctionSignature) and sig.var_arg_type is not None:
278
356
  parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
279
357
  return Signature(parameters)
280
358
 
@@ -412,10 +490,14 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
412
490
  try:
413
491
  return call_method_min_conversion(self, args[0], __name)
414
492
  except ConvertError:
415
- return NotImplemented
493
+ # Defer raising not imeplemented in case the dunder method is not symmetrical, then
494
+ # we use the standard process
495
+ pass
416
496
  if __name in class_decl.methods:
417
497
  fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
418
- return fn(*args, **kwargs)
498
+ return fn(*args, **kwargs) # type: ignore[arg-type]
499
+ if __name in PARTIAL_METHODS:
500
+ return NotImplemented
419
501
  raise TypeError(f"{class_name!r} object does not support {__name}")
420
502
 
421
503
  setattr(RuntimeExpr, name, _special_method)
@@ -436,8 +518,8 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime
436
518
  # find a minimum type that both can be converted to
437
519
  # This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
438
520
  min_tp = min_convertable_tp(slf, other, name)
439
- slf = resolve_literal(min_tp.to_var(), slf)
440
- other = resolve_literal(min_tp.to_var(), other)
521
+ slf = resolve_literal(TypeRefWithVars(min_tp), slf)
522
+ other = resolve_literal(TypeRefWithVars(min_tp), other)
441
523
  method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
442
524
  return method(other)
443
525
 
egglog/thunk.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from dataclasses import dataclass
4
4
  from typing import TYPE_CHECKING, Generic, TypeVar
5
5
 
6
- from typing_extensions import ParamSpec, TypeVarTuple, Unpack
6
+ from typing_extensions import TypeVarTuple, Unpack
7
7
 
8
8
  if TYPE_CHECKING:
9
9
  from collections.abc import Callable
@@ -12,7 +12,6 @@ if TYPE_CHECKING:
12
12
  __all__ = ["Thunk"]
13
13
 
14
14
  T = TypeVar("T")
15
- P = ParamSpec("P")
16
15
  TS = TypeVarTuple("TS")
17
16
 
18
17