dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,366 @@
1
+ """Tests for slice pushdown through MapOverlap.
2
+
3
+ These tests verify that slicing operations can be pushed through map_overlap
4
+ operations, reducing computation by slicing input arrays before applying
5
+ overlap boundaries.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import numpy as np
11
+ import pytest
12
+
13
+ import dask_array as da
14
+ from dask_array._test_utils import assert_eq
15
+
16
+
17
+ def add_neighbors(x):
18
+ """Add neighboring values along axis 0. Uses overlap data."""
19
+ result = x.copy()
20
+ if x.shape[0] > 2:
21
+ result[1:-1] = x[:-2] + x[1:-1] + x[2:]
22
+ return result
23
+
24
+
25
+ # =============================================================================
26
+ # Case 1: Slice on non-overlap axis (should push through)
27
+ # =============================================================================
28
+
29
+
30
+ def test_slice_through_overlap_non_overlap_axis():
31
+ """Slice on axis without overlap pushes through."""
32
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
33
+ x = da.from_array(arr, chunks=(10, 10))
34
+
35
+ # Overlap only on axis 0
36
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
37
+
38
+ # Slice on axis 1 (no overlap) - should be equivalent to slicing input first
39
+ sliced = result[:, :20]
40
+ expected = x[:, :20].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
41
+
42
+ # Verify expression structure matches
43
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
44
+
45
+
46
+ def test_slice_through_overlap_middle_slice():
47
+ """Slice in the middle of non-overlap axis."""
48
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
49
+ x = da.from_array(arr, chunks=(10, 10))
50
+
51
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
52
+
53
+ # Middle slice on axis 1 (no overlap)
54
+ sliced = result[:, 30:70]
55
+ expected = x[:, 30:70].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
56
+
57
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
58
+
59
+
60
+ def test_slice_through_overlap_correctness():
61
+ """Verify slice through overlap produces correct values."""
62
+ arr = np.arange(64).reshape((8, 8)).astype(float)
63
+ x = da.from_array(arr, chunks=(4, 4))
64
+
65
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
66
+
67
+ # Slice on axis 1
68
+ sliced = result[:, 2:6]
69
+ expected = x[:, 2:6].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
70
+
71
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
72
+
73
+
74
+ # =============================================================================
75
+ # Case 2: Slice on overlap axis (pushes through with padding)
76
+ # =============================================================================
77
+
78
+
79
+ def test_slice_on_overlap_axis_pushes_with_padding():
80
+ """Slice on axis with overlap pushes through with padded input."""
81
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
82
+ x = da.from_array(arr, chunks=(10, 10))
83
+
84
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
85
+
86
+ # Slice on axis 0 (has overlap) - should push through with padded input
87
+ # [:50] with depth=2 needs input [:52], then trim to [:50]
88
+ sliced = result[:50, :]
89
+ expected = x[:52, :].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")[:50, :]
90
+
91
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
92
+
93
+
94
+ def test_slice_on_both_axes_one_has_overlap():
95
+ """Slice on both axes when one has overlap."""
96
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
97
+ x = da.from_array(arr, chunks=(10, 10))
98
+
99
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
100
+ sliced = result[:50, :50]
101
+
102
+ # Axis 1 has no overlap: slice pushes directly
103
+ # Axis 0 has depth=2: need padded input [:52], then trim to [:50]
104
+ expected = x[:52, :50].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")[:50, :]
105
+
106
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
107
+
108
+
109
+ # =============================================================================
110
+ # Case 3: Multi-dimensional overlap
111
+ # =============================================================================
112
+
113
+
114
+ def add_neighbors_2d(x):
115
+ """Add neighboring values along both axes. Uses overlap data."""
116
+ result = x.copy()
117
+ if x.shape[0] > 2:
118
+ result[1:-1, :] += x[:-2, :] + x[2:, :]
119
+ if x.shape[1] > 2:
120
+ result[:, 1:-1] += x[:, :-2] + x[:, 2:]
121
+ return result
122
+
123
+
124
+ def test_slice_through_2d_overlap():
125
+ """Slice through 2D overlap - pushes when beneficial."""
126
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
127
+ x = da.from_array(arr, chunks=(10, 10))
128
+
129
+ result = x.map_overlap(add_neighbors_2d, depth={0: 1, 1: 1}, boundary="none")
130
+
131
+ # Slice on axis 1 with depth=1 needs input [:, :41], then trim to [:, :40]
132
+ sliced = result[:, :40]
133
+ expected = x[:, :41].map_overlap(add_neighbors_2d, depth={0: 1, 1: 1}, boundary="none")[:, :40]
134
+
135
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
136
+
137
+
138
+ def test_slice_through_2d_overlap_middle():
139
+ """Middle slice through 2D overlap on non-overlap dimension."""
140
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
141
+ x = da.from_array(arr, chunks=(10, 10))
142
+
143
+ # Overlap only on axis 0
144
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
145
+
146
+ # Middle slice on axis 1 (no overlap)
147
+ sliced = result[:, 25:75]
148
+ expected = x[:, 25:75].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
149
+
150
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
151
+
152
+
153
+ def test_slice_through_1d_overlap_on_3d_array():
154
+ """Slice on multiple non-overlap axes."""
155
+ arr = np.arange(1000).reshape((10, 10, 10)).astype(float)
156
+ x = da.from_array(arr, chunks=(5, 5, 5))
157
+
158
+ # Overlap only on axis 0
159
+ result = x.map_overlap(add_neighbors, depth={0: 1, 1: 0, 2: 0}, boundary="none")
160
+
161
+ # Slice on axes 1 and 2 (neither has overlap)
162
+ sliced = result[:, :3, :3]
163
+ expected = x[:, :3, :3].map_overlap(add_neighbors, depth={0: 1, 1: 0, 2: 0}, boundary="none")
164
+
165
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
166
+
167
+
168
+ # =============================================================================
169
+ # Case 4: Asymmetric overlap
170
+ # =============================================================================
171
+
172
+
173
+ def test_slice_through_asymmetric_overlap():
174
+ """Slice through asymmetric overlap (different left/right depth)."""
175
+ arr = np.arange(64).reshape((8, 8)).astype(float)
176
+ x = da.from_array(arr, chunks=(4, 4))
177
+
178
+ # Asymmetric overlap on axis 0
179
+ result = x.map_overlap(add_neighbors, depth={0: (2, 1), 1: 0}, boundary="none")
180
+
181
+ # Slice on axis 1 (no overlap)
182
+ sliced = result[:, 2:6]
183
+ expected = x[:, 2:6].map_overlap(add_neighbors, depth={0: (2, 1), 1: 0}, boundary="none")
184
+
185
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
186
+
187
+
188
+ def test_slice_on_asymmetric_overlap_axis_pushes():
189
+ """Slice on axis with asymmetric overlap pushes through with padding."""
190
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
191
+ x = da.from_array(arr, chunks=(10, 10))
192
+
193
+ result = x.map_overlap(add_neighbors, depth={0: (2, 1), 1: 0}, boundary="none")
194
+
195
+ # Slice axis 0 with asymmetric depth (2, 1) - needs extra 1 on right
196
+ # [:50] needs input [:51], then trim to [:50]
197
+ sliced = result[:50, :]
198
+ expected = x[:51, :].map_overlap(add_neighbors, depth={0: (2, 1), 1: 0}, boundary="none")[:50, :]
199
+
200
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
201
+
202
+
203
+ # =============================================================================
204
+ # Case 5: Zero overlap (edge case)
205
+ # =============================================================================
206
+
207
+
208
+ def test_slice_through_zero_overlap():
209
+ """Slice through axis with zero overlap pushes through."""
210
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
211
+ x = da.from_array(arr, chunks=(10, 10))
212
+
213
+ # Zero overlap - no actual overlap computation needed
214
+ result = x.map_overlap(add_neighbors, depth=0, boundary="none")
215
+
216
+ # Slice on axis 0 - with zero overlap, slice should push through
217
+ sliced = result[:50, :]
218
+ expected = x[:50, :].map_overlap(add_neighbors, depth=0, boundary="none")
219
+
220
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
221
+
222
+
223
+ # =============================================================================
224
+ # Case 6: Task reduction verification
225
+ # =============================================================================
226
+
227
+
228
+ def test_slice_through_overlap_reduces_tasks():
229
+ """Verify slice pushdown reduces task count."""
230
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
231
+ x = da.from_array(arr, chunks=(10, 10))
232
+
233
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
234
+
235
+ full = result
236
+ sliced = result[:, :10] # Take only first 10 columns
237
+
238
+ full_tasks = len(full.optimize().__dask_graph__())
239
+ sliced_tasks = len(sliced.optimize().__dask_graph__())
240
+
241
+ # Sliced should have fewer tasks (processes 1 column of chunks vs 10)
242
+ assert sliced_tasks < full_tasks
243
+
244
+
245
+ def test_slice_through_overlap_reduces_numblocks():
246
+ """Verify slice pushdown reduces number of output blocks."""
247
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
248
+ x = da.from_array(arr, chunks=(10, 10))
249
+
250
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
251
+ sliced = result[:, :10]
252
+
253
+ # Full result: 10x10 chunks
254
+ assert result.numblocks == (10, 10)
255
+
256
+ # Sliced result: 10x1 chunks (only 1 column of blocks)
257
+ assert sliced.numblocks == (10, 1)
258
+
259
+
260
+ # =============================================================================
261
+ # Case 7: Correctness with computed values
262
+ # =============================================================================
263
+
264
+
265
+ @pytest.mark.parametrize(
266
+ "shape,chunks,depth,slice_",
267
+ [
268
+ # Start slices (:n form) on non-overlap axes
269
+ ((80, 80), (20, 20), {0: 2, 1: 0}, (slice(None), slice(20))),
270
+ ((80, 80), (20, 20), {0: 0, 1: 2}, (slice(20), slice(None))),
271
+ # Middle slices (k:n form) on non-overlap axes
272
+ ((80, 80), (20, 20), {0: 2, 1: 0}, (slice(None), slice(20, 60))),
273
+ ((80, 80), (20, 20), {0: 0, 1: 2}, (slice(20, 60), slice(None))),
274
+ # End slices (k: form) on non-overlap axes
275
+ ((80, 80), (20, 20), {0: 2, 1: 0}, (slice(None), slice(40, None))),
276
+ ((80, 80), (20, 20), {0: 0, 1: 2}, (slice(40, None), slice(None))),
277
+ ],
278
+ )
279
+ def test_slice_through_overlap_parametrized(shape, chunks, depth, slice_):
280
+ """Parametrized correctness tests for slice through overlap."""
281
+ arr = np.arange(np.prod(shape)).reshape(shape).astype(float)
282
+ x = da.from_array(arr, chunks=chunks)
283
+
284
+ result = x.map_overlap(add_neighbors, depth=depth, boundary="none")
285
+ sliced = result[slice_]
286
+
287
+ # Build expected: slice input first, then overlap
288
+ input_sliced = x[slice_]
289
+ expected = input_sliced.map_overlap(add_neighbors, depth=depth, boundary="none")
290
+
291
+ # Verify expression structure matches
292
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
293
+
294
+
295
+ # =============================================================================
296
+ # Case 8: Special cases (trim=False, uniform depth)
297
+ # =============================================================================
298
+
299
+
300
+ def test_map_overlap_no_trim_slice_pushes():
301
+ """With trim=False, slice should push through to input."""
302
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
303
+ x = da.from_array(arr, chunks=(10, 10))
304
+
305
+ # With trim=False, there's no Trim wrapper, so slice can push through
306
+ result = x.map_overlap(add_neighbors, depth={0: 2}, boundary="none", trim=False)
307
+
308
+ # Slice on axis 1 (no overlap on axis 1) - pushes directly through
309
+ sliced = result[:, :30]
310
+ expected = x[:, :30].map_overlap(add_neighbors, depth={0: 2}, boundary="none", trim=False)
311
+
312
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
313
+
314
+
315
+ def test_map_overlap_uniform_depth_correctness():
316
+ """Test with uniform depth (int instead of dict).
317
+
318
+ When slicing on an axis with overlap, the optimization pads the input
319
+ slice to include data needed for overlap, then trims the output.
320
+ """
321
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
322
+ x = da.from_array(arr, chunks=(10, 10))
323
+
324
+ result = x.map_overlap(add_neighbors_2d, depth=2, boundary="none")
325
+ sliced = result[:, :30]
326
+
327
+ # Expected: pad input by depth on sliced axis, apply overlap, then trim
328
+ # [:, :30] with depth=2 needs input [:, :32] to preserve overlap semantics
329
+ expected = x[:, :32].map_overlap(add_neighbors_2d, depth=2, boundary="none")[:, :30]
330
+
331
+ assert sliced.expr.simplify()._name == expected.expr.simplify()._name
332
+
333
+
334
+ # =============================================================================
335
+ # Case 9: Value correctness verification
336
+ # =============================================================================
337
+
338
+
339
+ def test_slice_through_overlap_value_correctness():
340
+ """Verify optimized slice produces correct values."""
341
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
342
+ x = da.from_array(arr, chunks=(10, 10))
343
+
344
+ result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
345
+
346
+ # Slice on non-overlap axis
347
+ sliced = result[:, :50]
348
+
349
+ # Compare against unoptimized computation
350
+ full_result = result.compute()
351
+ assert_eq(sliced, full_result[:, :50])
352
+
353
+
354
+ def test_slice_on_overlap_axis_value_correctness():
355
+ """Verify slice on overlap axis produces correct values."""
356
+ arr = np.arange(10000).reshape((100, 100)).astype(float)
357
+ x = da.from_array(arr, chunks=(10, 10))
358
+
359
+ result = x.map_overlap(add_neighbors_2d, depth=2, boundary="none")
360
+
361
+ # Slice on axis with overlap
362
+ sliced = result[:50, :50]
363
+
364
+ # Compare against unoptimized computation
365
+ full_result = result.compute()
366
+ assert_eq(sliced, full_result[:50, :50])
@@ -0,0 +1,272 @@
1
+ """Tests for slice pushdown through Reshape expressions.
2
+
3
+ Slice can push through Reshape when leading dimensions are preserved,
4
+ i.e., the reshape only affects trailing dimensions.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+ import pytest
11
+
12
+ import dask_array as da
13
+ from dask_array._test_utils import assert_eq
14
+
15
+ # =============================================================================
16
+ # Case 1: Leading dimension preserved - slice should push through
17
+ # =============================================================================
18
+
19
+
20
+ def test_slice_through_reshape_leading_dim_preserved():
21
+ """Slice on preserved leading dimension pushes through reshape."""
22
+ arr = np.arange(60).reshape((10, 6))
23
+ x = da.from_array(arr, chunks=(5, 3))
24
+
25
+ # Reshape (10, 6) -> (10, 2, 3) preserves first dimension
26
+ result = x.reshape((10, 2, 3))[:3]
27
+ expected = x[:3].reshape((3, 2, 3))
28
+
29
+ # After simplification, both should have same structure
30
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
31
+ assert_eq(result, arr.reshape((10, 2, 3))[:3])
32
+
33
+
34
+ def test_slice_through_reshape_flatten_trailing():
35
+ """Slice on leading dimension when flattening trailing dims."""
36
+ arr = np.arange(60).reshape((10, 2, 3))
37
+ x = da.from_array(arr, chunks=(5, 2, 3))
38
+
39
+ # Reshape (10, 2, 3) -> (10, 6) preserves first dimension
40
+ result = x.reshape((10, 6))[:4]
41
+ expected = x[:4].reshape((4, 6))
42
+
43
+ # After optimization, both should produce equivalent graphs
44
+ # (Reshape vs ReshapeLowered difference resolved by lowering)
45
+ assert result.optimize()._name == expected.optimize()._name
46
+ assert_eq(result, arr.reshape((10, 6))[:4])
47
+
48
+
49
+ def test_slice_through_reshape_multiple_leading_dims():
50
+ """Slice when multiple leading dimensions are preserved."""
51
+ arr = np.arange(120).reshape((4, 5, 6))
52
+ x = da.from_array(arr, chunks=(2, 5, 3))
53
+
54
+ # Reshape (4, 5, 6) -> (4, 5, 2, 3) preserves first two dimensions
55
+ result = x.reshape((4, 5, 2, 3))[:2, :3]
56
+ expected = x[:2, :3].reshape((2, 3, 2, 3))
57
+
58
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
59
+ assert_eq(result, arr.reshape((4, 5, 2, 3))[:2, :3])
60
+
61
+
62
+ def test_slice_through_reshape_middle_slice():
63
+ """Middle slice (start > 0) on preserved dimension."""
64
+ arr = np.arange(100).reshape((10, 10))
65
+ x = da.from_array(arr, chunks=(5, 5))
66
+
67
+ # Reshape (10, 10) -> (10, 2, 5) preserves first dimension
68
+ result = x.reshape((10, 2, 5))[3:7]
69
+ expected = x[3:7].reshape((4, 2, 5))
70
+
71
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
72
+ assert_eq(result, arr.reshape((10, 2, 5))[3:7])
73
+
74
+
75
+ # =============================================================================
76
+ # Case 2: Dimension NOT preserved - slice should NOT push through
77
+ # =============================================================================
78
+
79
+
80
+ def test_slice_not_pushed_when_dim_changes():
81
+ """Slice blocked when the sliced dimension changes size."""
82
+ arr = np.arange(60).reshape((6, 10))
83
+ x = da.from_array(arr, chunks=(3, 5))
84
+
85
+ # Reshape (6, 10) -> (2, 3, 10) splits first dimension
86
+ # First dim changes from 6 to 2, so slice should NOT push through
87
+ result = x.reshape((2, 3, 10))[:1]
88
+
89
+ # Just verify correctness - can't push through
90
+ assert_eq(result, arr.reshape((2, 3, 10))[:1])
91
+
92
+
93
+ def test_slice_not_pushed_through_flatten():
94
+ """Slice blocked when reshape completely flattens the array."""
95
+ arr = np.arange(100).reshape((10, 10))
96
+ x = da.from_array(arr, chunks=(5, 5))
97
+
98
+ # Reshape to 1D - no dimension correspondence
99
+ result = x.reshape((100,))[:30]
100
+
101
+ # Just verify correctness - can't push through
102
+ assert_eq(result, arr.reshape((100,))[:30])
103
+
104
+
105
+ def test_slice_on_reshaped_axis_not_pushed():
106
+ """Slice on axis that was created by reshape doesn't push."""
107
+ arr = np.arange(60).reshape((10, 6))
108
+ x = da.from_array(arr, chunks=(5, 3))
109
+
110
+ # Reshape (10, 6) -> (10, 2, 3), then slice on new axis 1
111
+ result = x.reshape((10, 2, 3))[:, :1]
112
+
113
+ # Axis 1 is new (from splitting 6 -> 2, 3), can't push through
114
+ assert_eq(result, arr.reshape((10, 2, 3))[:, :1])
115
+
116
+
117
+ # =============================================================================
118
+ # Case 3: Slice with None (newaxis) - should still push through
119
+ # =============================================================================
120
+
121
+
122
+ def test_slice_with_none_pushes_through():
123
+ """Slice with None (newaxis) should push through and re-apply None."""
124
+ arr = np.arange(60).reshape((10, 6))
125
+ x = da.from_array(arr, chunks=(5, 3))
126
+
127
+ # Reshape (10, 6) -> (10, 2, 3), slice with None
128
+ result = x.reshape((10, 2, 3))[:5, None]
129
+ expected = x[:5].reshape((5, 2, 3))[:, None]
130
+
131
+ # Both should have same structure after simplification
132
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
133
+ assert_eq(result, arr.reshape((10, 2, 3))[:5, None])
134
+
135
+
136
+ def test_slice_with_none_at_end():
137
+ """Slice with None at end of index."""
138
+ arr = np.arange(60).reshape((10, 6))
139
+ x = da.from_array(arr, chunks=(5, 3))
140
+
141
+ result = x.reshape((10, 2, 3))[:5, :, :, None]
142
+ expected = x[:5].reshape((5, 2, 3))[:, :, :, None]
143
+
144
+ # Both should have same structure after simplification
145
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
146
+ assert_eq(result, arr.reshape((10, 2, 3))[:5, :, :, None])
147
+
148
+
149
+ def test_slice_with_multiple_nones():
150
+ """Slice with multiple Nones."""
151
+ arr = np.arange(60).reshape((10, 6))
152
+ x = da.from_array(arr, chunks=(5, 3))
153
+
154
+ result = x.reshape((10, 2, 3))[None, :5, None]
155
+ expected = x[:5].reshape((5, 2, 3))[None, :, None]
156
+
157
+ # Both should have same structure after simplification
158
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
159
+ assert_eq(result, arr.reshape((10, 2, 3))[None, :5, None])
160
+
161
+
162
+ def test_slice_with_none_correctness():
163
+ """Verify correctness with various None positions."""
164
+ arr = np.arange(120).reshape((10, 12))
165
+ x = da.from_array(arr, chunks=(5, 6))
166
+
167
+ # (10, 12) -> (10, 3, 4)
168
+ reshaped = x.reshape((10, 3, 4))
169
+
170
+ # Various slices with Nones
171
+ assert_eq(reshaped[:5, None, :, :], arr.reshape((10, 3, 4))[:5, None, :, :])
172
+ assert_eq(reshaped[None, :5], arr.reshape((10, 3, 4))[None, :5])
173
+ assert_eq(reshaped[:5, :, None, :], arr.reshape((10, 3, 4))[:5, :, None, :])
174
+
175
+
176
+ # =============================================================================
177
+ # Correctness tests with various shapes
178
+ # =============================================================================
179
+
180
+
181
+ @pytest.mark.parametrize(
182
+ "in_shape,out_shape,slice_",
183
+ [
184
+ # Trailing split
185
+ ((20, 6), (20, 2, 3), (slice(10),)),
186
+ ((20, 6), (20, 2, 3), (slice(5, 15),)),
187
+ ((20, 12), (20, 3, 4), (slice(None, 8),)),
188
+ # Trailing merge
189
+ ((20, 2, 3), (20, 6), (slice(10),)),
190
+ ((20, 4, 5), (20, 20), (slice(5, 15),)),
191
+ # Multiple preserved dims
192
+ ((10, 5, 6), (10, 5, 2, 3), (slice(5), slice(3))),
193
+ ((10, 5, 4), (10, 5, 2, 2), (slice(3, 8), slice(None, 4))),
194
+ ],
195
+ )
196
+ def test_slice_through_reshape_correctness(in_shape, out_shape, slice_):
197
+ """Parametrized correctness tests."""
198
+ arr = np.arange(np.prod(in_shape)).reshape(in_shape)
199
+ chunks = tuple(max(1, s // 2) for s in in_shape)
200
+ x = da.from_array(arr, chunks=chunks)
201
+
202
+ result = x.reshape(out_shape)[slice_]
203
+ expected = arr.reshape(out_shape)[slice_]
204
+
205
+ assert_eq(result, expected)
206
+
207
+
208
+ # =============================================================================
209
+ # Task reduction tests
210
+ # =============================================================================
211
+
212
+
213
+ def test_slice_through_reshape_reduces_tasks():
214
+ """Verify slice pushdown reduces task count."""
215
+ arr = np.ones((100, 10))
216
+ x = da.from_array(arr, chunks=(10, 5))
217
+
218
+ # Reshape preserves first dim, then slice
219
+ full = x.reshape((100, 2, 5))
220
+ sliced = x.reshape((100, 2, 5))[:10]
221
+
222
+ full_tasks = len(full.optimize().__dask_graph__())
223
+ sliced_tasks = len(sliced.optimize().__dask_graph__())
224
+
225
+ # Sliced should have fewer tasks (only 1/10 of chunks)
226
+ assert sliced_tasks < full_tasks
227
+
228
+
229
+ def test_slice_through_reshape_reduces_numblocks():
230
+ """Verify slice pushdown reduces number of blocks."""
231
+ arr = np.ones((100, 20))
232
+ x = da.from_array(arr, chunks=(10, 10))
233
+
234
+ result = x.reshape((100, 4, 5))[:20]
235
+ optimized = result.optimize()
236
+
237
+ # Should only have 2 blocks in first dimension (20 / 10)
238
+ assert optimized.numblocks[0] == 2
239
+
240
+
241
+ # =============================================================================
242
+ # Expression structure tests
243
+ # =============================================================================
244
+
245
+
246
+ def test_expression_structure_slice_pushed():
247
+ """Verify slice is pushed through reshape in expression tree."""
248
+ from dask_array.slicing import SliceSlicesIntegers
249
+
250
+ x = da.ones((20, 6), chunks=(5, 3))
251
+ result = x.reshape((20, 2, 3))[:5]
252
+
253
+ # Before simplification: Slice(Reshape(...))
254
+ assert isinstance(result.expr, SliceSlicesIntegers)
255
+
256
+ # After simplification: slice should have pushed through
257
+ simplified = result.expr.simplify()
258
+
259
+ # Slice shouldn't be at root after pushdown
260
+ assert not isinstance(simplified, SliceSlicesIntegers)
261
+
262
+
263
+ def test_expression_structure_slice_blocked():
264
+ """Verify slice is NOT pushed when dimension changes."""
265
+ x = da.ones((6, 10), chunks=(3, 5))
266
+
267
+ # Reshape (6, 10) -> (2, 3, 10) splits first dim
268
+ # First dim changes from 6 to 2, slice should not push through
269
+ result = x.reshape((2, 3, 10))[:1]
270
+
271
+ # Just verify correctness - the optimization doesn't apply here
272
+ assert_eq(result, np.ones((2, 3, 10))[:1])