effectful 0.1.0__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {effectful-0.1.0 → effectful-0.2.0}/PKG-INFO +59 -56
  2. {effectful-0.1.0 → effectful-0.2.0}/README.rst +26 -11
  3. {effectful-0.1.0 → effectful-0.2.0}/effectful/handlers/indexed.py +23 -24
  4. effectful-0.2.0/effectful/handlers/jax/__init__.py +14 -0
  5. effectful-0.2.0/effectful/handlers/jax/_handlers.py +293 -0
  6. effectful-0.2.0/effectful/handlers/jax/_terms.py +502 -0
  7. effectful-0.2.0/effectful/handlers/jax/numpy/__init__.py +23 -0
  8. effectful-0.2.0/effectful/handlers/jax/numpy/linalg.py +13 -0
  9. effectful-0.2.0/effectful/handlers/jax/scipy/special.py +11 -0
  10. effectful-0.2.0/effectful/handlers/numpyro.py +562 -0
  11. effectful-0.2.0/effectful/handlers/pyro.py +817 -0
  12. {effectful-0.1.0 → effectful-0.2.0}/effectful/handlers/torch.py +297 -168
  13. {effectful-0.1.0 → effectful-0.2.0}/effectful/internals/runtime.py +6 -13
  14. effectful-0.2.0/effectful/internals/tensor_utils.py +32 -0
  15. effectful-0.2.0/effectful/internals/unification.py +900 -0
  16. {effectful-0.1.0 → effectful-0.2.0}/effectful/ops/semantics.py +101 -77
  17. {effectful-0.1.0 → effectful-0.2.0}/effectful/ops/syntax.py +813 -251
  18. effectful-0.2.0/effectful/ops/types.py +216 -0
  19. {effectful-0.1.0 → effectful-0.2.0}/effectful.egg-info/PKG-INFO +59 -56
  20. {effectful-0.1.0 → effectful-0.2.0}/effectful.egg-info/SOURCES.txt +14 -4
  21. effectful-0.2.0/effectful.egg-info/requires.txt +37 -0
  22. effectful-0.2.0/pyproject.toml +81 -0
  23. effectful-0.2.0/setup.cfg +4 -0
  24. {effectful-0.1.0 → effectful-0.2.0}/tests/test_examples_minipyro.py +0 -1
  25. {effectful-0.1.0 → effectful-0.2.0}/tests/test_handlers_indexed.py +8 -6
  26. effectful-0.2.0/tests/test_handlers_jax.py +941 -0
  27. effectful-0.2.0/tests/test_handlers_numpyro.py +855 -0
  28. {effectful-0.1.0 → effectful-0.2.0}/tests/test_handlers_pyro.py +71 -46
  29. {effectful-0.1.0 → effectful-0.2.0}/tests/test_handlers_pyro_dist.py +359 -219
  30. {effectful-0.1.0 → effectful-0.2.0}/tests/test_handlers_torch.py +270 -149
  31. effectful-0.2.0/tests/test_internals_unification.py +1646 -0
  32. {effectful-0.1.0 → effectful-0.2.0}/tests/test_ops_semantics.py +195 -24
  33. effectful-0.2.0/tests/test_ops_syntax.py +873 -0
  34. effectful-0.2.0/tests/test_ops_types.py +12 -0
  35. {effectful-0.1.0 → effectful-0.2.0}/tests/test_semi_ring.py +23 -23
  36. effectful-0.1.0/effectful/handlers/numbers.py +0 -263
  37. effectful-0.1.0/effectful/handlers/pyro.py +0 -466
  38. effectful-0.1.0/effectful/ops/types.py +0 -124
  39. effectful-0.1.0/effectful.egg-info/requires.txt +0 -37
  40. effectful-0.1.0/setup.cfg +0 -25
  41. effectful-0.1.0/setup.py +0 -81
  42. effectful-0.1.0/tests/test_handlers_numbers.py +0 -266
  43. effectful-0.1.0/tests/test_ops_syntax.py +0 -124
  44. {effectful-0.1.0 → effectful-0.2.0}/LICENSE.md +0 -0
  45. {effectful-0.1.0 → effectful-0.2.0}/effectful/__init__.py +0 -0
  46. {effectful-0.1.0 → effectful-0.2.0}/effectful/handlers/__init__.py +0 -0
  47. {effectful-0.1.0 → effectful-0.2.0}/effectful/internals/__init__.py +0 -0
  48. {effectful-0.1.0 → effectful-0.2.0}/effectful/ops/__init__.py +0 -0
  49. {effectful-0.1.0 → effectful-0.2.0}/effectful/py.typed +0 -0
  50. {effectful-0.1.0 → effectful-0.2.0}/effectful.egg-info/dependency_links.txt +0 -0
  51. {effectful-0.1.0 → effectful-0.2.0}/effectful.egg-info/top_level.txt +0 -0
