effectful 0.0.1__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {effectful-0.0.1 → effectful-0.2.0}/PKG-INFO +65 -57
  2. {effectful-0.0.1 → effectful-0.2.0}/README.rst +26 -14
  3. {effectful-0.0.1 → effectful-0.2.0}/effectful/handlers/indexed.py +27 -46
  4. effectful-0.2.0/effectful/handlers/jax/__init__.py +14 -0
  5. effectful-0.2.0/effectful/handlers/jax/_handlers.py +293 -0
  6. effectful-0.2.0/effectful/handlers/jax/_terms.py +502 -0
  7. effectful-0.2.0/effectful/handlers/jax/numpy/__init__.py +23 -0
  8. effectful-0.2.0/effectful/handlers/jax/numpy/linalg.py +13 -0
  9. effectful-0.2.0/effectful/handlers/jax/scipy/special.py +11 -0
  10. effectful-0.2.0/effectful/handlers/numpyro.py +562 -0
  11. effectful-0.2.0/effectful/handlers/pyro.py +817 -0
  12. effectful-0.2.0/effectful/handlers/torch.py +724 -0
  13. {effectful-0.0.1 → effectful-0.2.0}/effectful/internals/runtime.py +6 -13
  14. effectful-0.2.0/effectful/internals/tensor_utils.py +32 -0
  15. effectful-0.2.0/effectful/internals/unification.py +900 -0
  16. {effectful-0.0.1 → effectful-0.2.0}/effectful/ops/semantics.py +104 -84
  17. effectful-0.2.0/effectful/ops/syntax.py +1632 -0
  18. effectful-0.2.0/effectful/ops/types.py +216 -0
  19. {effectful-0.0.1 → effectful-0.2.0}/effectful.egg-info/PKG-INFO +65 -57
  20. {effectful-0.0.1 → effectful-0.2.0}/effectful.egg-info/SOURCES.txt +14 -5
  21. effectful-0.2.0/effectful.egg-info/requires.txt +37 -0
  22. effectful-0.2.0/pyproject.toml +81 -0
  23. effectful-0.2.0/setup.cfg +4 -0
  24. {effectful-0.0.1 → effectful-0.2.0}/tests/test_examples_minipyro.py +0 -1
  25. {effectful-0.0.1 → effectful-0.2.0}/tests/test_handlers_indexed.py +8 -6
  26. effectful-0.2.0/tests/test_handlers_jax.py +941 -0
  27. effectful-0.2.0/tests/test_handlers_numpyro.py +855 -0
  28. {effectful-0.0.1 → effectful-0.2.0}/tests/test_handlers_pyro.py +71 -46
  29. {effectful-0.0.1 → effectful-0.2.0}/tests/test_handlers_pyro_dist.py +359 -219
  30. {effectful-0.0.1 → effectful-0.2.0}/tests/test_handlers_torch.py +270 -149
  31. effectful-0.2.0/tests/test_internals_unification.py +1646 -0
  32. {effectful-0.0.1 → effectful-0.2.0}/tests/test_ops_semantics.py +242 -21
  33. effectful-0.2.0/tests/test_ops_syntax.py +873 -0
  34. effectful-0.2.0/tests/test_ops_types.py +12 -0
  35. {effectful-0.0.1 → effectful-0.2.0}/tests/test_semi_ring.py +23 -23
  36. effectful-0.0.1/effectful/handlers/numbers.py +0 -259
  37. effectful-0.0.1/effectful/handlers/pyro.py +0 -466
  38. effectful-0.0.1/effectful/handlers/torch.py +0 -572
  39. effectful-0.0.1/effectful/internals/base_impl.py +0 -259
  40. effectful-0.0.1/effectful/ops/syntax.py +0 -523
  41. effectful-0.0.1/effectful/ops/types.py +0 -110
  42. effectful-0.0.1/effectful.egg-info/requires.txt +0 -28
  43. effectful-0.0.1/setup.cfg +0 -25
  44. effectful-0.0.1/setup.py +0 -72
  45. effectful-0.0.1/tests/test_handlers_numbers.py +0 -263
  46. effectful-0.0.1/tests/test_ops_syntax.py +0 -92
  47. {effectful-0.0.1 → effectful-0.2.0}/LICENSE.md +0 -0
  48. {effectful-0.0.1 → effectful-0.2.0}/effectful/__init__.py +0 -0
  49. {effectful-0.0.1 → effectful-0.2.0}/effectful/handlers/__init__.py +0 -0
  50. {effectful-0.0.1 → effectful-0.2.0}/effectful/internals/__init__.py +0 -0
  51. {effectful-0.0.1 → effectful-0.2.0}/effectful/ops/__init__.py +0 -0
  52. {effectful-0.0.1 → effectful-0.2.0}/effectful/py.typed +0 -0
  53. {effectful-0.0.1 → effectful-0.2.0}/effectful.egg-info/dependency_links.txt +0 -0
  54. {effectful-0.0.1 → effectful-0.2.0}/effectful.egg-info/top_level.txt +0 -0
