array-api-extra 0.10.0__tar.gz → 0.10.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/.all-contributorsrc +11 -0
  2. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/CONTRIBUTORS.md +3 -0
  3. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/PKG-INFO +4 -1
  4. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/README.md +3 -0
  5. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/meson.build +1 -1
  6. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/pyproject.toml +6 -5
  7. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/__init__.py +1 -1
  8. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/testing.py +132 -31
  9. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_testing.py +124 -1
  10. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/.dprint.jsonc +0 -0
  11. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/.editorconfig +0 -0
  12. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/LICENSE +0 -0
  13. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/api-lazy.md +0 -0
  14. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/api-reference.md +0 -0
  15. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/conf.py +0 -0
  16. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/contributing.md +0 -0
  17. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/contributors.md +0 -0
  18. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/docs/index.md +0 -0
  19. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/lefthook.yml +0 -0
  20. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_delegation.py +0 -0
  21. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/__init__.py +0 -0
  22. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_at.py +0 -0
  23. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_backends.py +0 -0
  24. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_funcs.py +0 -0
  25. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_lazy.py +0 -0
  26. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_testing.py +0 -0
  27. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/__init__.py +0 -0
  28. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_compat.py +0 -0
  29. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_compat.pyi +0 -0
  30. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_helpers.py +0 -0
  31. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_typing.py +0 -0
  32. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/_lib/_utils/_typing.pyi +0 -0
  33. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/src/array_api_extra/py.typed +0 -0
  34. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/__init__.py +0 -0
  35. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/conftest.py +0 -0
  36. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/meson.build +0 -0
  37. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_at.py +0 -0
  38. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_funcs.py +0 -0
  39. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_helpers.py +0 -0
  40. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_lazy.py +0 -0
  41. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/tests/test_version.py +0 -0
  42. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/typos.toml +0 -0
  43. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/vendor_tests/__init__.py +0 -0
  44. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/vendor_tests/_array_api_compat_vendor.py +0 -0
  45. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/vendor_tests/meson.build +0 -0
  46. {array_api_extra-0.10.0 → array_api_extra-0.10.1}/vendor_tests/test_vendor.py +0 -0
@@ -308,6 +308,17 @@
308
308
  "ideas",
309
309
  "test"
310
310
  ]
311
+ },
312
+ {
313
+ "login": "steppi",
314
+ "name": "Albert Steppi",
315
+ "avatar_url": "https://avatars.githubusercontent.com/u/1953382?v=4",
316
+ "profile": "http://steppi.github.io",
317
+ "contributions": [
318
+ "code",
319
+ "ideas",
320
+ "test"
321
+ ]
311
322
  }
312
323
  ]
313
324
  }
@@ -44,6 +44,9 @@ This project exists thanks to the following contributors
44
44
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/adriagarp"><img src="https://avatars.githubusercontent.com/u/96059447?v=4?s=100" width="100px;" alt="Adrián García Pitarch"/><br /><sub><b>Adrián García Pitarch</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Tests">⚠️</a></td>
45
45
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/cakedev0"><img src="https://avatars.githubusercontent.com/u/25986961?v=4?s=100" width="100px;" alt="Arthur Lacote"/><br /><sub><b>Arthur Lacote</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Documentation">📖</a> <a href="#ideas-cakedev0" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Tests">⚠️</a></td>
46
46
  </tr>
47
+ <tr>
48
+ <td align="center" valign="top" width="14.28%"><a href="http://steppi.github.io"><img src="https://avatars.githubusercontent.com/u/1953382?v=4?s=100" width="100px;" alt="Albert Steppi"/><br /><sub><b>Albert Steppi</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Code">💻</a> <a href="#ideas-steppi" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Tests">⚠️</a></td>
49
+ </tr>
47
50
  </tbody>
48
51
  </table>
49
52
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: array-api-extra
3
- Version: 0.10.0
3
+ Version: 0.10.1
4
4
  Summary: Extra array functions built on top of the array API standard.
5
5
  Author-Email: Lucas Colley <lucas.colley8@gmail.com>
6
6
  License-Expression: MIT
@@ -124,6 +124,9 @@ This project exists thanks to the following contributors
124
124
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/adriagarp"><img src="https://avatars.githubusercontent.com/u/96059447?v=4?s=100" width="100px;" alt="Adrián García Pitarch"/><br /><sub><b>Adrián García Pitarch</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Tests">⚠️</a></td>
125
125
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/cakedev0"><img src="https://avatars.githubusercontent.com/u/25986961?v=4?s=100" width="100px;" alt="Arthur Lacote"/><br /><sub><b>Arthur Lacote</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Documentation">📖</a> <a href="#ideas-cakedev0" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Tests">⚠️</a></td>
126
126
  </tr>
