array-api-compat 1.7__tar.gz → 1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {array_api_compat-1.7 → array_api_compat-1.8}/PKG-INFO +6 -6
  2. {array_api_compat-1.7 → array_api_compat-1.8}/README.md +3 -3
  3. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/__init__.py +1 -1
  4. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/common/_aliases.py +62 -2
  5. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/common/_helpers.py +35 -1
  6. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/cupy/_aliases.py +1 -0
  7. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/dask/array/_aliases.py +1 -0
  8. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/numpy/_aliases.py +1 -0
  9. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/torch/_aliases.py +9 -6
  10. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat.egg-info/PKG-INFO +6 -6
  11. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat.egg-info/requires.txt +1 -1
  12. {array_api_compat-1.7 → array_api_compat-1.8}/setup.py +1 -1
  13. {array_api_compat-1.7 → array_api_compat-1.8}/LICENSE +0 -0
  14. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/_internal.py +0 -0
  15. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/common/__init__.py +0 -0
  16. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/common/_fft.py +0 -0
  17. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/common/_linalg.py +0 -0
  18. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/common/_typing.py +0 -0
  19. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/cupy/__init__.py +0 -0
  20. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/cupy/_typing.py +0 -0
  21. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/cupy/fft.py +0 -0
  22. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/cupy/linalg.py +0 -0
  23. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/dask/__init__.py +0 -0
  24. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/dask/array/__init__.py +0 -0
  25. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/dask/array/linalg.py +0 -0
  26. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/numpy/__init__.py +0 -0
  27. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/numpy/_typing.py +0 -0
  28. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/numpy/fft.py +0 -0
  29. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/numpy/linalg.py +0 -0
  30. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/torch/__init__.py +0 -0
  31. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/torch/fft.py +0 -0
  32. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat/torch/linalg.py +0 -0
  33. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat.egg-info/SOURCES.txt +0 -0
  34. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat.egg-info/dependency_links.txt +0 -0
  35. {array_api_compat-1.7 → array_api_compat-1.8}/array_api_compat.egg-info/top_level.txt +0 -0
  36. {array_api_compat-1.7 → array_api_compat-1.8}/setup.cfg +0 -0
  37. {array_api_compat-1.7 → array_api_compat-1.8}/tests/test_all.py +0 -0
  38. {array_api_compat-1.7 → array_api_compat-1.8}/tests/test_array_namespace.py +0 -0
  39. {array_api_compat-1.7 → array_api_compat-1.8}/tests/test_common.py +0 -0
  40. {array_api_compat-1.7 → array_api_compat-1.8}/tests/test_isdtype.py +0 -0
  41. {array_api_compat-1.7 → array_api_compat-1.8}/tests/test_no_dependencies.py +0 -0
  42. {array_api_compat-1.7 → array_api_compat-1.8}/tests/test_vendoring.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: array_api_compat
3
- Version: 1.7
3
+ Version: 1.8
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
@@ -23,15 +23,15 @@ Provides-Extra: pytorch
23
23
  Requires-Dist: pytorch; extra == "pytorch"
24
24
  Provides-Extra: dask
25
25
  Requires-Dist: dask; extra == "dask"
26
- Provides-Extra: sprase
27
- Requires-Dist: sparse>=0.15.1; extra == "sprase"
26
+ Provides-Extra: sparse
27
+ Requires-Dist: sparse>=0.15.1; extra == "sparse"
28
28
 
29
29
  # Array API compatibility library
30
30
 
31
31
  This is a small wrapper around common array libraries that is compatible with
32
32
  the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
