effectful 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
effectful/ops/types.py CHANGED
@@ -2,20 +2,21 @@ from __future__ import annotations
2
2
 
3
3
  import abc
4
4
  import collections.abc
5
+ import functools
5
6
  import inspect
6
7
  import typing
7
- from typing import Any, Callable, Generic, Mapping, Sequence, Type, TypeVar, Union
8
+ from collections.abc import Callable, Mapping, Sequence
9
+ from typing import Any, _ProtocolMeta, overload, runtime_checkable
8
10
 
9
- from typing_extensions import ParamSpec
10
11
 
11
- P = ParamSpec("P")
12
- Q = ParamSpec("Q")
13
- S = TypeVar("S")
14
- T = TypeVar("T")
15
- V = TypeVar("V")
12
+ class NotHandled(Exception):
13
+ """Raised by an operation when the operation should remain unhandled."""
16
14
 
15
+ pass
17
16
 
18
- class Operation(abc.ABC, Generic[Q, V]):
17
+
18
+ @functools.total_ordering
19
+ class Operation[**Q, V](abc.ABC):
19
20
  """An abstract class representing an effect that can be implemented by an effect handler.
20
21
 
21
22
  .. note::
@@ -36,7 +37,11 @@ class Operation(abc.ABC, Generic[Q, V]):
36
37
  raise NotImplementedError
37
38
 
38
39
  @abc.abstractmethod
39
- def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]":
40
+ def __lt__(self, other):
41
+ raise NotImplementedError
42
+
43
+ @abc.abstractmethod
44
+ def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> Expr[V]:
40
45
  """The default rule is used when the operation is not handled.
41
46
 
42
47
  If no default rule is supplied, the free rule is used instead.
@@ -44,15 +49,12 @@ class Operation(abc.ABC, Generic[Q, V]):
44
49
  raise NotImplementedError
45
50
 
46
51
  @abc.abstractmethod
47
- def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> Type[V]:
52
+ def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]:
48
53
  """Returns the type of the operation applied to arguments."""
49
54
  raise NotImplementedError
50
55
 
51
56
  @abc.abstractmethod
52
- def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> tuple[
53
- tuple[collections.abc.Set["Operation"], ...],
54
- dict[str, collections.abc.Set["Operation"]],
55
- ]:
57
+ def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> inspect.BoundArguments:
56
58
  """
57
59
  Returns the sets of variables that appear free in each argument and keyword argument
58
60
  but not in the result of the operation, i.e. the variables bound by the operation.
@@ -63,19 +65,17 @@ class Operation(abc.ABC, Generic[Q, V]):
63
65
  """
64
66
  raise NotImplementedError
65
67
 
66
- @abc.abstractmethod
67
- def __repr_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> str:
68
- raise NotImplementedError
69
-
70
68
  @typing.final
71
69
  def __call__(self, *args: Q.args, **kwargs: Q.kwargs) -> V:
72
- from effectful.internals.runtime import get_interpretation
73
70
  from effectful.ops.semantics import apply
74
71
 
75
- return apply.__default_rule__(get_interpretation(), self, *args, **kwargs) # type: ignore
72
+ return apply.__default_rule__(self, *args, **kwargs) # type: ignore
76
73
 
74
+ def __repr__(self):
75
+ return f"{self.__class__.__name__}({self.__name__}, {self.__signature__})"
77
76
 