127
+ <tr>
128
+ <td align="center" valign="top" width="14.28%"><a href="http://steppi.github.io"><img src="https://avatars.githubusercontent.com/u/1953382?v=4?s=100" width="100px;" alt="Albert Steppi"/><br /><sub><b>Albert Steppi</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Code">💻</a> <a href="#ideas-steppi" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Tests">⚠️</a></td>
129
+ </tr>
127
130
  </tbody>
128
131
  </table>
129
132
 
@@ -99,6 +99,9 @@ This project exists thanks to the following contributors
99
99
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/adriagarp"><img src="https://avatars.githubusercontent.com/u/96059447?v=4?s=100" width="100px;" alt="Adrián García Pitarch"/><br /><sub><b>Adrián García Pitarch</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=adriagarp" title="Tests">⚠️</a></td>
100
100
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/cakedev0"><img src="https://avatars.githubusercontent.com/u/25986961?v=4?s=100" width="100px;" alt="Arthur Lacote"/><br /><sub><b>Arthur Lacote</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Code">💻</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Documentation">📖</a> <a href="#ideas-cakedev0" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=cakedev0" title="Tests">⚠️</a></td>
101
101
  </tr>
102
+ <tr>
103
+ <td align="center" valign="top" width="14.28%"><a href="http://steppi.github.io"><img src="https://avatars.githubusercontent.com/u/1953382?v=4?s=100" width="100px;" alt="Albert Steppi"/><br /><sub><b>Albert Steppi</b></sub></a><br /><a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Code">💻</a> <a href="#ideas-steppi" title="Ideas, Planning, & Feedback">🤔</a> <a href="https://github.com/data-apis/array-api-extra/commits?author=steppi" title="Tests">⚠️</a></td>
104
+ </tr>
102
105
  </tbody>
103
106
  </table>
104
107
 
@@ -1,6 +1,6 @@
1
1
  project(
2
2
  'array-api-extra',
3
- version: '0.10.0'
3
+ version: '0.10.1'
4
4
  )
5
5
 
6
6
  py = import('python').find_installation()
@@ -4,7 +4,7 @@ build-backend = "mesonpy"
4
4
 
5
5
  [project]
6
6
  name = "array-api-extra"
