array-api-compat 1.5.1__tar.gz → 1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {array_api_compat-1.5.1 → array_api_compat-1.6}/PKG-INFO +7 -3
  2. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/__init__.py +1 -1
  3. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_aliases.py +4 -88
  4. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_helpers.py +40 -14
  5. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/_aliases.py +51 -6
  6. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/dask/array/_aliases.py +42 -5
  7. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/__init__.py +6 -0
  8. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/_aliases.py +56 -6
  9. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/PKG-INFO +7 -3
  10. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/SOURCES.txt +1 -0
  11. array_api_compat-1.6/array_api_compat.egg-info/requires.txt +15 -0
  12. {array_api_compat-1.5.1 → array_api_compat-1.6}/setup.py +3 -2
  13. {array_api_compat-1.5.1 → array_api_compat-1.6}/tests/test_all.py +2 -2
  14. {array_api_compat-1.5.1 → array_api_compat-1.6}/tests/test_array_namespace.py +21 -16
  15. array_api_compat-1.6/tests/test_common.py +184 -0
  16. {array_api_compat-1.5.1 → array_api_compat-1.6}/tests/test_isdtype.py +3 -3
  17. array_api_compat-1.6/tests/test_no_dependencies.py +73 -0
  18. array_api_compat-1.5.1/array_api_compat.egg-info/requires.txt +0 -6
  19. array_api_compat-1.5.1/tests/test_common.py +0 -85
  20. {array_api_compat-1.5.1 → array_api_compat-1.6}/LICENSE +0 -0
  21. {array_api_compat-1.5.1 → array_api_compat-1.6}/README.md +0 -0
  22. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/_internal.py +0 -0
  23. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/__init__.py +0 -0
  24. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_fft.py +0 -0
  25. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_linalg.py +0 -0
  26. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/common/_typing.py +0 -0
  27. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/__init__.py +0 -0
  28. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/_typing.py +0 -0
  29. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/fft.py +0 -0
  30. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/cupy/linalg.py +0 -0
  31. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/dask/__init__.py +0 -0
  32. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/dask/array/__init__.py +0 -0
  33. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/dask/array/linalg.py +0 -0
  34. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/_typing.py +0 -0
  35. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/fft.py +0 -0
  36. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/numpy/linalg.py +0 -0
  37. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/torch/__init__.py +0 -0
  38. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/torch/_aliases.py +0 -0
  39. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/torch/fft.py +0 -0
  40. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat/torch/linalg.py +0 -0
  41. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/dependency_links.txt +0 -0
  42. {array_api_compat-1.5.1 → array_api_compat-1.6}/array_api_compat.egg-info/top_level.txt +0 -0
  43. {array_api_compat-1.5.1 → array_api_compat-1.6}/setup.cfg +0 -0
  44. {array_api_compat-1.5.1 → array_api_compat-1.6}/tests/test_vendoring.py +0 -0
@@ -1,24 +1,28 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: array_api_compat
3
- Version: 1.5.1
3
+ Version: 1.6
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
7
7
  License: MIT
8
8
  Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
9
  Classifier: Programming Language :: Python :: 3.9
11
10
  Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
13
12
  Classifier: License :: OSI Approved :: MIT License
14
13
  Classifier: Operating System :: OS Independent
15
- Requires-Python: >=3.8
16
14
  Description-Content-Type: text/markdown
17
15
  License-File: LICENSE
18
16
  Provides-Extra: numpy
19
17
  Requires-Dist: numpy; extra == "numpy"
20
18
  Provides-Extra: cupy
21
19
  Requires-Dist: cupy; extra == "cupy"
20
+ Provides-Extra: jax
21
+ Requires-Dist: jax; extra == "jax"
22
+ Provides-Extra: pytorch
23
+ Requires-Dist: pytorch; extra == "pytorch"
24
+ Provides-Extra: dask
25
+ Requires-Dist: dask; extra == "dask"
22
26
 
23
27
  # Array API compatibility library
24
28
 
@@ -17,6 +17,6 @@ to ensure they are not using functionality outside of the standard, but prefer
17
17
  this implementation for the default when working with NumPy arrays.
18
18
 
