egglog 7.0.0__cp312-none-win_amd64.whl → 7.1.0__cp312-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of egglog might be problematic. Click here for more details.
- egglog/bindings.cp312-win_amd64.pyd +0 -0
- egglog/bindings.pyi +7 -0
- egglog/builtins.py +41 -1
- egglog/conversion.py +22 -17
- egglog/declarations.py +57 -31
- egglog/egraph.py +93 -18
- egglog/egraph_state.py +76 -37
- egglog/exp/array_api.py +8 -8
- egglog/pretty.py +56 -10
- egglog/runtime.py +112 -30
- egglog/thunk.py +1 -2
- {egglog-7.0.0.dist-info → egglog-7.1.0.dist-info}/METADATA +20 -20
- {egglog-7.0.0.dist-info → egglog-7.1.0.dist-info}/RECORD +15 -15
- {egglog-7.0.0.dist-info → egglog-7.1.0.dist-info}/WHEEL +0 -0
- {egglog-7.0.0.dist-info → egglog-7.1.0.dist-info}/license_files/LICENSE +0 -0
egglog/egraph_state.py
CHANGED
|
@@ -87,17 +87,26 @@ class EGraphState:
|
|
|
87
87
|
"""
|
|
88
88
|
Registers a ruleset if it's not already registered.
|
|
89
89
|
"""
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
self.
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
90
|
+
match self.__egg_decls__._rulesets[name]:
|
|
91
|
+
case RulesetDecl(rules):
|
|
92
|
+
if name not in self.rulesets:
|
|
93
|
+
if name:
|
|
94
|
+
self.egraph.run_program(bindings.AddRuleset(name))
|
|
95
|
+
added_rules = self.rulesets[name] = set()
|
|
96
|
+
else:
|
|
97
|
+
added_rules = self.rulesets[name]
|
|
98
|
+
for rule in rules:
|
|
99
|
+
if rule in added_rules:
|
|
100
|
+
continue
|
|
101
|
+
self.egraph.run_program(self.command_to_egg(rule, name))
|
|
102
|
+
added_rules.add(rule)
|
|
103
|
+
case CombinedRulesetDecl(rulesets):
|
|
104
|
+
if name in self.rulesets:
|
|
105
|
+
return
|
|
106
|
+
self.rulesets[name] = set()
|
|
107
|
+
for ruleset in rulesets:
|
|
108
|
+
self.ruleset_to_egg(ruleset)
|
|
109
|
+
self.egraph.run_program(bindings.UnstableCombinedRuleset(name, list(rulesets)))
|
|
101
110
|
|
|
102
111
|
def command_to_egg(self, cmd: CommandDecl, ruleset: str) -> bindings._Command:
|
|
103
112
|
match cmd:
|
|
@@ -184,11 +193,13 @@ class EGraphState:
|
|
|
184
193
|
)
|
|
185
194
|
case FunctionDecl():
|
|
186
195
|
if not decl.builtin:
|
|
196
|
+
signature = decl.signature
|
|
197
|
+
assert isinstance(signature, FunctionSignature), "Cannot turn special function to egg"
|
|
187
198
|
egg_fn_decl = bindings.FunctionDecl(
|
|
188
199
|
egg_name,
|
|
189
200
|
bindings.Schema(
|
|
190
|
-
[self.type_ref_to_egg(a.to_just()) for a in
|
|
191
|
-
self.type_ref_to_egg(
|
|
201
|
+
[self.type_ref_to_egg(a.to_just()) for a in signature.arg_types],
|
|
202
|
+
self.type_ref_to_egg(signature.semantic_return_type.to_just()),
|
|
192
203
|
),
|
|
193
204
|
self.expr_to_egg(decl.default) if decl.default else None,
|
|
194
205
|
self.expr_to_egg(decl.merge) if decl.merge else None,
|
|
@@ -212,19 +223,22 @@ class EGraphState:
|
|
|
212
223
|
decl = self.__egg_decls__._classes[ref.name]
|
|
213
224
|
self.type_ref_to_egg_sort[ref] = egg_name = decl.egg_name or _generate_type_egg_name(ref)
|
|
214
225
|
if not decl.builtin or ref.args:
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
(
|
|
220
|
-
self.type_ref_to_egg(
|
|
221
|
-
[bindings.Var(self.type_ref_to_egg(a)) for a in ref.args],
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
226
|
+
if ref.args:
|
|
227
|
+
if ref.name == "UnstableFn":
|
|
228
|
+
# UnstableFn is a special case, where the rest of args are collected into a call
|
|
229
|
+
type_args: list[bindings._Expr] = [
|
|
230
|
+
bindings.Call(
|
|
231
|
+
self.type_ref_to_egg(ref.args[1]),
|
|
232
|
+
[bindings.Var(self.type_ref_to_egg(a)) for a in ref.args[2:]],
|
|
233
|
+
),
|
|
234
|
+
bindings.Var(self.type_ref_to_egg(ref.args[0])),
|
|
235
|
+
]
|
|
236
|
+
else:
|
|
237
|
+
type_args = [bindings.Var(self.type_ref_to_egg(a)) for a in ref.args]
|
|
238
|
+
args = (self.type_ref_to_egg(JustTypeRef(ref.name)), type_args)
|
|
239
|
+
else:
|
|
240
|
+
args = None
|
|
241
|
+
self.egraph.run_program(bindings.Sort(egg_name, args))
|
|
228
242
|
# For builtin classes, let's also make sure we have the mapping of all egg fn names for class methods, because
|
|
229
243
|
# these can be created even without adding them to the e-graph, like `vec-empty` which can be extracted
|
|
230
244
|
# even if you never use that function.
|
|
@@ -292,6 +306,9 @@ class EGraphState:
|
|
|
292
306
|
res = bindings.Call(egg_fn, egg_args)
|
|
293
307
|
case PyObjectDecl(value):
|
|
294
308
|
res = GLOBAL_PY_OBJECT_SORT.store(value)
|
|
309
|
+
case PartialCallDecl(call_decl):
|
|
310
|
+
egg_fn_call = self.expr_to_egg(call_decl)
|
|
311
|
+
res = bindings.Call("unstable-fn", [bindings.Lit(bindings.String(egg_fn_call.name)), *egg_fn_call.args])
|
|
295
312
|
case _:
|
|
296
313
|
assert_never(expr_decl.expr)
|
|
297
314
|
|
|
@@ -371,24 +388,45 @@ class FromEggState:
|
|
|
371
388
|
if term.name == "py-object":
|
|
372
389
|
call = bindings.termdag_term_to_expr(self.termdag, term)
|
|
373
390
|
expr_decl = PyObjectDecl(self.state.egraph.eval_py_object(call))
|
|
391
|
+
if term.name == "unstable-fn":
|
|
392
|
+
# Get function name
|
|
393
|
+
fn_term, *arg_terms = term.args
|
|
394
|
+
fn_value = self.resolve_term(fn_term, JustTypeRef("String"))
|
|
395
|
+
assert isinstance(fn_value.expr, LitDecl)
|
|
396
|
+
fn_name = fn_value.expr.value
|
|
397
|
+
assert isinstance(fn_name, str)
|
|
398
|
+
|
|
399
|
+
# Resolve what types the partiallied applied args are
|
|
400
|
+
assert tp.name == "UnstableFn"
|
|
401
|
+
call_decl = self.from_call(tp.args[0], bindings.TermApp(fn_name, arg_terms))
|
|
402
|
+
expr_decl = PartialCallDecl(call_decl)
|
|
374
403
|
else:
|
|
375
404
|
expr_decl = self.from_call(tp, term)
|
|
376
405
|
else:
|
|
377
406
|
assert_never(term)
|
|
378
407
|
return TypedExprDecl(tp, expr_decl)
|
|
379
408
|
|
|
380
|
-
def from_call(
|
|
409
|
+
def from_call(
|
|
410
|
+
self,
|
|
411
|
+
tp: JustTypeRef,
|
|
412
|
+
term: bindings.TermApp, # additional_arg_tps: tuple[JustTypeRef, ...]
|
|
413
|
+
) -> CallDecl:
|
|
381
414
|
"""
|
|
382
415
|
Convert a call to a CallDecl.
|
|
383
416
|
|
|
384
417
|
There could be Python call refs which match the call, so we need to find the correct one.
|
|
418
|
+
|
|
419
|
+
The additional_arg_tps are known types for arguments that come after the term args, used to infer types
|
|
420
|
+
for partially applied functions, where we know the types of the later args, but not of the earlier ones where
|
|
421
|
+
we have values for.
|
|
385
422
|
"""
|
|
386
423
|
# Find the first callable ref that matches the call
|
|
387
424
|
for callable_ref in self.state.egg_fn_to_callable_refs[term.name]:
|
|
388
425
|
# If this is a classmethod, we might need the type params that were bound for this type
|
|
389
426
|
# This could be multiple types if the classmethod is ambiguous, like map create.
|
|
390
427
|
possible_types: Iterable[JustTypeRef | None]
|
|
391
|
-
|
|
428
|
+
signature = self.decls.get_callable_decl(callable_ref).to_function_decl().signature
|
|
429
|
+
assert isinstance(signature, FunctionSignature)
|
|
392
430
|
if isinstance(callable_ref, ClassMethodRef):
|
|
393
431
|
possible_types = self.state._get_possible_types(callable_ref.class_name)
|
|
394
432
|
cls_name = callable_ref.class_name
|
|
@@ -402,16 +440,17 @@ class FromEggState:
|
|
|
402
440
|
|
|
403
441
|
try:
|
|
404
442
|
arg_types, bound_tp_params = tcs.infer_arg_types(
|
|
405
|
-
|
|
443
|
+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, tp, cls_name
|
|
406
444
|
)
|
|
407
445
|
except TypeConstraintError:
|
|
408
446
|
continue
|
|
409
|
-
args
|
|
410
|
-
|
|
411
|
-
try:
|
|
412
|
-
res = self.cache[a]
|
|
413
|
-
except KeyError:
|
|
414
|
-
res = self.cache[a] = self.from_expr(tp, self.termdag.nodes[a])
|
|
415
|
-
args.append(res)
|
|
416
|
-
return CallDecl(callable_ref, tuple(args), bound_tp_params)
|
|
447
|
+
args = tuple(self.resolve_term(a, tp) for a, tp in zip(term.args, arg_types, strict=False))
|
|
448
|
+
return CallDecl(callable_ref, args, bound_tp_params)
|
|
417
449
|
raise ValueError(f"Could not find callable ref for call {term}")
|
|
450
|
+
|
|
451
|
+
def resolve_term(self, term_id: int, tp: JustTypeRef) -> TypedExprDecl:
|
|
452
|
+
try:
|
|
453
|
+
return self.cache[term_id]
|
|
454
|
+
except KeyError:
|
|
455
|
+
res = self.cache[term_id] = self.from_expr(tp, self.termdag.nodes[term_id])
|
|
456
|
+
return res
|
egglog/exp/array_api.py
CHANGED
|
@@ -18,7 +18,7 @@ from egglog.runtime import RuntimeExpr
|
|
|
18
18
|
from .program_gen import *
|
|
19
19
|
|
|
20
20
|
if TYPE_CHECKING:
|
|
21
|
-
from collections.abc import Iterator
|
|
21
|
+
from collections.abc import Callable, Iterator
|
|
22
22
|
from types import ModuleType
|
|
23
23
|
|
|
24
24
|
# Pretend that exprs are numbers b/c sklearn does isinstance checks
|
|
@@ -257,7 +257,7 @@ class TupleInt(Expr):
|
|
|
257
257
|
|
|
258
258
|
def __getitem__(self, i: Int) -> Int: ...
|
|
259
259
|
|
|
260
|
-
def
|
|
260
|
+
def fold(self, init: Int, f: Callable[[Int, Int], Int]) -> Int: ...
|
|
261
261
|
|
|
262
262
|
|
|
263
263
|
converter(
|
|
@@ -272,7 +272,7 @@ converter(
|
|
|
272
272
|
|
|
273
273
|
|
|
274
274
|
@array_api_ruleset.register
|
|
275
|
-
def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
|
|
275
|
+
def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64, f: Callable[[Int, Int], Int]):
|
|
276
276
|
return [
|
|
277
277
|
rewrite(ti + TupleInt.EMPTY).to(ti),
|
|
278
278
|
rewrite(TupleInt(i).length()).to(Int(1)),
|
|
@@ -281,10 +281,10 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
|
|
|
281
281
|
rewrite((TupleInt(i) + ti)[Int(0)]).to(i),
|
|
282
282
|
# Rule for indexing > 0
|
|
283
283
|
rule(eq(i).to((TupleInt(i2) + ti)[Int(k)]), k > 0).then(union(i).with_(ti[Int(k - 1)])),
|
|
284
|
-
#
|
|
285
|
-
rewrite(TupleInt(i)
|
|
286
|
-
rewrite(
|
|
287
|
-
rewrite(TupleInt.
|
|
284
|
+
# fold
|
|
285
|
+
rewrite(TupleInt.EMPTY.fold(i, f)).to(i),
|
|
286
|
+
rewrite(TupleInt(i2).fold(i, f)).to(f(i, i2)),
|
|
287
|
+
rewrite((TupleInt(i2) + ti).fold(i, f)).to(ti.fold(f(i, i2), f)),
|
|
288
288
|
]
|
|
289
289
|
|
|
290
290
|
|
|
@@ -1346,7 +1346,7 @@ def _unique(xs: TupleValue, a: NDArray, shape: TupleInt, copy: OptionalBool):
|
|
|
1346
1346
|
|
|
1347
1347
|
@array_api_ruleset.register
|
|
1348
1348
|
def _size(x: NDArray):
|
|
1349
|
-
yield rewrite(x.size).to(x.shape.
|
|
1349
|
+
yield rewrite(x.size).to(x.shape.fold(Int(1), Int.__mul__))
|
|
1350
1350
|
|
|
1351
1351
|
|
|
1352
1352
|
@overload
|
egglog/pretty.py
CHANGED
|
@@ -66,7 +66,7 @@ UNARY_METHODS = {
|
|
|
66
66
|
"__invert__": "~",
|
|
67
67
|
}
|
|
68
68
|
|
|
69
|
-
AllDecls: TypeAlias = RulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
|
|
69
|
+
AllDecls: TypeAlias = RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl
|
|
70
70
|
|
|
71
71
|
|
|
72
72
|
def pretty_decl(
|
|
@@ -106,7 +106,7 @@ def pretty_callable_ref(
|
|
|
106
106
|
"""
|
|
107
107
|
# Pass in three dummy args, which are the max used for any operation that
|
|
108
108
|
# is not a generic function call
|
|
109
|
-
args: list[ExprDecl] = [
|
|
109
|
+
args: list[ExprDecl] = [VarDecl(ARG_STR)] * 3
|
|
110
110
|
if first_arg:
|
|
111
111
|
args.insert(0, first_arg)
|
|
112
112
|
res = PrettyContext(decls, defaultdict(lambda: 0))._call_inner(
|
|
@@ -117,6 +117,10 @@ def pretty_callable_ref(
|
|
|
117
117
|
return res[0] if isinstance(res, tuple) else res
|
|
118
118
|
|
|
119
119
|
|
|
120
|
+
# TODO: Add a different pretty callable ref that doesnt fill in wholes but instead returns the function
|
|
121
|
+
# so that things like Math.__add__ will be represented properly
|
|
122
|
+
|
|
123
|
+
|
|
120
124
|
@dataclass
|
|
121
125
|
class TraverseContext:
|
|
122
126
|
"""
|
|
@@ -170,6 +174,10 @@ class TraverseContext:
|
|
|
170
174
|
if until:
|
|
171
175
|
for f in until:
|
|
172
176
|
self(f)
|
|
177
|
+
case PartialCallDecl(c):
|
|
178
|
+
self(c)
|
|
179
|
+
case CombinedRulesetDecl(_):
|
|
180
|
+
pass
|
|
173
181
|
case _:
|
|
174
182
|
assert_never(decl)
|
|
175
183
|
|
|
@@ -231,6 +239,9 @@ class PrettyContext:
|
|
|
231
239
|
return name, name
|
|
232
240
|
case CallDecl(_, _, _):
|
|
233
241
|
return self._call(decl, parens)
|
|
242
|
+
case PartialCallDecl(CallDecl(ref, typed_args, _)):
|
|
243
|
+
arg_strs = (_pretty_callable(ref), *(self(a.expr, parens=False, unwrap_lit=True) for a in typed_args))
|
|
244
|
+
return f"UnstableFn({', '.join(arg_strs)})", "fn"
|
|
234
245
|
case PyObjectDecl(value):
|
|
235
246
|
return repr(value) if unwrap_lit else f"PyObject({value!r})", "PyObject"
|
|
236
247
|
case ActionCommandDecl(action):
|
|
@@ -267,6 +278,10 @@ class PrettyContext:
|
|
|
267
278
|
return f"ruleset(name={ruleset_name!r})", f"ruleset_{ruleset_name}"
|
|
268
279
|
args = ", ".join(map(self, rules))
|
|
269
280
|
return f"ruleset({args})", "ruleset"
|
|
281
|
+
case CombinedRulesetDecl(rulesets):
|
|
282
|
+
if ruleset_name:
|
|
283
|
+
rulesets = (*rulesets, f"name={ruleset_name!r})")
|
|
284
|
+
return f"unstable_combine_rulesets({', '.join(rulesets)})", "combined_ruleset"
|
|
270
285
|
case SaturateDecl(schedule):
|
|
271
286
|
return f"{self(schedule, parens=True)}.saturate()", "schedule"
|
|
272
287
|
case RepeatDecl(schedule, times):
|
|
@@ -302,19 +317,28 @@ class PrettyContext:
|
|
|
302
317
|
l, r = self(args[0]), self(args[1])
|
|
303
318
|
return f"ne({l}).to({r})", "Unit"
|
|
304
319
|
function_decl = self.decls.get_callable_decl(ref).to_function_decl()
|
|
320
|
+
signature = function_decl.signature
|
|
321
|
+
|
|
305
322
|
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
|
|
306
323
|
n_defaults = 0
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
324
|
+
# Dont try counting defaults for function application
|
|
325
|
+
if isinstance(signature, FunctionSignature):
|
|
326
|
+
for arg, default in zip(
|
|
327
|
+
reversed(args), reversed(signature.arg_defaults), strict=not signature.var_arg_type
|
|
328
|
+
):
|
|
329
|
+
if arg != default:
|
|
330
|
+
break
|
|
331
|
+
n_defaults += 1
|
|
313
332
|
if n_defaults:
|
|
314
333
|
args = args[:-n_defaults]
|
|
315
334
|
|
|
316
|
-
|
|
317
|
-
if
|
|
335
|
+
# If this is a function application, the type is the first type arg of the function object
|
|
336
|
+
if signature == "fn-app":
|
|
337
|
+
tp_name = decl.args[0].tp.args[0].name
|
|
338
|
+
else:
|
|
339
|
+
assert isinstance(signature, FunctionSignature)
|
|
340
|
+
tp_name = signature.semantic_return_type.name
|
|
341
|
+
if isinstance(signature, FunctionSignature) and signature.mutates:
|
|
318
342
|
first_arg = args[0]
|
|
319
343
|
expr_str = self(first_arg)
|
|
320
344
|
# copy an identifier expression iff it has multiple parents (b/c then we can't mutate it directly)
|
|
@@ -397,6 +421,28 @@ class PrettyContext:
|
|
|
397
421
|
return name
|
|
398
422
|
|
|
399
423
|
|
|
424
|
+
def _pretty_callable(ref: CallableRef) -> str:
|
|
425
|
+
"""
|
|
426
|
+
Returns a function call as a string.
|
|
427
|
+
"""
|
|
428
|
+
match ref:
|
|
429
|
+
case FunctionRef(name):
|
|
430
|
+
return name
|
|
431
|
+
case (
|
|
432
|
+
ClassMethodRef(class_name, method_name)
|
|
433
|
+
| MethodRef(class_name, method_name)
|
|
434
|
+
| PropertyRef(class_name, method_name)
|
|
435
|
+
):
|
|
436
|
+
return f"{class_name}.{method_name}"
|
|
437
|
+
case ConstantRef(_):
|
|
438
|
+
msg = "Constants should not be callable"
|
|
439
|
+
raise NotImplementedError(msg)
|
|
440
|
+
case ClassVariableRef(_, _):
|
|
441
|
+
msg = "Class variables should not be callable"
|
|
442
|
+
raise NotADirectoryError(msg)
|
|
443
|
+
assert_never(ref)
|
|
444
|
+
|
|
445
|
+
|
|
400
446
|
def _plot_line_length(expr: object): # pragma: no cover
|
|
401
447
|
"""
|
|
402
448
|
Plots the number of line lengths based on different max lengths
|
egglog/runtime.py
CHANGED
|
@@ -11,7 +11,8 @@ so they are not mangled by Python and can be accessed by the user.
|
|
|
11
11
|
|
|
12
12
|
from __future__ import annotations
|
|
13
13
|
|
|
14
|
-
from
|
|
14
|
+
from collections.abc import Callable
|
|
15
|
+
from dataclasses import dataclass, replace
|
|
15
16
|
from inspect import Parameter, Signature
|
|
16
17
|
from itertools import zip_longest
|
|
17
18
|
from typing import TYPE_CHECKING, NoReturn, TypeVar, Union, cast, get_args, get_origin
|
|
@@ -22,7 +23,7 @@ from .thunk import Thunk
|
|
|
22
23
|
from .type_constraint_solver import *
|
|
23
24
|
|
|
24
25
|
if TYPE_CHECKING:
|
|
25
|
-
from collections.abc import
|
|
26
|
+
from collections.abc import Iterable
|
|
26
27
|
|
|
27
28
|
from .egraph import Expr
|
|
28
29
|
|
|
@@ -60,6 +61,8 @@ REFLECTED_BINARY_METHODS = {
|
|
|
60
61
|
# Set this globally so we can get access to PyObject when we have a type annotation of just object.
|
|
61
62
|
# This is the only time a type annotation doesn't need to include the egglog type b/c object is top so that would be redundant statically.
|
|
62
63
|
_PY_OBJECT_CLASS: RuntimeClass | None = None
|
|
64
|
+
# Same for functions
|
|
65
|
+
_UNSTABLE_FN_CLASS: RuntimeClass | None = None
|
|
63
66
|
|
|
64
67
|
T = TypeVar("T")
|
|
65
68
|
|
|
@@ -67,6 +70,8 @@ T = TypeVar("T")
|
|
|
67
70
|
def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
|
|
68
71
|
"""
|
|
69
72
|
Resolves a type object into a type reference.
|
|
73
|
+
|
|
74
|
+
Any runtime type object decls will be add to those passed in.
|
|
70
75
|
"""
|
|
71
76
|
if isinstance(tp, TypeVar):
|
|
72
77
|
return ClassTypeVarRef(tp.__name__)
|
|
@@ -79,6 +84,11 @@ def resolve_type_annotation(decls: Declarations, tp: object) -> TypeOrVarRef:
|
|
|
79
84
|
if tp == object:
|
|
80
85
|
assert _PY_OBJECT_CLASS
|
|
81
86
|
return resolve_type_annotation(decls, _PY_OBJECT_CLASS)
|
|
87
|
+
# If the type is a `Callable` then convert it into a UnstableFn
|
|
88
|
+
if get_origin(tp) == Callable:
|
|
89
|
+
assert _UNSTABLE_FN_CLASS
|
|
90
|
+
args, ret = get_args(tp)
|
|
91
|
+
return resolve_type_annotation(decls, _UNSTABLE_FN_CLASS[(ret, *args)])
|
|
82
92
|
if isinstance(tp, RuntimeClass):
|
|
83
93
|
decls |= tp
|
|
84
94
|
return tp.__egg_tp__
|
|
@@ -95,9 +105,11 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
95
105
|
__egg_tp__: TypeRefWithVars
|
|
96
106
|
|
|
97
107
|
def __post_init__(self) -> None:
|
|
98
|
-
global _PY_OBJECT_CLASS
|
|
99
|
-
if self.__egg_tp__.name == "PyObject":
|
|
108
|
+
global _PY_OBJECT_CLASS, _UNSTABLE_FN_CLASS
|
|
109
|
+
if (name := self.__egg_tp__.name) == "PyObject":
|
|
100
110
|
_PY_OBJECT_CLASS = self
|
|
111
|
+
elif name == "UnstableFn" and not self.__egg_tp__.args:
|
|
112
|
+
_UNSTABLE_FN_CLASS = self
|
|
101
113
|
|
|
102
114
|
def verify(self) -> None:
|
|
103
115
|
if not self.__egg_tp__.args:
|
|
@@ -113,26 +125,48 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
113
125
|
Create an instance of this kind by calling the __init__ classmethod
|
|
114
126
|
"""
|
|
115
127
|
# If this is a literal type, initializing it with a literal should return a literal
|
|
116
|
-
if self.__egg_tp__.name == "PyObject":
|
|
128
|
+
if (name := self.__egg_tp__.name) == "PyObject":
|
|
117
129
|
assert len(args) == 1
|
|
118
130
|
return RuntimeExpr.__from_value__(
|
|
119
131
|
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), PyObjectDecl(args[0]))
|
|
120
132
|
)
|
|
121
|
-
if
|
|
133
|
+
if name == "UnstableFn":
|
|
134
|
+
assert not kwargs
|
|
135
|
+
fn_arg, *partial_args = args
|
|
136
|
+
del args
|
|
137
|
+
# Assumes we don't have types set for UnstableFn w/ generics, that they have to be inferred
|
|
138
|
+
|
|
139
|
+
# 1. Create a runtime function for the first arg
|
|
140
|
+
assert isinstance(fn_arg, RuntimeFunction)
|
|
141
|
+
# 2. Call it with the partial args, and use untyped vars for the rest of the args
|
|
142
|
+
res = fn_arg(*partial_args, _egg_partial_function=True)
|
|
143
|
+
assert res is not None, "Mutable partial functions not supported"
|
|
144
|
+
# 3. Use the inferred return type and inferred rest arg types as the types of the function, and
|
|
145
|
+
# the partially applied args as the args.
|
|
146
|
+
call = (res_typed_expr := res.__egg_typed_expr__).expr
|
|
147
|
+
return_tp = res_typed_expr.tp
|
|
148
|
+
assert isinstance(call, CallDecl), "partial function must be a call"
|
|
149
|
+
n_args = len(partial_args)
|
|
150
|
+
value = PartialCallDecl(replace(call, args=call.args[:n_args]))
|
|
151
|
+
remaining_arg_types = [a.tp for a in call.args[n_args:]]
|
|
152
|
+
type_ref = JustTypeRef("UnstableFn", (return_tp, *remaining_arg_types))
|
|
153
|
+
return RuntimeExpr.__from_value__(Declarations.create(self, res), TypedExprDecl(type_ref, value))
|
|
154
|
+
|
|
155
|
+
if name in UNARY_LIT_CLASS_NAMES:
|
|
122
156
|
assert len(args) == 1
|
|
123
157
|
assert isinstance(args[0], int | float | str | bool)
|
|
124
158
|
return RuntimeExpr.__from_value__(
|
|
125
159
|
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(args[0]))
|
|
126
160
|
)
|
|
127
|
-
if
|
|
161
|
+
if name == UNIT_CLASS_NAME:
|
|
128
162
|
assert len(args) == 0
|
|
129
163
|
return RuntimeExpr.__from_value__(
|
|
130
164
|
self.__egg_decls__, TypedExprDecl(self.__egg_tp__.to_just(), LitDecl(None))
|
|
131
165
|
)
|
|
132
166
|
|
|
133
167
|
return RuntimeFunction(
|
|
134
|
-
Thunk.value(self.__egg_decls__), ClassMethodRef(
|
|
135
|
-
)(*args, **kwargs)
|
|
168
|
+
Thunk.value(self.__egg_decls__), ClassMethodRef(name, "__init__"), self.__egg_tp__.to_just()
|
|
169
|
+
)(*args, **kwargs) # type: ignore[arg-type]
|
|
136
170
|
|
|
137
171
|
def __dir__(self) -> list[str]:
|
|
138
172
|
cls_decl = self.__egg_decls__.get_class_decl(self.__egg_tp__.name)
|
|
@@ -184,6 +218,12 @@ class RuntimeClass(DelayedDeclerations):
|
|
|
184
218
|
return RuntimeFunction(
|
|
185
219
|
Thunk.value(self.__egg_decls__), ClassMethodRef(self.__egg_tp__.name, name), self.__egg_tp__.to_just()
|
|
186
220
|
)
|
|
221
|
+
# allow referencing properties and methods as class variables as well
|
|
222
|
+
if name in cls_decl.properties:
|
|
223
|
+
return RuntimeFunction(Thunk.value(self.__egg_decls__), PropertyRef(self.__egg_tp__.name, name))
|
|
224
|
+
if name in cls_decl.methods:
|
|
225
|
+
return RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(self.__egg_tp__.name, name))
|
|
226
|
+
|
|
187
227
|
msg = f"Class {self.__egg_tp__.name} has no method {name}"
|
|
188
228
|
if name == "__ne__":
|
|
189
229
|
msg += ". Did you mean to use the ne(...).to(...)?"
|
|
@@ -207,24 +247,47 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
207
247
|
# bound methods need to store RuntimeExpr not just TypedExprDecl, so they can mutate the expr if required on self
|
|
208
248
|
__egg_bound__: JustTypeRef | RuntimeExpr | None = None
|
|
209
249
|
|
|
210
|
-
def __call__(self, *args: object, **kwargs: object) -> RuntimeExpr | None:
|
|
250
|
+
def __call__(self, *args: object, _egg_partial_function: bool = False, **kwargs: object) -> RuntimeExpr | None:
|
|
211
251
|
from .conversion import resolve_literal
|
|
212
252
|
|
|
213
253
|
if isinstance(self.__egg_bound__, RuntimeExpr):
|
|
214
254
|
args = (self.__egg_bound__, *args)
|
|
215
|
-
|
|
255
|
+
signature = self.__egg_decls__.get_callable_decl(self.__egg_ref__).to_function_decl().signature
|
|
256
|
+
decls = self.__egg_decls__.copy()
|
|
257
|
+
# Special case function application bc we dont support variadic generics yet generally
|
|
258
|
+
if signature == "fn-app":
|
|
259
|
+
fn, *rest_args = args
|
|
260
|
+
args = tuple(rest_args)
|
|
261
|
+
assert not kwargs
|
|
262
|
+
assert isinstance(fn, RuntimeExpr)
|
|
263
|
+
decls.update(fn)
|
|
264
|
+
function_value = fn.__egg_typed_expr__
|
|
265
|
+
fn_tp = function_value.tp
|
|
266
|
+
assert fn_tp.name == "UnstableFn"
|
|
267
|
+
fn_return_tp, *fn_arg_tps = fn_tp.args
|
|
268
|
+
signature = FunctionSignature(
|
|
269
|
+
tuple(tp.to_var() for tp in fn_arg_tps),
|
|
270
|
+
tuple(f"_{i}" for i in range(len(fn_arg_tps))),
|
|
271
|
+
(None,) * len(fn_arg_tps),
|
|
272
|
+
fn_return_tp.to_var(),
|
|
273
|
+
)
|
|
274
|
+
else:
|
|
275
|
+
function_value = None
|
|
276
|
+
assert isinstance(signature, FunctionSignature)
|
|
277
|
+
|
|
216
278
|
# Turn all keyword args into positional args
|
|
217
|
-
|
|
279
|
+
py_signature = to_py_signature(signature, self.__egg_decls__, _egg_partial_function)
|
|
280
|
+
bound = py_signature.bind(*args, **kwargs)
|
|
281
|
+
del kwargs
|
|
218
282
|
bound.apply_defaults()
|
|
219
283
|
assert not bound.kwargs
|
|
220
|
-
|
|
284
|
+
args = bound.args
|
|
221
285
|
|
|
222
286
|
upcasted_args = [
|
|
223
287
|
resolve_literal(cast(TypeOrVarRef, tp), arg)
|
|
224
|
-
for arg, tp in zip_longest(
|
|
288
|
+
for arg, tp in zip_longest(args, signature.arg_types, fillvalue=signature.var_arg_type)
|
|
225
289
|
]
|
|
226
|
-
|
|
227
|
-
decls = Declarations.create(self, *upcasted_args)
|
|
290
|
+
decls.update(*upcasted_args)
|
|
228
291
|
|
|
229
292
|
tcs = TypeConstraintSolver(decls)
|
|
230
293
|
bound_tp = (
|
|
@@ -234,19 +297,27 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
234
297
|
if isinstance(self.__egg_bound__, RuntimeExpr)
|
|
235
298
|
else self.__egg_bound__
|
|
236
299
|
)
|
|
237
|
-
if
|
|
300
|
+
if (
|
|
301
|
+
bound_tp
|
|
302
|
+
and bound_tp.args
|
|
303
|
+
# Don't bind class if we have a first class function arg, b/c we don't support that yet
|
|
304
|
+
and not function_value
|
|
305
|
+
):
|
|
238
306
|
tcs.bind_class(bound_tp)
|
|
239
307
|
arg_exprs = tuple(arg.__egg_typed_expr__ for arg in upcasted_args)
|
|
240
308
|
arg_types = [expr.tp for expr in arg_exprs]
|
|
241
309
|
cls_name = bound_tp.name if bound_tp else None
|
|
242
310
|
return_tp = tcs.infer_return_type(
|
|
243
|
-
|
|
311
|
+
signature.arg_types, signature.semantic_return_type, signature.var_arg_type, arg_types, cls_name
|
|
244
312
|
)
|
|
245
313
|
bound_params = cast(JustTypeRef, bound_tp).args if isinstance(self.__egg_ref__, ClassMethodRef) else None
|
|
314
|
+
# If we were using unstable-app to call a funciton, add that function back as the first arg.
|
|
315
|
+
if function_value:
|
|
316
|
+
arg_exprs = (function_value, *arg_exprs)
|
|
246
317
|
expr_decl = CallDecl(self.__egg_ref__, arg_exprs, bound_params)
|
|
247
318
|
typed_expr_decl = TypedExprDecl(return_tp, expr_decl)
|
|
248
319
|
# If there is not return type, we are mutating the first arg
|
|
249
|
-
if not
|
|
320
|
+
if not signature.return_type:
|
|
250
321
|
first_arg = upcasted_args[0]
|
|
251
322
|
first_arg.__egg_thunk__ = Thunk.value((decls, typed_expr_decl))
|
|
252
323
|
return None
|
|
@@ -262,19 +333,26 @@ class RuntimeFunction(DelayedDeclerations):
|
|
|
262
333
|
return pretty_callable_ref(self.__egg_decls__, self.__egg_ref__, first_arg, bound_tp_params)
|
|
263
334
|
|
|
264
335
|
|
|
265
|
-
def
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
336
|
+
def to_py_signature(sig: FunctionSignature, decls: Declarations, optional_args: bool) -> Signature:
|
|
337
|
+
"""
|
|
338
|
+
Convert to a Python signature.
|
|
339
|
+
|
|
340
|
+
If optional_args is true, then all args will be treated as optional, as if a default was provided that makes them
|
|
341
|
+
a var with that arg name as the value.
|
|
342
|
+
|
|
343
|
+
Used for partial application to try binding a function with only some of its args.
|
|
344
|
+
"""
|
|
269
345
|
parameters = [
|
|
270
346
|
Parameter(
|
|
271
347
|
n,
|
|
272
348
|
Parameter.POSITIONAL_OR_KEYWORD,
|
|
273
|
-
default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d
|
|
349
|
+
default=RuntimeExpr.__from_value__(decls, TypedExprDecl(t.to_just(), d if d is not None else VarDecl(n)))
|
|
350
|
+
if d is not None or optional_args
|
|
351
|
+
else Parameter.empty,
|
|
274
352
|
)
|
|
275
|
-
for n, d, t in zip(
|
|
353
|
+
for n, d, t in zip(sig.arg_names, sig.arg_defaults, sig.arg_types, strict=True)
|
|
276
354
|
]
|
|
277
|
-
if isinstance(
|
|
355
|
+
if isinstance(sig, FunctionSignature) and sig.var_arg_type is not None:
|
|
278
356
|
parameters.append(Parameter("__rest", Parameter.VAR_POSITIONAL))
|
|
279
357
|
return Signature(parameters)
|
|
280
358
|
|
|
@@ -412,10 +490,14 @@ for name in list(BINARY_METHODS) + list(UNARY_METHODS) + ["__getitem__", "__call
|
|
|
412
490
|
try:
|
|
413
491
|
return call_method_min_conversion(self, args[0], __name)
|
|
414
492
|
except ConvertError:
|
|
415
|
-
|
|
493
|
+
# Defer raising not imeplemented in case the dunder method is not symmetrical, then
|
|
494
|
+
# we use the standard process
|
|
495
|
+
pass
|
|
416
496
|
if __name in class_decl.methods:
|
|
417
497
|
fn = RuntimeFunction(Thunk.value(self.__egg_decls__), MethodRef(class_name, __name), self)
|
|
418
|
-
return fn(*args, **kwargs)
|
|
498
|
+
return fn(*args, **kwargs) # type: ignore[arg-type]
|
|
499
|
+
if __name in PARTIAL_METHODS:
|
|
500
|
+
return NotImplemented
|
|
419
501
|
raise TypeError(f"{class_name!r} object does not support {__name}")
|
|
420
502
|
|
|
421
503
|
setattr(RuntimeExpr, name, _special_method)
|
|
@@ -436,8 +518,8 @@ def call_method_min_conversion(slf: object, other: object, name: str) -> Runtime
|
|
|
436
518
|
# find a minimum type that both can be converted to
|
|
437
519
|
# This is so so that calls like `-0.1 * Int("x")` work by upcasting both to floats.
|
|
438
520
|
min_tp = min_convertable_tp(slf, other, name)
|
|
439
|
-
slf = resolve_literal(min_tp
|
|
440
|
-
other = resolve_literal(min_tp
|
|
521
|
+
slf = resolve_literal(TypeRefWithVars(min_tp), slf)
|
|
522
|
+
other = resolve_literal(TypeRefWithVars(min_tp), other)
|
|
441
523
|
method = RuntimeFunction(Thunk.value(slf.__egg_decls__), MethodRef(slf.__egg_class_name__, name), slf)
|
|
442
524
|
return method(other)
|
|
443
525
|
|
egglog/thunk.py
CHANGED
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
4
|
from typing import TYPE_CHECKING, Generic, TypeVar
|
|
5
5
|
|
|
6
|
-
from typing_extensions import
|
|
6
|
+
from typing_extensions import TypeVarTuple, Unpack
|
|
7
7
|
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
9
|
from collections.abc import Callable
|
|
@@ -12,7 +12,6 @@ if TYPE_CHECKING:
|
|
|
12
12
|
__all__ = ["Thunk"]
|
|
13
13
|
|
|
14
14
|
T = TypeVar("T")
|
|
15
|
-
P = ParamSpec("P")
|
|
16
15
|
TS = TypeVarTuple("TS")
|
|
17
16
|
|
|
18
17
|
|