effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,502 @@
1
+ import functools
2
+ import operator
3
+ from collections.abc import Sequence
4
+ from typing import Any, cast
5
+
6
+ import jax
7
+ import tree
8
+
9
+ import effectful.handlers.jax.numpy as jnp
10
+ from effectful.handlers.jax._handlers import (
11
+ IndexElement,
12
+ _partial_eval,
13
+ _register_jax_op,
14
+ bind_dims,
15
+ jax_getitem,
16
+ unbind_dims,
17
+ )
18
+ from effectful.internals.tensor_utils import _desugar_tensor_index
19
+ from effectful.ops.syntax import defdata
20
+ from effectful.ops.types import Expr, NotHandled, Operation, Term
21
+
22
+
23
+ class _IndexUpdateHelper:
24
+ """Helper class to implement array-style .at[index].set() updates for effectful arrays."""
25
+
26
+ def __init__(self, array):
27
+ self.array = array
28
+
29
+ def __getitem__(self, key):
30
+ return _IndexUpdateRef(self.array, key)
31
+
32
+
33
+ class _IndexUpdateRef:
34
+ """Reference to an array position for updates via .at[index]."""
35
+
36
+ def __init__(self, array, key):
37
+ self.array = array
38
+ self.key = key
39
+
40
+ def set(self, value):
41
+ """Set values at the indexed positions."""
42
+
43
+ # Create a JAX at operation that properly handles the indexing
44
+ @_register_jax_op
45
+ def jax_at_set(arr, index_key, val):
46
+ # JAX's at expects the index to be unpacked correctly
47
+ if isinstance(index_key, tuple) and len(index_key) == 1:
48
+ # Single index case
49
+ return arr.at[index_key[0]].set(val)
50
+ elif isinstance(index_key, tuple):
51
+ # Multiple indices case
52
+ return arr.at[index_key].set(val)
53
+ else:
54
+ # Direct index case
55
+ return arr.at[index_key].set(val)
56
+
57
+ return jax_at_set(self.array, self.key, value)
58
+
59
+
60
+ @defdata.register(jax.Array)
61
+ def _embed_array(op, *args, **kwargs):
62
+ if (
63
+ op is jax_getitem
64
+ and not isinstance(args[0], Term)
65
+ and all(not k.args and not k.kwargs for k in args[1] if isinstance(k, Term))
66
+ ):
67
+ return _EagerArrayTerm(jax_getitem, args[0], args[1])
68
+ else:
69
+ return _ArrayTerm(op, *args, **kwargs)
70
+
71
+
72
+ class _ArrayTerm(Term[jax.Array]):
73
+ def __init__(self, op: Operation[..., jax.Array], *args: Expr, **kwargs: Expr):
74
+ self._op = op
75
+ self._args = args
76
+ self._kwargs = kwargs
77
+
78
+ @property
79
+ def op(self) -> Operation[..., jax.Array]:
80
+ return self._op
81
+
82
+ @property
83
+ def args(self) -> tuple:
84
+ return self._args
85
+
86
+ @property
87
+ def kwargs(self) -> dict:
88
+ return self._kwargs
89
+
90
+ def __getitem__(
91
+ self, key: Expr[IndexElement] | tuple[Expr[IndexElement], ...]
92
+ ) -> Expr[jax.Array]:
93
+ return jax_getitem(self, key if isinstance(key, tuple) else (key,))
94
+
95
+ @property
96
+ def shape(self) -> Expr[tuple[int, ...]]:
97
+ return jnp.shape(cast(jax.Array, self))
98
+
99
+ @property
100
+ def size(self) -> Expr[int]:
101
+ return jnp.size(cast(jax.Array, self))
102
+
103
+ def __len__(self):
104
+ return self.shape[0]
105
+
106
+ @property
107
+ def ndim(self) -> Expr[int]:
108
+ return jnp.ndim(cast(jax.Array, self))
109
+
110
+ def __add__(self, other: jax.Array) -> jax.Array:
111
+ return jnp.add(cast(jax.Array, self), other)
112
+
113
+ def __radd__(self, other: jax.Array) -> jax.Array:
114
+ return jnp.add(other, cast(jax.Array, self))
115
+
116
+ def __neg__(self) -> jax.Array:
117
+ return jnp.negative(cast(jax.Array, self))
118
+
119
+ def __pos__(self) -> jax.Array:
120
+ return jnp.positive(cast(jax.Array, self))
121
+
122
+ def __sub__(self, other: jax.Array) -> jax.Array:
123
+ return jnp.subtract(cast(jax.Array, self), other)
124
+
125
+ def __rsub__(self, other: jax.Array) -> jax.Array:
126
+ return jnp.subtract(other, cast(jax.Array, self))
127
+
128
+ def __mul__(self, other: jax.Array) -> jax.Array:
129
+ return jnp.multiply(cast(jax.Array, self), other)
130
+
131
+ def __rmul__(self, other: jax.Array) -> jax.Array:
132
+ return jnp.multiply(other, cast(jax.Array, self))
133
+
134
+ def __truediv__(self, other: jax.Array) -> jax.Array:
135
+ return jnp.divide(cast(jax.Array, self), other)
136
+
137
+ def __rtruediv__(self, other: jax.Array) -> jax.Array:
138
+ return jnp.divide(other, cast(jax.Array, self))
139
+
140
+ def __pow__(self, other: jax.Array) -> jax.Array:
141
+ return jnp.power(cast(jax.Array, self), other)
142
+
143
+ def __rpow__(self, other: jax.Array) -> jax.Array:
144
+ return jnp.power(other, cast(jax.Array, self))
145
+
146
+ def __abs__(self) -> jax.Array:
147
+ return jnp.abs(cast(jax.Array, self))
148
+
149
+ def __eq__(self, other: Any):
150
+ return jnp.equal(cast(jax.Array, self), other)
151
+
152
+ def __ne__(self, other: Any):
153
+ return jnp.not_equal(cast(jax.Array, self), other)
154
+
155
+ def __floordiv__(self, other: jax.Array) -> jax.Array:
156
+ return jnp.floor_divide(cast(jax.Array, self), other)
157
+
158
+ def __rfloordiv__(self, other: jax.Array) -> jax.Array:
159
+ return jnp.floor_divide(other, cast(jax.Array, self))
160
+
161
+ def __mod__(self, other: jax.Array) -> jax.Array:
162
+ return jnp.mod(cast(jax.Array, self), other)
163
+
164
+ def __rmod__(self, other: jax.Array) -> jax.Array:
165
+ return jnp.mod(other, cast(jax.Array, self))
166
+
167
+ def __lt__(self, other: jax.Array) -> jax.Array:
168
+ return jnp.less(cast(jax.Array, self), other)
169
+
170
+ def __le__(self, other: jax.Array) -> jax.Array:
171
+ return jnp.less_equal(cast(jax.Array, self), other)
172
+
173
+ def __gt__(self, other: jax.Array) -> jax.Array:
174
+ return jnp.greater(cast(jax.Array, self), other)
175
+
176
+ def __ge__(self, other: jax.Array) -> jax.Array:
177
+ return jnp.greater_equal(cast(jax.Array, self), other)
178
+
179
+ def __lshift__(self, other: jax.Array) -> jax.Array:
180
+ return jnp.left_shift(cast(jax.Array, self), other)
181
+
182
+ def __rlshift__(self, other: jax.Array) -> jax.Array:
183
+ return jnp.left_shift(other, cast(jax.Array, self))
184
+
185
+ def __rshift__(self, other: jax.Array) -> jax.Array:
186
+ return jnp.right_shift(cast(jax.Array, self), other)
187
+
188
+ def __rrshift__(self, other: jax.Array) -> jax.Array:
189
+ return jnp.right_shift(other, cast(jax.Array, self))
190
+
191
+ def __and__(self, other: jax.Array) -> jax.Array:
192
+ return jnp.bitwise_and(cast(jax.Array, self), other)
193
+
194
+ def __rand__(self, other: jax.Array) -> jax.Array:
195
+ return jnp.bitwise_and(other, cast(jax.Array, self))
196
+
197
+ def __xor__(self, other: jax.Array) -> jax.Array:
198
+ return jnp.bitwise_xor(cast(jax.Array, self), other)
199
+
200
+ def __rxor__(self, other: jax.Array) -> jax.Array:
201
+ return jnp.bitwise_xor(other, cast(jax.Array, self))
202
+
203
+ def __or__(self, other: jax.Array) -> jax.Array:
204
+ return jnp.bitwise_or(cast(jax.Array, self), other)
205
+
206
+ def __ror__(self, other: jax.Array) -> jax.Array:
207
+ return jnp.bitwise_or(other, cast(jax.Array, self))
208
+
209
+ def __invert__(self) -> jax.Array:
210
+ return jnp.bitwise_not(cast(jax.Array, self))
211
+
212
+ def __matmul__(self, other: jax.Array) -> jax.Array:
213
+ return jnp.matmul(cast(jax.Array, self), other)
214
+
215
+ def __rmatmul__(self, other: jax.Array) -> jax.Array:
216
+ return jnp.matmul(other, cast(jax.Array, self))
217
+
218
+ @property
219
+ def at(self) -> _IndexUpdateHelper:
220
+ """Return an IndexUpdateHelper for array updates."""
221
+ return _IndexUpdateHelper(self)
222
+
223
+ def __iter__(self):
224
+ raise TypeError("A free array is not iterable.")
225
+
226
+ def all(self, axis=None, keepdims=False, *, where=None):
227
+ return jnp.all(cast(jax.Array, self), axis=axis, keepdims=keepdims, where=where)
228
+
229
+ def any(self, axis=None, keepdims=False, *, where=None):
230
+ return jnp.any(cast(jax.Array, self), axis=axis, keepdims=keepdims, where=where)
231
+
232
+ def argmax(self, axis=None, keepdims=False):
233
+ return jnp.argmax(cast(jax.Array, self), axis=axis, keepdims=keepdims)
234
+
235
+ def argmin(self, axis=None, keepdims=False):
236
+ return jnp.argmin(cast(jax.Array, self), axis=axis, keepdims=keepdims)
237
+
238
+ def argpartition(self, kth, axis=-1):
239
+ return jnp.argpartition(cast(jax.Array, self), kth, axis=axis)
240
+
241
+ def argsort(self, axis=-1, descending=False, stable=True):
242
+ return jnp.argsort(
243
+ cast(jax.Array, self), axis=axis, descending=descending, stable=stable
244
+ )
245
+
246
+ def astype(self, dtype):
247
+ return jnp.astype(cast(jax.Array, self), dtype)
248
+
249
+ def choose(self, choices, mode="raise"):
250
+ return jnp.choose(cast(jax.Array, self), choices, mode=mode)
251
+
252
+ def clip(self, min=None, max=None):
253
+ return jnp.clip(cast(jax.Array, self), min=min, max=max)
254
+
255
+ def compress(self, condition, axis=None):
256
+ return jnp.compress(condition, cast(jax.Array, self), axis=axis)
257
+
258
+ def conj(self):
259
+ return jnp.conj(cast(jax.Array, self))
260
+
261
+ def conjugate(self):
262
+ return jnp.conjugate(cast(jax.Array, self))
263
+
264
+ def copy(self):
265
+ return jnp.copy(cast(jax.Array, self))
266
+
267
+ def cumprod(self, axis=None, dtype=None):
268
+ return jnp.cumprod(cast(jax.Array, self), axis=axis, dtype=dtype)
269
+
270
+ def cumsum(self, axis=None, dtype=None):
271
+ return jnp.cumsum(cast(jax.Array, self), axis=axis, dtype=dtype)
272
+
273
+ def diagonal(self, offset=0, axis1=0, axis2=1):
274
+ return jnp.diagonal(
275
+ cast(jax.Array, self), offset=offset, axis1=axis1, axis2=axis2
276
+ )
277
+
278
+ def dot(self, b):
279
+ return jnp.dot(cast(jax.Array, self), b)
280
+
281
+ def max(self, axis=None, keepdims=False, initial=None, where=None):
282
+ return jnp.max(
283
+ cast(jax.Array, self),
284
+ axis=axis,
285
+ keepdims=keepdims,
286
+ initial=initial,
287
+ where=where,
288
+ )
289
+
290
+ def mean(self, axis=None, keepdims=False, *, where=None):
291
+ return jnp.mean(
292
+ cast(jax.Array, self), axis=axis, keepdims=keepdims, where=where
293
+ )
294
+
295
+ def min(self, axis=None, keepdims=False, initial=None, where=None):
296
+ return jnp.min(
297
+ cast(jax.Array, self),
298
+ axis=axis,
299
+ keepdims=keepdims,
300
+ initial=initial,
301
+ where=where,
302
+ )
303
+
304
+ def nonzero(self, *, size=None, fill_value=None):
305
+ return jnp.nonzero(cast(jax.Array, self), size=size, fill_value=fill_value)
306
+
307
+ def prod(
308
+ self, axis=None, keepdims=False, initial=None, where=None, promote_integers=True
309
+ ):
310
+ return jnp.prod(
311
+ cast(jax.Array, self),
312
+ axis=axis,
313
+ keepdims=keepdims,
314
+ initial=initial,
315
+ where=where,
316
+ promote_integers=promote_integers,
317
+ )
318
+
319
+ def ptp(self, axis=None, keepdims=False):
320
+ return jnp.ptp(cast(jax.Array, self), axis=axis, keepdims=keepdims)
321
+
322
+ def ravel(self, order="C"):
323
+ return jnp.ravel(cast(jax.Array, self), order=order)
324
+
325
+ def repeat(self, repeats, axis=None, *, total_repeat_length=None):
326
+ return jnp.repeat(
327
+ cast(jax.Array, self),
328
+ repeats,
329
+ axis=axis,
330
+ total_repeat_length=total_repeat_length,
331
+ )
332
+
333
+ def reshape(self, *shape, order="C"):
334
+ if len(shape) == 1 and isinstance(shape[0], tuple | list):
335
+ shape = shape[0]
336
+ return jnp.reshape(cast(jax.Array, self), shape, order=order)
337
+
338
+ def round(self, decimals=0):
339
+ return jnp.round(cast(jax.Array, self), decimals=decimals)
340
+
341
+ def searchsorted(self, v, side="left", sorter=None):
342
+ return jnp.searchsorted(cast(jax.Array, self), v, side=side, sorter=sorter)
343
+
344
+ def sort(self, axis=-1, descending=False, stable=True):
345
+ return jnp.sort(
346
+ cast(jax.Array, self), axis=axis, descending=descending, stable=stable
347
+ )
348
+
349
+ def squeeze(self, axis=None):
350
+ return jnp.squeeze(cast(jax.Array, self), axis=axis)
351
+
352
+ def std(self, axis=None, keepdims=False, ddof=0, *, where=None):
353
+ return jnp.std(
354
+ cast(jax.Array, self), axis=axis, keepdims=keepdims, ddof=ddof, where=where
355
+ )
356
+
357
+ def sum(
358
+ self, axis=None, keepdims=False, initial=None, where=None, promote_integers=True
359
+ ):
360
+ return jnp.sum(
361
+ cast(jax.Array, self),
362
+ axis=axis,
363
+ keepdims=keepdims,
364
+ initial=initial,
365
+ where=where,
366
+ promote_integers=promote_integers,
367
+ )
368
+
369
+ def swapaxes(self, axis1, axis2):
370
+ return jnp.swapaxes(cast(jax.Array, self), axis1, axis2)
371
+
372
+ def take(
373
+ self,
374
+ indices,
375
+ axis=None,
376
+ mode=None,
377
+ unique_indices=False,
378
+ indices_are_sorted=False,
379
+ fill_value=None,
380
+ ):
381
+ return jnp.take(
382
+ cast(jax.Array, self),
383
+ indices,
384
+ axis=axis,
385
+ mode=mode,
386
+ unique_indices=unique_indices,
387
+ indices_are_sorted=indices_are_sorted,
388
+ fill_value=fill_value,
389
+ )
390
+
391
+ def trace(self, offset=0, axis1=0, axis2=1, dtype=None):
392
+ return jnp.trace(
393
+ cast(jax.Array, self), offset=offset, axis1=axis1, axis2=axis2, dtype=dtype
394
+ )
395
+
396
+ def transpose(self, axes=None):
397
+ return jnp.transpose(cast(jax.Array, self), axes=axes)
398
+
399
+ def var(self, axis=None, keepdims=False, ddof=0, *, where=None):
400
+ return jnp.var(
401
+ cast(jax.Array, self), axis=axis, keepdims=keepdims, ddof=ddof, where=where
402
+ )
403
+
404
+
405
+ class _EagerArrayTerm(_ArrayTerm):
406
+ def __init__(self, op, tensor, key):
407
+ new_shape, new_key = _desugar_tensor_index(tensor.shape, key)
408
+ super().__init__(op, jax.numpy.reshape(tensor, new_shape), new_key)
409
+
410
+ def __iter__(self):
411
+ for i in range(len(self)):
412
+ yield self[i]
413
+
414
+ @property
415
+ def shape(self) -> tuple[int, ...]:
416
+ return tuple(
417
+ s
418
+ for s, k in zip(self.args[0].shape, self.args[1])
419
+ if not isinstance(k, Term)
420
+ )
421
+
422
+ @property
423
+ def size(self) -> int:
424
+ return functools.reduce(operator.mul, self.shape, 1)
425
+
426
+ @property
427
+ def ndim(self) -> int:
428
+ return len(self.shape)
429
+
430
+
431
+ @bind_dims.register # type: ignore
432
+ def _bind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array:
433
+ """Convert named dimensions to positional dimensions.
434
+
435
+ :param t: An array.
436
+ :param args: Named dimensions to convert to positional dimensions.
437
+ These positional dimensions will appear at the beginning of the
438
+ shape.
439
+ :return: An array with the named dimensions in ``args`` converted to positional dimensions.
440
+
441
+ **Example usage**:
442
+
443
+ >>> from effectful.ops.syntax import defop
444
+ >>> from effectful.handlers.jax import bind_dims, unbind_dims
445
+ >>> a, b = defop(jax.Array, name='a'), defop(jax.Array, name='b')
446
+ >>> t = unbind_dims(jnp.ones((2, 3)), a, b)
447
+ >>> bind_dims(t, b, a).shape
448
+ (3, 2)
449
+ """
450
+
451
+ def _evaluate(expr):
452
+ if isinstance(expr, Term):
453
+ (args, kwargs) = tree.map_structure(_evaluate, (expr.args, expr.kwargs))
454
+ return _partial_eval(expr)
455
+ if tree.is_nested(expr):
456
+ return tree.map_structure(_evaluate, expr)
457
+ return expr
458
+
459
+ if not isinstance(t, Term):
460
+ return t
461
+
462
+ result = _evaluate(t)
463
+ if not isinstance(result, Term) or not args:
464
+ return result
465
+
466
+ # ensure that the result is a jax_getitem with an array as the first argument
467
+ if not (result.op is jax_getitem and isinstance(result.args[0], jax.Array)):
468
+ raise NotHandled
469
+
470
+ array = result.args[0]
471
+ dims = result.args[1]
472
+ assert isinstance(dims, Sequence)
473
+
474
+ # ensure that the order is a subset of the named dimensions
475
+ order_set = set(args)
476
+ if not order_set <= set(a.op for a in dims if isinstance(a, Term)):
477
+ raise NotHandled
478
+
479
+ # permute the inner array so that the leading dimensions are in the order
480
+ # specified and the trailing dimensions are the remaining named dimensions
481
+ # (or slices)
482
+ reindex_dims = [
483
+ i
484
+ for i, o in enumerate(dims)
485
+ if not isinstance(o, Term) or o.op not in order_set
486
+ ]
487
+ dim_ops = [a.op if isinstance(a, Term) else None for a in dims]
488
+ perm = (
489
+ [dim_ops.index(o) for o in args]
490
+ + reindex_dims
491
+ + list(range(len(dims), len(array.shape)))
492
+ )
493
+ array = jnp.transpose(array, perm)
494
+ reindexed = jax_getitem(
495
+ array, (slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims)
496
+ )
497
+ return reindexed
498
+
499
+
500
+ @unbind_dims.register # type: ignore
501
+ def _unbind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array:
502
+ return jax_getitem(t, tuple(n() for n in args))
@@ -0,0 +1,23 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import jax.numpy
4
+
5
+ from .._handlers import _register_jax_op, _register_jax_op_no_partial_eval
6
+
7
+ _no_overload = ["array", "asarray"]
8
+
9
+ for name, op in jax.numpy.__dict__.items():
10
+ if not callable(op):
11
+ continue
12
+
13
+ jax_op = (
14
+ _register_jax_op_no_partial_eval(op)
15
+ if name in _no_overload
16
+ else _register_jax_op(op)
17
+ )
18
+ globals()[name] = jax_op
19
+
20
+
21
+ # Tell mypy about our wrapped functions.
22
+ if TYPE_CHECKING:
23
+ from jax.numpy import * # noqa: F403
@@ -0,0 +1,13 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import jax.numpy.linalg
4
+
5
+ from effectful.handlers.jax._handlers import _register_jax_op
6
+
7
+ for name, op in jax.numpy.linalg.__dict__.items():
8
+ if callable(op):
9
+ globals()[name] = _register_jax_op(op)
10
+
11
+ # Tell mypy about our wrapped functions.
12
+ if TYPE_CHECKING:
13
+ from jax.numpy.linalg import * # noqa: F403
@@ -0,0 +1,11 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ import jax.scipy.special
4
+
5
+ from effectful.handlers.jax._handlers import _register_jax_op
6
+
7
+ logsumexp = _register_jax_op(jax.scipy.special.logsumexp)
8
+
9
+ # Tell mypy about our wrapped functions.
10
+ if TYPE_CHECKING:
11
+ from jax.scipy.special import logsumexp # noqa: F401