19
19
  """
20
- __version__ = '1.5.1'
20
+ __version__ = '1.6'
21
21
 
22
22
  from .common import * # noqa: F401, F403
@@ -6,18 +6,18 @@ from __future__ import annotations
6
6
 
7
7
  from typing import TYPE_CHECKING
8
8
  if TYPE_CHECKING:
9
- import numpy as np
10
9
  from typing import Optional, Sequence, Tuple, Union
11
- from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
10
+ from ._typing import ndarray, Device, Dtype
12
11
 
13
12
  from typing import NamedTuple
14
- from types import ModuleType
15
13
  import inspect
16
14
 
17
- from ._helpers import _check_device, is_numpy_array, array_namespace
15
+ from ._helpers import _check_device
18
16
 
19
17
  # These functions are modified from the NumPy versions.
20
18
 
19
+ # Creation functions add the device keyword (which does nothing for NumPy)
20
+
21
21
  def arange(
22
22
  start: Union[int, float],
23
23
  /,
@@ -268,90 +268,6 @@ def var(
268
268
  def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269
269
  return xp.transpose(x, axes)
270
270
 
271
- # Creation functions add the device keyword (which does nothing for NumPy)
272
-
273
- # asarray also adds the copy keyword
274
- def _asarray(
275
- obj: Union[
276
- ndarray,
277
- bool,
278
- int,
279
- float,
280
- NestedSequence[bool | int | float],
281
- SupportsBufferProtocol,
282
- ],
283
- /,
284
- *,
285
- dtype: Optional[Dtype] = None,
286
- device: Optional[Device] = None,
287
- copy: "Optional[Union[bool, np._CopyMode]]" = None,
288
- namespace = None,
289
- **kwargs,
290
- ) -> ndarray:
291
- """
292
- Array API compatibility wrapper for asarray().
293
-
294
- See the corresponding documentation in NumPy/CuPy and/or the array API
295
- specification for more details.
296
-
297
- """
298
- if namespace is None:
299
- try:
300
- xp = array_namespace(obj, _use_compat=False)
301
- except ValueError:
302
- # TODO: What about lists of arrays?
303
- raise ValueError("A namespace must be specified for asarray() with non-array input")
304
- elif isinstance(namespace, ModuleType):
305
- xp = namespace
306
- elif namespace == 'numpy':
307
- import numpy as xp
308
- elif namespace == 'cupy':
309
- import cupy as xp
310
- elif namespace == 'dask.array':
311
- import dask.array as xp
312
- else:
313
- raise ValueError("Unrecognized namespace argument to asarray()")
314
-
315
- _check_device(xp, device)
316
- if is_numpy_array(obj):
317
- import numpy as np
318
- if hasattr(np, '_CopyMode'):
319
- # Not present in older NumPys
320
- COPY_FALSE = (False, np._CopyMode.IF_NEEDED)
321
- COPY_TRUE = (True, np._CopyMode.ALWAYS)
322
- else:
323
- COPY_FALSE = (False,)
324
- COPY_TRUE = (True,)
325
- else:
326
- COPY_FALSE = (False,)
327
- COPY_TRUE = (True,)
328
- if copy in COPY_FALSE and namespace != "dask.array":
329
- # copy=False is not yet implemented in xp.asarray
330
- raise NotImplementedError("copy=False is not yet implemented")
331
- if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)):
332
- if dtype is not None and obj.dtype != dtype:
333
- copy = True
334
- if copy in COPY_TRUE:
335
- return xp.array(obj, copy=True, dtype=dtype)
336
- return obj
337
- elif namespace == "dask.array":
338
- if copy in COPY_TRUE:
339
- if dtype is None:
340
- return obj.copy()
341
- # Go through numpy, since dask copy is no-op by default
342
- import numpy as np
343
- obj = np.array(obj, dtype=dtype, copy=True)
344
- return xp.array(obj, dtype=dtype)
345
- else:
346
- import dask.array as da
347
- import numpy as np
348
- if not isinstance(obj, da.Array):
349
- obj = np.asarray(obj, dtype=dtype)
350
- return da.from_array(obj)
351
- return obj
352
-
353
- return xp.asarray(obj, dtype=dtype, **kwargs)
354
-
355
271
  # np.reshape calls the keyword argument 'newshape' instead of 'shape'
356
272
  def reshape(x: ndarray,
357
273
  /,
@@ -178,7 +178,7 @@ def _check_api_version(api_version):
178
178
  elif api_version is not None and api_version != '2022.12':
179
179
  raise ValueError("Only the 2022.12 version of the array API specification is currently supported")
180
180
 
181
- def array_namespace(*xs, api_version=None, _use_compat=True):
181
+ def array_namespace(*xs, api_version=None, use_compat=None):
182
182
  """
183
183
  Get the array API compatible namespace for the arrays `xs`.
184
184
 
@@ -191,6 +191,12 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
191
191
  The newest version of the spec that you need support for (currently
192
192
  the compat library wrapped APIs support v2022.12).
193
193
 
194
+ use_compat: bool or None
195
+ If None (the default), the native namespace will be returned if it is
196
+ already array API compatible, otherwise a compat wrapper is used. If
197
+ True, the compat library wrapped library will be returned. If False,
198
+ the native library namespace is returned.
199
+
194
200
  Returns
195
201
  -------
196
202
 
@@ -234,46 +240,66 @@ def array_namespace(*xs, api_version=None, _use_compat=True):
234
240
  is_jax_array
235
241
 