@@ -1,63 +1,59 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: effectful
3
- Version: 0.0.1
3
+ Version: 0.2.0
4
4
  Summary: Metaprogramming infrastructure
5
- Home-page: https://www.basis.ai/
6
5
  Author: Basis
7
- License: Apache 2.0
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://www.basis.ai/
8
8
  Project-URL: Source, https://github.com/BasisResearch/effectful
9
- Keywords: machine learning statistics probabilistic programming bayesian modeling pytorch
9
+ Project-URL: Bug Tracker, https://github.com/BasisResearch/effectful/issues
10
+ Keywords: machine learning,statistics,probabilistic programming,bayesian modeling,pytorch
10
11
  Classifier: Intended Audience :: Developers
11
12
  Classifier: Intended Audience :: Education
12
13
  Classifier: Intended Audience :: Science/Research
13
- Classifier: License :: OSI Approved :: Apache Software License
14
14
  Classifier: Operating System :: POSIX :: Linux
15
15
  Classifier: Operating System :: MacOS :: MacOS X
16
- Classifier: Programming Language :: Python :: 3.10
17
- Classifier: Programming Language :: Python :: 3.11
18
- Requires-Python: >=3.10
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Requires-Python: >=3.12
19
+ Description-Content-Type: text/x-rst
19
20
  License-File: LICENSE.md
20
- Requires-Dist: typing_extensions
21
- Requires-Dist: dm-tree
22
21
  Provides-Extra: torch
23
22
  Requires-Dist: torch; extra == "torch"
23
+ Requires-Dist: dm-tree; extra == "torch"
24
24
  Provides-Extra: pyro
