dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,1091 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import warnings
5
+ from contextlib import nullcontext as does_not_warn
6
+ from itertools import permutations, zip_longest
7
+
8
+ import pytest
9
+
10
+ np = pytest.importorskip("numpy")
11
+
12
+ import itertools
13
+
14
+ import dask_array as da
15
+ from dask import config
16
+ from dask_array._numpy_compat import ComplexWarning
17
+ from dask_array._test_utils import assert_eq, same_keys
18
+ from dask.core import get_deps
19
+
20
+
21
+ @pytest.mark.parametrize("dtype", ["f4", "i4"])
22
+ @pytest.mark.parametrize("keepdims", [True, False])
23
+ @pytest.mark.parametrize("nan", [True, False])
24
+ def test_numel(dtype, keepdims, nan):
25
+ x = np.ones((2, 3, 4))
26
+ if nan:
27
+ y = np.random.default_rng().uniform(-1, 1, size=(2, 3, 4))
28
+ x[y < 0] = np.nan
29
+ numel = da.reductions.nannumel
30
+
31
+ def _sum(arr, **kwargs):
32
+ n = np.sum(np.ma.masked_where(np.isnan(arr), arr), **kwargs)
33
+ return n.filled(0) if isinstance(n, np.ma.MaskedArray) else n
34
+
35
+ else:
36
+ numel = da.reductions.numel
37
+ _sum = np.sum
38
+
39
+ assert_eq(
40
+ numel(x, axis=(), keepdims=keepdims, dtype=dtype),
41
+ _sum(x, axis=(), keepdims=keepdims, dtype=dtype),
42
+ )
43
+ assert_eq(
44
+ numel(x, axis=0, keepdims=keepdims, dtype=dtype),
45
+ _sum(x, axis=0, keepdims=keepdims, dtype=dtype),
46
+ )
47
+
48
+ for length in range(x.ndim):
49
+ for sub in itertools.combinations([d for d in range(x.ndim)], length):
50
+ assert_eq(
51
+ numel(x, axis=sub, keepdims=keepdims, dtype=dtype),
52
+ _sum(x, axis=sub, keepdims=keepdims, dtype=dtype),
53
+ )
54
+
55
+ for length in range(x.ndim):
56
+ for sub in itertools.combinations([d for d in range(x.ndim)], length):
57
+ ssub = np.random.default_rng().shuffle(list(sub))
58
+ assert_eq(
59
+ numel(x, axis=ssub, keepdims=keepdims, dtype=dtype),
60
+ _sum(x, axis=ssub, keepdims=keepdims, dtype=dtype),
61
+ )
62
+
63
+
64
+ def reduction_0d_test(da_func, darr, np_func, narr):
65
+ expected = np_func(narr)
66
+ actual = da_func(darr)
67
+
68
+ assert_eq(actual, expected)
69
+ assert_eq(da_func(narr), expected) # Ensure Dask reductions work with NumPy arrays
70
+ assert actual.size == 1
71
+
72
+
73
+ def test_reductions_0D():
74
+ x = np.int_(3) # np.int_ has a dtype attribute, np.int does not.
75
+ a = da.from_array(x, chunks=(1,))
76
+
77
+ reduction_0d_test(da.sum, a, np.sum, x)
78
+ reduction_0d_test(da.prod, a, np.prod, x)
79
+ reduction_0d_test(da.mean, a, np.mean, x)
80
+ reduction_0d_test(da.var, a, np.var, x)
81
+ reduction_0d_test(da.std, a, np.std, x)
82
+ reduction_0d_test(da.min, a, np.min, x)
83
+ reduction_0d_test(da.max, a, np.max, x)
84
+ reduction_0d_test(da.any, a, np.any, x)
85
+ reduction_0d_test(da.all, a, np.all, x)
86
+
87
+ reduction_0d_test(da.nansum, a, np.nansum, x)
88
+ reduction_0d_test(da.nanprod, a, np.nanprod, x)
89
+ reduction_0d_test(da.nanmean, a, np.mean, x)
90
+ reduction_0d_test(da.nanvar, a, np.var, x)
91
+ reduction_0d_test(da.nanstd, a, np.std, x)
92
+ reduction_0d_test(da.nanmin, a, np.nanmin, x)
93
+ reduction_0d_test(da.nanmax, a, np.nanmax, x)
94
+
95
+
96
+ def reduction_1d_test(da_func, darr, np_func, narr, use_dtype=True, split_every=True):
97
+ assert_eq(da_func(darr), np_func(narr))
98
+ assert_eq(da_func(narr), np_func(narr)) # Ensure Dask reductions work with NumPy arrays
99
+ assert_eq(da_func(darr, keepdims=True), np_func(narr, keepdims=True))
100
+ assert_eq(da_func(darr, axis=()), np_func(narr, axis=()))
101
+ assert same_keys(da_func(darr), da_func(darr))
102
+ assert same_keys(da_func(darr, keepdims=True), da_func(darr, keepdims=True))
103
+ if use_dtype:
104
+ with pytest.warns(ComplexWarning) if np.iscomplexobj(narr) else does_not_warn():
105
+ assert_eq(da_func(darr, dtype="f8"), np_func(narr, dtype="f8"))
106
+ assert_eq(da_func(darr, dtype="i8"), np_func(narr, dtype="i8"))
107
+ assert same_keys(da_func(darr, dtype="i8"), da_func(darr, dtype="i8"))
108
+ if split_every:
109
+ a1 = da_func(darr, split_every=2)
110
+ a2 = da_func(darr, split_every={0: 2})
111
+ assert same_keys(a1, a2)
112
+ assert_eq(a1, np_func(narr))
113
+ assert_eq(a2, np_func(narr))
114
+ assert_eq(da_func(darr, keepdims=True, split_every=2), np_func(narr, keepdims=True))
115
+
116
+
117
+ @pytest.mark.parametrize("dtype", ["f4", "i4", "c8"])
118
+ def test_reductions_1D(dtype):
119
+ with warnings.catch_warnings():
120
+ warnings.simplefilter("ignore", ComplexWarning)
121
+ x = (np.arange(5) + 1j * np.arange(5)).astype(dtype)
122
+ a = da.from_array(x, chunks=(2,))
123
+
124
+ reduction_1d_test(da.sum, a, np.sum, x)
125
+ reduction_1d_test(da.prod, a, np.prod, x)
126
+ reduction_1d_test(da.mean, a, np.mean, x)
127
+ reduction_1d_test(da.var, a, np.var, x)
128
+ reduction_1d_test(da.std, a, np.std, x)
129
+ reduction_1d_test(da.min, a, np.min, x, False)
130
+ reduction_1d_test(da.max, a, np.max, x, False)
131
+ reduction_1d_test(da.any, a, np.any, x, False)
132
+ reduction_1d_test(da.all, a, np.all, x, False)
133
+
134
+ reduction_1d_test(da.nansum, a, np.nansum, x)
135
+ reduction_1d_test(da.nanprod, a, np.nanprod, x)
136
+ reduction_1d_test(da.nanmean, a, np.mean, x)
137
+ reduction_1d_test(da.nanvar, a, np.var, x)
138
+ reduction_1d_test(da.nanstd, a, np.std, x)
139
+ reduction_1d_test(da.nanmin, a, np.nanmin, x, False)
140
+ reduction_1d_test(da.nanmax, a, np.nanmax, x, False)
141
+
142
+
143
+ def test_reductions_1D_datetime():
144
+ x = np.arange(5).astype("datetime64[ns]")
145
+ a = da.from_array(x, chunks=(2,))
146
+ reduction_1d_test(da.min, a, np.min, x, False)
147
+ reduction_1d_test(da.max, a, np.max, x, False)
148
+ reduction_1d_test(da.any, a, np.any, x, False)
149
+ reduction_1d_test(da.all, a, np.all, x, False)
150
+ reduction_1d_test(da.nanmin, a, np.nanmin, x, False)
151
+ reduction_1d_test(da.nanmax, a, np.nanmax, x, False)
152
+
153
+
154
+ @pytest.mark.parametrize("dtype", ["f4", "c8"])
155
+ @pytest.mark.parametrize("x", [np.array([np.inf, np.nan, -np.inf, 2]), np.array([np.nan, np.nan, 3, 2])])
156
+ def test_reductions_1D_nans(x, dtype):
157
+ x = x.astype(dtype)
158
+ a = da.from_array(x, chunks=(1,))
159
+ with warnings.catch_warnings():
160
+ warnings.simplefilter("ignore", RuntimeWarning)
161
+ reduction_1d_test(da.nansum, a, np.nansum, x)
162
+ reduction_1d_test(da.nanprod, a, np.nanprod, x)
163
+ reduction_1d_test(da.nanmean, a, np.nanmean, x, False)
164
+ reduction_1d_test(da.nanvar, a, np.nanvar, x, False)
165
+ reduction_1d_test(da.nanstd, a, np.nanstd, x, False)
166
+ reduction_1d_test(da.nanmin, a, np.nanmin, x, False)
167
+ reduction_1d_test(da.nanmax, a, np.nanmax, x, False)
168
+
169
+
170
+ def reduction_2d_test(da_func, darr, np_func, narr, use_dtype=True, split_every=True):
171
+ assert_eq(da_func(darr), np_func(narr))
172
+ assert_eq(da_func(darr, keepdims=True), np_func(narr, keepdims=True))
173
+ assert_eq(da_func(darr, axis=()), np_func(narr, axis=()))
174
+ assert_eq(da_func(darr, axis=0), np_func(narr, axis=0))
175
+ assert_eq(da_func(darr, axis=1), np_func(narr, axis=1))
176
+ assert_eq(da_func(darr, axis=-1), np_func(narr, axis=-1))
177
+ assert_eq(da_func(darr, axis=-2), np_func(narr, axis=-2))
178
+ assert_eq(da_func(darr, axis=1, keepdims=True), np_func(narr, axis=1, keepdims=True))
179
+ assert_eq(da_func(darr, axis=(), keepdims=True), np_func(narr, axis=(), keepdims=True))
180
+ assert_eq(da_func(darr, axis=(1, 0)), np_func(narr, axis=(1, 0)))
181
+
182
+ assert same_keys(da_func(darr, axis=()), da_func(darr, axis=()))
183
+ assert same_keys(da_func(darr, axis=1), da_func(darr, axis=1))
184
+ assert same_keys(da_func(darr, axis=(1, 0)), da_func(darr, axis=(1, 0)))
185
+
186
+ if use_dtype:
187
+ with pytest.warns(ComplexWarning) if np.iscomplexobj(narr) else does_not_warn():
188
+ assert_eq(da_func(darr, dtype="f8"), np_func(narr, dtype="f8"))
189
+ assert_eq(da_func(darr, dtype="i8"), np_func(narr, dtype="i8"))
190
+
191
+ if split_every:
192
+ a1 = da_func(darr, split_every=4)
193
+ a2 = da_func(darr, split_every={0: 2, 1: 2})
194
+ assert same_keys(a1, a2)
195
+ assert_eq(a1, np_func(narr))
196
+ assert_eq(a2, np_func(narr))
197
+ assert_eq(
198
+ da_func(darr, keepdims=True, split_every=4),
199
+ np_func(narr, keepdims=True),
200
+ )
201
+ assert_eq(da_func(darr, axis=(), split_every=2), np_func(narr, axis=()))
202
+ assert_eq(da_func(darr, axis=0, split_every=2), np_func(narr, axis=0))
203
+ assert_eq(
204
+ da_func(darr, axis=(), keepdims=True, split_every=2),
205
+ np_func(narr, axis=(), keepdims=True),
206
+ )
207
+ assert_eq(
208
+ da_func(darr, axis=0, keepdims=True, split_every=2),
209
+ np_func(narr, axis=0, keepdims=True),
210
+ )
211
+ assert_eq(da_func(darr, axis=1, split_every=2), np_func(narr, axis=1))
212
+ assert_eq(
213
+ da_func(darr, axis=1, keepdims=True, split_every=2),
214
+ np_func(narr, axis=1, keepdims=True),
215
+ )
216
+
217
+
218
+ def test_reduction_errors():
219
+ x = da.ones((5, 5), chunks=(3, 3))
220
+ with pytest.raises(ValueError):
221
+ x.sum(axis=2)
222
+ with pytest.raises(ValueError):
223
+ x.sum(axis=-3)
224
+
225
+
226
+ @pytest.mark.slow
227
+ @pytest.mark.parametrize("dtype", ["f4", "i4", "c8"])
228
+ def test_reductions_2D(dtype):
229
+ with warnings.catch_warnings():
230
+ warnings.simplefilter("ignore", ComplexWarning)
231
+ x = (np.arange(1, 122) + 1j * np.arange(1, 122)).reshape((11, 11)).astype(dtype)
232
+ a = da.from_array(x, chunks=(4, 4))
233
+
234
+ b = a.sum(keepdims=True)
235
+ assert b.__dask_keys__() == [[(b.name, 0, 0)]]
236
+
237
+ reduction_2d_test(da.sum, a, np.sum, x)
238
+ reduction_2d_test(da.mean, a, np.mean, x)
239
+ reduction_2d_test(da.var, a, np.var, x, False) # Difference in dtype algo
240
+ reduction_2d_test(da.std, a, np.std, x, False) # Difference in dtype algo
241
+ reduction_2d_test(da.min, a, np.min, x, False)
242
+ reduction_2d_test(da.max, a, np.max, x, False)
243
+ reduction_2d_test(da.any, a, np.any, x, False)
244
+ reduction_2d_test(da.all, a, np.all, x, False)
245
+
246
+ reduction_2d_test(da.nansum, a, np.nansum, x)
247
+ reduction_2d_test(da.nanmean, a, np.mean, x)
248
+ reduction_2d_test(da.nanvar, a, np.nanvar, x, False) # Difference in dtype algo
249
+ reduction_2d_test(da.nanstd, a, np.nanstd, x, False) # Difference in dtype algo
250
+ reduction_2d_test(da.nanmin, a, np.nanmin, x, False)
251
+ reduction_2d_test(da.nanmax, a, np.nanmax, x, False)
252
+
253
+ # prod/nanprod overflow for data at this size, leading to warnings about
254
+ # overflow/invalid values.
255
+ with warnings.catch_warnings():
256
+ warnings.simplefilter("ignore", RuntimeWarning)
257
+ reduction_2d_test(da.prod, a, np.prod, x)
258
+ reduction_2d_test(da.nanprod, a, np.nanprod, x)
259
+
260
+
261
+ def test_reductions_2D_datetime():
262
+ x = np.arange(1, 122).reshape(11, 11).astype("datetime64[ns]")
263
+ a = da.from_array(x, chunks=(4, 4))
264
+ reduction_2d_test(da.min, a, np.min, x, False)
265
+ reduction_2d_test(da.max, a, np.max, x, False)
266
+ reduction_2d_test(da.any, a, np.any, x, False)
267
+ reduction_2d_test(da.all, a, np.all, x, False)
268
+ reduction_2d_test(da.nanmin, a, np.nanmin, x, False)
269
+ reduction_2d_test(da.nanmax, a, np.nanmax, x, False)
270
+
271
+
272
+ @pytest.mark.parametrize(
273
+ ["dfunc", "func"],
274
+ [
275
+ (da.argmin, np.argmin),
276
+ (da.argmax, np.argmax),
277
+ (da.nanargmin, np.nanargmin),
278
+ (da.nanargmax, np.nanargmax),
279
+ ],
280
+ )
281
+ def test_arg_reductions(dfunc, func):
282
+ x = np.random.default_rng().random((10, 10, 10))
283
+ a = da.from_array(x, chunks=(3, 4, 5))
284
+
285
+ assert_eq(dfunc(a), func(x))
286
+ assert_eq(dfunc(a, 0), func(x, 0))
287
+ assert_eq(dfunc(a, 1), func(x, 1))
288
+ assert_eq(dfunc(a, 2), func(x, 2))
289
+ with config.set(split_every=2):
290
+ assert_eq(dfunc(a), func(x))
291
+ assert_eq(dfunc(a, 0), func(x, 0))
292
+ assert_eq(dfunc(a, 1), func(x, 1))
293
+ assert_eq(dfunc(a, 2), func(x, 2))
294
+ assert_eq(dfunc(a, keepdims=True), func(x, keepdims=True))
295
+
296
+ pytest.raises(ValueError, lambda: dfunc(a, 3))
297
+ pytest.raises(TypeError, lambda: dfunc(a, (0, 1)))
298
+
299
+ x2 = np.arange(10)
300
+ a2 = da.from_array(x2, chunks=3)
301
+ assert_eq(dfunc(a2), func(x2))
302
+ assert_eq(dfunc(a2, 0), func(x2, 0))
303
+ assert_eq(dfunc(a2, 0, split_every=2), func(x2, 0))
304
+
305
+ x3 = np.array(1)
306
+ a3 = da.from_array(x3)
307
+ assert_eq(dfunc(a3), func(x3))
308
+
309
+
310
+ @pytest.mark.parametrize(["dfunc", "func"], [(da.nanmin, np.nanmin), (da.nanmax, np.nanmax)])
311
+ def test_nan_reduction_warnings(dfunc, func):
312
+ x = np.random.default_rng().random((10, 10, 10))
313
+ x[5] = np.nan
314
+ a = da.from_array(x, chunks=(3, 4, 5))
315
+ with warnings.catch_warnings():
316
+ warnings.simplefilter("ignore", RuntimeWarning) # All-NaN slice encountered
317
+ expected = func(x, 1)
318
+ assert_eq(dfunc(a, 1), expected)
319
+
320
+
321
+ @pytest.mark.parametrize(["dfunc", "func"], [(da.nanargmin, np.nanargmin), (da.nanargmax, np.nanargmax)])
322
+ def test_nanarg_reductions(dfunc, func):
323
+ x = np.random.default_rng().random((10, 10, 10))
324
+ x[5] = np.nan
325
+ a = da.from_array(x, chunks=(3, 4, 5))
326
+ assert_eq(dfunc(a), func(x))
327
+ assert_eq(dfunc(a, 0), func(x, 0))
328
+ with warnings.catch_warnings():
329
+ warnings.simplefilter("ignore", RuntimeWarning) # All-NaN slice encountered
330
+ with pytest.raises(ValueError):
331
+ dfunc(a, 1).compute()
332
+
333
+ with pytest.raises(ValueError):
334
+ dfunc(a, 2).compute()
335
+
336
+ x[:] = np.nan
337
+ a = da.from_array(x, chunks=(3, 4, 5))
338
+ with pytest.raises(ValueError):
339
+ dfunc(a).compute()
340
+
341
+
342
+ @pytest.mark.parametrize(["dfunc", "func"], [(da.min, np.min), (da.max, np.max)])
343
+ def test_min_max_empty_chunks(dfunc, func):
344
+ x1 = np.arange(10)
345
+ a1 = da.from_array(x1, chunks=1)
346
+ assert_eq(dfunc(a1[a1 < 2]), func(x1[x1 < 2]))
347
+
348
+ x2 = np.arange(10)
349
+ a2 = da.from_array(x2, chunks=((5, 0, 5),))
350
+ assert_eq(dfunc(a2), func(x2))
351
+
352
+ x3 = np.array([[1, 1, 2, 3], [1, 1, 4, 0]])
353
+ a3 = da.from_array(x3, chunks=1)
354
+ assert_eq(dfunc(a3[a3 >= 2]), func(x3[x3 >= 2]))
355
+
356
+ a4 = da.arange(10)
357
+ with pytest.raises(ValueError): # Checking it mimics numpy behavior when all chunks are empty
358
+ dfunc(a4[a4 < 0]).compute()
359
+
360
+
361
+ @pytest.mark.parametrize("func", ["argmax", "nanargmax"])
362
+ def test_arg_reductions_unknown_chunksize(func):
363
+ x = da.arange(10, chunks=5)
364
+ x = x[x > 1]
365
+
366
+ with pytest.raises(ValueError) as info:
367
+ getattr(da, func)(x)
368
+
369
+ assert "unknown chunksize" in str(info.value)
370
+
371
+
372
+ @pytest.mark.parametrize("func", ["argmax", "nanargmax"])
373
+ def test_arg_reductions_unknown_chunksize_2d(func):
374
+ x = da.ones((10, 10), chunks=(5, 5))
375
+ x = x[x[0, :] > 0, :] # unknown chunks in first dimension only
376
+
377
+ with pytest.raises(ValueError):
378
+ getattr(da, func)(x, axis=0)
379
+
380
+ getattr(da, func)(x, axis=1).compute()
381
+
382
+
383
+ @pytest.mark.parametrize("func", ["argmax", "nanargmax"])
384
+ def test_arg_reductions_unknown_single_chunksize(func):
385
+ x = da.ones((10, 10), chunks=(10, 10))
386
+ x = x[x[0, :] > 0, :] # unknown chunks in first dimension only
387
+
388
+ getattr(da, func)(x, axis=0).compute()
389
+ getattr(da, func)(x, axis=1).compute()
390
+
391
+
392
+ def test_reductions_2D_nans():
393
+ # chunks are a mix of some/all/no NaNs
394
+ x = np.full((4, 4), np.nan)
395
+ x[:2, :2] = np.array([[1, 2], [3, 4]])
396
+ x[2, 2] = 5
397
+ x[3, 3] = 6
398
+ a = da.from_array(x, chunks=(2, 2))
399
+
400
+ reduction_2d_test(da.sum, a, np.sum, x, False, False)
401
+ reduction_2d_test(da.prod, a, np.prod, x, False, False)
402
+ reduction_2d_test(da.mean, a, np.mean, x, False, False)
403
+ reduction_2d_test(da.var, a, np.var, x, False, False)
404
+ reduction_2d_test(da.std, a, np.std, x, False, False)
405
+ reduction_2d_test(da.min, a, np.min, x, False, False)
406
+ reduction_2d_test(da.max, a, np.max, x, False, False)
407
+ reduction_2d_test(da.any, a, np.any, x, False, False)
408
+ reduction_2d_test(da.all, a, np.all, x, False, False)
409
+
410
+ reduction_2d_test(da.nansum, a, np.nansum, x)
411
+ reduction_2d_test(da.nanprod, a, np.nanprod, x)
412
+
413
+ with warnings.catch_warnings():
414
+ warnings.simplefilter("ignore", RuntimeWarning)
415
+ reduction_2d_test(da.nanmean, a, np.nanmean, x, False, False)
416
+ reduction_2d_test(da.nanvar, a, np.nanvar, x, False, False)
417
+ reduction_2d_test(da.nanstd, a, np.nanstd, x, False, False)
418
+ reduction_2d_test(da.nanmin, a, np.nanmin, x, False, False)
419
+ reduction_2d_test(da.nanmax, a, np.nanmax, x, False, False)
420
+
421
+ assert_eq(da.argmax(a), np.argmax(x))
422
+ assert_eq(da.argmin(a), np.argmin(x))
423
+ assert_eq(da.nanargmax(a), np.nanargmax(x))
424
+ assert_eq(da.nanargmin(a), np.nanargmin(x))
425
+
426
+ assert_eq(da.argmax(a, axis=0), np.argmax(x, axis=0))
427
+ assert_eq(da.argmin(a, axis=0), np.argmin(x, axis=0))
428
+ assert_eq(da.nanargmax(a, axis=0), np.nanargmax(x, axis=0))
429
+ assert_eq(da.nanargmin(a, axis=0), np.nanargmin(x, axis=0))
430
+
431
+ assert_eq(da.argmax(a, axis=1), np.argmax(x, axis=1))
432
+ assert_eq(da.argmin(a, axis=1), np.argmin(x, axis=1))
433
+ assert_eq(da.nanargmax(a, axis=1), np.nanargmax(x, axis=1))
434
+ assert_eq(da.nanargmin(a, axis=1), np.nanargmin(x, axis=1))
435
+
436
+
437
+ def test_moment():
438
+ def moment(x, n, axis=None):
439
+ return ((x - x.mean(axis=axis, keepdims=True)) ** n).sum(axis=axis) / np.ones_like(x).sum(axis=axis)
440
+
441
+ # Poorly conditioned
442
+ x = np.array([1.0, 2.0, 3.0] * 10).reshape((3, 10)) + 1e8
443
+ a = da.from_array(x, chunks=5)
444
+ assert_eq(a.moment(2), moment(x, 2))
445
+ assert_eq(a.moment(3), moment(x, 3))
446
+ assert_eq(a.moment(4), moment(x, 4))
447
+
448
+ x = np.arange(1, 122).reshape((11, 11)).astype("f8")
449
+ a = da.from_array(x, chunks=(4, 4))
450
+ assert_eq(a.moment(4, axis=1), moment(x, 4, axis=1))
451
+ assert_eq(a.moment(4, axis=(1, 0)), moment(x, 4, axis=(1, 0)))
452
+
453
+ # Tree reduction
454
+ assert_eq(a.moment(order=4, split_every=4), moment(x, 4))
455
+ assert_eq(a.moment(order=4, axis=0, split_every=4), moment(x, 4, axis=0))
456
+ assert_eq(a.moment(order=4, axis=1, split_every=4), moment(x, 4, axis=1))
457
+
458
+
459
+ def test_reductions_with_negative_axes():
460
+ x = np.random.default_rng().random((4, 4, 4))
461
+ a = da.from_array(x, chunks=2)
462
+
463
+ assert_eq(a.argmin(axis=-1), x.argmin(axis=-1))
464
+ assert_eq(a.argmin(axis=-1, split_every=2), x.argmin(axis=-1))
465
+
466
+ assert_eq(a.sum(axis=-1), x.sum(axis=-1))
467
+ assert_eq(a.sum(axis=(0, -1)), x.sum(axis=(0, -1)))
468
+
469
+
470
+ def test_nan():
471
+ x = np.array([[1, np.nan, 3, 4], [5, 6, 7, np.nan], [9, 10, 11, 12]])
472
+ d = da.from_array(x, chunks=(2, 2))
473
+
474
+ assert_eq(np.nansum(x), da.nansum(d))
475
+ assert_eq(np.nansum(x, axis=0), da.nansum(d, axis=0))
476
+ assert_eq(np.nanmean(x, axis=1), da.nanmean(d, axis=1))
477
+ assert_eq(np.nanmin(x, axis=1), da.nanmin(d, axis=1))
478
+ assert_eq(np.nanmax(x, axis=(0, 1)), da.nanmax(d, axis=(0, 1)))
479
+ assert_eq(np.nanvar(x), da.nanvar(d))
480
+ assert_eq(np.nanstd(x, axis=0), da.nanstd(d, axis=0))
481
+ assert_eq(np.nanargmin(x, axis=0), da.nanargmin(d, axis=0))
482
+ assert_eq(np.nanargmax(x, axis=0), da.nanargmax(d, axis=0))
483
+ assert_eq(np.nanprod(x), da.nanprod(d))
484
+
485
+
486
+ @pytest.mark.parametrize("func", ["nansum", "sum", "nanmin", "min", "nanmax", "max"])
487
+ def test_nan_object(func):
488
+ with warnings.catch_warnings():
489
+ if os.name == "nt" and func in {"min", "max"}:
490
+ # RuntimeWarning: invalid value encountered in reduce in wrapreduction
491
+ # from NumPy.
492
+ warnings.simplefilter("ignore", RuntimeWarning)
493
+
494
+ x = np.array([[1, np.nan, 3, 4], [5, 6, 7, np.nan], [9, 10, 11, 12]]).astype(object)
495
+ d = da.from_array(x, chunks=(2, 2))
496
+
497
+ if func in {"nanmin", "nanmax"}:
498
+ warnings.simplefilter("ignore", RuntimeWarning)
499
+
500
+ assert_eq(getattr(np, func)(x, axis=()), getattr(da, func)(d, axis=()))
501
+
502
+ if func in {"nanmin", "nanmax"}:
503
+ warnings.simplefilter("default", RuntimeWarning)
504
+
505
+ if func in {"min", "max"}:
506
+ warnings.simplefilter("ignore", RuntimeWarning)
507
+ assert_eq(getattr(np, func)(x, axis=0), getattr(da, func)(d, axis=0))
508
+ if os.name != "nt" and func in {"min", "max"}:
509
+ warnings.simplefilter("default", RuntimeWarning)
510
+
511
+ assert_eq(getattr(np, func)(x, axis=1), getattr(da, func)(d, axis=1))
512
+ # wrap the scalar in a numpy array since the dask version cannot know dtype
513
+ assert_eq(np.array(getattr(np, func)(x)).astype(object), getattr(da, func)(d))
514
+
515
+
516
+ def test_0d_array():
517
+ x = da.mean(da.ones(4, chunks=4), axis=()).compute()
518
+ x = da.mean(da.ones(4, chunks=4), axis=0).compute()
519
+ y = np.mean(np.ones(4))
520
+ assert type(x) == type(y)
521
+
522
+ x = da.sum(da.zeros(4, chunks=1)).compute()
523
+ y = np.sum(np.zeros(4))
524
+ assert type(x) == type(y)
525
+
526
+
527
+ def test_reduction_on_scalar():
528
+ x = da.from_array(np.array(1.0), chunks=())
529
+ assert (x == x).all()
530
+
531
+
532
+ def test_reductions_with_empty_array():
533
+ dx1 = da.ones((10, 0, 5), chunks=4)
534
+ x1 = dx1.compute()
535
+ dx2 = da.ones((0, 0, 0), chunks=4)
536
+ x2 = dx2.compute()
537
+
538
+ for dx, x in [(dx1, x1), (dx2, x2)]:
539
+ with warnings.catch_warnings():
540
+ warnings.simplefilter("ignore", RuntimeWarning) # Mean of empty slice
541
+ assert_eq(dx.mean(), x.mean())
542
+ assert_eq(dx.mean(axis=()), x.mean(axis=()))
543
+ assert_eq(dx.mean(axis=0), x.mean(axis=0))
544
+ assert_eq(dx.mean(axis=1), x.mean(axis=1))
545
+ assert_eq(dx.mean(axis=2), x.mean(axis=2))
546
+
547
+
548
+ def assert_max_deps(x, n, eq=True):
549
+ dependencies, dependents = get_deps(x.dask)
550
+ if eq:
551
+ assert max(map(len, dependencies.values())) == n
552
+ else:
553
+ assert max(map(len, dependencies.values())) <= n
554
+
555
+
556
+ def test_tree_reduce_depth():
557
+ # 2D
558
+ x = da.from_array(np.arange(242).reshape((11, 22)), chunks=(3, 4))
559
+ thresh = {0: 2, 1: 3}
560
+ assert_max_deps(x.sum(split_every=thresh), 2 * 3)
561
+ assert_max_deps(x.sum(axis=(), split_every=thresh), 1)
562
+ assert_max_deps(x.sum(axis=0, split_every=thresh), 2)
563
+ assert_max_deps(x.sum(axis=1, split_every=thresh), 3)
564
+ assert_max_deps(x.sum(split_every=20), 20, False)
565
+ assert_max_deps(x.sum(axis=(), split_every=20), 1)
566
+ assert_max_deps(x.sum(axis=0, split_every=20), 4)
567
+ assert_max_deps(x.sum(axis=1, split_every=20), 6)
568
+
569
+ # 3D
570
+ x = da.from_array(np.arange(11 * 22 * 29).reshape((11, 22, 29)), chunks=(3, 4, 5))
571
+ thresh = {0: 2, 1: 3, 2: 4}
572
+ assert_max_deps(x.sum(split_every=thresh), 2 * 3 * 4)
573
+ assert_max_deps(x.sum(axis=(), split_every=thresh), 1)
574
+ assert_max_deps(x.sum(axis=0, split_every=thresh), 2)
575
+ assert_max_deps(x.sum(axis=1, split_every=thresh), 3)
576
+ assert_max_deps(x.sum(axis=2, split_every=thresh), 4)
577
+ assert_max_deps(x.sum(axis=(0, 1), split_every=thresh), 2 * 3)
578
+ assert_max_deps(x.sum(axis=(0, 2), split_every=thresh), 2 * 4)
579
+ assert_max_deps(x.sum(axis=(1, 2), split_every=thresh), 3 * 4)
580
+ assert_max_deps(x.sum(split_every=20), 20, False)
581
+ assert_max_deps(x.sum(axis=(), split_every=20), 1)
582
+ assert_max_deps(x.sum(axis=0, split_every=20), 4)
583
+ assert_max_deps(x.sum(axis=1, split_every=20), 6)
584
+ assert_max_deps(x.sum(axis=2, split_every=20), 6)
585
+ assert_max_deps(x.sum(axis=(0, 1), split_every=20), 20, False)
586
+ assert_max_deps(x.sum(axis=(0, 2), split_every=20), 20, False)
587
+ assert_max_deps(x.sum(axis=(1, 2), split_every=20), 20, False)
588
+ assert_max_deps(x.sum(axis=(0, 1), split_every=40), 4 * 6)
589
+ assert_max_deps(x.sum(axis=(0, 2), split_every=40), 4 * 6)
590
+ assert_max_deps(x.sum(axis=(1, 2), split_every=40), 6 * 6)
591
+
592
+
593
+ def test_tree_reduce_set_options():
594
+ x = da.from_array(np.arange(242).reshape((11, 22)), chunks=(3, 4))
595
+ with config.set(split_every={0: 2, 1: 3}):
596
+ assert_max_deps(x.sum(), 2 * 3)
597
+ assert_max_deps(x.sum(axis=()), 1)
598
+ assert_max_deps(x.sum(axis=0), 2)
599
+
600
+
601
+ def test_reduction_names():
602
+ x = da.ones(5, chunks=(2,))
603
+ assert x.sum().name.startswith("sum")
604
+ assert "max" in x.max().name.split("-")[0]
605
+ assert x.var().name.startswith("var")
606
+ assert x.all().name.startswith("all")
607
+ assert any(k[0].startswith("nansum") for k in da.nansum(x).dask)
608
+ assert x.mean().name.startswith("mean")
609
+
610
+
611
+ def test_general_reduction_names():
612
+ dtype = int
613
+ a = da.reduction(da.ones(10, dtype, chunks=2), np.sum, np.sum, dtype=dtype, name="foo")
614
+ names, tokens = list(zip_longest(*[key[0].rsplit("-", 1) for key in a.dask]))
615
+ # array-expr uses "ones" vs traditional "ones_like" and may skip "foo-partial"
616
+ expected_traditional = {"ones_like", "foo", "foo-partial", "foo-aggregate"}
617
+ expected_expr = {"ones", "foo", "foo-aggregate"}
618
+ assert set(names) in (expected_traditional, expected_expr)
619
+ assert all(tokens)
620
+
621
+
622
+ def test_reduction_intermediate_chunks():
623
+ """Test that intermediate reduction results have correct chunk sizes."""
624
+ x = da.ones((10, 12), chunks=(5, 4))
625
+ result = x.sum(axis=0, keepdims=True)
626
+
627
+ # Lower the expression to get the physical Blockwise + PartialReduce tree
628
+ lowered = result.expr.lower_completely()
629
+
630
+ # Walk the expression tree to find the Blockwise (chunk step)
631
+ from dask_array._blockwise import Blockwise
632
+
633
+ def find_blockwise(expr):
634
+ if isinstance(expr, Blockwise):
635
+ return expr
636
+ for dep in expr.dependencies():
637
+ found = find_blockwise(dep)
638
+ if found is not None:
639
+ return found
640
+ return None
641
+
642
+ blockwise_expr = find_blockwise(lowered)
643
+ assert blockwise_expr is not None
644
+
645
+ # The intermediate should have chunks of size 1 along reduced axis
646
+ assert blockwise_expr.chunks == ((1, 1), (4, 4, 4))
647
+
648
+
649
+ @pytest.mark.parametrize("func", [np.sum, np.argmax])
650
+ def test_array_reduction_out(func):
651
+ x = da.arange(10, chunks=(5,))
652
+ y = da.ones((10, 10), chunks=(4, 4))
653
+ func(y, axis=0, out=x)
654
+ assert_eq(x, func(np.ones((10, 10)), axis=0))
655
+
656
+
657
+ @pytest.mark.parametrize("func", ["cumsum", "cumprod", "nancumsum", "nancumprod"])
658
+ @pytest.mark.parametrize("use_nan", [False, True])
659
+ @pytest.mark.parametrize("axis", [None, 0, 1, -1])
660
+ @pytest.mark.parametrize("method", ["sequential", "blelloch"])
661
+ def test_array_cumreduction_axis(func, use_nan, axis, method):
662
+ np_func = getattr(np, func)
663
+ da_func = getattr(da, func)
664
+
665
+ s = (10, 11, 12)
666
+ a = np.arange(np.prod(s), dtype=float).reshape(s)
667
+ if use_nan:
668
+ a[1] = np.nan
669
+ d = da.from_array(a, chunks=(4, 5, 6))
670
+ if func in ["cumprod", "nancumprod"] and method == "blelloch" and axis is None:
671
+ with pytest.warns(RuntimeWarning):
672
+ da_func(d, axis=axis, method=method).compute()
673
+ return
674
+
675
+ a_r = np_func(a, axis=axis)
676
+ d_r = da_func(d, axis=axis, method=method)
677
+
678
+ assert_eq(a_r, d_r)
679
+
680
+
681
+ @pytest.mark.parametrize("func", ["cumsum", "cumprod", "nancumsum", "nancumprod"])
682
+ @pytest.mark.parametrize("method", ["sequential", "blelloch"])
683
+ @pytest.mark.parametrize("target_dtype", [None, int, float])
684
+ def test_array_cumreduction_dtype(func, method, target_dtype):
685
+ np_func = getattr(np, func)
686
+ da_func = getattr(da, func)
687
+
688
+ a = np.linspace(0, 1, num=10, dtype=float)
689
+ d = da.from_array(a)
690
+
691
+ a_r = np_func(a, axis=0, dtype=target_dtype)
692
+ d_r = da_func(d, method=method, axis=0, dtype=target_dtype)
693
+
694
+ assert_eq(a_r, d_r)
695
+
696
+
697
+ @pytest.mark.parametrize("ufunc", ["add", "multiply", "maximum"])
698
+ @pytest.mark.parametrize("target_dtype", [None, int, float])
699
+ def test_array_cumreduction_ufunc(ufunc, target_dtype):
700
+ ufunc_obj = getattr(np, ufunc)
701
+ accumulate = ufunc_obj.accumulate
702
+ identity = ufunc_obj.identity
703
+
704
+ a = np.linspace(0, 1, num=10, dtype=float)
705
+ d = da.from_array(a)
706
+
707
+ cumreduction = da.reductions.cumreduction
708
+
709
+ a_r = accumulate(a, dtype=target_dtype)
710
+ d_r = cumreduction(accumulate, ufunc_obj, identity, d, dtype=target_dtype)
711
+
712
+ assert_eq(a_r, d_r)
713
+
714
+
715
+ @pytest.mark.parametrize("func", [np.cumsum, np.cumprod])
716
+ def test_array_cumreduction_out(func):
717
+ x = da.ones((10, 10), chunks=(4, 4))
718
+ func(x, axis=0, out=x)
719
+ assert_eq(x, func(np.ones((10, 10)), axis=0))
720
+
721
+
722
+ @pytest.mark.parametrize("npfunc,daskfunc", [(np.sort, da.topk), (np.argsort, da.argtopk)])
723
+ @pytest.mark.parametrize("split_every", [None, 2, 4, 8])
724
+ def test_topk_argtopk1(npfunc, daskfunc, split_every):
725
+ # Test data
726
+ k = 5
727
+ # Test at least 3 levels of aggregation when split_every=2
728
+ # to stress the different chunk, combine, aggregate kernels
729
+ rng = np.random.default_rng()
730
+ npa = rng.random(800)
731
+ npb = rng.random((10, 20, 30))
732
+
733
+ a = da.from_array(npa, chunks=((120, 80, 100, 200, 300),))
734
+ b = da.from_array(npb, chunks=(4, 8, 8))
735
+
736
+ # 1-dimensional arrays
737
+ # top 5 elements, sorted descending
738
+ assert_eq(npfunc(npa)[-k:][::-1], daskfunc(a, k, split_every=split_every))
739
+ # bottom 5 elements, sorted ascending
740
+ assert_eq(npfunc(npa)[:k], daskfunc(a, -k, split_every=split_every))
741
+
742
+ # n-dimensional arrays
743
+ # also testing when k > chunk
744
+ # top 5 elements, sorted descending
745
+ assert_eq(
746
+ npfunc(npb, axis=0)[-k:, :, :][::-1, :, :],
747
+ daskfunc(b, k, axis=0, split_every=split_every),
748
+ )
749
+ assert_eq(
750
+ npfunc(npb, axis=1)[:, -k:, :][:, ::-1, :],
751
+ daskfunc(b, k, axis=1, split_every=split_every),
752
+ )
753
+ assert_eq(
754
+ npfunc(npb, axis=-1)[:, :, -k:][:, :, ::-1],
755
+ daskfunc(b, k, axis=-1, split_every=split_every),
756
+ )
757
+ with pytest.raises(ValueError):
758
+ daskfunc(b, k, axis=3, split_every=split_every)
759
+
760
+ # bottom 5 elements, sorted ascending
761
+ assert_eq(npfunc(npb, axis=0)[:k, :, :], daskfunc(b, -k, axis=0, split_every=split_every))
762
+ assert_eq(npfunc(npb, axis=1)[:, :k, :], daskfunc(b, -k, axis=1, split_every=split_every))
763
+ assert_eq(
764
+ npfunc(npb, axis=-1)[:, :, :k],
765
+ daskfunc(b, -k, axis=-1, split_every=split_every),
766
+ )
767
+ with pytest.raises(ValueError):
768
+ daskfunc(b, -k, axis=3, split_every=split_every)
769
+
770
+
771
+ @pytest.mark.parametrize("npfunc,daskfunc", [(np.sort, da.topk), (np.argsort, da.argtopk)])
772
+ @pytest.mark.parametrize("split_every", [None, 2, 3, 4])
773
+ @pytest.mark.parametrize("chunksize", [1, 2, 3, 4, 5, 10])
774
+ def test_topk_argtopk2(npfunc, daskfunc, split_every, chunksize):
775
+ """Fine test use cases when k is larger than chunk size"""
776
+ npa = np.random.default_rng().random((10,))
777
+ a = da.from_array(npa, chunks=chunksize)
778
+ k = 5
779
+
780
+ # top 5 elements, sorted descending
781
+ assert_eq(npfunc(npa)[-k:][::-1], daskfunc(a, k, split_every=split_every))
782
+ # bottom 5 elements, sorted ascending
783
+ assert_eq(npfunc(npa)[:k], daskfunc(a, -k, split_every=split_every))
784
+
785
+
786
+ def test_topk_argtopk3():
787
+ a = da.random.default_rng().random((10, 20, 30), chunks=(4, 8, 8))
788
+
789
+ # As Array methods
790
+ assert_eq(a.topk(5, axis=1, split_every=2), da.topk(a, 5, axis=1, split_every=2))
791
+ assert_eq(a.argtopk(5, axis=1, split_every=2), da.argtopk(a, 5, axis=1, split_every=2))
792
+
793
+
794
+ @pytest.mark.parametrize(
795
+ "func",
796
+ [da.cumsum, da.cumprod, da.argmin, da.argmax, da.min, da.max, da.nansum, da.nanmax],
797
+ )
798
+ @pytest.mark.parametrize("method", ["sequential", "blelloch"])
799
+ def test_regres_3940(func, method):
800
+ if func in {da.cumsum, da.cumprod}:
801
+ kwargs = {"method": method}
802
+ else:
803
+ kwargs = {}
804
+ a = da.ones((5, 2), chunks=(2, 2))
805
+ assert func(a, **kwargs).name != func(a + 1, **kwargs).name
806
+ assert func(a, axis=0, **kwargs).name != func(a, **kwargs).name
807
+ assert func(a, axis=0, **kwargs).name != func(a, axis=1, **kwargs).name
808
+ if func not in {da.cumsum, da.cumprod, da.argmin, da.argmax}:
809
+ assert func(a, axis=()).name != func(a).name
810
+ assert func(a, axis=()).name != func(a, axis=0).name
811
+
812
+
813
+ def test_trace():
814
+ def _assert(a, b, *args, **kwargs):
815
+ return assert_eq(a.trace(*args, **kwargs), b.trace(*args, **kwargs))
816
+
817
+ b = np.arange(12).reshape((3, 4))
818
+ a = da.from_array(b, 1)
819
+ _assert(a, b)
820
+ _assert(a, b, 0)
821
+ _assert(a, b, 1)
822
+ _assert(a, b, -1)
823
+
824
+ b = np.arange(8).reshape((2, 2, 2))
825
+ a = da.from_array(b, 2)
826
+ _assert(a, b)
827
+ _assert(a, b, 0)
828
+ _assert(a, b, 1)
829
+ _assert(a, b, -1)
830
+ _assert(a, b, 0, 0, 1)
831
+ _assert(a, b, 0, 0, 2)
832
+ _assert(a, b, 0, 1, 2, int)
833
+ _assert(a, b, 0, 1, 2, float)
834
+ _assert(a, b, offset=1, axis1=0, axis2=2, dtype=int)
835
+ _assert(a, b, offset=1, axis1=0, axis2=2, dtype=float)
836
+
837
+
838
+ @pytest.mark.parametrize("func", ["median", "nanmedian"])
839
+ @pytest.mark.parametrize("axis", [0, [0, 1], 1, -1])
840
+ @pytest.mark.parametrize("keepdims", [True, False])
841
+ def test_median(axis, keepdims, func):
842
+ x = np.arange(100).reshape((2, 5, 10))
843
+ d = da.from_array(x, chunks=2)
844
+ assert_eq(
845
+ getattr(da, func)(d, axis=axis, keepdims=keepdims),
846
+ getattr(np, func)(x, axis=axis, keepdims=keepdims),
847
+ )
848
+
849
+
850
+ @pytest.mark.parametrize("func", ["median", "nanmedian"])
851
+ @pytest.mark.parametrize("axis", [0, [0, 2], 1])
852
+ def test_median_does_not_rechunk_if_whole_axis_in_one_chunk(axis, func):
853
+ x = np.arange(100).reshape((2, 5, 10))
854
+ d = da.from_array(x, chunks=(2, 1, 10))
855
+
856
+ actual = getattr(da, func)(d, axis=axis)
857
+ expected = getattr(np, func)(x, axis=axis)
858
+ assert_eq(actual, expected)
859
+ does_rechunk = "rechunk" in str(dict(actual.__dask_graph__()))
860
+ if axis == 1:
861
+ assert does_rechunk
862
+ else:
863
+ assert not does_rechunk
864
+
865
+
866
+ @pytest.mark.parametrize("method", ["sum", "mean", "prod"])
867
+ def test_object_reduction(method):
868
+ arr = da.ones(1).astype(object)
869
+ result = getattr(arr, method)().compute()
870
+ assert result == 1
871
+
872
+
873
+ @pytest.mark.parametrize("func", ["nanmin", "nanmax"])
874
+ def test_empty_chunk_nanmin_nanmax(func):
875
+ # see https://github.com/dask/dask/issues/8352
876
+ x = np.arange(10).reshape(2, 5)
877
+ d = da.from_array(x, chunks=2)
878
+ x = x[x > 4]
879
+ d = d[d > 4]
880
+ block_lens = np.array([len(x.compute()) for x in d.blocks])
881
+ assert 0 in block_lens
882
+ with pytest.raises(ValueError) as err:
883
+ getattr(da, func)(d)
884
+ assert "Arrays chunk sizes are unknown" in str(err)
885
+ d = d.compute_chunk_sizes()
886
+ assert_eq(getattr(da, func)(d), getattr(np, func)(x))
887
+
888
+
889
+ @pytest.mark.parametrize("func", ["nanmin", "nanmax"])
890
+ def test_empty_chunk_nanmin_nanmax_raise(func):
891
+ # see https://github.com/dask/dask/issues/8352
892
+ x = np.arange(10).reshape(2, 5)
893
+ d = da.from_array(x, chunks=2)
894
+ d = d[d > 9]
895
+ x = x[x > 9]
896
+ d = d.compute_chunk_sizes()
897
+ with pytest.raises(ValueError) as err_np:
898
+ getattr(np, func)(x)
899
+ with pytest.raises(ValueError) as err_da:
900
+ d = getattr(da, func)(d)
901
+ d.compute()
902
+ assert str(err_np.value) == str(err_da.value)
903
+
904
+
905
+ def test_mean_func_does_not_warn():
906
+ # non-regression test for https://github.com/pydata/xarray/issues/5151
907
+ xr = pytest.importorskip("xarray")
908
+ a = xr.DataArray(da.from_array(np.full((10, 10), np.nan)))
909
+
910
+ with warnings.catch_warnings(record=True) as rec:
911
+ a.mean().compute()
912
+ assert not rec # did not warn
913
+
914
+
915
+ @pytest.mark.parametrize("func", ["nanvar", "nanstd"])
916
+ def test_nan_func_does_not_warn(func):
917
+ # non-regression test for #6105
918
+ x = np.ones((10,)) * np.nan
919
+ x[0] = 1
920
+ x[1] = 2
921
+ d = da.from_array(x, chunks=2)
922
+ with warnings.catch_warnings(record=True) as rec:
923
+ getattr(da, func)(d).compute()
924
+ assert not rec # did not warn
925
+
926
+
927
+ @pytest.mark.parametrize("chunks", list(permutations(((2, 1) * 8, (3,) * 8, (6,) * 4))))
928
+ @pytest.mark.parametrize("split_every", [2, 4])
929
+ @pytest.mark.parametrize("axes", list(permutations((0, 1, 2), 2)) + list(permutations((0, 1, 2))))
930
+ def test_chunk_structure_independence(axes, split_every, chunks):
931
+ # Reducing an array should not depend on its chunk-structure!!!
932
+ # See Issue #8541: https://github.com/dask/dask/issues/8541
933
+ shape = tuple(np.sum(s) for s in chunks)
934
+ np_array = np.arange(np.prod(shape)).reshape(*shape)
935
+ x = da.from_array(np_array, chunks=chunks)
936
+ reduced_x = da.reduction(
937
+ x,
938
+ lambda x, axis, keepdims: x,
939
+ lambda x, axis, keepdims: x,
940
+ keepdims=True,
941
+ axis=axes,
942
+ split_every=split_every,
943
+ dtype=x.dtype,
944
+ meta=x._meta,
945
+ )
946
+ assert_eq(reduced_x, np_array, check_chunks=False, check_shape=False)
947
+
948
+
949
+ def test_weighted_reduction():
950
+ # Weighted reduction
951
+ def w_sum(x, weights=None, dtype=None, computing_meta=False, **kwargs):
952
+ """`chunk` callable for (weighted) sum"""
953
+ if computing_meta:
954
+ return x
955
+ if weights is not None:
956
+ x = x * weights
957
+ return np.sum(x, dtype=dtype, **kwargs)
958
+
959
+ # Arrays
960
+ a = 1 + np.ma.arange(60).reshape(6, 10)
961
+ a[2, 2] = np.ma.masked
962
+ dx = da.from_array(a, chunks=(4, 5))
963
+ # Weights
964
+ w = np.linspace(1, 2, 6).reshape(6, 1)
965
+
966
+ # No weights (i.e. normal sum)
967
+ x = da.reduction(dx, w_sum, np.sum, dtype=dx.dtype)
968
+ assert_eq(x, np.sum(a), check_shape=True)
969
+
970
+ # Weighted sum
971
+ x = da.reduction(dx, w_sum, np.sum, dtype="f8", weights=w)
972
+ assert_eq(x, np.sum(a * w), check_shape=True)
973
+
974
+ # Non-broadcastable weights (short axis)
975
+ with pytest.raises(ValueError):
976
+ da.reduction(dx, w_sum, np.sum, weights=[1, 2, 3])
977
+
978
+ # Non-broadcastable weights (too many dims)
979
+ with pytest.raises(ValueError):
980
+ da.reduction(dx, w_sum, np.sum, weights=[[[2]]])
981
+
982
+
983
+ def test_cumreduction_no_rechunk_on_1d_array():
984
+ x = da.ones((5,))
985
+ y = da.cumsum(x)
986
+ no_rechunk = "rechunk" not in str(dict(y.__dask_graph__()))
987
+ assert no_rechunk
988
+
989
+
990
+ @pytest.mark.parametrize("axis", [3, 0, [1, 3]])
991
+ @pytest.mark.parametrize("q", [0.75, [0.75], [0.75, 0.4]])
992
+ @pytest.mark.parametrize("rechunk", [True, False])
993
+ def test_nanquantile(rechunk, q, axis):
994
+ shape = 7, 10, 7, 10
995
+ arr = np.random.randn(*shape)
996
+ indexer = np.random.randint(0, 10, size=shape)
997
+ arr[indexer >= 8] = np.nan
998
+ arr[:, :, :, 1] = 1
999
+ arr[1, :, :, :] = 1
1000
+
1001
+ darr = da.from_array(arr, chunks=(2, 3, 4, (5 if rechunk else -1)))
1002
+ assert_eq(da.nanquantile(darr, q, axis=axis), np.nanquantile(arr, q, axis=axis))
1003
+ assert_eq(
1004
+ da.nanquantile(darr, q, axis=axis, keepdims=True),
1005
+ np.nanquantile(arr, q, axis=axis, keepdims=True),
1006
+ )
1007
+ assert_eq(
1008
+ da.nanpercentile(darr, q * 100, axis=axis),
1009
+ np.nanpercentile(arr, q * 100, axis=axis),
1010
+ )
1011
+ assert_eq(
1012
+ da.nanpercentile(darr, q * 100, axis=axis, keepdims=True),
1013
+ np.nanpercentile(arr, q * 100, axis=axis, keepdims=True),
1014
+ )
1015
+
1016
+
1017
+ @pytest.mark.parametrize("axis", [3, [1, 3]])
1018
+ @pytest.mark.parametrize("q", [0.75, [0.75]])
1019
+ @pytest.mark.parametrize("rechunk", [True, False])
1020
+ def test_quantile(rechunk, q, axis):
1021
+ shape = 10, 15, 20, 15
1022
+ arr = np.random.randn(*shape)
1023
+ indexer = np.random.randint(0, 10, size=shape)
1024
+ arr[indexer >= 8] = np.nan
1025
+
1026
+ darr = da.from_array(arr, chunks=(2, 3, 4, (5 if rechunk else -1)))
1027
+ assert_eq(da.quantile(darr, q, axis=axis), np.quantile(arr, q, axis=axis))
1028
+ assert_eq(
1029
+ da.quantile(darr, q, axis=axis, keepdims=True),
1030
+ np.quantile(arr, q, axis=axis, keepdims=True),
1031
+ )
1032
+ assert_eq(da.percentile(darr, q, axis=axis), np.percentile(arr, q, axis=axis))
1033
+ assert_eq(
1034
+ da.percentile(darr, q, axis=axis, keepdims=True),
1035
+ np.percentile(arr, q, axis=axis, keepdims=True),
1036
+ )
1037
+
1038
+
1039
+ @pytest.mark.parametrize("func", [da.quantile, da.nanquantile, da.nanpercentile])
1040
+ def test_quantile_func_family_with_axis_none(func):
1041
+ # Check that these functions raise a NotImplementedError
1042
+ # when axis=None and more than one chunk is present
1043
+ # along at least one dimension
1044
+ darr = da.ones((3, 3), chunks=(2, 2))
1045
+ with pytest.raises(NotImplementedError, match="The full algorithm is difficult to do in parallel"):
1046
+ func(darr, 0.5, axis=None)
1047
+
1048
+ # Check that the functions behave as expected
1049
+ # when axis=None and the array is a single chunk
1050
+ darr = da.from_array([-1, 0, 1])
1051
+ assert_eq(func(darr, 0.0, axis=None), -1.0)
1052
+
1053
+
1054
+ def test_nanquantile_all_nan():
1055
+ shape = 10, 15, 20, 15
1056
+ arr = np.random.randn(*shape)
1057
+ arr[:] = np.nan
1058
+ darr = da.from_array(arr, chunks=(2, 3, 4, -1))
1059
+ da.nanquantile(darr, 0.75, axis=-1).compute()
1060
+ with pytest.raises(RuntimeWarning):
1061
+ assert_eq(da.nanquantile(darr, 0.75, axis=-1), np.nanquantile(arr, 0.75, axis=-1))
1062
+ assert_eq(da.percentile(darr, 0.75, axis=-1), np.percentile(arr, 0.75, axis=-1))
1063
+
1064
+
1065
+ def test_nanquantile_method():
1066
+ shape = 10, 15, 20, 15
1067
+ arr = np.random.randn(*shape)
1068
+ indexer = np.random.randint(0, 10, size=shape)
1069
+ arr[indexer >= 8] = np.nan
1070
+ darr = da.from_array(arr, chunks=(2, 3, 4, -1))
1071
+ assert_eq(
1072
+ da.nanquantile(darr, 0.75, axis=-1, method="weibull"),
1073
+ np.nanquantile(arr, 0.75, axis=-1, method="weibull"),
1074
+ )
1075
+ assert_eq(
1076
+ da.nanpercentile(darr, 0.75, axis=-1, method="weibull"),
1077
+ np.nanpercentile(arr, 0.75, axis=-1, method="weibull"),
1078
+ )
1079
+
1080
+
1081
+ def test_nanquantile_one_dim():
1082
+ arr = np.random.randn(10)
1083
+ darr = da.from_array(arr, chunks=(2,))
1084
+ assert_eq(da.nanquantile(darr, 0.75, axis=-1), np.nanquantile(arr, 0.75, axis=-1))
1085
+
1086
+
1087
+ def test_nanquantile_two_dims():
1088
+ arr = np.random.randn(10, 10)
1089
+ darr = da.from_array(arr, chunks=(2, -1))
1090
+ assert_eq(da.nanquantile(darr, 0.75, axis=-1), np.nanquantile(arr, 0.75, axis=-1))
1091
+ assert_eq(da.nanpercentile(darr, 0.75, axis=-1), np.nanpercentile(arr, 0.75, axis=-1))