array-api-extra 0.8.1__tar.gz → 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.all-contributorsrc +21 -0
  2. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/workflows/ci.yml +1 -0
  3. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/CONTRIBUTORS.md +2 -0
  4. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/PKG-INFO +3 -1
  5. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/README.md +2 -0
  6. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/api-reference.md +1 -0
  7. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/pyproject.toml +23 -15
  8. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/__init__.py +3 -2
  9. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_delegation.py +80 -1
  10. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_funcs.py +41 -0
  11. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_helpers.py +17 -5
  12. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/conftest.py +8 -0
  13. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_funcs.py +136 -0
  14. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_helpers.py +14 -0
  15. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_testing.py +20 -1
  16. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.dprint.jsonc +0 -0
  17. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.editorconfig +0 -0
  18. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.gitattributes +0 -0
  19. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/pull_request_template.md +0 -0
  20. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/workflows/cd.yml +0 -0
  21. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/workflows/docs-build.yml +0 -0
  22. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/workflows/docs-deploy.yml +0 -0
  23. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.gitignore +0 -0
  24. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/LICENSE +0 -0
  25. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/api-lazy.md +0 -0
  26. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/conf.py +0 -0
  27. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/contributing.md +0 -0
  28. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/contributors.md +0 -0
  29. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/index.md +0 -0
  30. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/lefthook.yml +0 -0
  31. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/__init__.py +0 -0
  32. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_at.py +0 -0
  33. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_backends.py +0 -0
  34. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_lazy.py +0 -0
  35. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_testing.py +0 -0
  36. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/__init__.py +0 -0
  37. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_compat.py +0 -0
  38. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_compat.pyi +0 -0
  39. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_typing.py +0 -0
  40. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_typing.pyi +0 -0
  41. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/py.typed +0 -0
  42. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/testing.py +0 -0
  43. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/__init__.py +0 -0
  44. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_at.py +0 -0
  45. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_lazy.py +0 -0
  46. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_version.py +0 -0
  47. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/typos.toml +0 -0
  48. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/vendor_tests/__init__.py +0 -0
  49. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/vendor_tests/_array_api_compat_vendor.py +0 -0
  50. {array_api_extra-0.8.1 → array_api_extra-0.9.0}/vendor_tests/test_vendor.py +0 -0
@@ -249,6 +249,27 @@
249
249
  "contributions": [
250
250
  "review"
251
251
  ]
252
+ },
253
+ {
254
+ "login": "paddyroddy",
255
+ "name": "Patrick J. Roddy",
256
+ "avatar_url": "https://avatars.githubusercontent.com/u/15052188?v=4",
257
+ "profile": "https://paddyroddy.github.io/",
258
+ "contributions": [
259
+ "code",
260
+ "doc",
261
+ "example",
262
+ "test"
263
+ ]
264
+ },
265
+ {
266
+ "login": "amacati",
267
+ "name": "Martin Schuck",
268
+ "avatar_url": "https://avatars.githubusercontent.com/u/57562633?v=4",
269
+ "profile": "https://amacati.github.io/",
270
+ "contributions": [
271
+ "ideas"
272
+ ]
252
273
  }
253
274
  ]
254
275
  }
@@ -48,6 +48,7 @@ jobs:
48
48
  - tests-py313
49
49
  - tests-numpy1
50
50
  - tests-backends
51
+ - tests-backends-py310
51
52
  - tests-nogil
52
53
  runs-on: [ubuntu-latest]
53
54
 
@@ -38,6 +38,8 @@ This project exists thanks to the following contributors
38
38
  <tr>
39
39
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
40
40
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
41
+ <td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
42
+ <td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
41
43
  </tr>
42
44
  </tbody>
43
45
  </table>
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: array-api-extra
3
- Version: 0.8.1
3
+ Version: 0.9.0
4
4
  Summary: Extra array functions built on top of the array API standard.
5
5
  Project-URL: Homepage, https://github.com/data-apis/array-api-extra
6
6
  Project-URL: Bug Tracker, https://github.com/data-apis/array-api-extra/issues
@@ -139,6 +139,8 @@ This project exists thanks to the following contributors
139
139
  <tr>
140
140
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
141
141
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
142
+ <td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
143
+ <td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
142
144
  </tr>
143
145
  </tbody>
144
146
  </table>
@@ -93,6 +93,8 @@ This project exists thanks to the following contributors
93
93
  <tr>
94
94
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
95
95
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
96
+ <td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
97
+ <td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
96
98
  </tr>
