array-api-extra 0.8.2__tar.gz → 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.all-contributorsrc +9 -0
  2. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.github/workflows/ci.yml +1 -0
  3. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/CONTRIBUTORS.md +1 -0
  4. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/PKG-INFO +2 -1
  5. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/README.md +1 -0
  6. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/pyproject.toml +20 -12
  7. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/__init__.py +1 -1
  8. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_helpers.py +17 -5
  9. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/tests/test_helpers.py +14 -0
  10. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/tests/test_testing.py +20 -1
  11. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.dprint.jsonc +0 -0
  12. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.editorconfig +0 -0
  13. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.gitattributes +0 -0
  14. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.github/pull_request_template.md +0 -0
  15. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.github/workflows/cd.yml +0 -0
  16. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.github/workflows/docs-build.yml +0 -0
  17. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.github/workflows/docs-deploy.yml +0 -0
  18. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/.gitignore +0 -0
  19. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/LICENSE +0 -0
  20. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/docs/api-lazy.md +0 -0
  21. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/docs/api-reference.md +0 -0
  22. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/docs/conf.py +0 -0
  23. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/docs/contributing.md +0 -0
  24. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/docs/contributors.md +0 -0
  25. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/docs/index.md +0 -0
  26. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/lefthook.yml +0 -0
  27. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_delegation.py +0 -0
  28. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/__init__.py +0 -0
  29. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_at.py +0 -0
  30. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_backends.py +0 -0
  31. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_funcs.py +0 -0
  32. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_lazy.py +0 -0
  33. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_testing.py +0 -0
  34. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/__init__.py +0 -0
  35. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_compat.py +0 -0
  36. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_compat.pyi +0 -0
  37. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_typing.py +0 -0
  38. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_typing.pyi +0 -0
  39. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/py.typed +0 -0
  40. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/src/array_api_extra/testing.py +0 -0
  41. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/tests/__init__.py +0 -0
  42. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/tests/conftest.py +0 -0
  43. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/tests/test_at.py +0 -0
  44. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/tests/test_funcs.py +0 -0
  45. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/tests/test_lazy.py +0 -0
  46. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/tests/test_version.py +0 -0
  47. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/typos.toml +0 -0
  48. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/vendor_tests/__init__.py +0 -0
  49. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/vendor_tests/_array_api_compat_vendor.py +0 -0
  50. {array_api_extra-0.8.2 → array_api_extra-0.9.0}/vendor_tests/test_vendor.py +0 -0
@@ -261,6 +261,15 @@
261
261
  "example",
262
262
  "test"
263
263
  ]
264
+ },
265
+ {
266
+ "login": "amacati",
267
+ "name": "Martin Schuck",
268
+ "avatar_url": "https://avatars.githubusercontent.com/u/57562633?v=4",
269
+ "profile": "https://amacati.github.io/",
270
+ "contributions": [
271
+ "ideas"
272
+ ]
264
273
  }
265
274
  ]
266
275
  }
@@ -48,6 +48,7 @@ jobs:
48
48
  - tests-py313
49
49
  - tests-numpy1
50
50
  - tests-backends
51
+ - tests-backends-py310
51
52
  - tests-nogil
52
53
  runs-on: [ubuntu-latest]
53
54
 
@@ -39,6 +39,7 @@ This project exists thanks to the following contributors
39
39
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
40
40
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
41
41
  <td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
42
+ <td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
42
43
  </tr>
43
44
  </tbody>
44
45
  </table>
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: array-api-extra
3
- Version: 0.8.2
3
+ Version: 0.9.0
4
4
  Summary: Extra array functions built on top of the array API standard.
5
5
  Project-URL: Homepage, https://github.com/data-apis/array-api-extra
6
6
  Project-URL: Bug Tracker, https://github.com/data-apis/array-api-extra/issues
@@ -140,6 +140,7 @@ This project exists thanks to the following contributors
140
140
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
141
141
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
142
142
  <td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
143
+ <td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
143
144
  </tr>
144
145
  </tbody>
145
146
  </table>
@@ -94,6 +94,7 @@ This project exists thanks to the following contributors
94
94
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
95
95
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
96
96
  <td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
97
+ <td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
97
98
  </tr>
98
99
  </tbody>
99
100
  </table>
@@ -67,7 +67,7 @@ hypothesis = ">=6.136.4"
67
67
  dask-core = ">=2025.7.0" # No distributed, tornado, etc.
68
68
  dprint = ">=0.50.0,<0.51"
69
69
  lefthook = ">=1.12.3,<2"
70
- ruff = ">=0.12.8,<0.13"
70
+ ruff = ">=0.12.11,<0.13"
71
71
  typos = ">=1.35.5,<2"