25
- Requires-Dist: torch; extra == "pyro"
26
- Requires-Dist: pyro-ppl; extra == "pyro"
27
- Provides-Extra: dev
28
- Requires-Dist: torch; extra == "dev"
29
- Requires-Dist: pyro-ppl; extra == "dev"
30
- Requires-Dist: torch; extra == "dev"
31
- Requires-Dist: pytest; extra == "dev"
32
- Requires-Dist: pytest-cov; extra == "dev"
33
- Requires-Dist: pytest-xdist; extra == "dev"
34
- Requires-Dist: pytest-benchmark; extra == "dev"
35
- Requires-Dist: mypy; extra == "dev"
36
- Requires-Dist: black; extra == "dev"
37
- Requires-Dist: flake8; extra == "dev"
38
- Requires-Dist: isort; extra == "dev"
39
- Requires-Dist: sphinx; extra == "dev"
40
- Requires-Dist: sphinxcontrib-bibtex; extra == "dev"
41
- Requires-Dist: sphinx_rtd_theme; extra == "dev"
42
- Requires-Dist: myst-parser; extra == "dev"
43
- Requires-Dist: nbsphinx; extra == "dev"
44
- Requires-Dist: nbval; extra == "dev"
45
- Requires-Dist: nbqa; extra == "dev"
46
- Dynamic: author
47
- Dynamic: classifier
48
- Dynamic: description
49
- Dynamic: home-page
50
- Dynamic: keywords
51
- Dynamic: license
52
- Dynamic: project-url
53
- Dynamic: provides-extra
54
- Dynamic: requires-dist
55
- Dynamic: requires-python
56
- Dynamic: summary
25
+ Requires-Dist: pyro-ppl>=1.9.1; extra == "pyro"
26
+ Requires-Dist: dm-tree; extra == "pyro"
27
+ Provides-Extra: jax
28
+ Requires-Dist: jax; extra == "jax"
29
+ Requires-Dist: dm-tree; extra == "jax"
30
+ Provides-Extra: numpyro
31
+ Requires-Dist: numpyro>=0.19; extra == "numpyro"
32
+ Requires-Dist: dm-tree; extra == "numpyro"
33
+ Provides-Extra: docs
34
+ Requires-Dist: effectful[jax,numpyro,pyro,torch]; extra == "docs"
35
+ Requires-Dist: sphinx; extra == "docs"
36
+ Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
37
+ Requires-Dist: sphinx_rtd_theme; extra == "docs"
38
+ Requires-Dist: myst-parser; extra == "docs"
39
+ Requires-Dist: nbsphinx; extra == "docs"
40
+ Requires-Dist: sphinx_autodoc_typehints; extra == "docs"
41
+ Requires-Dist: pypandoc_binary; extra == "docs"
42
+ Provides-Extra: test
43
+ Requires-Dist: effectful[docs,jax,numpyro,pyro,torch]; extra == "test"
44
+ Requires-Dist: pytest; extra == "test"
45
+ Requires-Dist: pytest-cov; extra == "test"
46
+ Requires-Dist: pytest-xdist; extra == "test"
47
+ Requires-Dist: pytest-benchmark; extra == "test"
48
+ Requires-Dist: mypy; extra == "test"
49
+ Requires-Dist: ruff; extra == "test"
50
+ Requires-Dist: nbval; extra == "test"
51
+ Requires-Dist: nbqa; extra == "test"
52
+ Dynamic: license-file
57
53
 
58
54
  .. index-inclusion-marker
59
55
 
60
- Effectful
56
+ Effectful
61
57
  =========
62
58
 
63
59
  Effectful is an algebraic effect system for Python, intended for use in the
@@ -75,14 +71,17 @@ Install From Source
75
71
  git clone git@github.com:BasisResearch/effectful.git
76
72
  cd effectful
77
73
  git checkout master
78
- pip install -e .[test]
74
+ pip install -e .[pyro]
79
75
 
80
76
  Install With Optional PyTorch/Pyro Support
81
77
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
82
78
 
83
- ``effectful`` has optional support for `PyTorch <https://pytorch.org/>`_ (tensors
84
- with named dimensions) and `Pyro <https://pyro.ai/>`_ (wrappers for Pyro
85
- effects).
79
+ ``effectful`` has optional support for:
80
+
81
+ - `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
82
+ - `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
83
+ - `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
84
+ - `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
86
85
 
87
86
  To enable PyTorch support:
88
87
 
@@ -96,6 +95,18 @@ Pyro support (which includes PyTorch support):
96
95
 
97
96
  pip install effectful[pyro]
98
97
 
98
+ Jax support:
99
+
100
+ .. code:: sh
101
+
102
+ pip install effectful[jax]
103
+
104
+ Numpyro support (which includes Jax support):
105
+
106
+ .. code:: sh
107
+
108
+ pip install effectful[numpyro]
109
+
99
110
  Getting Started
100
111
  ---------------
101
112
 
@@ -103,18 +114,15 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
103
114
 
104
115
  .. code:: python
105
116
 
106
- import operator
107
117
  import functools
108
118
 
109
- from effectful.handlers.operator import OPERATORS
110
119
  from effectful.ops.types import Term
111
- from effectful.ops.syntax import defop
120
+ from effectful.ops.syntax import defdata, defop
112
121
  from effectful.ops.semantics import handler, evaluate, coproduct, fwd
113
- import effectful.handlers.operator
114
122
 
115
- add = OPERATORS[operator.add]
123
+ add = defdata.dispatch(int).__add__
116
124
 
