array-api-compat 1.10.0__tar.gz → 1.11__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {array_api_compat-1.10.0 → array_api_compat-1.11}/PKG-INFO +11 -2
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/__init__.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_aliases.py +31 -6
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_fft.py +27 -5
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_helpers.py +131 -43
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_typing.py +3 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/__init__.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/_aliases.py +19 -4
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/_info.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/array/__init__.py +1 -1
- array_api_compat-1.11/array_api_compat/dask/array/_aliases.py +348 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/array/_info.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/__init__.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/_aliases.py +15 -3
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/_info.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/__init__.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/_aliases.py +63 -6
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/_info.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/PKG-INFO +11 -2
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/SOURCES.txt +2 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/setup.py +1 -1
- {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_all.py +8 -3
- {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_array_namespace.py +10 -8
- {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_common.py +143 -25
- array_api_compat-1.11/tests/test_dask.py +179 -0
- array_api_compat-1.11/tests/test_jax.py +34 -0
- array_api_compat-1.10.0/array_api_compat/dask/array/_aliases.py +0 -219
- {array_api_compat-1.10.0 → array_api_compat-1.11}/LICENSE +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/README.md +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/_internal.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/__init__.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_linalg.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/_typing.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/fft.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/linalg.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/__init__.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/array/fft.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/array/linalg.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/_typing.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/fft.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/linalg.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/fft.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/linalg.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/dependency_links.txt +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/requires.txt +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/top_level.txt +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/setup.cfg +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_isdtype.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_no_dependencies.py +0 -0
- {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_vendoring.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
2
|
Name: array_api_compat
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.11
|
|
4
4
|
Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
|
|
5
5
|
Home-page: https://data-apis.org/array-api-compat/
|
|
6
6
|
Author: Consortium for Python Data API Standards
|
|
@@ -28,6 +28,15 @@ Provides-Extra: dask
|
|
|
28
28
|
Requires-Dist: dask; extra == "dask"
|
|
29
29
|
Provides-Extra: sparse
|
|
30
30
|
Requires-Dist: sparse>=0.15.1; extra == "sparse"
|
|
31
|
+
Dynamic: author
|
|
32
|
+
Dynamic: classifier
|
|
33
|
+
Dynamic: description
|
|
34
|
+
Dynamic: description-content-type
|
|
35
|
+
Dynamic: home-page
|
|
36
|
+
Dynamic: license
|
|
37
|
+
Dynamic: provides-extra
|
|
38
|
+
Dynamic: requires-python
|
|
39
|
+
Dynamic: summary
|
|
31
40
|
|
|
32
41
|
# Array API compatibility library
|
|
33
42
|
|
|
@@ -233,11 +233,6 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
|
|
|
233
233
|
**kwargs,
|
|
234
234
|
)
|
|
235
235
|
|
|
236
|
-
def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
|
|
237
|
-
if not copy and dtype == x.dtype:
|
|
238
|
-
return x
|
|
239
|
-
return x.astype(dtype=dtype, copy=copy)
|
|
240
|
-
|
|
241
236
|
# These functions have different keyword argument names
|
|
242
237
|
|
|
243
238
|
def std(
|
|
@@ -297,6 +292,36 @@ def cumulative_sum(
|
|
|
297
292
|
)
|
|
298
293
|
return res
|
|
299
294
|
|
|
295
|
+
|
|
296
|
+
def cumulative_prod(
|
|
297
|
+
x: ndarray,
|
|
298
|
+
/,
|
|
299
|
+
xp,
|
|
300
|
+
*,
|
|
301
|
+
axis: Optional[int] = None,
|
|
302
|
+
dtype: Optional[Dtype] = None,
|
|
303
|
+
include_initial: bool = False,
|
|
304
|
+
**kwargs
|
|
305
|
+
) -> ndarray:
|
|
306
|
+
wrapped_xp = array_namespace(x)
|
|
307
|
+
|
|
308
|
+
if axis is None:
|
|
309
|
+
if x.ndim > 1:
|
|
310
|
+
raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
|
|
311
|
+
axis = 0
|
|
312
|
+
|
|
313
|
+
res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
|
|
314
|
+
|
|
315
|
+
# np.cumprod does not support include_initial
|
|
316
|
+
if include_initial:
|
|
317
|
+
initial_shape = list(x.shape)
|
|
318
|
+
initial_shape[axis] = 1
|
|
319
|
+
res = xp.concatenate(
|
|
320
|
+
[wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
|
|
321
|
+
axis=axis,
|
|
322
|
+
)
|
|
323
|
+
return res
|
|
324
|
+
|
|
300
325
|
# The min and max argument names in clip are different and not optional in numpy, and type
|
|
301
326
|
# promotion behavior is different.
|
|
302
327
|
def clip(
|
|
@@ -549,7 +574,7 @@ __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
|
|
|
549
574
|
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
|
|
550
575
|
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
|
|
551
576
|
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
|
|
552
|
-
'
|
|
577
|
+
'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims',
|
|
553
578
|
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
|
|
554
579
|
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
|
|
555
580
|
'unstack', 'sign']
|
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING, Union, Optional, Literal
|
|
4
4
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
|
-
from ._typing import Device, ndarray
|
|
6
|
+
from ._typing import Device, ndarray, DType
|
|
7
7
|
from collections.abc import Sequence
|
|
8
8
|
|
|
9
9
|
# Note: NumPy fft functions improperly upcast float32 and complex64 to
|
|
@@ -149,15 +149,37 @@ def ihfft(
|
|
|
149
149
|
return res.astype(xp.complex64)
|
|
150
150
|
return res
|
|
151
151
|
|
|
152
|
-
def fftfreq(
|
|
152
|
+
def fftfreq(
|
|
153
|
+
n: int,
|
|
154
|
+
/,
|
|
155
|
+
xp,
|
|
156
|
+
*,
|
|
157
|
+
d: float = 1.0,
|
|
158
|
+
dtype: Optional[DType] = None,
|
|
159
|
+
device: Optional[Device] = None
|
|
160
|
+
) -> ndarray:
|
|
153
161
|
if device not in ["cpu", None]:
|
|
154
162
|
raise ValueError(f"Unsupported device {device!r}")
|
|
155
|
-
|
|
163
|
+
res = xp.fft.fftfreq(n, d=d)
|
|
164
|
+
if dtype is not None:
|
|
165
|
+
return res.astype(dtype)
|
|
166
|
+
return res
|
|
156
167
|
|
|
157
|
-
def rfftfreq(
|
|
168
|
+
def rfftfreq(
|
|
169
|
+
n: int,
|
|
170
|
+
/,
|
|
171
|
+
xp,
|
|
172
|
+
*,
|
|
173
|
+
d: float = 1.0,
|
|
174
|
+
dtype: Optional[DType] = None,
|
|
175
|
+
device: Optional[Device] = None
|
|
176
|
+
) -> ndarray:
|
|
158
177
|
if device not in ["cpu", None]:
|
|
159
178
|
raise ValueError(f"Unsupported device {device!r}")
|
|
160
|
-
|
|
179
|
+
res = xp.fft.rfftfreq(n, d=d)
|
|
180
|
+
if dtype is not None:
|
|
181
|
+
return res.astype(dtype)
|
|
182
|
+
return res
|
|
161
183
|
|
|
162
184
|
def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
|
|
163
185
|
return xp.fft.fftshift(x, axes=axes)
|
|
@@ -11,14 +11,14 @@ from typing import TYPE_CHECKING
|
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from typing import Optional, Union, Any
|
|
14
|
-
from ._typing import Array, Device
|
|
14
|
+
from ._typing import Array, Device, Namespace
|
|
15
15
|
|
|
16
16
|
import sys
|
|
17
17
|
import math
|
|
18
18
|
import inspect
|
|
19
19
|
import warnings
|
|
20
20
|
|
|
21
|
-
def _is_jax_zero_gradient_array(x):
|
|
21
|
+
def _is_jax_zero_gradient_array(x: object) -> bool:
|
|
22
22
|
"""Return True if `x` is a zero-gradient array.
|
|
23
23
|
|
|
24
24
|
These arrays are a design quirk of Jax that may one day be removed.
|
|
@@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):
|
|
|
32
32
|
|
|
33
33
|
return isinstance(x, np.ndarray) and x.dtype == jax.float0
|
|
34
34
|
|
|
35
|
-
|
|
35
|
+
|
|
36
|
+
def is_numpy_array(x: object) -> bool:
|
|
36
37
|
"""
|
|
37
38
|
Return True if `x` is a NumPy array.
|
|
38
39
|
|
|
@@ -63,7 +64,8 @@ def is_numpy_array(x):
|
|
|
63
64
|
return (isinstance(x, (np.ndarray, np.generic))
|
|
64
65
|
and not _is_jax_zero_gradient_array(x))
|
|
65
66
|
|
|
66
|
-
|
|
67
|
+
|
|
68
|
+
def is_cupy_array(x: object) -> bool:
|
|
67
69
|
"""
|
|
68
70
|
Return True if `x` is a CuPy array.
|
|
69
71
|
|
|
@@ -93,7 +95,8 @@ def is_cupy_array(x):
|
|
|
93
95
|
# TODO: Should we reject ndarray subclasses?
|
|
94
96
|
return isinstance(x, cp.ndarray)
|
|
95
97
|
|
|
96
|
-
|
|
98
|
+
|
|
99
|
+
def is_torch_array(x: object) -> bool:
|
|
97
100
|
"""
|
|
98
101
|
Return True if `x` is a PyTorch tensor.
|
|
99
102
|
|
|
@@ -120,7 +123,8 @@ def is_torch_array(x):
|
|
|
120
123
|
# TODO: Should we reject ndarray subclasses?
|
|
121
124
|
return isinstance(x, torch.Tensor)
|
|
122
125
|
|
|
123
|
-
|
|
126
|
+
|
|
127
|
+
def is_ndonnx_array(x: object) -> bool:
|
|
124
128
|
"""
|
|
125
129
|
Return True if `x` is a ndonnx Array.
|
|
126
130
|
|
|
@@ -147,7 +151,8 @@ def is_ndonnx_array(x):
|
|
|
147
151
|
|
|
148
152
|
return isinstance(x, ndx.Array)
|
|
149
153
|
|
|
150
|
-
|
|
154
|
+
|
|
155
|
+
def is_dask_array(x: object) -> bool:
|
|
151
156
|
"""
|
|
152
157
|
Return True if `x` is a dask.array Array.
|
|
153
158
|
|
|
@@ -174,7 +179,8 @@ def is_dask_array(x):
|
|
|
174
179
|
|
|
175
180
|
return isinstance(x, dask.array.Array)
|
|
176
181
|
|
|
177
|
-
|
|
182
|
+
|
|
183
|
+
def is_jax_array(x: object) -> bool:
|
|
178
184
|
"""
|
|
179
185
|
Return True if `x` is a JAX array.
|
|
180
186
|
|
|
@@ -202,6 +208,7 @@ def is_jax_array(x):
|
|
|
202
208
|
|
|
203
209
|
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
|
|
204
210
|
|
|
211
|
+
|
|
205
212
|
def is_pydata_sparse_array(x) -> bool:
|
|
206
213
|
"""
|
|
207
214
|
Return True if `x` is an array from the `sparse` package.
|
|
@@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
|
|
|
231
238
|
# TODO: Account for other backends.
|
|
232
239
|
return isinstance(x, sparse.SparseArray)
|
|
233
240
|
|
|
234
|
-
|
|
241
|
+
|
|
242
|
+
def is_array_api_obj(x: object) -> bool:
|
|
235
243
|
"""
|
|
236
244
|
Return True if `x` is an array API compatible array object.
|
|
237
245
|
|
|
@@ -254,10 +262,12 @@ def is_array_api_obj(x):
|
|
|
254
262
|
or is_pydata_sparse_array(x) \
|
|
255
263
|
or hasattr(x, '__array_namespace__')
|
|
256
264
|
|
|
257
|
-
|
|
265
|
+
|
|
266
|
+
def _compat_module_name() -> str:
|
|
258
267
|
assert __name__.endswith('.common._helpers')
|
|
259
268
|
return __name__.removesuffix('.common._helpers')
|
|
260
269
|
|
|
270
|
+
|
|
261
271
|
def is_numpy_namespace(xp) -> bool:
|
|
262
272
|
"""
|
|
263
273
|
Returns True if `xp` is a NumPy namespace.
|
|
@@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
|
|
|
278
288
|
"""
|
|
279
289
|
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
|
|
280
290
|
|
|
291
|
+
|
|
281
292
|
def is_cupy_namespace(xp) -> bool:
|
|
282
293
|
"""
|
|
283
294
|
Returns True if `xp` is a CuPy namespace.
|
|
@@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
|
|
|
298
309
|
"""
|
|
299
310
|
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
|
|
300
311
|
|
|
312
|
+
|
|
301
313
|
def is_torch_namespace(xp) -> bool:
|
|
302
314
|
"""
|
|
303
315
|
Returns True if `xp` is a PyTorch namespace.
|
|
@@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
|
|
|
319
331
|
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
|
|
320
332
|
|
|
321
333
|
|
|
322
|
-
def is_ndonnx_namespace(xp):
|
|
334
|
+
def is_ndonnx_namespace(xp) -> bool:
|
|
323
335
|
"""
|
|
324
336
|
Returns True if `xp` is an NDONNX namespace.
|
|
325
337
|
|
|
@@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
|
|
|
337
349
|
"""
|
|
338
350
|
return xp.__name__ == 'ndonnx'
|
|
339
351
|
|
|
340
|
-
|
|
352
|
+
|
|
353
|
+
def is_dask_namespace(xp) -> bool:
|
|
341
354
|
"""
|
|
342
355
|
Returns True if `xp` is a Dask namespace.
|
|
343
356
|
|
|
@@ -357,7 +370,8 @@ def is_dask_namespace(xp):
|
|
|
357
370
|
"""
|
|
358
371
|
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
|
|
359
372
|
|
|
360
|
-
|
|
373
|
+
|
|
374
|
+
def is_jax_namespace(xp) -> bool:
|
|
361
375
|
"""
|
|
362
376
|
Returns True if `xp` is a JAX namespace.
|
|
363
377
|
|
|
@@ -378,7 +392,8 @@ def is_jax_namespace(xp):
|
|
|
378
392
|
"""
|
|
379
393
|
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
|
|
380
394
|
|
|
381
|
-
|
|
395
|
+
|
|
396
|
+
def is_pydata_sparse_namespace(xp) -> bool:
|
|
382
397
|
"""
|
|
383
398
|
Returns True if `xp` is a pydata/sparse namespace.
|
|
384
399
|
|
|
@@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
|
|
|
396
411
|
"""
|
|
397
412
|
return xp.__name__ == 'sparse'
|
|
398
413
|
|
|
399
|
-
|
|
414
|
+
|
|
415
|
+
def is_array_api_strict_namespace(xp) -> bool:
|
|
400
416
|
"""
|
|
401
417
|
Returns True if `xp` is an array-api-strict namespace.
|
|
402
418
|
|
|
@@ -414,14 +430,16 @@ def is_array_api_strict_namespace(xp):
|
|
|
414
430
|
"""
|
|
415
431
|
return xp.__name__ == 'array_api_strict'
|
|
416
432
|
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
433
|
+
|
|
434
|
+
def _check_api_version(api_version: str) -> None:
|
|
435
|
+
if api_version in ['2021.12', '2022.12', '2023.12']:
|
|
436
|
+
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12")
|
|
420
437
|
elif api_version is not None and api_version not in ['2021.12', '2022.12',
|
|
421
|
-
'2023.12']:
|
|
422
|
-
raise ValueError("Only the
|
|
438
|
+
'2023.12', '2024.12']:
|
|
439
|
+
raise ValueError("Only the 2024.12 version of the array API specification is currently supported")
|
|
423
440
|
|
|
424
|
-
|
|
441
|
+
|
|
442
|
+
def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace:
|
|
425
443
|
"""
|
|
426
444
|
Get the array API compatible namespace for the arrays `xs`.
|
|
427
445
|
|
|
@@ -433,7 +451,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):
|
|
|
433
451
|
|
|
434
452
|
api_version: str
|
|
435
453
|
The newest version of the spec that you need support for (currently
|
|
436
|
-
the compat library wrapped APIs support
|
|
454
|
+
the compat library wrapped APIs support v2024.12).
|
|
437
455
|
|
|
438
456
|
use_compat: bool or None
|
|
439
457
|
If None (the default), the native namespace will be returned if it is
|
|
@@ -630,24 +648,24 @@ def device(x: Array, /) -> Device:
|
|
|
630
648
|
if is_numpy_array(x):
|
|
631
649
|
return "cpu"
|
|
632
650
|
elif is_dask_array(x):
|
|
633
|
-
# Peek at the metadata of the
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
# Must be on CPU since backed by numpy
|
|
638
|
-
return "cpu"
|
|
639
|
-
except ImportError:
|
|
640
|
-
pass
|
|
651
|
+
# Peek at the metadata of the Dask array to determine type
|
|
652
|
+
if is_numpy_array(x._meta):
|
|
653
|
+
# Must be on CPU since backed by numpy
|
|
654
|
+
return "cpu"
|
|
641
655
|
return _DASK_DEVICE
|
|
642
656
|
elif is_jax_array(x):
|
|
643
|
-
#
|
|
644
|
-
#
|
|
645
|
-
#
|
|
646
|
-
#
|
|
647
|
-
|
|
648
|
-
|
|
657
|
+
# FIXME Jitted JAX arrays do not have a device attribute
|
|
658
|
+
# https://github.com/jax-ml/jax/issues/26000
|
|
659
|
+
# Return None in this case. Note that this workaround breaks
|
|
660
|
+
# the standard and will result in new arrays being created on the
|
|
661
|
+
# default device instead of the same device as the input array(s).
|
|
662
|
+
x_device = getattr(x, 'device', None)
|
|
663
|
+
# Older JAX releases had .device() as a method, which has been replaced
|
|
664
|
+
# with a property in accordance with the standard.
|
|
665
|
+
if inspect.ismethod(x_device):
|
|
666
|
+
return x_device()
|
|
649
667
|
else:
|
|
650
|
-
return
|
|
668
|
+
return x_device
|
|
651
669
|
elif is_pydata_sparse_array(x):
|
|
652
670
|
# `sparse` will gain `.device`, so check for this first.
|
|
653
671
|
x_device = getattr(x, 'device', None)
|
|
@@ -778,8 +796,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
|
|
|
778
796
|
raise ValueError(f"Unsupported device {device!r}")
|
|
779
797
|
elif is_jax_array(x):
|
|
780
798
|
if not hasattr(x, "__array_namespace__"):
|
|
781
|
-
# In JAX v0.4.31 and older, this import adds to_device method to x
|
|
799
|
+
# In JAX v0.4.31 and older, this import adds to_device method to x...
|
|
782
800
|
import jax.experimental.array_api # noqa: F401
|
|
801
|
+
# ... but only on eager JAX. It won't work inside jax.jit.
|
|
802
|
+
if not hasattr(x, "to_device"):
|
|
803
|
+
return x
|
|
783
804
|
return x.to_device(device, stream=stream)
|
|
784
805
|
elif is_pydata_sparse_array(x) and device == _device(x):
|
|
785
806
|
# Perform trivial check to return the same array if
|
|
@@ -788,24 +809,30 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
|
|
|
788
809
|
return x.to_device(device, stream=stream)
|
|
789
810
|
|
|
790
811
|
|
|
791
|
-
def size(x):
|
|
812
|
+
def size(x: Array) -> int | None:
|
|
792
813
|
"""
|
|
793
814
|
Return the total number of elements of x.
|
|
794
815
|
|
|
795
816
|
This is equivalent to `x.size` according to the `standard
|
|
796
817
|
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
|
|
818
|
+
|
|
797
819
|
This helper is included because PyTorch defines `size` in an
|
|
798
820
|
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
|
|
799
|
-
|
|
821
|
+
It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
|
|
822
|
+
the standard requires None.
|
|
800
823
|
"""
|
|
824
|
+
# Lazy API compliant arrays, such as ndonnx, can contain None in their shape
|
|
801
825
|
if None in x.shape:
|
|
802
826
|
return None
|
|
803
|
-
|
|
827
|
+
out = math.prod(x.shape)
|
|
828
|
+
# dask.array.Array.shape can contain NaN
|
|
829
|
+
return None if math.isnan(out) else out
|
|
804
830
|
|
|
805
831
|
|
|
806
|
-
def is_writeable_array(x) -> bool:
|
|
832
|
+
def is_writeable_array(x: object) -> bool:
|
|
807
833
|
"""
|
|
808
834
|
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
|
|
835
|
+
Return False if `x` is not an array API compatible object.
|
|
809
836
|
|
|
810
837
|
Warning
|
|
811
838
|
-------
|
|
@@ -816,7 +843,67 @@ def is_writeable_array(x) -> bool:
|
|
|
816
843
|
return x.flags.writeable
|
|
817
844
|
if is_jax_array(x) or is_pydata_sparse_array(x):
|
|
818
845
|
return False
|
|
819
|
-
return
|
|
846
|
+
return is_array_api_obj(x)
|
|
847
|
+
|
|
848
|
+
|
|
849
|
+
def is_lazy_array(x: object) -> bool:
|
|
850
|
+
"""Return True if x is potentially a future or it may be otherwise impossible or
|
|
851
|
+
expensive to eagerly read its contents, regardless of their size, e.g. by
|
|
852
|
+
calling ``bool(x)`` or ``float(x)``.
|
|
853
|
+
|
|
854
|
+
Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
|
|
855
|
+
cheap as long as the array has the right dtype and size.
|
|
856
|
+
|
|
857
|
+
Note
|
|
858
|
+
----
|
|
859
|
+
This function errs on the side of caution for array types that may or may not be
|
|
860
|
+
lazy, e.g. JAX arrays, by always returning True for them.
|
|
861
|
+
"""
|
|
862
|
+
if (
|
|
863
|
+
is_numpy_array(x)
|
|
864
|
+
or is_cupy_array(x)
|
|
865
|
+
or is_torch_array(x)
|
|
866
|
+
or is_pydata_sparse_array(x)
|
|
867
|
+
):
|
|
868
|
+
return False
|
|
869
|
+
|
|
870
|
+
# **JAX note:** while it is possible to determine if you're inside or outside
|
|
871
|
+
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
|
|
872
|
+
# as we do below for unknown arrays, this is not recommended by JAX best practices.
|
|
873
|
+
|
|
874
|
+
# **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
|
|
875
|
+
# This behaviour, while impossible to change without breaking backwards
|
|
876
|
+
# compatibility, is highly detrimental to performance as the whole graph will end
|
|
877
|
+
# up being computed multiple times.
|
|
878
|
+
|
|
879
|
+
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
|
|
880
|
+
return True
|
|
881
|
+
|
|
882
|
+
if not is_array_api_obj(x):
|
|
883
|
+
return False
|
|
884
|
+
|
|
885
|
+
# Unknown Array API compatible object. Note that this test may have dire consequences
|
|
886
|
+
# in terms of performance, e.g. for a lazy object that eagerly computes the graph
|
|
887
|
+
# on __bool__ (dask is one such example, which however is special-cased above).
|
|
888
|
+
|
|
889
|
+
# Select a single point of the array
|
|
890
|
+
s = size(x)
|
|
891
|
+
if s is None:
|
|
892
|
+
return True
|
|
893
|
+
xp = array_namespace(x)
|
|
894
|
+
if s > 1:
|
|
895
|
+
x = xp.reshape(x, (-1,))[0]
|
|
896
|
+
# Cast to dtype=bool and deal with size 0 arrays
|
|
897
|
+
x = xp.any(x)
|
|
898
|
+
|
|
899
|
+
try:
|
|
900
|
+
bool(x)
|
|
901
|
+
return False
|
|
902
|
+
# The Array API standard dictactes that __bool__ should raise TypeError if the
|
|
903
|
+
# output cannot be defined.
|
|
904
|
+
# Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
|
|
905
|
+
except Exception:
|
|
906
|
+
return True
|
|
820
907
|
|
|
821
908
|
|
|
822
909
|
__all__ = [
|
|
@@ -840,6 +927,7 @@ __all__ = [
|
|
|
840
927
|
"is_pydata_sparse_array",
|
|
841
928
|
"is_pydata_sparse_namespace",
|
|
842
929
|
"is_writeable_array",
|
|
930
|
+
"is_lazy_array",
|
|
843
931
|
"size",
|
|
844
932
|
"to_device",
|
|
845
933
|
]
|
|
@@ -2,7 +2,7 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import cupy as cp
|
|
4
4
|
|
|
5
|
-
from ..common import _aliases
|
|
5
|
+
from ..common import _aliases, _helpers
|
|
6
6
|
from .._internal import get_xp
|
|
7
7
|
|
|
8
8
|
from ._info import __array_namespace_info__
|
|
@@ -46,10 +46,10 @@ unique_all = get_xp(cp)(_aliases.unique_all)
|
|
|
46
46
|
unique_counts = get_xp(cp)(_aliases.unique_counts)
|
|
47
47
|
unique_inverse = get_xp(cp)(_aliases.unique_inverse)
|
|
48
48
|
unique_values = get_xp(cp)(_aliases.unique_values)
|
|
49
|
-
astype = _aliases.astype
|
|
50
49
|
std = get_xp(cp)(_aliases.std)
|
|
51
50
|
var = get_xp(cp)(_aliases.var)
|
|
52
51
|
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
|
|
52
|
+
cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
|
|
53
53
|
clip = get_xp(cp)(_aliases.clip)
|
|
54
54
|
permute_dims = get_xp(cp)(_aliases.permute_dims)
|
|
55
55
|
reshape = get_xp(cp)(_aliases.reshape)
|
|
@@ -110,6 +110,21 @@ def asarray(
|
|
|
110
110
|
|
|
111
111
|
return cp.array(obj, dtype=dtype, **kwargs)
|
|
112
112
|
|
|
113
|
+
|
|
114
|
+
def astype(
|
|
115
|
+
x: ndarray,
|
|
116
|
+
dtype: Dtype,
|
|
117
|
+
/,
|
|
118
|
+
*,
|
|
119
|
+
copy: bool = True,
|
|
120
|
+
device: Optional[Device] = None,
|
|
121
|
+
) -> ndarray:
|
|
122
|
+
if device is None:
|
|
123
|
+
return x.astype(dtype=dtype, copy=copy)
|
|
124
|
+
out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
|
|
125
|
+
return out.copy() if copy and out is x else out
|
|
126
|
+
|
|
127
|
+
|
|
113
128
|
# These functions are completely new here. If the library already has them
|
|
114
129
|
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
|
115
130
|
if hasattr(cp, 'vecdot'):
|
|
@@ -127,10 +142,10 @@ if hasattr(cp, 'unstack'):
|
|
|
127
142
|
else:
|
|
128
143
|
unstack = get_xp(cp)(_aliases.unstack)
|
|
129
144
|
|
|
130
|
-
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', '
|
|
145
|
+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
|
|
131
146
|
'acos', 'acosh', 'asin', 'asinh', 'atan',
|
|
132
147
|
'atan2', 'atanh', 'bitwise_left_shift',
|
|
133
148
|
'bitwise_invert', 'bitwise_right_shift',
|
|
134
|
-
'concat', 'pow', 'sign']
|
|
149
|
+
'bool', 'concat', 'pow', 'sign']
|
|
135
150
|
|
|
136
151
|
_all_ignore = ['cp', 'get_xp']
|
|
@@ -3,7 +3,7 @@ from dask.array import * # noqa: F403
|
|
|
3
3
|
# These imports may overwrite names from the import * above.
|
|
4
4
|
from ._aliases import * # noqa: F403
|
|
5
5
|
|
|
6
|
-
__array_api_version__ = '
|
|
6
|
+
__array_api_version__ = '2024.12'
|
|
7
7
|
|
|
8
8
|
__import__(__package__ + '.linalg')
|
|
9
9
|
__import__(__package__ + '.fft')
|