array-api-extra 0.2.1.dev0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/PKG-INFO +1 -2
  2. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/pixi.lock +2 -18
  3. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/pyproject.toml +2 -2
  4. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/src/array_api_extra/__init__.py +2 -2
  5. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/src/array_api_extra/_funcs.py +1 -21
  6. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/src/array_api_extra/_lib/_compat.py +1 -1
  7. array_api_extra-0.3.2/src/array_api_extra/_lib/_typing.py +22 -0
  8. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/src/array_api_extra/_lib/_utils.py +21 -1
  9. array_api_extra-0.2.1.dev0/src/array_api_extra/_lib/_typing.py +0 -10
  10. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.all-contributorsrc +0 -0
  11. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.gitattributes +0 -0
  12. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.github/dependabot.yml +0 -0
  13. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.github/workflows/cd.yml +0 -0
  14. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.github/workflows/ci.yml +0 -0
  15. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.github/workflows/dependabot-auto-merge.yml +0 -0
  16. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.github/workflows/docs-build.yml +0 -0
  17. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.github/workflows/docs-deploy.yml +0 -0
  18. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.gitignore +0 -0
  19. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/.pre-commit-config.yaml +0 -0
  20. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/CONTRIBUTORS.md +0 -0
  21. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/LICENSE +0 -0
  22. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/README.md +0 -0
  23. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/codecov.yml +0 -0
  24. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/docs/api-reference.md +0 -0
  25. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/docs/conf.py +0 -0
  26. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/docs/contributing.md +0 -0
  27. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/docs/contributors.md +0 -0
  28. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/docs/index.md +0 -0
  29. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/docs/requirements.txt +0 -0
  30. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/src/array_api_extra/py.typed +0 -0
  31. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/tests/test_funcs.py +49 -49
  32. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/tests/test_utils.py +0 -0
  33. {array_api_extra-0.2.1.dev0 → array_api_extra-0.3.2}/tests/test_version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: array-api-extra
3
- Version: 0.2.1.dev0
3
+ Version: 0.3.2
4
4
  Summary: Extra array functions built on top of the array API standard.
5
5
  Project-URL: Homepage, https://github.com/data-apis/array-api-extra
6
6
  Project-URL: Bug Tracker, https://github.com/data-apis/array-api-extra/issues
@@ -41,7 +41,6 @@ Classifier: Programming Language :: Python :: 3.12
41
41
  Classifier: Programming Language :: Python :: 3.13
42
42
  Classifier: Typing :: Typed
43
43
  Requires-Python: >=3.10
44
- Requires-Dist: typing-extensions
45
44
  Provides-Extra: docs
46
45
  Requires-Dist: furo>=2023.08.17; extra == 'docs'
47
46
  Requires-Dist: myst-parser>=0.13; extra == 'docs'
@@ -46,7 +46,6 @@ environments:
46
46
  - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
47
47
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
48
48
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
49
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
50
49
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
51
50
  - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2
52
51
  - pypi: .
@@ -82,7 +81,6 @@ environments:
82
81
  - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda
83
82
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
84
83
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
85
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
86
84
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
87
85
  - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2
88
86
  - pypi: .
@@ -118,7 +116,6 @@ environments:
118
116
  - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda
119
117
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
120
118
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
121
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
122
119
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
123
120
  - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda
124
121
  - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda
@@ -172,7 +169,6 @@ environments:
172
169
  - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
173
170
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
174
171
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
175
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
176
172
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
177
173
  - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2
178
174
  - pypi: .
@@ -210,7 +206,6 @@ environments:
210
206
  - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda
211
207
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
212
208
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
213
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
214
209
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
215
210
  - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2
216
211
  - pypi: .
@@ -248,7 +243,6 @@ environments:
248
243
  - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda
249
244
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
250
245
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
251
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
252
246
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
253
247
  - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda
254
248
  - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda
@@ -283,7 +277,6 @@ environments:
283
277
  - conda: https://prefix.dev/conda-forge/linux-64/python_abi-3.13-5_cp313.conda
284
278
  - conda: https://prefix.dev/conda-forge/linux-64/readline-8.2-h8228510_1.conda
285
279
  - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
286
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
287
280
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
288
281
  - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2
289
282
  - pypi: .
@@ -301,7 +294,6 @@ environments:
301
294
  - conda: https://prefix.dev/conda-forge/osx-arm64/python_abi-3.13-5_cp313.conda
302
295
  - conda: https://prefix.dev/conda-forge/osx-arm64/readline-8.2-h92ec313_1.conda
303
296
  - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda
304
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
305
297
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
306
298
  - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2
307
299
  - pypi: .
