array-api-compat 1.10.0__tar.gz → 1.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {array_api_compat-1.10.0 → array_api_compat-1.11}/PKG-INFO +11 -2
  2. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/__init__.py +1 -1
  3. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_aliases.py +31 -6
  4. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_fft.py +27 -5
  5. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_helpers.py +131 -43
  6. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_typing.py +3 -0
  7. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/__init__.py +1 -1
  8. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/_aliases.py +19 -4
  9. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/_info.py +1 -1
  10. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/array/__init__.py +1 -1
  11. array_api_compat-1.11/array_api_compat/dask/array/_aliases.py +348 -0
  12. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/array/_info.py +1 -1
  13. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/__init__.py +1 -1
  14. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/_aliases.py +15 -3
  15. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/_info.py +1 -1
  16. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/__init__.py +1 -1
  17. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/_aliases.py +63 -6
  18. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/_info.py +1 -1
  19. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/PKG-INFO +11 -2
  20. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/SOURCES.txt +2 -0
  21. {array_api_compat-1.10.0 → array_api_compat-1.11}/setup.py +1 -1
  22. {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_all.py +8 -3
  23. {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_array_namespace.py +10 -8
  24. {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_common.py +143 -25
  25. array_api_compat-1.11/tests/test_dask.py +179 -0
  26. array_api_compat-1.11/tests/test_jax.py +34 -0
  27. array_api_compat-1.10.0/array_api_compat/dask/array/_aliases.py +0 -219
  28. {array_api_compat-1.10.0 → array_api_compat-1.11}/LICENSE +0 -0
  29. {array_api_compat-1.10.0 → array_api_compat-1.11}/README.md +0 -0
  30. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/_internal.py +0 -0
  31. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/__init__.py +0 -0
  32. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/common/_linalg.py +0 -0
  33. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/_typing.py +0 -0
  34. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/fft.py +0 -0
  35. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/cupy/linalg.py +0 -0
  36. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/__init__.py +0 -0
  37. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/array/fft.py +0 -0
  38. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/dask/array/linalg.py +0 -0
  39. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/_typing.py +0 -0
  40. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/fft.py +0 -0
  41. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/numpy/linalg.py +0 -0
  42. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/fft.py +0 -0
  43. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat/torch/linalg.py +0 -0
  44. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/dependency_links.txt +0 -0
  45. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/requires.txt +0 -0
  46. {array_api_compat-1.10.0 → array_api_compat-1.11}/array_api_compat.egg-info/top_level.txt +0 -0
  47. {array_api_compat-1.10.0 → array_api_compat-1.11}/setup.cfg +0 -0
  48. {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_isdtype.py +0 -0
  49. {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_no_dependencies.py +0 -0
  50. {array_api_compat-1.10.0 → array_api_compat-1.11}/tests/test_vendoring.py +0 -0
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: array_api_compat
3
- Version: 1.10.0
3
+ Version: 1.11
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
@@ -28,6 +28,15 @@ Provides-Extra: dask
28
28
  Requires-Dist: dask; extra == "dask"
29
29
  Provides-Extra: sparse
30
30
  Requires-Dist: sparse>=0.15.1; extra == "sparse"
31
+ Dynamic: author
32
+ Dynamic: classifier
33
+ Dynamic: description
34
+ Dynamic: description-content-type
35
+ Dynamic: home-page
36
+ Dynamic: license
37
+ Dynamic: provides-extra
38
+ Dynamic: requires-python
39
+ Dynamic: summary
31
40
 
32
41
  # Array API compatibility library
33
42
 
@@ -17,6 +17,6 @@ to ensure they are not using functionality outside of the standard, but prefer
17
17
  this implementation for the default when working with NumPy arrays.
18
18
 
19
19
  """
20
- __version__ = '1.10.0'
20
+ __version__ = '1.11'
21
21
 
22
22
  from .common import * # noqa: F401, F403
@@ -233,11 +233,6 @@ def unique_values(x: ndarray, /, xp) -> ndarray:
233
233
  **kwargs,
234
234
  )
235
235
 
236
- def astype(x: ndarray, dtype: Dtype, /, *, copy: bool = True) -> ndarray:
237
- if not copy and dtype == x.dtype:
238
- return x
239
- return x.astype(dtype=dtype, copy=copy)
240
-
241
236
  # These functions have different keyword argument names
242
237
 
243
238
  def std(
@@ -297,6 +292,36 @@ def cumulative_sum(
297
292
  )
298
293
  return res
299
294
 
295
+
296
+ def cumulative_prod(
297
+ x: ndarray,
298
+ /,
299
+ xp,
300
+ *,
301
+ axis: Optional[int] = None,
302
+ dtype: Optional[Dtype] = None,
303
+ include_initial: bool = False,
304
+ **kwargs
305
+ ) -> ndarray:
306
+ wrapped_xp = array_namespace(x)
307
+
308
+ if axis is None:
309
+ if x.ndim > 1:
310
+ raise ValueError("axis must be specified in cumulative_prod for more than one dimension")
311
+ axis = 0
312
+
313
+ res = xp.cumprod(x, axis=axis, dtype=dtype, **kwargs)
314
+
315
+ # np.cumprod does not support include_initial
316
+ if include_initial:
317
+ initial_shape = list(x.shape)
318
+ initial_shape[axis] = 1
319
+ res = xp.concatenate(
320
+ [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res],
321
+ axis=axis,
322
+ )
323
+ return res
324
+
300
325
  # The min and max argument names in clip are different and not optional in numpy, and type
301
326
  # promotion behavior is different.
302
327
  def clip(
@@ -549,7 +574,7 @@ __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
549
574
  'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
550
575
  'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
551
576
  'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
552
- 'astype', 'std', 'var', 'cumulative_sum', 'clip', 'permute_dims',
577
+ 'std', 'var', 'cumulative_sum', 'cumulative_prod','clip', 'permute_dims',
553
578
  'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
554
579
  'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
555
580
  'unstack', 'sign']
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING, Union, Optional, Literal
4
4
 
5
5
  if TYPE_CHECKING:
6
- from ._typing import Device, ndarray
6
+ from ._typing import Device, ndarray, DType
7
7
  from collections.abc import Sequence
8
8
 
9
9
  # Note: NumPy fft functions improperly upcast float32 and complex64 to
@@ -149,15 +149,37 @@ def ihfft(
149
149
  return res.astype(xp.complex64)
150
150
  return res
151
151
 
152
- def fftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
152
+ def fftfreq(
153
+ n: int,
154
+ /,
155
+ xp,
156
+ *,
157
+ d: float = 1.0,
158
+ dtype: Optional[DType] = None,
159
+ device: Optional[Device] = None
160
+ ) -> ndarray:
153
161
  if device not in ["cpu", None]:
154
162
  raise ValueError(f"Unsupported device {device!r}")
155
- return xp.fft.fftfreq(n, d=d)
163
+ res = xp.fft.fftfreq(n, d=d)
164
+ if dtype is not None:
165
+ return res.astype(dtype)
166
+ return res
156
167
 
157
- def rfftfreq(n: int, /, xp, *, d: float = 1.0, device: Optional[Device] = None) -> ndarray:
168
+ def rfftfreq(
169
+ n: int,
170
+ /,
171
+ xp,
172
+ *,
173
+ d: float = 1.0,
174
+ dtype: Optional[DType] = None,
175
+ device: Optional[Device] = None
176
+ ) -> ndarray:
158
177
  if device not in ["cpu", None]:
159
178
  raise ValueError(f"Unsupported device {device!r}")
160
- return xp.fft.rfftfreq(n, d=d)
179
+ res = xp.fft.rfftfreq(n, d=d)
180
+ if dtype is not None:
181
+ return res.astype(dtype)
182
+ return res
161
183
 
162
184
  def fftshift(x: ndarray, /, xp, *, axes: Union[int, Sequence[int]] = None) -> ndarray:
163
185
  return xp.fft.fftshift(x, axes=axes)
@@ -11,14 +11,14 @@ from typing import TYPE_CHECKING
11
11
 
12
12
  if TYPE_CHECKING:
13
13
  from typing import Optional, Union, Any
14
- from ._typing import Array, Device
14
+ from ._typing import Array, Device, Namespace
15
15
 
16
16
  import sys
17
17
  import math
18
18
  import inspect
19
19
  import warnings
20
20
 
21
- def _is_jax_zero_gradient_array(x):
21
+ def _is_jax_zero_gradient_array(x: object) -> bool:
22
22
  """Return True if `x` is a zero-gradient array.
23
23
 
24
24
  These arrays are a design quirk of Jax that may one day be removed.
@@ -32,7 +32,8 @@ def _is_jax_zero_gradient_array(x):
32
32
 
33
33
  return isinstance(x, np.ndarray) and x.dtype == jax.float0
34
34
 
35
- def is_numpy_array(x):
35
+
36
+ def is_numpy_array(x: object) -> bool:
36
37
  """
37
38
  Return True if `x` is a NumPy array.
38
39
 
@@ -63,7 +64,8 @@ def is_numpy_array(x):
63
64
  return (isinstance(x, (np.ndarray, np.generic))
64
65
  and not _is_jax_zero_gradient_array(x))
65
66
 
66
- def is_cupy_array(x):
67
+
68
+ def is_cupy_array(x: object) -> bool:
67
69
  """
68
70
  Return True if `x` is a CuPy array.
69
71
 
@@ -93,7 +95,8 @@ def is_cupy_array(x):
93
95
  # TODO: Should we reject ndarray subclasses?
94
96
  return isinstance(x, cp.ndarray)
95
97
 
96
- def is_torch_array(x):
98
+
99
+ def is_torch_array(x: object) -> bool:
97
100
  """
98
101
  Return True if `x` is a PyTorch tensor.
99
102
 
@@ -120,7 +123,8 @@ def is_torch_array(x):
120
123
  # TODO: Should we reject ndarray subclasses?
121
124
  return isinstance(x, torch.Tensor)
122
125
 
123
- def is_ndonnx_array(x):
126
+
127
+ def is_ndonnx_array(x: object) -> bool:
124
128
  """
125
129
  Return True if `x` is a ndonnx Array.
126
130
 
@@ -147,7 +151,8 @@ def is_ndonnx_array(x):
147
151
 
148
152
  return isinstance(x, ndx.Array)
149
153
 
150
- def is_dask_array(x):
154
+
155
+ def is_dask_array(x: object) -> bool:
151
156
  """
152
157
  Return True if `x` is a dask.array Array.
153
158
 
@@ -174,7 +179,8 @@ def is_dask_array(x):
174
179
 
175
180
  return isinstance(x, dask.array.Array)
176
181
 
177
- def is_jax_array(x):
182
+
183
+ def is_jax_array(x: object) -> bool:
178
184
  """
179
185
  Return True if `x` is a JAX array.
180
186
 
@@ -202,6 +208,7 @@ def is_jax_array(x):
202
208
 
203
209
  return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
204
210
 
211
+
205
212
  def is_pydata_sparse_array(x) -> bool:
206
213
  """
207
214
  Return True if `x` is an array from the `sparse` package.
@@ -231,7 +238,8 @@ def is_pydata_sparse_array(x) -> bool:
231
238
  # TODO: Account for other backends.
232
239
  return isinstance(x, sparse.SparseArray)
233
240
 
234
- def is_array_api_obj(x):
241
+
242
+ def is_array_api_obj(x: object) -> bool:
235
243
  """
236
244
  Return True if `x` is an array API compatible array object.
237
245
 
@@ -254,10 +262,12 @@ def is_array_api_obj(x):
254
262
  or is_pydata_sparse_array(x) \
255
263
  or hasattr(x, '__array_namespace__')
256
264
 
257
- def _compat_module_name():
265
+
266
+ def _compat_module_name() -> str:
258
267
  assert __name__.endswith('.common._helpers')
259
268
  return __name__.removesuffix('.common._helpers')
260
269
 
270
+
261
271
  def is_numpy_namespace(xp) -> bool:
262
272
  """
263
273
  Returns True if `xp` is a NumPy namespace.
@@ -278,6 +288,7 @@ def is_numpy_namespace(xp) -> bool:
278
288
  """
279
289
  return xp.__name__ in {'numpy', _compat_module_name() + '.numpy'}
280
290
 
291
+
281
292
  def is_cupy_namespace(xp) -> bool:
282
293
  """
283
294
  Returns True if `xp` is a CuPy namespace.
@@ -298,6 +309,7 @@ def is_cupy_namespace(xp) -> bool:
298
309
  """
299
310
  return xp.__name__ in {'cupy', _compat_module_name() + '.cupy'}
300
311
 
312
+
301
313
  def is_torch_namespace(xp) -> bool:
302
314
  """
303
315
  Returns True if `xp` is a PyTorch namespace.
@@ -319,7 +331,7 @@ def is_torch_namespace(xp) -> bool:
319
331
  return xp.__name__ in {'torch', _compat_module_name() + '.torch'}
320
332
 
321
333
 
322
- def is_ndonnx_namespace(xp):
334
+ def is_ndonnx_namespace(xp) -> bool:
323
335
  """
324
336
  Returns True if `xp` is an NDONNX namespace.
325
337
 
@@ -337,7 +349,8 @@ def is_ndonnx_namespace(xp):
337
349
  """
338
350
  return xp.__name__ == 'ndonnx'
339
351
 
340
- def is_dask_namespace(xp):
352
+
353
+ def is_dask_namespace(xp) -> bool:
341
354
  """
342
355
  Returns True if `xp` is a Dask namespace.
343
356
 
@@ -357,7 +370,8 @@ def is_dask_namespace(xp):
357
370
  """
358
371
  return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
359
372
 
360
- def is_jax_namespace(xp):
373
+
374
+ def is_jax_namespace(xp) -> bool:
361
375
  """
362
376
  Returns True if `xp` is a JAX namespace.
363
377
 
@@ -378,7 +392,8 @@ def is_jax_namespace(xp):
378
392
  """
379
393
  return xp.__name__ in {'jax.numpy', 'jax.experimental.array_api'}
380
394
 
381
- def is_pydata_sparse_namespace(xp):
395
+
396
+ def is_pydata_sparse_namespace(xp) -> bool:
382
397
  """
383
398
  Returns True if `xp` is a pydata/sparse namespace.
384
399
 
@@ -396,7 +411,8 @@ def is_pydata_sparse_namespace(xp):
396
411
  """
397
412
  return xp.__name__ == 'sparse'
398
413
 
399
- def is_array_api_strict_namespace(xp):
414
+
415
+ def is_array_api_strict_namespace(xp) -> bool:
400
416
  """
401
417
  Returns True if `xp` is an array-api-strict namespace.
402
418
 
@@ -414,14 +430,16 @@ def is_array_api_strict_namespace(xp):
414
430
  """
415
431
  return xp.__name__ == 'array_api_strict'
416
432
 
417
- def _check_api_version(api_version):
418
- if api_version in ['2021.12', '2022.12']:
419
- warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2023.12")
433
+
434
+ def _check_api_version(api_version: str) -> None:
435
+ if api_version in ['2021.12', '2022.12', '2023.12']:
436
+ warnings.warn(f"The {api_version} version of the array API specification was requested but the returned namespace is actually version 2024.12")
420
437
  elif api_version is not None and api_version not in ['2021.12', '2022.12',
421
- '2023.12']:
422
- raise ValueError("Only the 2023.12 version of the array API specification is currently supported")
438
+ '2023.12', '2024.12']:
439
+ raise ValueError("Only the 2024.12 version of the array API specification is currently supported")
423
440
 
424
- def array_namespace(*xs, api_version=None, use_compat=None):
441
+
442
+ def array_namespace(*xs, api_version=None, use_compat=None) -> Namespace:
425
443
  """
426
444
  Get the array API compatible namespace for the arrays `xs`.
427
445
 
@@ -433,7 +451,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):
433
451
 
434
452
  api_version: str
435
453
  The newest version of the spec that you need support for (currently
436
- the compat library wrapped APIs support v2023.12).
454
+ the compat library wrapped APIs support v2024.12).
437
455
 
438
456
  use_compat: bool or None
439
457
  If None (the default), the native namespace will be returned if it is
@@ -630,24 +648,24 @@ def device(x: Array, /) -> Device:
630
648
  if is_numpy_array(x):
631
649
  return "cpu"
632
650
  elif is_dask_array(x):
633
- # Peek at the metadata of the jax array to determine type
634
- try:
635
- import numpy as np
636
- if isinstance(x._meta, np.ndarray):
637
- # Must be on CPU since backed by numpy
638
- return "cpu"
639
- except ImportError:
640
- pass
651
+ # Peek at the metadata of the Dask array to determine type
652
+ if is_numpy_array(x._meta):
653
+ # Must be on CPU since backed by numpy
654
+ return "cpu"
641
655
  return _DASK_DEVICE
642
656
  elif is_jax_array(x):
643
- # JAX has .device() as a method, but it is being deprecated so that it
644
- # can become a property, in accordance with the standard. In order for
645
- # this function to not break when JAX makes the flip, we check for
646
- # both here.
647
- if inspect.ismethod(x.device):
648
- return x.device()
657
+ # FIXME Jitted JAX arrays do not have a device attribute
658
+ # https://github.com/jax-ml/jax/issues/26000
659
+ # Return None in this case. Note that this workaround breaks
660
+ # the standard and will result in new arrays being created on the
661
+ # default device instead of the same device as the input array(s).
662
+ x_device = getattr(x, 'device', None)
663
+ # Older JAX releases had .device() as a method, which has been replaced
664
+ # with a property in accordance with the standard.
665
+ if inspect.ismethod(x_device):
666
+ return x_device()
649
667
  else:
650
- return x.device
668
+ return x_device
651
669
  elif is_pydata_sparse_array(x):
652
670
  # `sparse` will gain `.device`, so check for this first.
653
671
  x_device = getattr(x, 'device', None)
@@ -778,8 +796,11 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
778
796
  raise ValueError(f"Unsupported device {device!r}")
779
797
  elif is_jax_array(x):
780
798
  if not hasattr(x, "__array_namespace__"):
781
- # In JAX v0.4.31 and older, this import adds to_device method to x.
799
+ # In JAX v0.4.31 and older, this import adds to_device method to x...
782
800
  import jax.experimental.array_api # noqa: F401
801
+ # ... but only on eager JAX. It won't work inside jax.jit.
802
+ if not hasattr(x, "to_device"):
803
+ return x
783
804
  return x.to_device(device, stream=stream)
784
805
  elif is_pydata_sparse_array(x) and device == _device(x):
785
806
  # Perform trivial check to return the same array if
@@ -788,24 +809,30 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
788
809
  return x.to_device(device, stream=stream)
789
810
 
790
811
 
791
- def size(x):
812
+ def size(x: Array) -> int | None:
792
813
  """
793
814
  Return the total number of elements of x.
794
815
 
795
816
  This is equivalent to `x.size` according to the `standard
796
817
  <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
818
+
797
819
  This helper is included because PyTorch defines `size` in an
798
820
  :external+torch:meth:`incompatible way <torch.Tensor.size>`.
799
-
821
+ It also fixes dask.array's behaviour which returns nan for unknown sizes, whereas
822
+ the standard requires None.
800
823
  """
824
+ # Lazy API compliant arrays, such as ndonnx, can contain None in their shape
801
825
  if None in x.shape:
802
826
  return None
803
- return math.prod(x.shape)
827
+ out = math.prod(x.shape)
828
+ # dask.array.Array.shape can contain NaN
829
+ return None if math.isnan(out) else out
804
830
 
805
831
 
806
- def is_writeable_array(x) -> bool:
832
+ def is_writeable_array(x: object) -> bool:
807
833
  """
808
834
  Return False if ``x.__setitem__`` is expected to raise; True otherwise.
835
+ Return False if `x` is not an array API compatible object.
809
836
 
810
837
  Warning
811
838
  -------
@@ -816,7 +843,67 @@ def is_writeable_array(x) -> bool:
816
843
  return x.flags.writeable
817
844
  if is_jax_array(x) or is_pydata_sparse_array(x):
818
845
  return False
819
- return True
846
+ return is_array_api_obj(x)
847
+
848
+
849
+ def is_lazy_array(x: object) -> bool:
850
+ """Return True if x is potentially a future or it may be otherwise impossible or
851
+ expensive to eagerly read its contents, regardless of their size, e.g. by
852
+ calling ``bool(x)`` or ``float(x)``.
853
+
854
+ Return False otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be
855
+ cheap as long as the array has the right dtype and size.
856
+
857
+ Note
858
+ ----
859
+ This function errs on the side of caution for array types that may or may not be
860
+ lazy, e.g. JAX arrays, by always returning True for them.
861
+ """
862
+ if (
863
+ is_numpy_array(x)
864
+ or is_cupy_array(x)
865
+ or is_torch_array(x)
866
+ or is_pydata_sparse_array(x)
867
+ ):
868
+ return False
869
+
870
+ # **JAX note:** while it is possible to determine if you're inside or outside
871
+ # jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
872
+ # as we do below for unknown arrays, this is not recommended by JAX best practices.
873
+
874
+ # **Dask note:** Dask eagerly computes the graph on __bool__, __float__, and so on.
875
+ # This behaviour, while impossible to change without breaking backwards
876
+ # compatibility, is highly detrimental to performance as the whole graph will end
877
+ # up being computed multiple times.
878
+
879
+ if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
880
+ return True
881
+
882
+ if not is_array_api_obj(x):
883
+ return False
884
+
885
+ # Unknown Array API compatible object. Note that this test may have dire consequences
886
+ # in terms of performance, e.g. for a lazy object that eagerly computes the graph
887
+ # on __bool__ (dask is one such example, which however is special-cased above).
888
+
889
+ # Select a single point of the array
890
+ s = size(x)
891
+ if s is None:
892
+ return True
893
+ xp = array_namespace(x)
894
+ if s > 1:
895
+ x = xp.reshape(x, (-1,))[0]
896
+ # Cast to dtype=bool and deal with size 0 arrays
897
+ x = xp.any(x)
898
+
899
+ try:
900
+ bool(x)
901
+ return False
902
+ # The Array API standard dictactes that __bool__ should raise TypeError if the
903
+ # output cannot be defined.
904
+ # Here we allow for it to raise arbitrary exceptions, e.g. like Dask does.
905
+ except Exception:
906
+ return True
820
907
 
821
908
 
822
909
  __all__ = [
@@ -840,6 +927,7 @@ __all__ = [
840
927
  "is_pydata_sparse_array",
841
928
  "is_pydata_sparse_namespace",
842
929
  "is_writeable_array",
930
+ "is_lazy_array",
843
931
  "size",
844
932
  "to_device",
845
933
  ]
@@ -5,6 +5,7 @@ __all__ = [
5
5
  "SupportsBufferProtocol",
6
6
  ]
7
7
 
8
+ from types import ModuleType
8
9
  from typing import (
9
10
  Any,
10
11
  TypeVar,
@@ -21,3 +22,5 @@ SupportsBufferProtocol = Any
21
22
 
22
23
  Array = Any
23
24
  Device = Any
25
+ DType = Any
26
+ Namespace = ModuleType
@@ -13,4 +13,4 @@ __import__(__package__ + '.fft')
13
13
 
14
14
  from ..common._helpers import * # noqa: F401,F403
15
15
 
16
- __array_api_version__ = '2023.12'
16
+ __array_api_version__ = '2024.12'
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import cupy as cp
4
4
 
5
- from ..common import _aliases
5
+ from ..common import _aliases, _helpers
6
6
  from .._internal import get_xp
7
7
 
8
8
  from ._info import __array_namespace_info__
@@ -46,10 +46,10 @@ unique_all = get_xp(cp)(_aliases.unique_all)
46
46
  unique_counts = get_xp(cp)(_aliases.unique_counts)
47
47
  unique_inverse = get_xp(cp)(_aliases.unique_inverse)
48
48
  unique_values = get_xp(cp)(_aliases.unique_values)
49
- astype = _aliases.astype
50
49
  std = get_xp(cp)(_aliases.std)
51
50
  var = get_xp(cp)(_aliases.var)
52
51
  cumulative_sum = get_xp(cp)(_aliases.cumulative_sum)
52
+ cumulative_prod = get_xp(cp)(_aliases.cumulative_prod)
53
53
  clip = get_xp(cp)(_aliases.clip)
54
54
  permute_dims = get_xp(cp)(_aliases.permute_dims)
55
55
  reshape = get_xp(cp)(_aliases.reshape)
@@ -110,6 +110,21 @@ def asarray(
110
110
 
111
111
  return cp.array(obj, dtype=dtype, **kwargs)
112
112
 
113
+
114
+ def astype(
115
+ x: ndarray,
116
+ dtype: Dtype,
117
+ /,
118
+ *,
119
+ copy: bool = True,
120
+ device: Optional[Device] = None,
121
+ ) -> ndarray:
122
+ if device is None:
123
+ return x.astype(dtype=dtype, copy=copy)
124
+ out = _helpers.to_device(x.astype(dtype=dtype, copy=False), device)
125
+ return out.copy() if copy and out is x else out
126
+
127
+
113
128
  # These functions are completely new here. If the library already has them
114
129
  # (i.e., numpy 2.0), use the library version instead of our wrapper.
115
130
  if hasattr(cp, 'vecdot'):
@@ -127,10 +142,10 @@ if hasattr(cp, 'unstack'):
127
142
  else:
128
143
  unstack = get_xp(cp)(_aliases.unstack)
129
144
 
130
- __all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'bool',
145
+ __all__ = _aliases.__all__ + ['__array_namespace_info__', 'asarray', 'astype',
131
146
  'acos', 'acosh', 'asin', 'asinh', 'atan',
132
147
  'atan2', 'atanh', 'bitwise_left_shift',
133
148
  'bitwise_invert', 'bitwise_right_shift',
134
- 'concat', 'pow', 'sign']
149
+ 'bool', 'concat', 'pow', 'sign']
135
150
 
136
151
  _all_ignore = ['cp', 'get_xp']
@@ -101,7 +101,7 @@ class __array_namespace_info__:
101
101
  "boolean indexing": True,
102
102
  "data-dependent shapes": True,
103
103
  # 'max rank' will be part of the 2024.12 standard
104
- # "max rank": 64,
104
+ "max dimensions": 64,
105
105
  }
106
106
 
107
107
  def default_device(self):
@@ -3,7 +3,7 @@ from dask.array import * # noqa: F403
3
3
  # These imports may overwrite names from the import * above.
4
4
  from ._aliases import * # noqa: F403
5
5
 
6
- __array_api_version__ = '2023.12'
6
+ __array_api_version__ = '2024.12'
7
7
 
8
8
  __import__(__package__ + '.linalg')
9
9
  __import__(__package__ + '.fft')