array-api-compat 1.6__tar.gz → 1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {array_api_compat-1.6 → array_api_compat-1.7}/PKG-INFO +5 -3
  2. {array_api_compat-1.6 → array_api_compat-1.7}/README.md +2 -2
  3. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/__init__.py +1 -1
  4. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/common/_helpers.py +82 -2
  5. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/torch/linalg.py +33 -1
  6. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat.egg-info/PKG-INFO +5 -3
  7. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat.egg-info/requires.txt +3 -0
  8. {array_api_compat-1.6 → array_api_compat-1.7}/setup.py +1 -0
  9. {array_api_compat-1.6 → array_api_compat-1.7}/tests/test_array_namespace.py +8 -1
  10. {array_api_compat-1.6 → array_api_compat-1.7}/tests/test_common.py +4 -1
  11. {array_api_compat-1.6 → array_api_compat-1.7}/tests/test_no_dependencies.py +2 -2
  12. {array_api_compat-1.6 → array_api_compat-1.7}/LICENSE +0 -0
  13. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/_internal.py +0 -0
  14. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/common/__init__.py +0 -0
  15. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/common/_aliases.py +0 -0
  16. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/common/_fft.py +0 -0
  17. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/common/_linalg.py +0 -0
  18. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/common/_typing.py +0 -0
  19. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/cupy/__init__.py +0 -0
  20. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/cupy/_aliases.py +0 -0
  21. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/cupy/_typing.py +0 -0
  22. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/cupy/fft.py +0 -0
  23. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/cupy/linalg.py +0 -0
  24. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/dask/__init__.py +0 -0
  25. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/dask/array/__init__.py +0 -0
  26. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/dask/array/_aliases.py +0 -0
  27. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/dask/array/linalg.py +0 -0
  28. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/numpy/__init__.py +0 -0
  29. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/numpy/_aliases.py +0 -0
  30. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/numpy/_typing.py +0 -0
  31. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/numpy/fft.py +0 -0
  32. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/numpy/linalg.py +0 -0
  33. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/torch/__init__.py +0 -0
  34. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/torch/_aliases.py +0 -0
  35. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat/torch/fft.py +0 -0
  36. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat.egg-info/SOURCES.txt +0 -0
  37. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat.egg-info/dependency_links.txt +0 -0
  38. {array_api_compat-1.6 → array_api_compat-1.7}/array_api_compat.egg-info/top_level.txt +0 -0
  39. {array_api_compat-1.6 → array_api_compat-1.7}/setup.cfg +0 -0
  40. {array_api_compat-1.6 → array_api_compat-1.7}/tests/test_all.py +0 -0
  41. {array_api_compat-1.6 → array_api_compat-1.7}/tests/test_isdtype.py +0 -0
  42. {array_api_compat-1.6 → array_api_compat-1.7}/tests/test_vendoring.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: array_api_compat
3
- Version: 1.6
3
+ Version: 1.7
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
@@ -23,13 +23,15 @@ Provides-Extra: pytorch
23
23
  Requires-Dist: pytorch; extra == "pytorch"
24
24
  Provides-Extra: dask
25
25
  Requires-Dist: dask; extra == "dask"
26
+ Provides-Extra: sprase
27
+ Requires-Dist: sparse>=0.15.1; extra == "sprase"
26
28
 
27
29
  # Array API compatibility library
28
30
 
29
31
  This is a small wrapper around common array libraries that is compatible with
30
32
  the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