236
242
  """
243
+ if use_compat not in [None, True, False]:
244
+ raise ValueError("use_compat must be None, True, or False")
245
+
246
+ _use_compat = use_compat in [None, True]
247
+
237
248
  namespaces = set()
238
249
  for x in xs:
239
250
  if is_numpy_array(x):
240
- _check_api_version(api_version)
241
- if _use_compat:
242
- from .. import numpy as numpy_namespace
251
+ from .. import numpy as numpy_namespace
252
+ import numpy as np
253
+ if use_compat is True:
254
+ _check_api_version(api_version)
243
255
  namespaces.add(numpy_namespace)
244
- else:
245
- import numpy as np
256
+ elif use_compat is False:
246
257
  namespaces.add(np)
258
+ else:
259
+ # numpy 2.0 has __array_namespace__ and is fully array API
260
+ # compatible.
261
+ if hasattr(x, '__array_namespace__'):
262
+ namespaces.add(x.__array_namespace__(api_version=api_version))
263
+ else:
264
+ namespaces.add(numpy_namespace)
247
265
  elif is_cupy_array(x):
248
- _check_api_version(api_version)
249
266
  if _use_compat:
267
+ _check_api_version(api_version)
250
268
  from .. import cupy as cupy_namespace
251
269
  namespaces.add(cupy_namespace)
252
270
  else:
253
271
  import cupy as cp
254
272
  namespaces.add(cp)
255
273
  elif is_torch_array(x):
256
- _check_api_version(api_version)
257
274
  if _use_compat:
275
+ _check_api_version(api_version)
258
276
  from .. import torch as torch_namespace
259
277
  namespaces.add(torch_namespace)
260
278
  else:
261
279
  import torch
262
280
  namespaces.add(torch)
263
281
  elif is_dask_array(x):
264
- _check_api_version(api_version)
265
282
  if _use_compat:
283
+ _check_api_version(api_version)
266
284
  from ..dask import array as dask_namespace
267
285
  namespaces.add(dask_namespace)
268
286
  else:
269
- raise TypeError("_use_compat cannot be False if input array is a dask array!")
287
+ import dask.array as da
288
+ namespaces.add(da)
270
289
  elif is_jax_array(x):
271
- _check_api_version(api_version)
272
- # jax.experimental.array_api is already an array namespace. We do
273
- # not have a wrapper submodule for it.
274
- import jax.experimental.array_api as jnp
290
+ if use_compat is True:
291
+ _check_api_version(api_version)
292
+ raise ValueError("JAX does not have an array-api-compat wrapper")
293
+ elif use_compat is False:
294
+ import jax.numpy as jnp
295
+ else:
296
+ # jax.experimental.array_api is already an array namespace. We do
297
+ # not have a wrapper submodule for it.
298
+ import jax.experimental.array_api as jnp
275
299
  namespaces.add(jnp)
276
300
  elif hasattr(x, '__array_namespace__'):
301
+ if use_compat is True:
302
+ raise ValueError("The given array does not have an array-api-compat wrapper")
277
303
  namespaces.add(x.__array_namespace__(api_version=api_version))
278
304
  else:
279
305
  # TODO: Support Python scalars?
@@ -1,15 +1,14 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial
4
-
5
3
  import cupy as cp
6
4
 
7
5
  from ..common import _aliases
8
6
  from .._internal import get_xp
9
7
 
10
- asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
11
- asarray.__doc__ = _aliases._asarray.__doc__
12
- del partial
8
+ from typing import TYPE_CHECKING
9
+ if TYPE_CHECKING:
10
+ from typing import Optional, Union
11
+ from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
13
12
 
14
13
  bool = cp.bool_
15
14
 
@@ -62,6 +61,52 @@ matmul = get_xp(cp)(_aliases.matmul)
62
61
  matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
63
62
  tensordot = get_xp(cp)(_aliases.tensordot)
64
63
 
64
+ _copy_default = object()
65
+
66
+ # asarray also adds the copy keyword, which is not present in numpy 1.0.
67
+ def asarray(
68
+ obj: Union[
69
+ ndarray,
70
+ bool,
71
+ int,
72
+ float,
73
+ NestedSequence[bool | int | float],
74
+ SupportsBufferProtocol,
75
+ ],
76
+ /,
77
+ *,
78
+ dtype: Optional[Dtype] = None,
79
+ device: Optional[Device] = None,
80
+ copy: Optional[bool] = _copy_default,
81
+ **kwargs,
82
+ ) -> ndarray:
83
+ """
84
+ Array API compatibility wrapper for asarray().
85
+
86
+ See the corresponding documentation in the array library and/or the array API
87
+ specification for more details.
88
+ """
89
+ with cp.cuda.Device(device):
90
+ # cupy is like NumPy 1.26 (except without _CopyMode). See the comments
91
+ # in asarray in numpy/_aliases.py.
92
+ if copy is not _copy_default:
93
+ # A future version of CuPy will change the meaning of copy=False
94
+ # to mean no-copy. We don't know for certain what version it will
95
+ # be yet, so to avoid breaking that version, we use a different
96
+ # default value for copy so asarray(obj) with no copy kwarg will
97
+ # always do the copy-if-needed behavior.
98
+
99
+ # This will still need to be updated to remove the
100
+ # NotImplementedError for copy=False, but at least this won't
101
+ # break the default or existing behavior.
102
+ if copy is None:
103
+ copy = False
104
+ elif copy is False:
105
+ raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
106
+ kwargs['copy'] = copy
107
+
108
+ return cp.array(obj, dtype=dtype, **kwargs)
109
+
65
110
  # These functions are completely new here. If the library already has them
66
111
  # (i.e., numpy 2.0), use the library version instead of our wrapper.
67
112
  if hasattr(cp, 'vecdot'):
@@ -73,7 +118,7 @@ if hasattr(cp, 'isdtype'):
73
118
  else:
74
119
  isdtype = get_xp(cp)(_aliases.isdtype)
75
120
 
76
- __all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
121
+ __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
77
122
  'acosh', 'asin', 'asinh', 'atan', 'atan2',
78
123
  'atanh', 'bitwise_left_shift', 'bitwise_invert',
79
124
  'bitwise_right_shift', 'concat', 'pow']
@@ -37,7 +37,7 @@ from typing import TYPE_CHECKING
37
37
  if TYPE_CHECKING:
38
38
  from typing import Optional, Union
39
39
 
40
- from ...common._typing import Device, Dtype, Array
40
+ from ...common._typing import Device, Dtype, Array, NestedSequence, SupportsBufferProtocol
41
41
 
42
42
  import dask.array as da
43
43
 
@@ -76,10 +76,6 @@ def _dask_arange(
76
76
  arange = get_xp(da)(_dask_arange)
77
77
  eye = get_xp(da)(_aliases.eye)
78
78
 
79
- from functools import partial
80
- asarray = partial(_aliases._asarray, namespace='dask.array')
81
- asarray.__doc__ = _aliases._asarray.__doc__
82
-
83
79
  linspace = get_xp(da)(_aliases.linspace)
84
80
  eye = get_xp(da)(_aliases.eye)
85
81
  UniqueAllResult = get_xp(da)(_aliases.UniqueAllResult)
@@ -113,6 +109,47 @@ trunc = get_xp(np)(_aliases.trunc)
113
109
  matmul = get_xp(np)(_aliases.matmul)
114
110
  tensordot = get_xp(np)(_aliases.tensordot)
115
111
 
112
+
113
+ # asarray also adds the copy keyword, which is not present in numpy 1.0.
114
+ def asarray(
115
+ obj: Union[
116
+ Array,
117
+ bool,
118
+ int,
119
+ float,
120
+ NestedSequence[bool | int | float],
121
+ SupportsBufferProtocol,
122
+ ],
123
+ /,
124
+ *,
125
+ dtype: Optional[Dtype] = None,
126
+ device: Optional[Device] = None,
127
+ copy: "Optional[Union[bool, np._CopyMode]]" = None,
128
+ **kwargs,
129
+ ) -> Array:
130
+ """
131
+ Array API compatibility wrapper for asarray().
132
+
133
+ See the corresponding documentation in the array library and/or the array API
134
+ specification for more details.
135
+ """
136
+ if copy is False:
137
+ # copy=False is not yet implemented in dask
138
+ raise NotImplementedError("copy=False is not yet implemented")
139
+ elif copy is True:
140
+ if isinstance(obj, da.Array) and dtype is None:
141
+ return obj.copy()
142
+ # Go through numpy, since dask copy is no-op by default
143
+ obj = np.array(obj, dtype=dtype, copy=True)
144
+ return da.array(obj, dtype=dtype)
145
+ else:
146
+ if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
147
+ obj = np.asarray(obj, dtype=dtype)
148
+ return da.from_array(obj)
149
+ return obj
150
+
151
+ return da.asarray(obj, dtype=dtype, **kwargs)
152
+
116
153
  from dask.array import (
117
154
  # Element wise aliases
118
155
  arccos as acos,
@@ -21,4 +21,10 @@ from .linalg import matrix_transpose, vecdot # noqa: F401
21
21
 
22
22
  from ..common._helpers import * # noqa: F403
23
23
 
24
+ try:
25
+ # Used in asarray(). Not present in older versions.
26
+ from numpy import _CopyMode # noqa: F401
27
+ except ImportError:
28
+ pass
29
+
24
30
  __array_api_version__ = '2022.12'
@@ -1,14 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import partial
4
-
5
3
  from ..common import _aliases
6
4
 
7
5
  from .._internal import get_xp
8
6
 
9
- asarray = asarray_numpy = partial(_aliases._asarray, namespace='numpy')
10
- asarray.__doc__ = _aliases._asarray.__doc__
11
- del partial
7
+ from typing import TYPE_CHECKING
8
+ if TYPE_CHECKING:
9
+ from typing import Optional, Union
10
+ from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol
12
11
 
13
12
  import numpy as np
14
13
  bool = np.bool_
@@ -62,6 +61,57 @@ matmul = get_xp(np)(_aliases.matmul)
62
61
  matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
63
62
  tensordot = get_xp(np)(_aliases.tensordot)
64
63
 
64
+ def _supports_buffer_protocol(obj):
65
+ try:
66
+ memoryview(obj)
67
+ except TypeError:
68
+ return False
69
+ return True
70
+
71
+ # asarray also adds the copy keyword, which is not present in numpy 1.0.
72
+ # asarray() is different enough between numpy, cupy, and dask, the logic
73
+ # complicated enough that it's easier to define it separately for each module
74
+ # rather than trying to combine everything into one function in common/
75
+ def asarray(
76
+ obj: Union[
77
+ ndarray,
78
+ bool,
79
+ int,
80
+ float,
81
+ NestedSequence[bool | int | float],
82
+ SupportsBufferProtocol,
83
+ ],
84
+ /,
85
+ *,
86
+ dtype: Optional[Dtype] = None,
87
+ device: Optional[Device] = None,
88
+ copy: "Optional[Union[bool, np._CopyMode]]" = None,
89
+ **kwargs,
90
+ ) -> ndarray:
91
+ """
92
+ Array API compatibility wrapper for asarray().
93
+
94
+ See the corresponding documentation in the array library and/or the array API
95
+ specification for more details.
96
+ """
97
+ if device not in ["cpu", None]:
98
+ raise ValueError(f"Unsupported device for NumPy: {device!r}")
99
+
100
+ if hasattr(np, '_CopyMode'):
101
+ if copy is None:
102
+ copy = np._CopyMode.IF_NEEDED
103
+ elif copy is False:
104
+ copy = np._CopyMode.NEVER
105
+ elif copy is True:
106
+ copy = np._CopyMode.ALWAYS
107
+ else:
108
+ # Not present in older NumPys. In this case, we cannot really support
109
+ # copy=False.
110
+ if copy is False:
111
+ raise NotImplementedError("asarray(copy=False) requires a newer version of NumPy.")
112
+
113
+ return np.array(obj, copy=copy, dtype=dtype, **kwargs)
114
+
65
115
  # These functions are completely new here. If the library already has them
66
116
  # (i.e., numpy 2.0), use the library version instead of our wrapper.
67
117
  if hasattr(np, 'vecdot'):
@@ -73,7 +123,7 @@ if hasattr(np, 'isdtype'):
73
123
  else:
74
124
  isdtype = get_xp(np)(_aliases.isdtype)
75
125
 
76
- __all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
126
+ __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos',
77
127
  'acosh', 'asin', 'asinh', 'atan', 'atan2',
78
128
  'atanh', 'bitwise_left_shift', 'bitwise_invert',
79
129
  'bitwise_right_shift', 'concat', 'pow']
@@ -1,24 +1,28 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: array_api_compat
3
- Version: 1.5.1
3
+ Version: 1.6
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
7
7
  License: MIT
8
8
  Classifier: Programming Language :: Python :: 3
9
- Classifier: Programming Language :: Python :: 3.8
10
9
  Classifier: Programming Language :: Python :: 3.9
11
10
  Classifier: Programming Language :: Python :: 3.10
12
11
  Classifier: Programming Language :: Python :: 3.11
13
12
  Classifier: License :: OSI Approved :: MIT License
14
13
  Classifier: Operating System :: OS Independent
15
- Requires-Python: >=3.8
16
14
  Description-Content-Type: text/markdown
17
15
  License-File: LICENSE
18
16
  Provides-Extra: numpy
19
17
  Requires-Dist: numpy; extra == "numpy"
20
18
  Provides-Extra: cupy
21
19
  Requires-Dist: cupy; extra == "cupy"
20
+ Provides-Extra: jax
21
+ Requires-Dist: jax; extra == "jax"
22
+ Provides-Extra: pytorch
23
+ Requires-Dist: pytorch; extra == "pytorch"
24
+ Provides-Extra: dask
25
+ Requires-Dist: dask; extra == "dask"
22
26
 
23
27
  # Array API compatibility library
24
28
 
@@ -36,4 +36,5 @@ tests/test_all.py
36
36
  tests/test_array_namespace.py
37
37
  tests/test_common.py
38
38
  tests/test_isdtype.py
39
+ tests/test_no_dependencies.py
39
40
  tests/test_vendoring.py
@@ -0,0 +1,15 @@
1
+
2
+ [cupy]
3
+ cupy
4
+
5
+ [dask]
6
+ dask
7
+
8
+ [jax]
9
+ jax
10
+
11
+ [numpy]
12
+ numpy
13
+
14
+ [pytorch]
15
+ pytorch
@@ -15,14 +15,15 @@ setup(
15
15
  long_description_content_type="text/markdown",
16
16
  url="https://data-apis.org/array-api-compat/",
17
17
  license="MIT",
18
- python_requires=">=3.8",
19
18
  extras_require={
20
19
  "numpy": "numpy",
21
20
  "cupy": "cupy",
21
+ "jax": "jax",
22
+ "pytorch": "pytorch",
23
+ "dask": "dask",
22
24
  },
23
25
  classifiers=[
24
26
  "Programming Language :: Python :: 3",
25
- "Programming Language :: Python :: 3.8",
26
27
  "Programming Language :: Python :: 3.9",
27
28
  "Programming Language :: Python :: 3.10",
28
29
  "Programming Language :: Python :: 3.11",
@@ -12,11 +12,11 @@ used inside of a function. Note that names starting with an underscore are autom
12
12
 
13
13
  import sys
14
14
 
15
- from ._helpers import import_
15
+ from ._helpers import import_, wrapped_libraries
16
16
 
17
17
  import pytest
18
18
 
19
- @pytest.mark.parametrize("library", ["common", "cupy", "numpy", "torch", "dask.array"])
19
+ @pytest.mark.parametrize("library", ["common"] + wrapped_libraries)
20
20
  def test_all(library):
21
21
  import_(library, wrapper=True)
22
22
 
@@ -9,24 +9,29 @@ import torch
9
9
  import array_api_compat
10
10
  from array_api_compat import array_namespace
11
11
 
12
- from ._helpers import import_
12
+ from ._helpers import import_, all_libraries, wrapped_libraries
13
13
 
14
- @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
15
- @pytest.mark.parametrize("api_version", [None, "2021.12"])
16
- def test_array_namespace(library, api_version):
14
+ @pytest.mark.parametrize("use_compat", [True, False, None])
15
+ @pytest.mark.parametrize("api_version", [None, "2021.12", "2022.12"])
16
+ @pytest.mark.parametrize("library", all_libraries + ['array_api_strict'])
17
+ def test_array_namespace(library, api_version, use_compat):
17
18
  xp = import_(library)
18
19
 
19
20
  array = xp.asarray([1.0, 2.0, 3.0])
20
- namespace = array_api_compat.array_namespace(array, api_version=api_version)
21
+ if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
22
+ pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
23
+ return
24
+ namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
21
25
 
22
- if "array_api" in library:
23
- assert namespace == xp
26
+ if use_compat is False or use_compat is None and library not in wrapped_libraries:
27
+ if library == "jax.numpy" and use_compat is None:
28
+ import jax.experimental.array_api
29
+ assert namespace == jax.experimental.array_api
30
+ else:
31
+ assert namespace == xp
24
32
  else:
25
33
  if library == "dask.array":
26
34
  assert namespace == array_api_compat.dask.array
27
- elif library == "jax.numpy":
28
- import jax.experimental.array_api
29
- assert namespace == jax.experimental.array_api
30
35
  else:
31
36
  assert namespace == getattr(array_api_compat, library)
32
37
 
@@ -64,14 +69,14 @@ def test_array_namespace_errors_torch():
64
69
  pytest.raises(TypeError, lambda: array_namespace(x, y))
65
70
 
66
71
  def test_api_version():
67
- x = np.asarray([1, 2])
68
- np_ = import_("numpy", wrapper=True)
69
- assert array_namespace(x, api_version="2022.12") == np_
70
- assert array_namespace(x, api_version=None) == np_
71
- assert array_namespace(x) == np_
72
+ x = torch.asarray([1, 2])
73
+ torch_ = import_("torch", wrapper=True)
74
+ assert array_namespace(x, api_version="2022.12") == torch_
75
+ assert array_namespace(x, api_version=None) == torch_
76
+ assert array_namespace(x) == torch_
72
77
  # Should issue a warning
73
78
  with warnings.catch_warnings(record=True) as w:
74
- assert array_namespace(x, api_version="2021.12") == np_
79
+ assert array_namespace(x, api_version="2021.12") == torch_
75
80
  assert len(w) == 1
76
81
  assert "2021.12" in str(w[0].message)
77
82
 
@@ -0,0 +1,184 @@
1
+ from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2
+ is_dask_array, is_jax_array)
3
+
4
+ from array_api_compat import is_array_api_obj, device, to_device
5
+
6
+ from ._helpers import import_, wrapped_libraries, all_libraries
7
+
8
+ import pytest
9
+ import numpy as np
10
+ import array
11
+ from numpy.testing import assert_allclose
12
+
13
+ is_functions = {
14
+ 'numpy': 'is_numpy_array',
15
+ 'cupy': 'is_cupy_array',
16
+ 'torch': 'is_torch_array',
17
+ 'dask.array': 'is_dask_array',
18
+ 'jax.numpy': 'is_jax_array',
19
+ }
20
+
21
+ @pytest.mark.parametrize('library', is_functions.keys())
22
+ @pytest.mark.parametrize('func', is_functions.values())
23
+ def test_is_xp_array(library, func):
24
+ lib = import_(library)
25
+ is_func = globals()[func]
26
+
27
+ x = lib.asarray([1, 2, 3])
28
+
29
+ assert is_func(x) == (func == is_functions[library])
30
+
31
+ assert is_array_api_obj(x)
32
+
33
+ @pytest.mark.parametrize("library", all_libraries)
34
+ def test_device(library):
35
+ xp = import_(library, wrapper=True)
36
+
37
+ # We can't test much for device() and to_device() other than that
38
+ # x.to_device(x.device) works.
39
+
40
+ x = xp.asarray([1, 2, 3])
41
+ dev = device(x)
42
+
43
+ x2 = to_device(x, dev)
44
+ assert device(x) == device(x2)
45
+
46
+
47
+ @pytest.mark.parametrize("library", wrapped_libraries)
48
+ def test_to_device_host(library):
49
+ # different libraries have different semantics
50
+ # for DtoH transfers; ensure that we support a portable
51
+ # shim for common array libs
52
+ # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
53
+ xp = import_(library, wrapper=True)
54
+
55
+ expected = np.array([1, 2, 3])
56
+ x = xp.asarray([1, 2, 3])
57
+ x = to_device(x, "cpu")
58
+ # torch will return a genuine Device object, but
59
+ # the other libs will do something different with
60
+ # a `device(x)` query; however, what's really important
61
+ # here is that we can test portably after calling
62
+ # to_device(x, "cpu") to return to host
63
+ assert_allclose(x, expected)
64
+
65
+
66
+ @pytest.mark.parametrize("target_library", is_functions.keys())
67
+ @pytest.mark.parametrize("source_library", is_functions.keys())
68
+ def test_asarray_cross_library(source_library, target_library, request):
69
+ if source_library == "dask.array" and target_library == "torch":
70
+ # Allow rest of test to execute instead of immediately xfailing
71
+ # xref https://github.com/pandas-dev/pandas/issues/38902
72
+
73
+ # TODO: remove xfail once
74
+ # https://github.com/dask/dask/issues/8260 is resolved
75
+ request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
76
+ if source_library == "cupy" and target_library != "cupy":
77
+ # cupy explicitly disallows implicit conversions to CPU
78
+ pytest.skip(reason="cupy does not support implicit conversion to CPU")
79
+ src_lib = import_(source_library, wrapper=True)
80
+ tgt_lib = import_(target_library, wrapper=True)
81
+ is_tgt_type = globals()[is_functions[target_library]]
82
+
83
+ a = src_lib.asarray([1, 2, 3])
84
+ b = tgt_lib.asarray(a)
85
+
86
+ assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
87
+
88
+ @pytest.mark.parametrize("library", wrapped_libraries)
89
+ def test_asarray_copy(library):
90
+ # Note, we have this test here because the test suite currently doesn't
91
+ # test the copy flag to asarray() very rigorously. Once
92
+ # https://github.com/data-apis/array-api-tests/issues/241 is fixed we
93
+ # should be able to delete this.
94
+ xp = import_(library, wrapper=True)
95
+ asarray = xp.asarray
96
+ is_lib_func = globals()[is_functions[library]]
97
+ all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()
98
+
99
+ if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
100
+ supports_copy_false = False
101
+ elif library in ['cupy', 'dask.array']:
102
+ supports_copy_false = False
103
+ else:
104
+ supports_copy_false = True
105
+
106
+ a = asarray([1])
107
+ b = asarray(a, copy=True)
108
+ assert is_lib_func(b)
109
+ a[0] = 0
110
+ assert all(b[0] == 1)
111
+ assert all(a[0] == 0)
112
+
113
+ a = asarray([1])
114
+ if supports_copy_false:
115
+ b = asarray(a, copy=False)
116
+ assert is_lib_func(b)
117
+ a[0] = 0
118
+ assert all(b[0] == 0)
119
+ else:
120
+ pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
121
+
122
+ a = asarray([1])
123
+ if supports_copy_false:
124
+ pytest.raises(ValueError, lambda: asarray(a, copy=False,
125
+ dtype=xp.float64))
126
+ else:
127
+ pytest.raises(NotImplementedError, lambda: asarray(a, copy=False, dtype=xp.float64))
128
+
129
+ a = asarray([1])
130
+ b = asarray(a, copy=None)
131
+ assert is_lib_func(b)
132
+ a[0] = 0
133
+ assert all(b[0] == 0)
134
+
135
+ a = asarray([1.0], dtype=xp.float32)
136
+ assert a.dtype == xp.float32
137
+ b = asarray(a, dtype=xp.float64, copy=None)
138
+ assert is_lib_func(b)
139
+ assert b.dtype == xp.float64
140
+ a[0] = 0.0
141
+ assert all(b[0] == 1.0)
142
+
143
+ a = asarray([1.0], dtype=xp.float64)
144
+ assert a.dtype == xp.float64
145
+ b = asarray(a, dtype=xp.float64, copy=None)
146
+ assert is_lib_func(b)
147
+ assert b.dtype == xp.float64
148
+ a[0] = 0.0
149
+ assert all(b[0] == 0.0)
150
+
151
+ # Python built-in types
152
+ for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
153
+ asarray(obj, copy=True) # No error
154
+ asarray(obj, copy=None) # No error
155
+ if supports_copy_false:
156
+ pytest.raises(ValueError, lambda: asarray(obj, copy=False))
157
+ else:
158
+ pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
159
+
160
+ # Use the standard library array to test the buffer protocol
161
+ a = array.array('f', [1.0])
162
+ b = asarray(a, copy=True)
163
+ assert is_lib_func(b)
164
+ a[0] = 0.0
165
+ assert all(b[0] == 1.0)
166
+
167
+ a = array.array('f', [1.0])
168
+ if supports_copy_false:
169
+ b = asarray(a, copy=False)
170
+ assert is_lib_func(b)
171
+ a[0] = 0.0
172
+ assert all(b[0] == 0.0)
173
+ else:
174
+ pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))
175
+
176
+ a = array.array('f', [1.0])
177
+ b = asarray(a, copy=None)
178
+ assert is_lib_func(b)
179
+ a[0] = 0.0
180
+ if library == 'cupy':
181
+ # A copy is required for libraries where the default device is not CPU
182
+ assert all(b[0] == 1.0)
183
+ else:
184
+ assert all(b[0] == 0.0)
@@ -5,7 +5,7 @@ non-spec dtypes
5
5
 
6
6
  import pytest
7
7
 
8
- from ._helpers import import_
8
+ from ._helpers import import_, wrapped_libraries
9
9
 
10
10
  # Check the known dtypes by their string names
11
11
 
@@ -64,7 +64,7 @@ def isdtype_(dtype_, kind):
64
64
  assert type(res) is bool # noqa: E721
65
65
  return res
66
66
 
67
- @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
67
+ @pytest.mark.parametrize("library", wrapped_libraries)
68
68
  def test_isdtype_spec_dtypes(library):
69
69
  xp = import_(library, wrapper=True)
70
70
 
@@ -98,7 +98,7 @@ additional_dtypes = [
98
98
  'bfloat16',
99
99
  ]
100
100
 
101
- @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
101
+ @pytest.mark.parametrize("library", wrapped_libraries)
102
102
  @pytest.mark.parametrize("dtype_", additional_dtypes)
103
103
  def test_isdtype_additional_dtypes(library, dtype_):
104
104
  xp = import_(library, wrapper=True)
@@ -0,0 +1,73 @@
1
+ """
2
+ Test that array_api_compat has no "hard" dependencies.
3
+
4
+ Libraries like NumPy should only be imported if a numpy array is passed to
5
+ array_namespace or if array_api_compat.numpy is explicitly imported.
6
+
7
+ We have to test this in a subprocess because these libraries have already been
8
+ imported from the other tests.
9
+ """
10
+
11
+ import sys
12
+ import subprocess
13
+
14
+ import pytest
15
+
16
+ class Array:
17
+ # Dummy array namespace that doesn't depend on any array library
18
+ def __array_namespace__(self, api_version=None):
19
+ class Namespace:
20
+ pass
21
+ return Namespace()
22
+
23
+ def _test_dependency(mod):
24
+ assert mod not in sys.modules
25
+
26
+ # Run various functions that shouldn't depend on mod and check that they
27
+ # don't import it.
28
+
29
+ import array_api_compat
30
+ assert mod not in sys.modules
31
+
32
+ a = Array()
33
+
34
+ # array-api-strict is an example of an array API library that isn't
35
+ # wrapped by array-api-compat.
36
+ if "strict" not in mod:
37
+ is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array")
38
+ assert not is_mod_array(a)
39
+ assert mod not in sys.modules
40
+
41
+ is_array_api_obj = getattr(array_api_compat, "is_array_api_obj")
42
+ assert is_array_api_obj(a)
43
+ assert mod not in sys.modules
44
+
45
+ array_namespace = getattr(array_api_compat, "array_namespace")
46
+ array_namespace(Array())
47
+ assert mod not in sys.modules
48
+
49
+ # TODO: Test that wrapper for library X doesn't depend on wrappers for library
50
+ # Y (except most array libraries actually do themselves depend on numpy).
51
+
52
+ @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
53
+ "jax.numpy", "array_api_strict"])
54
+ def test_numpy_dependency(library):
55
+ # This import is here because it imports numpy
56
+ from ._helpers import import_
57
+
58
+ # This unfortunately won't go through any of the pytest machinery. We
59
+ # reraise the exception as an AssertionError so that pytest will show it
60
+ # in a semi-reasonable way
61
+
62
+ # Import (in this process) to make sure 'library' is actually installed and
63
+ # so that cupy can be skipped.
64
+ import_(library)
65
+
66
+ try:
67
+ subprocess.run([sys.executable, '-c', f'''\
68
+ from tests.test_no_dependencies import _test_dependency
69
+
70
+ _test_dependency({library!r})'''], check=True, capture_output=True, encoding='utf-8')
71
+ except subprocess.CalledProcessError as e:
72
+ print(e.stdout, end='')
73
+ raise AssertionError(e.stderr)
@@ -1,6 +0,0 @@
1
-
2
- [cupy]
3
- cupy
4
-
5
- [numpy]
6
- numpy
@@ -1,85 +0,0 @@
1
- from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2
- is_dask_array, is_jax_array)
3
-
4
- from array_api_compat import is_array_api_obj, device, to_device
5
-
6
- from ._helpers import import_
7
-
8
- import pytest
9
- import numpy as np
10
- from numpy.testing import assert_allclose
11
-
12
- is_functions = {
13
- 'numpy': 'is_numpy_array',
14
- 'cupy': 'is_cupy_array',
15
- 'torch': 'is_torch_array',
16
- 'dask.array': 'is_dask_array',
17
- 'jax.numpy': 'is_jax_array',
18
- }
19
-
20
- @pytest.mark.parametrize('library', is_functions.keys())
21
- @pytest.mark.parametrize('func', is_functions.values())
22
- def test_is_xp_array(library, func):
23
- lib = import_(library)
24
- is_func = globals()[func]
25
-
26
- x = lib.asarray([1, 2, 3])
27
-
28
- assert is_func(x) == (func == is_functions[library])
29
-
30
- assert is_array_api_obj(x)
31
-
32
- @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array", "jax.numpy"])
33
- def test_device(library):
34
- xp = import_(library, wrapper=True)
35
-
36
- # We can't test much for device() and to_device() other than that
37
- # x.to_device(x.device) works.
38
-
39
- x = xp.asarray([1, 2, 3])
40
- dev = device(x)
41
-
42
- x2 = to_device(x, dev)
43
- assert device(x) == device(x2)
44
-
45
-
46
- @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array"])
47
- def test_to_device_host(library):
48
- # different libraries have different semantics
49
- # for DtoH transfers; ensure that we support a portable
50
- # shim for common array libs
51
- # see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
52
- xp = import_(library, wrapper=True)
53
-
54
- expected = np.array([1, 2, 3])
55
- x = xp.asarray([1, 2, 3])
56
- x = to_device(x, "cpu")
57
- # torch will return a genuine Device object, but
58
- # the other libs will do something different with
59
- # a `device(x)` query; however, what's really important
60
- # here is that we can test portably after calling
61
- # to_device(x, "cpu") to return to host
62
- assert_allclose(x, expected)
63
-
64
-
65
- @pytest.mark.parametrize("target_library", is_functions.keys())
66
- @pytest.mark.parametrize("source_library", is_functions.keys())
67
- def test_asarray(source_library, target_library, request):
68
- if source_library == "dask.array" and target_library == "torch":
69
- # Allow rest of test to execute instead of immediately xfailing
70
- # xref https://github.com/pandas-dev/pandas/issues/38902
71
-
72
- # TODO: remove xfail once
73
- # https://github.com/dask/dask/issues/8260 is resolved
74
- request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion"))
75
- if source_library == "cupy" and target_library != "cupy":
76
- # cupy explicitly disallows implicit conversions to CPU
77
- pytest.skip(reason="cupy does not support implicit conversion to CPU")
78
- src_lib = import_(source_library, wrapper=True)
79
- tgt_lib = import_(target_library, wrapper=True)
80
- is_tgt_type = globals()[is_functions[target_library]]
81
-
82
- a = src_lib.asarray([1, 2, 3])
83
- b = tgt_lib.asarray(a)
84
-
85
- assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}"
File without changes