array-api-compat 1.5__tar.gz → 1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- array_api_compat-1.6/PKG-INFO +35 -0
- array_api_compat-1.6/README.md +9 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/__init__.py +1 -1
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_aliases.py +4 -89
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_helpers.py +245 -32
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/_aliases.py +51 -6
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/dask/array/_aliases.py +42 -5
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/dask/array/linalg.py +13 -4
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/__init__.py +6 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/_aliases.py +56 -6
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/torch/_aliases.py +4 -0
- array_api_compat-1.6/array_api_compat.egg-info/PKG-INFO +35 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat.egg-info/SOURCES.txt +1 -0
- array_api_compat-1.6/array_api_compat.egg-info/requires.txt +15 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/setup.py +3 -2
- {array_api_compat-1.5 → array_api_compat-1.6}/tests/test_all.py +2 -2
- {array_api_compat-1.5 → array_api_compat-1.6}/tests/test_array_namespace.py +21 -16
- array_api_compat-1.6/tests/test_common.py +184 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/tests/test_isdtype.py +3 -3
- array_api_compat-1.6/tests/test_no_dependencies.py +73 -0
- array_api_compat-1.5/PKG-INFO +0 -421
- array_api_compat-1.5/README.md +0 -399
- array_api_compat-1.5/array_api_compat.egg-info/PKG-INFO +0 -421
- array_api_compat-1.5/array_api_compat.egg-info/requires.txt +0 -6
- array_api_compat-1.5/tests/test_common.py +0 -62
- {array_api_compat-1.5 → array_api_compat-1.6}/LICENSE +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/_internal.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/__init__.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_fft.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_linalg.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/common/_typing.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/__init__.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/_typing.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/fft.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/cupy/linalg.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/dask/__init__.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/dask/array/__init__.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/_typing.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/fft.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/numpy/linalg.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/torch/__init__.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/torch/fft.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat/torch/linalg.py +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat.egg-info/dependency_links.txt +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/array_api_compat.egg-info/top_level.txt +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/setup.cfg +0 -0
- {array_api_compat-1.5 → array_api_compat-1.6}/tests/test_vendoring.py +0 -0
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: array_api_compat
|
|
3
|
+
Version: 1.6
|
|
4
|
+
Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
|
|
5
|
+
Home-page: https://data-apis.org/array-api-compat/
|
|
6
|
+
Author: Consortium for Python Data API Standards
|
|
7
|
+
License: MIT
|
|
8
|
+
Classifier: Programming Language :: Python :: 3
|
|
9
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
10
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
13
|
+
Classifier: Operating System :: OS Independent
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
License-File: LICENSE
|
|
16
|
+
Provides-Extra: numpy
|
|
17
|
+
Requires-Dist: numpy; extra == "numpy"
|
|
18
|
+
Provides-Extra: cupy
|
|
19
|
+
Requires-Dist: cupy; extra == "cupy"
|
|
20
|
+
Provides-Extra: jax
|
|
21
|
+
Requires-Dist: jax; extra == "jax"
|
|
22
|
+
Provides-Extra: pytorch
|
|
23
|
+
Requires-Dist: pytorch; extra == "pytorch"
|
|
24
|
+
Provides-Extra: dask
|
|
25
|
+
Requires-Dist: dask; extra == "dask"
|
|
26
|
+
|
|
27
|
+
# Array API compatibility library
|
|
28
|
+
|
|
29
|
+
This is a small wrapper around common array libraries that is compatible with
|
|
30
|
+
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
|
|
31
|
+
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
|
|
32
|
+
libraries, or if you encounter any issues, please [open an
|
|
33
|
+
issue](https://github.com/data-apis/array-api-compat/issues).
|
|
34
|
+
|
|
35
|
+
See the documentation for more details https://data-apis.org/array-api-compat/
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Array API compatibility library
|
|
2
|
+
|
|
3
|
+
This is a small wrapper around common array libraries that is compatible with
|
|
4
|
+
the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
|
|
5
|
+
NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
|
|
6
|
+
libraries, or if you encounter any issues, please [open an
|
|
7
|
+
issue](https://github.com/data-apis/array-api-compat/issues).
|
|
8
|
+
|
|
9
|
+
See the documentation for more details https://data-apis.org/array-api-compat/
|
|
@@ -6,18 +6,18 @@ from __future__ import annotations
|
|
|
6
6
|
|
|
7
7
|
from typing import TYPE_CHECKING
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
|
-
import numpy as np
|
|
10
9
|
from typing import Optional, Sequence, Tuple, Union
|
|
11
|
-
from ._typing import ndarray, Device, Dtype
|
|
10
|
+
from ._typing import ndarray, Device, Dtype
|
|
12
11
|
|
|
13
12
|
from typing import NamedTuple
|
|
14
|
-
from types import ModuleType
|
|
15
13
|
import inspect
|
|
16
14
|
|
|
17
|
-
from ._helpers import _check_device
|
|
15
|
+
from ._helpers import _check_device
|
|
18
16
|
|
|
19
17
|
# These functions are modified from the NumPy versions.
|
|
20
18
|
|
|
19
|
+
# Creation functions add the device keyword (which does nothing for NumPy)
|
|
20
|
+
|
|
21
21
|
def arange(
|
|
22
22
|
start: Union[int, float],
|
|
23
23
|
/,
|
|
@@ -268,91 +268,6 @@ def var(
|
|
|
268
268
|
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
|
|
269
269
|
return xp.transpose(x, axes)
|
|
270
270
|
|
|
271
|
-
# Creation functions add the device keyword (which does nothing for NumPy)
|
|
272
|
-
|
|
273
|
-
# asarray also adds the copy keyword
|
|
274
|
-
def _asarray(
|
|
275
|
-
obj: Union[
|
|
276
|
-
ndarray,
|
|
277
|
-
bool,
|
|
278
|
-
int,
|
|
279
|
-
float,
|
|
280
|
-
NestedSequence[bool | int | float],
|
|
281
|
-
SupportsBufferProtocol,
|
|
282
|
-
],
|
|
283
|
-
/,
|
|
284
|
-
*,
|
|
285
|
-
dtype: Optional[Dtype] = None,
|
|
286
|
-
device: Optional[Device] = None,
|
|
287
|
-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
|
|
288
|
-
namespace = None,
|
|
289
|
-
**kwargs,
|
|
290
|
-
) -> ndarray:
|
|
291
|
-
"""
|
|
292
|
-
Array API compatibility wrapper for asarray().
|
|
293
|
-
|
|
294
|
-
See the corresponding documentation in NumPy/CuPy and/or the array API
|
|
295
|
-
specification for more details.
|
|
296
|
-
|
|
297
|
-
"""
|
|
298
|
-
if namespace is None:
|
|
299
|
-
try:
|
|
300
|
-
xp = array_namespace(obj, _use_compat=False)
|
|
301
|
-
except ValueError:
|
|
302
|
-
# TODO: What about lists of arrays?
|
|
303
|
-
raise ValueError("A namespace must be specified for asarray() with non-array input")
|
|
304
|
-
elif isinstance(namespace, ModuleType):
|
|
305
|
-
xp = namespace
|
|
306
|
-
elif namespace == 'numpy':
|
|
307
|
-
import numpy as xp
|
|
308
|
-
elif namespace == 'cupy':
|
|
309
|
-
import cupy as xp
|
|
310
|
-
elif namespace == 'dask.array':
|
|
311
|
-
import dask.array as xp
|
|
312
|
-
else:
|
|
313
|
-
raise ValueError("Unrecognized namespace argument to asarray()")
|
|
314
|
-
|
|
315
|
-
_check_device(xp, device)
|
|
316
|
-
if is_numpy_array(obj):
|
|
317
|
-
import numpy as np
|
|
318
|
-
if hasattr(np, '_CopyMode'):
|
|
319
|
-
# Not present in older NumPys
|
|
320
|
-
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
|
|
321
|
-
COPY_TRUE = (True, np._CopyMode.ALWAYS)
|
|
322
|
-
else:
|
|
323
|
-
COPY_FALSE = (False,)
|
|
324
|
-
COPY_TRUE = (True,)
|
|
325
|
-
else:
|
|
326
|
-
COPY_FALSE = (False,)
|
|
327
|
-
COPY_TRUE = (True,)
|
|
328
|
-
if copy in COPY_FALSE:
|
|
329
|
-
# copy=False is not yet implemented in xp.asarray
|
|
330
|
-
raise NotImplementedError("copy=False is not yet implemented")
|
|
331
|
-
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)) or hasattr(obj, "__array__"):
|
|
332
|
-
#print('hit me')
|
|
333
|
-
if dtype is not None and obj.dtype != dtype:
|
|
334
|
-
copy = True
|
|
335
|
-
#print(copy)
|
|
336
|
-
if copy in COPY_TRUE:
|
|
337
|
-
copy_kwargs = {}
|
|
338
|
-
if namespace != "dask.array":
|
|
339
|
-
copy_kwargs["copy"] = True
|
|
340
|
-
else:
|
|
341
|
-
# No copy kw in dask.asarray so we go thorugh np.asarray first
|
|
342
|
-
# (like dask also does) but copy after
|
|
343
|
-
if dtype is None:
|
|
344
|
-
# Same dtype copy is no-op in dask
|
|
345
|
-
#print("in here?")
|
|
346
|
-
return obj.copy()
|
|
347
|
-
import numpy as np
|
|
348
|
-
#print(obj)
|
|
349
|
-
obj = np.asarray(obj).copy()
|
|
350
|
-
#print(obj)
|
|
351
|
-
return xp.array(obj, dtype=dtype, **copy_kwargs)
|
|
352
|
-
return obj
|
|
353
|
-
|
|
354
|
-
return xp.asarray(obj, dtype=dtype, **kwargs)
|
|
355
|
-
|
|
356
271
|
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
|
|
357
272
|
def reshape(x: ndarray,
|
|
358
273
|
/,
|
|
@@ -19,6 +19,24 @@ import inspect
|
|
|
19
19
|
import warnings
|
|
20
20
|
|
|
21
21
|
def is_numpy_array(x):
|
|
22
|
+
"""
|
|
23
|
+
Return True if `x` is a NumPy array.
|
|
24
|
+
|
|
25
|
+
This function does not import NumPy if it has not already been imported
|
|
26
|
+
and is therefore cheap to use.
|
|
27
|
+
|
|
28
|
+
This also returns True for `ndarray` subclasses and NumPy scalar objects.
|
|
29
|
+
|
|
30
|
+
See Also
|
|
31
|
+
--------
|
|
32
|
+
|
|
33
|
+
array_namespace
|
|
34
|
+
is_array_api_obj
|
|
35
|
+
is_cupy_array
|
|
36
|
+
is_torch_array
|
|
37
|
+
is_dask_array
|
|
38
|
+
is_jax_array
|
|
39
|
+
"""
|
|
22
40
|
# Avoid importing NumPy if it isn't already
|
|
23
41
|
if 'numpy' not in sys.modules:
|
|
24
42
|
return False
|
|
@@ -29,6 +47,24 @@ def is_numpy_array(x):
|
|
|
29
47
|
return isinstance(x, (np.ndarray, np.generic))
|
|
30
48
|
|
|
31
49
|
def is_cupy_array(x):
|
|
50
|
+
"""
|
|
51
|
+
Return True if `x` is a CuPy array.
|
|
52
|
+
|
|
53
|
+
This function does not import CuPy if it has not already been imported
|
|
54
|
+
and is therefore cheap to use.
|
|
55
|
+
|
|
56
|
+
This also returns True for `cupy.ndarray` subclasses and CuPy scalar objects.
|
|
57
|
+
|
|
58
|
+
See Also
|
|
59
|
+
--------
|
|
60
|
+
|
|
61
|
+
array_namespace
|
|
62
|
+
is_array_api_obj
|
|
63
|
+
is_numpy_array
|
|
64
|
+
is_torch_array
|
|
65
|
+
is_dask_array
|
|
66
|
+
is_jax_array
|
|
67
|
+
"""
|
|
32
68
|
# Avoid importing NumPy if it isn't already
|
|
33
69
|
if 'cupy' not in sys.modules:
|
|
34
70
|
return False
|
|
@@ -39,6 +75,22 @@ def is_cupy_array(x):
|
|
|
39
75
|
return isinstance(x, (cp.ndarray, cp.generic))
|
|
40
76
|
|
|
41
77
|
def is_torch_array(x):
|
|
78
|
+
"""
|
|
79
|
+
Return True if `x` is a PyTorch tensor.
|
|
80
|
+
|
|
81
|
+
This function does not import PyTorch if it has not already been imported
|
|
82
|
+
and is therefore cheap to use.
|
|
83
|
+
|
|
84
|
+
See Also
|
|
85
|
+
--------
|
|
86
|
+
|
|
87
|
+
array_namespace
|
|
88
|
+
is_array_api_obj
|
|
89
|
+
is_numpy_array
|
|
90
|
+
is_cupy_array
|
|
91
|
+
is_dask_array
|
|
92
|
+
is_jax_array
|
|
93
|
+
"""
|
|
42
94
|
# Avoid importing torch if it isn't already
|
|
43
95
|
if 'torch' not in sys.modules:
|
|
44
96
|
return False
|
|
@@ -49,6 +101,22 @@ def is_torch_array(x):
|
|
|
49
101
|
return isinstance(x, torch.Tensor)
|
|
50
102
|
|
|
51
103
|
def is_dask_array(x):
|
|
104
|
+
"""
|
|
105
|
+
Return True if `x` is a dask.array Array.
|
|
106
|
+
|
|
107
|
+
This function does not import dask if it has not already been imported
|
|
108
|
+
and is therefore cheap to use.
|
|
109
|
+
|
|
110
|
+
See Also
|
|
111
|
+
--------
|
|
112
|
+
|
|
113
|
+
array_namespace
|
|
114
|
+
is_array_api_obj
|
|
115
|
+
is_numpy_array
|
|
116
|
+
is_cupy_array
|
|
117
|
+
is_torch_array
|
|
118
|
+
is_jax_array
|
|
119
|
+
"""
|
|
52
120
|
# Avoid importing dask if it isn't already
|
|
53
121
|
if 'dask.array' not in sys.modules:
|
|
54
122
|
return False
|
|
@@ -58,6 +126,23 @@ def is_dask_array(x):
|
|
|
58
126
|
return isinstance(x, dask.array.Array)
|
|
59
127
|
|
|
60
128
|
def is_jax_array(x):
|
|
129
|
+
"""
|
|
130
|
+
Return True if `x` is a JAX array.
|
|
131
|
+
|
|
132
|
+
This function does not import JAX if it has not already been imported
|
|
133
|
+
and is therefore cheap to use.
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
See Also
|
|
137
|
+
--------
|
|
138
|
+
|
|
139
|
+
array_namespace
|
|
140
|
+
is_array_api_obj
|
|
141
|
+
is_numpy_array
|
|
142
|
+
is_cupy_array
|
|
143
|
+
is_torch_array
|
|
144
|
+
is_dask_array
|
|
145
|
+
"""
|
|
61
146
|
# Avoid importing jax if it isn't already
|
|
62
147
|
if 'jax' not in sys.modules:
|
|
63
148
|
return False
|
|
@@ -68,7 +153,17 @@ def is_jax_array(x):
|
|
|
68
153
|
|
|
69
154
|
def is_array_api_obj(x):
|
|
70
155
|
"""
|
|
71
|
-
|
|
156
|
+
Return True if `x` is an array API compatible array object.
|
|
157
|
+
|
|
158
|
+
See Also
|
|
159
|
+
--------
|
|
160
|
+
|
|
161
|
+
array_namespace
|
|
162
|
+
is_numpy_array
|
|
163
|
+
is_cupy_array
|
|
164
|
+
is_torch_array
|
|
165
|
+
is_dask_array
|
|
166
|
+
is_jax_array
|
|
72
167
|
"""
|
|
73
168
|
return is_numpy_array(x) \
|
|
74
169
|
or is_cupy_array(x) \
|
|
@@ -83,62 +178,128 @@ def _check_api_version(api_version):
|
|
|
83
178
|
elif api_version is not None and api_version != '2022.12':
|
|
84
179
|
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
|
|
85
180
|
|
|
86
|
-
def array_namespace(*xs, api_version=None,
|
|
181
|
+
def array_namespace(*xs, api_version=None, use_compat=None):
|
|
87
182
|
"""
|
|
88
183
|
Get the array API compatible namespace for the arrays `xs`.
|
|
89
184
|
|
|
90
|
-
|
|
185
|
+
Parameters
|
|
186
|
+
----------
|
|
187
|
+
xs: arrays
|
|
188
|
+
one or more arrays.
|
|
189
|
+
|
|
190
|
+
api_version: str
|
|
191
|
+
The newest version of the spec that you need support for (currently
|
|
192
|
+
the compat library wrapped APIs support v2022.12).
|
|
91
193
|
|
|
92
|
-
|
|
194
|
+
use_compat: bool or None
|
|
195
|
+
If None (the default), the native namespace will be returned if it is
|
|
196
|
+
already array API compatible, otherwise a compat wrapper is used. If
|
|
197
|
+
True, the compat library wrapped library will be returned. If False,
|
|
198
|
+
the native library namespace is returned.
|
|
93
199
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
|
|
203
|
+
out: namespace
|
|
204
|
+
The array API compatible namespace corresponding to the arrays in `xs`.
|
|
205
|
+
|
|
206
|
+
Raises
|
|
207
|
+
------
|
|
208
|
+
TypeError
|
|
209
|
+
If `xs` contains arrays from different array libraries or contains a
|
|
210
|
+
non-array.
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
Typical usage is to pass the arguments of a function to
|
|
214
|
+
`array_namespace()` at the top of a function to get the corresponding
|
|
215
|
+
array API namespace:
|
|
216
|
+
|
|
217
|
+
.. code:: python
|
|
218
|
+
|
|
219
|
+
def your_function(x, y):
|
|
220
|
+
xp = array_api_compat.array_namespace(x, y)
|
|
221
|
+
# Now use xp as the array library namespace
|
|
222
|
+
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
Wrapped array namespaces can also be imported directly. For example,
|
|
226
|
+
`array_namespace(np.array(...))` will return `array_api_compat.numpy`.
|
|
227
|
+
This function will also work for any array library not wrapped by
|
|
228
|
+
array-api-compat if it explicitly defines `__array_namespace__
|
|
229
|
+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__array_namespace__.html>`__
|
|
230
|
+
(the wrapped namespace is always preferred if it exists).
|
|
231
|
+
|
|
232
|
+
See Also
|
|
233
|
+
--------
|
|
234
|
+
|
|
235
|
+
is_array_api_obj
|
|
236
|
+
is_numpy_array
|
|
237
|
+
is_cupy_array
|
|
238
|
+
is_torch_array
|
|
239
|
+
is_dask_array
|
|
240
|
+
is_jax_array
|
|
98
241
|
|
|
99
|
-
api_version should be the newest version of the spec that you need support
|
|
100
|
-
for (currently the compat library wrapped APIs only support v2021.12).
|
|
101
242
|
"""
|
|
243
|
+
if use_compat not in [None, True, False]:
|
|
244
|
+
raise ValueError("use_compat must be None, True, or False")
|
|
245
|
+
|
|
246
|
+
_use_compat = use_compat in [None, True]
|
|
247
|
+
|
|
102
248
|
namespaces = set()
|
|
103
249
|
for x in xs:
|
|
104
250
|
if is_numpy_array(x):
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
251
|
+
from .. import numpy as numpy_namespace
|
|
252
|
+
import numpy as np
|
|
253
|
+
if use_compat is True:
|
|
254
|
+
_check_api_version(api_version)
|
|
108
255
|
namespaces.add(numpy_namespace)
|
|
109
|
-
|
|
110
|
-
import numpy as np
|
|
256
|
+
elif use_compat is False:
|
|
111
257
|
namespaces.add(np)
|
|
258
|
+
else:
|
|
259
|
+
# numpy 2.0 has __array_namespace__ and is fully array API
|
|
260
|
+
# compatible.
|
|
261
|
+
if hasattr(x, '__array_namespace__'):
|
|
262
|
+
namespaces.add(x.__array_namespace__(api_version=api_version))
|
|
263
|
+
else:
|
|
264
|
+
namespaces.add(numpy_namespace)
|
|
112
265
|
elif is_cupy_array(x):
|
|
113
|
-
_check_api_version(api_version)
|
|
114
266
|
if _use_compat:
|
|
267
|
+
_check_api_version(api_version)
|
|
115
268
|
from .. import cupy as cupy_namespace
|
|
116
269
|
namespaces.add(cupy_namespace)
|
|
117
270
|
else:
|
|
118
271
|
import cupy as cp
|
|
119
272
|
namespaces.add(cp)
|
|
120
273
|
elif is_torch_array(x):
|
|
121
|
-
_check_api_version(api_version)
|
|
122
274
|
if _use_compat:
|
|
275
|
+
_check_api_version(api_version)
|
|
123
276
|
from .. import torch as torch_namespace
|
|
124
277
|
namespaces.add(torch_namespace)
|
|
125
278
|
else:
|
|
126
279
|
import torch
|
|
127
280
|
namespaces.add(torch)
|
|
128
281
|
elif is_dask_array(x):
|
|
129
|
-
_check_api_version(api_version)
|
|
130
282
|
if _use_compat:
|
|
283
|
+
_check_api_version(api_version)
|
|
131
284
|
from ..dask import array as dask_namespace
|
|
132
285
|
namespaces.add(dask_namespace)
|
|
133
286
|
else:
|
|
134
|
-
|
|
287
|
+
import dask.array as da
|
|
288
|
+
namespaces.add(da)
|
|
135
289
|
elif is_jax_array(x):
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
290
|
+
if use_compat is True:
|
|
291
|
+
_check_api_version(api_version)
|
|
292
|
+
raise ValueError("JAX does not have an array-api-compat wrapper")
|
|
293
|
+
elif use_compat is False:
|
|
294
|
+
import jax.numpy as jnp
|
|
295
|
+
else:
|
|
296
|
+
# jax.experimental.array_api is already an array namespace. We do
|
|
297
|
+
# not have a wrapper submodule for it.
|
|
298
|
+
import jax.experimental.array_api as jnp
|
|
140
299
|
namespaces.add(jnp)
|
|
141
300
|
elif hasattr(x, '__array_namespace__'):
|
|
301
|
+
if use_compat is True:
|
|
302
|
+
raise ValueError("The given array does not have an array-api-compat wrapper")
|
|
142
303
|
namespaces.add(x.__array_namespace__(api_version=api_version))
|
|
143
304
|
else:
|
|
144
305
|
# TODO: Support Python scalars?
|
|
@@ -181,15 +342,33 @@ def device(x: Array, /) -> Device:
|
|
|
181
342
|
"""
|
|
182
343
|
Hardware device the array data resides on.
|
|
183
344
|
|
|
345
|
+
This is equivalent to `x.device` according to the `standard
|
|
346
|
+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.device.html>`__.
|
|
347
|
+
This helper is included because some array libraries either do not have
|
|
348
|
+
the `device` attribute or include it with an incompatible API.
|
|
349
|
+
|
|
184
350
|
Parameters
|
|
185
351
|
----------
|
|
186
352
|
x: array
|
|
187
|
-
array instance from
|
|
353
|
+
array instance from an array API compatible library.
|
|
188
354
|
|
|
189
355
|
Returns
|
|
190
356
|
-------
|
|
191
357
|
out: device
|
|
192
|
-
a ``device`` object (see the
|
|
358
|
+
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
|
|
359
|
+
section of the array API specification).
|
|
360
|
+
|
|
361
|
+
Notes
|
|
362
|
+
-----
|
|
363
|
+
|
|
364
|
+
For NumPy the device is always `"cpu"`. For Dask, the device is always a
|
|
365
|
+
special `DASK_DEVICE` object.
|
|
366
|
+
|
|
367
|
+
See Also
|
|
368
|
+
--------
|
|
369
|
+
|
|
370
|
+
to_device : Move array data to a different device.
|
|
371
|
+
|
|
193
372
|
"""
|
|
194
373
|
if is_numpy_array(x):
|
|
195
374
|
return "cpu"
|
|
@@ -262,22 +441,50 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
|
|
|
262
441
|
"""
|
|
263
442
|
Copy the array from the device on which it currently resides to the specified ``device``.
|
|
264
443
|
|
|
444
|
+
This is equivalent to `x.to_device(device, stream=stream)` according to
|
|
445
|
+
the `standard
|
|
446
|
+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.to_device.html>`__.
|
|
447
|
+
This helper is included because some array libraries do not have the
|
|
448
|
+
`to_device` method.
|
|
449
|
+
|
|
265
450
|
Parameters
|
|
266
451
|
----------
|
|
452
|
+
|
|
267
453
|
x: array
|
|
268
|
-
array instance from
|
|
454
|
+
array instance from an array API compatible library.
|
|
455
|
+
|
|
269
456
|
device: device
|
|
270
|
-
a ``device`` object (see the
|
|
457
|
+
a ``device`` object (see the `Device Support <https://data-apis.org/array-api/latest/design_topics/device_support.html>`__
|
|
458
|
+
section of the array API specification).
|
|
459
|
+
|
|
271
460
|
stream: Optional[Union[int, Any]]
|
|
272
|
-
stream object to use during copy. In addition to the types supported
|
|
461
|
+
stream object to use during copy. In addition to the types supported
|
|
462
|
+
in ``array.__dlpack__``, implementations may choose to support any
|
|
463
|
+
library-specific stream object with the caveat that any code using
|
|
464
|
+
such an object would not be portable.
|
|
273
465
|
|
|
274
466
|
Returns
|
|
275
467
|
-------
|
|
468
|
+
|
|
276
469
|
out: array
|
|
277
|
-
an array with the same data and data type as ``x`` and located on the
|
|
470
|
+
an array with the same data and data type as ``x`` and located on the
|
|
471
|
+
specified ``device``.
|
|
472
|
+
|
|
473
|
+
Notes
|
|
474
|
+
-----
|
|
475
|
+
|
|
476
|
+
For NumPy, this function effectively does nothing since the only supported
|
|
477
|
+
device is the CPU. For CuPy, this method supports CuPy CUDA
|
|
478
|
+
:external+cupy:class:`Device <cupy.cuda.Device>` and
|
|
479
|
+
:external+cupy:class:`Stream <cupy.cuda.Stream>` objects. For PyTorch,
|
|
480
|
+
this is the same as :external+torch:meth:`x.to(device) <torch.Tensor.to>`
|
|
481
|
+
(the ``stream`` argument is not supported in PyTorch).
|
|
482
|
+
|
|
483
|
+
See Also
|
|
484
|
+
--------
|
|
485
|
+
|
|
486
|
+
device : Hardware device the array data resides on.
|
|
278
487
|
|
|
279
|
-
.. note::
|
|
280
|
-
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
|
|
281
488
|
"""
|
|
282
489
|
if is_numpy_array(x):
|
|
283
490
|
if stream is not None:
|
|
@@ -305,7 +512,13 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
|
|
|
305
512
|
|
|
306
513
|
def size(x):
|
|
307
514
|
"""
|
|
308
|
-
Return the total number of elements of x
|
|
515
|
+
Return the total number of elements of x.
|
|
516
|
+
|
|
517
|
+
This is equivalent to `x.size` according to the `standard
|
|
518
|
+
<https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.size.html>`__.
|
|
519
|
+
This helper is included because PyTorch defines `size` in an
|
|
520
|
+
:external+torch:meth:`incompatible way <torch.Tensor.size>`.
|
|
521
|
+
|
|
309
522
|
"""
|
|
310
523
|
if None in x.shape:
|
|
311
524
|
return None
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from functools import partial
|
|
4
|
-
|
|
5
3
|
import cupy as cp
|
|
6
4
|
|
|
7
5
|
from ..common import _aliases
|
|
8
6
|
from .._internal import get_xp
|
|
9
7
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from typing import Optional, Union
|
|
11
|
+
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
|
|
13
12
|
|
|
14
13
|
bool = cp.bool_
|
|
15
14
|
|
|
@@ -62,6 +61,52 @@ matmul = get_xp(cp)(_aliases.matmul)
|
|
|
62
61
|
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
|
|
63
62
|
tensordot = get_xp(cp)(_aliases.tensordot)
|
|
64
63
|
|
|
64
|
+
_copy_default = object()
|
|
65
|
+
|
|
66
|
+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
|
67
|
+
def asarray(
|
|
68
|
+
obj: Union[
|
|
69
|
+
ndarray,
|
|
70
|
+
bool,
|
|
71
|
+
int,
|
|
72
|
+
float,
|
|
73
|
+
NestedSequence[bool | int | float],
|
|
74
|
+
SupportsBufferProtocol,
|
|
75
|
+
],
|
|
76
|
+
/,
|
|
77
|
+
*,
|
|
78
|
+
dtype: Optional[Dtype] = None,
|
|
79
|
+
device: Optional[Device] = None,
|
|
80
|
+
copy: Optional[bool] = _copy_default,
|
|
81
|
+
**kwargs,
|
|
82
|
+
) -> ndarray:
|
|
83
|
+
"""
|
|
84
|
+
Array API compatibility wrapper for asarray().
|
|
85
|
+
|
|
86
|
+
See the corresponding documentation in the array library and/or the array API
|
|
87
|
+
specification for more details.
|
|
88
|
+
"""
|
|
89
|
+
with cp.cuda.Device(device):
|
|
90
|
+
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
|
|
91
|
+
# in asarray in numpy/_aliases.py.
|
|
92
|
+
if copy is not _copy_default:
|
|
93
|
+
# A future version of CuPy will change the meaning of copy=False
|
|
94
|
+
# to mean no-copy. We don't know for certain what version it will
|
|
95
|
+
# be yet, so to avoid breaking that version, we use a different
|
|
96
|
+
# default value for copy so asarray(obj) with no copy kwarg will
|
|
97
|
+
# always do the copy-if-needed behavior.
|
|
98
|
+
|
|
99
|
+
# This will still need to be updated to remove the
|
|
100
|
+
# NotImplementedError for copy=False, but at least this won't
|
|
101
|
+
# break the default or existing behavior.
|
|
102
|
+
if copy is None:
|
|
103
|
+
copy = False
|
|
104
|
+
elif copy is False:
|
|
105
|
+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
|
|
106
|
+
kwargs['copy'] = copy
|
|
107
|
+
|
|
108
|
+
return cp.array(obj, dtype=dtype, **kwargs)
|
|
109
|
+
|
|
65
110
|
# These functions are completely new here. If the library already has them
|
|
66
111
|
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
|
67
112
|
if hasattr(cp, 'vecdot'):
|
|
@@ -73,7 +118,7 @@ if hasattr(cp, 'isdtype'):
|
|
|
73
118
|
else:
|
|
74
119
|
isdtype = get_xp(cp)(_aliases.isdtype)
|
|
75
120
|
|
|
76
|
-
__all__ = _aliases.__all__ + ['asarray', '
|
|
121
|
+
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
|
|
77
122
|
'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
|
78
123
|
'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
|
79
124
|
'bitwise_right_shift', 'concat', 'pow']
|