7
- version = "0.10.0"
7
+ version = "0.10.1"
8
8
  authors = [
9
9
  { name = "Lucas Colley", email = "lucas.colley8@gmail.com" },
10
10
  # { name = "Open Source Contributors" }, # https://github.com/pypi/warehouse/issues/14813
@@ -93,7 +93,7 @@ array-api-extra.path = "."
93
93
  typing-extensions = ">=4.15.0"
94
94
  pylint = ">=4.0.4"
95
95
  mypy = ">=1.19.1"
96
- basedpyright = ">=1.37.4"
96
+ basedpyright = ">=1.38.0"
97
97
  numpydoc = ">=1.10.0,<2"
98
98
  # import dependencies for mypy:
99
99
  array-api-strict = ">=2.4.1,<2.5"
@@ -101,9 +101,9 @@ numpy = ">=2.1.3"
101
101
  hypothesis = ">=6.151.2"
102
102
  dask-core = ">=2026.1.2" # No distributed, tornado, etc.
103
103
  dprint = ">=0.50.0,<0.51"
104
- lefthook = ">=2.1.0,<3"
105
- ruff = ">=0.15.0,<0.16"
106
- typos = ">=1.43.3,<2"
104
+ lefthook = ">=2.1.1,<3"
105
+ ruff = ">=0.15.1,<0.16"
106
+ typos = ">=1.43.4,<2"
107
107
  actionlint = ">=1.7.10,<2"
108
108
  blacken-docs = ">=1.20.0,<2"
109
109
  pytest = ">=9.0.2,<10"
@@ -134,6 +134,7 @@ pytest-cov = ">=7.0.0"
134
134
  hypothesis = ">=6.151.2"
135
135
  array-api-strict = ">=2.4.1,<2.5"
136
136
  numpy = ">=1.22.0"
137
+ scipy = ">=1.15.2,<2"
137
138
 
138
139
  [tool.pixi.feature.tests.tasks]
139
140
  tests = { cmd = "pytest -v", description = "Run tests" }
@@ -27,7 +27,7 @@ from ._lib._funcs import (
27
27
  )
28
28
  from ._lib._lazy import lazy_apply
29
29
 
30
- __version__ = "0.10.0"
30
+ __version__ = "0.10.1"
31
31
 
32
32
  # pylint: disable=duplicate-code
33
33
  __all__ = [
@@ -10,8 +10,9 @@ import contextlib
10
10
  import enum
11
11
  import warnings
12
12
  from collections.abc import Callable, Generator, Iterator, Sequence
13
- from functools import wraps
14
- from types import ModuleType
13
+ from functools import update_wrapper, wraps
14
+ from inspect import getattr_static
15
+ from types import FunctionType, ModuleType
15
16
  from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast
16
17
 
17
18
  from ._lib._utils._compat import is_dask_namespace, is_jax_namespace
@@ -48,8 +49,21 @@ class Deprecated(enum.Enum):
48
49
  DEPRECATED = Deprecated.DEPRECATED
49
50
 
50
51
 
52
+ def _clone_function(f: Callable[..., Any]) -> Callable[..., Any]:
53
+ """Returns a clone of an existing function."""
54
+ f_new = FunctionType(
55
+ f.__code__,
56
+ f.__globals__,
57
+ name=f.__name__,
58
+ argdefs=f.__defaults__,
59
+ closure=f.__closure__,
60
+ )
61
+ f_new.__kwdefaults__ = f.__kwdefaults__
62
+ return update_wrapper(f_new, f)
63
+
64
+
51
65
  def lazy_xp_function(
52
- func: Callable[..., Any],
66
+ func: Callable[..., Any] | tuple[type, str],
53
67
  *,
54
68
  allow_dask_compute: bool | int = False,
55
69
  jax_jit: bool = True,
@@ -69,8 +83,9 @@ def lazy_xp_function(
69
83
 
70
84
  Parameters
71
85
  ----------
72
- func : callable
73
- Function to be tested.
86
+ func : callable | tuple[type, str]
87
+ Function to be tested, or a tuple containing an (uninstantiated) class and a
88
+ method name to specify a class method to be tested.
74
89
  allow_dask_compute : bool | int, optional
75
90
  Whether `func` is allowed to internally materialize the Dask graph, or maximum
76
91
  number of times it is allowed to do so. This is typically triggered by
@@ -204,15 +219,49 @@ def lazy_xp_function(
204
219
  DeprecationWarning,
205
220
  stacklevel=2,
206
221
  )
207
- tags = {
222
+ tags: dict[str, bool | int | type] = {
208
223
  "allow_dask_compute": allow_dask_compute,
209
224
  "jax_jit": jax_jit,
210
225
  }
211
226
 
227
+ if isinstance(func, tuple):
228
+ # Replace the method with a clone before adding tags
229
+ # to avoid adding unwanted tags to a parent method when
230
+ # the method was inherited from a parent class.
231
+ # Note: can't just accept an unbound method `cls.method_name` because in
232
+ # case of inheritance it would be impossible to attribute it to the child class.
233
+ # This also makes it so tagged methods will appear in their class's ``__dict__``
234
+ # and thus findable by ``iter_tagged_modules`` below.
235
+ cls, method_name = func
236
+ # The method might be a staticmethod or classmethod so we need to do a dance
237
+ # to ensure that this is preserved.
238
+ raw_attr = getattr_static(cls, method_name)
239
+ method = getattr(cls, method_name)
240
+ if isinstance(raw_attr, classmethod):
241
+ method = method.__func__
242
+ cloned_method = _clone_function(method)
243
+
244
+ method_to_set: Any
245
+ if isinstance(raw_attr, staticmethod):
246
+ method_to_set = staticmethod(cloned_method)
247
+ elif isinstance(raw_attr, classmethod):
248
+ method_to_set = classmethod(cloned_method)
249
+ else:
250
+ method_to_set = cloned_method
251
+
252
+ setattr(cls, method_name, method_to_set)
253
+ f = getattr(cls, method_name)
254
+ if isinstance(raw_attr, classmethod):
255
+ f = f.__func__
256
+ # Annotate that cls owns this method so we can check that later.
257
+ tags["owner"] = cls
258
+ else:
259
+ f = func
260
+
212
261
  try:
213
- func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
262
+ f._lazy_xp_function = tags # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess]
214
263
  except AttributeError: # @cython.vectorize
215
- _ufuncs_tags[func] = tags
264
+ _ufuncs_tags[f] = tags
216
265
 
217
266
 
218
267
  def patch_lazy_xp_functions(
@@ -224,10 +273,11 @@ def patch_lazy_xp_functions(
224
273
  """
225
274
  Test lazy execution of functions tagged with :func:`lazy_xp_function`.
226
275
 
227
- If ``xp==jax.numpy``, search for all functions which have been tagged with
228
- :func:`lazy_xp_function` in the globals of the module that defines the current test,
229
- as well as in the ``lazy_xp_modules`` list in the globals of the same module,
230
- and wrap them with :func:`jax.jit`. Unwrap them at the end of the test.
276
+ If ``xp==jax.numpy``, search for all functions and methods which have been tagged
277
+ with :func:`lazy_xp_function` in the globals of the module that defines the current
278
+ test, as well as in the ``lazy_xp_modules`` list in the globals of the same module,
279
+ and wrap them with :func:`jax.jit`.
280
+ Unwrap them at the end of the test.
231
281
 
232
282
  If ``xp==dask.array``, wrap the functions with a decorator that disables
233
283
  ``compute()`` and ``persist()`` and ensures that exceptions and warnings are raised
@@ -271,18 +321,34 @@ def patch_lazy_xp_functions(
271
321
  the example above.
272
322
  """
273
323
  mod = cast(ModuleType, request.module)
274
- mods = [mod, *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", []))]
275
-
276
- to_revert: list[tuple[ModuleType, str, object]] = []
277
-
278
- def temp_setattr(mod: ModuleType, name: str, func: object) -> None:
324
+ search_targets: list[ModuleType | type] = [
325
+ mod,
326
+ *cast(list[ModuleType], getattr(mod, "lazy_xp_modules", [])),
327
+ ]
328
+ # Also search for classes within the above modules which have had lazy_xp_function
329
+ # applied to methods through ``lazy_xp_function((cls, method_name))`` syntax.
330
+ # We might end up adding classes incidentally imported into modules, so using a
331
+ # set here to cut down on potential redundancy.
332
+ classes: set[type] = set()
333
+ for target in search_targets:
334
+ for obj in target.__dict__.values():
335
+ if isinstance(obj, type):
336
+ classes.add(obj)
337
+ search_targets.extend(classes)
338
+
339
+ to_revert: list[tuple[ModuleType | type, str, object]] = []
340
+
341
+ def temp_setattr(target: ModuleType | type, name: str, func: object) -> None:
279
342
  """
280
343
  Variant of monkeypatch.setattr, which allows monkey-patching only selected
281
344
  parameters of a test so that pytest-run-parallel can run on the remainder.
282
345
  """
283
- assert hasattr(mod, name)
284
- to_revert.append((mod, name, getattr(mod, name)))
285
- setattr(mod, name, func)
346
+ assert hasattr(target, name)
347
+ # Need getattr_static because the attr could be a staticmethod or other
348
+ # descriptor and we don't want that to be stripped away.
349
+ original = getattr_static(target, name)
350
+ to_revert.append((target, name, original))
351
+ setattr(target, name, func)
286
352
 
287
353
  if monkeypatch is not None:
288
354
  warnings.warn(
@@ -298,10 +364,19 @@ def patch_lazy_xp_functions(
298
364
  temp_setattr = monkeypatch.setattr # type: ignore[assignment] # pyright: ignore[reportAssignmentType]
299
365
 
300
366
  def iter_tagged() -> Iterator[
301
- tuple[ModuleType, str, Callable[..., Any], dict[str, Any]]
367
+ tuple[ModuleType | type, str, Any, Callable[..., Any], dict[str, Any]]
302
368
  ]:
303
- for mod in mods:
304
- for name, func in mod.__dict__.items():
369
+ for target in search_targets:
370
+ for name, attr in target.__dict__.items():
371
+ # attr might be a staticmethod or classmethod. If so we need
372
+ # to peel it back and wrap the underlying function and later
373
+ # make sure not to accidentally replace it with a regular
374
+ # method.
375
+ func: Any = (
376
+ attr.__func__
377
+ if isinstance(attr, (staticmethod, classmethod))
378
+ else attr
379
+ )
305
380
  tags: dict[str, Any] | None = None
306
381
  with contextlib.suppress(AttributeError):
307
382
  tags = func._lazy_xp_function # pylint: disable=protected-access
@@ -309,23 +384,49 @@ def patch_lazy_xp_functions(
309
384
  with contextlib.suppress(KeyError, TypeError):
310
385
  tags = _ufuncs_tags[func]
311
386
  if tags is not None:
312
- yield mod, name, func, tags
313
-
387
+ if isinstance(target, type) and tags.get("owner") is not target:
388
+ # There's a common pattern to wrap functions in namespace
389
+ # classes to bypass lazy_xp_function like this:
390
+ #
391
+ # class naked:
392
+ # myfunc = mymodule.myfunc
393
+ #
394
+ # To ensure this still works when checking for tags in
395
+ # attributes of classes, ensure that target is the actual
396
+ # owning class where func was defined.
397
+ continue
398
+ # put attr, and func in the outputs so we can later tell
399
+ # if this was a staticmethod or classmethod.
400
+ yield target, name, attr, func, tags
401
+
402
+ wrapped: Any
314
403
  if is_dask_namespace(xp):
315
- for mod, name, func, tags in iter_tagged():
404
+ for target, name, attr, func, tags in iter_tagged():
316
405
  n = tags["allow_dask_compute"]
317
406
  if n is True:
318
407
  n = 1_000_000
319
408
  elif n is False:
320
409
  n = 0
321
410
  wrapped = _dask_wrap(func, n)
322
- temp_setattr(mod, name, wrapped)
411
+ # If we're dealing with a staticmethod or classmethod, make
412
+ # sure things stay that way.
413
+ if isinstance(attr, staticmethod):
414
+ wrapped = staticmethod(wrapped)
415
+ elif isinstance(attr, classmethod):
416
+ wrapped = classmethod(wrapped)
417
+ temp_setattr(target, name, wrapped)
323
418
 
324
419
  elif is_jax_namespace(xp):
325
- for mod, name, func, tags in iter_tagged():
420
+ for target, name, attr, func, tags in iter_tagged():
326
421
  if tags["jax_jit"]:
327
422
  wrapped = jax_autojit(func)
328
- temp_setattr(mod, name, wrapped)
423
+ # If we're dealing with a staticmethod or classmethod, make
424
+ # sure things stay that way.
425
+ if isinstance(attr, staticmethod):
426
+ wrapped = staticmethod(wrapped)
427
+ elif isinstance(attr, classmethod):
428
+ wrapped = classmethod(wrapped)
429
+ temp_setattr(target, name, wrapped)
329
430
 
330
431
  # We can't just decorate patch_lazy_xp_functions with
331
432
  # @contextlib.contextmanager because it would not work with the
@@ -335,8 +436,8 @@ def patch_lazy_xp_functions(
335
436
  try:
336
437
  yield
337
438
  finally:
338
- for mod, name, orig_func in to_revert:
339
- setattr(mod, name, orig_func)
439
+ for target, name, orig_func in to_revert:
440
+ setattr(target, name, orig_func)
340
441
 
341
442
  return revert_on_exit()
342
443
 
@@ -1,9 +1,10 @@
1
1
  from collections.abc import Callable, Iterator
2
2
  from types import ModuleType
3
- from typing import cast
3
+ from typing import Any, cast, final
4
4
 
5
5
  import numpy as np
6
6
  import pytest
7
+ from typing_extensions import override
7
8
 
8
9
  from array_api_extra._lib._backends import Backend
9
10
  from array_api_extra._lib._testing import (
@@ -321,6 +322,128 @@ def test_lazy_xp_function_cython_ufuncs(xp: ModuleType, library: Backend):
321
322
  xp_assert_equal(cast(Array, erf(x)), xp.asarray([1.0, 1.0]))
322
323
 
323
324
 
325
+ class A:
326
+ def __init__(self, x: Array):
327
+ xp = array_namespace(x)
328
+ self._xp: ModuleType = xp
329
+ self.x: Any = np.asarray(x)
330
+
331
+ def f(self, y: Array) -> Array:
332
+ return self._xp.asarray(np.matmul(self.x, np.asarray(y)))
333
+
334
+ def g(self, y: Array, z: Array) -> Array:
335
+ return self.f(y) + self.f(z)
336
+
337
+ def h(self, y: Array) -> bool:
338
+ return bool(self._xp.any(y))
339
+
340
+
341
+ class B(A):
342
+ @override
343
+ def __init__(self, x: Array): # pyright: ignore[reportMissingSuperCall]
344
+ xp = array_namespace(x)
345
+ self._xp: ModuleType = xp
346
+ self.x: Any = xp.asarray(x)
347
+
348
+ @override
349
+ def f(self, y: Array) -> Array:
350
+ return self._xp.matmul(self.x, y)
351
+
352
+ @staticmethod
353
+ def k(y: Array) -> "B":
354
+ return B(2.0 * y)
355
+
356
+ @staticmethod
357
+ def j(y: Array) -> "B":
358
+ xp = array_namespace(y)
359
+ y = xp.asarray(y)
360
+ if bool(xp.any(y)):
361
+ return B(y)
362
+ return B(y + 1.0)
363
+
364
+ @classmethod
365
+ def w(cls, y: Array) -> "B":
366
+ xp = array_namespace(y)
367
+ y = xp.asarray(y)
368
+ if bool(xp.any(y)):
369
+ return B(y)
370
+ return B(y + 1.0)
371
+
372
+
373
+ @final
374
+ class eager:
375
+ # this needs to be a staticmethod to appease the type checker
376
+ non_materializable5 = staticmethod(non_materializable5)
377
+
378
+
379
+ lazy_xp_function((B, "g"))
380
+ lazy_xp_function((B, "h"))
381
+ lazy_xp_function((B, "k"))
382
+ lazy_xp_function((B, "j"))
383
+ lazy_xp_function((B, "w"))
384
+
385
+
386
+ class TestLazyXpFunctionClasses:
387
+ def test_parent_method_not_tagged(self):
388
+ assert hasattr(B.g, "_lazy_xp_function")
389
+ assert not hasattr(A.g, "_lazy_xp_function")
390
+
391
+ @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="converts to NumPy")
392
+ @pytest.mark.skip_xp_backend(Backend.CUPY, reason="converts to NumPy")
393
+ @pytest.mark.skip_xp_backend(Backend.JAX_GPU, reason="converts to NumPy")
394
+ @pytest.mark.skip_xp_backend(Backend.TORCH_GPU, reason="converts to NumPy")
395
+ def test_lazy_xp_function_classes(self, xp: ModuleType, library: Backend):
396
+ x = xp.asarray([1.1, 2.2, 3.3])
397
+ y = xp.asarray([1.0, 2.0, 3.0])
398
+ foo = A(x)
399
+ bar = B(x)
400
+
401
+ if library.like(Backend.JAX):
402
+ with pytest.raises(
403
+ TypeError, match="Attempted boolean conversion of traced array"
404
+ ):
405
+ assert bar.h(y)
406
+
407
+ assert foo.h(y)
408
+
409
+ def test_static_methods_preserved(self, xp: ModuleType):
410
+ # Tests that static methods stay static methods when
411
+ # lazy_xp_function is applied.
412
+ x = xp.asarray([1.1, 2.2, 3.3])
413
+ foo = B(x)
414
+ bar = foo.k(x)
415
+ xp_assert_equal(bar.x, 2.0 * foo.x)
416
+
417
+ @pytest.mark.skip_xp_backend(Backend.DASK, reason="calls dask.compute()")
418
+ def test_static_methods_wrapped(self, xp: ModuleType, library: Backend):
419
+ x = xp.asarray([1.1, 2.2, 3.3])
420
+ foo = B(x)
421
+
422
+ if library.like(Backend.JAX):
423
+ with pytest.raises(
424
+ TypeError, match="Attempted boolean conversion of traced array"
425
+ ):
426
+ assert isinstance(foo.j(x), B)
427
+ else:
428
+ assert isinstance(foo.j(x), B)
429
+
430
+ @pytest.mark.skip_xp_backend(Backend.DASK, reason="calls dask.compute()")
431
+ def test_class_methods_wrapped(self, xp: ModuleType, library: Backend):
432
+ x = xp.asarray([1.1, 2.2, 3.3])
433
+ if library.like(Backend.JAX):
434
+ with pytest.raises(
435
+ TypeError, match="Attempted boolean conversion of traced array"
436
+ ):
437
+ assert isinstance(B.w(x), B)
438
+ else:
439
+ assert isinstance(B.w(x), B)
440
+
441
+ def test_circumvention(self, xp: ModuleType):
442
+ x = xp.asarray([1.0, 2.0])
443
+ y = eager.non_materializable5(x)
444
+ xp_assert_equal(y, x)
445
+
446
+
324
447
  def dask_raises(x: Array) -> Array:
325
448
  def _raises(x: Array) -> Array:
326
449
  # Test that map_blocks doesn't eagerly call the function;