effectful 0.1.0__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {effectful-0.1.0 → effectful-0.2.1}/PKG-INFO +59 -56
- {effectful-0.1.0 → effectful-0.2.1}/README.rst +26 -11
- {effectful-0.1.0 → effectful-0.2.1}/effectful/handlers/indexed.py +23 -24
- effectful-0.2.1/effectful/handlers/jax/__init__.py +14 -0
- effectful-0.2.1/effectful/handlers/jax/_handlers.py +293 -0
- effectful-0.2.1/effectful/handlers/jax/_terms.py +502 -0
- effectful-0.2.1/effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful-0.2.1/effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful-0.2.1/effectful/handlers/jax/scipy/special.py +11 -0
- effectful-0.2.1/effectful/handlers/numpyro.py +562 -0
- effectful-0.2.1/effectful/handlers/pyro.py +817 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful/handlers/torch.py +297 -168
- {effectful-0.1.0 → effectful-0.2.1}/effectful/internals/runtime.py +6 -13
- effectful-0.2.1/effectful/internals/tensor_utils.py +32 -0
- effectful-0.2.1/effectful/internals/unification.py +901 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful/ops/semantics.py +109 -77
- {effectful-0.1.0 → effectful-0.2.1}/effectful/ops/syntax.py +821 -250
- effectful-0.2.1/effectful/ops/types.py +216 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful.egg-info/PKG-INFO +59 -56
- {effectful-0.1.0 → effectful-0.2.1}/effectful.egg-info/SOURCES.txt +14 -4
- effectful-0.2.1/effectful.egg-info/requires.txt +37 -0
- effectful-0.2.1/pyproject.toml +81 -0
- effectful-0.2.1/setup.cfg +4 -0
- {effectful-0.1.0 → effectful-0.2.1}/tests/test_examples_minipyro.py +0 -1
- {effectful-0.1.0 → effectful-0.2.1}/tests/test_handlers_indexed.py +8 -6
- effectful-0.2.1/tests/test_handlers_jax.py +949 -0
- effectful-0.2.1/tests/test_handlers_numpyro.py +855 -0
- {effectful-0.1.0 → effectful-0.2.1}/tests/test_handlers_pyro.py +71 -46
- {effectful-0.1.0 → effectful-0.2.1}/tests/test_handlers_pyro_dist.py +359 -219
- {effectful-0.1.0 → effectful-0.2.1}/tests/test_handlers_torch.py +270 -149
- effectful-0.2.1/tests/test_internals_unification.py +1651 -0
- {effectful-0.1.0 → effectful-0.2.1}/tests/test_ops_semantics.py +195 -24
- effectful-0.2.1/tests/test_ops_syntax.py +898 -0
- effectful-0.2.1/tests/test_ops_types.py +12 -0
- {effectful-0.1.0 → effectful-0.2.1}/tests/test_semi_ring.py +23 -23
- effectful-0.1.0/effectful/handlers/numbers.py +0 -263
- effectful-0.1.0/effectful/handlers/pyro.py +0 -466
- effectful-0.1.0/effectful/ops/types.py +0 -124
- effectful-0.1.0/effectful.egg-info/requires.txt +0 -37
- effectful-0.1.0/setup.cfg +0 -25
- effectful-0.1.0/setup.py +0 -81
- effectful-0.1.0/tests/test_handlers_numbers.py +0 -266
- effectful-0.1.0/tests/test_ops_syntax.py +0 -124
- {effectful-0.1.0 → effectful-0.2.1}/LICENSE.md +0 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful/__init__.py +0 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful/handlers/__init__.py +0 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful/internals/__init__.py +0 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful/ops/__init__.py +0 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful/py.typed +0 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful.egg-info/dependency_links.txt +0 -0
- {effectful-0.1.0 → effectful-0.2.1}/effectful.egg-info/top_level.txt +0 -0
@@ -1,72 +1,59 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: effectful
|
3
|
-
Version: 0.1
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: Metaprogramming infrastructure
|
5
|
-
Home-page: https://www.basis.ai/
|
6
5
|
Author: Basis
|
7
|
-
License: Apache
|
6
|
+
License-Expression: Apache-2.0
|
7
|
+
Project-URL: Homepage, https://www.basis.ai/
|
8
8
|
Project-URL: Source, https://github.com/BasisResearch/effectful
|
9
|
-
|
9
|
+
Project-URL: Bug Tracker, https://github.com/BasisResearch/effectful/issues
|
10
|
+
Keywords: machine learning,statistics,probabilistic programming,bayesian modeling,pytorch
|
10
11
|
Classifier: Intended Audience :: Developers
|
11
12
|
Classifier: Intended Audience :: Education
|
12
13
|
Classifier: Intended Audience :: Science/Research
|
13
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
14
14
|
Classifier: Operating System :: POSIX :: Linux
|
15
15
|
Classifier: Operating System :: MacOS :: MacOS X
|
16
|
-
Classifier: Programming Language :: Python :: 3.
|
17
|
-
Classifier: Programming Language :: Python :: 3.
|
18
|
-
Requires-Python: >=3.
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
18
|
+
Requires-Python: >=3.12
|
19
|
+
Description-Content-Type: text/x-rst
|
19
20
|
License-File: LICENSE.md
|
20
|
-
Requires-Dist: typing_extensions
|
21
|
-
Requires-Dist: dm-tree
|
22
21
|
Provides-Extra: torch
|
23
22
|
Requires-Dist: torch; extra == "torch"
|
23
|
+
Requires-Dist: dm-tree; extra == "torch"
|
24
24
|
Provides-Extra: pyro
|
25
|
-
Requires-Dist:
|
26
|
-
Requires-Dist:
|
27
|
-
Provides-Extra:
|
28
|
-
Requires-Dist:
|
29
|
-
Requires-Dist:
|
30
|
-
|
31
|
-
Requires-Dist:
|
32
|
-
Requires-Dist:
|
33
|
-
Requires-Dist: sphinxcontrib-bibtex; extra == "dev"
|
34
|
-
Requires-Dist: sphinx_rtd_theme; extra == "dev"
|
35
|
-
Requires-Dist: myst-parser; extra == "dev"
|
36
|
-
Requires-Dist: nbsphinx; extra == "dev"
|
37
|
-
Requires-Dist: pytest; extra == "dev"
|
38
|
-
Requires-Dist: pytest-cov; extra == "dev"
|
39
|
-
Requires-Dist: pytest-xdist; extra == "dev"
|
40
|
-
Requires-Dist: pytest-benchmark; extra == "dev"
|
41
|
-
Requires-Dist: mypy; extra == "dev"
|
42
|
-
Requires-Dist: black; extra == "dev"
|
43
|
-
Requires-Dist: flake8; extra == "dev"
|
44
|
-
Requires-Dist: isort; extra == "dev"
|
45
|
-
Requires-Dist: nbval; extra == "dev"
|
46
|
-
Requires-Dist: nbqa; extra == "dev"
|
25
|
+
Requires-Dist: pyro-ppl>=1.9.1; extra == "pyro"
|
26
|
+
Requires-Dist: dm-tree; extra == "pyro"
|
27
|
+
Provides-Extra: jax
|
28
|
+
Requires-Dist: jax; extra == "jax"
|
29
|
+
Requires-Dist: dm-tree; extra == "jax"
|
30
|
+
Provides-Extra: numpyro
|
31
|
+
Requires-Dist: numpyro>=0.19; extra == "numpyro"
|
32
|
+
Requires-Dist: dm-tree; extra == "numpyro"
|
47
33
|
Provides-Extra: docs
|
48
|
-
Requires-Dist:
|
34
|
+
Requires-Dist: effectful[jax,numpyro,pyro,torch]; extra == "docs"
|
49
35
|
Requires-Dist: sphinx; extra == "docs"
|
50
36
|
Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
|
51
37
|
Requires-Dist: sphinx_rtd_theme; extra == "docs"
|
52
38
|
Requires-Dist: myst-parser; extra == "docs"
|
53
39
|
Requires-Dist: nbsphinx; extra == "docs"
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
40
|
+
Requires-Dist: sphinx_autodoc_typehints; extra == "docs"
|
41
|
+
Requires-Dist: pypandoc_binary; extra == "docs"
|
42
|
+
Provides-Extra: test
|
43
|
+
Requires-Dist: effectful[docs,jax,numpyro,pyro,torch]; extra == "test"
|
44
|
+
Requires-Dist: pytest; extra == "test"
|
45
|
+
Requires-Dist: pytest-cov; extra == "test"
|
46
|
+
Requires-Dist: pytest-xdist; extra == "test"
|
47
|
+
Requires-Dist: pytest-benchmark; extra == "test"
|
48
|
+
Requires-Dist: mypy; extra == "test"
|
49
|
+
Requires-Dist: ruff; extra == "test"
|
50
|
+
Requires-Dist: nbval; extra == "test"
|
51
|
+
Requires-Dist: nbqa; extra == "test"
|
52
|
+
Dynamic: license-file
|
66
53
|
|
67
54
|
.. index-inclusion-marker
|
68
55
|
|
69
|
-
Effectful
|
56
|
+
Effectful
|
70
57
|
=========
|
71
58
|
|
72
59
|
Effectful is an algebraic effect system for Python, intended for use in the
|
@@ -89,9 +76,12 @@ Install From Source
|
|
89
76
|
Install With Optional PyTorch/Pyro Support
|
90
77
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
91
78
|
|
92
|
-
``effectful`` has optional support for
|
93
|
-
|
94
|
-
|
79
|
+
``effectful`` has optional support for:
|
80
|
+
|
81
|
+
- `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
|
82
|
+
- `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
|
83
|
+
- `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
|
84
|
+
- `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
|
95
85
|
|
96
86
|
To enable PyTorch support:
|
97
87
|
|
@@ -105,6 +95,18 @@ Pyro support (which includes PyTorch support):
|
|
105
95
|
|
106
96
|
pip install effectful[pyro]
|
107
97
|
|
98
|
+
Jax support:
|
99
|
+
|
100
|
+
.. code:: sh
|
101
|
+
|
102
|
+
pip install effectful[jax]
|
103
|
+
|
104
|
+
Numpyro support (which includes Jax support):
|
105
|
+
|
106
|
+
.. code:: sh
|
107
|
+
|
108
|
+
pip install effectful[numpyro]
|
109
|
+
|
108
110
|
Getting Started
|
109
111
|
---------------
|
110
112
|
|
@@ -115,11 +117,12 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
115
117
|
import functools
|
116
118
|
|
117
119
|
from effectful.ops.types import Term
|
118
|
-
from effectful.ops.syntax import defop
|
120
|
+
from effectful.ops.syntax import defdata, defop
|
119
121
|
from effectful.ops.semantics import handler, evaluate, coproduct, fwd
|
120
|
-
from effectful.handlers.numbers import add
|
121
122
|
|
122
|
-
|
123
|
+
add = defdata.dispatch(int).__add__
|
124
|
+
|
125
|
+
def beta_add(x: int, y: int) -> int:
|
123
126
|
match x, y:
|
124
127
|
case int(), int():
|
125
128
|
return x + y
|
@@ -129,14 +132,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
129
132
|
def commute_add(x: int, y: int) -> int:
|
130
133
|
match x, y:
|
131
134
|
case Term(), int():
|
132
|
-
return y + x
|
135
|
+
return y + x
|
133
136
|
case _:
|
134
137
|
return fwd()
|
135
138
|
|
136
139
|
def assoc_add(x: int, y: int) -> int:
|
137
140
|
match x, y:
|
138
141
|
case _, Term(op, (a, b)) if op == add:
|
139
|
-
return (x + a) + b
|
142
|
+
return (x + a) + b
|
140
143
|
case _:
|
141
144
|
return fwd()
|
142
145
|
|
@@ -168,7 +171,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
|
|
168
171
|
>>> with handler(eager_mixed):
|
169
172
|
>>> print(evaluate(e))
|
170
173
|
add(8, add(x(), y()))
|
171
|
-
|
174
|
+
|
172
175
|
Learn More
|
173
176
|
----------
|
174
177
|
|
@@ -1,7 +1,6 @@
|
|
1
|
-
|
2
1
|
.. index-inclusion-marker
|
3
2
|
|
4
|
-
Effectful
|
3
|
+
Effectful
|
5
4
|
=========
|
6
5
|
|
7
6
|
Effectful is an algebraic effect system for Python, intended for use in the
|
@@ -24,9 +23,12 @@ Install From Source
|
|
24
23
|
Install With Optional PyTorch/Pyro Support
|
25
24
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
26
25
|
|
27
|
-
``effectful`` has optional support for
|
28
|
-
|
29
|
-
|
26
|
+
``effectful`` has optional support for:
|
27
|
+
|
28
|
+
- `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
|
29
|
+
- `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
|
30
|
+
- `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
|
31
|
+
- `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
|
30
32
|
|
31
33
|
To enable PyTorch support:
|
32
34
|
|
@@ -40,6 +42,18 @@ Pyro support (which includes PyTorch support):
|
|
40
42
|
|
41
43
|
pip install effectful[pyro]
|
42
44
|
|
45
|
+
Jax support:
|
46
|
+
|
47
|
+
.. code:: sh
|
48
|
+
|
49
|
+
pip install effectful[jax]
|
50
|
+
|
51
|
+
Numpyro support (which includes Jax support):
|
52
|
+
|
53
|
+
.. code:: sh
|
54
|
+
|
55
|
+
pip install effectful[numpyro]
|
56
|
+
|
43
57
|
Getting Started
|
44
58
|
---------------
|
45
59
|
|
@@ -50,11 +64,12 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
50
64
|
import functools
|
51
65
|
|
52
66
|
from effectful.ops.types import Term
|
53
|
-
from effectful.ops.syntax import defop
|
67
|
+
from effectful.ops.syntax import defdata, defop
|
54
68
|
from effectful.ops.semantics import handler, evaluate, coproduct, fwd
|
55
|
-
from effectful.handlers.numbers import add
|
56
69
|
|
57
|
-
|
70
|
+
add = defdata.dispatch(int).__add__
|
71
|
+
|
72
|
+
def beta_add(x: int, y: int) -> int:
|
58
73
|
match x, y:
|
59
74
|
case int(), int():
|
60
75
|
return x + y
|
@@ -64,14 +79,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
|
|
64
79
|
def commute_add(x: int, y: int) -> int:
|
65
80
|
match x, y:
|
66
81
|
case Term(), int():
|
67
|
-
return y + x
|
82
|
+
return y + x
|
68
83
|
case _:
|
69
84
|
return fwd()
|
70
85
|
|
71
86
|
def assoc_add(x: int, y: int) -> int:
|
72
87
|
match x, y:
|
73
88
|
case _, Term(op, (a, b)) if op == add:
|
74
|
-
return (x + a) + b
|
89
|
+
return (x + a) + b
|
75
90
|
case _:
|
76
91
|
return fwd()
|
77
92
|
|
@@ -103,7 +118,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
|
|
103
118
|
>>> with handler(eager_mixed):
|
104
119
|
>>> print(evaluate(e))
|
105
120
|
add(8, add(x(), y()))
|
106
|
-
|
121
|
+
|
107
122
|
Learn More
|
108
123
|
----------
|
109
124
|
|
@@ -1,18 +1,16 @@
|
|
1
1
|
import functools
|
2
2
|
import operator
|
3
|
-
from
|
3
|
+
from collections.abc import Iterable
|
4
|
+
from typing import Any
|
4
5
|
|
5
6
|
import torch
|
6
7
|
|
7
|
-
from effectful.handlers.torch import
|
8
|
+
from effectful.handlers.torch import sizesof
|
8
9
|
from effectful.ops.syntax import deffn, defop
|
9
10
|
from effectful.ops.types import Operation
|
10
11
|
|
11
|
-
K = TypeVar("K")
|
12
|
-
T = TypeVar("T")
|
13
12
|
|
14
|
-
|
15
|
-
class IndexSet(Dict[str, Set[int]]):
|
13
|
+
class IndexSet(dict[str, set[int]]):
|
16
14
|
"""
|
17
15
|
:class:`IndexSet` s represent the support of an indexed value,
|
18
16
|
for which free variables correspond to single interventions and indices
|
@@ -32,13 +30,13 @@ class IndexSet(Dict[str, Set[int]]):
|
|
32
30
|
for which a value is defined::
|
33
31
|
|
34
32
|
>>> IndexSet(x={0, 1}, y={2, 3})
|
35
|
-
IndexSet({x: {0, 1}, y: {2, 3}})
|
33
|
+
IndexSet({'x': {0, 1}, 'y': {2, 3}})
|
36
34
|
|
37
35
|
:class:`IndexSet` 's constructor will automatically drop empty entries
|
38
36
|
and attempt to convert input values to :class:`set` s::
|
39
37
|
|
40
38
|
>>> IndexSet(x=[0, 0, 1], y=set(), z=2)
|
41
|
-
IndexSet({x: {0, 1}, z: {2}})
|
39
|
+
IndexSet({'x': {0, 1}, 'z': {2}})
|
42
40
|
|
43
41
|
:class:`IndexSet` s are also hashable and can be used as keys in :class:`dict` s::
|
44
42
|
|
@@ -47,7 +45,7 @@ class IndexSet(Dict[str, Set[int]]):
|
|
47
45
|
True
|
48
46
|
"""
|
49
47
|
|
50
|
-
def __init__(self, **mapping:
|
48
|
+
def __init__(self, **mapping: int | Iterable[int]):
|
51
49
|
index_set = {}
|
52
50
|
for k, vs in mapping.items():
|
53
51
|
indexes = {vs} if isinstance(vs, int) else set(vs)
|
@@ -161,12 +159,12 @@ def indices_of(value: Any) -> IndexSet:
|
|
161
159
|
)
|
162
160
|
|
163
161
|
|
164
|
-
@functools.
|
165
|
-
def name_to_sym(name: str) -> Operation[[],
|
166
|
-
return defop(
|
162
|
+
@functools.cache
|
163
|
+
def name_to_sym(name: str) -> Operation[[], torch.Tensor]:
|
164
|
+
return defop(torch.Tensor, name=name)
|
167
165
|
|
168
166
|
|
169
|
-
def gather(value: torch.Tensor, indexset: IndexSet
|
167
|
+
def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
|
170
168
|
"""
|
171
169
|
Selects entries from an indexed value at the indices in a :class:`IndexSet` .
|
172
170
|
:func:`gather` is useful in conjunction with :class:`MultiWorldCounterfactual`
|
@@ -230,9 +228,7 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
|
|
230
228
|
"""
|
231
229
|
indexset_vars = {name_to_sym(name): inds for name, inds in indexset.items()}
|
232
230
|
binding = {
|
233
|
-
k: functools.partial(
|
234
|
-
lambda v: v, Indexable(torch.tensor(list(indexset_vars[k])))[k()]
|
235
|
-
)
|
231
|
+
k: functools.partial(lambda v: v, torch.tensor(list(indexset_vars[k]))[k()])
|
236
232
|
for k in sizesof(value).keys()
|
237
233
|
if k in indexset_vars
|
238
234
|
}
|
@@ -241,14 +237,15 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
|
|
241
237
|
|
242
238
|
|
243
239
|
def stack(
|
244
|
-
values:
|
240
|
+
values: tuple[torch.Tensor, ...] | list[torch.Tensor], name: str
|
245
241
|
) -> torch.Tensor:
|
246
242
|
"""Stack a sequence of indexed values, creating a new dimension. The new
|
247
243
|
dimension is indexed by `dim`. The indexed values in the stack must have
|
248
244
|
identical shapes.
|
249
245
|
|
250
246
|
"""
|
251
|
-
|
247
|
+
values = torch.distributions.utils.broadcast_all(*values)
|
248
|
+
return torch.stack(values)[name_to_sym(name)()]
|
252
249
|
|
253
250
|
|
254
251
|
def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Tensor:
|
@@ -263,12 +260,14 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
|
|
263
260
|
Unlike a Python conditional expression, however, the case may be a tensor,
|
264
261
|
and both branches are evaluated, as with :func:`torch.where` ::
|
265
262
|
|
266
|
-
>>> from effectful.
|
267
|
-
>>>
|
268
|
-
|
263
|
+
>>> from effectful.ops.syntax import defop
|
264
|
+
>>> from effectful.handlers.torch import bind_dims
|
265
|
+
|
266
|
+
>>> b = defop(torch.Tensor, name="b")
|
267
|
+
>>> fst, snd = torch.randn(2, 3)[b()], torch.randn(2, 3)[b()]
|
269
268
|
>>> case = (fst < snd).all(-1)
|
270
269
|
>>> x = cond(fst, snd, case)
|
271
|
-
>>> assert (
|
270
|
+
>>> assert (bind_dims(x, b) == bind_dims(torch.where(case[..., None], snd, fst), b)).all()
|
272
271
|
|
273
272
|
.. note::
|
274
273
|
|
@@ -286,10 +285,10 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
|
|
286
285
|
)
|
287
286
|
|
288
287
|
|
289
|
-
def cond_n(values:
|
288
|
+
def cond_n(values: dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
|
290
289
|
assert len(values) > 0
|
291
290
|
assert all(isinstance(k, IndexSet) for k in values.keys())
|
292
|
-
result:
|
291
|
+
result: torch.Tensor | None = None
|
293
292
|
for indices, value in values.items():
|
294
293
|
tst = torch.as_tensor(
|
295
294
|
functools.reduce(
|
@@ -0,0 +1,14 @@
|
|
1
|
+
try:
|
2
|
+
# Dummy import to check if jax is installed
|
3
|
+
import jax # noqa: F401
|
4
|
+
except ImportError:
|
5
|
+
raise ImportError("Jax is required to use effectful.handlers.jax")
|
6
|
+
|
7
|
+
# side effect: register defdata for jax.Array
|
8
|
+
import effectful.handlers.jax._terms # noqa: F401
|
9
|
+
|
10
|
+
from ._handlers import bind_dims as bind_dims
|
11
|
+
from ._handlers import jax_getitem as jax_getitem
|
12
|
+
from ._handlers import jit as jit
|
13
|
+
from ._handlers import sizesof as sizesof
|
14
|
+
from ._handlers import unbind_dims as unbind_dims
|