dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,799 @@
1
+ from __future__ import annotations
2
+
3
+ import operator
4
+
5
+ import numpy as np
6
+ import pytest
7
+
8
+ import dask_array as da
9
+ from dask import is_dask_collection
10
+ from dask_array._test_utils import assert_eq
11
+ from dask_array._collection import Array
12
+ from dask_array._rechunk import Rechunk
13
+
14
+
15
+ @pytest.fixture()
16
+ def arr():
17
+ return da.random.random((10, 10), chunks=(5, 6))
18
+
19
+
20
+ @pytest.mark.parametrize(
21
+ "op",
22
+ [
23
+ "__add__",
24
+ "__sub__",
25
+ "__mul__",
26
+ "__truediv__",
27
+ "__floordiv__",
28
+ "__pow__",
29
+ "__radd__",
30
+ "__rsub__",
31
+ "__rmul__",
32
+ "__rtruediv__",
33
+ "__rfloordiv__",
34
+ "__rpow__",
35
+ ],
36
+ )
37
+ def test_arithmetic_ops(arr, op):
38
+ result = getattr(arr, op)(2)
39
+ expected = getattr(arr.compute(), op)(2)
40
+ assert_eq(result, expected)
41
+
42
+
43
+ def test_rechunk(arr):
44
+ result = arr.rechunk((7, 3))
45
+ expected = arr.compute()
46
+ assert_eq(result, expected)
47
+
48
+
49
+ def test_blockwise():
50
+ x = da.random.random((10, 10), chunks=(5, 5))
51
+ z = da.blockwise(operator.add, "ij", x, "ij", 100, None, dtype=x.dtype)
52
+ assert_eq(z, x.compute() + 100)
53
+
54
+ x = da.random.random((10, 10), chunks=(5, 5))
55
+ z = da.blockwise(operator.add, "ij", x, "ij", x, "ij", dtype=x.dtype)
56
+ expr = z.expr.optimize()
57
+ assert len(list(expr.find_operations(Rechunk))) == 0
58
+ assert_eq(z, x.compute() * 2)
59
+
60
+ # align
61
+ x = da.random.random((10, 10), chunks=(5, 5))
62
+ y = da.random.random((10, 10), chunks=(7, 3))
63
+ z = da.blockwise(operator.add, "ij", x, "ij", y, "ij", dtype=x.dtype)
64
+ expr = z.expr.optimize()
65
+ assert len(list(expr.find_operations(Rechunk))) > 0
66
+ assert_eq(z, x.compute() + y.compute())
67
+
68
+
69
+ @pytest.mark.parametrize("func", ["min", "max", "sum", "prod", "mean", "any", "all"])
70
+ def test_reductions(arr, func):
71
+ # var and std need __array_function__
72
+ result = getattr(arr, func)(axis=0)
73
+ expected = getattr(arr.compute(), func)(axis=0)
74
+ assert_eq(result, expected)
75
+
76
+
77
+ @pytest.mark.parametrize(
78
+ "func",
79
+ [
80
+ "sum",
81
+ "mean",
82
+ "any",
83
+ "all",
84
+ "max",
85
+ "min",
86
+ "nanmin",
87
+ "nanmax",
88
+ "nanmean",
89
+ "nansum",
90
+ "nanprod",
91
+ ],
92
+ )
93
+ def test_reductions_toplevel(arr, func):
94
+ # var and std need __array_function__
95
+ result = getattr(da, func)(arr, axis=0)
96
+ expected = getattr(np, func)(arr.compute(), axis=0)
97
+ assert_eq(result, expected)
98
+
99
+
100
+ def test_from_array():
101
+ x = np.random.random((10, 10))
102
+ d = da.from_array(x, chunks=(5, 5))
103
+ assert_eq(d, x)
104
+ assert d.chunks == ((5, 5), (5, 5))
105
+
106
+
107
+ def test_from_graph_same_key_prefix_different_layers():
108
+ from dask_array.core import from_graph
109
+
110
+ a = from_graph(
111
+ {("x", 0): np.array([1])},
112
+ np.empty((0,), dtype=int),
113
+ ((1,),),
114
+ [("x", 0)],
115
+ "a",
116
+ )
117
+ b = from_graph(
118
+ {("x", 0): np.array([2])},
119
+ np.empty((0,), dtype=int),
120
+ ((1,),),
121
+ [("x", 0)],
122
+ "b",
123
+ )
124
+
125
+ assert a.expr is not b.expr
126
+ assert_eq(a, np.array([1]))
127
+ assert_eq(b, np.array([2]))
128
+
129
+
130
+ def test_from_graph_tracks_expression_dependencies():
131
+ from dask._task_spec import DependenciesMapping, Task, TaskRef
132
+ from dask_array.core import from_graph
133
+
134
+ x = da.from_array(np.arange(6), chunks=(3,)).rechunk((2,))
135
+ name = "plus-one"
136
+ layer = {
137
+ (name, i): Task((name, i), operator.add, TaskRef((x.name, i)), 1)
138
+ for i in range(len(x.chunks[0]))
139
+ }
140
+
141
+ y = from_graph(
142
+ layer,
143
+ np.empty((0,), dtype=x.dtype),
144
+ x.chunks,
145
+ [(name, i) for i in range(len(x.chunks[0]))],
146
+ name,
147
+ dependencies=[x],
148
+ )
149
+ optimized = da.Array(y[:4].expr.optimize(fuse=True))
150
+ graph = optimized.__dask_graph__()
151
+ missing = [
152
+ dep
153
+ for deps in DependenciesMapping(graph).values()
154
+ for dep in deps
155
+ if dep not in graph
156
+ ]
157
+
158
+ assert not missing
159
+ assert_eq(optimized, np.arange(4) + 1)
160
+
161
+
162
+ @pytest.mark.xfail(reason="Requires dask core to recognize 'dask_array' module in is_dask_collection")
163
+ def test_is_dask_collection_doesnt_materialize():
164
+ class ArrayTest(Array):
165
+ def __dask_graph__(self):
166
+ raise NotImplementedError
167
+
168
+ arr = ArrayTest(da.random.random((10, 10), chunks=(5, 5)).expr)
169
+ assert is_dask_collection(arr)
170
+ with pytest.raises(NotImplementedError):
171
+ arr.__dask_graph__()
172
+
173
+
174
+ def test_astype():
175
+ x = da.random.randint(1, 100, (10, 10), chunks=(5, 5))
176
+ result = x.astype(np.float64)
177
+ expected = x.compute().astype(np.float64)
178
+ assert_eq(result, expected)
179
+
180
+
181
+ def test_stack_promote_type():
182
+ i = np.arange(10, dtype="i4")
183
+ f = np.arange(10, dtype="f4")
184
+ di = da.from_array(i, chunks=5)
185
+ df = da.from_array(f, chunks=5)
186
+ res = da.stack([di, df])
187
+ assert_eq(res, np.stack([i, f]))
188
+
189
+
190
+ def test_field_access():
191
+ x = np.array([(1, 1.0), (2, 2.0)], dtype=[("a", "i4"), ("b", "f4")])
192
+ y = da.from_array(x, chunks=(1,))
193
+ assert_eq(y["a"], x["a"])
194
+ assert_eq(y[["b", "a"]], x[["b", "a"]])
195
+
196
+
197
+ def test_field_access_with_shape():
198
+ dtype = [("col1", ("f4", (3, 2))), ("col2", ("f4", 3))]
199
+ data = np.ones((100, 50), dtype=dtype)
200
+ x = da.from_array(data, 10)
201
+ assert_eq(x["col1"], data["col1"])
202
+ assert_eq(x[["col1"]], data[["col1"]])
203
+ assert_eq(x["col2"], data["col2"])
204
+ assert_eq(x[["col1", "col2"]], data[["col1", "col2"]])
205
+
206
+
207
+ # =============================================================================
208
+ # Optimization tests (ported from dask-expr prototype)
209
+ # =============================================================================
210
+
211
+
212
+ def test_transpose_optimize():
213
+ """Test that transpose of transpose simplifies."""
214
+ a = np.random.random((10, 20))
215
+ b = da.from_array(a, chunks=(2, 5))
216
+
217
+ # T.T should be identity
218
+ assert b.T.T.expr.optimize()._name == b.expr.optimize()._name
219
+ assert_eq(b.T.T, a)
220
+
221
+ # Explicit axes composition
222
+ c = da.from_array(np.random.random((3, 4, 5)), chunks=(1, 2, 3))
223
+ d = c.transpose((2, 0, 1)).transpose((1, 2, 0)) # Should compose to (0, 1, 2) = identity
224
+ assert_eq(d, c)
225
+
226
+
227
+ def test_rechunk_optimize():
228
+ """Test that rechunk of rechunk simplifies to single rechunk."""
229
+ a = np.random.random((10, 10))
230
+ b = da.from_array(a, chunks=(4, 4))
231
+
232
+ c = b.rechunk((2, 5)).rechunk((5, 2))
233
+ d = b.rechunk((5, 2))
234
+
235
+ # Double rechunk should simplify to single rechunk
236
+ assert c.expr.optimize()._name == d.expr.optimize()._name
237
+ assert_eq(c, a)
238
+
239
+
240
+ def test_slicing_optimize_identity():
241
+ """Test that no-op slice simplifies to identity."""
242
+ a = np.random.random((10, 20))
243
+ b = da.from_array(a, chunks=(2, 5))
244
+
245
+ # b[:] should simplify to b
246
+ assert b[:].expr.optimize()._name == b.expr._name
247
+ assert_eq(b[:], a)
248
+
249
+
250
+ def test_slicing_optimize_fusion():
251
+ """Test that slice of slice fuses into single slice."""
252
+ a = np.random.random((10, 20))
253
+ b = da.from_array(a, chunks=(2, 5))
254
+
255
+ # Slice fusion: b[5:, 4][::2] should equal b[5::2, 4]
256
+ result = b[5:, 4][::2]
257
+ expected = b[5::2, 4]
258
+ assert result.expr.optimize()._name == expected.expr.optimize()._name
259
+ assert_eq(result, a[5::2, 4])
260
+
261
+
262
+ def test_slicing_pushdown_elemwise():
263
+ """Test that slice pushes through elemwise."""
264
+ a = np.random.random((10, 20))
265
+ b = da.from_array(a, chunks=(2, 5))
266
+
267
+ # (b + 1)[:5] should become (b[:5] + 1)
268
+ result = (b + 1)[:5]
269
+ expected = b[:5] + 1
270
+ assert result.expr.optimize()._name == expected.expr.optimize()._name
271
+ assert_eq(result, (a + 1)[:5])
272
+
273
+ # Test with integer index that reduces dimension
274
+ result2 = (b + 1)[5]
275
+ expected2 = b[5] + 1
276
+ assert result2.expr.optimize()._name == expected2.expr.optimize()._name
277
+ assert_eq(result2, (a + 1)[5])
278
+
279
+
280
+ def test_slicing_pushdown_elemwise_broadcast():
281
+ """Test slice pushdown through elemwise with broadcasting."""
282
+ a = np.random.random((10, 20))
283
+ c = np.random.random((20,)) # broadcasts on axis 0
284
+ aa = da.from_array(a, chunks=(2, 5))
285
+ cc = da.from_array(c, chunks=(5,))
286
+
287
+ # (aa + cc)[:5] should become (aa[:5] + cc)
288
+ # cc doesn't get sliced because axis 0 is broadcast
289
+ result = (aa + cc)[:5]
290
+ expected = aa[:5] + cc
291
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
292
+ assert_eq(result, (a + c)[:5])
293
+
294
+ # (aa + cc)[:, ::2] should become (aa[:, ::2] + cc[::2])
295
+ result2 = (aa + cc)[:, ::2]
296
+ expected2 = aa[:, ::2] + cc[::2]
297
+ assert result2.expr.simplify()._name == expected2.expr.simplify()._name
298
+ assert_eq(result2, (a + c)[:, ::2])
299
+
300
+
301
+ def test_slicing_pushdown_transpose():
302
+ """Test slice pushdown through transpose."""
303
+ a = np.random.random((10, 20))
304
+ b = da.from_array(a, chunks=(2, 5))
305
+
306
+ # b.T[5:] should become b[:, 5:].T
307
+ result = b.T[5:]
308
+ expected = b[:, 5:].T
309
+ assert result.expr.optimize()._name == expected.expr.optimize()._name
310
+ assert_eq(result, a.T[5:])
311
+
312
+
313
+ def test_rechunk_pushdown_transpose():
314
+ """Test rechunk pushdown through transpose."""
315
+ a = np.random.random((10, 20))
316
+ b = da.from_array(a, chunks=(2, 5))
317
+
318
+ # b.T.rechunk((10, 5)) should become Transpose(Rechunk(...))
319
+ # not Rechunk(Transpose(...))
320
+ result = b.T.rechunk((10, 5))
321
+ opt = result.expr.optimize()
322
+ # Should be Transpose at top level (rechunk pushed inside)
323
+ assert type(opt).__name__ == "Transpose"
324
+ assert_eq(result, a.T)
325
+
326
+
327
+ def test_rechunk_pushdown_elemwise():
328
+ """Test rechunk pushdown through elemwise."""
329
+ a = np.random.random((10, 10))
330
+ b = da.from_array(a, chunks=(4, 4))
331
+
332
+ # (b + 1).rechunk((5, 5)) should become Elemwise at top level
333
+ # not Rechunk(Elemwise(...))
334
+ result = (b + 1).rechunk((5, 5))
335
+ opt = result.expr.optimize()
336
+ # Should be Elemwise at top level (rechunk pushed inside)
337
+ assert type(opt).__name__ == "Elemwise"
338
+ assert_eq(result, a + 1)
339
+
340
+
341
+ def test_rechunk_pushdown_elemwise_broadcast():
342
+ """Test rechunk pushdown through elemwise with broadcasting."""
343
+ a = np.random.random((10,))
344
+ aa = da.from_array(a)
345
+ b = np.random.random((10, 10))
346
+ bb = da.from_array(b)
347
+
348
+ # (aa + bb).rechunk((5, 2)) should become Elemwise at top level
349
+ c = (aa + bb).rechunk((5, 2))
350
+ # Expected: rechunk pushed to inputs
351
+ expected = aa.rechunk((2,)) + bb.rechunk((5, 2))
352
+ assert c.expr.simplify()._name == expected.expr.simplify()._name
353
+
354
+ opt = c.expr.optimize()
355
+ # Should be Elemwise at top level (rechunk pushed inside)
356
+ assert type(opt).__name__ == "Elemwise"
357
+ assert_eq(c, a + b)
358
+
359
+
360
+ # =============================================================================
361
+ # Optimization correctness and safety tests
362
+ # =============================================================================
363
+
364
+
365
+ def test_optimization_correctness_various_chains():
366
+ """Verify optimized expressions produce correct results."""
367
+ np.random.seed(42)
368
+ a = da.random.random((15, 25), chunks=(3, 7))
369
+ a_np = a.compute()
370
+
371
+ # Various operation chains - verify correctness
372
+ assert_eq(a.T.T, a_np)
373
+ assert_eq(a.T[5:].T, a_np[:, 5:])
374
+ assert_eq((a + 1).rechunk((5, 5))[:10], (a_np + 1)[:10])
375
+ assert_eq(a.rechunk((5, 5)).rechunk((3, 3)), a_np)
376
+ assert_eq(a[::2, 1:][::2], a_np[::2, 1:][::2])
377
+ assert_eq((a * 2)[:, 10:][5:], (a_np * 2)[:, 10:][5:])
378
+
379
+
380
+ def test_optimize_empty_array():
381
+ """Verify optimizations handle empty arrays."""
382
+ a = da.zeros((0, 10), chunks=(1, 5))
383
+ result = (a + 1)[:, :5]
384
+ assert result.shape == (0, 5)
385
+ assert_eq(result, np.zeros((0, 5)))
386
+
387
+
388
+ def test_optimize_3d_transpose():
389
+ """Verify transpose composition works for 3D arrays."""
390
+ np.random.seed(42)
391
+ a = da.random.random((4, 5, 6), chunks=2)
392
+
393
+ # (2,0,1) then (1,2,0) should compose to identity
394
+ result = a.transpose((2, 0, 1)).transpose((1, 2, 0))
395
+ opt = result.expr.optimize()
396
+ # Should simplify to original (no Transpose at top)
397
+ assert type(opt).__name__ != "Transpose" or opt.axes == tuple(range(3))
398
+ assert_eq(result, a)
399
+
400
+
401
+ def test_optimize_scalar_in_elemwise():
402
+ """Verify scalar handling in elemwise pushdown."""
403
+ np.random.seed(42)
404
+ b = da.random.random((10, 10), chunks=5)
405
+ b_np = b.compute()
406
+
407
+ # Scalar + array, then slice
408
+ result = (5 + b)[:5]
409
+ assert_eq(result, (5 + b_np)[:5])
410
+
411
+ # Slice then rechunk with scalar
412
+ result = (b * 2).rechunk((5, 5))
413
+ assert_eq(result, b_np * 2)
414
+
415
+
416
+ def test_chunks_preserved_after_optimization():
417
+ """Verify chunk structure is correct after optimization."""
418
+ a = da.random.random((20, 20), chunks=(4, 5))
419
+
420
+ # Transpose then rechunk
421
+ result = a.T.rechunk((10, 10))
422
+ assert result.chunks == ((10, 10), (10, 10))
423
+
424
+ # Elemwise then slice
425
+ result = (a + 1)[:10, :15]
426
+ assert result.chunks == ((4, 4, 2), (5, 5, 5))
427
+
428
+ # Slice then rechunk
429
+ result = a[:12, :8].rechunk((6, 4))
430
+ assert result.chunks == ((6, 6), (4, 4))
431
+
432
+
433
+ def test_pushdown_broadcast_both_arrays():
434
+ """Test pushdown when both arrays broadcast to output shape."""
435
+ # (10, 1) + (1, 20) -> (10, 20)
436
+ a = da.from_array(np.random.random((10, 1)), chunks=(5, 1))
437
+ b = da.from_array(np.random.random((1, 20)), chunks=(1, 10))
438
+ a_np, b_np = a.compute(), b.compute()
439
+
440
+ # Slice pushdown - each input sliced on its non-broadcast dimension
441
+ result = (a + b)[:5, :10]
442
+ opt = result.expr.optimize()
443
+ assert type(opt).__name__ == "Elemwise"
444
+ # Input shapes should be sliced appropriately
445
+ assert opt.elemwise_args[0].shape == (5, 1)
446
+ assert opt.elemwise_args[1].shape == (1, 10)
447
+ assert_eq(result, (a_np + b_np)[:5, :10])
448
+
449
+ # Rechunk pushdown - each input rechunked on its non-broadcast dimension
450
+ result = (a + b).rechunk((2, 5))
451
+ opt = result.expr.optimize()
452
+ assert type(opt).__name__ == "Elemwise"
453
+ # Input chunks should be rechunked appropriately
454
+ assert opt.elemwise_args[0].chunks == ((2, 2, 2, 2, 2), (1,))
455
+ assert opt.elemwise_args[1].chunks == ((1,), (5, 5, 5, 5))
456
+ assert_eq(result, a_np + b_np)
457
+
458
+
459
+ def test_rechunk_pushdown_to_io():
460
+ """Rechunk should push down into FromArray by changing chunks parameter."""
461
+ from dask_array.io import FromArray
462
+
463
+ a = np.random.random((10, 10))
464
+ b = da.from_array(a, chunks=(4, 4))
465
+
466
+ result = b.rechunk((5, 2)).expr.optimize()
467
+ expected = da.from_array(a, chunks=((5, 5), (2, 2, 2, 2, 2))).expr
468
+
469
+ # Both should be FromArray with matching structure
470
+ assert type(result) is FromArray
471
+ assert result._name == expected._name
472
+
473
+
474
+ def test_rechunk_chain_optimize():
475
+ """Chained rechunks should collapse to single rechunk pushed to IO."""
476
+ from dask_array.io import FromArray
477
+
478
+ a = np.random.random((10, 10))
479
+ b = da.from_array(a, chunks=(4, 4))
480
+
481
+ result = b.rechunk((2, 5)).rechunk((5, 2)).expr.optimize()
482
+ expected = da.from_array(a, chunks=((5, 5), (2, 2, 2, 2, 2))).expr
483
+
484
+ # Both rechunks eliminated, just FromArray
485
+ assert type(result) is FromArray
486
+ assert result._name == expected._name
487
+
488
+
489
+ def test_rechunk_transpose_pushdown_to_io():
490
+ """Rechunk after transpose should push through to IO."""
491
+ from dask_array.io import FromArray
492
+ from dask_array.manipulation._transpose import Transpose
493
+
494
+ a = np.random.random((10, 10))
495
+ b = da.from_array(a, chunks=(4, 4))
496
+
497
+ result = b.T.rechunk((5, 2)).expr.optimize()
498
+ # Rechunk pushed through transpose: input rechunked to (2, 5) then transposed
499
+ expected = da.from_array(a, chunks=((2, 2, 2, 2, 2), (5, 5))).T.expr
500
+
501
+ assert type(result) is Transpose
502
+ assert type(result.array) is FromArray
503
+ assert result._name == expected._name
504
+
505
+
506
+ def test_rechunk_elemwise_pushdown_to_io():
507
+ """Rechunk after elemwise should push through to IO inputs."""
508
+ from dask_array._blockwise import Elemwise
509
+ from dask_array.io import FromArray
510
+
511
+ a = np.random.random((10, 10))
512
+ b = da.from_array(a, chunks=(4, 4))
513
+
514
+ result = (b + 1).rechunk((5, 5)).expr.optimize()
515
+
516
+ # Rechunk pushed through elemwise into FromArray
517
+ assert type(result) is Elemwise
518
+ assert type(result.elemwise_args[0]) is FromArray
519
+ assert result.elemwise_args[0].chunks == ((5, 5), (5, 5))
520
+ # Verify the prefix is preserved
521
+ assert result.elemwise_args[0].name.startswith("array-")
522
+
523
+
524
+ def test_rechunk_pushdown_concatenate_other_axis():
525
+ """Rechunk pushes through concatenate when rechunking non-concat axis."""
526
+ a = da.ones((10, 20), chunks=(5, 10))
527
+ b = da.ones((10, 20), chunks=(5, 10))
528
+ concat = da.concatenate([a, b], axis=0) # shape (20, 20)
529
+
530
+ # Rechunk axis 1 (not concat axis)
531
+ result = concat.rechunk({1: 5})
532
+
533
+ # Expected: rechunk pushed to inputs
534
+ expected = da.concatenate([a.rechunk({1: 5}), b.rechunk({1: 5})], axis=0)
535
+
536
+ # Structure should match
537
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
538
+ assert_eq(result, expected)
539
+
540
+
541
+ def test_rechunk_pushdown_concatenate_correctness():
542
+ """Verify rechunk through concatenate produces correct values with real data."""
543
+ a = np.arange(20).reshape(4, 5)
544
+ b = np.arange(20, 40).reshape(4, 5)
545
+ da_a = da.from_array(a, chunks=(2, 3))
546
+ da_b = da.from_array(b, chunks=(2, 3))
547
+
548
+ concat = da.concatenate([da_a, da_b], axis=0) # shape (8, 5)
549
+
550
+ # Rechunk non-concat axis
551
+ result = concat.rechunk({1: 2})
552
+ expected = da.concatenate([da_a.rechunk({1: 2}), da_b.rechunk({1: 2})], axis=0)
553
+
554
+ # Structure should match
555
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
556
+ assert_eq(result, np.concatenate([a, b], axis=0))
557
+
558
+
559
+ # --- Fusion regression tests ---
560
+
561
+
562
+ def test_fusion_broadcast_modulo():
563
+ """Test that fusion handles broadcasting correctly with modulo.
564
+
565
+ When fusing operations where one array broadcasts (has fewer blocks),
566
+ the block indices must use modulo to wrap around correctly.
567
+ This is a regression test for matmul-like operations.
568
+ """
569
+ # 1D array broadcasting to 2D - simulates matmul broadcast pattern
570
+ a = da.from_array(np.arange(6).reshape(2, 3), chunks=(1, 3))
571
+ b = da.from_array(np.arange(3), chunks=3)
572
+
573
+ # b broadcasts: it has 1 block but a has 2 blocks in first dimension
574
+ result = a * b # Elemwise with broadcast
575
+ assert_eq(result, np.arange(6).reshape(2, 3) * np.arange(3))
576
+
577
+ # Test that the fused graph computes correctly
578
+ opt = result.expr.optimize(fuse=True)
579
+ assert_eq(da.Array(opt), np.arange(6).reshape(2, 3) * np.arange(3))
580
+
581
+
582
+ def test_fusion_same_array_different_indices():
583
+ """Test conflict detection when same array used with different indices.
584
+
585
+ When the same array appears multiple times in a computation with
586
+ different index mappings (e.g., da.dot(x, x)), fusion must detect
587
+ this conflict and exclude the conflicting expression.
588
+ """
589
+ # da.dot(x, x) uses x with indices 'ij' and 'jk', different mappings
590
+ x = da.from_array(np.arange(9).reshape(3, 3), chunks=(2, 2))
591
+ x_np = x.compute()
592
+
593
+ result = da.dot(x, x)
594
+ expected = np.dot(x_np, x_np)
595
+ assert_eq(result, expected)
596
+
597
+ # Test with persist (triggers the conflict path during fusion)
598
+ result_persisted = result.persist()
599
+ assert_eq(result_persisted, expected)
600
+
601
+
602
+ def test_fusion_elemwise_with_out_and_where_true():
603
+ """Test that out arrays don't break fusion when where=True.
604
+
605
+ When an Elemwise has out=array but where=True (the default),
606
+ the out array should not be a dependency since it's not used
607
+ in the computation - it's just a placeholder for the result.
608
+ """
609
+ a = da.from_array(np.arange(4), chunks=2)
610
+ b = da.from_array(np.arange(4, 8), chunks=2)
611
+ out = da.zeros(4, chunks=2)
612
+
613
+ # When where=True (default), out is just a placeholder
614
+ result = da.add(a, b, out=out)
615
+ assert result is out
616
+
617
+ # Should compute correctly despite fusion
618
+ assert_eq(result, np.arange(4) + np.arange(4, 8))
619
+
620
+
621
+ def test_fusion_elemwise_with_out_and_where_array():
622
+ """Test that out arrays are properly used when where is an array.
623
+
624
+ When where is a mask array (not True), the out array IS used
625
+ as a dependency and should participate in the computation.
626
+ """
627
+ a = da.from_array(np.arange(4), chunks=2)
628
+ b = da.from_array(np.arange(4, 8), chunks=2)
629
+ where = da.from_array(np.array([True, False, True, False]), chunks=2)
630
+ out = da.zeros(4, dtype=int, chunks=2)
631
+
632
+ result = da.add(a, b, where=where, out=out)
633
+ assert result is out
634
+
635
+ # Should compute correctly: only positions where=True get the sum
636
+ expected = np.zeros(4, dtype=int)
637
+ np.add(
638
+ np.arange(4),
639
+ np.arange(4, 8),
640
+ where=np.array([True, False, True, False]),
641
+ out=expected,
642
+ )
643
+ assert_eq(result, expected)
644
+
645
+
646
+ def test_fusion_out_same_as_input():
647
+ """Test that out=x works when x is also an input argument.
648
+
649
+ When out is the same array as an input (e.g., np.sin(x, out=x)),
650
+ we must NOT exclude it from dependencies since it's actually used.
651
+ """
652
+ x = da.from_array(np.array([0.0, 0.5, 1.0, 1.5]), chunks=2)
653
+ x_np = x.compute().copy()
654
+
655
+ # In-place operation: out is same as input
656
+ result = np.sin(x, out=x)
657
+ assert result is x
658
+
659
+ expected = np.sin(x_np, out=x_np)
660
+ assert_eq(result, expected)
661
+
662
+
663
+ def test_fusion_transpose_conflict():
664
+ """Test conflict detection for a + a.T pattern.
665
+
666
+ When the same array is accessed both directly and transposed,
667
+ fusion must detect this conflict since different output blocks
668
+ would need different source blocks from the same expression.
669
+ """
670
+ a = da.from_array(np.arange(9).reshape(3, 3), chunks=(2, 2))
671
+ a_np = a.compute()
672
+
673
+ # a + a.T accesses 'a' with different index mappings
674
+ result = a + a.T
675
+ expected = a_np + a_np.T
676
+ assert_eq(result, expected)
677
+
678
+ # Verify fusion handles this correctly
679
+ opt = result.expr.optimize(fuse=True)
680
+ assert_eq(da.Array(opt), expected)
681
+
682
+
683
+ def test_fusion_chained_transpose():
684
+ """Test fusion with chained transpose operations.
685
+
686
+ Operations like (a + b).T should fuse correctly since there's
687
+ no conflict - just a consistent dimension permutation.
688
+ """
689
+ a = da.from_array(np.arange(6).reshape(2, 3), chunks=(1, 2))
690
+ b = da.from_array(np.arange(6, 12).reshape(2, 3), chunks=(1, 2))
691
+ a_np, b_np = a.compute(), b.compute()
692
+
693
+ result = (a + b).T
694
+ expected = (a_np + b_np).T
695
+ assert_eq(result, expected)
696
+
697
+ # Should fuse the add and transpose
698
+ opt = result.expr.optimize(fuse=True)
699
+ assert_eq(da.Array(opt), expected)
700
+
701
+
702
+ def test_reduction_scalar_aggregate_meta():
703
+ """Regression test: reduction handles aggregate returning Python scalar.
704
+
705
+ When a custom aggregate function returns a Python scalar instead of
706
+ preserving array dimensions, the meta computation must not fail.
707
+ Previously failed with:
708
+ ValueError: cannot reshape array of size 1 into shape (0,0)
709
+ """
710
+ arr = da.ones((10, 5, 5), chunks=(5, 5, 5))
711
+
712
+ # Custom aggregate that returns Python int (not numpy array)
713
+ def scalar_agg(x, axis=None, keepdims=False):
714
+ return 42
715
+
716
+ # Should not raise ValueError when accessing _meta
717
+ result = da.reduction(
718
+ arr,
719
+ chunk=np.sum,
720
+ aggregate=scalar_agg,
721
+ axis=0,
722
+ dtype=float,
723
+ )
724
+ assert result._meta.shape == (0, 0)
725
+ assert result._meta.dtype == np.float64
726
+
727
+
728
+ def test_fusion_blockwise_contracted_dimensions():
729
+ """Test fusion with Blockwise that has contracted dimensions.
730
+
731
+ When a Blockwise expression has indices in input that are not in output
732
+ (contracted dimensions), the fusion must correctly handle block lookups.
733
+
734
+ This is a regression test for xarray integration where groupby operations
735
+ create Blockwise with out_ind=(2,) for 1D output from 3D input with
736
+ ind=(0, 1, 2). When fused with Elemwise (out_ind=(0,)), the idx_to_block
737
+ mapping must correctly handle the contracted dimensions 0 and 1.
738
+
739
+ Previously failed with KeyError: 0 in FusedBlockwise._task().
740
+ """
741
+ from dask_array._blockwise import FusedBlockwise
742
+
743
+ # Create 3D array with single blocks in contracted dimensions
744
+ arr_3d = da.from_array(np.ones((1, 1, 3)), chunks=(1, 1, 1))
745
+
746
+ # Blockwise that reduces dims 0 and 1, keeps dim 2 as output
747
+ # out_ind=(2,) means output indexed by input's dimension 2
748
+ result = da.blockwise(
749
+ lambda x: x.mean(axis=(0, 1)),
750
+ (2,), # out_ind - output dimension comes from input dim 2
751
+ arr_3d.expr,
752
+ (0, 1, 2), # ind - input has all 3 dimensions
753
+ dtype=arr_3d.dtype,
754
+ )
755
+
756
+ # Verify Blockwise is fusable when contracted dims have single blocks
757
+ assert result.expr._is_blockwise_fusable
758
+
759
+ # Elemwise comparison - has out_ind=(0,)
760
+ expected = np.array([1.0, 1.0, 1.0])
761
+ close = da.isclose(result, expected)
762
+
763
+ # Should fuse Elemwise (out_ind=(0,)) with Blockwise (out_ind=(2,))
764
+ optimized = close.expr.optimize(fuse=True)
765
+ assert isinstance(optimized, FusedBlockwise)
766
+
767
+ # Verify correct computation
768
+ assert_eq(close, np.array([True, True, True]))
769
+
770
+
771
+ def test_fusion_blockwise_multiblock_contracted_prevents_fusion():
772
+ """Test that Blockwise with multi-block contracted dims isn't fusable.
773
+
774
+ When a Blockwise has contracted dimensions (in input but not output) with
775
+ multiple blocks, fusion is not possible since each output block would need
776
+ to reference multiple input blocks from the contracted dimension.
777
+ """
778
+ from dask_array._blockwise import FusedBlockwise
779
+
780
+ # Create 3D array with multiple blocks in contracted dimension 0
781
+ arr_3d = da.from_array(np.ones((2, 1, 3)), chunks=(1, 1, 1))
782
+
783
+ result = da.blockwise(
784
+ lambda x: x.sum(),
785
+ (2,), # output indexed by dim 2
786
+ arr_3d.expr,
787
+ (0, 1, 2),
788
+ dtype=arr_3d.dtype,
789
+ )
790
+
791
+ # Should NOT be fusable due to multi-block contracted dimension
792
+ assert not result.expr._is_blockwise_fusable
793
+
794
+ # Elemwise wrapping the Blockwise
795
+ close = da.isclose(result, np.array([1.0, 1.0, 1.0]))
796
+
797
+ # Should NOT fuse since Blockwise isn't fusable
798
+ optimized = close.expr.optimize(fuse=True)
799
+ assert not isinstance(optimized, FusedBlockwise)