117
- def beta_add(x: int, y: int) -> int:
125
+ def beta_add(x: int, y: int) -> int:
118
126
  match x, y:
119
127
  case int(), int():
120
128
  return x + y
@@ -124,14 +132,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
124
132
  def commute_add(x: int, y: int) -> int:
125
133
  match x, y:
126
134
  case Term(), int():
127
- return y + x
135
+ return y + x
128
136
  case _:
129
137
  return fwd()
130
138
 
131
139
  def assoc_add(x: int, y: int) -> int:
132
140
  match x, y:
133
141
  case _, Term(op, (a, b)) if op == add:
134
- return (x + a) + b
142
+ return (x + a) + b
135
143
  case _:
136
144
  return fwd()
137
145
 
@@ -163,7 +171,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
163
171
  >>> with handler(eager_mixed):
164
172
  >>> print(evaluate(e))
165
173
  add(8, add(x(), y()))
166
-
174
+
167
175
  Learn More
168
176
  ----------
169
177
 
@@ -1,6 +1,6 @@
1
1
  .. index-inclusion-marker
2
2
 
3
- Effectful
3
+ Effectful
4
4
  =========
5
5
 
6
6
  Effectful is an algebraic effect system for Python, intended for use in the
@@ -18,14 +18,17 @@ Install From Source
18
18
  git clone git@github.com:BasisResearch/effectful.git
19
19
  cd effectful
20
20
  git checkout master
21
- pip install -e .[test]
21
+ pip install -e .[pyro]
22
22
 
23
23
  Install With Optional PyTorch/Pyro Support
24
24
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
25
25
 
26
- ``effectful`` has optional support for `PyTorch <https://pytorch.org/>`_ (tensors
27
- with named dimensions) and `Pyro <https://pyro.ai/>`_ (wrappers for Pyro
28
- effects).
26
+ ``effectful`` has optional support for:
27
+
28
+ - `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
29
+ - `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
30
+ - `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
31
+ - `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
29
32
 
30
33
  To enable PyTorch support:
31
34
 
@@ -39,6 +42,18 @@ Pyro support (which includes PyTorch support):
39
42
 
40
43
  pip install effectful[pyro]
41
44
 
45
+ Jax support:
46
+
47
+ .. code:: sh
48
+
49
+ pip install effectful[jax]
50
+
51
+ Numpyro support (which includes Jax support):
52
+
53
+ .. code:: sh
54
+
55
+ pip install effectful[numpyro]
56
+
42
57
  Getting Started
43
58
  ---------------
44
59
 
@@ -46,18 +61,15 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
46
61
 
47
62
  .. code:: python
48
63
 
49
- import operator
50
64
  import functools
51
65
 
52
- from effectful.handlers.operator import OPERATORS
53
66
  from effectful.ops.types import Term
54
- from effectful.ops.syntax import defop
67
+ from effectful.ops.syntax import defdata, defop
55
68
  from effectful.ops.semantics import handler, evaluate, coproduct, fwd
56
- import effectful.handlers.operator
57
69
 
58
- add = OPERATORS[operator.add]
70
+ add = defdata.dispatch(int).__add__
59
71
 
60
- def beta_add(x: int, y: int) -> int:
72
+ def beta_add(x: int, y: int) -> int:
61
73
  match x, y:
62
74
  case int(), int():
63
75
  return x + y
@@ -67,14 +79,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
67
79
  def commute_add(x: int, y: int) -> int:
68
80
  match x, y:
69
81
  case Term(), int():
70
- return y + x
82
+ return y + x
71
83
  case _:
72
84
  return fwd()
73
85
 
74
86
  def assoc_add(x: int, y: int) -> int:
75
87
  match x, y:
76
88
  case _, Term(op, (a, b)) if op == add:
77
- return (x + a) + b
89
+ return (x + a) + b
78
90
  case _:
79
91
  return fwd()
80
92
 
@@ -106,7 +118,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
106
118
  >>> with handler(eager_mixed):
107
119
  >>> print(evaluate(e))