33
- NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
34
- for other array libraries, or if you encounter any issues, please [open an
35
- issue](https://github.com/data-apis/array-api-compat/issues).
33
+ NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
34
+ support for other array libraries, or if you encounter any issues, please [open
35
+ an issue](https://github.com/data-apis/array-api-compat/issues).
36
36
 
37
37
  See the documentation for more details https://data-apis.org/array-api-compat/
@@ -2,8 +2,8 @@
2
2
 
3
3
  This is a small wrapper around common array libraries that is compatible with
4
4
  the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5
- NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
6
- for other array libraries, or if you encounter any issues, please [open an
7
- issue](https://github.com/data-apis/array-api-compat/issues).
5
+ NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
6
+ support for other array libraries, or if you encounter any issues, please [open
7
+ an issue](https://github.com/data-apis/array-api-compat/issues).
8
8
 
9
9
  See the documentation for more details https://data-apis.org/array-api-compat/
@@ -17,6 +17,6 @@ to ensure they are not using functionality outside of the standard, but prefer
17
17
  this implementation for the default when working with NumPy arrays.
18
18
 
19
19
  """
20
- __version__ = '1.7'
20
+ __version__ = '1.8'
21
21
 
22
22
  from .common import * # noqa: F401, F403
@@ -12,7 +12,7 @@ if TYPE_CHECKING:
12
12
  from typing import NamedTuple
13
13
  import inspect
14
14
 
15
- from ._helpers import _check_device
15
+ from ._helpers import array_namespace, _check_device
16
16
 
17
17
  # These functions are modified from the NumPy versions.
18
18
 
@@ -264,6 +264,66 @@ def var(
264
264
  ) -> ndarray:
265
265
  return xp.var(x, axis=axis, ddof=correction, keepdims=keepdims, **kwargs)
266
266
 
267
+
268
+ # The min and max argument names in clip are different and not optional in numpy, and type
269
+ # promotion behavior is different.
270
+ def clip(
271
+ x: ndarray,
272
+ /,
273
+ min: Optional[Union[int, float, ndarray]] = None,
274
+ max: Optional[Union[int, float, ndarray]] = None,
275
+ *,
276
+ xp,
277
+ # TODO: np.clip has other ufunc kwargs
278
+ out: Optional[ndarray] = None,
279
+ ) -> ndarray:
280
+ def _isscalar(a):
281
+ return isinstance(a, (int, float, type(None)))
282
+ min_shape = () if _isscalar(min) else min.shape
283
+ max_shape = () if _isscalar(max) else max.shape
284
+ result_shape = xp.broadcast_shapes(x.shape, min_shape, max_shape)
285
+
286
+ wrapped_xp = array_namespace(x)
287
+
288
+ # np.clip does type promotion but the array API clip requires that the
289
+ # output have the same dtype as x. We do this instead of just downcasting
290
+ # the result of xp.clip() to handle some corner cases better (e.g.,
291
+ # avoiding uint64 -> float64 promotion).
292
+
293
+ # Note: cases where min or max overflow (integer) or round (float) in the
294
+ # wrong direction when downcasting to x.dtype are unspecified. This code
295
+ # just does whatever NumPy does when it downcasts in the assignment, but
296
+ # other behavior could be preferred, especially for integers. For example,
297
+ # this code produces:
298
+
299
+ # >>> clip(asarray(0, dtype=int8), asarray(128, dtype=int16), None)
300
+ # -128
301
+
302
+ # but an answer of 0 might be preferred. See
303
+ # https://github.com/numpy/numpy/issues/24976 for more discussion on this issue.
304
+
305
+
306
+ # At least handle the case of Python integers correctly (see
307
+ # https://github.com/numpy/numpy/pull/26892).
308
+ if type(min) is int and min <= xp.iinfo(x.dtype).min:
309
+ min = None
310
+ if type(max) is int and max >= xp.iinfo(x.dtype).max:
311
+ max = None
312
+
313
+ if out is None:
314
+ out = wrapped_xp.asarray(xp.broadcast_to(x, result_shape), copy=True)
315
+ if min is not None:
316
+ a = xp.broadcast_to(xp.asarray(min), result_shape)
317
+ ia = (out < a) | xp.isnan(a)
318
+ # torch requires an explicit cast here
319
+ out[ia] = wrapped_xp.astype(a[ia], out.dtype)
320
+ if max is not None:
321
+ b = xp.broadcast_to(xp.asarray(max), result_shape)
322
+ ib = (out > b) | xp.isnan(b)
323
+ out[ib] = wrapped_xp.astype(b[ib], out.dtype)
324
+ # Return a scalar for 0-D
325
+ return out[()]
326
+
267
327
  # Unlike transpose(), the axes argument to permute_dims() is required.
268
328
  def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
269
329
  return xp.transpose(x, axes)
@@ -465,6 +525,6 @@ __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
465
525
  'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
466
526
  'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
467
527
  'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
468
- 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
528
+ 'astype', 'std', 'var', 'clip', 'permute_dims', 'reshape', 'argsort',
469
529
  'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
470
530
  'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
@@ -48,6 +48,7 @@ def is_numpy_array(x):
48
48
  is_array_api_obj
49
49
  is_cupy_array
50
50
  is_torch_array
51
+ is_ndonnx_array
51
52
  is_dask_array
52
53
  is_jax_array
53
54
  is_pydata_sparse_array
@@ -78,11 +79,12 @@ def is_cupy_array(x):
78
79
  is_array_api_obj
79
80
  is_numpy_array
80
81
  is_torch_array
82
+ is_ndonnx_array
81
83
  is_dask_array
82
84
  is_jax_array
83
85
  is_pydata_sparse_array
84
86
  """
85
- # Avoid importing NumPy if it isn't already
87
+ # Avoid importing CuPy if it isn't already
86
88
  if 'cupy' not in sys.modules:
87
89
  return False
88
90
 
@@ -118,6 +120,33 @@ def is_torch_array(x):
118
120
  # TODO: Should we reject ndarray subclasses?
119
121
  return isinstance(x, torch.Tensor)
120
122
 
123
+ def is_ndonnx_array(x):
124
+ """
125
+ Return True if `x` is a ndonnx Array.
126
+
127
+ This function does not import ndonnx if it has not already been imported
128
+ and is therefore cheap to use.
129
+
130
+ See Also
131
+ --------
132
+
133
+ array_namespace
134
+ is_array_api_obj
135
+ is_numpy_array
136
+ is_cupy_array
137
+ is_ndonnx_array
138
+ is_dask_array
139
+ is_jax_array
140
+ is_pydata_sparse_array
141
+ """
142
+ # Avoid importing torch if it isn't already
143
+ if 'ndonnx' not in sys.modules:
144
+ return False
145
+
146
+ import ndonnx as ndx
147
+
148
+ return isinstance(x, ndx.Array)
149
+
121
150
  def is_dask_array(x):
122
151
  """
123
152
  Return True if `x` is a dask.array Array.
@@ -133,6 +162,7 @@ def is_dask_array(x):
133
162
  is_numpy_array
134
163
  is_cupy_array
135
164
  is_torch_array
165
+ is_ndonnx_array
136
166
  is_jax_array
137
167
  is_pydata_sparse_array
138
168
  """
@@ -160,6 +190,7 @@ def is_jax_array(x):
160
190
  is_numpy_array
161
191
  is_cupy_array
162
192
  is_torch_array
193
+ is_ndonnx_array
163
194
  is_dask_array
164
195
  is_pydata_sparse_array
165
196
  """
@@ -188,6 +219,7 @@ def is_pydata_sparse_array(x) -> bool:
188
219
  is_numpy_array
189
220
  is_cupy_array
190
221
  is_torch_array
222
+ is_ndonnx_array
191
223
  is_dask_array
192
224
  is_jax_array
193
225
  """
@@ -211,6 +243,7 @@ def is_array_api_obj(x):
211
243
  is_numpy_array
212
244
  is_cupy_array
213
245
  is_torch_array
246
+ is_ndonnx_array
214
247
  is_dask_array
215
248
  is_jax_array
216
249
  """
@@ -613,6 +646,7 @@ __all__ = [
613
646
  "is_jax_array",
614
647
  "is_numpy_array",
615
648
  "is_torch_array",
649
+ "is_ndonnx_array",
616
650
  "is_pydata_sparse_array",
617
651
  "size",
618
652
  "to_device",
@@ -47,6 +47,7 @@ unique_values = get_xp(cp)(_aliases.unique_values)
47
47
  astype = _aliases.astype
48
48
  std = get_xp(cp)(_aliases.std)
49
49
  var = get_xp(cp)(_aliases.var)
50
+ clip = get_xp(cp)(_aliases.clip)
50
51
  permute_dims = get_xp(cp)(_aliases.permute_dims)
51
52
  reshape = get_xp(cp)(_aliases.reshape)
52
53
  argsort = get_xp(cp)(_aliases.argsort)
@@ -88,6 +88,7 @@ unique_values = get_xp(da)(_aliases.unique_values)
88
88
  permute_dims = get_xp(da)(_aliases.permute_dims)
89
89
  std = get_xp(da)(_aliases.std)
90
90
  var = get_xp(da)(_aliases.var)
91
+ clip = get_xp(da)(_aliases.clip)
91
92
  empty = get_xp(da)(_aliases.empty)
92
93
  empty_like = get_xp(da)(_aliases.empty_like)
93
94
  full = get_xp(da)(_aliases.full)
@@ -47,6 +47,7 @@ unique_values = get_xp(np)(_aliases.unique_values)
47
47
  astype = _aliases.astype
48
48
  std = get_xp(np)(_aliases.std)
49
49
  var = get_xp(np)(_aliases.var)
50
+ clip = get_xp(np)(_aliases.clip)
50
51
  permute_dims = get_xp(np)(_aliases.permute_dims)
51
52
  reshape = get_xp(np)(_aliases.reshape)
52
53
  argsort = get_xp(np)(_aliases.argsort)
@@ -4,7 +4,7 @@ from functools import wraps as _wraps
4
4
  from builtins import all as _builtin_all, any as _builtin_any
5
5
 
6
6
  from ..common._aliases import (matrix_transpose as _aliases_matrix_transpose,
7
- vecdot as _aliases_vecdot)
7
+ vecdot as _aliases_vecdot, clip as _aliases_clip)
8
8
  from .._internal import get_xp
9
9
 
10
10
  import torch
@@ -155,6 +155,7 @@ bitwise_left_shift = _two_arg(torch.bitwise_left_shift)
155
155
  bitwise_or = _two_arg(torch.bitwise_or)
156
156
  bitwise_right_shift = _two_arg(torch.bitwise_right_shift)
157
157
  bitwise_xor = _two_arg(torch.bitwise_xor)
158
+ copysign = _two_arg(torch.copysign)
158
159
  divide = _two_arg(torch.divide)
159
160
  # Also a rename. torch.equal does not broadcast
160
161
  equal = _two_arg(torch.eq)
@@ -188,6 +189,8 @@ def min(x: array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = None, keep
188
189
  return torch.clone(x)
189
190
  return torch.amin(x, axis, keepdims=keepdims)
190
191
 
192
+ clip = get_xp(torch)(_aliases_clip)
193
+
191
194
  # torch.sort also returns a tuple
192
195
  # https://github.com/pytorch/pytorch/issues/70921
193
196
  def sort(x: array, /, *, axis: int = -1, descending: bool = False, stable: bool = True, **kwargs) -> array:
@@ -702,11 +705,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
702
705
 
703
706
  __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert',
704
707
  'newaxis', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift',
705
- 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide',
706
- 'equal', 'floor_divide', 'greater', 'greater_equal', 'less',
707
- 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
708
- 'remainder', 'subtract', 'max', 'min', 'sort', 'prod', 'sum',
709
- 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
708
+ 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'copysign',
709
+ 'divide', 'equal', 'floor_divide', 'greater', 'greater_equal',
710
+ 'less', 'less_equal', 'logaddexp', 'multiply', 'not_equal', 'pow',
711
+ 'remainder', 'subtract', 'max', 'min', 'clip', 'sort', 'prod',
712
+ 'sum', 'any', 'all', 'mean', 'std', 'var', 'concat', 'squeeze',
710
713
  'broadcast_to', 'flip', 'roll', 'nonzero', 'where', 'reshape',
711
714
  'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty',
712
715
  'tril', 'triu', 'expand_dims', 'astype', 'broadcast_arrays',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: array_api_compat
3
- Version: 1.7
3
+ Version: 1.8
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
@@ -23,15 +23,15 @@ Provides-Extra: pytorch
23
23
  Requires-Dist: pytorch; extra == "pytorch"
24
24
  Provides-Extra: dask
25
25
  Requires-Dist: dask; extra == "dask"
26
- Provides-Extra: sprase
27
- Requires-Dist: sparse>=0.15.1; extra == "sprase"
26
+ Provides-Extra: sparse
27
+ Requires-Dist: sparse>=0.15.1; extra == "sparse"
28
28
 
29
29
  # Array API compatibility library
30
30
 
31
31
  This is a small wrapper around common array libraries that is compatible with
32
32
  the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
33
- NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
34
- for other array libraries, or if you encounter any issues, please [open an
35
- issue](https://github.com/data-apis/array-api-compat/issues).
33
+ NumPy, CuPy, PyTorch, Dask, JAX, ndonnx and `sparse` are supported. If you want
34
+ support for other array libraries, or if you encounter any issues, please [open
35
+ an issue](https://github.com/data-apis/array-api-compat/issues).
36
36
 
37
37
  See the documentation for more details https://data-apis.org/array-api-compat/
@@ -14,5 +14,5 @@ numpy
14
14
  [pytorch]
15
15
  pytorch
16
16
 
17
- [sprase]
17
+ [sparse]
18
18
  sparse>=0.15.1
@@ -21,7 +21,7 @@ setup(
21
21
  "jax": "jax",
22
22
  "pytorch": "pytorch",
23
23
  "dask": "dask",
24
- "sprase": "sparse >=0.15.1",
24
+ "sparse": "sparse >=0.15.1",
25
25
  },
26
26
  classifiers=[
27
27
  "Programming Language :: Python :: 3",
File without changes
File without changes