array-api-compat 1.5__tar.gz → 1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. array_api_compat-1.6/PKG-INFO +35 -0
  2. array_api_compat-1.6/README.md +9 -0
  3. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/__init__.py +1 -1
  4. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_aliases.py +4 -89
  5. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_helpers.py +245 -32
  6. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/_aliases.py +51 -6
  7. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/dask/array/_aliases.py +42 -5
  8. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/dask/array/linalg.py +13 -4
  9. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/__init__.py +6 -0
  10. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/_aliases.py +56 -6
  11. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/torch/_aliases.py +4 -0
  12. array_api_compat-1.6/array_api_compat.egg-info/PKG-INFO +35 -0
  13. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat.egg-info/SOURCES.txt +1 -0
  14. array_api_compat-1.6/array_api_compat.egg-info/requires.txt +15 -0
  15. {array_api_compat-1.5 → array_api_compat-1.6}/setup.py +3 -2
  16. {array_api_compat-1.5 → array_api_compat-1.6}/tests/test_all.py +2 -2
  17. {array_api_compat-1.5 → array_api_compat-1.6}/tests/test_array_namespace.py +21 -16
  18. array_api_compat-1.6/tests/test_common.py +184 -0
  19. {array_api_compat-1.5 → array_api_compat-1.6}/tests/test_isdtype.py +3 -3
  20. array_api_compat-1.6/tests/test_no_dependencies.py +73 -0
  21. array_api_compat-1.5/PKG-INFO +0 -421
  22. array_api_compat-1.5/README.md +0 -399
  23. array_api_compat-1.5/array_api_compat.egg-info/PKG-INFO +0 -421
  24. array_api_compat-1.5/array_api_compat.egg-info/requires.txt +0 -6
  25. array_api_compat-1.5/tests/test_common.py +0 -62
  26. {array_api_compat-1.5 → array_api_compat-1.6}/LICENSE +0 -0
  27. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/_internal.py +0 -0
  28. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/__init__.py +0 -0
  29. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_fft.py +0 -0
  30. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_linalg.py +0 -0
  31. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_typing.py +0 -0
  32. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/__init__.py +0 -0
  33. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/_typing.py +0 -0
  34. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/fft.py +0 -0
  35. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/linalg.py +0 -0
  36. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/dask/__init__.py +0 -0
  37. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/dask/array/__init__.py +0 -0
  38. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/_typing.py +0 -0
  39. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/fft.py +0 -0
  40. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/linalg.py +0 -0
  41. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/torch/__init__.py +0 -0
  42. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/torch/fft.py +0 -0
  43. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/torch/linalg.py +0 -0
  44. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat.egg-info/dependency_links.txt +0 -0
  45. {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat.egg-info/top_level.txt +0 -0
  46. {array_api_compat-1.5 → array_api_compat-1.6}/setup.cfg +0 -0
  47. {array_api_compat-1.5 → array_api_compat-1.6}/tests/test_vendoring.py +0 -0
@@ -0,0 +1,35 @@
1
+ Metadata-Version: 2.1
2
+ Name: array_api_compat
3
+ Version: 1.6
4
+ Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
+ Home-page: https://data-apis.org/array-api-compat/
6
+ Author: Consortium for Python Data API Standards
7
+ License: MIT
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.9
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Operating System :: OS Independent
14
+ Description-Content-Type: text/markdown
15
+ License-File: LICENSE
16
+ Provides-Extra: numpy
17
+ Requires-Dist: numpy; extra == "numpy"
18
+ Provides-Extra: cupy
19
+ Requires-Dist: cupy; extra == "cupy"
20
+ Provides-Extra: jax
21
+ Requires-Dist: jax; extra == "jax"
22
+ Provides-Extra: pytorch
23
+ Requires-Dist: pytorch; extra == "pytorch"
24
+ Provides-Extra: dask
25
+ Requires-Dist: dask; extra == "dask"
26
+
27
+ # Array API compatibility library
28
+
29
+ This is a small wrapper around common array libraries that is compatible with
30
+ the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
31
+ NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
32
+ libraries, or if you encounter any issues, please [open an
33
+ issue](https://github.com/data-apis/array-api-compat/issues).
34
+
35
+ See the documentation for more details https://data-apis.org/array-api-compat/
@@ -0,0 +1,9 @@
1
+ # Array API compatibility library
2
+
3
+ This is a small wrapper around common array libraries that is compatible with
4
+ the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5
+ NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
6
+ libraries, or if you encounter any issues, please [open an
7
+ issue](https://github.com/data-apis/array-api-compat/issues).
8
+
9
+ See the documentation for more details https://data-apis.org/array-api-compat/
@@ -17,6 +17,6 @@ to ensure they are not using functionality outside of the standard, but prefer
17
17
  this implementation for the default when working with NumPy arrays.
18
18
 
19
19
  """
20
- __version__ = '1.5'
20
+ __version__ = '1.6'
21
21
 
22
22
  from .common import * # noqa: F401, F403
@@ -6,18 +6,18 @@ from __future__ import annotations
6
6
 
7
7
  from typing import TYPE_CHECKING
8
8
  if TYPE_CHECKING:
9
- import numpy as np
10
9
  from typing import Optional, Sequence, Tuple, Union
11
- from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
10
+ from ._typing import ndarray, Device, Dtype
12
11
 
13
12
  from typing import NamedTuple
14
- from types import ModuleType
15
13
  import inspect
16
14
 
17
- from ._helpers import _check_device, is_numpy_array, array_namespace
15
+ from ._helpers import _check_device
18
16
 
19
17
  # These functions are modified from the NumPy versions.
20
18
 
19
+ # Creation functions add the device keyword (which does nothing for NumPy)
20
+
21
21
  def arange(
22
22
  start: Union[int, float],
23
23
  /,
@@ -268,91 +268,6 @@ def var(
268
268
  def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269
269
  return xp.transpose(x, axes)
270
270
 
271
- # Creation functions add the device keyword (which does nothing for NumPy)
272
-
273
- # asarray also adds the copy keyword
274
- def _asarray(
275
- obj: Union[
276
- ndarray,
277
- bool,
278
- int,
279
- float,
280
- NestedSequence[bool | int | float],
281
- SupportsBufferProtocol,
282
- ],
283
- /,
284
- *,
285
- dtype: Optional[Dtype] = None,
286
- device: Optional[Device] = None,
287
- copy: "Optional[Union[bool, np._CopyMode]]" = None,
288
- namespace = None,
289
- **kwargs,
290
- ) -> ndarray:
291
- """
292
- Array API compatibility wrapper for asarray().
293
-
294
- See the corresponding documentation in NumPy/CuPy and/or the array API
295
- specification for more details.
296
-
297
- """
298
- if namespace is None:
299
- try:
300
- xp = array_namespace(obj, _use_compat=False)
301
- except ValueError:
302
- # TODO: What about lists of arrays?
303
- raise ValueError("A namespace must be specified for asarray() with non-array input")
304
- elif isinstance(namespace, ModuleType):
305
- xp = namespace
306
- elif namespace == 'numpy':
307
- import numpy as xp
308
- elif namespace == 'cupy':
309
- import cupy as xp
310
- elif namespace == 'dask.array':
311
- import dask.array as xp
312
- else:
313
- raise ValueError("Unrecognized namespace argument to asarray()")
314
-
315
- _check_device(xp, device)
316
- if is_numpy_array(obj):
317
- import numpy as np
318
- if hasattr(np, '_CopyMode'):
319
- # Not present in older NumPys
320
- COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
321
- COPY_TRUE = (True, np._CopyMode.ALWAYS)
322
- else:
323
- COPY_FALSE = (False,)
324
- COPY_TRUE = (True,)
325
- else:
326
- COPY_FALSE = (False,)
327
- COPY_TRUE = (True,)
328
- if copy in COPY_FALSE:
329
- # copy=False is not yet implemented in xp.asarray
330
- raise NotImplementedError("copy=False is not yet implemented")
331
- if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
332
- #print('hit me')
333
- if dtype is not None and obj.dtype != dtype:
334
- copy = True
335
- #print(copy)
336
- if copy in COPY_TRUE:
337
- copy_kwargs = {}
338
- if namespace != "dask.array":
339
- copy_kwargs["copy"] = True
340
- else:
341
- # No copy kw in dask.asarray so we go thorugh np.asarray first
342
- # (like dask also does) but copy after
343
- if dtype is None:
344
- # Same dtype copy is no-op in dask
345
- #print("in here?")
346
- return obj.copy()
347
- import numpy as np
348
- #print(obj)
349
- obj = np.asarray(obj).copy()
350
- #print(obj)
351
- return xp.array(obj, dtype=dtype, **copy_kwargs)
352
- return obj
353
-
354
- return xp.asarray(obj, dtype=dtype, **kwargs)
355
-
356
271
  # np.reshape calls the keyword argument 'newshape' instead of 'shape'
357
272
  def reshape(x: ndarray,
358
273
  /,
@@ -19,6 +19,24 @@ import inspect
19
19
  import warnings
20
20
 
21
21
  def is_numpy_array(x):
22
+ """
23
+ Return True if `x` is a NumPy array.
24
+
25
+ This function does not import NumPy if it has not already been imported
26
+ and is therefore cheap to use.
27
+
28
+ This also returns True for `ndarray` subclasses and NumPy scalar objects.
29
+
30
+ See Also
31
+ --------
32
+
33
+ array_namespace
34
+ is_array_api_obj
35
+ is_cupy_array
36
+ is_torch_array
37
+ is_dask_array
38
+ is_jax_array
39
+ """
22
40
  # Avoid importing NumPy if it isn't already
23
41
  if 'numpy' not in sys.modules:
24
42
  return False
@@ -29,6 +47,24 @@ def is_numpy_array(x):
29
47
  return isinstance(x, (np.ndarray, np.generic))
30
48
 
31
49
  def is_cupy_array(x):
50
+ """
51
+ Return True if `x` is a CuPy array.
52
+
53
+ This function does not import CuPy if it has not already been imported
54
+ and is therefore cheap to use.
55
+
56
+ This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects.
57
+
58
+ See Also
59
+ --------
60
+
61
+ array_namespace
62
+ is_array_api_obj
63
+ is_numpy_array
64
+ is_torch_array
65
+ is_dask_array
66
+ is_jax_array
67
+ """
32
68
  # Avoid importing NumPy if it isn't already
33
69
  if 'cupy' not in sys.modules:
34
70
  return False
@@ -39,6 +75,22 @@ def is_cupy_array(x):
39
75
  return isinstance(x, (cp.ndarray, cp.generic))
40
76
 
41
77
  def is_torch_array(x):
78
+ """
79
+ Return True if `x` is a PyTorch tensor.
80
+
81
+ This function does not import PyTorch if it has not already been imported
82
+ and is therefore cheap to use.
83
+
84
+ See Also
85
+ --------
86
+
87
+ array_namespace
88
+ is_array_api_obj
89
+ is_numpy_array
90
+ is_cupy_array
91
+ is_dask_array
92
+ is_jax_array
93
+ """
42
94
  # Avoid importing torch if it isn't already
43
95
  if 'torch' not in sys.modules:
44
96
  return False
@@ -49,6 +101,22 @@ def is_torch_array(x):
49
101
  return isinstance(x, torch.Tensor)
50
102
 
51
103
  def is_dask_array(x):
104
+ """
105
+ Return True if `x` is a dask.array Array.
106
+
107
+ This function does not import dask if it has not already been imported
108
+ and is therefore cheap to use.
109
+
110
+ See Also
111
+ --------
112
+
113
+ array_namespace
114
+ is_array_api_obj
115
+ is_numpy_array
116
+ is_cupy_array
117
+ is_torch_array
118
+ is_jax_array
119
+ """
52
120
  # Avoid importing dask if it isn't already
53
121
  if 'dask.array' not in sys.modules:
54
122
  return False
@@ -58,6 +126,23 @@ def is_dask_array(x):
58
126
  return isinstance(x, dask.array.Array)
59
127
 
60
128
  def is_jax_array(x):
129
+ """
130
+ Return True if `x` is a JAX array.
131
+
132
+ This function does not import JAX if it has not already been imported
133
+ and is therefore cheap to use.
134
+
135
+
136
+ See Also
137
+ --------
138
+
139
+ array_namespace
140
+ is_array_api_obj
141
+ is_numpy_array
142
+ is_cupy_array
143
+ is_torch_array
144
+ is_dask_array
145
+ """
61
146
  # Avoid importing jax if it isn't already
62
147
  if 'jax' not in sys.modules:
63
148
  return False
@@ -68,7 +153,17 @@ def is_jax_array(x):
68
153
 
69
154
  def is_array_api_obj(x):
70
155
  """
71
- Check if x is an array API compatible array object.
156
+ Return True if `x` is an array API compatible array object.
157
+
158
+ See Also
159
+ --------
160
+
161
+ array_namespace
162
+ is_numpy_array
163
+ is_cupy_array
164
+ is_torch_array
165
+ is_dask_array
166
+ is_jax_array
72
167
  """
73
168
  return is_numpy_array(x) \
74
169
  or is_cupy_array(x) \
@@ -83,62 +178,128 @@ def _check_api_version(api_version):
83
178
  elif api_version is not None and api_version != '2022.12':
84
179
  raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
85
180
 
86
- def array_namespace(*xs, api_version=None, _use_compat=True):
181
+ def array_namespace(*xs, api_version=None, use_compat=None):
87
182
  """
88
183
  Get the array API compatible namespace for the arrays `xs`.
89
184
 
90
- `xs` should contain one or more arrays.
185
+ Parameters
186
+ ----------
187
+ xs: arrays
188
+ one or more arrays.
189
+
190
+ api_version: str
191
+ The newest version of the spec that you need support for (currently
192
+ the compat library wrapped APIs support v2022.12).
91
193
 
92
- Typical usage is
194
+ use_compat: bool or None
195
+ If None (the default), the native namespace will be returned if it is
196
+ already array API compatible, otherwise a compat wrapper is used. If
197
+ True, the compat library wrapped library will be returned. If False,
198
+ the native library namespace is returned.
93
199
 
94
- def your_function(x, y):
95
- xp = array_api_compat.array_namespace(x, y)
96
- # Now use xp as the array library namespace
97
- return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
200
+ Returns
201
+ -------
202
+
203
+ out: namespace
204
+ The array API compatible namespace corresponding to the arrays in `xs`.
205
+
206
+ Raises
207
+ ------
208
+ TypeError
209
+ If `xs` contains arrays from different array libraries or contains a
210
+ non-array.
211
+
212
+
213
+ Typical usage is to pass the arguments of a function to
214
+ `array_namespace()` at the top of a function to get the corresponding
215
+ array API namespace:
216
+
217
+ .. code:: python
218
+
219
+ def your_function(x, y):
220
+ xp = array_api_compat.array_namespace(x, y)
221
+ # Now use xp as the array library namespace
222
+ return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
223
+
224
+
225
+ Wrapped array namespaces can also be imported directly. For example,
226
+ `array_namespace(np.array(...))` will return `array_api_compat.numpy`.
227
+ This function will also work for any array library not wrapped by
228
+ array-api-compat if it explicitly defines `__array_namespace__
229
+ <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html>`__
230
+ (the wrapped namespace is always preferred if it exists).
231
+
232
+ See Also
233
+ --------
234
+
235
+ is_array_api_obj
236
+ is_numpy_array
237
+ is_cupy_array
238
+ is_torch_array
239
+ is_dask_array
240
+ is_jax_array
98
241
 
99
- api_version should be the newest version of the spec that you need support
100
- for (currently the compat library wrapped APIs only support v2021.12).
101
242
  """
243
+ if use_compat not in [None, True, False]:
244
+ raise ValueError("use_compat must be None, True, or False")
245
+
246
+ _use_compat = use_compat in [None, True]
247
+
102
248
  namespaces = set()
103
249
  for x in xs:
104
250
  if is_numpy_array(x):
105
- _check_api_version(api_version)
106
- if _use_compat:
107
- from .. import numpy as numpy_namespace
251
+ from .. import numpy as numpy_namespace
252
+ import numpy as np
253
+ if use_compat is True:
254
+ _check_api_version(api_version)
108
255
  namespaces.add(numpy_namespace)
109
- else:
110
- import numpy as np
256
+ elif use_compat is False:
111
257
  namespaces.add(np)
258
+ else:
259
+ # numpy 2.0 has __array_namespace__ and is fully array API
260
+ # compatible.
261
+ if hasattr(x, '__array_namespace__'):
262
+ namespaces.add(x.__array_namespace__(api_version=api_version))
263
+ else:
264
+ namespaces.add(numpy_namespace)
112
265
  elif is_cupy_array(x):
113
- _check_api_version(api_version)
114
266
  if _use_compat:
267
+ _check_api_version(api_version)
115
268
  from .. import cupy as cupy_namespace
116
269
  namespaces.add(cupy_namespace)
117
270
  else:
118
271
  import cupy as cp
119
272
  namespaces.add(cp)
120
273
  elif is_torch_array(x):
121
- _check_api_version(api_version)
122
274
  if _use_compat:
275
+ _check_api_version(api_version)
123
276
  from .. import torch as torch_namespace
124
277
  namespaces.add(torch_namespace)
125
278
  else:
126
279
  import torch
127
280
  namespaces.add(torch)
128
281
  elif is_dask_array(x):
129
- _check_api_version(api_version)
130
282
  if _use_compat:
283
+ _check_api_version(api_version)
131
284
  from ..dask import array as dask_namespace
132
285
  namespaces.add(dask_namespace)
133
286
  else:
134
- raise TypeError("_use_compat cannot be False if input array is a dask array!")
287
+ import dask.array as da
288
+ namespaces.add(da)
135
289
  elif is_jax_array(x):
136
- _check_api_version(api_version)
137
- # jax.experimental.array_api is already an array namespace. We do
138
- # not have a wrapper submodule for it.
139
- import jax.experimental.array_api as jnp
290
+ if use_compat is True:
291
+ _check_api_version(api_version)
292
+ raise ValueError("JAX does not have an array-api-compat wrapper")
293
+ elif use_compat is False:
294
+ import jax.numpy as jnp
295
+ else:
296
+ # jax.experimental.array_api is already an array namespace. We do
297
+ # not have a wrapper submodule for it.
298
+ import jax.experimental.array_api as jnp
140
299
  namespaces.add(jnp)
141
300
  elif hasattr(x, '__array_namespace__'):
301
+ if use_compat is True:
302
+ raise ValueError("The given array does not have an array-api-compat wrapper")
142
303
  namespaces.add(x.__array_namespace__(api_version=api_version))
143
304
  else:
144
305
  # TODO: Support Python scalars?
@@ -181,15 +342,33 @@ def device(x: Array, /) -> Device:
181
342
  """
182
343
  Hardware device the array data resides on.
183
344
 
345
+ This is equivalent to `x.device` according to the `standard
346
+ <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
347
+ This helper is included because some array libraries either do not have
348
+ the `device` attribute or include it with an incompatible API.
349
+
184
350
  Parameters
185
351
  ----------
186
352
  x: array
187
- array instance from NumPy or an array API compatible library.
353
+ array instance from an array API compatible library.
188
354
 
189
355
  Returns
190
356
  -------
191
357
  out: device
192
- a ``device`` object (see the "Device Support" section of the array API specification).
358
+ a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
359
+ section of the array API specification).
360
+
361
+ Notes
362
+ -----
363
+
364
+ For NumPy the device is always `"cpu"`. For Dask, the device is always a
365
+ special `DASK_DEVICE` object.
366
+
367
+ See Also
368
+ --------
369
+
370
+ to_device : Move array data to a different device.
371
+
193
372
  """
194
373
  if is_numpy_array(x):
195
374
  return "cpu"
@@ -262,22 +441,50 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
262
441
  """
263
442
  Copy the array from the device on which it currently resides to the specified ``device``.
264
443
 
444
+ This is equivalent to `x.to_device(device, stream=stream)` according to
445
+ the `standard
446
+ <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html>`__.
447
+ This helper is included because some array libraries do not have the
448
+ `to_device` method.
449
+
265
450
  Parameters
266
451
  ----------
452
+
267
453
  x: array
268
- array instance from NumPy or an array API compatible library.
454
+ array instance from an array API compatible library.
455
+
269
456
  device: device
270
- a ``device`` object (see the "Device Support" section of the array API specification).
457
+ a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
458
+ section of the array API specification).
459
+
271
460
  stream: Optional[Union[int, Any]]
272
- stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable.
461
+ stream object to use during copy. In addition to the types supported
462
+ in ``array.__dlpack__``, implementations may choose to support any
463
+ library-specific stream object with the caveat that any code using
464
+ such an object would not be portable.
273
465
 
274
466
  Returns
275
467
  -------
468
+
276
469
  out: array
277
- an array with the same data and data type as ``x`` and located on the specified ``device``.
470
+ an array with the same data and data type as ``x`` and located on the
471
+ specified ``device``.
472
+
473
+ Notes
474
+ -----
475
+
476
+ For NumPy, this function effectively does nothing since the only supported
477
+ device is the CPU. For CuPy, this method supports CuPy CUDA
478
+ :external+cupy:class:`Device <cupy.cuda.Device>` and
479
+ :external+cupy:class:`Stream <cupy.cuda.Stream>` objects. For PyTorch,
480
+ this is the same as :external+torch:meth:`x.to(device) <torch.Tensor.to>`
481
+ (the ``stream`` argument is not supported in PyTorch).
482
+
483
+ See Also
484
+ --------
485
+
486
+ device : Hardware device the array data resides on.
278
487
 
279
- .. note::
280
- If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
281
488
  """
282
489
  if is_numpy_array(x):
283
490
  if stream is not None:
@@ -305,7 +512,13 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
305
512
 
306
513
  def size(x):
307
514
  """
308
- Return the total number of elements of x
515
+ Return the total number of elements of x.
516
+
517
+ This is equivalent to `x.size` according to the `standard
518
+ <https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
519
+ This helper is included because PyTorch defines `size` in an
520
+ :external+torch:meth:`incompatible way <torch.Tensor.size>`.
521
+
309
522
  """
310
523
  if None in x.shape:
311
524
  return None
@@ -1,15 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial
4
-
5
3
  import cupy as cp
6
4
 
7
5
  from ..common import _aliases
8
6
  from .._internal import get_xp
9
7
 
10
- asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
11
- asarray.__doc__ = _aliases._asarray.__doc__
12
- del partial
8
+ from typing import TYPE_CHECKING
9
+ if TYPE_CHECKING:
10
+ from typing import Optional, Union
11
+ from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
13
12
 
14
13
  bool = cp.bool_
15
14
 
@@ -62,6 +61,52 @@ matmul = get_xp(cp)(_aliases.matmul)
62
61
  matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
63
62
  tensordot = get_xp(cp)(_aliases.tensordot)
64
63
 
64
+ _copy_default = object()
65
+
66
+ # asarray also adds the copy keyword, which is not present in numpy 1.0.
67
+ def asarray(
68
+ obj: Union[
69
+ ndarray,
70
+ bool,
71
+ int,
72
+ float,
73
+ NestedSequence[bool | int | float],
74
+ SupportsBufferProtocol,
75
+ ],
76
+ /,
77
+ *,
78
+ dtype: Optional[Dtype] = None,
79
+ device: Optional[Device] = None,
80
+ copy: Optional[bool] = _copy_default,
81
+ **kwargs,
82
+ ) -> ndarray:
83
+ """
84
+ Array API compatibility wrapper for asarray().
85
+
86
+ See the corresponding documentation in the array library and/or the array API
87
+ specification for more details.
88
+ """
89
+ with cp.cuda.Device(device):
90
+ # cupy is like NumPy 1.26 (except without _CopyMode). See the comments
91
+ # in asarray in numpy/_aliases.py.
92
+ if copy is not _copy_default:
93
+ # A future version of CuPy will change the meaning of copy=False
94
+ # to mean no-copy. We don't know for certain what version it will
95
+ # be yet, so to avoid breaking that version, we use a different
96
+ # default value for copy so asarray(obj) with no copy kwarg will
97
+ # always do the copy-if-needed behavior.
98
+
99
+ # This will still need to be updated to remove the
100
+ # NotImplementedError for copy=False, but at least this won't
101
+ # break the default or existing behavior.
102
+ if copy is None:
103
+ copy = False
104
+ elif copy is False:
105
+ raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
106
+ kwargs['copy'] = copy
107
+
108
+ return cp.array(obj, dtype=dtype, **kwargs)
109
+
65
110
  # These functions are completely new here. If the library already has them
66
111
  # (i.e., numpy 2.0), use the library version instead of our wrapper.
67
112
  if hasattr(cp, 'vecdot'):
@@ -73,7 +118,7 @@ if hasattr(cp, 'isdtype'):
73
118
  else:
74
119
  isdtype = get_xp(cp)(_aliases.isdtype)
75
120
 
76
- __all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
121
+ __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
77
122
  'acosh', 'asin', 'asinh', 'atan', 'atan2',
78
123
  'atanh', 'bitwise_left_shift', 'bitwise_invert',
79
124
  'bitwise_right_shift', 'concat', 'pow']