array-api-compat 1.5.1__tar.gz → 1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {array_api_compat-1.5.1 → array_api_compat-1.6}/PKG-INFO +7 -3
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/__init__.py +1 -1
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_aliases.py +4 -88
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_helpers.py +40 -14
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/_aliases.py +51 -6
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/dask/array/_aliases.py +42 -5
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/__init__.py +6 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/_aliases.py +56 -6
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/PKG-INFO +7 -3
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/SOURCES.txt +1 -0
- array_api_compat-1.6/array_api_compat.egg-info/requires.txt +15 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/setup.py +3 -2
- {array_api_compat-1.5.1 → array_api_compat-1.6}/tests/test_all.py +2 -2
- {array_api_compat-1.5.1 → array_api_compat-1.6}/tests/test_array_namespace.py +21 -16
- array_api_compat-1.6/tests/test_common.py +184 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/tests/test_isdtype.py +3 -3
- array_api_compat-1.6/tests/test_no_dependencies.py +73 -0
- array_api_compat-1.5.1/array_api_compat.egg-info/requires.txt +0 -6
- array_api_compat-1.5.1/tests/test_common.py +0 -85
- {array_api_compat-1.5.1 → array_api_compat-1.6}/LICENSE +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/README.md +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/_internal.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/__init__.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_fft.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_linalg.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_typing.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/__init__.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/_typing.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/fft.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/linalg.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/dask/__init__.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/dask/array/__init__.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/dask/array/linalg.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/_typing.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/fft.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/linalg.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/torch/__init__.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/torch/_aliases.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/torch/fft.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/torch/linalg.py +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/dependency_links.txt +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/top_level.txt +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/setup.cfg +0 -0
- {array_api_compat-1.5.1 → array_api_compat-1.6}/tests/test_vendoring.py +0 -0
|
@@ -1,24 +1,28 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: array_api_compat
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.6
|
|
4
4
|
Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
|
|
5
5
|
Home-page: https://data-apis.org/array-api-compat/
|
|
6
6
|
Author: Consortium for Python Data API Standards
|
|
7
7
|
License: MIT
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
10
9
|
Classifier: Programming Language :: Python :: 3.9
|
|
11
10
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
12
|
Classifier: License :: OSI Approved :: MIT License
|
|
14
13
|
Classifier: Operating System :: OS Independent
|
|
15
|
-
Requires-Python: >=3.8
|
|
16
14
|
Description-Content-Type: text/markdown
|
|
17
15
|
License-File: LICENSE
|
|
18
16
|
Provides-Extra: numpy
|
|
19
17
|
Requires-Dist: numpy; extra == "numpy"
|
|
20
18
|
Provides-Extra: cupy
|
|
21
19
|
Requires-Dist: cupy; extra == "cupy"
|
|
20
|
+
Provides-Extra: jax
|
|
21
|
+
Requires-Dist: jax; extra == "jax"
|
|
22
|
+
Provides-Extra: pytorch
|
|
23
|
+
Requires-Dist: pytorch; extra == "pytorch"
|
|
24
|
+
Provides-Extra: dask
|
|
25
|
+
Requires-Dist: dask; extra == "dask"
|
|
22
26
|
|
|
23
27
|
# Array API compatibility library
|
|
24
28
|
|
|
@@ -6,18 +6,18 @@ from __future__ import annotations
|
|
|
6
6
|
|
|
7
7
|
from typing import TYPE_CHECKING
|
|
8
8
|
if TYPE_CHECKING:
|
|
9
|
-
import numpy as np
|
|
10
9
|
from typing import Optional, Sequence, Tuple, Union
|
|
11
|
-
from ._typing import ndarray, Device, Dtype
|
|
10
|
+
from ._typing import ndarray, Device, Dtype
|
|
12
11
|
|
|
13
12
|
from typing import NamedTuple
|
|
14
|
-
from types import ModuleType
|
|
15
13
|
import inspect
|
|
16
14
|
|
|
17
|
-
from ._helpers import _check_device
|
|
15
|
+
from ._helpers import _check_device
|
|
18
16
|
|
|
19
17
|
# These functions are modified from the NumPy versions.
|
|
20
18
|
|
|
19
|
+
# Creation functions add the device keyword (which does nothing for NumPy)
|
|
20
|
+
|
|
21
21
|
def arange(
|
|
22
22
|
start: Union[int, float],
|
|
23
23
|
/,
|
|
@@ -268,90 +268,6 @@ def var(
|
|
|
268
268
|
def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
|
|
269
269
|
return xp.transpose(x, axes)
|
|
270
270
|
|
|
271
|
-
# Creation functions add the device keyword (which does nothing for NumPy)
|
|
272
|
-
|
|
273
|
-
# asarray also adds the copy keyword
|
|
274
|
-
def _asarray(
|
|
275
|
-
obj: Union[
|
|
276
|
-
ndarray,
|
|
277
|
-
bool,
|
|
278
|
-
int,
|
|
279
|
-
float,
|
|
280
|
-
NestedSequence[bool | int | float],
|
|
281
|
-
SupportsBufferProtocol,
|
|
282
|
-
],
|
|
283
|
-
/,
|
|
284
|
-
*,
|
|
285
|
-
dtype: Optional[Dtype] = None,
|
|
286
|
-
device: Optional[Device] = None,
|
|
287
|
-
copy: "Optional[Union[bool, np._CopyMode]]" = None,
|
|
288
|
-
namespace = None,
|
|
289
|
-
**kwargs,
|
|
290
|
-
) -> ndarray:
|
|
291
|
-
"""
|
|
292
|
-
Array API compatibility wrapper for asarray().
|
|
293
|
-
|
|
294
|
-
See the corresponding documentation in NumPy/CuPy and/or the array API
|
|
295
|
-
specification for more details.
|
|
296
|
-
|
|
297
|
-
"""
|
|
298
|
-
if namespace is None:
|
|
299
|
-
try:
|
|
300
|
-
xp = array_namespace(obj, _use_compat=False)
|
|
301
|
-
except ValueError:
|
|
302
|
-
# TODO: What about lists of arrays?
|
|
303
|
-
raise ValueError("A namespace must be specified for asarray() with non-array input")
|
|
304
|
-
elif isinstance(namespace, ModuleType):
|
|
305
|
-
xp = namespace
|
|
306
|
-
elif namespace == 'numpy':
|
|
307
|
-
import numpy as xp
|
|
308
|
-
elif namespace == 'cupy':
|
|
309
|
-
import cupy as xp
|
|
310
|
-
elif namespace == 'dask.array':
|
|
311
|
-
import dask.array as xp
|
|
312
|
-
else:
|
|
313
|
-
raise ValueError("Unrecognized namespace argument to asarray()")
|
|
314
|
-
|
|
315
|
-
_check_device(xp, device)
|
|
316
|
-
if is_numpy_array(obj):
|
|
317
|
-
import numpy as np
|
|
318
|
-
if hasattr(np, '_CopyMode'):
|
|
319
|
-
# Not present in older NumPys
|
|
320
|
-
COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
|
|
321
|
-
COPY_TRUE = (True, np._CopyMode.ALWAYS)
|
|
322
|
-
else:
|
|
323
|
-
COPY_FALSE = (False,)
|
|
324
|
-
COPY_TRUE = (True,)
|
|
325
|
-
else:
|
|
326
|
-
COPY_FALSE = (False,)
|
|
327
|
-
COPY_TRUE = (True,)
|
|
328
|
-
if copy in COPY_FALSE and namespace != "dask.array":
|
|
329
|
-
# copy=False is not yet implemented in xp.asarray
|
|
330
|
-
raise NotImplementedError("copy=False is not yet implemented")
|
|
331
|
-
if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
|
|
332
|
-
if dtype is not None and obj.dtype != dtype:
|
|
333
|
-
copy = True
|
|
334
|
-
if copy in COPY_TRUE:
|
|
335
|
-
return xp.array(obj, copy=True, dtype=dtype)
|
|
336
|
-
return obj
|
|
337
|
-
elif namespace == "dask.array":
|
|
338
|
-
if copy in COPY_TRUE:
|
|
339
|
-
if dtype is None:
|
|
340
|
-
return obj.copy()
|
|
341
|
-
# Go through numpy, since dask copy is no-op by default
|
|
342
|
-
import numpy as np
|
|
343
|
-
obj = np.array(obj, dtype=dtype, copy=True)
|
|
344
|
-
return xp.array(obj, dtype=dtype)
|
|
345
|
-
else:
|
|
346
|
-
import dask.array as da
|
|
347
|
-
import numpy as np
|
|
348
|
-
if not isinstance(obj, da.Array):
|
|
349
|
-
obj = np.asarray(obj, dtype=dtype)
|
|
350
|
-
return da.from_array(obj)
|
|
351
|
-
return obj
|
|
352
|
-
|
|
353
|
-
return xp.asarray(obj, dtype=dtype, **kwargs)
|
|
354
|
-
|
|
355
271
|
# np.reshape calls the keyword argument 'newshape' instead of 'shape'
|
|
356
272
|
def reshape(x: ndarray,
|
|
357
273
|
/,
|
|
@@ -178,7 +178,7 @@ def _check_api_version(api_version):
|
|
|
178
178
|
elif api_version is not None and api_version != '2022.12':
|
|
179
179
|
raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
|
|
180
180
|
|
|
181
|
-
def array_namespace(*xs, api_version=None,
|
|
181
|
+
def array_namespace(*xs, api_version=None, use_compat=None):
|
|
182
182
|
"""
|
|
183
183
|
Get the array API compatible namespace for the arrays `xs`.
|
|
184
184
|
|
|
@@ -191,6 +191,12 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
|
|
|
191
191
|
The newest version of the spec that you need support for (currently
|
|
192
192
|
the compat library wrapped APIs support v2022.12).
|
|
193
193
|
|
|
194
|
+
use_compat: bool or None
|
|
195
|
+
If None (the default), the native namespace will be returned if it is
|
|
196
|
+
already array API compatible, otherwise a compat wrapper is used. If
|
|
197
|
+
True, the compat library wrapped library will be returned. If False,
|
|
198
|
+
the native library namespace is returned.
|
|
199
|
+
|
|
194
200
|
Returns
|
|
195
201
|
-------
|
|
196
202
|
|
|
@@ -234,46 +240,66 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
|
|
|
234
240
|
is_jax_array
|
|
235
241
|
|
|
236
242
|
"""
|
|
243
|
+
if use_compat not in [None, True, False]:
|
|
244
|
+
raise ValueError("use_compat must be None, True, or False")
|
|
245
|
+
|
|
246
|
+
_use_compat = use_compat in [None, True]
|
|
247
|
+
|
|
237
248
|
namespaces = set()
|
|
238
249
|
for x in xs:
|
|
239
250
|
if is_numpy_array(x):
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
251
|
+
from .. import numpy as numpy_namespace
|
|
252
|
+
import numpy as np
|
|
253
|
+
if use_compat is True:
|
|
254
|
+
_check_api_version(api_version)
|
|
243
255
|
namespaces.add(numpy_namespace)
|
|
244
|
-
|
|
245
|
-
import numpy as np
|
|
256
|
+
elif use_compat is False:
|
|
246
257
|
namespaces.add(np)
|
|
258
|
+
else:
|
|
259
|
+
# numpy 2.0 has __array_namespace__ and is fully array API
|
|
260
|
+
# compatible.
|
|
261
|
+
if hasattr(x, '__array_namespace__'):
|
|
262
|
+
namespaces.add(x.__array_namespace__(api_version=api_version))
|
|
263
|
+
else:
|
|
264
|
+
namespaces.add(numpy_namespace)
|
|
247
265
|
elif is_cupy_array(x):
|
|
248
|
-
_check_api_version(api_version)
|
|
249
266
|
if _use_compat:
|
|
267
|
+
_check_api_version(api_version)
|
|
250
268
|
from .. import cupy as cupy_namespace
|
|
251
269
|
namespaces.add(cupy_namespace)
|
|
252
270
|
else:
|
|
253
271
|
import cupy as cp
|
|
254
272
|
namespaces.add(cp)
|
|
255
273
|
elif is_torch_array(x):
|
|
256
|
-
_check_api_version(api_version)
|
|
257
274
|
if _use_compat:
|
|
275
|
+
_check_api_version(api_version)
|
|
258
276
|
from .. import torch as torch_namespace
|
|
259
277
|
namespaces.add(torch_namespace)
|
|
260
278
|
else:
|
|
261
279
|
import torch
|
|
262
280
|
namespaces.add(torch)
|
|
263
281
|
elif is_dask_array(x):
|
|
264
|
-
_check_api_version(api_version)
|
|
265
282
|
if _use_compat:
|
|
283
|
+
_check_api_version(api_version)
|
|
266
284
|
from ..dask import array as dask_namespace
|
|
267
285
|
namespaces.add(dask_namespace)
|
|
268
286
|
else:
|
|
269
|
-
|
|
287
|
+
import dask.array as da
|
|
288
|
+
namespaces.add(da)
|
|
270
289
|
elif is_jax_array(x):
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
290
|
+
if use_compat is True:
|
|
291
|
+
_check_api_version(api_version)
|
|
292
|
+
raise ValueError("JAX does not have an array-api-compat wrapper")
|
|
293
|
+
elif use_compat is False:
|
|
294
|
+
import jax.numpy as jnp
|
|
295
|
+
else:
|
|
296
|
+
# jax.experimental.array_api is already an array namespace. We do
|
|
297
|
+
# not have a wrapper submodule for it.
|
|
298
|
+
import jax.experimental.array_api as jnp
|
|
275
299
|
namespaces.add(jnp)
|
|
276
300
|
elif hasattr(x, '__array_namespace__'):
|
|
301
|
+
if use_compat is True:
|
|
302
|
+
raise ValueError("The given array does not have an array-api-compat wrapper")
|
|
277
303
|
namespaces.add(x.__array_namespace__(api_version=api_version))
|
|
278
304
|
else:
|
|
279
305
|
# TODO: Support Python scalars?
|
|
@@ -1,15 +1,14 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from functools import partial
|
|
4
|
-
|
|
5
3
|
import cupy as cp
|
|
6
4
|
|
|
7
5
|
from ..common import _aliases
|
|
8
6
|
from .._internal import get_xp
|
|
9
7
|
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
from typing import Optional, Union
|
|
11
|
+
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
|
|
13
12
|
|
|
14
13
|
bool = cp.bool_
|
|
15
14
|
|
|
@@ -62,6 +61,52 @@ matmul = get_xp(cp)(_aliases.matmul)
|
|
|
62
61
|
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
|
|
63
62
|
tensordot = get_xp(cp)(_aliases.tensordot)
|
|
64
63
|
|
|
64
|
+
_copy_default = object()
|
|
65
|
+
|
|
66
|
+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
|
67
|
+
def asarray(
|
|
68
|
+
obj: Union[
|
|
69
|
+
ndarray,
|
|
70
|
+
bool,
|
|
71
|
+
int,
|
|
72
|
+
float,
|
|
73
|
+
NestedSequence[bool | int | float],
|
|
74
|
+
SupportsBufferProtocol,
|
|
75
|
+
],
|
|
76
|
+
/,
|
|
77
|
+
*,
|
|
78
|
+
dtype: Optional[Dtype] = None,
|
|
79
|
+
device: Optional[Device] = None,
|
|
80
|
+
copy: Optional[bool] = _copy_default,
|
|
81
|
+
**kwargs,
|
|
82
|
+
) -> ndarray:
|
|
83
|
+
"""
|
|
84
|
+
Array API compatibility wrapper for asarray().
|
|
85
|
+
|
|
86
|
+
See the corresponding documentation in the array library and/or the array API
|
|
87
|
+
specification for more details.
|
|
88
|
+
"""
|
|
89
|
+
with cp.cuda.Device(device):
|
|
90
|
+
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
|
|
91
|
+
# in asarray in numpy/_aliases.py.
|
|
92
|
+
if copy is not _copy_default:
|
|
93
|
+
# A future version of CuPy will change the meaning of copy=False
|
|
94
|
+
# to mean no-copy. We don't know for certain what version it will
|
|
95
|
+
# be yet, so to avoid breaking that version, we use a different
|
|
96
|
+
# default value for copy so asarray(obj) with no copy kwarg will
|
|
97
|
+
# always do the copy-if-needed behavior.
|
|
98
|
+
|
|
99
|
+
# This will still need to be updated to remove the
|
|
100
|
+
# NotImplementedError for copy=False, but at least this won't
|
|
101
|
+
# break the default or existing behavior.
|
|
102
|
+
if copy is None:
|
|
103
|
+
copy = False
|
|
104
|
+
elif copy is False:
|
|
105
|
+
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
|
|
106
|
+
kwargs['copy'] = copy
|
|
107
|
+
|
|
108
|
+
return cp.array(obj, dtype=dtype, **kwargs)
|
|
109
|
+
|
|
65
110
|
# These functions are completely new here. If the library already has them
|
|
66
111
|
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
|
67
112
|
if hasattr(cp, 'vecdot'):
|
|
@@ -73,7 +118,7 @@ if hasattr(cp, 'isdtype'):
|
|
|
73
118
|
else:
|
|
74
119
|
isdtype = get_xp(cp)(_aliases.isdtype)
|
|
75
120
|
|
|
76
|
-
__all__ = _aliases.__all__ + ['asarray', '
|
|
121
|
+
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
|
|
77
122
|
'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
|
78
123
|
'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
|
79
124
|
'bitwise_right_shift', 'concat', 'pow']
|
|
@@ -37,7 +37,7 @@ from typing import TYPE_CHECKING
|
|
|
37
37
|
if TYPE_CHECKING:
|
|
38
38
|
from typing import Optional, Union
|
|
39
39
|
|
|
40
|
-
from ...common._typing import Device, Dtype, Array
|
|
40
|
+
from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
|
|
41
41
|
|
|
42
42
|
import dask.array as da
|
|
43
43
|
|
|
@@ -76,10 +76,6 @@ def _dask_arange(
|
|
|
76
76
|
arange = get_xp(da)(_dask_arange)
|
|
77
77
|
eye = get_xp(da)(_aliases.eye)
|
|
78
78
|
|
|
79
|
-
from functools import partial
|
|
80
|
-
asarray = partial(_aliases._asarray, namespace='dask.array')
|
|
81
|
-
asarray.__doc__ = _aliases._asarray.__doc__
|
|
82
|
-
|
|
83
79
|
linspace = get_xp(da)(_aliases.linspace)
|
|
84
80
|
eye = get_xp(da)(_aliases.eye)
|
|
85
81
|
UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
|
|
@@ -113,6 +109,47 @@ trunc = get_xp(np)(_aliases.trunc)
|
|
|
113
109
|
matmul = get_xp(np)(_aliases.matmul)
|
|
114
110
|
tensordot = get_xp(np)(_aliases.tensordot)
|
|
115
111
|
|
|
112
|
+
|
|
113
|
+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
|
114
|
+
def asarray(
|
|
115
|
+
obj: Union[
|
|
116
|
+
Array,
|
|
117
|
+
bool,
|
|
118
|
+
int,
|
|
119
|
+
float,
|
|
120
|
+
NestedSequence[bool | int | float],
|
|
121
|
+
SupportsBufferProtocol,
|
|
122
|
+
],
|
|
123
|
+
/,
|
|
124
|
+
*,
|
|
125
|
+
dtype: Optional[Dtype] = None,
|
|
126
|
+
device: Optional[Device] = None,
|
|
127
|
+
copy: "Optional[Union[bool, np._CopyMode]]" = None,
|
|
128
|
+
**kwargs,
|
|
129
|
+
) -> Array:
|
|
130
|
+
"""
|
|
131
|
+
Array API compatibility wrapper for asarray().
|
|
132
|
+
|
|
133
|
+
See the corresponding documentation in the array library and/or the array API
|
|
134
|
+
specification for more details.
|
|
135
|
+
"""
|
|
136
|
+
if copy is False:
|
|
137
|
+
# copy=False is not yet implemented in dask
|
|
138
|
+
raise NotImplementedError("copy=False is not yet implemented")
|
|
139
|
+
elif copy is True:
|
|
140
|
+
if isinstance(obj, da.Array) and dtype is None:
|
|
141
|
+
return obj.copy()
|
|
142
|
+
# Go through numpy, since dask copy is no-op by default
|
|
143
|
+
obj = np.array(obj, dtype=dtype, copy=True)
|
|
144
|
+
return da.array(obj, dtype=dtype)
|
|
145
|
+
else:
|
|
146
|
+
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
|
|
147
|
+
obj = np.asarray(obj, dtype=dtype)
|
|
148
|
+
return da.from_array(obj)
|
|
149
|
+
return obj
|
|
150
|
+
|
|
151
|
+
return da.asarray(obj, dtype=dtype, **kwargs)
|
|
152
|
+
|
|
116
153
|
from dask.array import (
|
|
117
154
|
# Element wise aliases
|
|
118
155
|
arccos as acos,
|
|
@@ -21,4 +21,10 @@ from .linalg import matrix_transpose, vecdot # noqa: F401
|
|
|
21
21
|
|
|
22
22
|
from ..common._helpers import * # noqa: F403
|
|
23
23
|
|
|
24
|
+
try:
|
|
25
|
+
# Used in asarray(). Not present in older versions.
|
|
26
|
+
from numpy import _CopyMode # noqa: F401
|
|
27
|
+
except ImportError:
|
|
28
|
+
pass
|
|
29
|
+
|
|
24
30
|
__array_api_version__ = '2022.12'
|
|
@@ -1,14 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
from functools import partial
|
|
4
|
-
|
|
5
3
|
from ..common import _aliases
|
|
6
4
|
|
|
7
5
|
from .._internal import get_xp
|
|
8
6
|
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
7
|
+
from typing import TYPE_CHECKING
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from typing import Optional, Union
|
|
10
|
+
from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
|
|
12
11
|
|
|
13
12
|
import numpy as np
|
|
14
13
|
bool = np.bool_
|
|
@@ -62,6 +61,57 @@ matmul = get_xp(np)(_aliases.matmul)
|
|
|
62
61
|
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
|
|
63
62
|
tensordot = get_xp(np)(_aliases.tensordot)
|
|
64
63
|
|
|
64
|
+
def _supports_buffer_protocol(obj):
|
|
65
|
+
try:
|
|
66
|
+
memoryview(obj)
|
|
67
|
+
except TypeError:
|
|
68
|
+
return False
|
|
69
|
+
return True
|
|
70
|
+
|
|
71
|
+
# asarray also adds the copy keyword, which is not present in numpy 1.0.
|
|
72
|
+
# asarray() is different enough between numpy, cupy, and dask, the logic
|
|
73
|
+
# complicated enough that it's easier to define it separately for each module
|
|
74
|
+
# rather than trying to combine everything into one function in common/
|
|
75
|
+
def asarray(
|
|
76
|
+
obj: Union[
|
|
77
|
+
ndarray,
|
|
78
|
+
bool,
|
|
79
|
+
int,
|
|
80
|
+
float,
|
|
81
|
+
NestedSequence[bool | int | float],
|
|
82
|
+
SupportsBufferProtocol,
|
|
83
|
+
],
|
|
84
|
+
/,
|
|
85
|
+
*,
|
|
86
|
+
dtype: Optional[Dtype] = None,
|
|
87
|
+
device: Optional[Device] = None,
|
|
88
|
+
copy: "Optional[Union[bool, np._CopyMode]]" = None,
|
|
89
|
+
**kwargs,
|
|
90
|
+
) -> ndarray:
|
|
91
|
+
"""
|
|
92
|
+
Array API compatibility wrapper for asarray().
|
|
93
|
+
|
|
94
|
+
See the corresponding documentation in the array library and/or the array API
|
|
95
|
+
specification for more details.
|
|
96
|
+
"""
|
|
97
|
+
if device not in ["cpu", None]:
|
|
98
|
+
raise ValueError(f"Unsupported device for NumPy: {device!r}")
|
|
99
|
+
|
|
100
|
+
if hasattr(np, '_CopyMode'):
|
|
101
|
+
if copy is None:
|
|
102
|
+
copy = np._CopyMode.IF_NEEDED
|
|
103
|
+
elif copy is False:
|
|
104
|
+
copy = np._CopyMode.NEVER
|
|
105
|
+
elif copy is True:
|
|
106
|
+
copy = np._CopyMode.ALWAYS
|
|
107
|
+
else:
|
|
108
|
+
# Not present in older NumPys. In this case, we cannot really support
|
|
109
|
+
# copy=False.
|
|
110
|
+
if copy is False:
|
|
111
|
+
raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.")
|
|
112
|
+
|
|
113
|
+
return np.array(obj, copy=copy, dtype=dtype, **kwargs)
|
|
114
|
+
|
|
65
115
|
# These functions are completely new here. If the library already has them
|
|
66
116
|
# (i.e., numpy 2.0), use the library version instead of our wrapper.
|
|
67
117
|
if hasattr(np, 'vecdot'):
|
|
@@ -73,7 +123,7 @@ if hasattr(np, 'isdtype'):
|
|
|
73
123
|
else:
|
|
74
124
|
isdtype = get_xp(np)(_aliases.isdtype)
|
|
75
125
|
|
|
76
|
-
__all__ = _aliases.__all__ + ['asarray', '
|
|
126
|
+
__all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
|
|
77
127
|
'acosh', 'asin', 'asinh', 'atan', 'atan2',
|
|
78
128
|
'atanh', 'bitwise_left_shift', 'bitwise_invert',
|
|
79
129
|
'bitwise_right_shift', 'concat', 'pow']
|
|
@@ -1,24 +1,28 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: array_api_compat
|
|
3
|
-
Version: 1.
|
|
3
|
+
Version: 1.6
|
|
4
4
|
Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
|
|
5
5
|
Home-page: https://data-apis.org/array-api-compat/
|
|
6
6
|
Author: Consortium for Python Data API Standards
|
|
7
7
|
License: MIT
|
|
8
8
|
Classifier: Programming Language :: Python :: 3
|
|
9
|
-
Classifier: Programming Language :: Python :: 3.8
|
|
10
9
|
Classifier: Programming Language :: Python :: 3.9
|
|
11
10
|
Classifier: Programming Language :: Python :: 3.10
|
|
12
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
13
12
|
Classifier: License :: OSI Approved :: MIT License
|
|
14
13
|
Classifier: Operating System :: OS Independent
|
|
15
|
-
Requires-Python: >=3.8
|
|
16
14
|
Description-Content-Type: text/markdown
|
|
17
15
|
License-File: LICENSE
|
|
18
16
|
Provides-Extra: numpy
|
|
19
17
|
Requires-Dist: numpy; extra == "numpy"
|
|
20
18
|
Provides-Extra: cupy
|
|
21
19
|
Requires-Dist: cupy; extra == "cupy"
|
|
20
|
+
Provides-Extra: jax
|
|
21
|
+
Requires-Dist: jax; extra == "jax"
|
|
22
|
+
Provides-Extra: pytorch
|
|
23
|
+
Requires-Dist: pytorch; extra == "pytorch"
|
|
24
|
+
Provides-Extra: dask
|
|
25
|
+
Requires-Dist: dask; extra == "dask"
|
|
22
26
|
|
|
23
27
|
# Array API compatibility library
|
|
24
28
|
|
|
@@ -15,14 +15,15 @@ setup(
|
|
|
15
15
|
long_description_content_type="text/markdown",
|
|
16
16
|
url="https://data-apis.org/array-api-compat/",
|
|
17
17
|
license="MIT",
|
|
18
|
-
python_requires=">=3.8",
|
|
19
18
|
extras_require={
|
|
20
19
|
"numpy": "numpy",
|
|
21
20
|
"cupy": "cupy",
|
|
21
|
+
"jax": "jax",
|
|
22
|
+
"pytorch": "pytorch",
|
|
23
|
+
"dask": "dask",
|
|
22
24
|
},
|
|
23
25
|
classifiers=[
|
|
24
26
|
"Programming Language :: Python :: 3",
|
|
25
|
-
"Programming Language :: Python :: 3.8",
|
|
26
27
|
"Programming Language :: Python :: 3.9",
|
|
27
28
|
"Programming Language :: Python :: 3.10",
|
|
28
29
|
"Programming Language :: Python :: 3.11",
|
|
@@ -12,11 +12,11 @@ used inside of a function. Note that names starting with an underscore are autom
|
|
|
12
12
|
|
|
13
13
|
import sys
|
|
14
14
|
|
|
15
|
-
from ._helpers import import_
|
|
15
|
+
from ._helpers import import_, wrapped_libraries
|
|
16
16
|
|
|
17
17
|
import pytest
|
|
18
18
|
|
|
19
|
-
@pytest.mark.parametrize("library", ["common"
|
|
19
|
+
@pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
|
|
20
20
|
def test_all(library):
|
|
21
21
|
import_(library, wrapper=True)
|
|
22
22
|
|
|
@@ -9,24 +9,29 @@ import torch
|
|
|
9
9
|
import array_api_compat
|
|
10
10
|
from array_api_compat import array_namespace
|
|
11
11
|
|
|
12
|
-
from ._helpers import import_
|
|
12
|
+
from ._helpers import import_, all_libraries, wrapped_libraries
|
|
13
13
|
|
|
14
|
-
@pytest.mark.parametrize("
|
|
15
|
-
@pytest.mark.parametrize("api_version", [None, "2021.12"])
|
|
16
|
-
|
|
14
|
+
@pytest.mark.parametrize("use_compat", [True, False, None])
|
|
15
|
+
@pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
|
|
16
|
+
@pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
|
|
17
|
+
def test_array_namespace(library, api_version, use_compat):
|
|
17
18
|
xp = import_(library)
|
|
18
19
|
|
|
19
20
|
array = xp.asarray([1.0, 2.0, 3.0])
|
|
20
|
-
|
|
21
|
+
if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
|
|
22
|
+
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
|
|
23
|
+
return
|
|
24
|
+
namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
|
|
21
25
|
|
|
22
|
-
if
|
|
23
|
-
|
|
26
|
+
if use_compat is False or use_compat is None and library not in wrapped_libraries:
|
|
27
|
+
if library == "jax.numpy" and use_compat is None:
|
|
28
|
+
import jax.experimental.array_api
|
|
29
|
+
assert namespace == jax.experimental.array_api
|
|
30
|
+
else:
|
|
31
|
+
assert namespace == xp
|
|
24
32
|
else:
|
|
25
33
|
if library == "dask.array":
|
|
26
34
|
assert namespace == array_api_compat.dask.array
|
|
27
|
-
elif library == "jax.numpy":
|
|
28
|
-
import jax.experimental.array_api
|
|
29
|
-
assert namespace == jax.experimental.array_api
|
|
30
35
|
else:
|
|
31
36
|
assert namespace == getattr(array_api_compat, library)
|
|
32
37
|
|
|
@@ -64,14 +69,14 @@ def test_array_namespace_errors_torch():
|
|
|
64
69
|
pytest.raises(TypeError, lambda: array_namespace(x, y))
|
|
65
70
|
|
|
66
71
|
def test_api_version():
|
|
67
|
-
x =
|
|
68
|
-
|
|
69
|
-
assert array_namespace(x, api_version="2022.12") ==
|
|
70
|
-
assert array_namespace(x, api_version=None) ==
|
|
71
|
-
assert array_namespace(x) ==
|
|
72
|
+
x = torch.asarray([1, 2])
|
|
73
|
+
torch_ = import_("torch", wrapper=True)
|
|
74
|
+
assert array_namespace(x, api_version="2022.12") == torch_
|
|
75
|
+
assert array_namespace(x, api_version=None) == torch_
|
|
76
|
+
assert array_namespace(x) == torch_
|
|
72
77
|
# Should issue a warning
|
|
73
78
|
with warnings.catch_warnings(record=True) as w:
|
|
74
|
-
assert array_namespace(x, api_version="2021.12") ==
|
|
79
|
+
assert array_namespace(x, api_version="2021.12") == torch_
|
|
75
80
|
assert len(w) == 1
|
|
76
81
|
assert "2021.12" in str(w[0].message)
|
|
77
82
|
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
|
|
2
|
+
is_dask_array, is_jax_array)
|
|
3
|
+
|
|
4
|
+
from array_api_compat import is_array_api_obj, device, to_device
|
|
5
|
+
|
|
6
|
+
from ._helpers import import_, wrapped_libraries, all_libraries
|
|
7
|
+
|
|
8
|
+
import pytest
|
|
9
|
+
import numpy as np
|
|
10
|
+
import array
|
|
11
|
+
from numpy.testing import assert_allclose
|
|
12
|
+
|
|
13
|
+
is_functions = {
|
|
14
|
+
'numpy': 'is_numpy_array',
|
|
15
|
+
'cupy': 'is_cupy_array',
|
|
16
|
+
'torch': 'is_torch_array',
|
|
17
|
+
'dask.array': 'is_dask_array',
|
|
18
|
+
'jax.numpy': 'is_jax_array',
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
@pytest.mark.parametrize('library', is_functions.keys())
|
|
22
|
+
@pytest.mark.parametrize('func', is_functions.values())
|
|
23
|
+
def test_is_xp_array(library, func):
|
|
24
|
+
lib = import_(library)
|
|
25
|
+
is_func = globals()[func]
|
|
26
|
+
|
|
27
|
+
x = lib.asarray([1, 2, 3])
|
|
28
|
+
|
|
29
|
+
assert is_func(x) == (func == is_functions[library])
|
|
30
|
+
|
|
31
|
+
assert is_array_api_obj(x)
|
|
32
|
+
|
|
33
|
+
@pytest.mark.parametrize("library", all_libraries)
|
|
34
|
+
def test_device(library):
|
|
35
|
+
xp = import_(library, wrapper=True)
|
|
36
|
+
|
|
37
|
+
# We can't test much for device() and to_device() other than that
|
|
38
|
+
# x.to_device(x.device) works.
|
|
39
|
+
|
|
40
|
+
x = xp.asarray([1, 2, 3])
|
|
41
|
+
dev = device(x)
|
|
42
|
+
|
|
43
|
+
x2 = to_device(x, dev)
|
|
44
|
+
assert device(x) == device(x2)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@pytest.mark.parametrize("library", wrapped_libraries)
|
|
48
|
+
def test_to_device_host(library):
|
|
49
|
+
# different libraries have different semantics
|
|
50
|
+
# for DtoH transfers; ensure that we support a portable
|
|
51
|
+
# shim for common array libs
|
|
52
|
+
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
|
|
53
|
+
xp = import_(library, wrapper=True)
|
|
54
|
+
|
|
55
|
+
expected = np.array([1, 2, 3])
|
|
56
|
+
x = xp.asarray([1, 2, 3])
|
|
57
|
+
x = to_device(x, "cpu")
|
|
58
|
+
# torch will return a genuine Device object, but
|
|
59
|
+
# the other libs will do something different with
|
|
60
|
+
# a `device(x)` query; however, what's really important
|
|
61
|
+
# here is that we can test portably after calling
|
|
62
|
+
# to_device(x, "cpu") to return to host
|
|
63
|
+
assert_allclose(x, expected)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@pytest.mark.parametrize("target_library", is_functions.keys())
|
|
67
|
+
@pytest.mark.parametrize("source_library", is_functions.keys())
|
|
68
|
+
def test_asarray_cross_library(source_library, target_library, request):
|
|
69
|
+
if source_library == "dask.array" and target_library == "torch":
|
|
70
|
+
# Allow rest of test to execute instead of immediately xfailing
|
|
71
|
+
# xref https://github.com/pandas-dev/pandas/issues/38902
|
|
72
|
+
|
|
73
|
+
# TODO: remove xfail once
|
|
74
|
+
# https://github.com/dask/dask/issues/8260 is resolved
|
|
75
|
+
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
|
|
76
|
+
if source_library == "cupy" and target_library != "cupy":
|
|
77
|
+
# cupy explicitly disallows implicit conversions to CPU
|
|
78
|
+
pytest.skip(reason="cupy does not support implicit conversion to CPU")
|
|
79
|
+
src_lib = import_(source_library, wrapper=True)
|
|
80
|
+
tgt_lib = import_(target_library, wrapper=True)
|
|
81
|
+
is_tgt_type = globals()[is_functions[target_library]]
|
|
82
|
+
|
|
83
|
+
a = src_lib.asarray([1, 2, 3])
|
|
84
|
+
b = tgt_lib.asarray(a)
|
|
85
|
+
|
|
86
|
+
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
|
|
87
|
+
|
|
88
|
+
@pytest.mark.parametrize("library", wrapped_libraries)
|
|
89
|
+
def test_asarray_copy(library):
|
|
90
|
+
# Note, we have this test here because the test suite currently doesn't
|
|
91
|
+
# test the copy flag to asarray() very rigorously. Once
|
|
92
|
+
# https://github.com/data-apis/array-api-tests/issues/241 is fixed we
|
|
93
|
+
# should be able to delete this.
|
|
94
|
+
xp = import_(library, wrapper=True)
|
|
95
|
+
asarray = xp.asarray
|
|
96
|
+
is_lib_func = globals()[is_functions[library]]
|
|
97
|
+
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
|
|
98
|
+
|
|
99
|
+
if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
|
|
100
|
+
supports_copy_false = False
|
|
101
|
+
elif library in ['cupy', 'dask.array']:
|
|
102
|
+
supports_copy_false = False
|
|
103
|
+
else:
|
|
104
|
+
supports_copy_false = True
|
|
105
|
+
|
|
106
|
+
a = asarray([1])
|
|
107
|
+
b = asarray(a, copy=True)
|
|
108
|
+
assert is_lib_func(b)
|
|
109
|
+
a[0] = 0
|
|
110
|
+
assert all(b[0] == 1)
|
|
111
|
+
assert all(a[0] == 0)
|
|
112
|
+
|
|
113
|
+
a = asarray([1])
|
|
114
|
+
if supports_copy_false:
|
|
115
|
+
b = asarray(a, copy=False)
|
|
116
|
+
assert is_lib_func(b)
|
|
117
|
+
a[0] = 0
|
|
118
|
+
assert all(b[0] == 0)
|
|
119
|
+
else:
|
|
120
|
+
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
|
|
121
|
+
|
|
122
|
+
a = asarray([1])
|
|
123
|
+
if supports_copy_false:
|
|
124
|
+
pytest.raises(ValueError, lambda: asarray(a, copy=False,
|
|
125
|
+
dtype=xp.float64))
|
|
126
|
+
else:
|
|
127
|
+
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
|
|
128
|
+
|
|
129
|
+
a = asarray([1])
|
|
130
|
+
b = asarray(a, copy=None)
|
|
131
|
+
assert is_lib_func(b)
|
|
132
|
+
a[0] = 0
|
|
133
|
+
assert all(b[0] == 0)
|
|
134
|
+
|
|
135
|
+
a = asarray([1.0], dtype=xp.float32)
|
|
136
|
+
assert a.dtype == xp.float32
|
|
137
|
+
b = asarray(a, dtype=xp.float64, copy=None)
|
|
138
|
+
assert is_lib_func(b)
|
|
139
|
+
assert b.dtype == xp.float64
|
|
140
|
+
a[0] = 0.0
|
|
141
|
+
assert all(b[0] == 1.0)
|
|
142
|
+
|
|
143
|
+
a = asarray([1.0], dtype=xp.float64)
|
|
144
|
+
assert a.dtype == xp.float64
|
|
145
|
+
b = asarray(a, dtype=xp.float64, copy=None)
|
|
146
|
+
assert is_lib_func(b)
|
|
147
|
+
assert b.dtype == xp.float64
|
|
148
|
+
a[0] = 0.0
|
|
149
|
+
assert all(b[0] == 0.0)
|
|
150
|
+
|
|
151
|
+
# Python built-in types
|
|
152
|
+
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
|
|
153
|
+
asarray(obj, copy=True) # No error
|
|
154
|
+
asarray(obj, copy=None) # No error
|
|
155
|
+
if supports_copy_false:
|
|
156
|
+
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
|
|
157
|
+
else:
|
|
158
|
+
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
|
|
159
|
+
|
|
160
|
+
# Use the standard library array to test the buffer protocol
|
|
161
|
+
a = array.array('f', [1.0])
|
|
162
|
+
b = asarray(a, copy=True)
|
|
163
|
+
assert is_lib_func(b)
|
|
164
|
+
a[0] = 0.0
|
|
165
|
+
assert all(b[0] == 1.0)
|
|
166
|
+
|
|
167
|
+
a = array.array('f', [1.0])
|
|
168
|
+
if supports_copy_false:
|
|
169
|
+
b = asarray(a, copy=False)
|
|
170
|
+
assert is_lib_func(b)
|
|
171
|
+
a[0] = 0.0
|
|
172
|
+
assert all(b[0] == 0.0)
|
|
173
|
+
else:
|
|
174
|
+
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
|
|
175
|
+
|
|
176
|
+
a = array.array('f', [1.0])
|
|
177
|
+
b = asarray(a, copy=None)
|
|
178
|
+
assert is_lib_func(b)
|
|
179
|
+
a[0] = 0.0
|
|
180
|
+
if library == 'cupy':
|
|
181
|
+
# A copy is required for libraries where the default device is not CPU
|
|
182
|
+
assert all(b[0] == 1.0)
|
|
183
|
+
else:
|
|
184
|
+
assert all(b[0] == 0.0)
|
|
@@ -5,7 +5,7 @@ non-spec dtypes
|
|
|
5
5
|
|
|
6
6
|
import pytest
|
|
7
7
|
|
|
8
|
-
from ._helpers import import_
|
|
8
|
+
from ._helpers import import_, wrapped_libraries
|
|
9
9
|
|
|
10
10
|
# Check the known dtypes by their string names
|
|
11
11
|
|
|
@@ -64,7 +64,7 @@ def isdtype_(dtype_, kind):
|
|
|
64
64
|
assert type(res) is bool # noqa: E721
|
|
65
65
|
return res
|
|
66
66
|
|
|
67
|
-
@pytest.mark.parametrize("library",
|
|
67
|
+
@pytest.mark.parametrize("library", wrapped_libraries)
|
|
68
68
|
def test_isdtype_spec_dtypes(library):
|
|
69
69
|
xp = import_(library, wrapper=True)
|
|
70
70
|
|
|
@@ -98,7 +98,7 @@ additional_dtypes = [
|
|
|
98
98
|
'bfloat16',
|
|
99
99
|
]
|
|
100
100
|
|
|
101
|
-
@pytest.mark.parametrize("library",
|
|
101
|
+
@pytest.mark.parametrize("library", wrapped_libraries)
|
|
102
102
|
@pytest.mark.parametrize("dtype_", additional_dtypes)
|
|
103
103
|
def test_isdtype_additional_dtypes(library, dtype_):
|
|
104
104
|
xp = import_(library, wrapper=True)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Test that array_api_compat has no "hard" dependencies.
|
|
3
|
+
|
|
4
|
+
Libraries like NumPy should only be imported if a numpy array is passed to
|
|
5
|
+
array_namespace or if array_api_compat.numpy is explicitly imported.
|
|
6
|
+
|
|
7
|
+
We have to test this in a subprocess because these libraries have already been
|
|
8
|
+
imported from the other tests.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import sys
|
|
12
|
+
import subprocess
|
|
13
|
+
|
|
14
|
+
import pytest
|
|
15
|
+
|
|
16
|
+
class Array:
|
|
17
|
+
# Dummy array namespace that doesn't depend on any array library
|
|
18
|
+
def __array_namespace__(self, api_version=None):
|
|
19
|
+
class Namespace:
|
|
20
|
+
pass
|
|
21
|
+
return Namespace()
|
|
22
|
+
|
|
23
|
+
def _test_dependency(mod):
|
|
24
|
+
assert mod not in sys.modules
|
|
25
|
+
|
|
26
|
+
# Run various functions that shouldn't depend on mod and check that they
|
|
27
|
+
# don't import it.
|
|
28
|
+
|
|
29
|
+
import array_api_compat
|
|
30
|
+
assert mod not in sys.modules
|
|
31
|
+
|
|
32
|
+
a = Array()
|
|
33
|
+
|
|
34
|
+
# array-api-strict is an example of an array API library that isn't
|
|
35
|
+
# wrapped by array-api-compat.
|
|
36
|
+
if "strict" not in mod:
|
|
37
|
+
is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array")
|
|
38
|
+
assert not is_mod_array(a)
|
|
39
|
+
assert mod not in sys.modules
|
|
40
|
+
|
|
41
|
+
is_array_api_obj = getattr(array_api_compat, "is_array_api_obj")
|
|
42
|
+
assert is_array_api_obj(a)
|
|
43
|
+
assert mod not in sys.modules
|
|
44
|
+
|
|
45
|
+
array_namespace = getattr(array_api_compat, "array_namespace")
|
|
46
|
+
array_namespace(Array())
|
|
47
|
+
assert mod not in sys.modules
|
|
48
|
+
|
|
49
|
+
# TODO: Test that wrapper for library X doesn't depend on wrappers for library
|
|
50
|
+
# Y (except most array libraries actually do themselves depend on numpy).
|
|
51
|
+
|
|
52
|
+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
|
|
53
|
+
"jax.numpy", "array_api_strict"])
|
|
54
|
+
def test_numpy_dependency(library):
|
|
55
|
+
# This import is here because it imports numpy
|
|
56
|
+
from ._helpers import import_
|
|
57
|
+
|
|
58
|
+
# This unfortunately won't go through any of the pytest machinery. We
|
|
59
|
+
# reraise the exception as an AssertionError so that pytest will show it
|
|
60
|
+
# in a semi-reasonable way
|
|
61
|
+
|
|
62
|
+
# Import (in this process) to make sure 'library' is actually installed and
|
|
63
|
+
# so that cupy can be skipped.
|
|
64
|
+
import_(library)
|
|
65
|
+
|
|
66
|
+
try:
|
|
67
|
+
subprocess.run([sys.executable, '-c', f'''\
|
|
68
|
+
from tests.test_no_dependencies import _test_dependency
|
|
69
|
+
|
|
70
|
+
_test_dependency({library!r})'''], check=True, capture_output=True, encoding='utf-8')
|
|
71
|
+
except subprocess.CalledProcessError as e:
|
|
72
|
+
print(e.stdout, end='')
|
|
73
|
+
raise AssertionError(e.stderr)
|
|
@@ -1,85 +0,0 @@
|
|
|
1
|
-
from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
|
|
2
|
-
is_dask_array, is_jax_array)
|
|
3
|
-
|
|
4
|
-
from array_api_compat import is_array_api_obj, device, to_device
|
|
5
|
-
|
|
6
|
-
from ._helpers import import_
|
|
7
|
-
|
|
8
|
-
import pytest
|
|
9
|
-
import numpy as np
|
|
10
|
-
from numpy.testing import assert_allclose
|
|
11
|
-
|
|
12
|
-
is_functions = {
|
|
13
|
-
'numpy': 'is_numpy_array',
|
|
14
|
-
'cupy': 'is_cupy_array',
|
|
15
|
-
'torch': 'is_torch_array',
|
|
16
|
-
'dask.array': 'is_dask_array',
|
|
17
|
-
'jax.numpy': 'is_jax_array',
|
|
18
|
-
}
|
|
19
|
-
|
|
20
|
-
@pytest.mark.parametrize('library', is_functions.keys())
|
|
21
|
-
@pytest.mark.parametrize('func', is_functions.values())
|
|
22
|
-
def test_is_xp_array(library, func):
|
|
23
|
-
lib = import_(library)
|
|
24
|
-
is_func = globals()[func]
|
|
25
|
-
|
|
26
|
-
x = lib.asarray([1, 2, 3])
|
|
27
|
-
|
|
28
|
-
assert is_func(x) == (func == is_functions[library])
|
|
29
|
-
|
|
30
|
-
assert is_array_api_obj(x)
|
|
31
|
-
|
|
32
|
-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
|
|
33
|
-
def test_device(library):
|
|
34
|
-
xp = import_(library, wrapper=True)
|
|
35
|
-
|
|
36
|
-
# We can't test much for device() and to_device() other than that
|
|
37
|
-
# x.to_device(x.device) works.
|
|
38
|
-
|
|
39
|
-
x = xp.asarray([1, 2, 3])
|
|
40
|
-
dev = device(x)
|
|
41
|
-
|
|
42
|
-
x2 = to_device(x, dev)
|
|
43
|
-
assert device(x) == device(x2)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
|
|
47
|
-
def test_to_device_host(library):
|
|
48
|
-
# different libraries have different semantics
|
|
49
|
-
# for DtoH transfers; ensure that we support a portable
|
|
50
|
-
# shim for common array libs
|
|
51
|
-
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
|
|
52
|
-
xp = import_(library, wrapper=True)
|
|
53
|
-
|
|
54
|
-
expected = np.array([1, 2, 3])
|
|
55
|
-
x = xp.asarray([1, 2, 3])
|
|
56
|
-
x = to_device(x, "cpu")
|
|
57
|
-
# torch will return a genuine Device object, but
|
|
58
|
-
# the other libs will do something different with
|
|
59
|
-
# a `device(x)` query; however, what's really important
|
|
60
|
-
# here is that we can test portably after calling
|
|
61
|
-
# to_device(x, "cpu") to return to host
|
|
62
|
-
assert_allclose(x, expected)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
@pytest.mark.parametrize("target_library", is_functions.keys())
|
|
66
|
-
@pytest.mark.parametrize("source_library", is_functions.keys())
|
|
67
|
-
def test_asarray(source_library, target_library, request):
|
|
68
|
-
if source_library == "dask.array" and target_library == "torch":
|
|
69
|
-
# Allow rest of test to execute instead of immediately xfailing
|
|
70
|
-
# xref https://github.com/pandas-dev/pandas/issues/38902
|
|
71
|
-
|
|
72
|
-
# TODO: remove xfail once
|
|
73
|
-
# https://github.com/dask/dask/issues/8260 is resolved
|
|
74
|
-
request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
|
|
75
|
-
if source_library == "cupy" and target_library != "cupy":
|
|
76
|
-
# cupy explicitly disallows implicit conversions to CPU
|
|
77
|
-
pytest.skip(reason="cupy does not support implicit conversion to CPU")
|
|
78
|
-
src_lib = import_(source_library, wrapper=True)
|
|
79
|
-
tgt_lib = import_(target_library, wrapper=True)
|
|
80
|
-
is_tgt_type = globals()[is_functions[target_library]]
|
|
81
|
-
|
|
82
|
-
a = src_lib.asarray([1, 2, 3])
|
|
83
|
-
b = tgt_lib.asarray(a)
|
|
84
|
-
|
|
85
|
-
assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/dependency_links.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|