effectful 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +2 -1
- effectful/internals/product_n.py +163 -0
- effectful/internals/unification.py +4 -3
- effectful/ops/semantics.py +24 -13
- effectful/ops/syntax.py +60 -17
- {effectful-0.2.0.dist-info → effectful-0.2.2.dist-info}/METADATA +2 -2
- {effectful-0.2.0.dist-info → effectful-0.2.2.dist-info}/RECORD +10 -9
- {effectful-0.2.0.dist-info → effectful-0.2.2.dist-info}/WHEEL +0 -0
- {effectful-0.2.0.dist-info → effectful-0.2.2.dist-info}/licenses/LICENSE.md +0 -0
- {effectful-0.2.0.dist-info → effectful-0.2.2.dist-info}/top_level.txt +0 -0
effectful/handlers/indexed.py
CHANGED
@@ -233,7 +233,8 @@ def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
|
|
233
233
|
if k in indexset_vars
|
234
234
|
}
|
235
235
|
|
236
|
-
|
236
|
+
args = [v() for v in binding.values()]
|
237
|
+
return deffn(value, *binding.keys())(*args)
|
237
238
|
|
238
239
|
|
239
240
|
def stack(
|
@@ -0,0 +1,163 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import functools
|
3
|
+
from collections.abc import Callable, Mapping
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import tree
|
7
|
+
|
8
|
+
from effectful.ops.semantics import apply, coproduct, handler
|
9
|
+
from effectful.ops.syntax import defop
|
10
|
+
from effectful.ops.types import Interpretation, Operation
|
11
|
+
|
12
|
+
|
13
|
+
@dataclasses.dataclass
|
14
|
+
class CallByNeed[**P, T]:
|
15
|
+
func: Callable[P, T]
|
16
|
+
args: Any # P.args
|
17
|
+
kwargs: Any # P.kwargs
|
18
|
+
value: T | None = None
|
19
|
+
initialized: bool = False
|
20
|
+
|
21
|
+
def __init__(self, func, *args, **kwargs):
|
22
|
+
self.func = func
|
23
|
+
self.args = args
|
24
|
+
self.kwargs = kwargs
|
25
|
+
|
26
|
+
def __call__(self):
|
27
|
+
if not self.initialized:
|
28
|
+
self.value = self.func(*self.args, **self.kwargs)
|
29
|
+
self.initialized = True
|
30
|
+
return self.value
|
31
|
+
|
32
|
+
|
33
|
+
@defop
|
34
|
+
def argsof(op: Operation) -> tuple[list, dict]:
|
35
|
+
raise RuntimeError("Prompt argsof not bound.")
|
36
|
+
|
37
|
+
|
38
|
+
@dataclasses.dataclass
|
39
|
+
class Product[S, T]:
|
40
|
+
values: object
|
41
|
+
|
42
|
+
|
43
|
+
def _pack(intp):
|
44
|
+
from effectful.internals.runtime import interpreter
|
45
|
+
|
46
|
+
return Product(interpreter(intp)(lambda x: x()))
|
47
|
+
|
48
|
+
|
49
|
+
def _unpack(x, prompt):
|
50
|
+
if isinstance(x, Product):
|
51
|
+
return x.values(prompt)
|
52
|
+
return x
|
53
|
+
|
54
|
+
|
55
|
+
def productN(intps: Mapping[Operation, Interpretation]) -> Interpretation:
|
56
|
+
# The resulting interpretation supports ops that exist in at least one input
|
57
|
+
# interpretation
|
58
|
+
result_ops = set(op for intp in intps.values() for op in intp)
|
59
|
+
if result_ops is None:
|
60
|
+
return {}
|
61
|
+
|
62
|
+
renaming = {(prompt, op): defop(op) for prompt in intps for op in result_ops}
|
63
|
+
|
64
|
+
# We enforce isolation between the named interpretations by giving every
|
65
|
+
# operation a fresh name and giving each operation a translation from
|
66
|
+
# the fresh names back to the names from their interpretation.
|
67
|
+
#
|
68
|
+
# E.g. { a: { f, g }, b: { f, h } } =>
|
69
|
+
# { handler({f: f_a, g: g_a, h: h_default})(f_a), handler({f: f_a, g: g_a})(g_a),
|
70
|
+
# handler({f: f_b, h: h_b})(f_b), handler({f: f_b, h: h_b})(h_b) }
|
71
|
+
translation_intps: dict[Operation, Interpretation] = {
|
72
|
+
prompt: {op: renaming[(prompt, op)] for op in result_ops} for prompt in intps
|
73
|
+
}
|
74
|
+
|
75
|
+
# For every prompt, build an isolated interpretation that binds all operations.
|
76
|
+
isolated_intps = {
|
77
|
+
prompt: {
|
78
|
+
renaming[(prompt, op)]: handler(translation_intps[prompt])(func)
|
79
|
+
for op, func in intp.items()
|
80
|
+
}
|
81
|
+
for prompt, intp in intps.items()
|
82
|
+
}
|
83
|
+
|
84
|
+
def product_op(op, *args, **kwargs):
|
85
|
+
"""Compute the product of operation `op` in named interpretations
|
86
|
+
`intps`. The product operation consumes product arguments and
|
87
|
+
returns product results. These products are represented as
|
88
|
+
interpretations.
|
89
|
+
|
90
|
+
"""
|
91
|
+
assert isinstance(op, Operation)
|
92
|
+
|
93
|
+
result_intp = {}
|
94
|
+
|
95
|
+
def argsof_direct_call(prompt):
|
96
|
+
return result_intp[prompt].args, result_intp[prompt].kwargs
|
97
|
+
|
98
|
+
def argsof_apply(prompt):
|
99
|
+
return result_intp[prompt].args[2:], result_intp[prompt].kwargs
|
100
|
+
|
101
|
+
# Every prompt gets an argsof implementation. The implementation is
|
102
|
+
# either for a direct call to a handler or for a call to an apply
|
103
|
+
# handler.
|
104
|
+
argsof_prompts = {}
|
105
|
+
|
106
|
+
for prompt, intp in intps.items():
|
107
|
+
# Args and kwargs are expected to be either interpretations with
|
108
|
+
# bindings for each named analysis in intps or concrete values.
|
109
|
+
# `get_for_intp` extracts the value that corresponds to this
|
110
|
+
# analysis.
|
111
|
+
#
|
112
|
+
# TODO: `get_for_intp` has to guess whether a dict value is an
|
113
|
+
# interpretation or not. This is probably a latent bug.
|
114
|
+
intp_args, intp_kwargs = tree.map_structure(
|
115
|
+
lambda x: _unpack(x, prompt), (args, kwargs)
|
116
|
+
)
|
117
|
+
|
118
|
+
# Making result a CallByNeed has two functions. It avoids some
|
119
|
+
# work when the result is not requested and it delays evaluation
|
120
|
+
# so that when the result is requested in `get_for_intp`, it
|
121
|
+
# evaluates in a context that binds the results of the other
|
122
|
+
# named interpretations.
|
123
|
+
isolated_intp = isolated_intps[prompt]
|
124
|
+
renamed_op = renaming[(prompt, op)]
|
125
|
+
if op in intp:
|
126
|
+
result = CallByNeed(
|
127
|
+
handler(isolated_intp)(renamed_op), *intp_args, **intp_kwargs
|
128
|
+
)
|
129
|
+
argsof_impl = argsof_direct_call
|
130
|
+
elif apply in intp:
|
131
|
+
result = CallByNeed(
|
132
|
+
handler(isolated_intp)(renaming[(prompt, apply)]),
|
133
|
+
renamed_op,
|
134
|
+
*intp_args,
|
135
|
+
**intp_kwargs,
|
136
|
+
)
|
137
|
+
argsof_impl = argsof_apply
|
138
|
+
else:
|
139
|
+
# TODO: If an intp does not handle an operation and has no apply
|
140
|
+
# handler, use the default rule. In the future, we would like to
|
141
|
+
# instead defer to the enclosing interpretation. This is
|
142
|
+
# difficult right now, because the output interpretation handles
|
143
|
+
# all operations with product handlers which would have to be
|
144
|
+
# skipped over.
|
145
|
+
result = CallByNeed(
|
146
|
+
handler(coproduct(isolated_intp, translation_intps[prompt]))(
|
147
|
+
op.__default_rule__
|
148
|
+
),
|
149
|
+
*intp_args,
|
150
|
+
**intp_kwargs,
|
151
|
+
)
|
152
|
+
argsof_impl = argsof_direct_call
|
153
|
+
|
154
|
+
result_intp[prompt] = result
|
155
|
+
argsof_prompts[prompt] = argsof_impl
|
156
|
+
|
157
|
+
result_intp[argsof] = lambda prompt: argsof_prompts[prompt](prompt)
|
158
|
+
return _pack(result_intp)
|
159
|
+
|
160
|
+
product_intp: Interpretation = {
|
161
|
+
op: functools.partial(product_op, op) for op in result_ops
|
162
|
+
}
|
163
|
+
return product_intp
|
@@ -440,9 +440,10 @@ def _(typ: type | abc.ABCMeta):
|
|
440
440
|
return typ
|
441
441
|
elif types.get_original_bases(typ):
|
442
442
|
for base in types.get_original_bases(typ):
|
443
|
-
|
444
|
-
|
445
|
-
|
443
|
+
if typing.get_origin(base) is not typing.Generic:
|
444
|
+
cbase = canonicalize(base)
|
445
|
+
if cbase != object:
|
446
|
+
return cbase
|
446
447
|
return typ
|
447
448
|
else:
|
448
449
|
raise TypeError(f"Cannot canonicalize type {typ}.")
|
effectful/ops/semantics.py
CHANGED
@@ -250,6 +250,14 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
|
|
250
250
|
elif isinstance(expr, collections.abc.Sequence):
|
251
251
|
if isinstance(expr, str | bytes):
|
252
252
|
return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
|
253
|
+
elif (
|
254
|
+
isinstance(expr, tuple)
|
255
|
+
and hasattr(expr, "_fields")
|
256
|
+
and all(hasattr(expr, field) for field in getattr(expr, "_fields"))
|
257
|
+
): # namedtuple
|
258
|
+
return type(expr)(
|
259
|
+
**{field: evaluate(getattr(expr, field)) for field in expr._fields}
|
260
|
+
)
|
253
261
|
else:
|
254
262
|
return type(expr)(evaluate(item) for item in expr) # type: ignore
|
255
263
|
elif isinstance(expr, collections.abc.Set):
|
@@ -274,6 +282,21 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
|
|
274
282
|
return typing.cast(T, expr)
|
275
283
|
|
276
284
|
|
285
|
+
def _simple_type(tp: type) -> type:
|
286
|
+
"""Convert a type object into a type that can be dispatched on."""
|
287
|
+
if isinstance(tp, typing.TypeVar):
|
288
|
+
tp = (
|
289
|
+
tp.__bound__
|
290
|
+
if tp.__bound__
|
291
|
+
else tp.__constraints__[0]
|
292
|
+
if tp.__constraints__
|
293
|
+
else object
|
294
|
+
)
|
295
|
+
if isinstance(tp, types.UnionType):
|
296
|
+
raise TypeError(f"Union types are not supported: {tp}")
|
297
|
+
return typing.get_origin(tp) or tp
|
298
|
+
|
299
|
+
|
277
300
|
def typeof[T](term: Expr[T]) -> type[T]:
|
278
301
|
"""Return the type of an expression.
|
279
302
|
|
@@ -302,19 +325,7 @@ def typeof[T](term: Expr[T]) -> type[T]:
|
|
302
325
|
if isinstance(term, Term):
|
303
326
|
# If term is a Term, we evaluate it to get its type
|
304
327
|
tp = evaluate(term)
|
305
|
-
|
306
|
-
tp = (
|
307
|
-
tp.__bound__
|
308
|
-
if tp.__bound__
|
309
|
-
else tp.__constraints__[0]
|
310
|
-
if tp.__constraints__
|
311
|
-
else object
|
312
|
-
)
|
313
|
-
if isinstance(tp, types.UnionType):
|
314
|
-
raise TypeError(
|
315
|
-
f"Cannot determine type of {term} because it is a union type: {tp}"
|
316
|
-
)
|
317
|
-
return typing.get_origin(tp) or tp # type: ignore
|
328
|
+
return _simple_type(typing.cast(type, tp))
|
318
329
|
else:
|
319
330
|
return type(term)
|
320
331
|
|
effectful/ops/syntax.py
CHANGED
@@ -924,8 +924,12 @@ def defdata[T](
|
|
924
924
|
When an Operation whose return type is `Callable` is passed to :func:`defdata`,
|
925
925
|
it is reconstructed as a :class:`_CallableTerm`, which implements the :func:`__call__` method.
|
926
926
|
"""
|
927
|
-
from effectful.
|
927
|
+
from effectful.internals.product_n import productN
|
928
|
+
from effectful.internals.runtime import interpreter
|
929
|
+
from effectful.ops.semantics import _simple_type, apply, evaluate
|
928
930
|
|
931
|
+
# If this operation binds variables, we need to rename them in the
|
932
|
+
# appropriate parts of the child term.
|
929
933
|
bindings: inspect.BoundArguments = op.__fvs_rule__(*args, **kwargs)
|
930
934
|
renaming = {
|
931
935
|
var: defop(var)
|
@@ -933,24 +937,54 @@ def defdata[T](
|
|
933
937
|
for var in bound_vars
|
934
938
|
}
|
935
939
|
|
936
|
-
|
940
|
+
# Analysis for type computation and term reconstruction
|
941
|
+
typ = defop(object, name="typ")
|
942
|
+
cast = defop(object, name="cast")
|
943
|
+
|
944
|
+
def apply_type(op, *args, **kwargs):
|
945
|
+
assert isinstance(op, Operation)
|
946
|
+
tp = op.__type_rule__(*args, **kwargs)
|
947
|
+
return tp
|
948
|
+
|
949
|
+
def apply_cast(op, *args, **kwargs):
|
950
|
+
assert isinstance(op, Operation)
|
951
|
+
full_type = typ()
|
952
|
+
dispatch_type = _simple_type(full_type)
|
953
|
+
return __dispatch(dispatch_type)(op, *args, **kwargs)
|
954
|
+
|
955
|
+
analysis = productN({typ: {apply: apply_type}, cast: {apply: apply_cast}})
|
956
|
+
|
957
|
+
def evaluate_with_renaming(expr, ctx):
|
958
|
+
"""Evaluate an expression with renaming applied."""
|
959
|
+
renaming_ctx = {
|
960
|
+
old_var: new_var for old_var, new_var in renaming.items() if old_var in ctx
|
961
|
+
}
|
962
|
+
|
963
|
+
# Note: coproduct cannot be used to compose these interpretations
|
964
|
+
# because evaluate will only do operation replacement when the handler
|
965
|
+
# is operation typed, which coproduct does not satisfy.
|
966
|
+
with interpreter(analysis | renaming_ctx):
|
967
|
+
result = evaluate(expr)
|
968
|
+
|
969
|
+
return result
|
970
|
+
|
971
|
+
renamed_args = op.__signature__.bind(*args, **kwargs)
|
937
972
|
renamed_args.apply_defaults()
|
938
973
|
|
939
974
|
args_ = [
|
940
|
-
|
941
|
-
|
942
|
-
)
|
943
|
-
for i, arg in enumerate(renamed_args.args)
|
975
|
+
evaluate_with_renaming(arg, bindings.args[i])
|
976
|
+
for (i, arg) in enumerate(renamed_args.args)
|
944
977
|
]
|
945
978
|
kwargs_ = {
|
946
|
-
k:
|
947
|
-
|
948
|
-
)
|
949
|
-
for k, arg in renamed_args.kwargs.items()
|
979
|
+
k: evaluate_with_renaming(v, bindings.kwargs[k])
|
980
|
+
for (k, v) in renamed_args.kwargs.items()
|
950
981
|
}
|
951
982
|
|
952
|
-
|
953
|
-
|
983
|
+
# Build the final term with type analysis
|
984
|
+
with interpreter(analysis):
|
985
|
+
result = op(*args_, **kwargs_)
|
986
|
+
|
987
|
+
return result.values(cast) # type: ignore
|
954
988
|
|
955
989
|
|
956
990
|
@defterm.register(object)
|
@@ -1161,11 +1195,20 @@ def _(x: collections.abc.Mapping, other) -> bool:
|
|
1161
1195
|
|
1162
1196
|
@syntactic_eq.register
|
1163
1197
|
def _(x: collections.abc.Sequence, other) -> bool:
|
1164
|
-
|
1165
|
-
isinstance(
|
1166
|
-
and
|
1167
|
-
and all(
|
1168
|
-
)
|
1198
|
+
if (
|
1199
|
+
isinstance(x, tuple)
|
1200
|
+
and hasattr(x, "_fields")
|
1201
|
+
and all(hasattr(x, f) for f in x._fields)
|
1202
|
+
):
|
1203
|
+
return type(other) == type(x) and all(
|
1204
|
+
syntactic_eq(getattr(x, f), getattr(other, f)) for f in x._fields
|
1205
|
+
)
|
1206
|
+
else:
|
1207
|
+
return (
|
1208
|
+
isinstance(other, collections.abc.Sequence)
|
1209
|
+
and len(x) == len(other)
|
1210
|
+
and all(syntactic_eq(a, b) for a, b in zip(x, other))
|
1211
|
+
)
|
1169
1212
|
|
1170
1213
|
|
1171
1214
|
@syntactic_eq.register(object)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: effectful
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.2
|
4
4
|
Summary: Metaprogramming infrastructure
|
5
5
|
Author: Basis
|
6
6
|
License-Expression: Apache-2.0
|
@@ -15,7 +15,7 @@ Classifier: Operating System :: POSIX :: Linux
|
|
15
15
|
Classifier: Operating System :: MacOS :: MacOS X
|
16
16
|
Classifier: Programming Language :: Python :: 3.12
|
17
17
|
Classifier: Programming Language :: Python :: 3.13
|
18
|
-
Requires-Python:
|
18
|
+
Requires-Python: <3.14,>=3.12
|
19
19
|
Description-Content-Type: text/x-rst
|
20
20
|
License-File: LICENSE.md
|
21
21
|
Provides-Extra: torch
|
@@ -1,7 +1,7 @@
|
|
1
1
|
effectful/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
effectful/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
effectful/handlers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
effectful/handlers/indexed.py,sha256=
|
4
|
+
effectful/handlers/indexed.py,sha256=TL3Y9EZ9Q4b12qtIEg1iWZ-mk97rZHw3zY9XAXKXCCM,12638
|
5
5
|
effectful/handlers/numpyro.py,sha256=RWoBNpHLr5KdotN5Vu118jY0kn_p6NaejcnJgZJLehw,19457
|
6
6
|
effectful/handlers/pyro.py,sha256=qVl1wson02pyV8YHGf93KDnYEp5pGmhKEwji95OYBl8,26486
|
7
7
|
effectful/handlers/torch.py,sha256=NNM7mxqZskEBCjsl25kHI95WlXG9aeD7FaSkXkoLZ_I,24330
|
@@ -12,15 +12,16 @@ effectful/handlers/jax/numpy/__init__.py,sha256=Kmvya0QI-GA56pPf1as-wYOuZFngOBLt
|
|
12
12
|
effectful/handlers/jax/numpy/linalg.py,sha256=9DiaYYG4SztmO-VkmMH3dVvULtMK-zEgbV9oNQFkFo8,350
|
13
13
|
effectful/handlers/jax/scipy/special.py,sha256=yTIECFtQVPgraonrPlyenjvcnEYchZwIZC-5CSkF-lA,299
|
14
14
|
effectful/internals/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15
|
+
effectful/internals/product_n.py,sha256=Ii06tiGvuvMAzv_wUDBnjUGPB1ZdVxhJPj2uqBeAJyQ,5842
|
15
16
|
effectful/internals/runtime.py,sha256=aLWol7sR1yHekn7zNz1evHKHARjiT1tnkmByLHPHBGc,1811
|
16
17
|
effectful/internals/tensor_utils.py,sha256=3QCSUqdxCXod3dsY3oRMcg36Rqr8pVX-ktEyCEkeODo,1173
|
17
|
-
effectful/internals/unification.py,sha256=
|
18
|
+
effectful/internals/unification.py,sha256=MD0beZ29by9hbJaowafBLIZuKPF5MxHXezuGpSG_wCA,30254
|
18
19
|
effectful/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
effectful/ops/semantics.py,sha256=
|
20
|
-
effectful/ops/syntax.py,sha256=
|
20
|
+
effectful/ops/semantics.py,sha256=9fvBwER-ZF940b28ZS1fKvxl7ra6GafBg9FTIw2IR4c,11915
|
21
|
+
effectful/ops/syntax.py,sha256=lP0_vX8Q_B6XD2OuDoMxBMV_fQIAOLy0QtFYtKV74Mc,57072
|
21
22
|
effectful/ops/types.py,sha256=W1gZJaBnX7_nFpWrG3vfCBQPSun3Gc9PqT61ls8B3EA,6599
|
22
|
-
effectful-0.2.
|
23
|
-
effectful-0.2.
|
24
|
-
effectful-0.2.
|
25
|
-
effectful-0.2.
|
26
|
-
effectful-0.2.
|
23
|
+
effectful-0.2.2.dist-info/licenses/LICENSE.md,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
24
|
+
effectful-0.2.2.dist-info/METADATA,sha256=9jZF6agI9xsyZ0ZUXYfSMH6ku4P-IamE_A-I9oi-T4U,5306
|
25
|
+
effectful-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
26
|
+
effectful-0.2.2.dist-info/top_level.txt,sha256=gtuJfrE2nXil_lZLCnqWF2KAbOnJs9ILNvK8WnkRzbs,10
|
27
|
+
effectful-0.2.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|