72
72
  actionlint = ">=1.7.7,<2"
73
73
  blacken-docs = ">=1.19.1,<2"
@@ -155,13 +155,16 @@ dask-core = ">=2025.7.0" # No distributed, tornado, etc.
155
155
  sparse = ">=0.17.0"
156
156
 
157
157
  [tool.pixi.feature.backends.target.linux-64.dependencies]
158
- jax = ">=0.6.0,!=0.6.2" # 0.6.2 segfaults on Linux CUDA
158
+ # On CPU Python 3.10, use 0.6.2
159
+ # On CPU Python >=3.11, use >=0.7.0
160
+ # On GPU, use 0.6.0 (0.6.2 and 0.7.0 both segfault); see jaxlib pin below.
161
+ jax = ">=0.6.0"
159
162
 
160
163
  [tool.pixi.feature.backends.target.osx-64.dependencies]
161
- jax = ">=0.6.0,!=0.6.2"
164
+ jax = ">=0.6.0"
162
165
 
163
166
  [tool.pixi.feature.backends.target.osx-arm64.dependencies]
164
- jax = ">=0.6.0,!=0.6.2"
167
+ jax = ">=0.6.0"
165
168
 
166
169
  [tool.pixi.feature.backends.target.win-64.dependencies]
167
170
  # jax = "*" # unavailable
@@ -176,8 +179,9 @@ jax = ">=0.6.0,!=0.6.2"
176
179
  system-requirements = { cuda = "12" }
177
180
 
