effectful 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- effectful/handlers/indexed.py +27 -46
- effectful/handlers/jax/__init__.py +14 -0
- effectful/handlers/jax/_handlers.py +293 -0
- effectful/handlers/jax/_terms.py +502 -0
- effectful/handlers/jax/numpy/__init__.py +23 -0
- effectful/handlers/jax/numpy/linalg.py +13 -0
- effectful/handlers/jax/scipy/special.py +11 -0
- effectful/handlers/numpyro.py +562 -0
- effectful/handlers/pyro.py +565 -214
- effectful/handlers/torch.py +321 -169
- effectful/internals/runtime.py +6 -13
- effectful/internals/tensor_utils.py +32 -0
- effectful/internals/unification.py +900 -0
- effectful/ops/semantics.py +104 -84
- effectful/ops/syntax.py +1276 -167
- effectful/ops/types.py +141 -35
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/METADATA +65 -57
- effectful-0.2.0.dist-info/RECORD +26 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/WHEEL +1 -1
- effectful/handlers/numbers.py +0 -259
- effectful/internals/base_impl.py +0 -259
- effectful-0.0.1.dist-info/RECORD +0 -19
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info/licenses}/LICENSE.md +0 -0
- {effectful-0.0.1.dist-info → effectful-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,502 @@
|
|
1
|
+
import functools
|
2
|
+
import operator
|
3
|
+
from collections.abc import Sequence
|
4
|
+
from typing import Any, cast
|
5
|
+
|
6
|
+
import jax
|
7
|
+
import tree
|
8
|
+
|
9
|
+
import effectful.handlers.jax.numpy as jnp
|
10
|
+
from effectful.handlers.jax._handlers import (
|
11
|
+
IndexElement,
|
12
|
+
_partial_eval,
|
13
|
+
_register_jax_op,
|
14
|
+
bind_dims,
|
15
|
+
jax_getitem,
|
16
|
+
unbind_dims,
|
17
|
+
)
|
18
|
+
from effectful.internals.tensor_utils import _desugar_tensor_index
|
19
|
+
from effectful.ops.syntax import defdata
|
20
|
+
from effectful.ops.types import Expr, NotHandled, Operation, Term
|
21
|
+
|
22
|
+
|
23
|
+
class _IndexUpdateHelper:
|
24
|
+
"""Helper class to implement array-style .at[index].set() updates for effectful arrays."""
|
25
|
+
|
26
|
+
def __init__(self, array):
|
27
|
+
self.array = array
|
28
|
+
|
29
|
+
def __getitem__(self, key):
|
30
|
+
return _IndexUpdateRef(self.array, key)
|
31
|
+
|
32
|
+
|
33
|
+
class _IndexUpdateRef:
|
34
|
+
"""Reference to an array position for updates via .at[index]."""
|
35
|
+
|
36
|
+
def __init__(self, array, key):
|
37
|
+
self.array = array
|
38
|
+
self.key = key
|
39
|
+
|
40
|
+
def set(self, value):
|
41
|
+
"""Set values at the indexed positions."""
|
42
|
+
|
43
|
+
# Create a JAX at operation that properly handles the indexing
|
44
|
+
@_register_jax_op
|
45
|
+
def jax_at_set(arr, index_key, val):
|
46
|
+
# JAX's at expects the index to be unpacked correctly
|
47
|
+
if isinstance(index_key, tuple) and len(index_key) == 1:
|
48
|
+
# Single index case
|
49
|
+
return arr.at[index_key[0]].set(val)
|
50
|
+
elif isinstance(index_key, tuple):
|
51
|
+
# Multiple indices case
|
52
|
+
return arr.at[index_key].set(val)
|
53
|
+
else:
|
54
|
+
# Direct index case
|
55
|
+
return arr.at[index_key].set(val)
|
56
|
+
|
57
|
+
return jax_at_set(self.array, self.key, value)
|
58
|
+
|
59
|
+
|
60
|
+
@defdata.register(jax.Array)
|
61
|
+
def _embed_array(op, *args, **kwargs):
|
62
|
+
if (
|
63
|
+
op is jax_getitem
|
64
|
+
and not isinstance(args[0], Term)
|
65
|
+
and all(not k.args and not k.kwargs for k in args[1] if isinstance(k, Term))
|
66
|
+
):
|
67
|
+
return _EagerArrayTerm(jax_getitem, args[0], args[1])
|
68
|
+
else:
|
69
|
+
return _ArrayTerm(op, *args, **kwargs)
|
70
|
+
|
71
|
+
|
72
|
+
class _ArrayTerm(Term[jax.Array]):
|
73
|
+
def __init__(self, op: Operation[..., jax.Array], *args: Expr, **kwargs: Expr):
|
74
|
+
self._op = op
|
75
|
+
self._args = args
|
76
|
+
self._kwargs = kwargs
|
77
|
+
|
78
|
+
@property
|
79
|
+
def op(self) -> Operation[..., jax.Array]:
|
80
|
+
return self._op
|
81
|
+
|
82
|
+
@property
|
83
|
+
def args(self) -> tuple:
|
84
|
+
return self._args
|
85
|
+
|
86
|
+
@property
|
87
|
+
def kwargs(self) -> dict:
|
88
|
+
return self._kwargs
|
89
|
+
|
90
|
+
def __getitem__(
|
91
|
+
self, key: Expr[IndexElement] | tuple[Expr[IndexElement], ...]
|
92
|
+
) -> Expr[jax.Array]:
|
93
|
+
return jax_getitem(self, key if isinstance(key, tuple) else (key,))
|
94
|
+
|
95
|
+
@property
|
96
|
+
def shape(self) -> Expr[tuple[int, ...]]:
|
97
|
+
return jnp.shape(cast(jax.Array, self))
|
98
|
+
|
99
|
+
@property
|
100
|
+
def size(self) -> Expr[int]:
|
101
|
+
return jnp.size(cast(jax.Array, self))
|
102
|
+
|
103
|
+
def __len__(self):
|
104
|
+
return self.shape[0]
|
105
|
+
|
106
|
+
@property
|
107
|
+
def ndim(self) -> Expr[int]:
|
108
|
+
return jnp.ndim(cast(jax.Array, self))
|
109
|
+
|
110
|
+
def __add__(self, other: jax.Array) -> jax.Array:
|
111
|
+
return jnp.add(cast(jax.Array, self), other)
|
112
|
+
|
113
|
+
def __radd__(self, other: jax.Array) -> jax.Array:
|
114
|
+
return jnp.add(other, cast(jax.Array, self))
|
115
|
+
|
116
|
+
def __neg__(self) -> jax.Array:
|
117
|
+
return jnp.negative(cast(jax.Array, self))
|
118
|
+
|
119
|
+
def __pos__(self) -> jax.Array:
|
120
|
+
return jnp.positive(cast(jax.Array, self))
|
121
|
+
|
122
|
+
def __sub__(self, other: jax.Array) -> jax.Array:
|
123
|
+
return jnp.subtract(cast(jax.Array, self), other)
|
124
|
+
|
125
|
+
def __rsub__(self, other: jax.Array) -> jax.Array:
|
126
|
+
return jnp.subtract(other, cast(jax.Array, self))
|
127
|
+
|
128
|
+
def __mul__(self, other: jax.Array) -> jax.Array:
|
129
|
+
return jnp.multiply(cast(jax.Array, self), other)
|
130
|
+
|
131
|
+
def __rmul__(self, other: jax.Array) -> jax.Array:
|
132
|
+
return jnp.multiply(other, cast(jax.Array, self))
|
133
|
+
|
134
|
+
def __truediv__(self, other: jax.Array) -> jax.Array:
|
135
|
+
return jnp.divide(cast(jax.Array, self), other)
|
136
|
+
|
137
|
+
def __rtruediv__(self, other: jax.Array) -> jax.Array:
|
138
|
+
return jnp.divide(other, cast(jax.Array, self))
|
139
|
+
|
140
|
+
def __pow__(self, other: jax.Array) -> jax.Array:
|
141
|
+
return jnp.power(cast(jax.Array, self), other)
|
142
|
+
|
143
|
+
def __rpow__(self, other: jax.Array) -> jax.Array:
|
144
|
+
return jnp.power(other, cast(jax.Array, self))
|
145
|
+
|
146
|
+
def __abs__(self) -> jax.Array:
|
147
|
+
return jnp.abs(cast(jax.Array, self))
|
148
|
+
|
149
|
+
def __eq__(self, other: Any):
|
150
|
+
return jnp.equal(cast(jax.Array, self), other)
|
151
|
+
|
152
|
+
def __ne__(self, other: Any):
|
153
|
+
return jnp.not_equal(cast(jax.Array, self), other)
|
154
|
+
|
155
|
+
def __floordiv__(self, other: jax.Array) -> jax.Array:
|
156
|
+
return jnp.floor_divide(cast(jax.Array, self), other)
|
157
|
+
|
158
|
+
def __rfloordiv__(self, other: jax.Array) -> jax.Array:
|
159
|
+
return jnp.floor_divide(other, cast(jax.Array, self))
|
160
|
+
|
161
|
+
def __mod__(self, other: jax.Array) -> jax.Array:
|
162
|
+
return jnp.mod(cast(jax.Array, self), other)
|
163
|
+
|
164
|
+
def __rmod__(self, other: jax.Array) -> jax.Array:
|
165
|
+
return jnp.mod(other, cast(jax.Array, self))
|
166
|
+
|
167
|
+
def __lt__(self, other: jax.Array) -> jax.Array:
|
168
|
+
return jnp.less(cast(jax.Array, self), other)
|
169
|
+
|
170
|
+
def __le__(self, other: jax.Array) -> jax.Array:
|
171
|
+
return jnp.less_equal(cast(jax.Array, self), other)
|
172
|
+
|
173
|
+
def __gt__(self, other: jax.Array) -> jax.Array:
|
174
|
+
return jnp.greater(cast(jax.Array, self), other)
|
175
|
+
|
176
|
+
def __ge__(self, other: jax.Array) -> jax.Array:
|
177
|
+
return jnp.greater_equal(cast(jax.Array, self), other)
|
178
|
+
|
179
|
+
def __lshift__(self, other: jax.Array) -> jax.Array:
|
180
|
+
return jnp.left_shift(cast(jax.Array, self), other)
|
181
|
+
|
182
|
+
def __rlshift__(self, other: jax.Array) -> jax.Array:
|
183
|
+
return jnp.left_shift(other, cast(jax.Array, self))
|
184
|
+
|
185
|
+
def __rshift__(self, other: jax.Array) -> jax.Array:
|
186
|
+
return jnp.right_shift(cast(jax.Array, self), other)
|
187
|
+
|
188
|
+
def __rrshift__(self, other: jax.Array) -> jax.Array:
|
189
|
+
return jnp.right_shift(other, cast(jax.Array, self))
|
190
|
+
|
191
|
+
def __and__(self, other: jax.Array) -> jax.Array:
|
192
|
+
return jnp.bitwise_and(cast(jax.Array, self), other)
|
193
|
+
|
194
|
+
def __rand__(self, other: jax.Array) -> jax.Array:
|
195
|
+
return jnp.bitwise_and(other, cast(jax.Array, self))
|
196
|
+
|
197
|
+
def __xor__(self, other: jax.Array) -> jax.Array:
|
198
|
+
return jnp.bitwise_xor(cast(jax.Array, self), other)
|
199
|
+
|
200
|
+
def __rxor__(self, other: jax.Array) -> jax.Array:
|
201
|
+
return jnp.bitwise_xor(other, cast(jax.Array, self))
|
202
|
+
|
203
|
+
def __or__(self, other: jax.Array) -> jax.Array:
|
204
|
+
return jnp.bitwise_or(cast(jax.Array, self), other)
|
205
|
+
|
206
|
+
def __ror__(self, other: jax.Array) -> jax.Array:
|
207
|
+
return jnp.bitwise_or(other, cast(jax.Array, self))
|
208
|
+
|
209
|
+
def __invert__(self) -> jax.Array:
|
210
|
+
return jnp.bitwise_not(cast(jax.Array, self))
|
211
|
+
|
212
|
+
def __matmul__(self, other: jax.Array) -> jax.Array:
|
213
|
+
return jnp.matmul(cast(jax.Array, self), other)
|
214
|
+
|
215
|
+
def __rmatmul__(self, other: jax.Array) -> jax.Array:
|
216
|
+
return jnp.matmul(other, cast(jax.Array, self))
|
217
|
+
|
218
|
+
@property
|
219
|
+
def at(self) -> _IndexUpdateHelper:
|
220
|
+
"""Return an IndexUpdateHelper for array updates."""
|
221
|
+
return _IndexUpdateHelper(self)
|
222
|
+
|
223
|
+
def __iter__(self):
|
224
|
+
raise TypeError("A free array is not iterable.")
|
225
|
+
|
226
|
+
def all(self, axis=None, keepdims=False, *, where=None):
|
227
|
+
return jnp.all(cast(jax.Array, self), axis=axis, keepdims=keepdims, where=where)
|
228
|
+
|
229
|
+
def any(self, axis=None, keepdims=False, *, where=None):
|
230
|
+
return jnp.any(cast(jax.Array, self), axis=axis, keepdims=keepdims, where=where)
|
231
|
+
|
232
|
+
def argmax(self, axis=None, keepdims=False):
|
233
|
+
return jnp.argmax(cast(jax.Array, self), axis=axis, keepdims=keepdims)
|
234
|
+
|
235
|
+
def argmin(self, axis=None, keepdims=False):
|
236
|
+
return jnp.argmin(cast(jax.Array, self), axis=axis, keepdims=keepdims)
|
237
|
+
|
238
|
+
def argpartition(self, kth, axis=-1):
|
239
|
+
return jnp.argpartition(cast(jax.Array, self), kth, axis=axis)
|
240
|
+
|
241
|
+
def argsort(self, axis=-1, descending=False, stable=True):
|
242
|
+
return jnp.argsort(
|
243
|
+
cast(jax.Array, self), axis=axis, descending=descending, stable=stable
|
244
|
+
)
|
245
|
+
|
246
|
+
def astype(self, dtype):
|
247
|
+
return jnp.astype(cast(jax.Array, self), dtype)
|
248
|
+
|
249
|
+
def choose(self, choices, mode="raise"):
|
250
|
+
return jnp.choose(cast(jax.Array, self), choices, mode=mode)
|
251
|
+
|
252
|
+
def clip(self, min=None, max=None):
|
253
|
+
return jnp.clip(cast(jax.Array, self), min=min, max=max)
|
254
|
+
|
255
|
+
def compress(self, condition, axis=None):
|
256
|
+
return jnp.compress(condition, cast(jax.Array, self), axis=axis)
|
257
|
+
|
258
|
+
def conj(self):
|
259
|
+
return jnp.conj(cast(jax.Array, self))
|
260
|
+
|
261
|
+
def conjugate(self):
|
262
|
+
return jnp.conjugate(cast(jax.Array, self))
|
263
|
+
|
264
|
+
def copy(self):
|
265
|
+
return jnp.copy(cast(jax.Array, self))
|
266
|
+
|
267
|
+
def cumprod(self, axis=None, dtype=None):
|
268
|
+
return jnp.cumprod(cast(jax.Array, self), axis=axis, dtype=dtype)
|
269
|
+
|
270
|
+
def cumsum(self, axis=None, dtype=None):
|
271
|
+
return jnp.cumsum(cast(jax.Array, self), axis=axis, dtype=dtype)
|
272
|
+
|
273
|
+
def diagonal(self, offset=0, axis1=0, axis2=1):
|
274
|
+
return jnp.diagonal(
|
275
|
+
cast(jax.Array, self), offset=offset, axis1=axis1, axis2=axis2
|
276
|
+
)
|
277
|
+
|
278
|
+
def dot(self, b):
|
279
|
+
return jnp.dot(cast(jax.Array, self), b)
|
280
|
+
|
281
|
+
def max(self, axis=None, keepdims=False, initial=None, where=None):
|
282
|
+
return jnp.max(
|
283
|
+
cast(jax.Array, self),
|
284
|
+
axis=axis,
|
285
|
+
keepdims=keepdims,
|
286
|
+
initial=initial,
|
287
|
+
where=where,
|
288
|
+
)
|
289
|
+
|
290
|
+
def mean(self, axis=None, keepdims=False, *, where=None):
|
291
|
+
return jnp.mean(
|
292
|
+
cast(jax.Array, self), axis=axis, keepdims=keepdims, where=where
|
293
|
+
)
|
294
|
+
|
295
|
+
def min(self, axis=None, keepdims=False, initial=None, where=None):
|
296
|
+
return jnp.min(
|
297
|
+
cast(jax.Array, self),
|
298
|
+
axis=axis,
|
299
|
+
keepdims=keepdims,
|
300
|
+
initial=initial,
|
301
|
+
where=where,
|
302
|
+
)
|
303
|
+
|
304
|
+
def nonzero(self, *, size=None, fill_value=None):
|
305
|
+
return jnp.nonzero(cast(jax.Array, self), size=size, fill_value=fill_value)
|
306
|
+
|
307
|
+
def prod(
|
308
|
+
self, axis=None, keepdims=False, initial=None, where=None, promote_integers=True
|
309
|
+
):
|
310
|
+
return jnp.prod(
|
311
|
+
cast(jax.Array, self),
|
312
|
+
axis=axis,
|
313
|
+
keepdims=keepdims,
|
314
|
+
initial=initial,
|
315
|
+
where=where,
|
316
|
+
promote_integers=promote_integers,
|
317
|
+
)
|
318
|
+
|
319
|
+
def ptp(self, axis=None, keepdims=False):
|
320
|
+
return jnp.ptp(cast(jax.Array, self), axis=axis, keepdims=keepdims)
|
321
|
+
|
322
|
+
def ravel(self, order="C"):
|
323
|
+
return jnp.ravel(cast(jax.Array, self), order=order)
|
324
|
+
|
325
|
+
def repeat(self, repeats, axis=None, *, total_repeat_length=None):
|
326
|
+
return jnp.repeat(
|
327
|
+
cast(jax.Array, self),
|
328
|
+
repeats,
|
329
|
+
axis=axis,
|
330
|
+
total_repeat_length=total_repeat_length,
|
331
|
+
)
|
332
|
+
|
333
|
+
def reshape(self, *shape, order="C"):
|
334
|
+
if len(shape) == 1 and isinstance(shape[0], tuple | list):
|
335
|
+
shape = shape[0]
|
336
|
+
return jnp.reshape(cast(jax.Array, self), shape, order=order)
|
337
|
+
|
338
|
+
def round(self, decimals=0):
|
339
|
+
return jnp.round(cast(jax.Array, self), decimals=decimals)
|
340
|
+
|
341
|
+
def searchsorted(self, v, side="left", sorter=None):
|
342
|
+
return jnp.searchsorted(cast(jax.Array, self), v, side=side, sorter=sorter)
|
343
|
+
|
344
|
+
def sort(self, axis=-1, descending=False, stable=True):
|
345
|
+
return jnp.sort(
|
346
|
+
cast(jax.Array, self), axis=axis, descending=descending, stable=stable
|
347
|
+
)
|
348
|
+
|
349
|
+
def squeeze(self, axis=None):
|
350
|
+
return jnp.squeeze(cast(jax.Array, self), axis=axis)
|
351
|
+
|
352
|
+
def std(self, axis=None, keepdims=False, ddof=0, *, where=None):
|
353
|
+
return jnp.std(
|
354
|
+
cast(jax.Array, self), axis=axis, keepdims=keepdims, ddof=ddof, where=where
|
355
|
+
)
|
356
|
+
|
357
|
+
def sum(
|
358
|
+
self, axis=None, keepdims=False, initial=None, where=None, promote_integers=True
|
359
|
+
):
|
360
|
+
return jnp.sum(
|
361
|
+
cast(jax.Array, self),
|
362
|
+
axis=axis,
|
363
|
+
keepdims=keepdims,
|
364
|
+
initial=initial,
|
365
|
+
where=where,
|
366
|
+
promote_integers=promote_integers,
|
367
|
+
)
|
368
|
+
|
369
|
+
def swapaxes(self, axis1, axis2):
|
370
|
+
return jnp.swapaxes(cast(jax.Array, self), axis1, axis2)
|
371
|
+
|
372
|
+
def take(
|
373
|
+
self,
|
374
|
+
indices,
|
375
|
+
axis=None,
|
376
|
+
mode=None,
|
377
|
+
unique_indices=False,
|
378
|
+
indices_are_sorted=False,
|
379
|
+
fill_value=None,
|
380
|
+
):
|
381
|
+
return jnp.take(
|
382
|
+
cast(jax.Array, self),
|
383
|
+
indices,
|
384
|
+
axis=axis,
|
385
|
+
mode=mode,
|
386
|
+
unique_indices=unique_indices,
|
387
|
+
indices_are_sorted=indices_are_sorted,
|
388
|
+
fill_value=fill_value,
|
389
|
+
)
|
390
|
+
|
391
|
+
def trace(self, offset=0, axis1=0, axis2=1, dtype=None):
|
392
|
+
return jnp.trace(
|
393
|
+
cast(jax.Array, self), offset=offset, axis1=axis1, axis2=axis2, dtype=dtype
|
394
|
+
)
|
395
|
+
|
396
|
+
def transpose(self, axes=None):
|
397
|
+
return jnp.transpose(cast(jax.Array, self), axes=axes)
|
398
|
+
|
399
|
+
def var(self, axis=None, keepdims=False, ddof=0, *, where=None):
|
400
|
+
return jnp.var(
|
401
|
+
cast(jax.Array, self), axis=axis, keepdims=keepdims, ddof=ddof, where=where
|
402
|
+
)
|
403
|
+
|
404
|
+
|
405
|
+
class _EagerArrayTerm(_ArrayTerm):
|
406
|
+
def __init__(self, op, tensor, key):
|
407
|
+
new_shape, new_key = _desugar_tensor_index(tensor.shape, key)
|
408
|
+
super().__init__(op, jax.numpy.reshape(tensor, new_shape), new_key)
|
409
|
+
|
410
|
+
def __iter__(self):
|
411
|
+
for i in range(len(self)):
|
412
|
+
yield self[i]
|
413
|
+
|
414
|
+
@property
|
415
|
+
def shape(self) -> tuple[int, ...]:
|
416
|
+
return tuple(
|
417
|
+
s
|
418
|
+
for s, k in zip(self.args[0].shape, self.args[1])
|
419
|
+
if not isinstance(k, Term)
|
420
|
+
)
|
421
|
+
|
422
|
+
@property
|
423
|
+
def size(self) -> int:
|
424
|
+
return functools.reduce(operator.mul, self.shape, 1)
|
425
|
+
|
426
|
+
@property
|
427
|
+
def ndim(self) -> int:
|
428
|
+
return len(self.shape)
|
429
|
+
|
430
|
+
|
431
|
+
@bind_dims.register # type: ignore
|
432
|
+
def _bind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array:
|
433
|
+
"""Convert named dimensions to positional dimensions.
|
434
|
+
|
435
|
+
:param t: An array.
|
436
|
+
:param args: Named dimensions to convert to positional dimensions.
|
437
|
+
These positional dimensions will appear at the beginning of the
|
438
|
+
shape.
|
439
|
+
:return: An array with the named dimensions in ``args`` converted to positional dimensions.
|
440
|
+
|
441
|
+
**Example usage**:
|
442
|
+
|
443
|
+
>>> from effectful.ops.syntax import defop
|
444
|
+
>>> from effectful.handlers.jax import bind_dims, unbind_dims
|
445
|
+
>>> a, b = defop(jax.Array, name='a'), defop(jax.Array, name='b')
|
446
|
+
>>> t = unbind_dims(jnp.ones((2, 3)), a, b)
|
447
|
+
>>> bind_dims(t, b, a).shape
|
448
|
+
(3, 2)
|
449
|
+
"""
|
450
|
+
|
451
|
+
def _evaluate(expr):
|
452
|
+
if isinstance(expr, Term):
|
453
|
+
(args, kwargs) = tree.map_structure(_evaluate, (expr.args, expr.kwargs))
|
454
|
+
return _partial_eval(expr)
|
455
|
+
if tree.is_nested(expr):
|
456
|
+
return tree.map_structure(_evaluate, expr)
|
457
|
+
return expr
|
458
|
+
|
459
|
+
if not isinstance(t, Term):
|
460
|
+
return t
|
461
|
+
|
462
|
+
result = _evaluate(t)
|
463
|
+
if not isinstance(result, Term) or not args:
|
464
|
+
return result
|
465
|
+
|
466
|
+
# ensure that the result is a jax_getitem with an array as the first argument
|
467
|
+
if not (result.op is jax_getitem and isinstance(result.args[0], jax.Array)):
|
468
|
+
raise NotHandled
|
469
|
+
|
470
|
+
array = result.args[0]
|
471
|
+
dims = result.args[1]
|
472
|
+
assert isinstance(dims, Sequence)
|
473
|
+
|
474
|
+
# ensure that the order is a subset of the named dimensions
|
475
|
+
order_set = set(args)
|
476
|
+
if not order_set <= set(a.op for a in dims if isinstance(a, Term)):
|
477
|
+
raise NotHandled
|
478
|
+
|
479
|
+
# permute the inner array so that the leading dimensions are in the order
|
480
|
+
# specified and the trailing dimensions are the remaining named dimensions
|
481
|
+
# (or slices)
|
482
|
+
reindex_dims = [
|
483
|
+
i
|
484
|
+
for i, o in enumerate(dims)
|
485
|
+
if not isinstance(o, Term) or o.op not in order_set
|
486
|
+
]
|
487
|
+
dim_ops = [a.op if isinstance(a, Term) else None for a in dims]
|
488
|
+
perm = (
|
489
|
+
[dim_ops.index(o) for o in args]
|
490
|
+
+ reindex_dims
|
491
|
+
+ list(range(len(dims), len(array.shape)))
|
492
|
+
)
|
493
|
+
array = jnp.transpose(array, perm)
|
494
|
+
reindexed = jax_getitem(
|
495
|
+
array, (slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims)
|
496
|
+
)
|
497
|
+
return reindexed
|
498
|
+
|
499
|
+
|
500
|
+
@unbind_dims.register # type: ignore
|
501
|
+
def _unbind_dims_array(t: jax.Array, *args: Operation[[], jax.Array]) -> jax.Array:
|
502
|
+
return jax_getitem(t, tuple(n() for n in args))
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
import jax.numpy
|
4
|
+
|
5
|
+
from .._handlers import _register_jax_op, _register_jax_op_no_partial_eval
|
6
|
+
|
7
|
+
_no_overload = ["array", "asarray"]
|
8
|
+
|
9
|
+
for name, op in jax.numpy.__dict__.items():
|
10
|
+
if not callable(op):
|
11
|
+
continue
|
12
|
+
|
13
|
+
jax_op = (
|
14
|
+
_register_jax_op_no_partial_eval(op)
|
15
|
+
if name in _no_overload
|
16
|
+
else _register_jax_op(op)
|
17
|
+
)
|
18
|
+
globals()[name] = jax_op
|
19
|
+
|
20
|
+
|
21
|
+
# Tell mypy about our wrapped functions.
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from jax.numpy import * # noqa: F403
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
import jax.numpy.linalg
|
4
|
+
|
5
|
+
from effectful.handlers.jax._handlers import _register_jax_op
|
6
|
+
|
7
|
+
for name, op in jax.numpy.linalg.__dict__.items():
|
8
|
+
if callable(op):
|
9
|
+
globals()[name] = _register_jax_op(op)
|
10
|
+
|
11
|
+
# Tell mypy about our wrapped functions.
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from jax.numpy.linalg import * # noqa: F403
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
import jax.scipy.special
|
4
|
+
|
5
|
+
from effectful.handlers.jax._handlers import _register_jax_op
|
6
|
+
|
7
|
+
logsumexp = _register_jax_op(jax.scipy.special.logsumexp)
|
8
|
+
|
9
|
+
# Tell mypy about our wrapped functions.
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from jax.scipy.special import logsumexp # noqa: F401
|