@@ -317,7 +309,6 @@ environments:
317
309
  - conda: https://prefix.dev/conda-forge/win-64/python-3.13.0-hf5aa216_101_cp313.conda
318
310
  - conda: https://prefix.dev/conda-forge/win-64/python_abi-3.13-5_cp313.conda
319
311
  - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda
320
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
321
312
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
322
313
  - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda
323
314
  - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda
@@ -768,7 +759,6 @@ environments:
768
759
  - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda
769
760
  - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
770
761
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
771
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
772
762
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
773
763
  - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.2.3-pyhd8ed1ab_0.conda
774
764
  - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2
@@ -832,7 +822,6 @@ environments:
832
822
  - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda
833
823
  - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda
834
824
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
835
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
836
825
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
837
826
  - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.2.3-pyhd8ed1ab_0.conda
838
827
  - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2
@@ -893,7 +882,6 @@ environments:
893
882
  - conda: https://prefix.dev/conda-forge/noarch/sphinxcontrib-serializinghtml-1.1.10-pyhd8ed1ab_0.conda
894
883
  - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda
895
884
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
896
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
897
885
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
898
886
  - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda
899
887
  - conda: https://prefix.dev/conda-forge/noarch/urllib3-2.2.3-pyhd8ed1ab_0.conda
@@ -1167,7 +1155,6 @@ environments:
1167
1155
  - conda: https://prefix.dev/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
1168
1156
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
1169
1157
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
1170
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
1171
1158
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
1172
1159
  - conda: https://prefix.dev/conda-forge/linux-64/xz-5.2.6-h166bdaf_0.tar.bz2
1173
1160
  - pypi: .
@@ -1205,7 +1192,6 @@ environments:
1205
1192
  - conda: https://prefix.dev/conda-forge/osx-arm64/tk-8.6.13-h5083fa2_1.conda
1206
1193
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
1207
1194
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
1208
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
1209
1195
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
1210
1196
  - conda: https://prefix.dev/conda-forge/osx-arm64/xz-5.2.6-h57fd34a_0.tar.bz2
1211
1197
  - pypi: .
@@ -1243,7 +1229,6 @@ environments:
1243
1229
  - conda: https://prefix.dev/conda-forge/win-64/tk-8.6.13-h5226925_1.conda
1244
1230
  - conda: https://prefix.dev/conda-forge/noarch/toml-0.10.2-pyhd8ed1ab_0.tar.bz2
1245
1231
  - conda: https://prefix.dev/conda-forge/noarch/tomli-2.1.0-pyhff2d567_0.conda
1246
- - conda: https://prefix.dev/conda-forge/noarch/typing_extensions-4.12.2-pyha770c72_0.conda
1247
1232
  - conda: https://prefix.dev/conda-forge/noarch/tzdata-2024b-hc8b5060_0.conda
1248
1233
  - conda: https://prefix.dev/conda-forge/win-64/ucrt-10.0.22621.0-h57928b3_1.conda
1249
1234
  - conda: https://prefix.dev/conda-forge/win-64/vc-14.3-ha32ba9b_23.conda
@@ -1306,11 +1291,10 @@ packages:
1306
1291
  timestamp: 1722035895436
1307
1292
  - kind: pypi
1308
1293
  name: array-api-extra
1309
- version: 0.2.1.dev0
1294
+ version: 0.3.2
1310
1295
  path: .
1311
- sha256: 81d59ceda4b873652fada8c13d55e7cc98840538e87ecf744e935b7d1ac3017f
1296
+ sha256: 8f949b727c03da7c3dff8d6ffab9361f273ea2a81a30296f0474707aaad1b227
1312
1297
  requires_dist:
1313
- - typing-extensions
1314
1298
  - furo>=2023.8.17 ; extra == 'docs'
1315
1299
  - myst-parser>=0.13 ; extra == 'docs'
1316
1300
  - sphinx-autodoc-typehints ; extra == 'docs'
@@ -26,7 +26,7 @@ classifiers = [
26
26
  "Typing :: Typed",
27
27
  ]
28
28
  dynamic = ["version"]
29
- dependencies = ["typing-extensions"]
29
+ dependencies = []
30
30
 
31
31
  [project.optional-dependencies]