97
99
  </tbody>
98
100
  </table>
@@ -16,6 +16,7 @@
16
16
  expand_dims
17
17
  isclose
18
18
  kron
19
+ nan_to_num
19
20
  nunique
20
21
  one_hot
21
22
  pad
@@ -58,7 +58,7 @@ array-api-extra = { path = ".", editable = true }
58
58
  typing-extensions = ">=4.14.1"
59
59
  pylint = ">=3.3.8"
60
60
  mypy = ">=1.17.1"
61
- basedpyright = ">=1.31.1"
61
+ basedpyright = ">=1.31.3"
62
62
  numpydoc = ">=1.9.0,<2"
63
63
  # import dependencies for mypy:
64
64
  array-api-strict = ">=2.4.1,<2.5"
@@ -66,9 +66,9 @@ numpy = ">=2.1.3"
66
66
  hypothesis = ">=6.136.4"
67
67
  dask-core = ">=2025.7.0" # No distributed, tornado, etc.
68
68
  dprint = ">=0.50.0,<0.51"
69
- lefthook = ">=1.12.2,<2"
70
- ruff = ">=0.12.8,<0.13"
71
- typos = ">=1.35.3,<2"
69
+ lefthook = ">=1.12.3,<2"
70
+ ruff = ">=0.12.11,<0.13"
71
+ typos = ">=1.35.5,<2"
72
72
  actionlint = ">=1.7.7,<2"
73
73
  blacken-docs = ">=1.19.1,<2"
74
74
  pytest = ">=8.4.1,<9"
@@ -155,13 +155,16 @@ dask-core = ">=2025.7.0" # No distributed, tornado, etc.
155
155
  sparse = ">=0.17.0"
156
156
 
157
157
  [tool.pixi.feature.backends.target.linux-64.dependencies]
158
- jax = ">=0.6.0,!=0.6.2" # 0.6.2 segfaults on Linux CUDA
158
+ # On CPU Python 3.10, use 0.6.2
159
+ # On CPU Python >=3.11, use >=0.7.0
160
+ # On GPU, use 0.6.0 (0.6.2 and 0.7.0 both segfault); see jaxlib pin below.
161
+ jax = ">=0.6.0"
159
162
 
160
163
  [tool.pixi.feature.backends.target.osx-64.dependencies]
161
- jax = ">=0.6.0,!=0.6.2"
164
+ jax = ">=0.6.0"
162
165
 
163
166
  [tool.pixi.feature.backends.target.osx-arm64.dependencies]
164
- jax = ">=0.6.0,!=0.6.2"
167
+ jax = ">=0.6.0"
165
168
 
166
169
  [tool.pixi.feature.backends.target.win-64.dependencies]
167
170
  # jax = "*" # unavailable
@@ -176,8 +179,9 @@ jax = ">=0.6.0,!=0.6.2"
176
179
  system-requirements = { cuda = "12" }
177
180
 
