array-api-extra 0.8.1__tar.gz → 0.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.all-contributorsrc +21 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/workflows/ci.yml +1 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/CONTRIBUTORS.md +2 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/PKG-INFO +3 -1
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/README.md +2 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/api-reference.md +1 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/pyproject.toml +23 -15
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/__init__.py +3 -2
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_delegation.py +80 -1
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_funcs.py +41 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_helpers.py +17 -5
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/conftest.py +8 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_funcs.py +136 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_helpers.py +14 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_testing.py +20 -1
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.dprint.jsonc +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.editorconfig +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.gitattributes +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/pull_request_template.md +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/workflows/cd.yml +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/workflows/docs-build.yml +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.github/workflows/docs-deploy.yml +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/.gitignore +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/LICENSE +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/api-lazy.md +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/conf.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/contributing.md +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/contributors.md +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/docs/index.md +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/lefthook.yml +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/__init__.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_at.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_backends.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_lazy.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_testing.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/__init__.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_compat.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_compat.pyi +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_typing.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/_lib/_utils/_typing.pyi +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/py.typed +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/src/array_api_extra/testing.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/__init__.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_at.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_lazy.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/tests/test_version.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/typos.toml +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/vendor_tests/__init__.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/vendor_tests/_array_api_compat_vendor.py +0 -0
- {array_api_extra-0.8.1 → array_api_extra-0.9.0}/vendor_tests/test_vendor.py +0 -0
|
@@ -249,6 +249,27 @@
|
|
|
249
249
|
"contributions": [
|
|
250
250
|
"review"
|
|
251
251
|
]
|
|
252
|
+
},
|
|
253
|
+
{
|
|
254
|
+
"login": "paddyroddy",
|
|
255
|
+
"name": "Patrick J. Roddy",
|
|
256
|
+
"avatar_url": "https://avatars.githubusercontent.com/u/15052188?v=4",
|
|
257
|
+
"profile": "https://paddyroddy.github.io/",
|
|
258
|
+
"contributions": [
|
|
259
|
+
"code",
|
|
260
|
+
"doc",
|
|
261
|
+
"example",
|
|
262
|
+
"test"
|
|
263
|
+
]
|
|
264
|
+
},
|
|
265
|
+
{
|
|
266
|
+
"login": "amacati",
|
|
267
|
+
"name": "Martin Schuck",
|
|
268
|
+
"avatar_url": "https://avatars.githubusercontent.com/u/57562633?v=4",
|
|
269
|
+
"profile": "https://amacati.github.io/",
|
|
270
|
+
"contributions": [
|
|
271
|
+
"ideas"
|
|
272
|
+
]
|
|
252
273
|
}
|
|
253
274
|
]
|
|
254
275
|
}
|
|
@@ -38,6 +38,8 @@ This project exists thanks to the following contributors
|
|
|
38
38
|
<tr>
|
|
39
39
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
|
|
40
40
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
|
|
41
|
+
<td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
|
|
42
|
+
<td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
|
|
41
43
|
</tr>
|
|
42
44
|
</tbody>
|
|
43
45
|
</table>
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: array-api-extra
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.9.0
|
|
4
4
|
Summary: Extra array functions built on top of the array API standard.
|
|
5
5
|
Project-URL: Homepage, https://github.com/data-apis/array-api-extra
|
|
6
6
|
Project-URL: Bug Tracker, https://github.com/data-apis/array-api-extra/issues
|
|
@@ -139,6 +139,8 @@ This project exists thanks to the following contributors
|
|
|
139
139
|
<tr>
|
|
140
140
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
|
|
141
141
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
|
|
142
|
+
<td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
|
|
143
|
+
<td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
|
|
142
144
|
</tr>
|
|
143
145
|
</tbody>
|
|
144
146
|
</table>
|
|
@@ -93,6 +93,8 @@ This project exists thanks to the following contributors
|
|
|
93
93
|
<tr>
|
|
94
94
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/lithomas1"><img src="https://avatars.githubusercontent.com/u/47963215?v=4?s=100" width="100px;" alt="Thomas Li"/><br /><sub><b>Thomas Li</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/issues?q=author%3Alithomas1" title="Bug reports">🐛</a> <a href="#tool-lithomas1" title="Tools">🔧</a></td>
|
|
95
95
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/pearu"><img src="https://avatars.githubusercontent.com/u/402156?v=4?s=100" width="100px;" alt="Pearu Peterson"/><br /><sub><b>Pearu Peterson</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/pulls?q=is%3Apr+reviewed-by%3Apearu" title="Reviewed Pull Requests">👀</a></td>
|
|
96
|
+
<td align="center" valign="top" width="14.28%"><a href="https://paddyroddy.github.io/"><img src="https://avatars.githubusercontent.com/u/15052188?v=4?s=100" width="100px;" alt="Patrick J. Roddy"/><br /><sub><b>Patrick J. Roddy</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Documentation">📖</a> <a href="#example-paddyroddy" title="Examples">💡</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=paddyroddy" title="Tests">⚠️</a></td>
|
|
97
|
+
<td align="center" valign="top" width="14.28%"><a href="https://amacati.github.io/"><img src="https://avatars.githubusercontent.com/u/57562633?v=4?s=100" width="100px;" alt="Martin Schuck"/><br /><sub><b>Martin Schuck</b></sub></a><br /><a href="#ideas-amacati" title="Ideas, Planning, & Feedback">🤔</a></td>
|
|
96
98
|
</tr>
|
|
97
99
|
</tbody>
|
|
98
100
|
</table>
|
|
@@ -58,7 +58,7 @@ array-api-extra = { path = ".", editable = true }
|
|
|
58
58
|
typing-extensions = ">=4.14.1"
|
|
59
59
|
pylint = ">=3.3.8"
|
|
60
60
|
mypy = ">=1.17.1"
|
|
61
|
-
basedpyright = ">=1.31.
|
|
61
|
+
basedpyright = ">=1.31.3"
|
|
62
62
|
numpydoc = ">=1.9.0,<2"
|
|
63
63
|
# import dependencies for mypy:
|
|
64
64
|
array-api-strict = ">=2.4.1,<2.5"
|
|
@@ -66,9 +66,9 @@ numpy = ">=2.1.3"
|
|
|
66
66
|
hypothesis = ">=6.136.4"
|
|
67
67
|
dask-core = ">=2025.7.0" # No distributed, tornado, etc.
|
|
68
68
|
dprint = ">=0.50.0,<0.51"
|
|
69
|
-
lefthook = ">=1.12.
|
|
70
|
-
ruff = ">=0.12.
|
|
71
|
-
typos = ">=1.35.
|
|
69
|
+
lefthook = ">=1.12.3,<2"
|
|
70
|
+
ruff = ">=0.12.11,<0.13"
|
|
71
|
+
typos = ">=1.35.5,<2"
|
|
72
72
|
actionlint = ">=1.7.7,<2"
|
|
73
73
|
blacken-docs = ">=1.19.1,<2"
|
|
74
74
|
pytest = ">=8.4.1,<9"
|
|
@@ -155,13 +155,16 @@ dask-core = ">=2025.7.0" # No distributed, tornado, etc.
|
|
|
155
155
|
sparse = ">=0.17.0"
|
|
156
156
|
|
|
157
157
|
[tool.pixi.feature.backends.target.linux-64.dependencies]
|
|
158
|
-
|
|
158
|
+
# On CPU Python 3.10, use 0.6.2
|
|
159
|
+
# On CPU Python >=3.11, use >=0.7.0
|
|
160
|
+
# On GPU, use 0.6.0 (0.6.2 and 0.7.0 both segfault); see jaxlib pin below.
|
|
161
|
+
jax = ">=0.6.0"
|
|
159
162
|
|
|
160
163
|
[tool.pixi.feature.backends.target.osx-64.dependencies]
|
|
161
|
-
jax = ">=0.6.0
|
|
164
|
+
jax = ">=0.6.0"
|
|
162
165
|
|
|
163
166
|
[tool.pixi.feature.backends.target.osx-arm64.dependencies]
|
|
164
|
-
jax = ">=0.6.0
|
|
167
|
+
jax = ">=0.6.0"
|
|
165
168
|
|
|
166
169
|
[tool.pixi.feature.backends.target.win-64.dependencies]
|
|
167
170
|
# jax = "*" # unavailable
|
|
@@ -176,8 +179,9 @@ jax = ">=0.6.0,!=0.6.2"
|
|
|
176
179
|
system-requirements = { cuda = "12" }
|
|
177
180
|
|
|
178
181
|
[tool.pixi.feature.cuda-backends.target.linux-64.dependencies]
|
|
179
|
-
cupy = ">=13.
|
|
180
|
-
|
|
182
|
+
cupy = ">=13.6.0"
|
|
183
|
+
# JAX 0.6.2 and 0.7.0 segfault on CUDA
|
|
184
|
+
jaxlib = { version = ">=0.6.0,!=0.6.2,!=0.7.0", build = "cuda12*" }
|
|
181
185
|
pytorch = { version = ">=2.7.1", build = "cuda12*" }
|
|
182
186
|
|
|
183
187
|
[tool.pixi.feature.cuda-backends.target.osx-64.dependencies]
|
|
@@ -191,13 +195,13 @@ pytorch = { version = ">=2.7.1", build = "cuda12*" }
|
|
|
191
195
|
# pytorch = { version = "*", build = "cuda12*" } # unavailable
|
|
192
196
|
|
|
193
197
|
[tool.pixi.feature.cuda-backends.target.win-64.dependencies]
|
|
194
|
-
cupy = ">=13.
|
|
198
|
+
cupy = ">=13.6.0"
|
|
195
199
|
# jaxlib = { version = "*", build = "cuda12*" } # unavailable
|
|
196
200
|
pytorch = { version = ">=2.7.1", build = "cuda12*" }
|
|
197
201
|
|
|
198
202
|
[tool.pixi.feature.nogil.dependencies]
|
|
199
203
|
python-freethreading = "~=3.13.0"
|
|
200
|
-
pytest-run-parallel = ">=0.6.
|
|
204
|
+
pytest-run-parallel = ">=0.6.1"
|
|
201
205
|
numpy = ">=2.3.2"
|
|
202
206
|
# pytorch = "*" # Not available on Python 3.13t yet
|
|
203
207
|
dask-core = ">=2025.7.0" # No distributed, tornado, etc.
|
|
@@ -212,12 +216,16 @@ tests = { features = ["py313", "tests"], solve-group = "py313" }
|
|
|
212
216
|
tests-py313 = { features = ["py313", "tests"], solve-group = "py313" } # alias of tests
|
|
213
217
|
|
|
214
218
|
# Some backends may pin numpy; use separate solve-group
|
|
215
|
-
dev = { features = ["
|
|
216
|
-
tests-backends = { features = ["
|
|
219
|
+
dev = { features = ["py313", "lint", "tests", "docs", "dev", "backends"], solve-group = "backends" }
|
|
220
|
+
tests-backends = { features = ["py313", "tests", "backends"], solve-group = "backends" }
|
|
221
|
+
# Note: Python 3.10 has already been dropped by some backends (like JAX),
|
|
222
|
+
# so this is testing older versions.
|
|
223
|
+
tests-backends-py310 = { features = ["py310", "tests", "backends"] }
|
|
217
224
|
|
|
218
225
|
# CUDA not available on free github actions and on some developers' PCs
|
|
219
|
-
dev-cuda = { features = ["
|
|
220
|
-
tests-cuda = { features = ["
|
|
226
|
+
dev-cuda = { features = ["py313", "lint", "tests", "docs", "dev", "backends", "cuda-backends"], solve-group = "cuda" }
|
|
227
|
+
tests-cuda = { features = ["py313", "tests", "backends", "cuda-backends"], solve-group = "cuda" }
|
|
228
|
+
tests-cuda-py310 = { features = ["py310", "tests", "backends", "cuda-backends"] }
|
|
221
229
|
|
|
222
230
|
# Ungrouped environments
|
|
223
231
|
tests-numpy1 = ["py310", "tests", "numpy1"]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""Extra array functions built on top of the array API standard."""
|
|
2
2
|
|
|
3
|
-
from ._delegation import isclose, one_hot, pad
|
|
3
|
+
from ._delegation import isclose, nan_to_num, one_hot, pad
|
|
4
4
|
from ._lib._at import at
|
|
5
5
|
from ._lib._funcs import (
|
|
6
6
|
apply_where,
|
|
@@ -17,7 +17,7 @@ from ._lib._funcs import (
|
|
|
17
17
|
)
|
|
18
18
|
from ._lib._lazy import lazy_apply
|
|
19
19
|
|
|
20
|
-
__version__ = "0.
|
|
20
|
+
__version__ = "0.9.0"
|
|
21
21
|
|
|
22
22
|
# pylint: disable=duplicate-code
|
|
23
23
|
__all__ = [
|
|
@@ -33,6 +33,7 @@ __all__ = [
|
|
|
33
33
|
"isclose",
|
|
34
34
|
"kron",
|
|
35
35
|
"lazy_apply",
|
|
36
|
+
"nan_to_num",
|
|
36
37
|
"nunique",
|
|
37
38
|
"one_hot",
|
|
38
39
|
"pad",
|
|
@@ -18,7 +18,7 @@ from ._lib._utils._compat import device as get_device
|
|
|
18
18
|
from ._lib._utils._helpers import asarrays
|
|
19
19
|
from ._lib._utils._typing import Array, DType
|
|
20
20
|
|
|
21
|
-
__all__ = ["isclose", "one_hot", "pad"]
|
|
21
|
+
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]
|
|
22
22
|
|
|
23
23
|
|
|
24
24
|
def isclose(
|
|
@@ -113,6 +113,85 @@ def isclose(
|
|
|
113
113
|
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)
|
|
114
114
|
|
|
115
115
|
|
|
116
|
+
def nan_to_num(
|
|
117
|
+
x: Array | float | complex,
|
|
118
|
+
/,
|
|
119
|
+
*,
|
|
120
|
+
fill_value: int | float = 0.0,
|
|
121
|
+
xp: ModuleType | None = None,
|
|
122
|
+
) -> Array:
|
|
123
|
+
"""
|
|
124
|
+
Replace NaN with zero and infinity with large finite numbers (default behaviour).
|
|
125
|
+
|
|
126
|
+
If `x` is inexact, NaN is replaced by zero or by the user defined value in the
|
|
127
|
+
`fill_value` keyword, infinity is replaced by the largest finite floating
|
|
128
|
+
point value representable by ``x.dtype``, and -infinity is replaced by the
|
|
129
|
+
most negative finite floating point value representable by ``x.dtype``.
|
|
130
|
+
|
|
131
|
+
For complex dtypes, the above is applied to each of the real and
|
|
132
|
+
imaginary components of `x` separately.
|
|
133
|
+
|
|
134
|
+
Parameters
|
|
135
|
+
----------
|
|
136
|
+
x : array | float | complex
|
|
137
|
+
Input data.
|
|
138
|
+
fill_value : int | float, optional
|
|
139
|
+
Value to be used to fill NaN values. If no value is passed
|
|
140
|
+
then NaN values will be replaced with 0.0.
|
|
141
|
+
xp : array_namespace, optional
|
|
142
|
+
The standard-compatible namespace for `x`. Default: infer.
|
|
143
|
+
|
|
144
|
+
Returns
|
|
145
|
+
-------
|
|
146
|
+
array
|
|
147
|
+
`x`, with the non-finite values replaced.
|
|
148
|
+
|
|
149
|
+
See Also
|
|
150
|
+
--------
|
|
151
|
+
array_api.isnan : Shows which elements are Not a Number (NaN).
|
|
152
|
+
|
|
153
|
+
Examples
|
|
154
|
+
--------
|
|
155
|
+
>>> import array_api_extra as xpx
|
|
156
|
+
>>> import array_api_strict as xp
|
|
157
|
+
>>> xpx.nan_to_num(xp.inf)
|
|
158
|
+
1.7976931348623157e+308
|
|
159
|
+
>>> xpx.nan_to_num(-xp.inf)
|
|
160
|
+
-1.7976931348623157e+308
|
|
161
|
+
>>> xpx.nan_to_num(xp.nan)
|
|
162
|
+
0.0
|
|
163
|
+
>>> x = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
|
|
164
|
+
>>> xpx.nan_to_num(x)
|
|
165
|
+
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
|
|
166
|
+
-1.28000000e+002, 1.28000000e+002])
|
|
167
|
+
>>> y = xp.asarray([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)])
|
|
168
|
+
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
|
|
169
|
+
-1.28000000e+002, 1.28000000e+002])
|
|
170
|
+
>>> xpx.nan_to_num(y)
|
|
171
|
+
array([ 1.79769313e+308 +0.00000000e+000j, # may vary
|
|
172
|
+
0.00000000e+000 +0.00000000e+000j,
|
|
173
|
+
0.00000000e+000 +1.79769313e+308j])
|
|
174
|
+
"""
|
|
175
|
+
if isinstance(fill_value, complex):
|
|
176
|
+
msg = "Complex fill values are not supported."
|
|
177
|
+
raise TypeError(msg)
|
|
178
|
+
|
|
179
|
+
xp = array_namespace(x) if xp is None else xp
|
|
180
|
+
|
|
181
|
+
# for scalars we want to output an array
|
|
182
|
+
y = xp.asarray(x)
|
|
183
|
+
|
|
184
|
+
if (
|
|
185
|
+
is_cupy_namespace(xp)
|
|
186
|
+
or is_jax_namespace(xp)
|
|
187
|
+
or is_numpy_namespace(xp)
|
|
188
|
+
or is_torch_namespace(xp)
|
|
189
|
+
):
|
|
190
|
+
return xp.nan_to_num(y, nan=fill_value)
|
|
191
|
+
|
|
192
|
+
return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp)
|
|
193
|
+
|
|
194
|
+
|
|
116
195
|
def one_hot(
|
|
117
196
|
x: Array,
|
|
118
197
|
/,
|
|
@@ -738,6 +738,47 @@ def kron(
|
|
|
738
738
|
return xp.reshape(result, res_shape)
|
|
739
739
|
|
|
740
740
|
|
|
741
|
+
def nan_to_num( # numpydoc ignore=PR01,RT01
|
|
742
|
+
x: Array,
|
|
743
|
+
/,
|
|
744
|
+
fill_value: int | float = 0.0,
|
|
745
|
+
*,
|
|
746
|
+
xp: ModuleType,
|
|
747
|
+
) -> Array:
|
|
748
|
+
"""See docstring in `array_api_extra._delegation.py`."""
|
|
749
|
+
|
|
750
|
+
def perform_replacements( # numpydoc ignore=PR01,RT01
|
|
751
|
+
x: Array,
|
|
752
|
+
fill_value: int | float,
|
|
753
|
+
xp: ModuleType,
|
|
754
|
+
) -> Array:
|
|
755
|
+
"""Internal function to perform the replacements."""
|
|
756
|
+
x = xp.where(xp.isnan(x), fill_value, x)
|
|
757
|
+
|
|
758
|
+
# convert infinities to finite values
|
|
759
|
+
finfo = xp.finfo(x.dtype)
|
|
760
|
+
idx_posinf = xp.isinf(x) & ~xp.signbit(x)
|
|
761
|
+
idx_neginf = xp.isinf(x) & xp.signbit(x)
|
|
762
|
+
x = xp.where(idx_posinf, finfo.max, x)
|
|
763
|
+
return xp.where(idx_neginf, finfo.min, x)
|
|
764
|
+
|
|
765
|
+
if xp.isdtype(x.dtype, "complex floating"):
|
|
766
|
+
return perform_replacements(
|
|
767
|
+
xp.real(x),
|
|
768
|
+
fill_value,
|
|
769
|
+
xp,
|
|
770
|
+
) + 1j * perform_replacements(
|
|
771
|
+
xp.imag(x),
|
|
772
|
+
fill_value,
|
|
773
|
+
xp,
|
|
774
|
+
)
|
|
775
|
+
|
|
776
|
+
if xp.isdtype(x.dtype, "numeric"):
|
|
777
|
+
return perform_replacements(x, fill_value, xp)
|
|
778
|
+
|
|
779
|
+
return x
|
|
780
|
+
|
|
781
|
+
|
|
741
782
|
def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
|
|
742
783
|
"""
|
|
743
784
|
Count the number of unique elements in an array.
|
|
@@ -6,7 +6,7 @@ import io
|
|
|
6
6
|
import math
|
|
7
7
|
import pickle
|
|
8
8
|
import types
|
|
9
|
-
from collections.abc import Callable, Generator, Iterable
|
|
9
|
+
from collections.abc import Callable, Generator, Iterable, Iterator
|
|
10
10
|
from functools import wraps
|
|
11
11
|
from types import ModuleType
|
|
12
12
|
from typing import (
|
|
@@ -512,13 +512,24 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
|
|
|
512
512
|
convert them to/from PyTrees.
|
|
513
513
|
"""
|
|
514
514
|
|
|
515
|
-
|
|
515
|
+
_obj: Any
|
|
516
|
+
_is_iter: bool
|
|
516
517
|
_registered: ClassVar[bool] = False
|
|
517
|
-
__slots__: tuple[str, ...] = ("
|
|
518
|
+
__slots__: tuple[str, ...] = ("_is_iter", "_obj")
|
|
518
519
|
|
|
519
520
|
def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
|
|
520
521
|
self._register()
|
|
521
|
-
|
|
522
|
+
if isinstance(obj, Iterator):
|
|
523
|
+
self._obj = list(obj)
|
|
524
|
+
self._is_iter = True
|
|
525
|
+
else:
|
|
526
|
+
self._obj = obj
|
|
527
|
+
self._is_iter = False
|
|
528
|
+
|
|
529
|
+
@property
|
|
530
|
+
def obj(self) -> T: # numpydoc ignore=RT01
|
|
531
|
+
"""Return wrapped object."""
|
|
532
|
+
return iter(self._obj) if self._is_iter else self._obj
|
|
522
533
|
|
|
523
534
|
@classmethod
|
|
524
535
|
def _register(cls) -> None: # numpydoc ignore=SS06
|
|
@@ -531,7 +542,7 @@ class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
|
|
|
531
542
|
|
|
532
543
|
jax.tree_util.register_pytree_node(
|
|
533
544
|
cls,
|
|
534
|
-
lambda
|
|
545
|
+
lambda instance: pickle_flatten(instance, jax.Array), # pyright: ignore[reportUnknownArgumentType]
|
|
535
546
|
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
|
|
536
547
|
)
|
|
537
548
|
cls._registered = True
|
|
@@ -556,6 +567,7 @@ def jax_autojit(
|
|
|
556
567
|
- Automatically descend into non-array return values and find ``jax.Array`` objects
|
|
557
568
|
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
|
|
558
569
|
tracer objects with concrete arrays.
|
|
570
|
+
- Returned iterators are immediately completely consumed.
|
|
559
571
|
|
|
560
572
|
See Also
|
|
561
573
|
--------
|
|
@@ -232,3 +232,11 @@ def device(
|
|
|
232
232
|
if library == Backend.TORCH_GPU:
|
|
233
233
|
return xp.device("cpu")
|
|
234
234
|
return get_device(xp.empty(0))
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@pytest.fixture
|
|
238
|
+
def infinity(library: Backend) -> float:
|
|
239
|
+
"""Retrieve the positive infinity value for the given backend."""
|
|
240
|
+
if library in (Backend.TORCH, Backend.TORCH_GPU):
|
|
241
|
+
return 3.4028235e38
|
|
242
|
+
return 1.7976931348623157e308
|
|
@@ -21,6 +21,7 @@ from array_api_extra import (
|
|
|
21
21
|
expand_dims,
|
|
22
22
|
isclose,
|
|
23
23
|
kron,
|
|
24
|
+
nan_to_num,
|
|
24
25
|
nunique,
|
|
25
26
|
one_hot,
|
|
26
27
|
pad,
|
|
@@ -40,6 +41,7 @@ lazy_xp_function(cov)
|
|
|
40
41
|
lazy_xp_function(create_diagonal)
|
|
41
42
|
lazy_xp_function(expand_dims)
|
|
42
43
|
lazy_xp_function(kron)
|
|
44
|
+
lazy_xp_function(nan_to_num)
|
|
43
45
|
lazy_xp_function(nunique)
|
|
44
46
|
lazy_xp_function(one_hot)
|
|
45
47
|
lazy_xp_function(pad)
|
|
@@ -941,6 +943,140 @@ class TestKron:
|
|
|
941
943
|
xp_assert_equal(kron(a, b, xp=xp), k)
|
|
942
944
|
|
|
943
945
|
|
|
946
|
+
class TestNanToNum:
|
|
947
|
+
def test_bool(self, xp: ModuleType) -> None:
|
|
948
|
+
a = xp.asarray([True])
|
|
949
|
+
xp_assert_equal(nan_to_num(a, xp=xp), a)
|
|
950
|
+
|
|
951
|
+
def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None:
|
|
952
|
+
a = xp.inf
|
|
953
|
+
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity))
|
|
954
|
+
|
|
955
|
+
def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None:
|
|
956
|
+
a = -xp.inf
|
|
957
|
+
xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity))
|
|
958
|
+
|
|
959
|
+
def test_scalar_nan(self, xp: ModuleType) -> None:
|
|
960
|
+
a = xp.nan
|
|
961
|
+
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0))
|
|
962
|
+
|
|
963
|
+
def test_real(self, xp: ModuleType, infinity: float) -> None:
|
|
964
|
+
a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
|
|
965
|
+
xp_assert_equal(
|
|
966
|
+
nan_to_num(a, xp=xp),
|
|
967
|
+
xp.asarray(
|
|
968
|
+
[
|
|
969
|
+
infinity,
|
|
970
|
+
-infinity,
|
|
971
|
+
0.0,
|
|
972
|
+
-128,
|
|
973
|
+
128,
|
|
974
|
+
]
|
|
975
|
+
),
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
def test_complex(self, xp: ModuleType, infinity: float) -> None:
|
|
979
|
+
a = xp.asarray(
|
|
980
|
+
[
|
|
981
|
+
complex(xp.inf, xp.nan),
|
|
982
|
+
xp.nan,
|
|
983
|
+
complex(xp.nan, xp.inf),
|
|
984
|
+
]
|
|
985
|
+
)
|
|
986
|
+
xp_assert_equal(
|
|
987
|
+
nan_to_num(a),
|
|
988
|
+
xp.asarray([complex(infinity, 0), complex(0, 0), complex(0, infinity)]),
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
def test_empty_array(self, xp: ModuleType) -> None:
|
|
992
|
+
a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch
|
|
993
|
+
xp_assert_equal(nan_to_num(a, xp=xp), a)
|
|
994
|
+
assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32)
|
|
995
|
+
|
|
996
|
+
@pytest.mark.parametrize(
|
|
997
|
+
("in_vals", "fill_value", "out_vals"),
|
|
998
|
+
[
|
|
999
|
+
([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]),
|
|
1000
|
+
([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]),
|
|
1001
|
+
(
|
|
1002
|
+
[
|
|
1003
|
+
complex(1, 1),
|
|
1004
|
+
complex(2, 2),
|
|
1005
|
+
complex(np.nan, 0),
|
|
1006
|
+
complex(4, 4),
|
|
1007
|
+
],
|
|
1008
|
+
3,
|
|
1009
|
+
[
|
|
1010
|
+
complex(1.0, 1.0),
|
|
1011
|
+
complex(2.0, 2.0),
|
|
1012
|
+
complex(3.0, 0.0),
|
|
1013
|
+
complex(4.0, 4.0),
|
|
1014
|
+
],
|
|
1015
|
+
),
|
|
1016
|
+
(
|
|
1017
|
+
[
|
|
1018
|
+
complex(1, 1),
|
|
1019
|
+
complex(2, 2),
|
|
1020
|
+
complex(0, np.nan),
|
|
1021
|
+
complex(4, 4),
|
|
1022
|
+
],
|
|
1023
|
+
3.0,
|
|
1024
|
+
[
|
|
1025
|
+
complex(1.0, 1.0),
|
|
1026
|
+
complex(2.0, 2.0),
|
|
1027
|
+
complex(0.0, 3.0),
|
|
1028
|
+
complex(4.0, 4.0),
|
|
1029
|
+
],
|
|
1030
|
+
),
|
|
1031
|
+
(
|
|
1032
|
+
[
|
|
1033
|
+
complex(1, 1),
|
|
1034
|
+
complex(2, 2),
|
|
1035
|
+
complex(np.nan, np.nan),
|
|
1036
|
+
complex(4, 4),
|
|
1037
|
+
],
|
|
1038
|
+
3.0,
|
|
1039
|
+
[
|
|
1040
|
+
complex(1.0, 1.0),
|
|
1041
|
+
complex(2.0, 2.0),
|
|
1042
|
+
complex(3.0, 3.0),
|
|
1043
|
+
complex(4.0, 4.0),
|
|
1044
|
+
],
|
|
1045
|
+
),
|
|
1046
|
+
],
|
|
1047
|
+
)
|
|
1048
|
+
def test_fill_value_success(
|
|
1049
|
+
self,
|
|
1050
|
+
xp: ModuleType,
|
|
1051
|
+
in_vals: Array,
|
|
1052
|
+
fill_value: int | float,
|
|
1053
|
+
out_vals: Array,
|
|
1054
|
+
) -> None:
|
|
1055
|
+
a = xp.asarray(in_vals)
|
|
1056
|
+
xp_assert_equal(
|
|
1057
|
+
nan_to_num(a, fill_value=fill_value, xp=xp),
|
|
1058
|
+
xp.asarray(out_vals),
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
def test_fill_value_failure(self, xp: ModuleType) -> None:
|
|
1062
|
+
a = xp.asarray(
|
|
1063
|
+
[
|
|
1064
|
+
complex(1, 1),
|
|
1065
|
+
complex(xp.nan, xp.nan),
|
|
1066
|
+
complex(3, 3),
|
|
1067
|
+
]
|
|
1068
|
+
)
|
|
1069
|
+
with pytest.raises(
|
|
1070
|
+
TypeError,
|
|
1071
|
+
match="Complex fill values are not supported",
|
|
1072
|
+
):
|
|
1073
|
+
_ = nan_to_num(
|
|
1074
|
+
a,
|
|
1075
|
+
fill_value=complex(2, 2), # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
|
1076
|
+
xp=xp,
|
|
1077
|
+
)
|
|
1078
|
+
|
|
1079
|
+
|
|
944
1080
|
class TestNUnique:
|
|
945
1081
|
def test_simple(self, xp: ModuleType):
|
|
946
1082
|
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from collections.abc import Iterator
|
|
1
2
|
from types import ModuleType
|
|
2
3
|
from typing import TYPE_CHECKING, Generic, TypeVar, cast
|
|
3
4
|
|
|
@@ -417,3 +418,16 @@ class TestJAXAutoJIT:
|
|
|
417
418
|
out = f([1, 2])
|
|
418
419
|
assert isinstance(out, list)
|
|
419
420
|
assert out == [3, 4]
|
|
421
|
+
|
|
422
|
+
def test_iterators(self, jnp: ModuleType):
|
|
423
|
+
@jax_autojit
|
|
424
|
+
def f(x: Array) -> Iterator[Array]:
|
|
425
|
+
return (x + i for i in range(2))
|
|
426
|
+
|
|
427
|
+
inp = jnp.asarray([1, 2])
|
|
428
|
+
out = f(inp)
|
|
429
|
+
assert isinstance(out, Iterator)
|
|
430
|
+
xp_assert_equal(next(out), jnp.asarray([1, 2]))
|
|
431
|
+
xp_assert_equal(next(out), jnp.asarray([2, 3]))
|
|
432
|
+
with pytest.raises(StopIteration):
|
|
433
|
+
_ = next(out)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from collections.abc import Callable
|
|
1
|
+
from collections.abc import Callable, Iterator
|
|
2
2
|
from types import ModuleType
|
|
3
3
|
from typing import cast
|
|
4
4
|
|
|
@@ -468,3 +468,22 @@ def test_patch_lazy_xp_functions_deprecated_monkeypatch(
|
|
|
468
468
|
monkeypatch.undo()
|
|
469
469
|
y = non_materializable5(x)
|
|
470
470
|
xp_assert_equal(y, x)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def my_iter(x: Array) -> Iterator[Array]:
|
|
474
|
+
yield x[0, :]
|
|
475
|
+
yield x[1, :]
|
|
476
|
+
|
|
477
|
+
|
|
478
|
+
lazy_xp_function(my_iter)
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def test_patch_lazy_xp_functions_iter(xp: ModuleType):
|
|
482
|
+
x = xp.asarray([[1.0, 2.0], [3.0, 4.0]])
|
|
483
|
+
it = my_iter(x)
|
|
484
|
+
|
|
485
|
+
assert isinstance(it, Iterator)
|
|
486
|
+
xp_assert_equal(next(it), x[0, :])
|
|
487
|
+
xp_assert_equal(next(it), x[1, :])
|
|
488
|
+
with pytest.raises(StopIteration):
|
|
489
|
+
_ = next(it)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|