32
32
  tests = [
@@ -64,7 +64,6 @@ platforms = ["linux-64", "osx-arm64", "win-64"]
64
64
 
65
65
  [tool.pixi.dependencies]
66
66
  python = ">=3.10.15,<3.14"
67
- typing_extensions = ">=4.12.2,<4.13"
68
67
 
69
68
  [tool.pixi.pypi-dependencies]
70
69
  array-api-extra = { path = ".", editable = true }
@@ -74,6 +73,7 @@ pre-commit = "*"
74
73
  pylint = "*"
75
74
  basedmypy = "*"
76
75
  basedpyright = "*"
76
+ typing_extensions = ">=4.12.2,<4.13"
77
77
  # import dependencies for mypy:
78
78
  array-api-strict = "*"
79
79
  numpy = "*"
@@ -1,8 +1,8 @@
1
- from __future__ import annotations
1
+ from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
2
2
 
3
3
  from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
4
4
 
5
- __version__ = "0.2.1.dev0"
5
+ __version__ = "0.3.2"
6
6
 
7
7
  # pylint: disable=duplicate-code
8
8
  __all__ = [
@@ -133,7 +133,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
133
133
  m = atleast_nd(m, ndim=2, xp=xp)
134
134
  m = xp.astype(m, dtype)
135
135
 
136
- avg = _mean(m, axis=1, xp=xp)
136
+ avg = _utils.mean(m, axis=1, xp=xp)
137
137
  fact = m.shape[1] - 1
138
138
 
139
139
  if fact <= 0:
@@ -199,26 +199,6 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
199
199
  return xp.reshape(diag, (n, n))
200
200
 
201
201
 
202
- def _mean(
203
- x: Array,
204
- /,
205
- *,
206
- axis: int | tuple[int, ...] | None = None,
207
- keepdims: bool = False,
208
- xp: ModuleType,
209
- ) -> Array:
210
- """
211
- Complex mean, https://github.com/data-apis/array-api/issues/846.
212
- """
213
- if xp.isdtype(x.dtype, "complex floating"):
214
- x_real = xp.real(x)
215
- x_imag = xp.imag(x)
216
- mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
217
- mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
218
- return mean_real + (mean_imag * xp.asarray(1j))
219
- return xp.mean(x, axis=axis, keepdims=keepdims)
220
-
221
-
222
202
  def expand_dims(
223
203
  a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
224
204
  ) -> Array:
@@ -6,7 +6,7 @@ import inspect
6
6
  import sys
7
7
  import typing
8
8
 
9
- from typing_extensions import override
9
+ from ._typing import override
10
10
 
11
11
  if typing.TYPE_CHECKING:
12
12
  from ._typing import Array, Device
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
2
+
3
+ import typing
4
+ from types import ModuleType
5
+ from typing import Any
6
+
7
+ if typing.TYPE_CHECKING:
8
+ from typing_extensions import override
9
+
10
+ # To be changed to a Protocol later (see data-apis/array-api#589)
11
+ Array = Any # type: ignore[no-any-explicit]
12
+ Device = Any # type: ignore[no-any-explicit]
13
+ else:
14
+
15
+ def no_op_decorator(f): # pyright: ignore[reportUnreachable]
16
+ return f
17
+
18
+ override = no_op_decorator
19
+
20
+ __all__ = ["ModuleType", "override"]
21
+ if typing.TYPE_CHECKING:
22
+ __all__ += ["Array", "Device"]
@@ -7,7 +7,7 @@ if typing.TYPE_CHECKING:
7
7
 
8
8
  from . import _compat
9
9
 
10
- __all__ = ["in1d"]
10
+ __all__ = ["in1d", "mean"]
11
11
 
12
12
 
13
13
  def in1d(
@@ -63,3 +63,23 @@ def in1d(
63
63
  if assume_unique:
64
64
  return ret[: x1.shape[0]]
65
65
  return xp.take(ret, rev_idx, axis=0)
66
+
67
+
68
+ def mean(
69
+ x: Array,
70
+ /,
71
+ *,
72
+ axis: int | tuple[int, ...] | None = None,
73
+ keepdims: bool = False,
74
+ xp: ModuleType,
75
+ ) -> Array:
76
+ """
77
+ Complex mean, https://github.com/data-apis/array-api/issues/846.
78
+ """
79
+ if xp.isdtype(x.dtype, "complex floating"):
80
+ x_real = xp.real(x)
81
+ x_imag = xp.imag(x)
82
+ mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
83
+ mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
84
+ return mean_real + (mean_imag * xp.asarray(1j))
85
+ return xp.mean(x, axis=axis, keepdims=keepdims)
@@ -1,10 +0,0 @@
1
- from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990
2
-
3
- from types import ModuleType
4
- from typing import Any
5
-
6
- # To be changed to a Protocol later (see data-apis/array-api#589)
7
- Array = Any # type: ignore[no-any-explicit]
8
- Device = Any # type: ignore[no-any-explicit]
9
-
10
- __all__ = ["Array", "Device", "ModuleType"]
@@ -157,6 +157,55 @@ class TestCreateDiagonal:
157
157
  create_diagonal(xp.asarray([[1]]), xp=xp)
158
158
 
159
159
 
160
+ class TestExpandDims:
161
+ def test_functionality(self):
162
+ def _squeeze_all(b: Array) -> Array:
163
+ """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
164
+ for axis in range(b.ndim):
165
+ with contextlib.suppress(ValueError):
166
+ b = xp.squeeze(b, axis=axis)
167
+ return b
168
+
169
+ s = (2, 3, 4, 5)
170
+ a = xp.empty(s)
171
+ for axis in range(-5, 4):
172
+ b = expand_dims(a, axis=axis, xp=xp)
173
+ assert b.shape[axis] == 1
174
+ assert _squeeze_all(b).shape == s
175
+
176
+ def test_axis_tuple(self):
177
+ a = xp.empty((3, 3, 3))
178
+ assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3)
179
+ assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1)
180
+ assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1)
181
+ assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3)
182
+
183
+ def test_axis_out_of_range(self):
184
+ s = (2, 3, 4, 5)
185
+ a = xp.empty(s)
186
+ with pytest.raises(IndexError, match="out of bounds"):
187
+ expand_dims(a, axis=-6, xp=xp)
188
+ with pytest.raises(IndexError, match="out of bounds"):
189
+ expand_dims(a, axis=5, xp=xp)
190
+
191
+ a = xp.empty((3, 3, 3))
192
+ with pytest.raises(IndexError, match="out of bounds"):
193
+ expand_dims(a, axis=(0, -6), xp=xp)
194
+ with pytest.raises(IndexError, match="out of bounds"):
195
+ expand_dims(a, axis=(0, 5), xp=xp)
196
+
197
+ def test_repeated_axis(self):
198
+ a = xp.empty((3, 3, 3))
199
+ with pytest.raises(ValueError, match="Duplicate dimensions"):
200
+ expand_dims(a, axis=(1, 1), xp=xp)
201
+
202
+ def test_positive_negative_repeated(self):
203
+ # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
204
+ a = xp.empty((2, 3, 4, 5))
205
+ with pytest.raises(ValueError, match="Duplicate dimensions"):
206
+ expand_dims(a, axis=(3, -3), xp=xp)
207
+
208
+
160
209
  class TestKron:
161
210
  def test_basic(self):
162
211
  # Using 0-dimensional array
@@ -222,55 +271,6 @@ class TestKron:
222
271
  assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron")
223
272
 
224
273
 
225
- class TestExpandDims:
226
- def test_functionality(self):
227
- def _squeeze_all(b: Array) -> Array:
228
- """Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
229
- for axis in range(b.ndim):
230
- with contextlib.suppress(ValueError):
231
- b = xp.squeeze(b, axis=axis)
232
- return b
233
-
234
- s = (2, 3, 4, 5)
235
- a = xp.empty(s)
236
- for axis in range(-5, 4):
237
- b = expand_dims(a, axis=axis, xp=xp)
238
- assert b.shape[axis] == 1
239
- assert _squeeze_all(b).shape == s
240
-
241
- def test_axis_tuple(self):
242
- a = xp.empty((3, 3, 3))
243
- assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3)
244
- assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1)
245
- assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1)
246
- assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3)
247
-
248
- def test_axis_out_of_range(self):
249
- s = (2, 3, 4, 5)
250
- a = xp.empty(s)
251
- with pytest.raises(IndexError, match="out of bounds"):
252
- expand_dims(a, axis=-6, xp=xp)
253
- with pytest.raises(IndexError, match="out of bounds"):
254
- expand_dims(a, axis=5, xp=xp)
255
-
256
- a = xp.empty((3, 3, 3))
257
- with pytest.raises(IndexError, match="out of bounds"):
258
- expand_dims(a, axis=(0, -6), xp=xp)
259
- with pytest.raises(IndexError, match="out of bounds"):
260
- expand_dims(a, axis=(0, 5), xp=xp)
261
-
262
- def test_repeated_axis(self):
263
- a = xp.empty((3, 3, 3))
264
- with pytest.raises(ValueError, match="Duplicate dimensions"):
265
- expand_dims(a, axis=(1, 1), xp=xp)
266
-
267
- def test_positive_negative_repeated(self):
268
- # https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
269
- a = xp.empty((2, 3, 4, 5))
270
- with pytest.raises(ValueError, match="Duplicate dimensions"):
271
- expand_dims(a, axis=(3, -3), xp=xp)
272
-
273
-
274
274
  class TestSetDiff1D:
275
275
  def test_setdiff1d(self):
276
276
  x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4])