effectful 0.0.1__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {effectful-0.0.1 → effectful-0.2.0}/PKG-INFO +65 -57
- {effectful-0.0.1 → effectful-0.2.0}/README.rst +26 -14
- {effectful-0.0.1 → effectful-0.2.0}/effectful/handlers/indexed.py +27 -46
- effectful-0.2.0/effectful/handlers/jax/__init__.py +14 -0
- effectful-0.2.0/effectful/handlers/jax/_handlers.py +293 -0
- effectful-0.2.0/effectful/handlers/jax/_terms.py +502 -0
- effectful-0.2.0/effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful-0.2.0/effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful-0.2.0/effectful/handlers/jax/scipy/special.py +11 -0
- effectful-0.2.0/effectful/handlers/numpyro.py +562 -0
- effectful-0.2.0/effectful/handlers/pyro.py +817 -0
- effectful-0.2.0/effectful/handlers/torch.py +724 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful/internals/runtime.py +6 -13
- effectful-0.2.0/effectful/internals/tensor_utils.py +32 -0
- effectful-0.2.0/effectful/internals/unification.py +900 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful/ops/semantics.py +104 -84
- effectful-0.2.0/effectful/ops/syntax.py +1632 -0
- effectful-0.2.0/effectful/ops/types.py +216 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful.egg-info/PKG-INFO +65 -57
- {effectful-0.0.1 → effectful-0.2.0}/effectful.egg-info/SOURCES.txt +14 -5
- effectful-0.2.0/effectful.egg-info/requires.txt +37 -0
- effectful-0.2.0/pyproject.toml +81 -0
- effectful-0.2.0/setup.cfg +4 -0
- {effectful-0.0.1 → effectful-0.2.0}/tests/test_examples_minipyro.py +0 -1
- {effectful-0.0.1 → effectful-0.2.0}/tests/test_handlers_indexed.py +8 -6
- effectful-0.2.0/tests/test_handlers_jax.py +941 -0
- effectful-0.2.0/tests/test_handlers_numpyro.py +855 -0
- {effectful-0.0.1 → effectful-0.2.0}/tests/test_handlers_pyro.py +71 -46
- {effectful-0.0.1 → effectful-0.2.0}/tests/test_handlers_pyro_dist.py +359 -219
- {effectful-0.0.1 → effectful-0.2.0}/tests/test_handlers_torch.py +270 -149
- effectful-0.2.0/tests/test_internals_unification.py +1646 -0
- {effectful-0.0.1 → effectful-0.2.0}/tests/test_ops_semantics.py +242 -21
- effectful-0.2.0/tests/test_ops_syntax.py +873 -0
- effectful-0.2.0/tests/test_ops_types.py +12 -0
- {effectful-0.0.1 → effectful-0.2.0}/tests/test_semi_ring.py +23 -23
- effectful-0.0.1/effectful/handlers/numbers.py +0 -259
- effectful-0.0.1/effectful/handlers/pyro.py +0 -466
- effectful-0.0.1/effectful/handlers/torch.py +0 -572
- effectful-0.0.1/effectful/internals/base_impl.py +0 -259
- effectful-0.0.1/effectful/ops/syntax.py +0 -523
- effectful-0.0.1/effectful/ops/types.py +0 -110
- effectful-0.0.1/effectful.egg-info/requires.txt +0 -28
- effectful-0.0.1/setup.cfg +0 -25
- effectful-0.0.1/setup.py +0 -72
- effectful-0.0.1/tests/test_handlers_numbers.py +0 -263
- effectful-0.0.1/tests/test_ops_syntax.py +0 -92
- {effectful-0.0.1 → effectful-0.2.0}/LICENSE.md +0 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful/__init__.py +0 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful/handlers/__init__.py +0 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful/internals/__init__.py +0 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful/ops/__init__.py +0 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful/py.typed +0 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful.egg-info/dependency_links.txt +0 -0
- {effectful-0.0.1 → effectful-0.2.0}/effectful.egg-info/top_level.txt +0 -0
@@ -1,63 +1,59 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: effectful
|
3
|
-
Version: 0.0
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: Metaprogramming infrastructure
|
5
|
-
Home-page: https://www.basis.ai/
|
6
5
|
Author: Basis
|
7
|
-
License: Apache
|
6
|
+
License-Expression: Apache-2.0
|
7
|
+
Project-URL: Homepage, https://www.basis.ai/
|
8
8
|
Project-URL: Source, https://github.com/BasisResearch/effectful
|
9
|
-
|
9
|
+
Project-URL: Bug Tracker, https://github.com/BasisResearch/effectful/issues
|
10
|
+
Keywords: machine learning,statistics,probabilistic programming,bayesian modeling,pytorch
|
10
11
|
Classifier: Intended Audience :: Developers
|
11
12
|
Classifier: Intended Audience :: Education
|
12
13
|
Classifier: Intended Audience :: Science/Research
|
13
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
14
14
|
Classifier: Operating System :: POSIX :: Linux
|
15
15
|
Classifier: Operating System :: MacOS :: MacOS X
|
16
|
-
Classifier: Programming Language :: Python :: 3.
|
17
|
-
Classifier: Programming Language :: Python :: 3.
|
18
|
-
Requires-Python: >=3.
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
18
|
+
Requires-Python: >=3.12
|
19
|
+
Description-Content-Type: text/x-rst
|
19
20
|
License-File: LICENSE.md
|
20
|
-
Requires-Dist: typing_extensions
|
21
|
-
Requires-Dist: dm-tree
|
22
21
|
Provides-Extra: torch
|
23
22
|
Requires-Dist: torch; extra == "torch"
|
23
|
+
Requires-Dist: dm-tree; extra == "torch"
|
24
24
|
Provides-Extra: pyro
|
25
|
-
Requires-Dist:
|
26
|
-
Requires-Dist:
|
27
|
-
Provides-Extra:
|
28
|
-
Requires-Dist:
|
29
|
-
Requires-Dist:
|
30
|
-
|
31
|
-
Requires-Dist:
|
32
|
-
Requires-Dist:
|
33
|
-
|
34
|
-
Requires-Dist:
|
35
|
-
Requires-Dist:
|
36
|
-
Requires-Dist:
|
37
|
-
Requires-Dist:
|
38
|
-
Requires-Dist:
|
39
|
-
Requires-Dist:
|
40
|
-
Requires-Dist:
|
41
|
-
Requires-Dist:
|
42
|
-
|
43
|
-
Requires-Dist:
|
44
|
-
Requires-Dist:
|
45
|
-
Requires-Dist:
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
Dynamic:
|
53
|
-
Dynamic: provides-extra
|
54
|
-
Dynamic: requires-dist
|
55
|
-
Dynamic: requires-python
|
56
|
-
Dynamic: summary
|
25
|
+
Requires-Dist: pyro-ppl>=1.9.1; extra == "pyro"
|
26
|
+
Requires-Dist: dm-tree; extra == "pyro"
|
27
|
+
Provides-Extra: jax
|
28
|
+
Requires-Dist: jax; extra == "jax"
|
29
|
+
Requires-Dist: dm-tree; extra == "jax"
|
30
|
+
Provides-Extra: numpyro
|
31
|
+
Requires-Dist: numpyro>=0.19; extra == "numpyro"
|
32
|
+
Requires-Dist: dm-tree; extra == "numpyro"
|
33
|
+
Provides-Extra: docs
|
34
|
+
Requires-Dist: effectful[jax,numpyro,pyro,torch]; extra == "docs"
|
35
|
+
Requires-Dist: sphinx; extra == "docs"
|
36
|
+
Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
|
37
|
+
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
38
|
+
Requires-Dist: myst-parser; extra == "docs"
|
39
|
+
Requires-Dist: nbsphinx; extra == "docs"
|
40
|
+
Requires-Dist: sphinx_autodoc_typehints; extra == "docs"
|
41
|
+
Requires-Dist: pypandoc_binary; extra == "docs"
|
42
|
+
Provides-Extra: test
|
43
|
+
Requires-Dist: effectful[docs,jax,numpyro,pyro,torch]; extra == "test"
|
44
|
+
Requires-Dist: pytest; extra == "test"
|
45
|
+
Requires-Dist: pytest-cov; extra == "test"
|
46
|
+
Requires-Dist: pytest-xdist; extra == "test"
|
47
|
+
Requires-Dist: pytest-benchmark; extra == "test"
|
48
|
+
Requires-Dist: mypy; extra == "test"
|
49
|
+
Requires-Dist: ruff; extra == "test"
|
50
|
+
Requires-Dist: nbval; extra == "test"
|
51
|
+
Requires-Dist: nbqa; extra == "test"
|
52
|
+
Dynamic: license-file
|
57
53
|
|
58
54
|
.. index-inclusion-marker
|
59
55
|
|
60
|
-
Effectful
|
56
|
+
Effectful
|
61
57
|
=========
|
62
58
|
|
63
59
|
Effectful is an algebraic effect system for Python, intended for use in the
|
@@ -75,14 +71,17 @@ Install From Source
|
|
75
71
|
git clone git@github.com:BasisResearch/effectful.git
|
76
72
|
cd effectful
|
77
73
|
git checkout master
|
78
|
-
pip install -e .[
|
74
|
+
pip install -e .[pyro]
|
79
75
|
|
80
76
|
Install With Optional PyTorch/Pyro Support
|
81
77
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
82
78
|
|
83
|
-
``effectful`` has optional support for
|
84
|
-
|
85
|
-
|
79
|
+
``effectful`` has optional support for:
|
80
|
+
|
81
|
+
- `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
|
82
|
+
- `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
|
83
|
+
- `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
|
84
|
+
- `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
|
86
85
|
|
87
86
|
To enable PyTorch support:
|
88
87
|
|
@@ -96,6 +95,18 @@ Pyro support (which includes PyTorch support):
|
|
96
95
|
|
97
96
|
pip install effectful[pyro]
|
98
97
|
|
98
|
+
Jax support:
|
99
|
+
|
100
|
+
.. code:: sh
|
101
|
+
|
102
|
+
pip install effectful[jax]
|
103
|
+
|
104
|
+
Numpyro support (which includes Jax support):
|
105
|
+
|
106
|
+
.. code:: sh
|
107
|
+
|
108
|
+
pip install effectful[numpyro]
|
109
|
+
|
99
110
|
Getting Started
|
100
111
|
---------------
|
101
112
|
|
@@ -103,18 +114,15 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
103
114
|
|
104
115
|
.. code:: python
|
105
116
|
|
106
|
-
import operator
|
107
117
|
import functools
|
108
118
|
|
109
|
-
from effectful.handlers.operator import OPERATORS
|
110
119
|
from effectful.ops.types import Term
|
111
|
-
from effectful.ops.syntax import defop
|
120
|
+
from effectful.ops.syntax import defdata, defop
|
112
121
|
from effectful.ops.semantics import handler, evaluate, coproduct, fwd
|
113
|
-
import effectful.handlers.operator
|
114
122
|
|
115
|
-
add =
|
123
|
+
add = defdata.dispatch(int).__add__
|
116
124
|
|
117
|
-
def beta_add(x: int, y: int) -> int:
|
125
|
+
def beta_add(x: int, y: int) -> int:
|
118
126
|
match x, y:
|
119
127
|
case int(), int():
|
120
128
|
return x + y
|
@@ -124,14 +132,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
124
132
|
def commute_add(x: int, y: int) -> int:
|
125
133
|
match x, y:
|
126
134
|
case Term(), int():
|
127
|
-
return y + x
|
135
|
+
return y + x
|
128
136
|
case _:
|
129
137
|
return fwd()
|
130
138
|
|
131
139
|
def assoc_add(x: int, y: int) -> int:
|
132
140
|
match x, y:
|
133
141
|
case _, Term(op, (a, b)) if op == add:
|
134
|
-
return (x + a) + b
|
142
|
+
return (x + a) + b
|
135
143
|
case _:
|
136
144
|
return fwd()
|
137
145
|
|
@@ -163,7 +171,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
|
|
163
171
|
>>> with handler(eager_mixed):
|
164
172
|
>>> print(evaluate(e))
|
165
173
|
add(8, add(x(), y()))
|
166
|
-
|
174
|
+
|
167
175
|
Learn More
|
168
176
|
----------
|
169
177
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
.. index-inclusion-marker
|
2
2
|
|
3
|
-
Effectful
|
3
|
+
Effectful
|
4
4
|
=========
|
5
5
|
|
6
6
|
Effectful is an algebraic effect system for Python, intended for use in the
|
@@ -18,14 +18,17 @@ Install From Source
|
|
18
18
|
git clone git@github.com:BasisResearch/effectful.git
|
19
19
|
cd effectful
|
20
20
|
git checkout master
|
21
|
-
pip install -e .[
|
21
|
+
pip install -e .[pyro]
|
22
22
|
|
23
23
|
Install With Optional PyTorch/Pyro Support
|
24
24
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
25
25
|
|
26
|
-
``effectful`` has optional support for
|
27
|
-
|
28
|
-
|
26
|
+
``effectful`` has optional support for:
|
27
|
+
|
28
|
+
- `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
|
29
|
+
- `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
|
30
|
+
- `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
|
31
|
+
- `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
|
29
32
|
|
30
33
|
To enable PyTorch support:
|
31
34
|
|
@@ -39,6 +42,18 @@ Pyro support (which includes PyTorch support):
|
|
39
42
|
|
40
43
|
pip install effectful[pyro]
|
41
44
|
|
45
|
+
Jax support:
|
46
|
+
|
47
|
+
.. code:: sh
|
48
|
+
|
49
|
+
pip install effectful[jax]
|
50
|
+
|
51
|
+
Numpyro support (which includes Jax support):
|
52
|
+
|
53
|
+
.. code:: sh
|
54
|
+
|
55
|
+
pip install effectful[numpyro]
|
56
|
+
|
42
57
|
Getting Started
|
43
58
|
---------------
|
44
59
|
|
@@ -46,18 +61,15 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
46
61
|
|
47
62
|
.. code:: python
|
48
63
|
|
49
|
-
import operator
|
50
64
|
import functools
|
51
65
|
|
52
|
-
from effectful.handlers.operator import OPERATORS
|
53
66
|
from effectful.ops.types import Term
|
54
|
-
from effectful.ops.syntax import defop
|
67
|
+
from effectful.ops.syntax import defdata, defop
|
55
68
|
from effectful.ops.semantics import handler, evaluate, coproduct, fwd
|
56
|
-
import effectful.handlers.operator
|
57
69
|
|
58
|
-
add =
|
70
|
+
add = defdata.dispatch(int).__add__
|
59
71
|
|
60
|
-
def beta_add(x: int, y: int) -> int:
|
72
|
+
def beta_add(x: int, y: int) -> int:
|
61
73
|
match x, y:
|
62
74
|
case int(), int():
|
63
75
|
return x + y
|
@@ -67,14 +79,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
67
79
|
def commute_add(x: int, y: int) -> int:
|
68
80
|
match x, y:
|
69
81
|
case Term(), int():
|
70
|
-
return y + x
|
82
|
+
return y + x
|
71
83
|
case _:
|
72
84
|
return fwd()
|
73
85
|
|
74
86
|
def assoc_add(x: int, y: int) -> int:
|
75
87
|
match x, y:
|
76
88
|
case _, Term(op, (a, b)) if op == add:
|
77
|
-
return (x + a) + b
|
89
|
+
return (x + a) + b
|
78
90
|
case _:
|
79
91
|
return fwd()
|
80
92
|
|
@@ -106,7 +118,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
|
|
106
118
|
>>> with handler(eager_mixed):
|
107
119
|
>>> print(evaluate(e))
|
108
120
|
add(8, add(x(), y()))
|
109
|
-
|
121
|
+
|
110
122
|
Learn More
|
111
123
|
----------
|
112
124
|
|
@@ -1,18 +1,16 @@
|
|
1
1
|
import functools
|
2
2
|
import operator
|
3
|
-
from
|
3
|
+
from collections.abc import Iterable
|
4
|
+
from typing import Any
|
4
5
|
|
5
6
|
import torch
|
6
7
|
|
7
|
-
from effectful.handlers.torch import
|
8
|
+
from effectful.handlers.torch import sizesof
|
8
9
|
from effectful.ops.syntax import deffn, defop
|
9
|
-
from effectful.ops.types import Operation
|
10
|
+
from effectful.ops.types import Operation
|
10
11
|
|
11
|
-
K = TypeVar("K")
|
12
|
-
T = TypeVar("T")
|
13
12
|
|
14
|
-
|
15
|
-
class IndexSet(Dict[str, Set[int]]):
|
13
|
+
class IndexSet(dict[str, set[int]]):
|
16
14
|
"""
|
17
15
|
:class:`IndexSet` s represent the support of an indexed value,
|
18
16
|
for which free variables correspond to single interventions and indices
|
@@ -32,13 +30,13 @@ class IndexSet(Dict[str, Set[int]]):
|
|
32
30
|
for which a value is defined::
|
33
31
|
|
34
32
|
>>> IndexSet(x={0, 1}, y={2, 3})
|
35
|
-
IndexSet({x: {0, 1}, y: {2, 3}})
|
33
|
+
IndexSet({'x': {0, 1}, 'y': {2, 3}})
|
36
34
|
|
37
35
|
:class:`IndexSet` 's constructor will automatically drop empty entries
|
38
36
|
and attempt to convert input values to :class:`set` s::
|
39
37
|
|
40
38
|
>>> IndexSet(x=[0, 0, 1], y=set(), z=2)
|
41
|
-
IndexSet({x: {0, 1}, z: {2}})
|
39
|
+
IndexSet({'x': {0, 1}, 'z': {2}})
|
42
40
|
|
43
41
|
:class:`IndexSet` s are also hashable and can be used as keys in :class:`dict` s::
|
44
42
|
|
@@ -47,7 +45,7 @@ class IndexSet(Dict[str, Set[int]]):
|
|
47
45
|
True
|
48
46
|
"""
|
49
47
|
|
50
|
-
def __init__(self, **mapping:
|
48
|
+
def __init__(self, **mapping: int | Iterable[int]):
|
51
49
|
index_set = {}
|
52
50
|
for k, vs in mapping.items():
|
53
51
|
indexes = {vs} if isinstance(vs, int) else set(vs)
|
@@ -61,16 +59,6 @@ class IndexSet(Dict[str, Set[int]]):
|
|
61
59
|
def __hash__(self):
|
62
60
|
return hash(frozenset((k, frozenset(vs)) for k, vs in self.items()))
|
63
61
|
|
64
|
-
def _to_handler(self):
|
65
|
-
"""Return an effectful handler that binds each index variable to a
|
66
|
-
tensor of its possible index values.
|
67
|
-
|
68
|
-
"""
|
69
|
-
return {
|
70
|
-
name_to_sym(k): functools.partial(lambda v: v, torch.tensor(list(v)))
|
71
|
-
for k, v in self.items()
|
72
|
-
}
|
73
|
-
|
74
62
|
|
75
63
|
def union(*indexsets: IndexSet) -> IndexSet:
|
76
64
|
"""
|
@@ -166,25 +154,17 @@ def indices_of(value: Any) -> IndexSet:
|
|
166
154
|
:param kwargs: Additional keyword arguments used by specific implementations.
|
167
155
|
:return: A :class:`IndexSet` containing the indices on which the value is supported.
|
168
156
|
"""
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
k.__name__: set(range(v)) # type:ignore
|
173
|
-
for (k, v) in sizesof(value).items()
|
174
|
-
}
|
175
|
-
)
|
176
|
-
elif isinstance(value, torch.distributions.Distribution):
|
177
|
-
return indices_of(value.sample())
|
178
|
-
|
179
|
-
return IndexSet()
|
157
|
+
return IndexSet(
|
158
|
+
**{getattr(k, "__name__"): set(range(v)) for (k, v) in sizesof(value).items()}
|
159
|
+
)
|
180
160
|
|
181
161
|
|
182
|
-
@functools.
|
183
|
-
def name_to_sym(name: str) -> Operation[[],
|
184
|
-
return defop(
|
162
|
+
@functools.cache
|
163
|
+
def name_to_sym(name: str) -> Operation[[], torch.Tensor]:
|
164
|
+
return defop(torch.Tensor, name=name)
|
185
165
|
|
186
166
|
|
187
|
-
def gather(value: torch.Tensor, indexset: IndexSet
|
167
|
+
def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
|
188
168
|
"""
|
189
169
|
Selects entries from an indexed value at the indices in a :class:`IndexSet` .
|
190
170
|
:func:`gather` is useful in conjunction with :class:`MultiWorldCounterfactual`
|
@@ -248,9 +228,7 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
|
|
248
228
|
"""
|
249
229
|
indexset_vars = {name_to_sym(name): inds for name, inds in indexset.items()}
|
250
230
|
binding = {
|
251
|
-
k: functools.partial(
|
252
|
-
lambda v: v, Indexable(torch.tensor(list(indexset_vars[k])))[k()]
|
253
|
-
)
|
231
|
+
k: functools.partial(lambda v: v, torch.tensor(list(indexset_vars[k]))[k()])
|
254
232
|
for k in sizesof(value).keys()
|
255
233
|
if k in indexset_vars
|
256
234
|
}
|
@@ -259,14 +237,15 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
|
|
259
237
|
|
260
238
|
|
261
239
|
def stack(
|
262
|
-
values:
|
240
|
+
values: tuple[torch.Tensor, ...] | list[torch.Tensor], name: str
|
263
241
|
) -> torch.Tensor:
|
264
242
|
"""Stack a sequence of indexed values, creating a new dimension. The new
|
265
243
|
dimension is indexed by `dim`. The indexed values in the stack must have
|
266
244
|
identical shapes.
|
267
245
|
|
268
246
|
"""
|
269
|
-
|
247
|
+
values = torch.distributions.utils.broadcast_all(*values)
|
248
|
+
return torch.stack(values)[name_to_sym(name)()]
|
270
249
|
|
271
250
|
|
272
251
|
def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Tensor:
|
@@ -281,12 +260,14 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
|
|
281
260
|
Unlike a Python conditional expression, however, the case may be a tensor,
|
282
261
|
and both branches are evaluated, as with :func:`torch.where` ::
|
283
262
|
|
284
|
-
>>> from effectful.
|
285
|
-
>>>
|
286
|
-
|
263
|
+
>>> from effectful.ops.syntax import defop
|
264
|
+
>>> from effectful.handlers.torch import bind_dims
|
265
|
+
|
266
|
+
>>> b = defop(torch.Tensor, name="b")
|
267
|
+
>>> fst, snd = torch.randn(2, 3)[b()], torch.randn(2, 3)[b()]
|
287
268
|
>>> case = (fst < snd).all(-1)
|
288
269
|
>>> x = cond(fst, snd, case)
|
289
|
-
>>> assert (
|
270
|
+
>>> assert (bind_dims(x, b) == bind_dims(torch.where(case[..., None], snd, fst), b)).all()
|
290
271
|
|
291
272
|
.. note::
|
292
273
|
|
@@ -304,10 +285,10 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
|
|
304
285
|
)
|
305
286
|
|
306
287
|
|
307
|
-
def cond_n(values:
|
288
|
+
def cond_n(values: dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
|
308
289
|
assert len(values) > 0
|
309
290
|
assert all(isinstance(k, IndexSet) for k in values.keys())
|
310
|
-
result:
|
291
|
+
result: torch.Tensor | None = None
|
311
292
|
for indices, value in values.items():
|
312
293
|
tst = torch.as_tensor(
|
313
294
|
functools.reduce(
|
@@ -0,0 +1,14 @@
|
|
1
|
+
try:
|
2
|
+
# Dummy import to check if jax is installed
|
3
|
+
import jax # noqa: F401
|
4
|
+
except ImportError:
|
5
|
+
raise ImportError("Jax is required to use effectful.handlers.jax")
|
6
|
+
|
7
|
+
# side effect: register defdata for jax.Array
|
8
|
+
import effectful.handlers.jax._terms # noqa: F401
|
9
|
+
|
10
|
+
from ._handlers import bind_dims as bind_dims
|
11
|
+
from ._handlers import jax_getitem as jax_getitem
|
12
|
+
from ._handlers import jit as jit
|
13
|
+
from ._handlers import sizesof as sizesof
|
14
|
+
from ._handlers import unbind_dims as unbind_dims
|