effectful 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,18 +1,16 @@
1
1
  import functools
2
2
  import operator
3
- from typing import Any, Dict, Iterable, Optional, Set, TypeVar, Union
3
+ from collections.abc import Iterable
4
+ from typing import Any
4
5
 
5
6
  import torch
6
7
 
7
- from effectful.handlers.torch import Indexable, sizesof
8
+ from effectful.handlers.torch import sizesof
8
9
  from effectful.ops.syntax import deffn, defop
9
10
  from effectful.ops.types import Operation
10
11
 
11
- K = TypeVar("K")
12
- T = TypeVar("T")
13
12
 
14
-
15
- class IndexSet(Dict[str, Set[int]]):
13
+ class IndexSet(dict[str, set[int]]):
16
14
  """
17
15
  :class:`IndexSet` s represent the support of an indexed value,
18
16
  for which free variables correspond to single interventions and indices
@@ -32,13 +30,13 @@ class IndexSet(Dict[str, Set[int]]):
32
30
  for which a value is defined::
33
31
 
34
32
  >>> IndexSet(x={0, 1}, y={2, 3})
35
- IndexSet({x: {0, 1}, y: {2, 3}})
33
+ IndexSet({'x': {0, 1}, 'y': {2, 3}})
36
34
 
37
35
  :class:`IndexSet` 's constructor will automatically drop empty entries
38
36
  and attempt to convert input values to :class:`set` s::
39
37
 
40
38
  >>> IndexSet(x=[0, 0, 1], y=set(), z=2)
41
- IndexSet({x: {0, 1}, z: {2}})
39
+ IndexSet({'x': {0, 1}, 'z': {2}})
42
40
 
43
41
  :class:`IndexSet` s are also hashable and can be used as keys in :class:`dict` s::
44
42
 
@@ -47,7 +45,7 @@ class IndexSet(Dict[str, Set[int]]):
47
45
  True
48
46
  """
49
47
 
50
- def __init__(self, **mapping: Union[int, Iterable[int]]):
48
+ def __init__(self, **mapping: int | Iterable[int]):
51
49
  index_set = {}
52
50
  for k, vs in mapping.items():
53
51
  indexes = {vs} if isinstance(vs, int) else set(vs)
@@ -161,12 +159,12 @@ def indices_of(value: Any) -> IndexSet:
161
159
  )
162
160
 
163
161
 
164
- @functools.lru_cache(maxsize=None)
165
- def name_to_sym(name: str) -> Operation[[], int]:
166
- return defop(int, name=name)
162
+ @functools.cache
163
+ def name_to_sym(name: str) -> Operation[[], torch.Tensor]:
164
+ return defop(torch.Tensor, name=name)
167
165
 
168
166
 
169
- def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
167
+ def gather(value: torch.Tensor, indexset: IndexSet) -> torch.Tensor:
170
168
  """
171
169
  Selects entries from an indexed value at the indices in a :class:`IndexSet` .
172
170
  :func:`gather` is useful in conjunction with :class:`MultiWorldCounterfactual`
@@ -230,9 +228,7 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
230
228
  """
231
229
  indexset_vars = {name_to_sym(name): inds for name, inds in indexset.items()}
232
230
  binding = {
233
- k: functools.partial(
234
- lambda v: v, Indexable(torch.tensor(list(indexset_vars[k])))[k()]
235
- )
231
+ k: functools.partial(lambda v: v, torch.tensor(list(indexset_vars[k]))[k()])
236
232
  for k in sizesof(value).keys()
237
233
  if k in indexset_vars
238
234
  }
@@ -241,14 +237,15 @@ def gather(value: torch.Tensor, indexset: IndexSet, **kwargs) -> torch.Tensor:
241
237
 
242
238
 
243
239
  def stack(
244
- values: Union[tuple[torch.Tensor, ...], list[torch.Tensor]], name: str, **kwargs
240
+ values: tuple[torch.Tensor, ...] | list[torch.Tensor], name: str
245
241
  ) -> torch.Tensor:
246
242
  """Stack a sequence of indexed values, creating a new dimension. The new
247
243
  dimension is indexed by `dim`. The indexed values in the stack must have
248
244
  identical shapes.
249
245
 
250
246
  """
251
- return Indexable(torch.stack(values))[name_to_sym(name)()]
247
+ values = torch.distributions.utils.broadcast_all(*values)
248
+ return torch.stack(values)[name_to_sym(name)()]
252
249
 
253
250
 
254
251
  def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Tensor:
@@ -263,12 +260,14 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
263
260
  Unlike a Python conditional expression, however, the case may be a tensor,
264
261
  and both branches are evaluated, as with :func:`torch.where` ::
265
262
 
