array-api-compat 1.8__tar.gz → 1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {array_api_compat-1.8 → array_api_compat-1.9}/PKG-INFO +1 -1
  2. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/__init__.py +1 -1
  3. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_aliases.py +56 -46
  4. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_helpers.py +181 -12
  5. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_linalg.py +0 -5
  6. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/_aliases.py +21 -6
  7. array_api_compat-1.9/array_api_compat/cupy/_info.py +326 -0
  8. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/dask/array/_aliases.py +49 -9
  9. array_api_compat-1.9/array_api_compat/dask/array/_info.py +345 -0
  10. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/dask/array/linalg.py +2 -1
  11. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/_aliases.py +14 -6
  12. array_api_compat-1.9/array_api_compat/numpy/_info.py +346 -0
  13. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/torch/_aliases.py +40 -9
  14. array_api_compat-1.9/array_api_compat/torch/_info.py +358 -0
  15. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/PKG-INFO +1 -1
  16. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/SOURCES.txt +4 -0
  17. {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_all.py +2 -0
  18. {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_array_namespace.py +20 -4
  19. {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_common.py +34 -10
  20. {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_vendoring.py +1 -0
  21. {array_api_compat-1.8 → array_api_compat-1.9}/LICENSE +0 -0
  22. {array_api_compat-1.8 → array_api_compat-1.9}/README.md +0 -0
  23. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/_internal.py +0 -0
  24. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/__init__.py +0 -0
  25. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_fft.py +0 -0
  26. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/common/_typing.py +0 -0
  27. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/__init__.py +0 -0
  28. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/_typing.py +0 -0
  29. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/fft.py +0 -0
  30. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/cupy/linalg.py +0 -0
  31. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/dask/__init__.py +0 -0
  32. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/dask/array/__init__.py +0 -0
  33. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/__init__.py +0 -0
  34. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/_typing.py +0 -0
  35. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/fft.py +0 -0
  36. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/numpy/linalg.py +0 -0
  37. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/torch/__init__.py +0 -0
  38. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/torch/fft.py +0 -0
  39. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat/torch/linalg.py +0 -0
  40. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/dependency_links.txt +0 -0
  41. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/requires.txt +0 -0
  42. {array_api_compat-1.8 → array_api_compat-1.9}/array_api_compat.egg-info/top_level.txt +0 -0
  43. {array_api_compat-1.8 → array_api_compat-1.9}/setup.cfg +0 -0
  44. {array_api_compat-1.8 → array_api_compat-1.9}/setup.py +0 -0
  45. {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_isdtype.py +0 -0
  46. {array_api_compat-1.8 → array_api_compat-1.9}/tests/test_no_dependencies.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: array_api_compat
3
- Version: 1.8
3
+ Version: 1.9
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
@@ -17,6 +17,6 @@ to ensure they are not using functionality outside of the standard, but prefer
17
17
  this implementation for the default when working with NumPy arrays.
18
18
 
19
19
  """
20
- __version__ = '1.8'
20
+ __version__ = '1.9'
21
21
 
22
22
  from .common import * # noqa: F401, F403
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
12
  from typing import NamedTuple
13
13
  import inspect
14
14
 
15
- from ._helpers import array_namespace, _check_device
15
+ from ._helpers import array_namespace, _check_device, device, is_torch_array
16
16
 
17
17
  # These functions are modified from the NumPy versions.
18
18
 
@@ -264,6 +264,38 @@ def var(
264
264
  ) -> ndarray:
265
265
  return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
266
266
 
267
+ # cumulative_sum is renamed from cumsum, and adds the include_initial keyword
268
+ # argument
269
+
270
+ def cumulative_sum(
271
+ x: ndarray,
272
+ /,
273
+ xp,
274
+ *,
275
+ axis: Optional[int] = None,
276
+ dtype: Optional[Dtype] = None,
277
+ include_initial: bool = False,
278
+ **kwargs
279
+ ) -> ndarray:
280
+ wrapped_xp = array_namespace(x)
281
+
282
+ # TODO: The standard is not clear about what should happen when x.ndim == 0.
283
+ if axis is None:
284
+ if x.ndim > 1:
285
+ raise ValueError("axis must be specified in cumulative_sum for more than one dimension")
286
+ axis = 0
287
+
288
+ res = xp.cumsum(x, axis=axis, dtype=dtype, **kwargs)
289
+
290
+ # np.cumsum does not support include_initial
291
+ if include_initial:
292
+ initial_shape = list(x.shape)
293
+ initial_shape[axis] = 1
294
+ res = xp.concatenate(
295
+ [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
296
+ axis=axis,
297
+ )
298
+ return res
267
299
 
268
300
  # The min and max argument names in clip are different and not optional in numpy, and type
269
301
  # promotion behavior is different.
@@ -281,10 +313,11 @@ def clip(
281
313
  return isinstance(a, (int, float, type(None)))
282
314
  min_shape = () if _isscalar(min) else min.shape
283
315
  max_shape = () if _isscalar(max) else max.shape
284
- result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
285
316
 
286
317
  wrapped_xp = array_namespace(x)
287
318
 
319
+ result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
320
+
288
321
  # np.clip does type promotion but the array API clip requires that the
289
322
  # output have the same dtype as x. We do this instead of just downcasting
290
323
  # the result of xp.clip() to handle some corner cases better (e.g.,
@@ -305,20 +338,26 @@ def clip(
305
338
 
306
339
  # At least handle the case of Python integers correctly (see
307
340
  # https://github.com/numpy/numpy/pull/26892).
308
- if type(min) is int and min <= xp.iinfo(x.dtype).min:
341
+ if type(min) is int and min <= wrapped_xp.iinfo(x.dtype).min:
309
342
  min = None
310
- if type(max) is int and max >= xp.iinfo(x.dtype).max:
343
+ if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
311
344
  max = None
312
345
 
313
346
  if out is None:
314
- out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
347
+ out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape),
348
+ copy=True, device=device(x))
315
349
  if min is not None:
316
- a = xp.broadcast_to(xp.asarray(min), result_shape)
350
+ if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(min):
351
+ # Avoid loss of precision due to torch defaulting to float32
352
+ min = wrapped_xp.asarray(min, dtype=xp.float64)
353
+ a = xp.broadcast_to(wrapped_xp.asarray(min, device=device(x)), result_shape)
317
354
  ia = (out < a) | xp.isnan(a)
318
355
  # torch requires an explicit cast here
319
356
  out[ia] = wrapped_xp.astype(a[ia], out.dtype)
320
357
  if max is not None:
321
- b = xp.broadcast_to(xp.asarray(max), result_shape)
358
+ if is_torch_array(x) and x.dtype == xp.float64 and _isscalar(max):
359
+ max = wrapped_xp.asarray(max, dtype=xp.float64)
360
+ b = xp.broadcast_to(wrapped_xp.asarray(max, device=device(x)), result_shape)
322
361
  ib = (out > b) | xp.isnan(b)
323
362
  out[ib] = wrapped_xp.astype(b[ib], out.dtype)
324
363
  # Return a scalar for 0-D
@@ -389,42 +428,6 @@ def nonzero(x: ndarray, /, xp, **kwargs) -> Tuple[ndarray, ...]:
389
428
  raise ValueError("nonzero() does not support zero-dimensional arrays")
390
429
  return xp.nonzero(x, **kwargs)
391
430
 
392
- # sum() and prod() should always upcast when dtype=None
393
- def sum(
394
- x: ndarray,
395
- /,
396
- xp,
397
- *,
398
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
399
- dtype: Optional[Dtype] = None,
400
- keepdims: bool = False,
401
- **kwargs,
402
- ) -> ndarray:
403
- # `xp.sum` already upcasts integers, but not floats or complexes
404
- if dtype is None:
405
- if x.dtype == xp.float32:
406
- dtype = xp.float64
407
- elif x.dtype == xp.complex64:
408
- dtype = xp.complex128
409
- return xp.sum(x, axis=axis, dtype=dtype, keepdims=keepdims, **kwargs)
410
-
411
- def prod(
412
- x: ndarray,
413
- /,
414
- xp,
415
- *,
416
- axis: Optional[Union[int, Tuple[int, ...]]] = None,
417
- dtype: Optional[Dtype] = None,
418
- keepdims: bool = False,
419
- **kwargs,
420
- ) -> ndarray:
421
- if dtype is None:
422
- if x.dtype == xp.float32:
423
- dtype = xp.float64
424
- elif x.dtype == xp.complex64:
425
- dtype = xp.complex128
426
- return xp.prod(x, dtype=dtype, axis=axis, keepdims=keepdims, **kwargs)
427
-
428
431
  # ceil, floor, and trunc return integers for integer inputs
429
432
 
430
433
  def ceil(x: ndarray, /, xp, **kwargs) -> ndarray:
@@ -521,10 +524,17 @@ def isdtype(
521
524
  # array_api_strict implementation will be very strict.
522
525
  return dtype == kind
523
526
 
527
+ # unstack is a new function in the 2023.12 array API standard
528
+ def unstack(x: ndarray, /, xp, *, axis: int = 0) -> Tuple[ndarray, ...]:
529
+ if x.ndim == 0:
530
+ raise ValueError("Input array must be at least 1-d.")
531
+ return tuple(xp.moveaxis(x, axis, 0))
532
+
524
533
  __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
525
534
  'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
526
535
  'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
527
536
  'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
528
- 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
529
- 'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
530
- 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
537
+ 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
538
+ 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
539
+ 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
540
+ 'unstack']
@@ -145,7 +145,7 @@ def is_ndonnx_array(x):
145
145
 
146
146
  import ndonnx as ndx
147
147
 
148
- return isinstance(x, ndx.Array)
148
+ return isinstance(x, ndx.Array)
149
149
 
150
150
  def is_dask_array(x):
151
151
  """
@@ -202,7 +202,6 @@ def is_jax_array(x):
202
202
 
203
203
  return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204
204
 
205
-
206
205
  def is_pydata_sparse_array(x) -> bool:
207
206
  """
208
207
  Return True if `x` is an array from the `sparse` package.
@@ -255,6 +254,166 @@ def is_array_api_obj(x):
255
254
  or is_pydata_sparse_array(x) \
256
255
  or hasattr(x, '__array_namespace__')
257
256
 
257
+ def _compat_module_name():
258
+ assert __name__.endswith('.common._helpers')
259
+ return __name__.removesuffix('.common._helpers')
260
+
261
+ def is_numpy_namespace(xp) -> bool:
262
+ """
263
+ Returns True if `xp` is a NumPy namespace.
264
+
265
+ This includes both NumPy itself and the version wrapped by array-api-compat.
266
+
267
+ See Also
268
+ --------
269
+
270
+ array_namespace
271
+ is_cupy_namespace
272
+ is_torch_namespace
273
+ is_ndonnx_namespace
274
+ is_dask_namespace
275
+ is_jax_namespace
276
+ is_pydata_sparse_namespace
277
+ is_array_api_strict_namespace
278
+ """
279
+ return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
280
+
281
+ def is_cupy_namespace(xp) -> bool:
282
+ """
283
+ Returns True if `xp` is a CuPy namespace.
284
+
285
+ This includes both CuPy itself and the version wrapped by array-api-compat.
286
+
287
+ See Also
288
+ --------
289
+
290
+ array_namespace
291
+ is_numpy_namespace
292
+ is_torch_namespace
293
+ is_ndonnx_namespace
294
+ is_dask_namespace
295
+ is_jax_namespace
296
+ is_pydata_sparse_namespace
297
+ is_array_api_strict_namespace
298
+ """
299
+ return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
300
+
301
+ def is_torch_namespace(xp) -> bool:
302
+ """
303
+ Returns True if `xp` is a PyTorch namespace.
304
+
305
+ This includes both PyTorch itself and the version wrapped by array-api-compat.
306
+
307
+ See Also
308
+ --------
309
+
310
+ array_namespace
311
+ is_numpy_namespace
312
+ is_cupy_namespace
313
+ is_ndonnx_namespace
314
+ is_dask_namespace
315
+ is_jax_namespace
316
+ is_pydata_sparse_namespace
317
+ is_array_api_strict_namespace
318
+ """
319
+ return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320
+
321
+
322
+ def is_ndonnx_namespace(xp):
323
+ """
324
+ Returns True if `xp` is an NDONNX namespace.
325
+
326
+ See Also
327
+ --------
328
+
329
+ array_namespace
330
+ is_numpy_namespace
331
+ is_cupy_namespace
332
+ is_torch_namespace
333
+ is_dask_namespace
334
+ is_jax_namespace
335
+ is_pydata_sparse_namespace
336
+ is_array_api_strict_namespace
337
+ """
338
+ return xp.__name__ == 'ndonnx'
339
+
340
+ def is_dask_namespace(xp):
341
+ """
342
+ Returns True if `xp` is a Dask namespace.
343
+
344
+ This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
345
+
346
+ See Also
347
+ --------
348
+
349
+ array_namespace
350
+ is_numpy_namespace
351
+ is_cupy_namespace
352
+ is_torch_namespace
353
+ is_ndonnx_namespace
354
+ is_jax_namespace
355
+ is_pydata_sparse_namespace
356
+ is_array_api_strict_namespace
357
+ """
358
+ return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
359
+
360
+ def is_jax_namespace(xp):
361
+ """
362
+ Returns True if `xp` is a JAX namespace.
363
+
364
+ This includes ``jax.numpy`` and ``jax.experimental.array_api`` which existed in
365
+ older versions of JAX.
366
+
367
+ See Also
368
+ --------
369
+
370
+ array_namespace
371
+ is_numpy_namespace
372
+ is_cupy_namespace
373
+ is_torch_namespace
374
+ is_ndonnx_namespace
375
+ is_dask_namespace
376
+ is_pydata_sparse_namespace
377
+ is_array_api_strict_namespace
378
+ """
379
+ return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
380
+
381
+ def is_pydata_sparse_namespace(xp):
382
+ """
383
+ Returns True if `xp` is a pydata/sparse namespace.
384
+
385
+ See Also
386
+ --------
387
+
388
+ array_namespace
389
+ is_numpy_namespace
390
+ is_cupy_namespace
391
+ is_torch_namespace
392
+ is_ndonnx_namespace
393
+ is_dask_namespace
394
+ is_jax_namespace
395
+ is_array_api_strict_namespace
396
+ """
397
+ return xp.__name__ == 'sparse'
398
+
399
+ def is_array_api_strict_namespace(xp):
400
+ """
401
+ Returns True if `xp` is an array-api-strict namespace.
402
+
403
+ See Also
404
+ --------
405
+
406
+ array_namespace
407
+ is_numpy_namespace
408
+ is_cupy_namespace
409
+ is_torch_namespace
410
+ is_ndonnx_namespace
411
+ is_dask_namespace
412
+ is_jax_namespace
413
+ is_pydata_sparse_namespace
414
+ """
415
+ return xp.__name__ == 'array_api_strict'
416
+
258
417
  def _check_api_version(api_version):
259
418
  if api_version == '2021.12':
260
419
  warnings.warn("The 2021.12 version of the array API specification was requested but the returned namespace is actually version 2022.12")
@@ -340,12 +499,9 @@ def array_namespace(*xs, api_version=None, use_compat=None):
340
499
  elif use_compat is False:
341
500
  namespaces.add(np)
342
501
  else:
343
- # numpy 2.0 has __array_namespace__ and is fully array API
502
+ # numpy 2.0+ have __array_namespace__, however, they are not yet fully array API
344
503
  # compatible.
345
- if hasattr(x, '__array_namespace__'):
346
- namespaces.add(x.__array_namespace__(api_version=api_version))
347
- else:
348
- namespaces.add(numpy_namespace)
504
+ namespaces.add(numpy_namespace)
349
505
  elif is_cupy_array(x):
350
506
  if _use_compat:
351
507
  _check_api_version(api_version)
@@ -377,9 +533,13 @@ def array_namespace(*xs, api_version=None, use_compat=None):
377
533
  elif use_compat is False:
378
534
  import jax.numpy as jnp
379
535
  else:
380
- # jax.experimental.array_api is already an array namespace. We do
381
- # not have a wrapper submodule for it.
382
- import jax.experimental.array_api as jnp
536
+ # JAX v0.4.32 and newer implements the array API directly in jax.numpy.
537
+ # For older JAX versions, it is available via jax.experimental.array_api.
538
+ import jax.numpy
539
+ if hasattr(jax.numpy, "__array_api_version__"):
540
+ jnp = jax.numpy
541
+ else:
542
+ import jax.experimental.array_api as jnp
383
543
  namespaces.add(jnp)
384
544
  elif is_pydata_sparse_array(x):
385
545
  if use_compat is True:
@@ -613,8 +773,9 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
613
773
  return x
614
774
  raise ValueError(f"Unsupported device {device!r}")
615
775
  elif is_jax_array(x):
616
- # This import adds to_device to x
617
- import jax.experimental.array_api # noqa: F401
776
+ if not hasattr(x, "__array_namespace__"):
777
+ # In JAX v0.4.31 and older, this import adds to_device method to x.
778
+ import jax.experimental.array_api # noqa: F401
618
779
  return x.to_device(device, stream=stream)
619
780
  elif is_pydata_sparse_array(x) and device == _device(x):
620
781
  # Perform trivial check to return the same array if
@@ -641,13 +802,21 @@ __all__ = [
641
802
  "device",
642
803
  "get_namespace",
643
804
  "is_array_api_obj",
805
+ "is_array_api_strict_namespace",
644
806
  "is_cupy_array",
807
+ "is_cupy_namespace",
645
808
  "is_dask_array",
809
+ "is_dask_namespace",
646
810
  "is_jax_array",
811
+ "is_jax_namespace",
647
812
  "is_numpy_array",
813
+ "is_numpy_namespace",
648
814
  "is_torch_array",
815
+ "is_torch_namespace",
649
816
  "is_ndonnx_array",
817
+ "is_ndonnx_namespace",
650
818
  "is_pydata_sparse_array",
819
+ "is_pydata_sparse_namespace",
651
820
  "size",
652
821
  "to_device",
653
822
  ]
@@ -147,11 +147,6 @@ def diagonal(x: ndarray, /, xp, *, offset: int = 0, **kwargs) -> ndarray:
147
147
  return xp.diagonal(x, offset=offset, axis1=-2, axis2=-1, **kwargs)
148
148
 
149
149
  def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarray:
150
- if dtype is None:
151
- if x.dtype == xp.float32:
152
- dtype = xp.float64
153
- elif x.dtype == xp.complex64:
154
- dtype = xp.complex128
155
150
  return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
156
151
 
157
152
  __all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
@@ -5,6 +5,8 @@ import cupy as cp
5
5
  from ..common import _aliases
6
6
  from .._internal import get_xp
7
7
 
8
+ from ._info import __array_namespace_info__
9
+
8
10
  from typing import TYPE_CHECKING
9
11
  if TYPE_CHECKING:
10
12
  from typing import Optional, Union
@@ -47,14 +49,13 @@ unique_values = get_xp(cp)(_aliases.unique_values)
47
49
  astype = _aliases.astype
48
50
  std = get_xp(cp)(_aliases.std)
49
51
  var = get_xp(cp)(_aliases.var)
52
+ cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
50
53
  clip = get_xp(cp)(_aliases.clip)
51
54
  permute_dims = get_xp(cp)(_aliases.permute_dims)
52
55
  reshape = get_xp(cp)(_aliases.reshape)
53
56
  argsort = get_xp(cp)(_aliases.argsort)
54
57
  sort = get_xp(cp)(_aliases.sort)
55
58
  nonzero = get_xp(cp)(_aliases.nonzero)
56
- sum = get_xp(cp)(_aliases.sum)
57
- prod = get_xp(cp)(_aliases.prod)
58
59
  ceil = get_xp(cp)(_aliases.ceil)
59
60
  floor = get_xp(cp)(_aliases.floor)
60
61
  trunc = get_xp(cp)(_aliases.trunc)
@@ -108,20 +109,34 @@ def asarray(
108
109
 
109
110
  return cp.array(obj, dtype=dtype, **kwargs)
110
111
 
112
+ def sign(x: ndarray, /) -> ndarray:
113
+ # CuPy sign() does not propagate nans. See
114
+ # https://github.com/data-apis/array-api-compat/issues/136
115
+ out = cp.sign(x)
116
+ out[cp.isnan(x)] = cp.nan
117
+ return out
118
+
111
119
  # These functions are completely new here. If the library already has them
112
120
  # (i.e., numpy 2.0), use the library version instead of our wrapper.
113
121
  if hasattr(cp, 'vecdot'):
114
122
  vecdot = cp.vecdot
115
123
  else:
116
124
  vecdot = get_xp(cp)(_aliases.vecdot)
125
+
117
126
  if hasattr(cp, 'isdtype'):
118
127
  isdtype = cp.isdtype
119
128
  else:
120
129
  isdtype = get_xp(cp)(_aliases.isdtype)
121
130
 
122
- __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
123
- 'acosh', 'asin', 'asinh', 'atan', 'atan2',
124
- 'atanh', 'bitwise_left_shift', 'bitwise_invert',
125
- 'bitwise_right_shift', 'concat', 'pow']
131
+ if hasattr(cp, 'unstack'):
132
+ unstack = cp.unstack
133
+ else:
134
+ unstack = get_xp(cp)(_aliases.unstack)
135
+
136
+ __all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
137
+ 'acos', 'acosh', 'asin', 'asinh', 'atan',
138
+ 'atan2', 'atanh', 'bitwise_left_shift',
139
+ 'bitwise_invert', 'bitwise_right_shift',
140
+ 'concat', 'pow', 'sign']
126
141
 
127
142
  _all_ignore = ['cp', 'get_xp']