178
181
  [tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
179
- cupy = ">=13.5.1"
180
- jaxlib = { version = ">=0.6.0", build = "cuda12*" }
182
+ cupy = ">=13.6.0"
183
+ # JAX 0.6.2 and 0.7.0 segfault on CUDA
184
+ jaxlib = { version = ">=0.6.0,!=0.6.2,!=0.7.0", build = "cuda12*" }
181
185
  pytorch = { version = ">=2.7.1", build = "cuda12*" }
182
186
 
183
187
  [tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
@@ -191,13 +195,13 @@ pytorch = { version = ">=2.7.1", build = "cuda12*" }
191
195
  # pytorch = { version = "*", build = "cuda12*" } # unavailable
192
196
 
193
197
  [tool.pixi.feature.cuda-backends.target.win-64.dependencies]
194
- cupy = ">=13.5.1"
198
+ cupy = ">=13.6.0"
195
199
  # jaxlib = { version = "*", build = "cuda12*" } # unavailable
196
200
  pytorch = { version = ">=2.7.1", build = "cuda12*" }
197
201
 
198
202
  [tool.pixi.feature.nogil.dependencies]
199
203
  python-freethreading = "~=3.13.0"
200
- pytest-run-parallel = ">=0.6.0"
204
+ pytest-run-parallel = ">=0.6.1"
201
205
  numpy = ">=2.3.2"
202
206
  # pytorch = "*" # Not available on Python 3.13t yet
203
207
  dask-core = ">=2025.7.0" # No distributed, tornado, etc.
@@ -212,12 +216,16 @@ tests = { features = ["py313", "tests"], solve-group = "py313" }
212
216
  tests-py313 = { features = ["py313", "tests"], solve-group = "py313" } # alias of tests
213
217
 
214
218
  # Some backends may pin numpy; use separate solve-group
215
- dev = { features = ["py310", "lint", "tests", "docs", "dev", "backends"], solve-group = "backends" }
216
- tests-backends = { features = ["py310", "tests", "backends"], solve-group = "backends" }
219
+ dev = { features = ["py313", "lint", "tests", "docs", "dev", "backends"], solve-group = "backends" }
220
+ tests-backends = { features = ["py313", "tests", "backends"], solve-group = "backends" }
221
+ # Note: Python 3.10 has already been dropped by some backends (like JAX),
222
+ # so this is testing older versions.
223
+ tests-backends-py310 = { features = ["py310", "tests", "backends"] }
217
224
 
218
225
  # CUDA not available on free github actions and on some developers' PCs
219
- dev-cuda = { features = ["py310", "lint", "tests", "docs", "dev", "backends", "cuda-backends"], solve-group = "cuda" }
220
- tests-cuda = { features = ["py310", "tests", "backends", "cuda-backends"], solve-group = "cuda" }
226
+ dev-cuda = { features = ["py313", "lint", "tests", "docs", "dev", "backends", "cuda-backends"], solve-group = "cuda" }
227
+ tests-cuda = { features = ["py313", "tests", "backends", "cuda-backends"], solve-group = "cuda" }
228
+ tests-cuda-py310 = { features = ["py310", "tests", "backends", "cuda-backends"] }
221
229
 
222
230
  # Ungrouped environments
223
231
  tests-numpy1 = ["py310", "tests", "numpy1"]
@@ -1,6 +1,6 @@
1
1
  """Extra array functions built on top of the array API standard."""
2
2
 
3
- from ._delegation import isclose, one_hot, pad
3
+ from ._delegation import isclose, nan_to_num, one_hot, pad
4
4
  from ._lib._at import at
5
5
  from ._lib._funcs import (
6
6
  apply_where,
@@ -17,7 +17,7 @@ from ._lib._funcs import (
17
17
  )
18
18
  from ._lib._lazy import lazy_apply
19
19
 
20
- __version__ = "0.8.1"
20
+ __version__ = "0.9.0"
21
21
 
22
22
  # pylint: disable=duplicate-code
23
23
  __all__ = [
@@ -33,6 +33,7 @@ __all__ = [
33
33
  "isclose",
34
34
  "kron",
35
35
  "lazy_apply",
36
+ "nan_to_num",
36
37
  "nunique",
37
38
  "one_hot",
38
39
  "pad",
@@ -18,7 +18,7 @@ from ._lib._utils._compat import device as get_device
18
18
  from ._lib._utils._helpers import asarrays
19
19
  from ._lib._utils._typing import Array, DType
20
20
 
21
- __all__ = ["isclose", "one_hot", "pad"]
21
+ __all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
22
22
 
23
23
 
24
24
  def isclose(
@@ -113,6 +113,85 @@ def isclose(
113
113
  return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
114
114
 
115
115
 
116
+ def nan_to_num(
117
+ x: Array | float | complex,
118
+ /,
119
+ *,
120
+ fill_value: int | float = 0.0,
121
+ xp: ModuleType | None = None,
122
+ ) -> Array:
123
+ """
124
+ Replace NaN with zero and infinity with large finite numbers (default behaviour).
125
+
126
+ If `x` is inexact, NaN is replaced by zero or by the user defined value in the
127
+ `fill_value` keyword, infinity is replaced by the largest finite floating
128
+ point value representable by ``x.dtype``, and -infinity is replaced by the
129
+ most negative finite floating point value representable by ``x.dtype``.
130
+
131
+ For complex dtypes, the above is applied to each of the real and
132
+ imaginary components of `x` separately.
133
+
134
+ Parameters
135
+ ----------
136
+ x : array | float | complex
137
+ Input data.
138
+ fill_value : int | float, optional
139
+ Value to be used to fill NaN values. If no value is passed
140
+ then NaN values will be replaced with 0.0.
141
+ xp : array_namespace, optional
142
+ The standard-compatible namespace for `x`. Default: infer.
143
+
144
+ Returns
145
+ -------
146
+ array
147
+ `x`, with the non-finite values replaced.
148
+
149
+ See Also
150
+ --------
151
+ array_api.isnan : Shows which elements are Not a Number (NaN).
152
+
153
+ Examples
154
+ --------
155
+ >>> import array_api_extra as xpx
156
+ >>> import array_api_strict as xp
157
+ >>> xpx.nan_to_num(xp.inf)
158
+ 1.7976931348623157e+308
159
+ >>> xpx.nan_to_num(-xp.inf)
160
+ -1.7976931348623157e+308
161
+ >>> xpx.nan_to_num(xp.nan)
162
+ 0.0
163
+ >>> x = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
164
+ >>> xpx.nan_to_num(x)
165
+ array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
166
+ -1.28000000e+002, 1.28000000e+002])
167
+ >>> y = xp.asarray([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)])
168
+ array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
169
+ -1.28000000e+002, 1.28000000e+002])
170
+ >>> xpx.nan_to_num(y)
171
+ array([ 1.79769313e+308 +0.00000000e+000j, # may vary
172
+ 0.00000000e+000 +0.00000000e+000j,
173
+ 0.00000000e+000 +1.79769313e+308j])
174
+ """
175
+ if isinstance(fill_value, complex):
176
+ msg = "Complex fill values are not supported."
177
+ raise TypeError(msg)
178
+
179
+ xp = array_namespace(x) if xp is None else xp
180
+
181
+ # for scalars we want to output an array
182
+ y = xp.asarray(x)
183
+
184
+ if (
185
+ is_cupy_namespace(xp)
186
+ or is_jax_namespace(xp)
187
+ or is_numpy_namespace(xp)
188
+ or is_torch_namespace(xp)
189
+ ):
190
+ return xp.nan_to_num(y, nan=fill_value)
191
+
192
+ return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp)
193
+
194
+
116
195
  def one_hot(
117
196
  x: Array,
118
197
  /,
@@ -738,6 +738,47 @@ def kron(
738
738
  return xp.reshape(result, res_shape)
739
739
 
740
740
 
741
+ def nan_to_num( # numpydoc ignore=PR01,RT01
742
+ x: Array,
743
+ /,
744
+ fill_value: int | float = 0.0,
745
+ *,
746
+ xp: ModuleType,
747
+ ) -> Array:
748
+ """See docstring in `array_api_extra._delegation.py`."""
749
+
750
+ def perform_replacements( # numpydoc ignore=PR01,RT01
751
+ x: Array,
752
+ fill_value: int | float,
753
+ xp: ModuleType,
754
+ ) -> Array:
755
+ """Internal function to perform the replacements."""
756
+ x = xp.where(xp.isnan(x), fill_value, x)
757
+
758
+ # convert infinities to finite values
759
+ finfo = xp.finfo(x.dtype)
760
+ idx_posinf = xp.isinf(x) & ~xp.signbit(x)
761
+ idx_neginf = xp.isinf(x) & xp.signbit(x)
762
+ x = xp.where(idx_posinf, finfo.max, x)
763
+ return xp.where(idx_neginf, finfo.min, x)
764
+
765
+ if xp.isdtype(x.dtype, "complex floating"):
766
+ return perform_replacements(
767
+ xp.real(x),
768
+ fill_value,
769
+ xp,
770
+ ) + 1j * perform_replacements(
771
+ xp.imag(x),
772
+ fill_value,
773
+ xp,
774
+ )
775
+
776
+ if xp.isdtype(x.dtype, "numeric"):
777
+ return perform_replacements(x, fill_value, xp)
778
+
779
+ return x
780
+
781
+
741
782
  def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
742
783
  """
743
784
  Count the number of unique elements in an array.
@@ -6,7 +6,7 @@ import io
6
6
  import math
7
7
  import pickle
8
8
  import types
9
- from collections.abc import Callable, Generator, Iterable
9
+ from collections.abc import Callable, Generator, Iterable, Iterator
10
10
  from functools import wraps
11
11
  from types import ModuleType
12
12
  from typing import (
@@ -512,13 +512,24 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
512
512
  convert them to/from PyTrees.
513
513
  """
514
514
 
515
- obj: T
515
+ _obj: Any
516
+ _is_iter: bool
516
517
  _registered: ClassVar[bool] = False
517
- __slots__: tuple[str, ...] = ("obj",)
518
+ __slots__: tuple[str, ...] = ("_is_iter", "_obj")
518
519
 
519
520
  def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
520
521
  self._register()
521
- self.obj = obj
522
+ if isinstance(obj, Iterator):
523
+ self._obj = list(obj)
524
+ self._is_iter = True
525
+ else:
526
+ self._obj = obj
527
+ self._is_iter = False
528
+
529
+ @property
530
+ def obj(self) -> T: # numpydoc ignore=RT01
531
+ """Return wrapped object."""
532
+ return iter(self._obj) if self._is_iter else self._obj
522
533
 
523
534
  @classmethod
524
535
  def _register(cls) -> None: # numpydoc ignore=SS06
@@ -531,7 +542,7 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
531
542
 
532
543
  jax.tree_util.register_pytree_node(
533
544
  cls,
534
- lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
545
+ lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
535
546
  lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
536
547
  )
537
548
  cls._registered = True
@@ -556,6 +567,7 @@ def jax_autojit(
556
567
  - Automatically descend into non-array return values and find ``jax.Array`` objects
557
568
  inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
558
569
  tracer objects with concrete arrays.
570
+ - Returned iterators are immediately completely consumed.
559
571
 
560
572
  See Also
561
573
  --------
@@ -232,3 +232,11 @@ def device(
232
232
  if library == Backend.TORCH_GPU:
233
233
  return xp.device("cpu")
234
234
  return get_device(xp.empty(0))
235
+
236
+
237
+ @pytest.fixture
238
+ def infinity(library: Backend) -> float:
239
+ """Retrieve the positive infinity value for the given backend."""
240
+ if library in (Backend.TORCH, Backend.TORCH_GPU):
241
+ return 3.4028235e38
242
+ return 1.7976931348623157e308
@@ -21,6 +21,7 @@ from array_api_extra import (
21
21
  expand_dims,
22
22
  isclose,
23
23
  kron,
24
+ nan_to_num,
24
25
  nunique,
25
26
  one_hot,
26
27
  pad,
@@ -40,6 +41,7 @@ lazy_xp_function(cov)
40
41
  lazy_xp_function(create_diagonal)
41
42
  lazy_xp_function(expand_dims)
42
43
  lazy_xp_function(kron)
44
+ lazy_xp_function(nan_to_num)
43
45
  lazy_xp_function(nunique)
44
46
  lazy_xp_function(one_hot)
45
47
  lazy_xp_function(pad)
@@ -941,6 +943,140 @@ class TestKron:
941
943
  xp_assert_equal(kron(a, b, xp=xp), k)
942
944
 
943
945
 
946
+ class TestNanToNum:
947
+ def test_bool(self, xp: ModuleType) -> None:
948
+ a = xp.asarray([True])
949
+ xp_assert_equal(nan_to_num(a, xp=xp), a)
950
+
951
+ def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None:
952
+ a = xp.inf
953
+ xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity))
954
+
955
+ def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None:
956
+ a = -xp.inf
957
+ xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity))
958
+
959
+ def test_scalar_nan(self, xp: ModuleType) -> None:
960
+ a = xp.nan
961
+ xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0))
962
+
963
+ def test_real(self, xp: ModuleType, infinity: float) -> None:
964
+ a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
965
+ xp_assert_equal(
966
+ nan_to_num(a, xp=xp),
967
+ xp.asarray(
968
+ [
969
+ infinity,
970
+ -infinity,
971
+ 0.0,
972
+ -128,
973
+ 128,
974
+ ]
975
+ ),
976
+ )
977
+
978
+ def test_complex(self, xp: ModuleType, infinity: float) -> None:
979
+ a = xp.asarray(
980
+ [
981
+ complex(xp.inf, xp.nan),
982
+ xp.nan,
983
+ complex(xp.nan, xp.inf),
984
+ ]
985
+ )
986
+ xp_assert_equal(
987
+ nan_to_num(a),
988
+ xp.asarray([complex(infinity, 0), complex(0, 0), complex(0, infinity)]),
989
+ )
990
+
991
+ def test_empty_array(self, xp: ModuleType) -> None:
992
+ a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch
993
+ xp_assert_equal(nan_to_num(a, xp=xp), a)
994
+ assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32)
995
+
996
+ @pytest.mark.parametrize(
997
+ ("in_vals", "fill_value", "out_vals"),
998
+ [
999
+ ([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]),
1000
+ ([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]),
1001
+ (
1002
+ [
1003
+ complex(1, 1),
1004
+ complex(2, 2),
1005
+ complex(np.nan, 0),
1006
+ complex(4, 4),
1007
+ ],
1008
+ 3,
1009
+ [
1010
+ complex(1.0, 1.0),
1011
+ complex(2.0, 2.0),
1012
+ complex(3.0, 0.0),
1013
+ complex(4.0, 4.0),
1014
+ ],
1015
+ ),
1016
+ (
1017
+ [
1018
+ complex(1, 1),
1019
+ complex(2, 2),
1020
+ complex(0, np.nan),
1021
+ complex(4, 4),
1022
+ ],
1023
+ 3.0,
1024
+ [
1025
+ complex(1.0, 1.0),
1026
+ complex(2.0, 2.0),
1027
+ complex(0.0, 3.0),
1028
+ complex(4.0, 4.0),
1029
+ ],
1030
+ ),
1031
+ (
1032
+ [
1033
+ complex(1, 1),
1034
+ complex(2, 2),
1035
+ complex(np.nan, np.nan),
1036
+ complex(4, 4),
1037
+ ],
1038
+ 3.0,
1039
+ [
1040
+ complex(1.0, 1.0),
1041
+ complex(2.0, 2.0),
1042
+ complex(3.0, 3.0),
1043
+ complex(4.0, 4.0),
1044
+ ],
1045
+ ),
1046
+ ],
1047
+ )
1048
+ def test_fill_value_success(
1049
+ self,
1050
+ xp: ModuleType,
1051
+ in_vals: Array,
1052
+ fill_value: int | float,
1053
+ out_vals: Array,
1054
+ ) -> None:
1055
+ a = xp.asarray(in_vals)
1056
+ xp_assert_equal(
1057
+ nan_to_num(a, fill_value=fill_value, xp=xp),
1058
+ xp.asarray(out_vals),
1059
+ )
1060
+
1061
+ def test_fill_value_failure(self, xp: ModuleType) -> None:
1062
+ a = xp.asarray(
1063
+ [
1064
+ complex(1, 1),
1065
+ complex(xp.nan, xp.nan),
1066
+ complex(3, 3),
1067
+ ]
1068
+ )
1069
+ with pytest.raises(
1070
+ TypeError,
1071
+ match="Complex fill values are not supported",
1072
+ ):
1073
+ _ = nan_to_num(
1074
+ a,
1075
+ fill_value=complex(2, 2), # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
1076
+ xp=xp,
1077
+ )
1078
+
1079
+
944
1080
  class TestNUnique:
945
1081
  def test_simple(self, xp: ModuleType):
946
1082
  a = xp.asarray([[1, 1], [0, 2], [2, 2]])
@@ -1,3 +1,4 @@
1
+ from collections.abc import Iterator
1
2
  from types import ModuleType
2
3
  from typing import TYPE_CHECKING, Generic, TypeVar, cast
3
4
 
@@ -417,3 +418,16 @@ class TestJAXAutoJIT:
417
418
  out = f([1, 2])
418
419
  assert isinstance(out, list)
419
420
  assert out == [3, 4]
421
+
422
+ def test_iterators(self, jnp: ModuleType):
423
+ @jax_autojit
424
+ def f(x: Array) -> Iterator[Array]:
425
+ return (x + i for i in range(2))
426
+
427
+ inp = jnp.asarray([1, 2])
428
+ out = f(inp)
429
+ assert isinstance(out, Iterator)
430
+ xp_assert_equal(next(out), jnp.asarray([1, 2]))
431
+ xp_assert_equal(next(out), jnp.asarray([2, 3]))
432
+ with pytest.raises(StopIteration):
433
+ _ = next(out)
@@ -1,4 +1,4 @@
1
- from collections.abc import Callable
1
+ from collections.abc import Callable, Iterator
2
2
  from types import ModuleType
3
3
  from typing import cast
4
4
 
@@ -468,3 +468,22 @@ def test_patch_lazy_xp_functions_deprecated_monkeypatch(
468
468
  monkeypatch.undo()
469
469
  y = non_materializable5(x)
470
470
  xp_assert_equal(y, x)
471
+
472
+
473
+ def my_iter(x: Array) -> Iterator[Array]:
474
+ yield x[0, :]
475
+ yield x[1, :]
476
+
477
+
478
+ lazy_xp_function(my_iter)
479
+
480
+
481
+ def test_patch_lazy_xp_functions_iter(xp: ModuleType):
482
+ x = xp.asarray([[1.0, 2.0], [3.0, 4.0]])
483
+ it = my_iter(x)
484
+
485
+ assert isinstance(it, Iterator)
486
+ xp_assert_equal(next(it), x[0, :])
487
+ xp_assert_equal(next(it), x[1, :])
488
+ with pytest.raises(StopIteration):
489
+ _ = next(it)
File without changes