266
- >>> from effectful.internals.sugar import gensym
267
- >>> b = gensym(int, name="b")
268
- >>> fst, snd = Indexable(torch.randn(2, 3))[b()], Indexable(torch.randn(2, 3))[b()]
263
+ >>> from effectful.ops.syntax import defop
264
+ >>> from effectful.handlers.torch import bind_dims
265
+
266
+ >>> b = defop(torch.Tensor, name="b")
267
+ >>> fst, snd = torch.randn(2, 3)[b()], torch.randn(2, 3)[b()]
269
268
  >>> case = (fst < snd).all(-1)
270
269
  >>> x = cond(fst, snd, case)
271
- >>> assert (to_tensor(x, [b]) == to_tensor(torch.where(case[..., None], snd, fst), [b])).all()
270
+ >>> assert (bind_dims(x, b) == bind_dims(torch.where(case[..., None], snd, fst), b)).all()
272
271
 
273
272
  .. note::
274
273
 
@@ -286,10 +285,10 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten
286
285
  )
287
286
 
288
287
 
289
- def cond_n(values: Dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
288
+ def cond_n(values: dict[IndexSet, torch.Tensor], case: torch.Tensor) -> torch.Tensor:
290
289
  assert len(values) > 0
291
290
  assert all(isinstance(k, IndexSet) for k in values.keys())
292
- result: Optional[torch.Tensor] = None
291
+ result: torch.Tensor | None = None
293
292
  for indices, value in values.items():
294
293
  tst = torch.as_tensor(
295
294
  functools.reduce(
@@ -0,0 +1,14 @@
1
+ try:
2
+ # Dummy import to check if jax is installed
3
+ import jax # noqa: F401
4
+ except ImportError:
5
+ raise ImportError("Jax is required to use effectful.handlers.jax")
6
+
7
+ # side effect: register defdata for jax.Array
8
+ import effectful.handlers.jax._terms # noqa: F401
9
+
10
+ from ._handlers import bind_dims as bind_dims
11
+ from ._handlers import jax_getitem as jax_getitem
12
+ from ._handlers import jit as jit
13
+ from ._handlers import sizesof as sizesof
14
+ from ._handlers import unbind_dims as unbind_dims
@@ -0,0 +1,293 @@
1
+ import functools
2
+ import typing
3
+ from collections.abc import Callable, Mapping, Sequence
4
+ from types import EllipsisType
5
+ from typing import Annotated
6
+
7
+ try:
8
+ import jax
9
+ import jax.numpy as jnp
10
+ except ImportError:
11
+ raise ImportError("JAX is required to use effectful.handlers.jax")
12
+
13
+ import tree
14
+
15
+ from effectful.ops.semantics import fvsof, typeof
16
+ from effectful.ops.syntax import (
17
+ Scoped,
18
+ _CustomSingleDispatchCallable,
19
+ defdata,
20
+ deffn,
21
+ defop,
22
+ defterm,
23
+ syntactic_eq,
24
+ )
25
+ from effectful.ops.types import Expr, NotHandled, Operation, Term
26
+
27
+ # + An element of an array index expression.
28
+ IndexElement = None | int | slice | Sequence[int] | EllipsisType | jax.Array
29
+
30
+
31
+ def is_eager_array(x):
32
+ return isinstance(x, jax.Array) or (
33
+ isinstance(x, Term)
34
+ and x.op is jax_getitem
35
+ and isinstance(x.args[0], jax.Array)
36
+ and all(
37
+ (not isinstance(k, Term)) or (not k.args and not k.kwargs)
38
+ for k in x.args[1]
39
+ )
40
+ and not x.kwargs
41
+ )
42
+
43
+
44
+ def sizesof(value) -> Mapping[Operation[[], jax.Array], int]:
45
+ """Return the sizes of named dimensions in an array expression.
46
+
47
+ Sizes are inferred from the array shape.
48
+
49
+ :param value: An array expression.
50
+ :return: A mapping from named dimensions to their sizes.
51
+
52
+ **Example usage**:
53
+
54
+ >>> a, b = defop(jax.Array, name='a'), defop(jax.Array, name='b')
55
+ >>> sizes = sizesof(jax_getitem(jnp.ones((2, 3)), [a(), b()]))
56
+ >>> assert sizes[a] == 2 and sizes[b] == 3
57
+ """
58
+ sizes: dict[Operation[[], jax.Array], int] = {}
59
+
60
+ def update_sizes(sizes, op, size):
61
+ old_size = sizes.get(op)
62
+ if old_size is not None and size != old_size:
63
+ raise ValueError(
64
+ f"Named index {op} used in incompatible dimensions of size {old_size} and {size}"
65
+ )
66
+ sizes[op] = size
67
+
68
+ def _getitem_sizeof(x: jax.Array, key: tuple[Expr[IndexElement], ...]):
69
+ if is_eager_array(x):
70
+ for i, k in enumerate(key):
71
+ if isinstance(k, Term) and len(k.args) == 0 and len(k.kwargs) == 0:
72
+ update_sizes(sizes, k.op, x.shape[i])
73
+
74
+ def _sizesof(expr):
75
+ expr = defterm(expr)
76
+ if isinstance(expr, Term):
77
+ for x in tree.flatten((expr.args, expr.kwargs)):
78
+ _sizesof(x)
79
+ if expr.op is jax_getitem:
80
+ _getitem_sizeof(*expr.args)
81
+ elif tree.is_nested(expr):
82
+ for x in tree.flatten(expr):
83
+ _sizesof(x)
84
+
85
+ _sizesof(value)
86
+ return sizes
87
+
88
+
89
+ def _partial_eval(t: Expr[jax.Array]) -> Expr[jax.Array]:
90
+ """Partially evaluate a term with respect to its sized free variables."""
91
+
92
+ sized_fvs = sizesof(t)
93
+ if not sized_fvs:
94
+ return t
95
+
96
+ def _is_eager(t):
97
+ return not isinstance(t, Term) or t.op in sized_fvs or is_eager_array(t)
98
+
99
+ if not (
100
+ isinstance(t, Term)
101
+ and all(_is_eager(a) for a in tree.flatten((t.args, t.kwargs)))
102
+ ):
103
+ return t
104
+
105
+ tpe_jax_fn = jax.vmap(deffn(t, *sized_fvs.keys()))
106
+
107
+ # Create indices for each dimension
108
+ indices = jnp.meshgrid(
109
+ *[jnp.arange(size) for size in sized_fvs.values()], indexing="ij"
110
+ )
111
+
112
+ # Flatten indices for vmap
113
+ flat_indices = [idx.reshape(-1) for idx in indices]
114
+
115
+ # Apply vmap
116
+ flat_result = tpe_jax_fn(*flat_indices)
117
+
118
+ def reindex_flat_array(t):
119
+ if not isinstance(t, jax.Array):
120
+ return t
121
+
122
+ result_shape = indices[0].shape + t.shape[1:]
123
+ result = jnp.reshape(t, result_shape)
124
+ return jax_getitem(result, tuple(k() for k in sized_fvs.keys()))
125
+
126
+ result = tree.map_structure(reindex_flat_array, flat_result)
127
+ return result
128
+
129
+
130
+ @functools.cache
131
+ def _register_jax_op[**P, T](jax_fn: Callable[P, T]):
132
+ if getattr(jax_fn, "__name__", None) == "__getitem__":
133
+ return jax_getitem
134
+
135
+ @defop
136
+ def _jax_op(*args, **kwargs) -> jax.Array:
137
+ tm = defdata(_jax_op, *args, **kwargs)
138
+ sized_fvs = sizesof(tm)
139
+
140
+ if (
141
+ _jax_op is jax_getitem
142
+ and not isinstance(args[0], Term)
143
+ and sized_fvs
144
+ and args[1]
145
+ and all(isinstance(k, Term) and k.op in sized_fvs for k in args[1])
146
+ ):
147
+ raise NotHandled
148
+ elif sized_fvs and set(sized_fvs.keys()) == fvsof(tm) - {jax_getitem, _jax_op}:
149
+ # note: this cast is a lie. partial_eval can return non-arrays, as
150
+ # can jax_fn. for example, some jax functions return tuples,
151
+ # which partial_eval handles.
152
+ return typing.cast(jax.Array, _partial_eval(tm))
153
+ elif not any(
154
+ tree.flatten(
155
+ tree.map_structure(lambda x: isinstance(x, Term), (args, kwargs))
156
+ )
157
+ ):
158
+ return typing.cast(jax.Array, jax_fn(*args, **kwargs))
159
+ else:
160
+ raise NotHandled
161
+
162
+ functools.update_wrapper(_jax_op, jax_fn)
163
+ return _jax_op
164
+
165
+
166
+ @functools.cache
167
+ def _register_jax_op_no_partial_eval[**P, T](jax_fn: Callable[P, T]):
168
+ # FIXME: Presumably not all jax ops return arrays. In other cases, we won't
169
+ # get the right kind of term.
170
+ @defop
171
+ def _jax_op(*args, **kwargs) -> jax.Array:
172
+ if not any(
173
+ tree.flatten(
174
+ tree.map_structure(lambda x: isinstance(x, Term), (args, kwargs))
175
+ )
176
+ ):
177
+ return typing.cast(jax.Array, jax_fn(*args, **kwargs))
178
+ else:
179
+ raise NotHandled
180
+
181
+ functools.update_wrapper(_jax_op, jax_fn)
182
+ return _jax_op
183
+
184
+
185
+ @_register_jax_op
186
+ def jax_getitem(x: jax.Array, key: tuple[IndexElement, ...]) -> jax.Array:
187
+ """Operation for indexing an array. Unlike the standard __getitem__ method,
188
+ this operation correctly handles indexing with terms.
189
+
190
+ """
191
+ return x[tuple(key)]
192
+
193
+
194
+ @defop
195
+ @_CustomSingleDispatchCallable
196
+ def bind_dims[T, A, B](
197
+ __dispatch: Callable[[type], Callable[..., T]],
198
+ value: Annotated[T, Scoped[A | B]],
199
+ *names: Annotated[Operation[[], jax.Array], Scoped[B]],
200
+ ) -> Annotated[T, Scoped[A]]:
201
+ """Convert named dimensions to positional dimensions.
202
+
203
+ :param t: An array.
204
+ :param args: Named dimensions to convert to positional dimensions.
205
+ These positional dimensions will appear at the beginning of the
206
+ shape.
207
+ :return: An array with the named dimensions in ``args`` converted to positional dimensions.
208
+
209
+ **Example usage**:
210
+
211
+ >>> import jax.numpy as jnp
212
+ >>> from effectful.ops.syntax import defop
213
+ >>> a, b = defop(jax.Array, name='a'), defop(jax.Array, name='b')
214
+ >>> t = jax_getitem(jnp.ones((2, 3)), [a(), b()])
215
+ >>> bind_dims(t, b, a).shape
216
+ (3, 2)
217
+ """
218
+ if tree.is_nested(value):
219
+ return tree.map_structure(lambda v: bind_dims(v, *names), value)
220
+
221
+ semantic_type = typeof(value)
222
+ return __dispatch(semantic_type)(value, *names)
223
+
224
+
225
+ @defop
226
+ @_CustomSingleDispatchCallable
227
+ def unbind_dims[T, A, B](
228
+ __dispatch: Callable[[type], Callable[..., T]],
229
+ value: Annotated[T, Scoped[A | B]],
230
+ *names: Annotated[Operation[[], jax.Array], Scoped[B]],
231
+ ) -> Annotated[T, Scoped[A | B]]:
232
+ """Convert positional dimensions to named dimensions."""
233
+ if tree.is_nested(value):
234
+ return tree.map_structure(lambda v: unbind_dims(v, *names), value)
235
+
236
+ semantic_type = typeof(value)
237
+ return __dispatch(semantic_type)(value, *names)
238
+
239
+
240
+ def jit(f, *args, **kwargs):
241
+ f_noindex, f_reindex = _indexed_func_wrapper(f, jax_getitem, sizesof)
242
+ f_noindex_jitted = jax.jit(f_noindex, *args, **kwargs)
243
+ return lambda *args, **kwargs: f_reindex(f_noindex_jitted(*args, **kwargs))
244
+
245
+
246
+ def _indexed_func_wrapper[**P, S, T](
247
+ func: Callable[P, T], getitem, sizesof
248
+ ) -> tuple[Callable[P, S], Callable[[S], T]]:
249
+ # index expressions for the result of the function
250
+ indexes = None
251
+
252
+ # hide index lists from tree.map_structure
253
+ class Indexes:
254
+ def __init__(self, sizes):
255
+ self.sizes = sizes
256
+ self.indexes = list(sizes.keys())
257
+
258
+ # strip named indexes from the result of the function and store them
259
+ def deindexed(*args, **kwargs):
260
+ nonlocal indexes
261
+
262
+ def deindex_tensor(t, i):
263
+ t_ = bind_dims(t, *i.sizes.keys())
264
+ assert all(t_.shape[j] == i.sizes[v] for j, v in enumerate(i.sizes))
265
+ return t_
266
+
267
+ ret = func(*args, **kwargs)
268
+ indexes = tree.map_structure(lambda t: Indexes(sizesof(t)), ret)
269
+ tensors = tree.map_structure(lambda t, i: deindex_tensor(t, i), ret, indexes)
270
+ return tensors
271
+
272
+ # reapply the stored indexes to a result
273
+ def reindex(ret, starting_dim=0):
274
+ def index_expr(i):
275
+ return (slice(None),) * (starting_dim) + tuple(x() for x in i.indexes)
276
+
277
+ if tree.is_nested(ret):
278
+ indexed_ret = tree.map_structure(
279
+ lambda t, i: getitem(t, index_expr(i)), ret, indexes
280
+ )
281
+ else:
282
+ indexed_ret = getitem(ret, index_expr(indexes))
283
+
284
+ return indexed_ret
285
+
286
+ return deindexed, reindex
287
+
288
+
289
+ @syntactic_eq.register
290
+ def _(x: jax.typing.ArrayLike, other) -> bool:
291
+ return isinstance(other, jax.typing.ArrayLike) and bool( # type: ignore[arg-type]
292
+ (jnp.asarray(x) == jnp.asarray(other)).all()
293
+ )