78
- class Term(abc.ABC, Generic[T]):
77
+
78
+ class Term[T](abc.ABC):
79
79
  """A term in an effectful computation is a is a tree of :class:`Operation`
80
80
  applied to values.
81
81
 
@@ -91,33 +91,125 @@ class Term(abc.ABC, Generic[T]):
91
91
 
92
92
  @property
93
93
  @abc.abstractmethod
94
- def args(self) -> Sequence["Expr[Any]"]:
94
+ def args(self) -> Sequence[Expr[Any]]:
95
95
  """Abstract property for the arguments."""
96
96
  raise NotImplementedError
97
97
 
98
98
  @property
99
99
  @abc.abstractmethod
100
- def kwargs(self) -> Mapping[str, "Expr[Any]"]:
100
+ def kwargs(self) -> Mapping[str, Expr[Any]]:
101
101
  """Abstract property for the keyword arguments."""
102
102
  raise NotImplementedError
103
103
 
104
104
  def __repr__(self) -> str:
105
+ return f"{self.__class__.__name__}({self.op!r}, {self.args!r}, {self.kwargs!r})"
106
+
107
+ def __str__(self) -> str:
105
108
  from effectful.internals.runtime import interpreter
106
109
  from effectful.ops.semantics import apply, evaluate
107
110
 
108
- with interpreter({apply: lambda _, op, *a, **k: op.__repr_rule__(*a, **k)}):
109
- return evaluate(self) # type: ignore
111
+ fresh: dict[str, dict[Operation, int]] = collections.defaultdict(dict)
112
+
113
+ def op_str(op):
114
+ """Return a unique (in this term) name for the operation."""
115
+ name = op.__name__
116
+ if name not in fresh:
117
+ fresh[name] = {op: 0}
118
+ if op not in fresh[name]:
119
+ fresh[name][op] = len(fresh[name])
120
+
121
+ n = fresh[name][op]
122
+ if n == 0:
123
+ return name
124
+ return f"{name}!{n}"
125
+
126
+ def term_str(term):
127
+ if isinstance(term, Operation):
128
+ return op_str(term)
129
+ elif isinstance(term, list):
130
+ return "[" + ", ".join(map(term_str, term)) + "]"
131
+ elif isinstance(term, tuple):
132
+ return "(" + ", ".join(map(term_str, term)) + ")"
133
+ elif isinstance(term, dict):
134
+ return (
135
+ "{"
136
+ + ", ".join(
137
+ f"{term_str(k)}:{term_str(v)}" for (k, v) in term.items()
138
+ )
139
+ + "}"
140
+ )
141
+ return str(term)
142
+
143
+ def _apply(op, *args, **kwargs) -> str:
144
+ args_str = ", ".join(map(term_str, args)) if args else ""
145
+ kwargs_str = (
146
+ ", ".join(f"{k}={term_str(v)}" for k, v in kwargs.items())
147
+ if kwargs
148
+ else ""
149
+ )
150
+
151
+ ret = f"{op_str(op)}({args_str}"
152
+ if kwargs:
153
+ ret += f"{', ' if args else ''}"
154
+ ret += f"{kwargs_str})"
155
+ return ret
156
+
157
+ with interpreter({apply: _apply}):
158
+ return typing.cast(str, evaluate(self))
110
159
 
111
160
 
112
161
  #: An expression is either a value or a term.
113
- Expr = Union[T, Term[T]]
162
+ type Expr[T] = T | Term[T]
114
163
 
115
- #: An interpretation is a mapping from operations to their implementations.
116
- Interpretation = Mapping[Operation[..., T], Callable[..., V]]
117
164
 
165
+ class _InterpretationMeta(_ProtocolMeta):
166
+ def __instancecheck__(cls, instance):
167
+ return isinstance(instance, collections.abc.Mapping) and all(
168
+ isinstance(k, Operation) and callable(v) for k, v in instance.items()
169
+ )
118
170
 
119
- class Annotation(abc.ABC):
120
171
 
172
+ @runtime_checkable
173
+ class Interpretation[T, V](typing.Protocol, metaclass=_InterpretationMeta):
174
+ """An interpretation is a mapping from operations to their implementations."""
175
+
176
+ def keys(self):
177
+ raise NotImplementedError
178
+
179
+ def values(self):
180
+ raise NotImplementedError
181
+
182
+ def items(self):
183
+ raise NotImplementedError
184
+
185
+ @overload
186
+ def get(self, key: Operation[..., T], /) -> Callable[..., V] | None:
187
+ raise NotImplementedError
188
+
189
+ @overload
190
+ def get(
191
+ self, key: Operation[..., T], default: Callable[..., V], /
192
+ ) -> Callable[..., V]:
193
+ raise NotImplementedError
194
+
195
+ @overload
196
+ def get[S](self, key: Operation[..., T], default: S, /) -> Callable[..., V] | S:
197
+ raise NotImplementedError
198
+
199
+ def __getitem__(self, key: Operation[..., T]) -> Callable[..., V]:
200
+ raise NotImplementedError
201
+
202
+ def __contains__(self, key: Operation[..., T]) -> bool:
203
+ raise NotImplementedError
204
+
205
+ def __iter__(self):
206
+ raise NotImplementedError
207
+
208
+ def __len__(self) -> int:
209
+ raise NotImplementedError
210
+
211
+
212
+ class Annotation(abc.ABC):
121
213
  @classmethod
122
214
  @abc.abstractmethod
123
215
  def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature:
@@ -1,72 +1,59 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: effectful
3
- Version: 0.1.0
3
+ Version: 0.2.0
4
4
  Summary: Metaprogramming infrastructure
5
- Home-page: https://www.basis.ai/
6
5
  Author: Basis
7
- License: Apache 2.0
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://www.basis.ai/
8
8
  Project-URL: Source, https://github.com/BasisResearch/effectful
9
- Keywords: machine learning statistics probabilistic programming bayesian modeling pytorch
9
+ Project-URL: Bug Tracker, https://github.com/BasisResearch/effectful/issues
10
+ Keywords: machine learning,statistics,probabilistic programming,bayesian modeling,pytorch
10
11
  Classifier: Intended Audience :: Developers
11
12
  Classifier: Intended Audience :: Education
12
13
  Classifier: Intended Audience :: Science/Research
13
- Classifier: License :: OSI Approved :: Apache Software License
14
14
  Classifier: Operating System :: POSIX :: Linux
15
15
  Classifier: Operating System :: MacOS :: MacOS X
16
- Classifier: Programming Language :: Python :: 3.10
17
- Classifier: Programming Language :: Python :: 3.11
18
- Requires-Python: >=3.10
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Requires-Python: >=3.12
19
+ Description-Content-Type: text/x-rst
19
20
  License-File: LICENSE.md
20
- Requires-Dist: typing_extensions
21
- Requires-Dist: dm-tree
22
21
  Provides-Extra: torch
23
22
  Requires-Dist: torch; extra == "torch"
23
+ Requires-Dist: dm-tree; extra == "torch"
24
24
  Provides-Extra: pyro
25
- Requires-Dist: torch; extra == "pyro"
26
- Requires-Dist: pyro-ppl; extra == "pyro"
27
- Provides-Extra: dev
28
- Requires-Dist: torch; extra == "dev"
29
- Requires-Dist: pyro-ppl; extra == "dev"
30
- Requires-Dist: torch; extra == "dev"
31
- Requires-Dist: setuptools; extra == "dev"
32
- Requires-Dist: sphinx; extra == "dev"
33
- Requires-Dist: sphinxcontrib-bibtex; extra == "dev"
34
- Requires-Dist: sphinx_rtd_theme; extra == "dev"
35
- Requires-Dist: myst-parser; extra == "dev"
36
- Requires-Dist: nbsphinx; extra == "dev"
37
- Requires-Dist: pytest; extra == "dev"
38
- Requires-Dist: pytest-cov; extra == "dev"
39
- Requires-Dist: pytest-xdist; extra == "dev"
40
- Requires-Dist: pytest-benchmark; extra == "dev"
41
- Requires-Dist: mypy; extra == "dev"
42
- Requires-Dist: black; extra == "dev"
43
- Requires-Dist: flake8; extra == "dev"
44
- Requires-Dist: isort; extra == "dev"
45
- Requires-Dist: nbval; extra == "dev"
46
- Requires-Dist: nbqa; extra == "dev"
25
+ Requires-Dist: pyro-ppl>=1.9.1; extra == "pyro"
26
+ Requires-Dist: dm-tree; extra == "pyro"
27
+ Provides-Extra: jax
28
+ Requires-Dist: jax; extra == "jax"
29
+ Requires-Dist: dm-tree; extra == "jax"
30
+ Provides-Extra: numpyro
31
+ Requires-Dist: numpyro>=0.19; extra == "numpyro"
32
+ Requires-Dist: dm-tree; extra == "numpyro"
47
33
  Provides-Extra: docs
48
- Requires-Dist: setuptools; extra == "docs"
34
+ Requires-Dist: effectful[jax,numpyro,pyro,torch]; extra == "docs"
49
35
  Requires-Dist: sphinx; extra == "docs"
50
36
  Requires-Dist: sphinxcontrib-bibtex; extra == "docs"
51
37
  Requires-Dist: sphinx_rtd_theme; extra == "docs"
52
38
  Requires-Dist: myst-parser; extra == "docs"
53
39
  Requires-Dist: nbsphinx; extra == "docs"
54
- Dynamic: author
55
- Dynamic: classifier
56
- Dynamic: description
57
- Dynamic: home-page
58
- Dynamic: keywords
59
- Dynamic: license
60
- Dynamic: project-url
61
- Dynamic: provides-extra
62
- Dynamic: requires-dist
63
- Dynamic: requires-python
64
- Dynamic: summary
65
-
40
+ Requires-Dist: sphinx_autodoc_typehints; extra == "docs"
41
+ Requires-Dist: pypandoc_binary; extra == "docs"
42
+ Provides-Extra: test
43
+ Requires-Dist: effectful[docs,jax,numpyro,pyro,torch]; extra == "test"
44
+ Requires-Dist: pytest; extra == "test"
45
+ Requires-Dist: pytest-cov; extra == "test"
46
+ Requires-Dist: pytest-xdist; extra == "test"
47
+ Requires-Dist: pytest-benchmark; extra == "test"
48
+ Requires-Dist: mypy; extra == "test"
49
+ Requires-Dist: ruff; extra == "test"
50
+ Requires-Dist: nbval; extra == "test"
51
+ Requires-Dist: nbqa; extra == "test"
52
+ Dynamic: license-file
66
53
 
67
54
  .. index-inclusion-marker
68
55
 
69
- Effectful
56
+ Effectful
70
57
  =========
71
58
 
72
59
  Effectful is an algebraic effect system for Python, intended for use in the
@@ -89,9 +76,12 @@ Install From Source
89
76
  Install With Optional PyTorch/Pyro Support
90
77
  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
91
78
 
92
- ``effectful`` has optional support for `PyTorch <https://pytorch.org/>`_ (tensors
93
- with named dimensions) and `Pyro <https://pyro.ai/>`_ (wrappers for Pyro
94
- effects).
79
+ ``effectful`` has optional support for:
80
+
81
+ - `PyTorch <https://pytorch.org/>`_ (tensors with named dimensions)
82
+ - `Pyro <https://pyro.ai/>`_ (wrappers for Pyro effects)
83
+ - `Jax <https://docs.jax.dev/en/latest/index.html>`_ (tensors with named dimensions)
84
+ - `Numpyro <https://num.pyro.ai>`_ (operations for Numpyro distributions)
95
85
 
96
86
  To enable PyTorch support:
97
87
 
@@ -105,6 +95,18 @@ Pyro support (which includes PyTorch support):
105
95
 
106
96
  pip install effectful[pyro]
107
97
 
98
+ Jax support:
99
+
100
+ .. code:: sh
101
+
102
+ pip install effectful[jax]
103
+
104
+ Numpyro support (which includes Jax support):
105
+
106
+ .. code:: sh
107
+
108
+ pip install effectful[numpyro]
109
+
108
110
  Getting Started
109
111
  ---------------
110
112
 
@@ -115,11 +117,12 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
115
117
  import functools
116
118
 
117
119
  from effectful.ops.types import Term
118
- from effectful.ops.syntax import defop
120
+ from effectful.ops.syntax import defdata, defop
119
121
  from effectful.ops.semantics import handler, evaluate, coproduct, fwd
120
- from effectful.handlers.numbers import add
121
122
 
122
- def beta_add(x: int, y: int) -> int:
123
+ add = defdata.dispatch(int).__add__
124
+
125
+ def beta_add(x: int, y: int) -> int:
123
126
  match x, y:
124
127
  case int(), int():
125
128
  return x + y
@@ -129,14 +132,14 @@ Here's an example demonstrating how ``effectful`` can be used to implement a sim
129
132
  def commute_add(x: int, y: int) -> int:
130
133
  match x, y:
131
134
  case Term(), int():
132
- return y + x
135
+ return y + x
133
136
  case _:
134
137
  return fwd()
135
138
 
136
139
  def assoc_add(x: int, y: int) -> int:
137
140
  match x, y:
138
141
  case _, Term(op, (a, b)) if op == add:
139
- return (x + a) + b
142
+ return (x + a) + b
140
143
  case _:
141
144
  return fwd()
142
145
 
@@ -168,7 +171,7 @@ We can make the evaluation strategy smarter by taking advantage of the commutati
168
171
  >>> with handler(eager_mixed):
169
172
  >>> print(evaluate(e))
170
173
  add(8, add(x(), y()))
171
-
174
+
172
175
  Learn More
173
176
  ----------
174
177
 
@@ -0,0 +1,26 @@
1
+ effectful/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ effectful/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ effectful/handlers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ effectful/handlers/indexed.py,sha256=ZY-8w32a1PKGVScwXjbgByI3wRHvfxuuXJVwLlp0rgw,12622
5
+ effectful/handlers/numpyro.py,sha256=RWoBNpHLr5KdotN5Vu118jY0kn_p6NaejcnJgZJLehw,19457
6
+ effectful/handlers/pyro.py,sha256=qVl1wson02pyV8YHGf93KDnYEp5pGmhKEwji95OYBl8,26486
7
+ effectful/handlers/torch.py,sha256=NNM7mxqZskEBCjsl25kHI95WlXG9aeD7FaSkXkoLZ_I,24330
8
+ effectful/handlers/jax/__init__.py,sha256=O-BygB2HCkDftP1B98mQnsrt7sOdZpvC2YoHXd61tmI,494
9
+ effectful/handlers/jax/_handlers.py,sha256=95fGOZoTQfYehs0ip_zR91TKwnJB-beK6wr_SH73Bg8,9252
10
+ effectful/handlers/jax/_terms.py,sha256=k_nLe5jvj76ZHVr4LWWFbQT4lxhGUVTEWQB6W38_M8I,16341
11
+ effectful/handlers/jax/numpy/__init__.py,sha256=Kmvya0QI-GA56pPf1as-wYOuZFngOBLtsawa35vPKhg,516
12
+ effectful/handlers/jax/numpy/linalg.py,sha256=9DiaYYG4SztmO-VkmMH3dVvULtMK-zEgbV9oNQFkFo8,350
13
+ effectful/handlers/jax/scipy/special.py,sha256=yTIECFtQVPgraonrPlyenjvcnEYchZwIZC-5CSkF-lA,299
14
+ effectful/internals/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ effectful/internals/runtime.py,sha256=aLWol7sR1yHekn7zNz1evHKHARjiT1tnkmByLHPHBGc,1811
16
+ effectful/internals/tensor_utils.py,sha256=3QCSUqdxCXod3dsY3oRMcg36Rqr8pVX-ktEyCEkeODo,1173
17
+ effectful/internals/unification.py,sha256=CTzRDVqziYKMjqedB1IKHLK4YkHK8TZrlPaP5Ivb0-o,30180
18
+ effectful/ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
+ effectful/ops/semantics.py,sha256=eFBKOKZ-ah0rN47GGUVW_hdBo1HXfk4hiuqjEgQCd-A,11596
20
+ effectful/ops/syntax.py,sha256=5YZh7vBVsWqMRNGok0_yNYLhsJXjrnq2sL1VN9aiN7M,55531
21
+ effectful/ops/types.py,sha256=W1gZJaBnX7_nFpWrG3vfCBQPSun3Gc9PqT61ls8B3EA,6599
22
+ effectful-0.2.0.dist-info/licenses/LICENSE.md,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
23
+ effectful-0.2.0.dist-info/METADATA,sha256=OXKK_zuSeSxST7LbwMCobitLX0OsOQu-wRfd7LzKJGM,5300
24
+ effectful-0.2.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
25
+ effectful-0.2.0.dist-info/top_level.txt,sha256=gtuJfrE2nXil_lZLCnqWF2KAbOnJs9ILNvK8WnkRzbs,10
26
+ effectful-0.2.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.0)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5