@@ -1,72 +1,59 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: effectful
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Metaprogramming infrastructure
5
- Home-page: https://www.basis.ai/
6
5
  Author: Basis
7
- License: Apache 2.0
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://www.basis.ai/
8
8
  Project-URL: Source, https://github.com/BasisResearch/effectful
9
- Keywords: machine learning statistics probabilistic programming bayesian modeling pytorch
9
+ Project-URL: Bug Tracker, https://github.com/BasisResearch/effectful/issues
10
+ Keywords: machine learning,statistics,probabilistic programming,bayesian modeling,pytorch
10
11
  Classifier: Intended Audience :: Developers
11
12
  Classifier: Intended Audience :: Education
12
13
  Classifier: Intended Audience :: Science/Research
13
- Classifier: License :: OSI Approved :: Apache Software License
14
14
  Classifier: Operating System :: POSIX :: Linux
15
15
  Classifier: Operating System :: MacOS :: MacOS X
16
- Classifier: Programming Language :: Python :: 3.10
17
- Classifier: Programming Language :: Python :: 3.11
18
- Requires-Python: >=3.10
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Requires-Python: >=3.12
19
+ Description-Content-Type: text/x-rst
19
20
  License-File: LICENSE.md
20
- Requires-Dist: typing_extensions
21
- Requires-Dist: dm-tree
22
21
  Provides-Extra: torch
23
22
  Requires-Dist: torch; extra == "torch"
23
+ Requires-Dist: dm-tree; extra == "torch"
24
24
  Provides-Extra: pyro
25
- Requires-Dist: torch; extra == "pyro"
26
- Requires-Dist: pyro-ppl; extra == "pyro"
27
- Provides-Extra: dev
28
- Requires-Dist: torch; extra == "dev"
29
- Requires-Dist: pyro-ppl; extra == "dev"
30
- Requires-Dist: torch; extra == "dev"
31
- Requires-Dist: setuptools; extra == "dev"
32
- Requires-Dist: sphinx; extra == "dev"
33
- Requires-Dist: sphinxcontrib-bibtex; extra == "dev"
34
- Requires-Dist: sphinx_rtd_theme; extra == "dev"
35
- Requires-Dist: myst-parser; extra == "dev"
36
- Requires-Dist: nbsphinx; extra == "dev"
37
- Requires-Dist: pytest; extra == "dev"
38
- Requires-Dist: pytest-cov; extra == "dev"
39
- Requires-Dist: pytest-xdist; extra == "dev"
40
- Requires-Dist: pytest-benchmark; extra == "dev"
41
- Requires-Dist: mypy; extra == "dev"
42
- Requires-Dist: black; extra == "dev"
43
- Requires-Dist: flake8; extra == "dev"
44
- Requires-Dist: isort; extra == "dev"
45
- Requires-Dist: nbval; extra == "dev"
46
- Requires-Dist: nbqa; extra == "dev"
25
+ Requires-Dist: pyro-ppl>=1.9.1; extra == "pyro"
26
+ Requires-Dist: dm-tree; extra == "pyro"
27
+ Provides-Extra: jax
28
+ Requires-Dist: jax; extra == "jax"
29
+ Requires-Dist: dm-tree; extra == "jax"
30
+ Provides-Extra: numpyro
31
+ Requires-Dist: numpyro>=0.19; extra == "numpyro"
32
+ Requires-Dist: dm-tree; extra == "numpyro"
47
33
  Provides-Extra: docs
48
- Requires-Dist: setuptools; extra == "docs"
34
+ Requires-Dist: effectful[jax,numpyro,pyro,torch]; extra == "docs"
49
35
  Requires-Dist: sphinx; extra == "docs"
50
36
  Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
51
37
  Requires-Dist: sphinx_rtd_theme; extra == "docs"
52
38
  Requires-Dist: myst-parser; extra == "docs"
53
39
  Requires-Dist: nbsphinx; extra == "docs"
