effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +27 -46
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +321 -169
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +104 -84
- effectful/ops/syntax.py +1276 -167
- effectful/ops/types.py +141 -35
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/METADATA +65 -57
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -259
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
effectful/handlers/numbers.py
DELETED
@@ -1,259 +0,0 @@
|
|
1
|
-
import numbers
|
2
|
-
import operator
|
3
|
-
from typing import Any, TypeVar
|
4
|
-
|
5
|
-
from typing_extensions import ParamSpec
|
6
|
-
|
7
|
-
from effectful.ops.syntax import NoDefaultRule, defdata, defop, syntactic_eq
|
8
|
-
from effectful.ops.types import Operation, Term
|
9
|
-
|
10
|
-
P = ParamSpec("P")
|
11
|
-
Q = ParamSpec("Q")
|
12
|
-
S = TypeVar("S")
|
13
|
-
T = TypeVar("T")
|
14
|
-
V = TypeVar("V")
|
15
|
-
|
16
|
-
T_Number = TypeVar("T_Number", bound=numbers.Number)
|
17
|
-
|
18
|
-
|
19
|
-
@defdata.register(numbers.Number)
|
20
|
-
@numbers.Number.register
|
21
|
-
class _NumberTerm(Term[numbers.Number]):
|
22
|
-
def __init__(
|
23
|
-
self, op: Operation[..., numbers.Number], args: tuple, kwargs: dict
|
24
|
-
) -> None:
|
25
|
-
self._op = op
|
26
|
-
self._args = args
|
27
|
-
self._kwargs = kwargs
|
28
|
-
|
29
|
-
@property
|
30
|
-
def op(self) -> Operation[..., numbers.Number]:
|
31
|
-
return self._op
|
32
|
-
|
33
|
-
@property
|
34
|
-
def args(self) -> tuple:
|
35
|
-
return self._args
|
36
|
-
|
37
|
-
@property
|
38
|
-
def kwargs(self) -> dict:
|
39
|
-
return self._kwargs
|
40
|
-
|
41
|
-
def __hash__(self):
|
42
|
-
return hash((self.op, tuple(self.args), tuple(self.kwargs.items())))
|
43
|
-
|
44
|
-
|
45
|
-
# Complex specific methods
|
46
|
-
@defop
|
47
|
-
def eq(x: T_Number, y: T_Number) -> bool:
|
48
|
-
if not any(isinstance(a, Term) for a in (x, y)):
|
49
|
-
return operator.eq(x, y)
|
50
|
-
else:
|
51
|
-
return syntactic_eq(x, y)
|
52
|
-
|
53
|
-
|
54
|
-
def _wrap_cmp(op):
|
55
|
-
def _wrapped_op(x: T_Number, y: T_Number) -> bool:
|
56
|
-
if not any(isinstance(a, Term) for a in (x, y)):
|
57
|
-
return op(x, y)
|
58
|
-
else:
|
59
|
-
raise NoDefaultRule
|
60
|
-
|
61
|
-
_wrapped_op.__name__ = op.__name__
|
62
|
-
return _wrapped_op
|
63
|
-
|
64
|
-
|
65
|
-
def _wrap_binop(op):
|
66
|
-
def _wrapped_op(x: T_Number, y: T_Number) -> T_Number:
|
67
|
-
if not any(isinstance(a, Term) for a in (x, y)):
|
68
|
-
return op(x, y)
|
69
|
-
else:
|
70
|
-
raise NoDefaultRule
|
71
|
-
|
72
|
-
_wrapped_op.__name__ = op.__name__
|
73
|
-
return _wrapped_op
|
74
|
-
|
75
|
-
|
76
|
-
def _wrap_unop(op):
|
77
|
-
def _wrapped_op(x: T_Number) -> T_Number:
|
78
|
-
if not isinstance(x, Term):
|
79
|
-
return op(x)
|
80
|
-
else:
|
81
|
-
raise NoDefaultRule
|
82
|
-
|
83
|
-
_wrapped_op.__name__ = op.__name__
|
84
|
-
return _wrapped_op
|
85
|
-
|
86
|
-
|
87
|
-
add = defop(_wrap_binop(operator.add))
|
88
|
-
neg = defop(_wrap_unop(operator.neg))
|
89
|
-
pos = defop(_wrap_unop(operator.pos))
|
90
|
-
sub = defop(_wrap_binop(operator.sub))
|
91
|
-
mul = defop(_wrap_binop(operator.mul))
|
92
|
-
truediv = defop(_wrap_binop(operator.truediv))
|
93
|
-
pow = defop(_wrap_binop(operator.pow))
|
94
|
-
abs = defop(_wrap_unop(operator.abs))
|
95
|
-
|
96
|
-
|
97
|
-
@defdata.register(numbers.Complex)
|
98
|
-
@numbers.Complex.register
|
99
|
-
class _ComplexTerm(_NumberTerm, Term[numbers.Complex]):
|
100
|
-
def __bool__(self) -> bool:
|
101
|
-
raise ValueError("Cannot convert term to bool")
|
102
|
-
|
103
|
-
def __add__(self, other: Any) -> numbers.Real:
|
104
|
-
return add(self, other)
|
105
|
-
|
106
|
-
def __radd__(self, other: Any) -> numbers.Real:
|
107
|
-
return add(other, self)
|
108
|
-
|
109
|
-
def __neg__(self):
|
110
|
-
return neg(self)
|
111
|
-
|
112
|
-
def __pos__(self):
|
113
|
-
return pos(self)
|
114
|
-
|
115
|
-
def __sub__(self, other: Any) -> numbers.Real:
|
116
|
-
return sub(self, other)
|
117
|
-
|
118
|
-
def __rsub__(self, other: Any) -> numbers.Real:
|
119
|
-
return sub(other, self)
|
120
|
-
|
121
|
-
def __mul__(self, other: Any) -> numbers.Real:
|
122
|
-
return mul(self, other)
|
123
|
-
|
124
|
-
def __rmul__(self, other: Any) -> numbers.Real:
|
125
|
-
return mul(other, self)
|
126
|
-
|
127
|
-
def __truediv__(self, other: Any) -> numbers.Real:
|
128
|
-
return truediv(self, other)
|
129
|
-
|
130
|
-
def __rtruediv__(self, other: Any) -> numbers.Real:
|
131
|
-
return truediv(other, self)
|
132
|
-
|
133
|
-
def __pow__(self, other: Any) -> numbers.Real:
|
134
|
-
return pow(self, other)
|
135
|
-
|
136
|
-
def __rpow__(self, other: Any) -> numbers.Real:
|
137
|
-
return pow(other, self)
|
138
|
-
|
139
|
-
def __abs__(self) -> numbers.Real:
|
140
|
-
return abs(self)
|
141
|
-
|
142
|
-
def __eq__(self, other: Any) -> bool:
|
143
|
-
return eq(self, other)
|
144
|
-
|
145
|
-
|
146
|
-
# Real specific methods
|
147
|
-
floordiv = defop(_wrap_binop(operator.floordiv))
|
148
|
-
mod = defop(_wrap_binop(operator.mod))
|
149
|
-
lt = defop(_wrap_cmp(operator.lt))
|
150
|
-
le = defop(_wrap_cmp(operator.le))
|
151
|
-
gt = defop(_wrap_cmp(operator.gt))
|
152
|
-
ge = defop(_wrap_cmp(operator.ge))
|
153
|
-
|
154
|
-
|
155
|
-
@defdata.register(numbers.Real)
|
156
|
-
@numbers.Real.register
|
157
|
-
class _RealTerm(_ComplexTerm, Term[numbers.Real]):
|
158
|
-
# Real specific methods
|
159
|
-
def __float__(self) -> float:
|
160
|
-
raise ValueError("Cannot convert term to float")
|
161
|
-
|
162
|
-
def __trunc__(self) -> numbers.Integral:
|
163
|
-
raise NotImplementedError
|
164
|
-
|
165
|
-
def __floor__(self) -> numbers.Integral:
|
166
|
-
raise NotImplementedError
|
167
|
-
|
168
|
-
def __ceil__(self) -> numbers.Integral:
|
169
|
-
raise NotImplementedError
|
170
|
-
|
171
|
-
def __round__(self, ndigits=None) -> numbers.Integral:
|
172
|
-
raise NotImplementedError
|
173
|
-
|
174
|
-
def __floordiv__(self, other):
|
175
|
-
return floordiv(self, other)
|
176
|
-
|
177
|
-
def __rfloordiv__(self, other):
|
178
|
-
return floordiv(other, self)
|
179
|
-
|
180
|
-
def __mod__(self, other):
|
181
|
-
return mod(self, other)
|
182
|
-
|
183
|
-
def __rmod__(self, other):
|
184
|
-
return mod(other, self)
|
185
|
-
|
186
|
-
def __lt__(self, other):
|
187
|
-
return lt(self, other)
|
188
|
-
|
189
|
-
def __le__(self, other):
|
190
|
-
return le(self, other)
|
191
|
-
|
192
|
-
|
193
|
-
@defdata.register(numbers.Rational)
|
194
|
-
@numbers.Rational.register
|
195
|
-
class _RationalTerm(_RealTerm, Term[numbers.Rational]):
|
196
|
-
@property
|
197
|
-
def numerator(self):
|
198
|
-
raise NotImplementedError
|
199
|
-
|
200
|
-
@property
|
201
|
-
def denominator(self):
|
202
|
-
raise NotImplementedError
|
203
|
-
|
204
|
-
|
205
|
-
# Integral specific methods
|
206
|
-
index = defop(_wrap_unop(operator.index))
|
207
|
-
lshift = defop(_wrap_binop(operator.lshift))
|
208
|
-
rshift = defop(_wrap_binop(operator.rshift))
|
209
|
-
and_ = defop(_wrap_binop(operator.and_))
|
210
|
-
xor = defop(_wrap_binop(operator.xor))
|
211
|
-
or_ = defop(_wrap_binop(operator.or_))
|
212
|
-
invert = defop(_wrap_unop(operator.invert))
|
213
|
-
|
214
|
-
|
215
|
-
@defdata.register(numbers.Integral)
|
216
|
-
@numbers.Integral.register
|
217
|
-
class _IntegralTerm(_RationalTerm, Term[numbers.Integral]):
|
218
|
-
# Integral specific methods
|
219
|
-
def __int__(self) -> int:
|
220
|
-
raise ValueError("Cannot convert term to int")
|
221
|
-
|
222
|
-
def __index__(self) -> numbers.Integral:
|
223
|
-
return index(self)
|
224
|
-
|
225
|
-
def __pow__(self, exponent: Any, modulus=None) -> numbers.Integral:
|
226
|
-
return pow(self, exponent)
|
227
|
-
|
228
|
-
def __lshift__(self, other):
|
229
|
-
return lshift(self, other)
|
230
|
-
|
231
|
-
def __rlshift__(self, other):
|
232
|
-
return lshift(other, self)
|
233
|
-
|
234
|
-
def __rshift__(self, other):
|
235
|
-
return rshift(self, other)
|
236
|
-
|
237
|
-
def __rrshift__(self, other):
|
238
|
-
return rshift(other, self)
|
239
|
-
|
240
|
-
def __and__(self, other):
|
241
|
-
return and_(self, other)
|
242
|
-
|
243
|
-
def __rand__(self, other):
|
244
|
-
return and_(other, self)
|
245
|
-
|
246
|
-
def __xor__(self, other):
|
247
|
-
return xor(self, other)
|
248
|
-
|
249
|
-
def __rxor__(self, other):
|
250
|
-
return xor(other, self)
|
251
|
-
|
252
|
-
def __or__(self, other):
|
253
|
-
return or_(self, other)
|
254
|
-
|
255
|
-
def __ror__(self, other):
|
256
|
-
return or_(other, self)
|
257
|
-
|
258
|
-
def __invert__(self):
|
259
|
-
return invert(self)
|
effectful/internals/base_impl.py
DELETED
@@ -1,259 +0,0 @@
|
|
1
|
-
import collections
|
2
|
-
import functools
|
3
|
-
import inspect
|
4
|
-
import typing
|
5
|
-
from typing import Callable, Generic, Mapping, Sequence, Set, Type, TypeVar
|
6
|
-
|
7
|
-
import tree
|
8
|
-
from typing_extensions import ParamSpec
|
9
|
-
|
10
|
-
from effectful.ops.types import Expr, Operation, Term
|
11
|
-
|
12
|
-
P = ParamSpec("P")
|
13
|
-
Q = ParamSpec("Q")
|
14
|
-
S = TypeVar("S")
|
15
|
-
T = TypeVar("T")
|
16
|
-
V = TypeVar("V")
|
17
|
-
|
18
|
-
|
19
|
-
def rename(
|
20
|
-
subs: Mapping[Operation[..., S], Operation[..., S]],
|
21
|
-
leaf_value: V, # Union[Term[V], Operation[..., V], V],
|
22
|
-
) -> V: # Union[Term[V], Operation[..., V], V]:
|
23
|
-
from effectful.internals.runtime import interpreter
|
24
|
-
from effectful.ops.semantics import apply, evaluate
|
25
|
-
|
26
|
-
if isinstance(leaf_value, Operation):
|
27
|
-
return subs.get(leaf_value, leaf_value) # type: ignore
|
28
|
-
elif isinstance(leaf_value, Term):
|
29
|
-
with interpreter(
|
30
|
-
{apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k), **subs}
|
31
|
-
):
|
32
|
-
return evaluate(leaf_value) # type: ignore
|
33
|
-
else:
|
34
|
-
return leaf_value
|
35
|
-
|
36
|
-
|
37
|
-
class _BaseOperation(Generic[Q, V], Operation[Q, V]):
|
38
|
-
signature: Callable[Q, V]
|
39
|
-
|
40
|
-
def __init__(self, signature: Callable[Q, V]):
|
41
|
-
functools.update_wrapper(self, signature)
|
42
|
-
self.signature = signature
|
43
|
-
|
44
|
-
def __eq__(self, other):
|
45
|
-
if not isinstance(other, Operation):
|
46
|
-
return NotImplemented
|
47
|
-
return self.signature == other.signature
|
48
|
-
|
49
|
-
def __hash__(self):
|
50
|
-
return hash(self.signature)
|
51
|
-
|
52
|
-
def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
|
53
|
-
from effectful.ops.syntax import NoDefaultRule
|
54
|
-
|
55
|
-
try:
|
56
|
-
return self.signature(*args, **kwargs)
|
57
|
-
except NoDefaultRule:
|
58
|
-
return self.__free_rule__(*args, **kwargs)
|
59
|
-
|
60
|
-
def __free_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
|
61
|
-
from effectful.ops.syntax import Bound, Scoped, defdata, defop
|
62
|
-
|
63
|
-
sig = inspect.signature(self.signature)
|
64
|
-
bound_sig = sig.bind(*args, **kwargs)
|
65
|
-
bound_sig.apply_defaults()
|
66
|
-
|
67
|
-
bound_vars: dict[int, set[Operation]] = collections.defaultdict(set)
|
68
|
-
scoped_args: dict[int, set[str]] = collections.defaultdict(set)
|
69
|
-
unscoped_args: set[str] = set()
|
70
|
-
for param_name, param in bound_sig.signature.parameters.items():
|
71
|
-
if typing.get_origin(param.annotation) is typing.Annotated:
|
72
|
-
for anno in param.annotation.__metadata__:
|
73
|
-
if isinstance(anno, Bound):
|
74
|
-
scoped_args[anno.scope].add(param_name)
|
75
|
-
if param.kind is inspect.Parameter.VAR_POSITIONAL:
|
76
|
-
assert isinstance(bound_sig.arguments[param_name], tuple)
|
77
|
-
for bound_var in bound_sig.arguments[param_name]:
|
78
|
-
bound_vars[anno.scope].add(bound_var)
|
79
|
-
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
80
|
-
assert isinstance(bound_sig.arguments[param_name], dict)
|
81
|
-
for bound_var in bound_sig.arguments[param_name].values():
|
82
|
-
bound_vars[anno.scope].add(bound_var)
|
83
|
-
else:
|
84
|
-
bound_vars[anno.scope].add(bound_sig.arguments[param_name])
|
85
|
-
elif isinstance(anno, Scoped):
|
86
|
-
scoped_args[anno.scope].add(param_name)
|
87
|
-
else:
|
88
|
-
unscoped_args.add(param_name)
|
89
|
-
|
90
|
-
# TODO replace this temporary check with more general scope level propagation
|
91
|
-
if bound_vars:
|
92
|
-
min_scope = min(bound_vars.keys(), default=0)
|
93
|
-
scoped_args[min_scope] |= unscoped_args
|
94
|
-
max_scope = max(bound_vars.keys(), default=0)
|
95
|
-
assert all(s in bound_vars or s > max_scope for s in scoped_args.keys())
|
96
|
-
|
97
|
-
# recursively rename bound variables from innermost to outermost scope
|
98
|
-
for scope in sorted(bound_vars.keys()):
|
99
|
-
# create fresh variables for each bound variable in the scope
|
100
|
-
renaming_map = {var: defop(var) for var in bound_vars[scope]}
|
101
|
-
# get just the arguments that are in the scope
|
102
|
-
for name in scoped_args[scope]:
|
103
|
-
bound_sig.arguments[name] = tree.map_structure(
|
104
|
-
lambda a: rename(renaming_map, a),
|
105
|
-
bound_sig.arguments[name],
|
106
|
-
)
|
107
|
-
|
108
|
-
return defdata(_BaseTerm(self, bound_sig.args, bound_sig.kwargs))
|
109
|
-
|
110
|
-
def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> Type[V]:
|
111
|
-
sig = inspect.signature(self.signature)
|
112
|
-
bound_sig = sig.bind(*args, **kwargs)
|
113
|
-
bound_sig.apply_defaults()
|
114
|
-
|
115
|
-
anno = sig.return_annotation
|
116
|
-
if anno is inspect.Signature.empty:
|
117
|
-
return typing.cast(Type[V], object)
|
118
|
-
elif isinstance(anno, typing.TypeVar):
|
119
|
-
# rudimentary but sound special-case type inference sufficient for syntax ops:
|
120
|
-
# if the return type annotation is a TypeVar,
|
121
|
-
# look for a parameter with the same annotation and return its type,
|
122
|
-
# otherwise give up and return Any/object
|
123
|
-
for name, param in bound_sig.signature.parameters.items():
|
124
|
-
if param.annotation is anno and param.kind not in (
|
125
|
-
inspect.Parameter.VAR_POSITIONAL,
|
126
|
-
inspect.Parameter.VAR_KEYWORD,
|
127
|
-
):
|
128
|
-
arg = bound_sig.arguments[name]
|
129
|
-
tp: Type[V] = type(arg) if not isinstance(arg, type) else arg
|
130
|
-
return tp
|
131
|
-
return typing.cast(Type[V], object)
|
132
|
-
elif typing.get_origin(anno) is typing.Annotated:
|
133
|
-
tp = typing.get_args(anno)[0]
|
134
|
-
if not typing.TYPE_CHECKING:
|
135
|
-
tp = tp if typing.get_origin(tp) is None else typing.get_origin(tp)
|
136
|
-
return tp
|
137
|
-
elif typing.get_origin(anno) is not None:
|
138
|
-
return typing.get_origin(anno)
|
139
|
-
else:
|
140
|
-
return anno
|
141
|
-
|
142
|
-
def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> Set[Operation]:
|
143
|
-
from effectful.ops.syntax import Bound
|
144
|
-
|
145
|
-
sig = inspect.signature(self.signature)
|
146
|
-
bound_sig = sig.bind(*args, **kwargs)
|
147
|
-
bound_sig.apply_defaults()
|
148
|
-
|
149
|
-
bound_vars: Set[Operation] = set()
|
150
|
-
for param_name, param in bound_sig.signature.parameters.items():
|
151
|
-
if typing.get_origin(param.annotation) is typing.Annotated:
|
152
|
-
for anno in param.annotation.__metadata__:
|
153
|
-
if isinstance(anno, Bound):
|
154
|
-
if param.kind is inspect.Parameter.VAR_POSITIONAL:
|
155
|
-
for bound_var in bound_sig.arguments[param_name]:
|
156
|
-
bound_vars.add(bound_var)
|
157
|
-
elif param.kind is inspect.Parameter.VAR_KEYWORD:
|
158
|
-
for bound_var in bound_sig.arguments[param_name].values():
|
159
|
-
bound_vars.add(bound_var)
|
160
|
-
else:
|
161
|
-
bound_var = bound_sig.arguments[param_name]
|
162
|
-
bound_vars.add(bound_var)
|
163
|
-
|
164
|
-
return bound_vars
|
165
|
-
|
166
|
-
def __repr_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> str:
|
167
|
-
args_str = ", ".join(map(str, args)) if args else ""
|
168
|
-
kwargs_str = (
|
169
|
-
", ".join(f"{k}={str(v)}" for k, v in kwargs.items()) if kwargs else ""
|
170
|
-
)
|
171
|
-
|
172
|
-
ret = f"{self.signature.__name__}({args_str}"
|
173
|
-
if kwargs:
|
174
|
-
ret += f"{', ' if args else ''}"
|
175
|
-
ret += f"{kwargs_str})"
|
176
|
-
return ret
|
177
|
-
|
178
|
-
def __repr__(self):
|
179
|
-
return self.signature.__name__
|
180
|
-
|
181
|
-
|
182
|
-
class _BaseTerm(Generic[T], Term[T]):
|
183
|
-
_op: Operation[..., T]
|
184
|
-
_args: Sequence[Expr]
|
185
|
-
_kwargs: Mapping[str, Expr]
|
186
|
-
|
187
|
-
def __init__(
|
188
|
-
self,
|
189
|
-
op: Operation[..., T],
|
190
|
-
args: Sequence[Expr],
|
191
|
-
kwargs: Mapping[str, Expr],
|
192
|
-
):
|
193
|
-
self._op = op
|
194
|
-
self._args = args
|
195
|
-
self._kwargs = kwargs
|
196
|
-
|
197
|
-
def __eq__(self, other) -> bool:
|
198
|
-
from effectful.ops.syntax import syntactic_eq
|
199
|
-
|
200
|
-
return syntactic_eq(self, other)
|
201
|
-
|
202
|
-
@property
|
203
|
-
def op(self):
|
204
|
-
return self._op
|
205
|
-
|
206
|
-
@property
|
207
|
-
def args(self):
|
208
|
-
return self._args
|
209
|
-
|
210
|
-
@property
|
211
|
-
def kwargs(self):
|
212
|
-
return self._kwargs
|
213
|
-
|
214
|
-
|
215
|
-
class _CallableTerm(Generic[P, T], _BaseTerm[collections.abc.Callable[P, T]]):
|
216
|
-
def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]:
|
217
|
-
from effectful.ops.semantics import call
|
218
|
-
|
219
|
-
return call(self, *args, **kwargs) # type: ignore
|
220
|
-
|
221
|
-
|
222
|
-
def _unembed_callable(value: Callable[P, T]) -> Expr[Callable[P, T]]:
|
223
|
-
from effectful.internals.runtime import interpreter
|
224
|
-
from effectful.ops.semantics import apply, call
|
225
|
-
from effectful.ops.syntax import deffn, defop
|
226
|
-
|
227
|
-
assert not isinstance(value, Term)
|
228
|
-
|
229
|
-
try:
|
230
|
-
sig = inspect.signature(value)
|
231
|
-
except ValueError:
|
232
|
-
return value
|
233
|
-
|
234
|
-
for name, param in sig.parameters.items():
|
235
|
-
if param.kind in (
|
236
|
-
inspect.Parameter.VAR_POSITIONAL,
|
237
|
-
inspect.Parameter.VAR_KEYWORD,
|
238
|
-
):
|
239
|
-
raise NotImplementedError(
|
240
|
-
f"cannot unembed {value}: parameter {name} is variadic"
|
241
|
-
)
|
242
|
-
|
243
|
-
bound_sig = sig.bind(
|
244
|
-
**{name: defop(param.annotation) for name, param in sig.parameters.items()}
|
245
|
-
)
|
246
|
-
bound_sig.apply_defaults()
|
247
|
-
|
248
|
-
with interpreter(
|
249
|
-
{
|
250
|
-
apply: lambda _, op, *a, **k: op.__free_rule__(*a, **k),
|
251
|
-
call: call.__default_rule__,
|
252
|
-
}
|
253
|
-
):
|
254
|
-
body = value(
|
255
|
-
*[a() for a in bound_sig.args],
|
256
|
-
**{k: v() for k, v in bound_sig.kwargs.items()},
|
257
|
-
)
|
258
|
-
|
259
|
-
return deffn(body, *bound_sig.args, **bound_sig.kwargs)
|
effectful-0.0.1.dist-info/RECORD
DELETED
@@ -1,19 +0,0 @@
|
|
1
|
-
effectful/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
effectful/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
effectful/handlers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
effectful/handlers/indexed.py,sha256=lNqlHZqD8JBs-KaC3YdCoQDAXHuQ9KhAH0ZhpIqMdFM,13184
|
5
|
-
effectful/handlers/numbers.py,sha256=aKVCJ-l3ISvOGHTfXsQIEp4uEsQsYy7bP-YUCsSO_3o,6557
|
6
|
-
effectful/handlers/pyro.py,sha256=s_CAXe65gqXmKiYLUXJ0_uLQzEEgyQbZv0hQWMMdQJ4,15312
|
7
|
-
effectful/handlers/torch.py,sha256=LJYwud6Dq3xg7A6_6aaDEyP1hT0IB4G_55Gki6izRvo,18401
|
8
|
-
effectful/internals/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
|
-
effectful/internals/base_impl.py,sha256=6Yxxg5tS_R2RLlHvtOhj_fuA_9hEa_YKf0LhQOX2448,9671
|
10
|
-
effectful/internals/runtime.py,sha256=E0ce5mfG0-DlTeALYZePWtdxr0AK3y0CPl_LE5pp_0k,1908
|
11
|
-
effectful/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
effectful/ops/semantics.py,sha256=Ihaj2ycildtMg18GSujSJSaAAKsArKoyMmR-gBoWx8E,10208
|
13
|
-
effectful/ops/syntax.py,sha256=nRfguHAol5lDnUv0UMOB0AL6y6By2DCW7hlCYsN3OnM,15331
|
14
|
-
effectful/ops/types.py,sha256=DmVqX8naWNQEW1A6w7ShrXp_LZV6qoK4p_OAM3NitTI,3241
|
15
|
-
effectful-0.0.1.dist-info/LICENSE.md,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
16
|
-
effectful-0.0.1.dist-info/METADATA,sha256=WN7j_mJwZao1T1MTWyAIo_1hP6YUCS4L6rPI9wWCScA,4930
|
17
|
-
effectful-0.0.1.dist-info/WHEEL,sha256=In9FTNxeP60KnTkGw7wk6mJPYd_dQSjEZmXdBdMCI-8,91
|
18
|
-
effectful-0.0.1.dist-info/top_level.txt,sha256=gtuJfrE2nXil_lZLCnqWF2KAbOnJs9ILNvK8WnkRzbs,10
|
19
|
-
effectful-0.0.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|