array-api-extra 0.10.0__tar.gz → 0.10.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/.all-contributorsrc +11 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/CONTRIBUTORS.md +3 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/PKG-INFO +4 -1
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/README.md +3 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/meson.build +1 -1
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/pyproject.toml +6 -5
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/__init__.py +1 -1
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/testing.py +132 -31
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_testing.py +124 -1
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/.dprint.jsonc +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/.editorconfig +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/LICENSE +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/api-lazy.md +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/api-reference.md +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/conf.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/contributing.md +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/contributors.md +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/index.md +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/lefthook.yml +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_delegation.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/__init__.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_at.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_backends.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_funcs.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_lazy.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_testing.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/__init__.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_compat.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_compat.pyi +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_helpers.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_typing.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_typing.pyi +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/py.typed +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/__init__.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/conftest.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/meson.build +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_at.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_funcs.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_helpers.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_lazy.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_version.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/typos.toml +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/vendor_tests/__init__.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/vendor_tests/_array_api_compat_vendor.py +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/vendor_tests/meson.build +0 -0
- {array_api_extra-0.10.0 → array_api_extra-0.10.1}/vendor_tests/test_vendor.py +0 -0
|
@@ -308,6 +308,17 @@
|
|
|
308
308
|
"ideas",
|
|
309
309
|
"test"
|
|
310
310
|
]
|
|
311
|
+
},
|
|
312
|
+
{
|
|
313
|
+
"login": "steppi",
|
|
314
|
+
"name": "Albert Steppi",
|
|
315
|
+
"avatar_url": "https://avatars.githubusercontent.com/u/1953382?v=4",
|
|
316
|
+
"profile": "http://steppi.github.io",
|
|
317
|
+
"contributions": [
|
|
318
|
+
"code",
|
|
319
|
+
"ideas",
|
|
320
|
+
"test"
|
|
321
|
+
]
|
|
311
322
|
}
|
|
312
323
|
]
|
|
313
324
|
}
|
|
@@ -44,6 +44,9 @@ This project exists thanks to the following contributors
|
|
|
44
44
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/adriagarp"><img src="https://avatars.githubusercontent.com/u/96059447?v=4?s=100" width="100px;" alt="Adrián García Pitarch"/><br /><sub><b>Adrián García Pitarch</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Tests">⚠️</a></td>
|
|
45
45
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/cakedev0"><img src="https://avatars.githubusercontent.com/u/25986961?v=4?s=100" width="100px;" alt="Arthur Lacote"/><br /><sub><b>Arthur Lacote</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Documentation">📖</a> <a href="#ideas-cakedev0" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Tests">⚠️</a></td>
|
|
46
46
|
</tr>
|
|
47
|
+
<tr>
|
|
48
|
+
<td align="center" valign="top" width="14.28%"><a href="http://steppi.github.io"><img src="https://avatars.githubusercontent.com/u/1953382?v=4?s=100" width="100px;" alt="Albert Steppi"/><br /><sub><b>Albert Steppi</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Code">💻</a> <a href="#ideas-steppi" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Tests">⚠️</a></td>
|
|
49
|
+
</tr>
|
|
47
50
|
</tbody>
|
|
48
51
|
</table>
|
|
49
52
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: array-api-extra
|
|
3
|
-
Version: 0.10.
|
|
3
|
+
Version: 0.10.1
|
|
4
4
|
Summary: Extra array functions built on top of the array API standard.
|
|
5
5
|
Author-Email: Lucas Colley <lucas.colley8@gmail.com>
|
|
6
6
|
License-Expression: MIT
|
|
@@ -124,6 +124,9 @@ This project exists thanks to the following contributors
|
|
|
124
124
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/adriagarp"><img src="https://avatars.githubusercontent.com/u/96059447?v=4?s=100" width="100px;" alt="Adrián García Pitarch"/><br /><sub><b>Adrián García Pitarch</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Tests">⚠️</a></td>
|
|
125
125
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/cakedev0"><img src="https://avatars.githubusercontent.com/u/25986961?v=4?s=100" width="100px;" alt="Arthur Lacote"/><br /><sub><b>Arthur Lacote</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Documentation">📖</a> <a href="#ideas-cakedev0" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Tests">⚠️</a></td>
|
|
126
126
|
</tr>
|
|
127
|
+
<tr>
|
|
128
|
+
<td align="center" valign="top" width="14.28%"><a href="http://steppi.github.io"><img src="https://avatars.githubusercontent.com/u/1953382?v=4?s=100" width="100px;" alt="Albert Steppi"/><br /><sub><b>Albert Steppi</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Code">💻</a> <a href="#ideas-steppi" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Tests">⚠️</a></td>
|
|
129
|
+
</tr>
|
|
127
130
|
</tbody>
|
|
128
131
|
</table>
|
|
129
132
|
|
|
@@ -99,6 +99,9 @@ This project exists thanks to the following contributors
|
|
|
99
99
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/adriagarp"><img src="https://avatars.githubusercontent.com/u/96059447?v=4?s=100" width="100px;" alt="Adrián García Pitarch"/><br /><sub><b>Adrián García Pitarch</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Tests">⚠️</a></td>
|
|
100
100
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/cakedev0"><img src="https://avatars.githubusercontent.com/u/25986961?v=4?s=100" width="100px;" alt="Arthur Lacote"/><br /><sub><b>Arthur Lacote</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Documentation">📖</a> <a href="#ideas-cakedev0" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Tests">⚠️</a></td>
|
|
101
101
|
</tr>
|
|
102
|
+
<tr>
|
|
103
|
+
<td align="center" valign="top" width="14.28%"><a href="http://steppi.github.io"><img src="https://avatars.githubusercontent.com/u/1953382?v=4?s=100" width="100px;" alt="Albert Steppi"/><br /><sub><b>Albert Steppi</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Code">💻</a> <a href="#ideas-steppi" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Tests">⚠️</a></td>
|
|
104
|
+
</tr>
|
|
102
105
|
</tbody>
|
|
103
106
|
</table>
|
|
104
107
|
|
|
@@ -4,7 +4,7 @@ build-backend = "mesonpy"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "array-api-extra"
|
|
7
|
-
version = "0.10.
|
|
7
|
+
version = "0.10.1"
|
|
8
8
|
authors = [
|
|
9
9
|
{ name = "Lucas Colley", email = "lucas.colley8@gmail.com" },
|
|
10
10
|
# { name = "Open Source Contributors" }, # https://github.com/pypi/warehouse/issues/14813
|
|
@@ -93,7 +93,7 @@ array-api-extra.path = "."
|
|
|
93
93
|
typing-extensions = ">=4.15.0"
|
|
94
94
|
pylint = ">=4.0.4"
|
|
95
95
|
mypy = ">=1.19.1"
|
|
96
|
-
basedpyright = ">=1.
|
|
96
|
+
basedpyright = ">=1.38.0"
|
|
97
97
|
numpydoc = ">=1.10.0,<2"
|
|
98
98
|
# import dependencies for mypy:
|
|
99
99
|
array-api-strict = ">=2.4.1,<2.5"
|
|
@@ -101,9 +101,9 @@ numpy = ">=2.1.3"
|
|
|
101
101
|
hypothesis = ">=6.151.2"
|
|
102
102
|
dask-core = ">=2026.1.2" # No distributed, tornado, etc.
|
|
103
103
|
dprint = ">=0.50.0,<0.51"
|
|
104
|
-
lefthook = ">=2.1.
|
|
105
|
-
ruff = ">=0.15.
|
|
106
|
-
typos = ">=1.43.
|
|
104
|
+
lefthook = ">=2.1.1,<3"
|
|
105
|
+
ruff = ">=0.15.1,<0.16"
|
|
106
|
+
typos = ">=1.43.4,<2"
|
|
107
107
|
actionlint = ">=1.7.10,<2"
|
|
108
108
|
blacken-docs = ">=1.20.0,<2"
|
|
109
109
|
pytest = ">=9.0.2,<10"
|
|
@@ -134,6 +134,7 @@ pytest-cov = ">=7.0.0"
|
|
|
134
134
|
hypothesis = ">=6.151.2"
|
|
135
135
|
array-api-strict = ">=2.4.1,<2.5"
|
|
136
136
|
numpy = ">=1.22.0"
|
|
137
|
+
scipy = ">=1.15.2,<2"
|
|
137
138
|
|
|
138
139
|
[tool.pixi.feature.tests.tasks]
|
|
139
140
|
tests = { cmd = "pytest -v", description = "Run tests" }
|
|
@@ -10,8 +10,9 @@ import contextlib
|
|
|
10
10
|
import enum
|
|
11
11
|
import warnings
|
|
12
12
|
from collections.abc import Callable, Generator, Iterator, Sequence
|
|
13
|
-
from functools import wraps
|
|
14
|
-
from
|
|
13
|
+
from functools import update_wrapper, wraps
|
|
14
|
+
from inspect import getattr_static
|
|
15
|
+
from types import FunctionType, ModuleType
|
|
15
16
|
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
|
|
16
17
|
|
|
17
18
|
from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
|
|
@@ -48,8 +49,21 @@ class Deprecated(enum.Enum):
|
|
|
48
49
|
DEPRECATED = Deprecated.DEPRECATED
|
|
49
50
|
|
|
50
51
|
|
|
52
|
+
def _clone_function(f: Callable[..., Any]) -> Callable[..., Any]:
|
|
53
|
+
"""Returns a clone of an existing function."""
|
|
54
|
+
f_new = FunctionType(
|
|
55
|
+
f.__code__,
|
|
56
|
+
f.__globals__,
|
|
57
|
+
name=f.__name__,
|
|
58
|
+
argdefs=f.__defaults__,
|
|
59
|
+
closure=f.__closure__,
|
|
60
|
+
)
|
|
61
|
+
f_new.__kwdefaults__ = f.__kwdefaults__
|
|
62
|
+
return update_wrapper(f_new, f)
|
|
63
|
+
|
|
64
|
+
|
|
51
65
|
def lazy_xp_function(
|
|
52
|
-
func: Callable[..., Any],
|
|
66
|
+
func: Callable[..., Any] | tuple[type, str],
|
|
53
67
|
*,
|
|
54
68
|
allow_dask_compute: bool | int = False,
|
|
55
69
|
jax_jit: bool = True,
|
|
@@ -69,8 +83,9 @@ def lazy_xp_function(
|
|
|
69
83
|
|
|
70
84
|
Parameters
|
|
71
85
|
----------
|
|
72
|
-
func : callable
|
|
73
|
-
Function to be tested
|
|
86
|
+
func : callable | tuple[type, str]
|
|
87
|
+
Function to be tested, or a tuple containing an (uninstantiated) class and a
|
|
88
|
+
method name to specify a class method to be tested.
|
|
74
89
|
allow_dask_compute : bool | int, optional
|
|
75
90
|
Whether `func` is allowed to internally materialize the Dask graph, or maximum
|
|
76
91
|
number of times it is allowed to do so. This is typically triggered by
|
|
@@ -204,15 +219,49 @@ def lazy_xp_function(
|
|
|
204
219
|
DeprecationWarning,
|
|
205
220
|
stacklevel=2,
|
|
206
221
|
)
|
|
207
|
-
tags = {
|
|
222
|
+
tags: dict[str, bool | int | type] = {
|
|
208
223
|
"allow_dask_compute": allow_dask_compute,
|
|
209
224
|
"jax_jit": jax_jit,
|
|
210
225
|
}
|
|
211
226
|
|
|
227
|
+
if isinstance(func, tuple):
|
|
228
|
+
# Replace the method with a clone before adding tags
|
|
229
|
+
# to avoid adding unwanted tags to a parent method when
|
|
230
|
+
# the method was inherited from a parent class.
|
|
231
|
+
# Note: can't just accept an unbound method `cls.method_name` because in
|
|
232
|
+
# case of inheritance it would be impossible to attribute it to the child class.
|
|
233
|
+
# This also makes it so tagged methods will appear in their class's ``__dict__``
|
|
234
|
+
# and thus findable by ``iter_tagged_modules`` below.
|
|
235
|
+
cls, method_name = func
|
|
236
|
+
# The method might be a staticmethod or classmethod so we need to do a dance
|
|
237
|
+
# to ensure that this is preserved.
|
|
238
|
+
raw_attr = getattr_static(cls, method_name)
|
|
239
|
+
method = getattr(cls, method_name)
|
|
240
|
+
if isinstance(raw_attr, classmethod):
|
|
241
|
+
method = method.__func__
|
|
242
|
+
cloned_method = _clone_function(method)
|
|
243
|
+
|
|
244
|
+
method_to_set: Any
|
|
245
|
+
if isinstance(raw_attr, staticmethod):
|
|
246
|
+
method_to_set = staticmethod(cloned_method)
|
|
247
|
+
elif isinstance(raw_attr, classmethod):
|
|
248
|
+
method_to_set = classmethod(cloned_method)
|
|
249
|
+
else:
|
|
250
|
+
method_to_set = cloned_method
|
|
251
|
+
|
|
252
|
+
setattr(cls, method_name, method_to_set)
|
|
253
|
+
f = getattr(cls, method_name)
|
|
254
|
+
if isinstance(raw_attr, classmethod):
|
|
255
|
+
f = f.__func__
|
|
256
|
+
# Annotate that cls owns this method so we can check that later.
|
|
257
|
+
tags["owner"] = cls
|
|
258
|
+
else:
|
|
259
|
+
f = func
|
|
260
|
+
|
|
212
261
|
try:
|
|
213
|
-
|
|
262
|
+
f._lazy_xp_function = tags # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
|
|
214
263
|
except AttributeError: # @cython.vectorize
|
|
215
|
-
_ufuncs_tags[
|
|
264
|
+
_ufuncs_tags[f] = tags
|
|
216
265
|
|
|
217
266
|
|
|
218
267
|
def patch_lazy_xp_functions(
|
|
@@ -224,10 +273,11 @@ def patch_lazy_xp_functions(
|
|
|
224
273
|
"""
|
|
225
274
|
Test lazy execution of functions tagged with :func:`lazy_xp_function`.
|
|
226
275
|
|
|
227
|
-
If ``xp==jax.numpy``, search for all functions which have been tagged
|
|
228
|
-
:func:`lazy_xp_function` in the globals of the module that defines the current
|
|
229
|
-
as well as in the ``lazy_xp_modules`` list in the globals of the same module,
|
|
230
|
-
and wrap them with :func:`jax.jit`.
|
|
276
|
+
If ``xp==jax.numpy``, search for all functions and methods which have been tagged
|
|
277
|
+
with :func:`lazy_xp_function` in the globals of the module that defines the current
|
|
278
|
+
test, as well as in the ``lazy_xp_modules`` list in the globals of the same module,
|
|
279
|
+
and wrap them with :func:`jax.jit`.
|
|
280
|
+
Unwrap them at the end of the test.
|
|
231
281
|
|
|
232
282
|
If ``xp==dask.array``, wrap the functions with a decorator that disables
|
|
233
283
|
``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
|
|
@@ -271,18 +321,34 @@ def patch_lazy_xp_functions(
|
|
|
271
321
|
the example above.
|
|
272
322
|
"""
|
|
273
323
|
mod = cast(ModuleType, request.module)
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
324
|
+
search_targets: list[ModuleType | type] = [
|
|
325
|
+
mod,
|
|
326
|
+
*cast(list[ModuleType], getattr(mod, "lazy_xp_modules", [])),
|
|
327
|
+
]
|
|
328
|
+
# Also search for classes within the above modules which have had lazy_xp_function
|
|
329
|
+
# applied to methods through ``lazy_xp_function((cls, method_name))`` syntax.
|
|
330
|
+
# We might end up adding classes incidentally imported into modules, so using a
|
|
331
|
+
# set here to cut down on potential redundancy.
|
|
332
|
+
classes: set[type] = set()
|
|
333
|
+
for target in search_targets:
|
|
334
|
+
for obj in target.__dict__.values():
|
|
335
|
+
if isinstance(obj, type):
|
|
336
|
+
classes.add(obj)
|
|
337
|
+
search_targets.extend(classes)
|
|
338
|
+
|
|
339
|
+
to_revert: list[tuple[ModuleType | type, str, object]] = []
|
|
340
|
+
|
|
341
|
+
def temp_setattr(target: ModuleType | type, name: str, func: object) -> None:
|
|
279
342
|
"""
|
|
280
343
|
Variant of monkeypatch.setattr, which allows monkey-patching only selected
|
|
281
344
|
parameters of a test so that pytest-run-parallel can run on the remainder.
|
|
282
345
|
"""
|
|
283
|
-
assert hasattr(
|
|
284
|
-
|
|
285
|
-
|
|
346
|
+
assert hasattr(target, name)
|
|
347
|
+
# Need getattr_static because the attr could be a staticmethod or other
|
|
348
|
+
# descriptor and we don't want that to be stripped away.
|
|
349
|
+
original = getattr_static(target, name)
|
|
350
|
+
to_revert.append((target, name, original))
|
|
351
|
+
setattr(target, name, func)
|
|
286
352
|
|
|
287
353
|
if monkeypatch is not None:
|
|
288
354
|
warnings.warn(
|
|
@@ -298,10 +364,19 @@ def patch_lazy_xp_functions(
|
|
|
298
364
|
temp_setattr = monkeypatch.setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
|
|
299
365
|
|
|
300
366
|
def iter_tagged() -> Iterator[
|
|
301
|
-
tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]
|
|
367
|
+
tuple[ModuleType | type, str, Any, Callable[..., Any], dict[str, Any]]
|
|
302
368
|
]:
|
|
303
|
-
for
|
|
304
|
-
for name,
|
|
369
|
+
for target in search_targets:
|
|
370
|
+
for name, attr in target.__dict__.items():
|
|
371
|
+
# attr might be a staticmethod or classmethod. If so we need
|
|
372
|
+
# to peel it back and wrap the underlying function and later
|
|
373
|
+
# make sure not to accidentally replace it with a regular
|
|
374
|
+
# method.
|
|
375
|
+
func: Any = (
|
|
376
|
+
attr.__func__
|
|
377
|
+
if isinstance(attr, (staticmethod, classmethod))
|
|
378
|
+
else attr
|
|
379
|
+
)
|
|
305
380
|
tags: dict[str, Any] | None = None
|
|
306
381
|
with contextlib.suppress(AttributeError):
|
|
307
382
|
tags = func._lazy_xp_function # pylint: disable=protected-access
|
|
@@ -309,23 +384,49 @@ def patch_lazy_xp_functions(
|
|
|
309
384
|
with contextlib.suppress(KeyError, TypeError):
|
|
310
385
|
tags = _ufuncs_tags[func]
|
|
311
386
|
if tags is not None:
|
|
312
|
-
|
|
313
|
-
|
|
387
|
+
if isinstance(target, type) and tags.get("owner") is not target:
|
|
388
|
+
# There's a common pattern to wrap functions in namespace
|
|
389
|
+
# classes to bypass lazy_xp_function like this:
|
|
390
|
+
#
|
|
391
|
+
# class naked:
|
|
392
|
+
# myfunc = mymodule.myfunc
|
|
393
|
+
#
|
|
394
|
+
# To ensure this still works when checking for tags in
|
|
395
|
+
# attributes of classes, ensure that target is the actual
|
|
396
|
+
# owning class where func was defined.
|
|
397
|
+
continue
|
|
398
|
+
# put attr, and func in the outputs so we can later tell
|
|
399
|
+
# if this was a staticmethod or classmethod.
|
|
400
|
+
yield target, name, attr, func, tags
|
|
401
|
+
|
|
402
|
+
wrapped: Any
|
|
314
403
|
if is_dask_namespace(xp):
|
|
315
|
-
for
|
|
404
|
+
for target, name, attr, func, tags in iter_tagged():
|
|
316
405
|
n = tags["allow_dask_compute"]
|
|
317
406
|
if n is True:
|
|
318
407
|
n = 1_000_000
|
|
319
408
|
elif n is False:
|
|
320
409
|
n = 0
|
|
321
410
|
wrapped = _dask_wrap(func, n)
|
|
322
|
-
|
|
411
|
+
# If we're dealing with a staticmethod or classmethod, make
|
|
412
|
+
# sure things stay that way.
|
|
413
|
+
if isinstance(attr, staticmethod):
|
|
414
|
+
wrapped = staticmethod(wrapped)
|
|
415
|
+
elif isinstance(attr, classmethod):
|
|
416
|
+
wrapped = classmethod(wrapped)
|
|
417
|
+
temp_setattr(target, name, wrapped)
|
|
323
418
|
|
|
324
419
|
elif is_jax_namespace(xp):
|
|
325
|
-
for
|
|
420
|
+
for target, name, attr, func, tags in iter_tagged():
|
|
326
421
|
if tags["jax_jit"]:
|
|
327
422
|
wrapped = jax_autojit(func)
|
|
328
|
-
|
|
423
|
+
# If we're dealing with a staticmethod or classmethod, make
|
|
424
|
+
# sure things stay that way.
|
|
425
|
+
if isinstance(attr, staticmethod):
|
|
426
|
+
wrapped = staticmethod(wrapped)
|
|
427
|
+
elif isinstance(attr, classmethod):
|
|
428
|
+
wrapped = classmethod(wrapped)
|
|
429
|
+
temp_setattr(target, name, wrapped)
|
|
329
430
|
|
|
330
431
|
# We can't just decorate patch_lazy_xp_functions with
|
|
331
432
|
# @contextlib.contextmanager because it would not work with the
|
|
@@ -335,8 +436,8 @@ def patch_lazy_xp_functions(
|
|
|
335
436
|
try:
|
|
336
437
|
yield
|
|
337
438
|
finally:
|
|
338
|
-
for
|
|
339
|
-
setattr(
|
|
439
|
+
for target, name, orig_func in to_revert:
|
|
440
|
+
setattr(target, name, orig_func)
|
|
340
441
|
|
|
341
442
|
return revert_on_exit()
|
|
342
443
|
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
from collections.abc import Callable, Iterator
|
|
2
2
|
from types import ModuleType
|
|
3
|
-
from typing import cast
|
|
3
|
+
from typing import Any, cast, final
|
|
4
4
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import pytest
|
|
7
|
+
from typing_extensions import override
|
|
7
8
|
|
|
8
9
|
from array_api_extra._lib._backends import Backend
|
|
9
10
|
from array_api_extra._lib._testing import (
|
|
@@ -321,6 +322,128 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
|
|
|
321
322
|
xp_assert_equal(cast(Array, erf(x)), xp.asarray([1.0, 1.0]))
|
|
322
323
|
|
|
323
324
|
|
|
325
|
+
class A:
|
|
326
|
+
def __init__(self, x: Array):
|
|
327
|
+
xp = array_namespace(x)
|
|
328
|
+
self._xp: ModuleType = xp
|
|
329
|
+
self.x: Any = np.asarray(x)
|
|
330
|
+
|
|
331
|
+
def f(self, y: Array) -> Array:
|
|
332
|
+
return self._xp.asarray(np.matmul(self.x, np.asarray(y)))
|
|
333
|
+
|
|
334
|
+
def g(self, y: Array, z: Array) -> Array:
|
|
335
|
+
return self.f(y) + self.f(z)
|
|
336
|
+
|
|
337
|
+
def h(self, y: Array) -> bool:
|
|
338
|
+
return bool(self._xp.any(y))
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class B(A):
|
|
342
|
+
@override
|
|
343
|
+
def __init__(self, x: Array): # pyright: ignore[reportMissingSuperCall]
|
|
344
|
+
xp = array_namespace(x)
|
|
345
|
+
self._xp: ModuleType = xp
|
|
346
|
+
self.x: Any = xp.asarray(x)
|
|
347
|
+
|
|
348
|
+
@override
|
|
349
|
+
def f(self, y: Array) -> Array:
|
|
350
|
+
return self._xp.matmul(self.x, y)
|
|
351
|
+
|
|
352
|
+
@staticmethod
|
|
353
|
+
def k(y: Array) -> "B":
|
|
354
|
+
return B(2.0 * y)
|
|
355
|
+
|
|
356
|
+
@staticmethod
|
|
357
|
+
def j(y: Array) -> "B":
|
|
358
|
+
xp = array_namespace(y)
|
|
359
|
+
y = xp.asarray(y)
|
|
360
|
+
if bool(xp.any(y)):
|
|
361
|
+
return B(y)
|
|
362
|
+
return B(y + 1.0)
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def w(cls, y: Array) -> "B":
|
|
366
|
+
xp = array_namespace(y)
|
|
367
|
+
y = xp.asarray(y)
|
|
368
|
+
if bool(xp.any(y)):
|
|
369
|
+
return B(y)
|
|
370
|
+
return B(y + 1.0)
|
|
371
|
+
|
|
372
|
+
|
|
373
|
+
@final
|
|
374
|
+
class eager:
|
|
375
|
+
# this needs to be a staticmethod to appease the type checker
|
|
376
|
+
non_materializable5 = staticmethod(non_materializable5)
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
lazy_xp_function((B, "g"))
|
|
380
|
+
lazy_xp_function((B, "h"))
|
|
381
|
+
lazy_xp_function((B, "k"))
|
|
382
|
+
lazy_xp_function((B, "j"))
|
|
383
|
+
lazy_xp_function((B, "w"))
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class TestLazyXpFunctionClasses:
|
|
387
|
+
def test_parent_method_not_tagged(self):
|
|
388
|
+
assert hasattr(B.g, "_lazy_xp_function")
|
|
389
|
+
assert not hasattr(A.g, "_lazy_xp_function")
|
|
390
|
+
|
|
391
|
+
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="converts to NumPy")
|
|
392
|
+
@pytest.mark.skip_xp_backend(Backend.CUPY, reason="converts to NumPy")
|
|
393
|
+
@pytest.mark.skip_xp_backend(Backend.JAX_GPU, reason="converts to NumPy")
|
|
394
|
+
@pytest.mark.skip_xp_backend(Backend.TORCH_GPU, reason="converts to NumPy")
|
|
395
|
+
def test_lazy_xp_function_classes(self, xp: ModuleType, library: Backend):
|
|
396
|
+
x = xp.asarray([1.1, 2.2, 3.3])
|
|
397
|
+
y = xp.asarray([1.0, 2.0, 3.0])
|
|
398
|
+
foo = A(x)
|
|
399
|
+
bar = B(x)
|
|
400
|
+
|
|
401
|
+
if library.like(Backend.JAX):
|
|
402
|
+
with pytest.raises(
|
|
403
|
+
TypeError, match="Attempted boolean conversion of traced array"
|
|
404
|
+
):
|
|
405
|
+
assert bar.h(y)
|
|
406
|
+
|
|
407
|
+
assert foo.h(y)
|
|
408
|
+
|
|
409
|
+
def test_static_methods_preserved(self, xp: ModuleType):
|
|
410
|
+
# Tests that static methods stay static methods when
|
|
411
|
+
# lazy_xp_function is applied.
|
|
412
|
+
x = xp.asarray([1.1, 2.2, 3.3])
|
|
413
|
+
foo = B(x)
|
|
414
|
+
bar = foo.k(x)
|
|
415
|
+
xp_assert_equal(bar.x, 2.0 * foo.x)
|
|
416
|
+
|
|
417
|
+
@pytest.mark.skip_xp_backend(Backend.DASK, reason="calls dask.compute()")
|
|
418
|
+
def test_static_methods_wrapped(self, xp: ModuleType, library: Backend):
|
|
419
|
+
x = xp.asarray([1.1, 2.2, 3.3])
|
|
420
|
+
foo = B(x)
|
|
421
|
+
|
|
422
|
+
if library.like(Backend.JAX):
|
|
423
|
+
with pytest.raises(
|
|
424
|
+
TypeError, match="Attempted boolean conversion of traced array"
|
|
425
|
+
):
|
|
426
|
+
assert isinstance(foo.j(x), B)
|
|
427
|
+
else:
|
|
428
|
+
assert isinstance(foo.j(x), B)
|
|
429
|
+
|
|
430
|
+
@pytest.mark.skip_xp_backend(Backend.DASK, reason="calls dask.compute()")
|
|
431
|
+
def test_class_methods_wrapped(self, xp: ModuleType, library: Backend):
|
|
432
|
+
x = xp.asarray([1.1, 2.2, 3.3])
|
|
433
|
+
if library.like(Backend.JAX):
|
|
434
|
+
with pytest.raises(
|
|
435
|
+
TypeError, match="Attempted boolean conversion of traced array"
|
|
436
|
+
):
|
|
437
|
+
assert isinstance(B.w(x), B)
|
|
438
|
+
else:
|
|
439
|
+
assert isinstance(B.w(x), B)
|
|
440
|
+
|
|
441
|
+
def test_circumvention(self, xp: ModuleType):
|
|
442
|
+
x = xp.asarray([1.0, 2.0])
|
|
443
|
+
y = eager.non_materializable5(x)
|
|
444
|
+
xp_assert_equal(y, x)
|
|
445
|
+
|
|
446
|
+
|
|
324
447
|
def dask_raises(x: Array) -> Array:
|
|
325
448
|
def _raises(x: Array) -> Array:
|
|
326
449
|
# Test that map_blocks doesn't eagerly call the function;
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/__init__.py
RENAMED
|
File without changes
|
{array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_compat.py
RENAMED
|
File without changes
|
{array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_compat.pyi
RENAMED
|
File without changes
|
{array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_helpers.py
RENAMED
|
File without changes
|
{array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_typing.py
RENAMED
|
File without changes
|
{array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_typing.pyi
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|