effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +27 -46
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +321 -169
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +104 -84
- effectful/ops/syntax.py +1276 -167
- effectful/ops/types.py +141 -35
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/METADATA +65 -57
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -259
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
effectful/internals/runtime.py
CHANGED
@@ -1,20 +1,14 @@
|
|
1
1
|
import contextlib
|
2
2
|
import dataclasses
|
3
3
|
import functools
|
4
|
-
from
|
5
|
-
|
6
|
-
from typing_extensions import ParamSpec
|
4
|
+
from collections.abc import Callable, Mapping
|
7
5
|
|
8
6
|
from effectful.ops.syntax import defop
|
9
7
|
from effectful.ops.types import Interpretation, Operation
|
10
8
|
|
11
|
-
P = ParamSpec("P")
|
12
|
-
S = TypeVar("S")
|
13
|
-
T = TypeVar("T")
|
14
|
-
|
15
9
|
|
16
10
|
@dataclasses.dataclass
|
17
|
-
class Runtime
|
11
|
+
class Runtime[S, T]:
|
18
12
|
interpretation: "Interpretation[S, T]"
|
19
13
|
|
20
14
|
|
@@ -29,7 +23,6 @@ def get_interpretation():
|
|
29
23
|
|
30
24
|
@contextlib.contextmanager
|
31
25
|
def interpreter(intp: "Interpretation"):
|
32
|
-
|
33
26
|
r = get_runtime()
|
34
27
|
old_intp = r.interpretation
|
35
28
|
try:
|
@@ -40,11 +33,11 @@ def interpreter(intp: "Interpretation"):
|
|
40
33
|
|
41
34
|
|
42
35
|
@defop
|
43
|
-
def _get_args() ->
|
36
|
+
def _get_args() -> tuple[tuple, Mapping]:
|
44
37
|
return ((), {})
|
45
38
|
|
46
39
|
|
47
|
-
def _restore_args(fn: Callable[P, T]) -> Callable[P, T]:
|
40
|
+
def _restore_args[**P, T](fn: Callable[P, T]) -> Callable[P, T]:
|
48
41
|
@functools.wraps(fn)
|
49
42
|
def _cont_wrapper(*a: P.args, **k: P.kwargs) -> T:
|
50
43
|
a, k = (a, k) if a or k else _get_args() # type: ignore
|
@@ -53,7 +46,7 @@ def _restore_args(fn: Callable[P, T]) -> Callable[P, T]:
|
|
53
46
|
return _cont_wrapper
|
54
47
|
|
55
48
|
|
56
|
-
def _save_args(fn: Callable[P, T]) -> Callable[P, T]:
|
49
|
+
def _save_args[**P, T](fn: Callable[P, T]) -> Callable[P, T]:
|
57
50
|
from effectful.ops.semantics import handler
|
58
51
|
|
59
52
|
@functools.wraps(fn)
|
@@ -64,7 +57,7 @@ def _save_args(fn: Callable[P, T]) -> Callable[P, T]:
|
|
64
57
|
return _cont_wrapper
|
65
58
|
|
66
59
|
|
67
|
-
def _set_prompt(
|
60
|
+
def _set_prompt[**P, T](
|
68
61
|
prompt: Operation[P, T], cont: Callable[P, T], body: Callable[P, T]
|
69
62
|
) -> Callable[P, T]:
|
70
63
|
from effectful.ops.semantics import handler
|
@@ -0,0 +1,32 @@
|
|
1
|
+
def _desugar_tensor_index(shape, key):
|
2
|
+
new_shape = []
|
3
|
+
new_key = []
|
4
|
+
|
5
|
+
def extra_dims(key):
|
6
|
+
return sum(1 for k in key if k is None)
|
7
|
+
|
8
|
+
# handle any missing dimensions by adding a trailing Ellipsis
|
9
|
+
if not any(k is Ellipsis for k in key):
|
10
|
+
key = tuple(key) + (...,)
|
11
|
+
|
12
|
+
for i, k in enumerate(key):
|
13
|
+
if k is None: # add a new singleton dimension
|
14
|
+
new_shape.append(1)
|
15
|
+
new_key.append(slice(None))
|
16
|
+
elif k is Ellipsis:
|
17
|
+
assert not any(k is Ellipsis for k in key[i + 1 :]), (
|
18
|
+
"only one Ellipsis allowed"
|
19
|
+
)
|
20
|
+
|
21
|
+
# determine which of the original dimensions this ellipsis refers to
|
22
|
+
pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis
|
23
|
+
elided_dims = (
|
24
|
+
len(shape) - pre_dims - (len(key) - i - 1 - extra_dims(key[i + 1 :]))
|
25
|
+
) #
|
26
|
+
new_shape += shape[pre_dims : pre_dims + elided_dims]
|
27
|
+
new_key += [slice(None)] * elided_dims
|
28
|
+
else:
|
29
|
+
new_shape.append(shape[len(new_shape) - extra_dims(key[:i])])
|
30
|
+
new_key.append(k)
|
31
|
+
|
32
|
+
return new_shape, new_key
|