54
- Dynamic: author
55
- Dynamic: classifier
56
- Dynamic: description
57
- Dynamic: home-page
58
- Dynamic: keywords
59
- Dynamic: license
60
- Dynamic: project-url
61
- Dynamic: provides-extra
62
- Dynamic: requires-dist
63
- Dynamic: requires-python
64
- Dynamic: summary
65
-
40
+ Requires-Dist: sphinx_autodoc_typehints; extra == "docs"
41
+ Requires-Dist: pypandoc_binary; extra == "docs"
42
+ Provides-Extra: test
43
+ Requires-Dist: effectful[docs,jax,numpyro,pyro,torch]; extra == "test"
44
+ Requires-Dist: pytest; extra == "test"
45
+ Requires-Dist: pytest-cov; extra == "test"
46
+ Requires-Dist: pytest-xdist; extra == "test"
47
+ Requires-Dist: pytest-benchmark; extra == "test"
48
+ Requires-Dist: mypy; extra == "test"
49
+ Requires-Dist: ruff; extra == "test"
50
+ Requires-Dist: nbval; extra == "test"
51
+ Requires-Dist: nbqa; extra == "test"
52
+ Dynamic: license-file
66
53
 
67
54
  .. index-inclusion-marker
68
55
 
69
- Effectful
56
+ Effectful
70
57
  =========
71
58
 
72
59
  Effectful is an algebraic effect system for Python, intended for use in the
@@ -89,9 +76,12 @@ Install From Source
89
76
  Install With Optional PyTorch/Pyro Support
90
77
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
91
78
 
92
- ``effectful`` has optional support for `PyTorch <https://pytorch.org/>`_ (tensors
93
- with named dimensions) and `Pyro <https://pyro.ai/>`_ (wrappers for Pyro
94
- effects).
79
+ ``effectful`` has optional support for:
80
+
81
+ - `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
82
+ - `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
83
+ - `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
84
+ - `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
95
85
 
96
86
  To enable PyTorch support:
97
87
 
@@ -105,6 +95,18 @@ Pyro support (which includes PyTorch support):
105
95
 
106
96
  pip install effectful[pyro]
107
97
 
98
+ Jax support:
99
+
100
+ .. code:: sh
101
+
102
+ pip install effectful[jax]
103
+
104
+ Numpyro support (which includes Jax support):
105
+
106
+ .. code:: sh
107
+
108
+ pip install effectful[numpyro]
109
+
108
110
  Getting Started
109
111
  ---------------
110
112
 
@@ -115,11 +117,12 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
115
117
  import functools
116
118
 
117
119
  from effectful.ops.types import Term
118
- from effectful.ops.syntax import defop
120
+ from effectful.ops.syntax import defdata, defop
119
121
  from effectful.ops.semantics import handler, evaluate, coproduct, fwd
120
- from effectful.handlers.numbers import add
121
122
 
122
- def beta_add(x: int, y: int) -> int:
123
+ add = defdata.dispatch(int).__add__
124
+
125
+ def beta_add(x: int, y: int) -> int:
123
126
  match x, y:
124
127
  case int(), int():
125
128
  return x + y
@@ -129,14 +132,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
129
132
  def commute_add(x: int, y: int) -> int:
130
133
  match x, y:
131
134
  case Term(), int():
132
- return y + x
135
+ return y + x
133
136
  case _:
134
137
  return fwd()
135
138
 
136
139
  def assoc_add(x: int, y: int) -> int:
137
140
  match x, y:
138
141
  case _, Term(op, (a, b)) if op == add:
139
- return (x + a) + b
142
+ return (x + a) + b
140
143
  case _:
141
144
  return fwd()
142
145
 
@@ -168,7 +171,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
168
171
  >>> with handler(eager_mixed):
169
172
  >>> print(evaluate(e))
170
173
  add(8, add(x(), y()))
171
-
174
+
172
175
  Learn More
173
176
  ----------
174
177
 
@@ -1,7 +1,6 @@
1
-
2
1
  .. index-inclusion-marker
3
2
 
4
- Effectful
3
+ Effectful
5
4
  =========
6
5
 
7
6
  Effectful is an algebraic effect system for Python, intended for use in the
@@ -24,9 +23,12 @@ Install From Source
24
23
  Install With Optional PyTorch/Pyro Support
25
24
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26
25
 
