array-api-compat 1.8__tar.gz → 1.9.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {array_api_compat-1.8 → array_api_compat-1.9.1}/PKG-INFO +4 -1
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/__init__.py +1 -1
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/common/_aliases.py +71 -46
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/common/_helpers.py +187 -17
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/common/_linalg.py +0 -5
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/cupy/__init__.py +1 -1
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/cupy/_aliases.py +15 -6
- array_api_compat-1.9.1/array_api_compat/cupy/_info.py +326 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/dask/array/__init__.py +2 -1
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/dask/array/_aliases.py +55 -22
- array_api_compat-1.9.1/array_api_compat/dask/array/_info.py +345 -0
- array_api_compat-1.9.1/array_api_compat/dask/array/fft.py +24 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/dask/array/linalg.py +2 -1
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/numpy/__init__.py +1 -1
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/numpy/_aliases.py +15 -6
- array_api_compat-1.9.1/array_api_compat/numpy/_info.py +346 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/torch/__init__.py +1 -1
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/torch/_aliases.py +40 -9
- array_api_compat-1.9.1/array_api_compat/torch/_info.py +358 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat.egg-info/PKG-INFO +4 -1
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat.egg-info/SOURCES.txt +5 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/setup.py +3 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/tests/test_all.py +2 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/tests/test_array_namespace.py +30 -8
- {array_api_compat-1.8 → array_api_compat-1.9.1}/tests/test_common.py +34 -10
- {array_api_compat-1.8 → array_api_compat-1.9.1}/tests/test_vendoring.py +1 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/LICENSE +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/README.md +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/_internal.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/common/__init__.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/common/_fft.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/common/_typing.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/cupy/_typing.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/cupy/fft.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/cupy/linalg.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/dask/__init__.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/numpy/_typing.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/numpy/fft.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/numpy/linalg.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/torch/fft.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat/torch/linalg.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat.egg-info/dependency_links.txt +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat.egg-info/requires.txt +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/array_api_compat.egg-info/top_level.txt +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/setup.cfg +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/tests/test_isdtype.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9.1}/tests/test_no_dependencies.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: array_api_compat
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9.1
|
|
4
4
|
Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
|
|
5
5
|
Home-page: https://data-apis.org/array-api-compat/
|
|
6
6
|
Author: Consortium for Python Data API Standards
|
|
@@ -9,8 +9,11 @@ Classifier: Programming Language :: Python :: 3
|
|
|
9
9
|
Classifier: Programming Language :: Python :: 3.9
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.10
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
12
14
|
Classifier: License :: OSI Approved :: MIT License
|
|
13
15
|
Classifier: Operating System :: OS Independent
|
|
16
|
+
Requires-Python: >=3.9
|
|
14
17
|
Description-Content-Type: text/markdown
|
|
15
18
|
License-File: LICENSE
|
|
16
19
|
Provides-Extra: numpy
|
|
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
|
12
12
|
from typing import NamedTuple
|
|
13
13
|
import inspect
|
|
14
14
|
|
|
15
|
-
from ._helpers import array_namespace, _check_device
|
|
15
|
+
from ._helpers import array_namespace, _check_device, device, is_torch_array, is_cupy_namespace
|
|
16
16
|
|
|
17
17
|
# These functions are modified from the NumPy versions.
|
|
18
18
|
|
|
@@ -264,6 +264,38 @@ def var(
|
|
|
264
264
|
) -> ndarray:
|
|
265
265
|
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
|
266
266
|
|
|
267
|
+
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
|
|
268
|
+
# argument
|
|
269
|
+
|
|
270
|
+
def cumulative_sum(
|
|
271
|
+
x: ndarray,
|
|
272
|
+
/,
|
|
273
|
+
xp,
|
|
274
|
+
*,
|
|
275
|
+
axis: Optional[int] = None,
|
|
276
|
+
dtype: Optional[Dtype] = None,
|
|
277
|
+
include_initial: bool = False,
|
|
278
|
+
**kwargs
|
|
279
|
+
) -> ndarray:
|
|
280
|
+
wrapped_xp = array_namespace(x)
|
|
281
|
+
|
|
282
|
+
# TODO: The standard is not clear about what should happen when x.ndim == 0.
|
|
283
|
+
if axis is None:
|
|
284
|
+
if x.ndim > 1:
|
|
285
|
+
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
|
|
286
|
+
axis = 0
|
|
287
|
+
|
|
288
|
+
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
|
|
289
|
+
|
|
290
|
+
# np.cumsum does not support include_initial
|
|
291
|
+
if include_initial:
|
|
292
|
+
initial_shape = list(x.shape)
|
|
293
|
+
initial_shape[axis] = 1
|
|
294
|
+
res = xp.concatenate(
|
|
295
|
+
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
|
|
296
|
+
axis=axis,
|
|
297
|
+
)
|
|
298
|
+
return res
|
|
267
299
|
|
|
268
300
|
# The min and max argument names in clip are different and not optional in numpy, and type
|
|
269
301
|
# promotion behavior is different.
|
|
@@ -281,10 +313,11 @@ def clip(
|
|
|
281
313
|
return isinstance(a, (int, float, type(None)))
|
|
282
314
|
min_shape = () if _isscalar(min) else min.shape
|
|
283
315
|
max_shape = () if _isscalar(max) else max.shape
|
|
284
|
-
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
|
|
285
316
|
|
|
286
317
|
wrapped_xp = array_namespace(x)
|
|
287
318
|
|
|
319
|
+
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
|
|
320
|
+
|
|
288
321
|
# np.clip does type promotion but the array API clip requires that the
|
|
289
322
|
# output have the same dtype as x. We do this instead of just downcasting
|
|
290
323
|
# the result of xp.clip() to handle some corner cases better (e.g.,
|
|
@@ -305,20 +338,26 @@ def clip(
|
|
|
305
338
|
|
|
306
339
|
# At least handle the case of Python integers correctly (see
|
|
307
340
|
# https://github.com/numpy/numpy/pull/26892).
|
|
308
|
-
if type(min) is int and min <=
|
|
341
|
+
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
|
|
309
342
|
min = None
|
|
310
|
-
if type(max) is int and max >=
|
|
343
|
+
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
|
|
311
344
|
max = None
|
|
312
345
|
|
|
313
346
|
if out is None:
|
|
314
|
-
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
|
|
347
|
+
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
|
|
348
|
+
copy=True, device=device(x))
|
|
315
349
|
if min is not None:
|
|
316
|
-
|
|
350
|
+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
|
|
351
|
+
# Avoid loss of precision due to torch defaulting to float32
|
|
352
|
+
min = wrapped_xp.asarray(min, dtype=xp.float64)
|
|
353
|
+
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
|
|
317
354
|
ia = (out < a) | xp.isnan(a)
|
|
318
355
|
# torch requires an explicit cast here
|
|
319
356
|
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
|
|
320
357
|
if max is not None:
|
|
321
|
-
|
|
358
|
+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
|
|
359
|
+
max = wrapped_xp.asarray(max, dtype=xp.float64)
|
|
360
|
+
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
|
|
322
361
|
ib = (out > b) | xp.isnan(b)
|
|
323
362
|
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
|
|
324
363
|
# Return a scalar for 0-D
|
|
@@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
|
|
|
389
428
|
raise ValueError("nonzero() does not support zero-dimensional arrays")
|
|
390
429
|
return xp.nonzero(x, **kwargs)
|
|
391
430
|
|
|
392
|
-
# sum() and prod() should always upcast when dtype=None
|
|
393
|
-
def sum(
|
|
394
|
-
x: ndarray,
|
|
395
|
-
/,
|
|
396
|
-
xp,
|
|
397
|
-
*,
|
|
398
|
-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
399
|
-
dtype: Optional[Dtype] = None,
|
|
400
|
-
keepdims: bool = False,
|
|
401
|
-
**kwargs,
|
|
402
|
-
) -> ndarray:
|
|
403
|
-
# `xp.sum` already upcasts integers, but not floats or complexes
|
|
404
|
-
if dtype is None:
|
|
405
|
-
if x.dtype == xp.float32:
|
|
406
|
-
dtype = xp.float64
|
|
407
|
-
elif x.dtype == xp.complex64:
|
|
408
|
-
dtype = xp.complex128
|
|
409
|
-
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
|
|
410
|
-
|
|
411
|
-
def prod(
|
|
412
|
-
x: ndarray,
|
|
413
|
-
/,
|
|
414
|
-
xp,
|
|
415
|
-
*,
|
|
416
|
-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
417
|
-
dtype: Optional[Dtype] = None,
|
|
418
|
-
keepdims: bool = False,
|
|
419
|
-
**kwargs,
|
|
420
|
-
) -> ndarray:
|
|
421
|
-
if dtype is None:
|
|
422
|
-
if x.dtype == xp.float32:
|
|
423
|
-
dtype = xp.float64
|
|
424
|
-
elif x.dtype == xp.complex64:
|
|
425
|
-
dtype = xp.complex128
|
|
426
|
-
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
|
|
427
|
-
|
|
428
431
|
# ceil, floor, and trunc return integers for integer inputs
|
|
429
432
|
|
|
430
433
|
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
|
|
@@ -521,10 +524,32 @@ def isdtype(
|
|
|
521
524
|
# array_api_strict implementation will be very strict.
|
|
522
525
|
return dtype == kind
|
|
523
526
|
|
|
527
|
+
# unstack is a new function in the 2023.12 array API standard
|
|
528
|
+
def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
|
|
529
|
+
if x.ndim == 0:
|
|
530
|
+
raise ValueError("Input array must be at least 1-d.")
|
|
531
|
+
return tuple(xp.moveaxis(x, axis, 0))
|
|
532
|
+
|
|
533
|
+
# numpy 1.26 does not use the standard definition for sign on complex numbers
|
|
534
|
+
|
|
535
|
+
def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
|
|
536
|
+
if isdtype(x.dtype, 'complex floating', xp=xp):
|
|
537
|
+
out = (x/xp.abs(x, **kwargs))[...]
|
|
538
|
+
# sign(0) = 0 but the above formula would give nan
|
|
539
|
+
out[x == 0+0j] = 0+0j
|
|
540
|
+
else:
|
|
541
|
+
out = xp.sign(x, **kwargs)
|
|
542
|
+
# CuPy sign() does not propagate nans. See
|
|
543
|
+
# https://github.com/data-apis/array-api-compat/issues/136
|
|
544
|
+
if is_cupy_namespace(xp) and isdtype(x.dtype, 'real floating', xp=xp):
|
|
545
|
+
out[xp.isnan(x)] = xp.nan
|
|
546
|
+
return out[()]
|
|
547
|
+
|
|
524
548
|
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
|
|
525
549
|
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
|
|
526
550
|
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
|
|
527
551
|
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
|
|
528
|
-
'astype', 'std', 'var', '
|
|
529
|
-
'
|
|
530
|
-
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'
|
|
552
|
+
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
|
|
553
|
+
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
|
|
554
|
+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
|
|
555
|
+
'unstack', 'sign']
|
|
@@ -145,7 +145,7 @@ def is_ndonnx_array(x):
|
|
|
145
145
|
|
|
146
146
|
import ndonnx as ndx
|
|
147
147
|
|
|
148
|
-
return isinstance(x, ndx.Array)
|
|
148
|
+
return isinstance(x, ndx.Array)
|
|
149
149
|
|
|
150
150
|
def is_dask_array(x):
|
|
151
151
|
"""
|
|
@@ -202,7 +202,6 @@ def is_jax_array(x):
|
|
|
202
202
|
|
|
203
203
|
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
|
|
204
204
|
|
|
205
|
-
|
|
206
205
|
def is_pydata_sparse_array(x) -> bool:
|
|
207
206
|
"""
|
|
208
207
|
Return True if `x` is an array from the `sparse` package.
|
|
@@ -255,11 +254,172 @@ def is_array_api_obj(x):
|
|
|
255
254
|
or is_pydata_sparse_array(x) \
|
|
256
255
|
or hasattr(x, '__array_namespace__')
|
|
257
256
|
|
|
257
|
+
def _compat_module_name():
|
|
258
|
+
assert __name__.endswith('.common._helpers')
|
|
259
|
+
return __name__.removesuffix('.common._helpers')
|
|
260
|
+
|
|
261
|
+
def is_numpy_namespace(xp) -> bool:
|
|
262
|
+
"""
|
|
263
|
+
Returns True if `xp` is a NumPy namespace.
|
|
264
|
+
|
|
265
|
+
This includes both NumPy itself and the version wrapped by array-api-compat.
|
|
266
|
+
|
|
267
|
+
See Also
|
|
268
|
+
--------
|
|
269
|
+
|
|
270
|
+
array_namespace
|
|
271
|
+
is_cupy_namespace
|
|
272
|
+
is_torch_namespace
|
|
273
|
+
is_ndonnx_namespace
|
|
274
|
+
is_dask_namespace
|
|
275
|
+
is_jax_namespace
|
|
276
|
+
is_pydata_sparse_namespace
|
|
277
|
+
is_array_api_strict_namespace
|
|
278
|
+
"""
|
|
279
|
+
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
|
|
280
|
+
|
|
281
|
+
def is_cupy_namespace(xp) -> bool:
|
|
282
|
+
"""
|
|
283
|
+
Returns True if `xp` is a CuPy namespace.
|
|
284
|
+
|
|
285
|
+
This includes both CuPy itself and the version wrapped by array-api-compat.
|
|
286
|
+
|
|
287
|
+
See Also
|
|
288
|
+
--------
|
|
289
|
+
|
|
290
|
+
array_namespace
|
|
291
|
+
is_numpy_namespace
|
|
292
|
+
is_torch_namespace
|
|
293
|
+
is_ndonnx_namespace
|
|
294
|
+
is_dask_namespace
|
|
295
|
+
is_jax_namespace
|
|
296
|
+
is_pydata_sparse_namespace
|
|
297
|
+
is_array_api_strict_namespace
|
|
298
|
+
"""
|
|
299
|
+
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
|
|
300
|
+
|
|
301
|
+
def is_torch_namespace(xp) -> bool:
|
|
302
|
+
"""
|
|
303
|
+
Returns True if `xp` is a PyTorch namespace.
|
|
304
|
+
|
|
305
|
+
This includes both PyTorch itself and the version wrapped by array-api-compat.
|
|
306
|
+
|
|
307
|
+
See Also
|
|
308
|
+
--------
|
|
309
|
+
|
|
310
|
+
array_namespace
|
|
311
|
+
is_numpy_namespace
|
|
312
|
+
is_cupy_namespace
|
|
313
|
+
is_ndonnx_namespace
|
|
314
|
+
is_dask_namespace
|
|
315
|
+
is_jax_namespace
|
|
316
|
+
is_pydata_sparse_namespace
|
|
317
|
+
is_array_api_strict_namespace
|
|
318
|
+
"""
|
|
319
|
+
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def is_ndonnx_namespace(xp):
|
|
323
|
+
"""
|
|
324
|
+
Returns True if `xp` is an NDONNX namespace.
|
|
325
|
+
|
|
326
|
+
See Also
|
|
327
|
+
--------
|
|
328
|
+
|
|
329
|
+
array_namespace
|
|
330
|
+
is_numpy_namespace
|
|
331
|
+
is_cupy_namespace
|
|
332
|
+
is_torch_namespace
|
|
333
|
+
is_dask_namespace
|
|
334
|
+
is_jax_namespace
|
|
335
|
+
is_pydata_sparse_namespace
|
|
336
|
+
is_array_api_strict_namespace
|
|
337
|
+
"""
|
|
338
|
+
return xp.__name__ == 'ndonnx'
|
|
339
|
+
|
|
340
|
+
def is_dask_namespace(xp):
|
|
341
|
+
"""
|
|
342
|
+
Returns True if `xp` is a Dask namespace.
|
|
343
|
+
|
|
344
|
+
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
|
|
345
|
+
|
|
346
|
+
See Also
|
|
347
|
+
--------
|
|
348
|
+
|
|
349
|
+
array_namespace
|
|
350
|
+
is_numpy_namespace
|
|
351
|
+
is_cupy_namespace
|
|
352
|
+
is_torch_namespace
|
|
353
|
+
is_ndonnx_namespace
|
|
354
|
+
is_jax_namespace
|
|
355
|
+
is_pydata_sparse_namespace
|
|
356
|
+
is_array_api_strict_namespace
|
|
357
|
+
"""
|
|
358
|
+
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
|
|
359
|
+
|
|
360
|
+
def is_jax_namespace(xp):
|
|
361
|
+
"""
|
|
362
|
+
Returns True if `xp` is a JAX namespace.
|
|
363
|
+
|
|
364
|
+
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
|
|
365
|
+
older versions of JAX.
|
|
366
|
+
|
|
367
|
+
See Also
|
|
368
|
+
--------
|
|
369
|
+
|
|
370
|
+
array_namespace
|
|
371
|
+
is_numpy_namespace
|
|
372
|
+
is_cupy_namespace
|
|
373
|
+
is_torch_namespace
|
|
374
|
+
is_ndonnx_namespace
|
|
375
|
+
is_dask_namespace
|
|
376
|
+
is_pydata_sparse_namespace
|
|
377
|
+
is_array_api_strict_namespace
|
|
378
|
+
"""
|
|
379
|
+
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
|
|
380
|
+
|
|
381
|
+
def is_pydata_sparse_namespace(xp):
|
|
382
|
+
"""
|
|
383
|
+
Returns True if `xp` is a pydata/sparse namespace.
|
|
384
|
+
|
|
385
|
+
See Also
|
|
386
|
+
--------
|
|
387
|
+
|
|
388
|
+
array_namespace
|
|
389
|
+
is_numpy_namespace
|
|
390
|
+
is_cupy_namespace
|
|
391
|
+
is_torch_namespace
|
|
392
|
+
is_ndonnx_namespace
|
|
393
|
+
is_dask_namespace
|
|
394
|
+
is_jax_namespace
|
|
395
|
+
is_array_api_strict_namespace
|
|
396
|
+
"""
|
|
397
|
+
return xp.__name__ == 'sparse'
|
|
398
|
+
|
|
399
|
+
def is_array_api_strict_namespace(xp):
|
|
400
|
+
"""
|
|
401
|
+
Returns True if `xp` is an array-api-strict namespace.
|
|
402
|
+
|
|
403
|
+
See Also
|
|
404
|
+
--------
|
|
405
|
+
|
|
406
|
+
array_namespace
|
|
407
|
+
is_numpy_namespace
|
|
408
|
+
is_cupy_namespace
|
|
409
|
+
is_torch_namespace
|
|
410
|
+
is_ndonnx_namespace
|
|
411
|
+
is_dask_namespace
|
|
412
|
+
is_jax_namespace
|
|
413
|
+
is_pydata_sparse_namespace
|
|
414
|
+
"""
|
|
415
|
+
return xp.__name__ == 'array_api_strict'
|
|
416
|
+
|
|
258
417
|
def _check_api_version(api_version):
|
|
259
|
-
if api_version
|
|
260
|
-
warnings.warn("The
|
|
261
|
-
elif api_version is not None and api_version
|
|
262
|
-
|
|
418
|
+
if api_version in ['2021.12', '2022.12']:
|
|
419
|
+
warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
|
|
420
|
+
elif api_version is not None and api_version not in ['2021.12', '2022.12',
|
|
421
|
+
'2023.12']:
|
|
422
|
+
raise ValueError("Only the 2023.12 version of the array API specification is currently supported")
|
|
263
423
|
|
|
264
424
|
def array_namespace(*xs, api_version=None, use_compat=None):
|
|
265
425
|
"""
|
|
@@ -272,7 +432,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):
|
|
|
272
432
|
|
|
273
433
|
api_version: str
|
|
274
434
|
The newest version of the spec that you need support for (currently
|
|
275
|
-
the compat library wrapped APIs support
|
|
435
|
+
the compat library wrapped APIs support v2023.12).
|
|
276
436
|
|
|
277
437
|
use_compat: bool or None
|
|
278
438
|
If None (the default), the native namespace will be returned if it is
|
|
@@ -340,12 +500,9 @@ def array_namespace(*xs, api_version=None, use_compat=None):
|
|
|
340
500
|
elif use_compat is False:
|
|
341
501
|
namespaces.add(np)
|
|
342
502
|
else:
|
|
343
|
-
# numpy 2.0
|
|
503
|
+
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
|
|
344
504
|
# compatible.
|
|
345
|
-
|
|
346
|
-
namespaces.add(x.__array_namespace__(api_version=api_version))
|
|
347
|
-
else:
|
|
348
|
-
namespaces.add(numpy_namespace)
|
|
505
|
+
namespaces.add(numpy_namespace)
|
|
349
506
|
elif is_cupy_array(x):
|
|
350
507
|
if _use_compat:
|
|
351
508
|
_check_api_version(api_version)
|
|
@@ -377,9 +534,13 @@ def array_namespace(*xs, api_version=None, use_compat=None):
|
|
|
377
534
|
elif use_compat is False:
|
|
378
535
|
import jax.numpy as jnp
|
|
379
536
|
else:
|
|
380
|
-
#
|
|
381
|
-
#
|
|
382
|
-
import jax.
|
|
537
|
+
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
|
|
538
|
+
# For older JAX versions, it is available via jax.experimental.array_api.
|
|
539
|
+
import jax.numpy
|
|
540
|
+
if hasattr(jax.numpy, "__array_api_version__"):
|
|
541
|
+
jnp = jax.numpy
|
|
542
|
+
else:
|
|
543
|
+
import jax.experimental.array_api as jnp
|
|
383
544
|
namespaces.add(jnp)
|
|
384
545
|
elif is_pydata_sparse_array(x):
|
|
385
546
|
if use_compat is True:
|
|
@@ -613,8 +774,9 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
|
|
|
613
774
|
return x
|
|
614
775
|
raise ValueError(f"Unsupported device {device!r}")
|
|
615
776
|
elif is_jax_array(x):
|
|
616
|
-
|
|
617
|
-
|
|
777
|
+
if not hasattr(x, "__array_namespace__"):
|
|
778
|
+
# In JAX v0.4.31 and older, this import adds to_device method to x.
|
|
779
|
+
import jax.experimental.array_api # noqa: F401
|
|
618
780
|
return x.to_device(device, stream=stream)
|
|
619
781
|
elif is_pydata_sparse_array(x) and device == _device(x):
|
|
620
782
|
# Perform trivial check to return the same array if
|
|
@@ -641,13 +803,21 @@ __all__ = [
|
|
|
641
803
|
"device",
|
|
642
804
|
"get_namespace",
|
|
643
805
|
"is_array_api_obj",
|
|
806
|
+
"is_array_api_strict_namespace",
|
|
644
807
|
"is_cupy_array",
|
|
808
|
+
"is_cupy_namespace",
|
|
645
809
|
"is_dask_array",
|
|
810
|
+
"is_dask_namespace",
|
|
646
811
|
"is_jax_array",
|
|
812
|
+
"is_jax_namespace",
|
|
647
813
|
"is_numpy_array",
|
|
814
|
+
"is_numpy_namespace",
|
|
648
815
|
"is_torch_array",
|
|
816
|
+
"is_torch_namespace",
|
|
649
817
|
"is_ndonnx_array",
|
|
818
|
+
"is_ndonnx_namespace",
|
|
650
819
|
"is_pydata_sparse_array",
|
|
820
|
+
"is_pydata_sparse_namespace",
|
|
651
821
|
"size",
|
|
652
822
|
"to_device",
|
|
653
823
|
]
|
|
@@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
|
|
|
147
147
|
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
|
|
148
148
|
|
|
149
149
|
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
|
|
150
|
-
if dtype is None:
|
|
151
|
-
if x.dtype == xp.float32:
|
|
152
|
-
dtype = xp.float64
|
|
153
|
-
elif x.dtype == xp.complex64:
|
|
154
|
-
dtype = xp.complex128
|
|
155
150
|
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
|
|
156
151
|
|
|
157
152
|
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
|
|
@@ -5,6 +5,8 @@ import cupy as cp
|
|
|
5
5
|
from ..common import _aliases
|
|
6
6
|
from .._internal import get_xp
|
|
7
7
|
|
|
8
|
+
from ._info import __array_namespace_info__
|
|
9
|
+
|
|
8
10
|
from typing import TYPE_CHECKING
|
|
9
11
|
if TYPE_CHECKING:
|
|
10
12
|
from typing import Optional, Union
|
|
@@ -47,20 +49,20 @@ unique_values = get_xp(cp)(_aliases.unique_values)
|
|
|
47
49
|
astype = _aliases.astype
|
|
48
50
|
std = get_xp(cp)(_aliases.std)
|
|
49
51
|
var = get_xp(cp)(_aliases.var)
|
|
52
|
+
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
|
|
50
53
|
clip = get_xp(cp)(_aliases.clip)
|
|
51
54
|
permute_dims = get_xp(cp)(_aliases.permute_dims)
|
|
52
55
|
reshape = get_xp(cp)(_aliases.reshape)
|
|
53
56
|
argsort = get_xp(cp)(_aliases.argsort)
|
|
54
57
|
sort = get_xp(cp)(_aliases.sort)
|
|
55
58
|
nonzero = get_xp(cp)(_aliases.nonzero)
|
|
56
|
-
sum = get_xp(cp)(_aliases.sum)
|
|
57
|
-
prod = get_xp(cp)(_aliases.prod)
|
|
58
59
|
ceil = get_xp(cp)(_aliases.ceil)
|
|
59
60
|
floor = get_xp(cp)(_aliases.floor)
|
|
60
61
|
trunc = get_xp(cp)(_aliases.trunc)
|
|
61
62
|
matmul = get_xp(cp)(_aliases.matmul)
|
|
62
63
|
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
|
|
63
64
|
tensordot = get_xp(cp)(_aliases.tensordot)
|
|
65
|
+
sign = get_xp(cp)(_aliases.sign)
|
|
64
66
|
|
|
65
67
|
_copy_default = object()
|
|
66
68
|
|
|
@@ -114,14 +116,21 @@ if hasattr(cp, 'vecdot'):
|
|
|
114
116
|
vecdot = cp.vecdot
|
|
115
117
|
else:
|
|
116
118
|
vecdot = get_xp(cp)(_aliases.vecdot)
|
|
119
|
+
|
|
117
120
|
if hasattr(cp, 'isdtype'):
|
|
118
121
|
isdtype = cp.isdtype
|
|
119
122
|
else:
|
|
120
123
|
isdtype = get_xp(cp)(_aliases.isdtype)
|
|
121
124
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
125
|
+
if hasattr(cp, 'unstack'):
|
|
126
|
+
unstack = cp.unstack
|
|
127
|
+
else:
|
|
128
|
+
unstack = get_xp(cp)(_aliases.unstack)
|
|
129
|
+
|
|
130
|
+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
|
|
131
|
+
'acos', 'acosh', 'asin', 'asinh', 'atan',
|
|
132
|
+
'atan2', 'atanh', 'bitwise_left_shift',
|
|
133
|
+
'bitwise_invert', 'bitwise_right_shift',
|
|
134
|
+
'concat', 'pow', 'sign']
|
|
126
135
|
|
|
127
136
|
_all_ignore = ['cp', 'get_xp']
|