108
120
  add(8, add(x(), y()))
109
-
121
+
110
122
  Learn More
111
123
  ----------
112
124
 
@@ -1,18 +1,16 @@
1
1
  import functools
2
2
  import operator
3
- from typing import Any, Dict, Iterable, Optional, Set, TypeVar, Union
3
+ from collections.abc import Iterable
4
+ from typing import Any
4
5
 
5
6
  import torch
6
7
 
7
- from effectful.handlers.torch import Indexable, sizesof
8
+ from effectful.handlers.torch import sizesof
8
9
  from effectful.ops.syntax import deffn, defop
9
- from effectful.ops.types import Operation, Term
10
+ from effectful.ops.types import Operation
10
11
 
11
- K = TypeVar("K")
12
- T = TypeVar("T")
13
12
 
14
-
15
- class IndexSet(Dict[str, Set[int]]):
13
+ class IndexSet(dict[str, set[int]]):
16
14
  """
17
15
  :class:`IndexSet` s represent the support of an indexed value,
18
16
  for which free variables correspond to single interventions and indices
@@ -32,13 +30,13 @@ class IndexSet(Dict[str, Set[int]]):
32
30
  for which a value is defined::
33
31
 
34
32
  >>> IndexSet(x={0, 1}, y={2, 3})
35
- IndexSet({x: {0, 1}, y: {2, 3}})
33
+ IndexSet({'x': {0, 1}, 'y': {2, 3}})
36
34
 
37
35
  :class:`IndexSet` 's constructor will automatically drop empty entries
38
36
  and attempt to convert input values to :class:`set` s::
39
37
 
40
38
  >>> IndexSet(x=[0, 0, 1], y=set(), z=2)
41
- IndexSet({x: {0, 1}, z: {2}})
39
+ IndexSet({'x': {0, 1}, 'z': {2}})
42
40
 
43
41
  :class:`IndexSet` s are also hashable and can be used as keys in :class:`dict` s::
44
42
 
@@ -47,7 +45,7 @@ class IndexSet(Dict[str, Set[int]]):
47
45
  True
48
46
  """
49
47
 
50
- def __init__(self, **mapping: Union[int, Iterable[int]]):
48
+ def __init__(self, **mapping: int | Iterable[int]):
51
49
  index_set = {}
52
50
  for k, vs in mapping.items():
53
51
  indexes = {vs} if isinstance(vs, int) else set(vs)
@@ -61,16 +59,6 @@ class IndexSet(Dict[str, Set[int]]):
61
59
  def __hash__(self):
62
60
  return hash(frozenset((k, frozenset(vs)) for k, vs in self.items()))
63
61
 
64
- def _to_handler(self):
65
- """Return an effectful handler that binds each index variable to a
66
- tensor of its possible index values.
67
-
68
- """
69
- return {
70
- name_to_sym(k): functools.partial(lambda v: v, torch.tensor(list(v)))
71
- for k, v in self.items()
72
- }
73
-
74
62
 
75
63
  def union(*indexsets: IndexSet) -> IndexSet:
76
64
  """
@@ -166,25 +154,17 @@ def indices_of(value: Any) -> IndexSet:
166
154
  :param kwargs: Additional keyword arguments used by specific implementations.
167
155
  :return: A :class:`IndexSet` containing the indices on which the value is supported.
168
156
  """
169
- if isinstance(value, Term):
170
- return IndexSet(
171
- **{
172
- k.__name__: set(range(v)) # type:ignore
173
- for (k, v) in sizesof(value).items()
174
- }
175
- )
176
- elif isinstance(value, torch.distributions.Distribution):
177
- return indices_of(value.sample())
178
-
179
- return IndexSet()
157
+ return IndexSet(
158
+ **{getattr(k, "__name__"): set(range(v)) for (k, v) in sizesof(value).items()}
159
+ )
180
160
 
181
161
 
182
- @functools.lru_cache(maxsize=None)
183
- def name_to_sym(name: str) -> Operation[[], int]:
184
- return defop(int, name=name)
162
+ @functools.cache
163
+ def name_to_sym(name: str) -> Operation[[], torch.Tensor]:
164
+ return defop(torch.Tensor, name=name)
185
165
 