27
- ``effectful`` has optional support for `PyTorch <https://pytorch.org/>`_ (tensors
28
- with named dimensions) and `Pyro <https://pyro.ai/>`_ (wrappers for Pyro
29
- effects).
26
+ ``effectful`` has optional support for:
27
+
28
+ - `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
29
+ - `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
30
+ - `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
31
+ - `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
30
32
 
31
33
  To enable PyTorch support:
32
34
 
@@ -40,6 +42,18 @@ Pyro support (which includes PyTorch support):
40
42
 
41
43
  pip install effectful[pyro]
42
44
 
45
+ Jax support:
46
+
47
+ .. code:: sh
48
+
49
+ pip install effectful[jax]
50
+
51
+ Numpyro support (which includes Jax support):
52
+
53
+ .. code:: sh
54
+
55
+ pip install effectful[numpyro]
56
+
43
57
  Getting Started
44
58
  ---------------
45
59
 
@@ -50,11 +64,12 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
50
64
  import functools
51
65
 
52
66
  from effectful.ops.types import Term
53
- from effectful.ops.syntax import defop
67
+ from effectful.ops.syntax import defdata, defop
54
68
  from effectful.ops.semantics import handler, evaluate, coproduct, fwd
55
- from effectful.handlers.numbers import add
56
69
 
57
- def beta_add(x: int, y: int) -> int:
70
+ add = defdata.dispatch(int).__add__
71
+
72
+ def beta_add(x: int, y: int) -> int:
58
73
  match x, y:
59
74
  case int(), int():
60
75
  return x + y
@@ -64,14 +79,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
64
79
  def commute_add(x: int, y: int) -> int:
65
80
  match x, y:
66
81
  case Term(), int():
67
- return y + x
82
+ return y + x
68
83
  case _:
69
84
  return fwd()
70
85
 
71
86
  def assoc_add(x: int, y: int) -> int:
72
87
  match x, y:
73
88
  case _, Term(op, (a, b)) if op == add:
74
- return (x + a) + b
89
+ return (x + a) + b
75
90
  case _:
76
91
  return fwd()
77
92
 
@@ -103,7 +118,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
103
118
  >>> with handler(eager_mixed):
104
119
  >>> print(evaluate(e))
105
120
  add(8, add(x(), y()))
106
-
121
+
107
122
  Learn More
108
123
  ----------
109
124
 
@@ -1,18 +1,16 @@
1
1
  import functools
2
2
  import operator
3
- from typing import Any, Dict, Iterable, Optional, Set, TypeVar, Union
3
+ from collections.abc import Iterable
4
+ from typing import Any
4
5
 
5
6
  import torch
6
7
 
7
- from effectful.handlers.torch import Indexable, sizesof
8
+ from effectful.handlers.torch import sizesof
8
9
  from effectful.ops.syntax import deffn, defop
9
10
  from effectful.ops.types import Operation
10
11
 
11
- K = TypeVar("K")
12
- T = TypeVar("T")
13
12
 
14
-
15
- class IndexSet(Dict[str, Set[int]]):
13
+ class IndexSet(dict[str, set[int]]):
16
14
  """
17
15
  :class:`IndexSet` s represent the support of an indexed value,
18
16
  for which free variables correspond to single interventions and indices
@@ -32,13 +30,13 @@ class IndexSet(Dict[str, Set[int]]):
32
30
  for which a value is defined::
33
31
 
34
32
  >>> IndexSet(x={0, 1}, y={2, 3})
35
- IndexSet({x: {0, 1}, y: {2, 3}})
33
+ IndexSet({'x': {0, 1}, 'y': {2, 3}})
36
34
 
37
35
  :class:`IndexSet` 's constructor will automatically drop empty entries
38
36
  and attempt to convert input values to :class:`set` s::
39
37
 
40
38
  >>> IndexSet(x=[0, 0, 1], y=set(), z=2)
41
- IndexSet({x: {0, 1}, z: {2}})
39
+ IndexSet({'x': {0, 1}, 'z': {2}})
42
40
 
43
41
  :class:`IndexSet` s are also hashable and can be used as keys in :class:`dict` s::
44
42
 
@@ -47,7 +45,7 @@ class IndexSet(Dict[str, Set[int]]):
47
45
  True
48
46
  """
49
47
 
50
- def __init__(self, **mapping: Union[int, Iterable[int]]):
48
+ def __init__(self, **mapping: int | Iterable[int]):
51
49
  index_set = {}
52
50
  for k, vs in mapping.items():
53
51
  indexes = {vs} if isinstance(vs, int) else set(vs)
