array-api-compat 1.8__tar.gz → 1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {array_api_compat-1.8 → array_api_compat-1.9}/PKG-INFO +1 -1
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/__init__.py +1 -1
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_aliases.py +56 -46
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_helpers.py +181 -12
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_linalg.py +0 -5
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/_aliases.py +21 -6
- array_api_compat-1.9/array_api_compat/cupy/_info.py +326 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/dask/array/_aliases.py +49 -9
- array_api_compat-1.9/array_api_compat/dask/array/_info.py +345 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/dask/array/linalg.py +2 -1
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/_aliases.py +14 -6
- array_api_compat-1.9/array_api_compat/numpy/_info.py +346 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/torch/_aliases.py +40 -9
- array_api_compat-1.9/array_api_compat/torch/_info.py +358 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/PKG-INFO +1 -1
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/SOURCES.txt +4 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_all.py +2 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_array_namespace.py +20 -4
- {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_common.py +34 -10
- {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_vendoring.py +1 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/LICENSE +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/README.md +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/_internal.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/__init__.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_fft.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_typing.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/__init__.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/_typing.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/fft.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/linalg.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/dask/__init__.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/dask/array/__init__.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/__init__.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/_typing.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/fft.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/linalg.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/torch/__init__.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/torch/fft.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/torch/linalg.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/dependency_links.txt +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/requires.txt +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/top_level.txt +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/setup.cfg +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/setup.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_isdtype.py +0 -0
- {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_no_dependencies.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: array_api_compat
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.9
|
|
4
4
|
Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
|
|
5
5
|
Home-page: https://data-apis.org/array-api-compat/
|
|
6
6
|
Author: Consortium for Python Data API Standards
|
|
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
|
|
|
12
12
|
from typing import NamedTuple
|
|
13
13
|
import inspect
|
|
14
14
|
|
|
15
|
-
from ._helpers import array_namespace, _check_device
|
|
15
|
+
from ._helpers import array_namespace, _check_device, device, is_torch_array
|
|
16
16
|
|
|
17
17
|
# These functions are modified from the NumPy versions.
|
|
18
18
|
|
|
@@ -264,6 +264,38 @@ def var(
|
|
|
264
264
|
) -> ndarray:
|
|
265
265
|
return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
|
|
266
266
|
|
|
267
|
+
# cumulative_sum is renamed from cumsum, and adds the include_initial keyword
|
|
268
|
+
# argument
|
|
269
|
+
|
|
270
|
+
def cumulative_sum(
|
|
271
|
+
x: ndarray,
|
|
272
|
+
/,
|
|
273
|
+
xp,
|
|
274
|
+
*,
|
|
275
|
+
axis: Optional[int] = None,
|
|
276
|
+
dtype: Optional[Dtype] = None,
|
|
277
|
+
include_initial: bool = False,
|
|
278
|
+
**kwargs
|
|
279
|
+
) -> ndarray:
|
|
280
|
+
wrapped_xp = array_namespace(x)
|
|
281
|
+
|
|
282
|
+
# TODO: The standard is not clear about what should happen when x.ndim == 0.
|
|
283
|
+
if axis is None:
|
|
284
|
+
if x.ndim > 1:
|
|
285
|
+
raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
|
|
286
|
+
axis = 0
|
|
287
|
+
|
|
288
|
+
res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
|
|
289
|
+
|
|
290
|
+
# np.cumsum does not support include_initial
|
|
291
|
+
if include_initial:
|
|
292
|
+
initial_shape = list(x.shape)
|
|
293
|
+
initial_shape[axis] = 1
|
|
294
|
+
res = xp.concatenate(
|
|
295
|
+
[wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
|
|
296
|
+
axis=axis,
|
|
297
|
+
)
|
|
298
|
+
return res
|
|
267
299
|
|
|
268
300
|
# The min and max argument names in clip are different and not optional in numpy, and type
|
|
269
301
|
# promotion behavior is different.
|
|
@@ -281,10 +313,11 @@ def clip(
|
|
|
281
313
|
return isinstance(a, (int, float, type(None)))
|
|
282
314
|
min_shape = () if _isscalar(min) else min.shape
|
|
283
315
|
max_shape = () if _isscalar(max) else max.shape
|
|
284
|
-
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
|
|
285
316
|
|
|
286
317
|
wrapped_xp = array_namespace(x)
|
|
287
318
|
|
|
319
|
+
result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
|
|
320
|
+
|
|
288
321
|
# np.clip does type promotion but the array API clip requires that the
|
|
289
322
|
# output have the same dtype as x. We do this instead of just downcasting
|
|
290
323
|
# the result of xp.clip() to handle some corner cases better (e.g.,
|
|
@@ -305,20 +338,26 @@ def clip(
|
|
|
305
338
|
|
|
306
339
|
# At least handle the case of Python integers correctly (see
|
|
307
340
|
# https://github.com/numpy/numpy/pull/26892).
|
|
308
|
-
if type(min) is int and min <=
|
|
341
|
+
if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
|
|
309
342
|
min = None
|
|
310
|
-
if type(max) is int and max >=
|
|
343
|
+
if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
|
|
311
344
|
max = None
|
|
312
345
|
|
|
313
346
|
if out is None:
|
|
314
|
-
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
|
|
347
|
+
out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
|
|
348
|
+
copy=True, device=device(x))
|
|
315
349
|
if min is not None:
|
|
316
|
-
|
|
350
|
+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
|
|
351
|
+
# Avoid loss of precision due to torch defaulting to float32
|
|
352
|
+
min = wrapped_xp.asarray(min, dtype=xp.float64)
|
|
353
|
+
a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
|
|
317
354
|
ia = (out < a) | xp.isnan(a)
|
|
318
355
|
# torch requires an explicit cast here
|
|
319
356
|
out[ia] = wrapped_xp.astype(a[ia], out.dtype)
|
|
320
357
|
if max is not None:
|
|
321
|
-
|
|
358
|
+
if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
|
|
359
|
+
max = wrapped_xp.asarray(max, dtype=xp.float64)
|
|
360
|
+
b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
|
|
322
361
|
ib = (out > b) | xp.isnan(b)
|
|
323
362
|
out[ib] = wrapped_xp.astype(b[ib], out.dtype)
|
|
324
363
|
# Return a scalar for 0-D
|
|
@@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
|
|
|
389
428
|
raise ValueError("nonzero() does not support zero-dimensional arrays")
|
|
390
429
|
return xp.nonzero(x, **kwargs)
|
|
391
430
|
|
|
392
|
-
# sum() and prod() should always upcast when dtype=None
|
|
393
|
-
def sum(
|
|
394
|
-
x: ndarray,
|
|
395
|
-
/,
|
|
396
|
-
xp,
|
|
397
|
-
*,
|
|
398
|
-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
399
|
-
dtype: Optional[Dtype] = None,
|
|
400
|
-
keepdims: bool = False,
|
|
401
|
-
**kwargs,
|
|
402
|
-
) -> ndarray:
|
|
403
|
-
# `xp.sum` already upcasts integers, but not floats or complexes
|
|
404
|
-
if dtype is None:
|
|
405
|
-
if x.dtype == xp.float32:
|
|
406
|
-
dtype = xp.float64
|
|
407
|
-
elif x.dtype == xp.complex64:
|
|
408
|
-
dtype = xp.complex128
|
|
409
|
-
return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
|
|
410
|
-
|
|
411
|
-
def prod(
|
|
412
|
-
x: ndarray,
|
|
413
|
-
/,
|
|
414
|
-
xp,
|
|
415
|
-
*,
|
|
416
|
-
axis: Optional[Union[int, Tuple[int, ...]]] = None,
|
|
417
|
-
dtype: Optional[Dtype] = None,
|
|
418
|
-
keepdims: bool = False,
|
|
419
|
-
**kwargs,
|
|
420
|
-
) -> ndarray:
|
|
421
|
-
if dtype is None:
|
|
422
|
-
if x.dtype == xp.float32:
|
|
423
|
-
dtype = xp.float64
|
|
424
|
-
elif x.dtype == xp.complex64:
|
|
425
|
-
dtype = xp.complex128
|
|
426
|
-
return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
|
|
427
|
-
|
|
428
431
|
# ceil, floor, and trunc return integers for integer inputs
|
|
429
432
|
|
|
430
433
|
def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
|
|
@@ -521,10 +524,17 @@ def isdtype(
|
|
|
521
524
|
# array_api_strict implementation will be very strict.
|
|
522
525
|
return dtype == kind
|
|
523
526
|
|
|
527
|
+
# unstack is a new function in the 2023.12 array API standard
|
|
528
|
+
def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
|
|
529
|
+
if x.ndim == 0:
|
|
530
|
+
raise ValueError("Input array must be at least 1-d.")
|
|
531
|
+
return tuple(xp.moveaxis(x, axis, 0))
|
|
532
|
+
|
|
524
533
|
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
|
|
525
534
|
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
|
|
526
535
|
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
|
|
527
536
|
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
|
|
528
|
-
'astype', 'std', 'var', '
|
|
529
|
-
'
|
|
530
|
-
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'
|
|
537
|
+
'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
|
|
538
|
+
'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
|
|
539
|
+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
|
|
540
|
+
'unstack']
|
|
@@ -145,7 +145,7 @@ def is_ndonnx_array(x):
|
|
|
145
145
|
|
|
146
146
|
import ndonnx as ndx
|
|
147
147
|
|
|
148
|
-
return isinstance(x, ndx.Array)
|
|
148
|
+
return isinstance(x, ndx.Array)
|
|
149
149
|
|
|
150
150
|
def is_dask_array(x):
|
|
151
151
|
"""
|
|
@@ -202,7 +202,6 @@ def is_jax_array(x):
|
|
|
202
202
|
|
|
203
203
|
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
|
|
204
204
|
|
|
205
|
-
|
|
206
205
|
def is_pydata_sparse_array(x) -> bool:
|
|
207
206
|
"""
|
|
208
207
|
Return True if `x` is an array from the `sparse` package.
|
|
@@ -255,6 +254,166 @@ def is_array_api_obj(x):
|
|
|
255
254
|
or is_pydata_sparse_array(x) \
|
|
256
255
|
or hasattr(x, '__array_namespace__')
|
|
257
256
|
|
|
257
|
+
def _compat_module_name():
|
|
258
|
+
assert __name__.endswith('.common._helpers')
|
|
259
|
+
return __name__.removesuffix('.common._helpers')
|
|
260
|
+
|
|
261
|
+
def is_numpy_namespace(xp) -> bool:
|
|
262
|
+
"""
|
|
263
|
+
Returns True if `xp` is a NumPy namespace.
|
|
264
|
+
|
|
265
|
+
This includes both NumPy itself and the version wrapped by array-api-compat.
|
|
266
|
+
|
|
267
|
+
See Also
|
|
268
|
+
--------
|
|
269
|
+
|
|
270
|
+
array_namespace
|
|
271
|
+
is_cupy_namespace
|
|
272
|
+
is_torch_namespace
|
|
273
|
+
is_ndonnx_namespace
|
|
274
|
+
is_dask_namespace
|
|
275
|
+
is_jax_namespace
|
|
276
|
+
is_pydata_sparse_namespace
|
|
277
|
+
is_array_api_strict_namespace
|
|
278
|
+
"""
|
|
279
|
+
return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
|
|
280
|
+
|
|
281
|
+
def is_cupy_namespace(xp) -> bool:
|
|
282
|
+
"""
|
|
283
|
+
Returns True if `xp` is a CuPy namespace.
|
|
284
|
+
|
|
285
|
+
This includes both CuPy itself and the version wrapped by array-api-compat.
|
|
286
|
+
|
|
287
|
+
See Also
|
|
288
|
+
--------
|
|
289
|
+
|
|
290
|
+
array_namespace
|
|
291
|
+
is_numpy_namespace
|
|
292
|
+
is_torch_namespace
|
|
293
|
+
is_ndonnx_namespace
|
|
294
|
+
is_dask_namespace
|
|
295
|
+
is_jax_namespace
|
|
296
|
+
is_pydata_sparse_namespace
|
|
297
|
+
is_array_api_strict_namespace
|
|
298
|
+
"""
|
|
299
|
+
return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
|
|
300
|
+
|
|
301
|
+
def is_torch_namespace(xp) -> bool:
|
|
302
|
+
"""
|
|
303
|
+
Returns True if `xp` is a PyTorch namespace.
|
|
304
|
+
|
|
305
|
+
This includes both PyTorch itself and the version wrapped by array-api-compat.
|
|
306
|
+
|
|
307
|
+
See Also
|
|
308
|
+
--------
|
|
309
|
+
|
|
310
|
+
array_namespace
|
|
311
|
+
is_numpy_namespace
|
|
312
|
+
is_cupy_namespace
|
|
313
|
+
is_ndonnx_namespace
|
|
314
|
+
is_dask_namespace
|
|
315
|
+
is_jax_namespace
|
|
316
|
+
is_pydata_sparse_namespace
|
|
317
|
+
is_array_api_strict_namespace
|
|
318
|
+
"""
|
|
319
|
+
return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def is_ndonnx_namespace(xp):
|
|
323
|
+
"""
|
|
324
|
+
Returns True if `xp` is an NDONNX namespace.
|
|
325
|
+
|
|
326
|
+
See Also
|
|
327
|
+
--------
|
|
328
|
+
|
|
329
|
+
array_namespace
|
|
330
|
+
is_numpy_namespace
|
|
331
|
+
is_cupy_namespace
|
|
332
|
+
is_torch_namespace
|
|
333
|
+
is_dask_namespace
|
|
334
|
+
is_jax_namespace
|
|
335
|
+
is_pydata_sparse_namespace
|
|
336
|
+
is_array_api_strict_namespace
|
|
337
|
+
"""
|
|
338
|
+
return xp.__name__ == 'ndonnx'
|
|
339
|
+
|
|
340
|
+
def is_dask_namespace(xp):
|
|
341
|
+
"""
|
|
342
|
+
Returns True if `xp` is a Dask namespace.
|
|
343
|
+
|
|
344
|
+
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
|
|
345
|
+
|
|
346
|
+
See Also
|
|
347
|
+
--------
|
|
348
|
+
|
|
349
|
+
array_namespace
|
|
350
|
+
is_numpy_namespace
|
|
351
|
+
is_cupy_namespace
|
|
352
|
+
is_torch_namespace
|
|
353
|
+
is_ndonnx_namespace
|
|
354
|
+
is_jax_namespace
|
|
355
|
+
is_pydata_sparse_namespace
|
|
356
|
+
is_array_api_strict_namespace
|
|
357
|
+
"""
|
|
358
|
+
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
|
|
359
|
+
|
|
360
|
+
def is_jax_namespace(xp):
|
|
361
|
+
"""
|
|
362
|
+
Returns True if `xp` is a JAX namespace.
|
|
363
|
+
|
|
364
|
+
This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
|
|
365
|
+
older versions of JAX.
|
|
366
|
+
|
|
367
|
+
See Also
|
|
368
|
+
--------
|
|
369
|
+
|
|
370
|
+
array_namespace
|
|
371
|
+
is_numpy_namespace
|
|
372
|
+
is_cupy_namespace
|
|
373
|
+
is_torch_namespace
|
|
374
|
+
is_ndonnx_namespace
|
|
375
|
+
is_dask_namespace
|
|
376
|
+
is_pydata_sparse_namespace
|
|
377
|
+
is_array_api_strict_namespace
|
|
378
|
+
"""
|
|
379
|
+
return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
|
|
380
|
+
|
|
381
|
+
def is_pydata_sparse_namespace(xp):
|
|
382
|
+
"""
|
|
383
|
+
Returns True if `xp` is a pydata/sparse namespace.
|
|
384
|
+
|
|
385
|
+
See Also
|
|
386
|
+
--------
|
|
387
|
+
|
|
388
|
+
array_namespace
|
|
389
|
+
is_numpy_namespace
|
|
390
|
+
is_cupy_namespace
|
|
391
|
+
is_torch_namespace
|
|
392
|
+
is_ndonnx_namespace
|
|
393
|
+
is_dask_namespace
|
|
394
|
+
is_jax_namespace
|
|
395
|
+
is_array_api_strict_namespace
|
|
396
|
+
"""
|
|
397
|
+
return xp.__name__ == 'sparse'
|
|
398
|
+
|
|
399
|
+
def is_array_api_strict_namespace(xp):
|
|
400
|
+
"""
|
|
401
|
+
Returns True if `xp` is an array-api-strict namespace.
|
|
402
|
+
|
|
403
|
+
See Also
|
|
404
|
+
--------
|
|
405
|
+
|
|
406
|
+
array_namespace
|
|
407
|
+
is_numpy_namespace
|
|
408
|
+
is_cupy_namespace
|
|
409
|
+
is_torch_namespace
|
|
410
|
+
is_ndonnx_namespace
|
|
411
|
+
is_dask_namespace
|
|
412
|
+
is_jax_namespace
|
|
413
|
+
is_pydata_sparse_namespace
|
|
414
|
+
"""
|
|
415
|
+
return xp.__name__ == 'array_api_strict'
|
|
416
|
+
|
|
258
417
|
def _check_api_version(api_version):
|
|
259
418
|
if api_version == '2021.12':
|
|
260
419
|
warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
|
|
@@ -340,12 +499,9 @@ def array_namespace(*xs, api_version=None, use_compat=None):
|
|
|
340
499
|
elif use_compat is False:
|
|
341
500
|
namespaces.add(np)
|
|
342
501
|
else:
|
|
343
|
-
# numpy 2.0
|
|
502
|
+
# numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
|
|
344
503
|
# compatible.
|
|
345
|
-
|
|
346
|
-
namespaces.add(x.__array_namespace__(api_version=api_version))
|
|
347
|
-
else:
|
|
348
|
-
namespaces.add(numpy_namespace)
|
|
504
|
+
namespaces.add(numpy_namespace)
|
|
349
505
|
elif is_cupy_array(x):
|
|
350
506
|
if _use_compat:
|
|
351
507
|
_check_api_version(api_version)
|
|
@@ -377,9 +533,13 @@ def array_namespace(*xs, api_version=None, use_compat=None):
|
|
|
377
533
|
elif use_compat is False:
|
|
378
534
|
import jax.numpy as jnp
|
|
379
535
|
else:
|
|
380
|
-
#
|
|
381
|
-
#
|
|
382
|
-
import jax.
|
|
536
|
+
# JAX v0.4.32 and newer implements the array API directly in jax.numpy.
|
|
537
|
+
# For older JAX versions, it is available via jax.experimental.array_api.
|
|
538
|
+
import jax.numpy
|
|
539
|
+
if hasattr(jax.numpy, "__array_api_version__"):
|
|
540
|
+
jnp = jax.numpy
|
|
541
|
+
else:
|
|
542
|
+
import jax.experimental.array_api as jnp
|
|
383
543
|
namespaces.add(jnp)
|
|
384
544
|
elif is_pydata_sparse_array(x):
|
|
385
545
|
if use_compat is True:
|
|
@@ -613,8 +773,9 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
|
|
|
613
773
|
return x
|
|
614
774
|
raise ValueError(f"Unsupported device {device!r}")
|
|
615
775
|
elif is_jax_array(x):
|
|
616
|
-
|
|
617
|
-
|
|
776
|
+
if not hasattr(x, "__array_namespace__"):
|
|
777
|
+
# In JAX v0.4.31 and older, this import adds to_device method to x.
|
|
778
|
+
import jax.experimental.array_api # noqa: F401
|
|
618
779
|
return x.to_device(device, stream=stream)
|
|
619
780
|
elif is_pydata_sparse_array(x) and device == _device(x):
|
|
620
781
|
# Perform trivial check to return the same array if
|
|
@@ -641,13 +802,21 @@ __all__ = [
|
|
|
641
802
|
"device",
|
|
642
803
|
"get_namespace",
|
|
643
804
|
"is_array_api_obj",
|
|
805
|
+
"is_array_api_strict_namespace",
|
|
644
806
|
"is_cupy_array",
|
|
807
|
+
"is_cupy_namespace",
|
|
645
808
|
"is_dask_array",
|
|
809
|
+
"is_dask_namespace",
|
|
646
810
|
"is_jax_array",
|
|
811
|
+
"is_jax_namespace",
|
|
647
812
|
"is_numpy_array",
|
|
813
|
+
"is_numpy_namespace",
|
|
648
814
|
"is_torch_array",
|
|
815
|
+
"is_torch_namespace",
|
|
649
816
|
"is_ndonnx_array",
|
|
817
|
+
"is_ndonnx_namespace",
|
|
650
818
|
"is_pydata_sparse_array",
|
|
819
|
+
"is_pydata_sparse_namespace",
|
|
651
820
|
"size",
|
|
652
821
|
"to_device",
|
|
653
822
|
]
|
|
@@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
|
|
|
147
147
|
return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
|
|
148
148
|
|
|
149
149
|
def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
|
|
150
|
-
if dtype is None:
|
|
151
|
-
if x.dtype == xp.float32:
|
|
152
|
-
dtype = xp.float64
|
|
153
|
-
elif x.dtype == xp.complex64:
|
|
154
|
-
dtype = xp.complex128
|
|
155
150
|
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
|
|
156
151
|
|
|
157
152
|
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
|
|
@@ -5,6 +5,8 @@ import cupy as cp
|
|
|
5
5
|
from ..common import _aliases
|
|
6
6
|
from .._internal import get_xp
|
|
7
7
|
|
|
8
|
+
from ._info import __array_namespace_info__
|
|
9
|
+
|
|
8
10
|
from typing import TYPE_CHECKING
|
|
9
11
|
if TYPE_CHECKING:
|
|
10
12
|
from typing import Optional, Union
|
|
@@ -47,14 +49,13 @@ unique_values = get_xp(cp)(_aliases.unique_values)
|
|
|
47
49
|
astype = _aliases.astype
|
|
48
50
|
std = get_xp(cp)(_aliases.std)
|
|
49
51
|
var = get_xp(cp)(_aliases.var)
|
|
52
|
+
cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
|
|
50
53
|
clip = get_xp(cp)(_aliases.clip)
|
|
51
54
|
permute_dims = get_xp(cp)(_aliases.permute_dims)
|
|
52
55
|
reshape = get_xp(cp)(_aliases.reshape)
|
|
53
56
|
argsort = get_xp(cp)(_aliases.argsort)
|
|
54
57
|
sort = get_xp(cp)(_aliases.sort)
|
|
55
58
|
nonzero = get_xp(cp)(_aliases.nonzero)
|
|
56
|
-
sum = get_xp(cp)(_aliases.sum)
|
|
57
|
-
prod = get_xp(cp)(_aliases.prod)
|
|
58
59
|
ceil = get_xp(cp)(_aliases.ceil)
|
|
59
60
|
floor = get_xp(cp)(_aliases.floor)
|
|
60
61
|
trunc = get_xp(cp)(_aliases.trunc)
|
|
@@ -108,20 +109,34 @@ def asarray(
|
|
|
108
109
|
|
|
109
110
|
return cp.array(obj, dtype=dtype, **kwargs)
|
|
110
111
|
|
|
112
|
+
def sign(x: ndarray, /) -> ndarray:
|
|
113
|
+
# CuPy sign() does not propagate nans. See
|
|
114
|
+
# https://github.com/data-apis/array-api-compat/issues/136
|
|
115
|
+
out = cp.sign(x)
|
|
116
|
+
out[cp.isnan(x)] = cp.nan
|
|
117
|
+
return out
|
|
118
|
+
|
|
111
119
|
# These functions are completely new here. If the library already has them
|
|
112
120
|
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
|
113
121
|
if hasattr(cp, 'vecdot'):
|
|
114
122
|
vecdot = cp.vecdot
|
|
115
123
|
else:
|
|
116
124
|
vecdot = get_xp(cp)(_aliases.vecdot)
|
|
125
|
+
|
|
117
126
|
if hasattr(cp, 'isdtype'):
|
|
118
127
|
isdtype = cp.isdtype
|
|
119
128
|
else:
|
|
120
129
|
isdtype = get_xp(cp)(_aliases.isdtype)
|
|
121
130
|
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
131
|
+
if hasattr(cp, 'unstack'):
|
|
132
|
+
unstack = cp.unstack
|
|
133
|
+
else:
|
|
134
|
+
unstack = get_xp(cp)(_aliases.unstack)
|
|
135
|
+
|
|
136
|
+
__all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
|
|
137
|
+
'acos', 'acosh', 'asin', 'asinh', 'atan',
|
|
138
|
+
'atan2', 'atanh', 'bitwise_left_shift',
|
|
139
|
+
'bitwise_invert', 'bitwise_right_shift',
|
|
140
|
+
'concat', 'pow', 'sign']
|
|
126
141
|
|
|
127
142
|
_all_ignore = ['cp', 'get_xp']
|