186
166
 
187
- def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
167
+ def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
188
168
  """
189
169
  Selects entries from an indexed value at the indices in a :class:`IndexSet` .
190
170
  :func:`gather` is useful in conjunction with :class:`MultiWorldCounterfactual`
@@ -248,9 +228,7 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
248
228
  """
249
229
  indexset_vars = {name_to_sym(name): inds for name, inds in indexset.items()}
250
230
  binding = {
251
- k: functools.partial(
252
- lambda v: v, Indexable(torch.tensor(list(indexset_vars[k])))[k()]
253
- )
231
+ k: functools.partial(lambda v: v, torch.tensor(list(indexset_vars[k]))[k()])
254
232
  for k in sizesof(value).keys()
255
233
  if k in indexset_vars
256
234
  }
@@ -259,14 +237,15 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
259
237
 
260
238
 
261
239
  def stack(
262
- values: Union[tuple[torch.Tensor, ...], list[torch.Tensor]], name: str, **kwargs
240
+ values: tuple[torch.Tensor, ...] | list[torch.Tensor], name: str
263
241
  ) -> torch.Tensor:
264
242
  """Stack a sequence of indexed values, creating a new dimension. The new
265
243
  dimension is indexed by `dim`. The indexed values in the stack must have
266
244
  identical shapes.
267
245
 
268
246
  """
269
- return Indexable(torch.stack(values))[name_to_sym(name)()]
247
+ values = torch.distributions.utils.broadcast_all(*values)
248
+ return torch.stack(values)[name_to_sym(name)()]
270
249
 
271
250
 
272
251
  def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Tensor:
@@ -281,12 +260,14 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
281
260
  Unlike a Python conditional expression, however, the case may be a tensor,
282
261
  and both branches are evaluated, as with :func:`torch.where` ::
283
262
 
284
- >>> from effectful.internals.sugar import gensym
285
- >>> b = gensym(int, name="b")
286
- >>> fst, snd = Indexable(torch.randn(2, 3))[b()], Indexable(torch.randn(2, 3))[b()]
263
+ >>> from effectful.ops.syntax import defop
264
+ >>> from effectful.handlers.torch import bind_dims
265
+
266
+ >>> b = defop(torch.Tensor, name="b")
267
+ >>> fst, snd = torch.randn(2, 3)[b()], torch.randn(2, 3)[b()]
287
268
  >>> case = (fst < snd).all(-1)
288
269
  >>> x = cond(fst, snd, case)
289
- >>> assert (to_tensor(x, [b]) == to_tensor(torch.where(case[..., None], snd, fst), [b])).all()
270
+ >>> assert (bind_dims(x, b) == bind_dims(torch.where(case[..., None], snd, fst), b)).all()
290
271
 
291
272
  .. note::
292
273
 
@@ -304,10 +285,10 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
304
285
  )
305
286
 
306
287
 
307
- def cond_n(values: Dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
288
+ def cond_n(values: dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
308
289
  assert len(values) > 0
309
290
  assert all(isinstance(k, IndexSet) for k in values.keys())
310
- result: Optional[torch.Tensor] = None
291
+ result: torch.Tensor | None = None
311
292
  for indices, value in values.items():
312
293
  tst = torch.as_tensor(
313
294
  functools.reduce(
@@ -0,0 +1,14 @@
1
+ try:
2
+ # Dummy import to check if jax is installed
3
+ import jax # noqa: F401
4
+ except ImportError:
5
+ raise ImportError("Jax is required to use effectful.handlers.jax")
6
+
7
+ # side effect: register defdata for jax.Array
8
+ import effectful.handlers.jax._terms # noqa: F401
9
+
10
+ from ._handlers import bind_dims as bind_dims
11
+ from ._handlers import jax_getitem as jax_getitem
12
+ from ._handlers import jit as jit
13
+ from ._handlers import sizesof as sizesof
14
+ from ._handlers import unbind_dims as unbind_dims