@@ -161,12 +159,12 @@ def indices_of(value: Any) -> IndexSet:
161
159
  )
162
160
 
163
161
 
164
- @functools.lru_cache(maxsize=None)
165
- def name_to_sym(name: str) -> Operation[[], int]:
166
- return defop(int, name=name)
162
+ @functools.cache
163
+ def name_to_sym(name: str) -> Operation[[], torch.Tensor]:
164
+ return defop(torch.Tensor, name=name)
167
165
 
168
166
 
169
- def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
167
+ def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
170
168
  """
171
169
  Selects entries from an indexed value at the indices in a :class:`IndexSet` .
172
170
  :func:`gather` is useful in conjunction with :class:`MultiWorldCounterfactual`
@@ -230,9 +228,7 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
230
228
  """
231
229
  indexset_vars = {name_to_sym(name): inds for name, inds in indexset.items()}
232
230
  binding = {
233
- k: functools.partial(
234
- lambda v: v, Indexable(torch.tensor(list(indexset_vars[k])))[k()]
235
- )
231
+ k: functools.partial(lambda v: v, torch.tensor(list(indexset_vars[k]))[k()])
236
232
  for k in sizesof(value).keys()
237
233
  if k in indexset_vars
238
234
  }
@@ -241,14 +237,15 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
241
237
 
242
238
 
243
239
  def stack(
244
- values: Union[tuple[torch.Tensor, ...], list[torch.Tensor]], name: str, **kwargs
240
+ values: tuple[torch.Tensor, ...] | list[torch.Tensor], name: str
245
241
  ) -> torch.Tensor:
246
242
  """Stack a sequence of indexed values, creating a new dimension. The new
247
243
  dimension is indexed by `dim`. The indexed values in the stack must have
248
244
  identical shapes.
249
245
 
250
246
  """
251
- return Indexable(torch.stack(values))[name_to_sym(name)()]
247
+ values = torch.distributions.utils.broadcast_all(*values)
248
+ return torch.stack(values)[name_to_sym(name)()]
252
249
 
253
250
 
254
251
  def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Tensor:
@@ -263,12 +260,14 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
263
260
  Unlike a Python conditional expression, however, the case may be a tensor,
264
261
  and both branches are evaluated, as with :func:`torch.where` ::
265
262
 
266
- >>> from effectful.internals.sugar import gensym
267
- >>> b = gensym(int, name="b")
268
- >>> fst, snd = Indexable(torch.randn(2, 3))[b()], Indexable(torch.randn(2, 3))[b()]
263
+ >>> from effectful.ops.syntax import defop
264
+ >>> from effectful.handlers.torch import bind_dims
265
+
266
+ >>> b = defop(torch.Tensor, name="b")
267
+ >>> fst, snd = torch.randn(2, 3)[b()], torch.randn(2, 3)[b()]
269
268
  >>> case = (fst < snd).all(-1)
270
269
  >>> x = cond(fst, snd, case)
271
- >>> assert (to_tensor(x, [b]) == to_tensor(torch.where(case[..., None], snd, fst), [b])).all()
270
+ >>> assert (bind_dims(x, b) == bind_dims(torch.where(case[..., None], snd, fst), b)).all()
272
271
 
273
272
  .. note::
274
273
 
@@ -286,10 +285,10 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
286
285
  )
287
286
 
288
287
 
289
- def cond_n(values: Dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
288
+ def cond_n(values: dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
290
289
  assert len(values) > 0
291
290
  assert all(isinstance(k, IndexSet) for k in values.keys())
292
- result: Optional[torch.Tensor] = None
291
+ result: torch.Tensor | None = None
293
292
  for indices, value in values.items():
294
293
  tst = torch.as_tensor(
295
294
  functools.reduce(
@@ -0,0 +1,14 @@
1
+ try:
2
+ # Dummy import to check if jax is installed
3
+ import jax # noqa: F401
4
+ except ImportError:
5
+ raise ImportError("Jax is required to use effectful.handlers.jax")
6
+
7
+ # side effect: register defdata for jax.Array
8
+ import effectful.handlers.jax._terms # noqa: F401
9
+
10
+ from ._handlers import bind_dims as bind_dims
11
+ from ._handlers import jax_getitem as jax_getitem
12
+ from ._handlers import jit as jit
13
+ from ._handlers import sizesof as sizesof
14
+ from ._handlers import unbind_dims as unbind_dims