31
- NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
32
- libraries, or if you encounter any issues, please [open an
33
+ NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
34
+ for other array libraries, or if you encounter any issues, please [open an
33
35
  issue](https://github.com/data-apis/array-api-compat/issues).
34
36
 
35
37
  See the documentation for more details https://data-apis.org/array-api-compat/
@@ -2,8 +2,8 @@
2
2
 
3
3
  This is a small wrapper around common array libraries that is compatible with
4
4
  the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
5
- NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
6
- libraries, or if you encounter any issues, please [open an
5
+ NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
6
+ for other array libraries, or if you encounter any issues, please [open an
7
7
  issue](https://github.com/data-apis/array-api-compat/issues).
8
8
 
9
9
  See the documentation for more details https://data-apis.org/array-api-compat/
@@ -17,6 +17,6 @@ to ensure they are not using functionality outside of the standard, but prefer
17
17
  this implementation for the default when working with NumPy arrays.
18
18
 
19
19
  """
20
- __version__ = '1.6'
20
+ __version__ = '1.7'
21
21
 
22
22
  from .common import * # noqa: F401, F403
@@ -18,6 +18,20 @@ import math
18
18
  import inspect
19
19
  import warnings
20
20
 
21
+ def _is_jax_zero_gradient_array(x):
22
+ """Return True if `x` is a zero-gradient array.
23
+
24
+ These arrays are a design quirk of Jax that may one day be removed.
25
+ See https://github.com/google/jax/issues/20620.
26
+ """
27
+ if 'numpy' not in sys.modules or 'jax' not in sys.modules:
28
+ return False
29
+
30
+ import numpy as np
31
+ import jax
32
+
33
+ return isinstance(x, np.ndarray) and x.dtype == jax.float0
34
+
21
35
  def is_numpy_array(x):
22
36
  """
23
37
  Return True if `x` is a NumPy array.
@@ -36,6 +50,7 @@ def is_numpy_array(x):
36
50
  is_torch_array
37
51
  is_dask_array
38
52
  is_jax_array
53
+ is_pydata_sparse_array
39
54
  """
40
55
  # Avoid importing NumPy if it isn't already
41
56
  if 'numpy' not in sys.modules:
@@ -44,7 +59,8 @@ def is_numpy_array(x):
44
59
  import numpy as np
45
60
 
46
61
  # TODO: Should we reject ndarray subclasses?
47
- return isinstance(x, (np.ndarray, np.generic))
62
+ return (isinstance(x, (np.ndarray, np.generic))
63
+ and not _is_jax_zero_gradient_array(x))
48
64
 
49
65
  def is_cupy_array(x):
50
66
  """
@@ -64,6 +80,7 @@ def is_cupy_array(x):
64
80
  is_torch_array
65
81
  is_dask_array
66
82
  is_jax_array
83
+ is_pydata_sparse_array
67
84
  """
68
85
  # Avoid importing NumPy if it isn't already
69
86
  if 'cupy' not in sys.modules:
@@ -90,6 +107,7 @@ def is_torch_array(x):
90
107
  is_cupy_array
91
108
  is_dask_array
92
109
  is_jax_array
110
+ is_pydata_sparse_array
93
111
  """
94
112
  # Avoid importing torch if it isn't already
95
113
  if 'torch' not in sys.modules:
@@ -116,6 +134,7 @@ def is_dask_array(x):
116
134
  is_cupy_array
117
135
  is_torch_array
118
136
  is_jax_array
137
+ is_pydata_sparse_array
119
138
  """
120
139
  # Avoid importing dask if it isn't already
121
140
  if 'dask.array' not in sys.modules:
@@ -142,6 +161,7 @@ def is_jax_array(x):
142
161
  is_cupy_array
143
162
  is_torch_array
144
163
  is_dask_array
164
+ is_pydata_sparse_array
145
165
  """
146
166
  # Avoid importing jax if it isn't already
147
167
  if 'jax' not in sys.modules:
@@ -149,7 +169,36 @@ def is_jax_array(x):
149
169
 
150
170
  import jax
151
171
 
152
- return isinstance(x, jax.Array)
172
+ return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
173
+
174
+
175
+ def is_pydata_sparse_array(x) -> bool:
176
+ """
177
+ Return True if `x` is an array from the `sparse` package.
178
+
179
+ This function does not import `sparse` if it has not already been imported
180
+ and is therefore cheap to use.
181
+
182
+
183
+ See Also
184
+ --------
185
+
186
+ array_namespace
187
+ is_array_api_obj
188
+ is_numpy_array
189
+ is_cupy_array
190
+ is_torch_array
191
+ is_dask_array
192
+ is_jax_array
193
+ """
194
+ # Avoid importing jax if it isn't already
195
+ if 'sparse' not in sys.modules:
196
+ return False
197
+
198
+ import sparse
199
+
200
+ # TODO: Account for other backends.
201
+ return isinstance(x, sparse.SparseArray)
153
202
 
154
203
  def is_array_api_obj(x):
155
204
  """
@@ -170,6 +219,7 @@ def is_array_api_obj(x):
170
219
  or is_torch_array(x) \
171
220
  or is_dask_array(x) \
172
221
  or is_jax_array(x) \
222
+ or is_pydata_sparse_array(x) \
173
223
  or hasattr(x, '__array_namespace__')
174
224
 
175
225
  def _check_api_version(api_version):
@@ -238,6 +288,7 @@ def array_namespace(*xs, api_version=None, use_compat=None):
238
288
  is_torch_array
239
289
  is_dask_array
240
290
  is_jax_array
291
+ is_pydata_sparse_array
241
292
 
242
293
  """
243
294
  if use_compat not in [None, True, False]:
@@ -297,6 +348,15 @@ def array_namespace(*xs, api_version=None, use_compat=None):
297
348
  # not have a wrapper submodule for it.
298
349
  import jax.experimental.array_api as jnp
299
350
  namespaces.add(jnp)
351
+ elif is_pydata_sparse_array(x):
352
+ if use_compat is True:
353
+ _check_api_version(api_version)
354
+ raise ValueError("`sparse` does not have an array-api-compat wrapper")
355
+ else:
356
+ import sparse
357
+ # `sparse` is already an array namespace. We do not have a wrapper
358
+ # submodule for it.
359
+ namespaces.add(sparse)
300
360
  elif hasattr(x, '__array_namespace__'):
301
361
  if use_compat is True:
302
362
  raise ValueError("The given array does not have an array-api-compat wrapper")
@@ -391,8 +451,23 @@ def device(x: Array, /) -> Device:
391
451
  return x.device()
392
452
  else:
393
453
  return x.device
454
+ elif is_pydata_sparse_array(x):
455
+ # `sparse` will gain `.device`, so check for this first.
456
+ x_device = getattr(x, 'device', None)
457
+ if x_device is not None:
458
+ return x_device
459
+ # Everything but DOK has this attr.
460
+ try:
461
+ inner = x.data
462
+ except AttributeError:
463
+ return "cpu"
464
+ # Return the device of the constituent array
465
+ return device(inner)
394
466
  return x.device
395
467
 
468
+ # Prevent shadowing, used below
469
+ _device = device
470
+
396
471
  # Based on cupy.array_api.Array.to_device
397
472
  def _cupy_to_device(x, device, /, stream=None):
398
473
  import cupy as cp
@@ -508,6 +583,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
508
583
  # This import adds to_device to x
509
584
  import jax.experimental.array_api # noqa: F401
510
585
  return x.to_device(device, stream=stream)
586
+ elif is_pydata_sparse_array(x) and device == _device(x):
587
+ # Perform trivial check to return the same array if
588
+ # device is same instead of err-ing.
589
+ return x
511
590
  return x.to_device(device, stream=stream)
512
591
 
513
592
  def size(x):
@@ -534,6 +613,7 @@ __all__ = [
534
613
  "is_jax_array",
535
614
  "is_numpy_array",
536
615
  "is_torch_array",
616
+ "is_pydata_sparse_array",
537
617
  "size",
538
618
  "to_device",
539
619
  ]
@@ -60,6 +60,22 @@ def vecdot(x1: array, x2: array, /, *, axis: int = -1, **kwargs) -> array:
60
60
 
61
61
  def solve(x1: array, x2: array, /, **kwargs) -> array:
62
62
  x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
63
+ # Torch tries to emulate NumPy 1 solve behavior by using batched 1-D solve
64
+ # whenever
65
+ # 1. x1.ndim - 1 == x2.ndim
66
+ # 2. x1.shape[:-1] == x2.shape
67
+ #
68
+ # See linalg_solve_is_vector_rhs in
69
+ # aten/src/ATen/native/LinearAlgebraUtils.h and
70
+ # TORCH_META_FUNC(_linalg_solve_ex) in
71
+ # aten/src/ATen/native/BatchLinearAlgebra.cpp in the PyTorch source code.
72
+ #
73
+ # The easiest way to work around this is to prepend a size 1 dimension to
74
+ # x2, since x2 is already one dimension less than x1.
75
+ #
76
+ # See https://github.com/pytorch/pytorch/issues/52915
77
+ if x2.ndim != 1 and x1.ndim - 1 == x2.ndim and x1.shape[:-1] == x2.shape:
78
+ x2 = x2[None]
63
79
  return torch.linalg.solve(x1, x2, **kwargs)
64
80
 
65
81
  # torch.trace doesn't support the offset argument and doesn't support stacking
@@ -78,7 +94,23 @@ def vector_norm(
78
94
  ) -> array:
79
95
  # torch.vector_norm incorrectly treats axis=() the same as axis=None
80
96
  if axis == ():
81
- keepdims = True
97
+ out = kwargs.get('out')
98
+ if out is None:
99
+ dtype = None
100
+ if x.dtype == torch.complex64:
101
+ dtype = torch.float32
102
+ elif x.dtype == torch.complex128:
103
+ dtype = torch.float64
104
+
105
+ out = torch.zeros_like(x, dtype=dtype)
106
+
107
+ # The norm of a single scalar works out to abs(x) in every case except
108
+ # for ord=0, which is x != 0.
109
+ if ord == 0:
110
+ out[:] = (x != 0)
111
+ else:
112
+ out[:] = torch.abs(x)
113
+ return out
82
114
  return torch.linalg.vector_norm(x, ord=ord, axis=axis, keepdim=keepdims, **kwargs)
83
115
 
84
116
  __all__ = linalg_all + ['outer', 'matmul', 'matrix_transpose', 'tensordot',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: array_api_compat
3
- Version: 1.6
3
+ Version: 1.7
4
4
  Summary: A wrapper around NumPy and other array libraries to make them compatible with the Array API standard
5
5
  Home-page: https://data-apis.org/array-api-compat/
6
6
  Author: Consortium for Python Data API Standards
@@ -23,13 +23,15 @@ Provides-Extra: pytorch
23
23
  Requires-Dist: pytorch; extra == "pytorch"
24
24
  Provides-Extra: dask
25
25
  Requires-Dist: dask; extra == "dask"
26
+ Provides-Extra: sprase
27
+ Requires-Dist: sparse>=0.15.1; extra == "sprase"
26
28
 
27
29
  # Array API compatibility library
28
30
 
29
31
  This is a small wrapper around common array libraries that is compatible with
30
32
  the [Array API standard](https://data-apis.org/array-api/latest/). Currently,
31
- NumPy, CuPy, PyTorch, Dask, and JAX are supported. If you want support for other array
32
- libraries, or if you encounter any issues, please [open an
33
+ NumPy, CuPy, PyTorch, Dask, JAX and `sparse` are supported. If you want support
34
+ for other array libraries, or if you encounter any issues, please [open an
33
35
  issue](https://github.com/data-apis/array-api-compat/issues).
34
36
 
35
37
  See the documentation for more details https://data-apis.org/array-api-compat/
@@ -13,3 +13,6 @@ numpy
13
13
 
14
14
  [pytorch]
15
15
  pytorch
16
+
17
+ [sprase]
18
+ sparse>=0.15.1
@@ -21,6 +21,7 @@ setup(
21
21
  "jax": "jax",
22
22
  "pytorch": "pytorch",
23
23
  "dask": "dask",
24
+ "sprase": "sparse >=0.15.1",
24
25
  },
25
26
  classifiers=[
26
27
  "Programming Language :: Python :: 3",
@@ -2,6 +2,7 @@ import subprocess
2
2
  import sys
3
3
  import warnings
4
4
 
5
+ import jax
5
6
  import numpy as np
6
7
  import pytest
7
8
  import torch
@@ -18,7 +19,7 @@ def test_array_namespace(library, api_version, use_compat):
18
19
  xp = import_(library)
19
20
 
20
21
  array = xp.asarray([1.0, 2.0, 3.0])
21
- if use_compat is True and library in ['array_api_strict', 'jax.numpy']:
22
+ if use_compat is True and library in {'array_api_strict', 'jax.numpy', 'sparse'}:
22
23
  pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
23
24
  return
24
25
  namespace = array_api_compat.array_namespace(array, api_version=api_version, use_compat=use_compat)
@@ -55,6 +56,12 @@ assert namespace == jax.experimental.array_api
55
56
  """
56
57
  subprocess.run([sys.executable, "-c", code], check=True)
57
58
 
59
+ def test_jax_zero_gradient():
60
+ jx = jax.numpy.arange(4)
61
+ jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
62
+ assert (array_api_compat.get_namespace(jax_zero) is
63
+ array_api_compat.get_namespace(jx))
64
+
58
65
  def test_array_namespace_errors():
59
66
  pytest.raises(TypeError, lambda: array_namespace([1]))
60
67
  pytest.raises(TypeError, lambda: array_namespace())
@@ -1,5 +1,5 @@
1
1
  from array_api_compat import (is_numpy_array, is_cupy_array, is_torch_array, # noqa: F401
2
- is_dask_array, is_jax_array)
2
+ is_dask_array, is_jax_array, is_pydata_sparse_array)
3
3
 
4
4
  from array_api_compat import is_array_api_obj, device, to_device
5
5
 
@@ -16,6 +16,7 @@ is_functions = {
16
16
  'torch': 'is_torch_array',
17
17
  'dask.array': 'is_dask_array',
18
18
  'jax.numpy': 'is_jax_array',
19
+ 'sparse': 'is_pydata_sparse_array',
19
20
  }
20
21
 
21
22
  @pytest.mark.parametrize('library', is_functions.keys())
@@ -76,6 +77,8 @@ def test_asarray_cross_library(source_library, target_library, request):
76
77
  if source_library == "cupy" and target_library != "cupy":
77
78
  # cupy explicitly disallows implicit conversions to CPU
78
79
  pytest.skip(reason="cupy does not support implicit conversion to CPU")
80
+ elif source_library == "sparse" and target_library != "sparse":
81
+ pytest.skip(reason="`sparse` does not allow implicit densification")
79
82
  src_lib = import_(source_library, wrapper=True)
80
83
  tgt_lib = import_(target_library, wrapper=True)
81
84
  is_tgt_type = globals()[is_functions[target_library]]
@@ -33,7 +33,7 @@ def _test_dependency(mod):
33
33
 
34
34
  # array-api-strict is an example of an array API library that isn't
35
35
  # wrapped by array-api-compat.
36
- if "strict" not in mod:
36
+ if "strict" not in mod and mod != "sparse":
37
37
  is_mod_array = getattr(array_api_compat, f"is_{mod.split('.')[0]}_array")
38
38
  assert not is_mod_array(a)
39
39
  assert mod not in sys.modules
@@ -50,7 +50,7 @@ def _test_dependency(mod):
50
50
  # Y (except most array libraries actually do themselves depend on numpy).
51
51
 
52
52
  @pytest.mark.parametrize("library", ["cupy", "numpy", "torch", "dask.array",
53
- "jax.numpy", "array_api_strict"])
53
+ "jax.numpy", "sparse", "array_api_strict"])
54
54
  def test_numpy_dependency(library):
55
55
  # This import is here because it imports numpy
56
56
  from ._helpers import import_
File without changes
File without changes