178
181
  [tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
179
- cupy = ">=13.5.1"
180
- jaxlib = { version = ">=0.6.0", build = "cuda12*" }
182
+ cupy = ">=13.6.0"
183
+ # JAX 0.6.2 and 0.7.0 segfault on CUDA
184
+ jaxlib = { version = ">=0.6.0,!=0.6.2,!=0.7.0", build = "cuda12*" }
181
185
  pytorch = { version = ">=2.7.1", build = "cuda12*" }
182
186
 
183
187
  [tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
@@ -191,13 +195,13 @@ pytorch = { version = ">=2.7.1", build = "cuda12*" }
191
195
  # pytorch = { version = "*", build = "cuda12*" } # unavailable
192
196
 
193
197
  [tool.pixi.feature.cuda-backends.target.win-64.dependencies]
194
- cupy = ">=13.5.1"
198
+ cupy = ">=13.6.0"
195
199
  # jaxlib = { version = "*", build = "cuda12*" } # unavailable
196
200
  pytorch = { version = ">=2.7.1", build = "cuda12*" }
197
201
 
198
202
  [tool.pixi.feature.nogil.dependencies]
199
203
  python-freethreading = "~=3.13.0"
200
- pytest-run-parallel = ">=0.6.0"
204
+ pytest-run-parallel = ">=0.6.1"
201
205
  numpy = ">=2.3.2"
202
206
  # pytorch = "*" # Not available on Python 3.13t yet
203
207
  dask-core = ">=2025.7.0" # No distributed, tornado, etc.
@@ -212,12 +216,16 @@ tests = { features = ["py313", "tests"], solve-group = "py313" }
212
216
  tests-py313 = { features = ["py313", "tests"], solve-group = "py313" } # alias of tests
213
217
 
214
218
  # Some backends may pin numpy; use separate solve-group
215
- dev = { features = ["py310", "lint", "tests", "docs", "dev", "backends"], solve-group = "backends" }
216
- tests-backends = { features = ["py310", "tests", "backends"], solve-group = "backends" }
219
+ dev = { features = ["py313", "lint", "tests", "docs", "dev", "backends"], solve-group = "backends" }
220
+ tests-backends = { features = ["py313", "tests", "backends"], solve-group = "backends" }
221
+ # Note: Python 3.10 has already been dropped by some backends (like JAX),
222
+ # so this is testing older versions.
223
+ tests-backends-py310 = { features = ["py310", "tests", "backends"] }
217
224
 
218
225
  # CUDA not available on free github actions and on some developers' PCs
219
- dev-cuda = { features = ["py310", "lint", "tests", "docs", "dev", "backends", "cuda-backends"], solve-group = "cuda" }
220
- tests-cuda = { features = ["py310", "tests", "backends", "cuda-backends"], solve-group = "cuda" }
226
+ dev-cuda = { features = ["py313", "lint", "tests", "docs", "dev", "backends", "cuda-backends"], solve-group = "cuda" }
227
+ tests-cuda = { features = ["py313", "tests", "backends", "cuda-backends"], solve-group = "cuda" }
228
+ tests-cuda-py310 = { features = ["py310", "tests", "backends", "cuda-backends"] }
221
229
 
222
230
  # Ungrouped environments
223
231
  tests-numpy1 = ["py310", "tests", "numpy1"]
@@ -17,7 +17,7 @@ from ._lib._funcs import (
17
17
  )
18
18
  from ._lib._lazy import lazy_apply
19
19
 
20
- __version__ = "0.8.2"
20
+ __version__ = "0.9.0"
21
21
 
22
22
  # pylint: disable=duplicate-code
23
23
  __all__ = [
@@ -6,7 +6,7 @@ import io
6
6
  import math
7
7
  import pickle
8
8
  import types
9
- from collections.abc import Callable, Generator, Iterable
9
+ from collections.abc import Callable, Generator, Iterable, Iterator
10
10
  from functools import wraps
11
11
  from types import ModuleType
12
12
  from typing import (
@@ -512,13 +512,24 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
512
512
  convert them to/from PyTrees.
513
513
  """
514
514
 
515
- obj: T
515
+ _obj: Any
516
+ _is_iter: bool
516
517
  _registered: ClassVar[bool] = False
517
- __slots__: tuple[str, ...] = ("obj",)
518
+ __slots__: tuple[str, ...] = ("_is_iter", "_obj")
518
519
 
519
520
  def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
520
521
  self._register()
521
- self.obj = obj
522
+ if isinstance(obj, Iterator):
523
+ self._obj = list(obj)
524
+ self._is_iter = True
525
+ else:
526
+ self._obj = obj
527
+ self._is_iter = False
528
+
529
+ @property
530
+ def obj(self) -> T: # numpydoc ignore=RT01
531
+ """Return wrapped object."""
532
+ return iter(self._obj) if self._is_iter else self._obj
522
533
 
523
534
  @classmethod
524
535
  def _register(cls) -> None: # numpydoc ignore=SS06
@@ -531,7 +542,7 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
531
542
 
532
543
  jax.tree_util.register_pytree_node(
533
544
  cls,
534
- lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
545
+ lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
535
546
  lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
536
547
  )
537
548
  cls._registered = True
@@ -556,6 +567,7 @@ def jax_autojit(
556
567
  - Automatically descend into non-array return values and find ``jax.Array`` objects
557
568
  inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
558
569
  tracer objects with concrete arrays.
570
+ - Returned iterators are immediately completely consumed.
559
571
 
560
572
  See Also
561
573
  --------
@@ -1,3 +1,4 @@
1
+ from collections.abc import Iterator
1
2
  from types import ModuleType
2
3
  from typing import TYPE_CHECKING, Generic, TypeVar, cast
3
4
 
@@ -417,3 +418,16 @@ class TestJAXAutoJIT:
417
418
  out = f([1, 2])
418
419
  assert isinstance(out, list)
419
420
  assert out == [3, 4]
421
+
422
+ def test_iterators(self, jnp: ModuleType):
423
+ @jax_autojit
424
+ def f(x: Array) -> Iterator[Array]:
425
+ return (x + i for i in range(2))
426
+
427
+ inp = jnp.asarray([1, 2])
428
+ out = f(inp)
429
+ assert isinstance(out, Iterator)
430
+ xp_assert_equal(next(out), jnp.asarray([1, 2]))
431
+ xp_assert_equal(next(out), jnp.asarray([2, 3]))
432
+ with pytest.raises(StopIteration):
433
+ _ = next(out)
@@ -1,4 +1,4 @@
1
- from collections.abc import Callable
1
+ from collections.abc import Callable, Iterator
2
2
  from types import ModuleType
3
3
  from typing import cast
4
4
 
@@ -468,3 +468,22 @@ def test_patch_lazy_xp_functions_deprecated_monkeypatch(
468
468
  monkeypatch.undo()
469
469
  y = non_materializable5(x)
470
470
  xp_assert_equal(y, x)
471
+
472
+
473
+ def my_iter(x: Array) -> Iterator[Array]:
474
+ yield x[0, :]
475
+ yield x[1, :]
476
+
477
+
478
+ lazy_xp_function(my_iter)
479
+
480
+
481
+ def test_patch_lazy_xp_functions_iter(xp: ModuleType):
482
+ x = xp.asarray([[1.0, 2.0], [3.0, 4.0]])
483
+ it = my_iter(x)
484
+
485
+ assert isinstance(it, Iterator)
486
+ xp_assert_equal(next(it), x[0, :])
487
+ xp_assert_equal(next(it), x[1, :])
488
+ with pytest.raises(StopIteration):
489
+ _ = next(it)
File without changes