dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,968 @@
1
+ """Tests for slice pushdown into IO expressions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import pytest
7
+
8
+ import dask_array as da
9
+ from dask_array.io import FromArray
10
+ from dask_array.slicing import SliceSlicesIntegers
11
+ from dask_array._test_utils import assert_eq
12
+
13
+ # Parametrized correctness tests: (array_shape, chunks, slice_tuple)
14
+ SLICE_CASES = [
15
+ # Basic slices
16
+ ((10, 10), (2, 2), (slice(0, 2), slice(0, 2))), # corner
17
+ ((10, 10), (2, 2), (slice(0, 4), slice(0, 4))), # 2x2 chunks
18
+ ((10, 10), (5, 5), (slice(0, 5), slice(0, 5))), # chunk boundary
19
+ ((10, 10), (5, 5), (slice(2, 7), slice(3, 8))), # mid-chunk
20
+ ((10, 10), (5, 5), (slice(None), slice(None))), # full slice
21
+ # Edge cases
22
+ ((10, 10), (2, 2), (slice(5, 5), slice(None))), # empty
23
+ ((10, 10), (2, 2), (slice(3, 4), slice(None))), # single row
24
+ ((10, 10), (2, 2), (slice(-4, -1), slice(-3, None))), # negative
25
+ ((10, 10, 10), (3, 3, 3), (slice(1, 4), slice(2, 5), slice(3, 6))), # 3D
26
+ # Adversarial
27
+ ((10, 10), (10, 10), (slice(2, 5), slice(3, 7))), # single chunk source
28
+ ((10, 10), (3, 4), (slice(1, 8), slice(2, 9))), # uneven chunks
29
+ ((10, 10), (3, 3), (slice(9, None), slice(9, None))), # last chunk
30
+ ((10, 10, 10), (3, 3, 3), (slice(2, 5),)), # partial dims
31
+ ]
32
+
33
+
34
+ @pytest.mark.parametrize("shape,chunks,slc", SLICE_CASES)
35
+ def test_slice_correctness(shape, chunks, slc):
36
+ """Sliced dask array matches sliced numpy array."""
37
+ arr = np.arange(np.prod(shape)).reshape(shape)
38
+ x = da.from_array(arr, chunks=chunks)
39
+ assert_eq(x[slc], arr[slc])
40
+
41
+
42
+ # Task count tests: (array_shape, chunks, slice_tuple, expected_tasks)
43
+ TASK_COUNT_CASES = [
44
+ ((10, 10), (2, 2), (slice(0, 2), slice(0, 2)), 1), # 1 chunk
45
+ ((10, 10), (2, 2), (slice(0, 4), slice(0, 4)), 4), # 2x2 chunks
46
+ ((10, 10), (5, 5), (slice(0, 5), slice(0, 5)), 1), # boundary
47
+ ((10, 10), (5, 5), (slice(2, 7), slice(3, 8)), 4), # spans 2x2
48
+ ((20, 20), (5, 5), (slice(0, 3), slice(None)), 4), # 1x4 row
49
+ ((20, 20), (5, 5), (slice(None), slice(0, 3)), 4), # 4x1 col
50
+ ((20, 20), (5, 5), (slice(0, 12), slice(0, 12)), 9), # 3x3
51
+ ((100, 100), (10, 10), (slice(0, 5), slice(0, 5)), 1), # small from large
52
+ ((10, 10, 10), (5, 5, 5), (slice(0, 3), slice(0, 3), slice(0, 3)), 1), # 3D corner
53
+ ]
54
+
55
+
56
+ @pytest.mark.parametrize("shape,chunks,slc,expected", TASK_COUNT_CASES)
57
+ def test_task_count(shape, chunks, slc, expected):
58
+ """After optimization, task count equals chunks touched."""
59
+ arr = np.arange(np.prod(shape)).reshape(shape)
60
+ x = da.from_array(arr, chunks=chunks)
61
+ y = x[slc].optimize()
62
+ assert len(y.__dask_graph__()) == expected
63
+
64
+
65
+ def test_slice_optimize_slice():
66
+ """Slice, optimize, slice again works correctly."""
67
+ arr = np.arange(100).reshape(10, 10)
68
+ x = da.from_array(arr, chunks=(2, 2))
69
+
70
+ y = x[0:6, 0:6].optimize()
71
+ assert len(y.__dask_graph__()) == 9 # 3x3 chunks
72
+
73
+ z = y[0:2, 0:2].optimize()
74
+ assert len(z.__dask_graph__()) == 1 # 1 chunk
75
+
76
+ assert_eq(z, arr[0:6, 0:6][0:2, 0:2])
77
+
78
+
79
+ def test_slice_through_elemwise():
80
+ """Slice pushes through elemwise into IO."""
81
+ arr = np.arange(100).reshape(10, 10)
82
+ x = da.from_array(arr, chunks=(2, 2))
83
+ y = ((x + 1) * 2)[0:2, 0:2].optimize()
84
+ assert len(y.__dask_graph__()) <= 2
85
+ assert_eq(y, ((arr + 1) * 2)[0:2, 0:2])
86
+
87
+
88
+ def test_nested_slices():
89
+ """Nested slices fuse."""
90
+ arr = np.arange(100).reshape(10, 10)
91
+ x = da.from_array(arr, chunks=(2, 2))
92
+ y = x[1:8, 2:9][1:4, 1:4]
93
+ assert_eq(y, arr[1:8, 2:9][1:4, 1:4])
94
+
95
+
96
+ def test_expression_structure():
97
+ """Verify expression types before/after optimization."""
98
+ x = da.from_array(np.arange(100).reshape(10, 10), chunks=(2, 2))
99
+ y = x[0:2, 0:2]
100
+
101
+ assert isinstance(y.expr, SliceSlicesIntegers)
102
+ assert isinstance(y.optimize().expr, FromArray)
103
+
104
+
105
+ def test_steps_and_reverse():
106
+ """Slices with steps still compute correctly."""
107
+ arr = np.arange(100).reshape(10, 10)
108
+ x = da.from_array(arr, chunks=(2, 2))
109
+
110
+ assert_eq(x[::2, ::2], arr[::2, ::2])
111
+ assert_eq(x[::-1, ::-1], arr[::-1, ::-1])
112
+ assert_eq(x[::5, ::5], arr[::5, ::5])
113
+
114
+
115
+ def test_non_pushdown_cases():
116
+ """Integer indexing, fancy indexing, newaxis don't break."""
117
+ arr = np.arange(100).reshape(10, 10)
118
+ x = da.from_array(arr, chunks=(2, 2))
119
+
120
+ assert_eq(x[5, :], arr[5, :])
121
+ assert_eq(x[[1, 3, 5], :], arr[[1, 3, 5], :])
122
+ assert_eq(x[None, :5, :5], arr[None, :5, :5])
123
+
124
+
125
+ def test_broadcast_to_empty_slice():
126
+ result = da.broadcast_to(da.from_array(np.array([1]), (1,)), (5,))[:0]
127
+ expected = np.array([], dtype=int)
128
+
129
+ assert result.chunks == ((0,),)
130
+ assert_eq(result, expected)
131
+ assert_eq(da.Array(result.expr.optimize(fuse=False)), expected)
132
+
133
+
134
+ def test_masked_array():
135
+ """Slice pushdown preserves masks."""
136
+ arr = np.ma.array(np.arange(100).reshape(10, 10), mask=False)
137
+ arr.mask[5, 5] = True
138
+ x = da.from_array(arr, chunks=(3, 3))
139
+ result = x[4:7, 4:7].compute()
140
+ expected = arr[4:7, 4:7]
141
+ assert_eq(result, expected)
142
+ assert_eq(result.mask, expected.mask)
143
+
144
+
145
+ def test_deterministic_names():
146
+ """Same slice -> same name, different slice -> different name."""
147
+ arr = np.arange(100).reshape(10, 10)
148
+ x1 = da.from_array(arr, chunks=(2, 2))
149
+ x2 = da.from_array(arr, chunks=(2, 2))
150
+
151
+ assert x1[0:2, 0:2].optimize().name == x2[0:2, 0:2].optimize().name
152
+ assert x1[0:2, 0:2].optimize().name != x1[0:3, 0:3].optimize().name
153
+
154
+
155
+ def test_slice_then_reduction():
156
+ """Slice followed by reduction."""
157
+ arr = np.arange(100).reshape(10, 10)
158
+ x = da.from_array(arr, chunks=(2, 2))
159
+ assert_eq(x[0:4, 0:4].sum(), arr[0:4, 0:4].sum())
160
+
161
+
162
+ def test_region_deferred_slice():
163
+ """Slice pushdown uses _region for deferred slicing (not eager read)."""
164
+ arr = np.arange(10000).reshape(100, 100)
165
+ x = da.from_array(arr, chunks=(10, 10))
166
+ # Use a slice that fits within a single chunk
167
+ y = x[12:18, 34:39]
168
+
169
+ opt = y.expr.optimize()
170
+
171
+ # Should use _region parameter, not slice the source eagerly
172
+ assert opt.operand("_region") == (slice(12, 18, None), slice(34, 39, None))
173
+ # Source array should still be the full array
174
+ assert opt.array.shape == (100, 100)
175
+ # Chunks should be for the sliced region (6x5)
176
+ assert opt.chunks == ((6,), (5,))
177
+
178
+ # Verify correctness
179
+ assert_eq(y, arr[12:18, 34:39])
180
+
181
+
182
+ def test_region_single_chunk():
183
+ """Slice within a single chunk produces one task with direct slice."""
184
+ arr = np.arange(10000 * 10000).reshape(10000, 10000)
185
+ x = da.from_array(arr, chunks=(1000, 1000))
186
+ # Small slice within a single chunk
187
+ y = x[1500:1550, 2300:2350]
188
+
189
+ opt = y.expr.optimize()
190
+ graph = dict(opt.__dask_graph__())
191
+
192
+ # Should be single task (slice fits within one chunk)
193
+ task_keys = [k for k in graph if isinstance(k, tuple) and len(k) == 3]
194
+ assert len(task_keys) == 1
195
+
196
+ # The slice should be direct (1500:1550, 2300:2350), not via 1000x1000 chunk
197
+ graph_str = str(graph)
198
+ assert "1000" not in graph_str, "Should slice directly, not via full chunk"
199
+
200
+ # Verify correctness
201
+ assert_eq(y, arr[1500:1550, 2300:2350])
202
+
203
+
204
+ def test_region_multiple_chunks():
205
+ """Slice spanning multiple chunks still produces multiple tasks."""
206
+ arr = np.arange(10000).reshape(100, 100)
207
+ x = da.from_array(arr, chunks=(10, 10))
208
+ # Slice spanning 2x2 chunks: 15-25 spans chunks 1,2 in first dim
209
+ # 35-45 spans chunks 3,4 in second dim
210
+ y = x[15:25, 35:45]
211
+
212
+ opt = y.expr.optimize()
213
+ graph = dict(opt.__dask_graph__())
214
+
215
+ # Should be 2x2=4 tasks (slice spans multiple chunks)
216
+ task_keys = [k for k in graph if isinstance(k, tuple) and len(k) == 3]
217
+ assert len(task_keys) == 4
218
+
219
+ # Verify correctness
220
+ assert_eq(y, arr[15:25, 35:45])
221
+
222
+
223
+ def test_region_zarr_deferred(tmp_path):
224
+ """Zarr slicing is deferred - graph contains zarr array, not numpy data."""
225
+ zarr = pytest.importorskip("zarr")
226
+ # Create zarr array
227
+ zarr_path = tmp_path / "test.zarr"
228
+ z = zarr.open(
229
+ str(zarr_path),
230
+ mode="w",
231
+ shape=(10000, 10000),
232
+ dtype="float64",
233
+ chunks=(1000, 1000),
234
+ )
235
+ z[1500:1550, 2300:2350] = np.arange(2500).reshape(50, 50)
236
+
237
+ x = da.from_zarr(str(zarr_path))
238
+ y = x[1500:1550, 2300:2350]
239
+
240
+ opt = y.expr.optimize()
241
+ graph = dict(opt.__dask_graph__())
242
+
243
+ # Should have zarr array in graph, not numpy data
244
+ zarr_arrays = [v for v in graph.values() if isinstance(v, zarr.Array)]
245
+ numpy_arrays = [v for v in graph.values() if isinstance(v, np.ndarray)]
246
+
247
+ assert len(zarr_arrays) == 1, "Graph should contain the zarr array"
248
+ assert len(numpy_arrays) == 0, "Graph should not contain numpy arrays (data not loaded)"
249
+
250
+ # The zarr array in graph should be the full array, not sliced
251
+ assert zarr_arrays[0].shape == (10000, 10000)
252
+
253
+ # Verify correctness
254
+ assert_eq(y, z[1500:1550, 2300:2350])
255
+
256
+
257
+ def test_integer_indexing_pushdown():
258
+ """Integer indexing uses region pushdown to minimize data loading."""
259
+ arr = np.arange(100).reshape(10, 10)
260
+ x = da.from_array(arr, chunks=(5, 5))
261
+
262
+ # Pure integer indexing - should be 2 tasks (FromArray + extract)
263
+ y = x[3, 7]
264
+ opt = y.optimize()
265
+ assert len(opt.__dask_graph__()) == 2
266
+
267
+ # The inner FromArray should have region centered on (3, 7)
268
+ from_array_expr = opt.expr.array
269
+ assert from_array_expr.operand("_region") == (slice(3, 4), slice(7, 8))
270
+ assert from_array_expr.array.shape == (10, 10) # Original array unchanged
271
+
272
+ assert_eq(y, arr[3, 7])
273
+
274
+ # Mixed slice + integer
275
+ y = x[:3, 5]
276
+ assert_eq(y, arr[:3, 5])
277
+
278
+ y = x[5, 2:8]
279
+ assert_eq(y, arr[5, 2:8])
280
+
281
+
282
+ # ============================================================
283
+ # Slice through reduction tests
284
+ # ============================================================
285
+
286
+
287
+ def test_slice_through_reduction_optimization():
288
+ """Verify slice pushdown through reduction produces equivalent result.
289
+
290
+ x.sum(axis=0)[:5] should simplify to x[:, :5].sum(axis=0)
291
+ """
292
+ x = da.ones((100, 100), chunks=(10, 10))
293
+
294
+ # The naive way: full sum then slice
295
+ y = x.sum(axis=0)[:5]
296
+
297
+ # The optimized way: slice first, then sum
298
+ expected = x[:, :5].sum(axis=0)
299
+
300
+ # After simplification, the names should be equivalent
301
+ # (both sides need simplify since slices also simplify through ones)
302
+ assert y.expr.simplify()._name == expected.expr.simplify()._name
303
+
304
+
305
+ def test_slice_through_reduction_reduces_tasks():
306
+ """Slice pushdown through reduction should reduce graph size.
307
+
308
+ For a from_array with (10, 10) chunks, slicing after reduction
309
+ should result in fewer tasks than computing the full reduction.
310
+ """
311
+ arr = np.arange(10000).reshape(100, 100)
312
+ x = da.from_array(arr, chunks=(10, 10))
313
+
314
+ # Full reduction has 10*10 input chunks
315
+ full_sum = x.sum(axis=0)
316
+ full_tasks = len(full_sum.optimize().__dask_graph__())
317
+
318
+ # Sliced reduction should have fewer tasks
319
+ sliced_sum = x.sum(axis=0)[:5]
320
+ sliced_tasks = len(sliced_sum.optimize().__dask_graph__())
321
+
322
+ # Slicing to first 5 elements (1 chunk column) should have ~10x fewer tasks
323
+ assert sliced_tasks < full_tasks
324
+
325
+ # Verify the reduction is correct
326
+ assert_eq(sliced_sum, arr.sum(axis=0)[:5])
327
+
328
+
329
+ def test_slice_through_reduction_axis1():
330
+ """Slice pushdown through sum(axis=1)."""
331
+ x = da.ones((100, 100), chunks=(10, 10))
332
+
333
+ # x.sum(axis=1)[:5] should simplify to x[:5, :].sum(axis=1)
334
+ y = x.sum(axis=1)[:5]
335
+ expected = x[:5, :].sum(axis=1)
336
+
337
+ assert y.expr.simplify()._name == expected.expr.simplify()._name
338
+
339
+
340
+ def test_slice_through_reduction_3d():
341
+ """Slice pushdown through reduction on 3D array."""
342
+ x = da.ones((20, 20, 20), chunks=(5, 5, 5))
343
+
344
+ # Reduce axis 1, slice result
345
+ # Output axes: [0, 2] become [0, 1] -> slice [:3, :4] maps to input [:3, :, :4]
346
+ y = x.sum(axis=1)[:3, :4]
347
+ expected = x[:3, :, :4].sum(axis=1)
348
+
349
+ assert y.expr.simplify()._name == expected.expr.simplify()._name
350
+
351
+
352
+ def test_slice_through_reduction_multiple_axes():
353
+ """Slice pushdown through reduction on multiple axes."""
354
+ x = da.ones((20, 20, 20), chunks=(5, 5, 5))
355
+
356
+ # Reduce axes 0 and 2, only axis 1 remains
357
+ # Output axis 0 -> input axis 1
358
+ y = x.sum(axis=(0, 2))[:5]
359
+ expected = x[:, :5, :].sum(axis=(0, 2))
360
+
361
+ assert y.expr.simplify()._name == expected.expr.simplify()._name
362
+
363
+
364
+ def test_slice_through_reduction_correctness():
365
+ """Verify correctness of optimized slice-through-reduction."""
366
+ arr = np.arange(10000).reshape(100, 100)
367
+ x = da.from_array(arr, chunks=(10, 10))
368
+
369
+ # Various cases
370
+ assert_eq(x.sum(axis=0)[:5], arr.sum(axis=0)[:5])
371
+ assert_eq(x.sum(axis=1)[:5], arr.sum(axis=1)[:5])
372
+ assert_eq(x.sum(axis=0)[10:20], arr.sum(axis=0)[10:20])
373
+
374
+
375
+ def test_slice_through_reduction_integer_index():
376
+ """Integer indexing through reduction reduces tasks.
377
+
378
+ Integer indices are converted to size-1 slices, pushed through,
379
+ then extracted with [0] at the end.
380
+ """
381
+ arr = np.arange(10000).reshape(100, 100)
382
+ x = da.from_array(arr, chunks=(10, 10))
383
+
384
+ # Full reduction
385
+ full_tasks = len(x.sum(axis=0).optimize().__dask_graph__())
386
+
387
+ # Integer index should have fewer tasks
388
+ result = x.sum(axis=0)[5]
389
+ indexed_tasks = len(result.optimize().__dask_graph__())
390
+
391
+ assert indexed_tasks < full_tasks
392
+ assert_eq(result, arr.sum(axis=0)[5])
393
+
394
+
395
+ # =============================================================================
396
+ # Slice through creation expressions (ones, zeros, full, empty)
397
+ # =============================================================================
398
+
399
+
400
+ def test_slice_ones_returns_smaller_ones():
401
+ """Slicing ones() returns a new ones() with the sliced shape."""
402
+ from dask_array.creation import Ones
403
+
404
+ x = da.ones((100, 100), chunks=(10, 10))
405
+ y = x[:15, :25]
406
+
407
+ # After simplification, should be Ones with new shape, not Slice(Ones)
408
+ simplified = y.expr.simplify()
409
+ assert isinstance(simplified, Ones)
410
+ assert simplified.shape == (15, 25)
411
+
412
+
413
+ def test_slice_zeros_returns_smaller_zeros():
414
+ """Slicing zeros() returns a new zeros() with the sliced shape."""
415
+ from dask_array.creation import Zeros
416
+
417
+ x = da.zeros((100, 100), chunks=(10, 10))
418
+ y = x[:15, :25]
419
+
420
+ simplified = y.expr.simplify()
421
+ assert isinstance(simplified, Zeros)
422
+ assert simplified.shape == (15, 25)
423
+
424
+
425
+ def test_slice_full_returns_smaller_full():
426
+ """Slicing full() returns a new full() with the sliced shape."""
427
+ from dask_array.creation import Full
428
+
429
+ x = da.full((100, 100), 42, chunks=(10, 10))
430
+ y = x[:15, :25]
431
+
432
+ simplified = y.expr.simplify()
433
+ assert isinstance(simplified, Full)
434
+ assert simplified.shape == (15, 25)
435
+ # Verify fill_value is preserved
436
+ assert_eq(y, np.full((15, 25), 42))
437
+
438
+
439
+ def test_slice_creation_correctness():
440
+ """Verify sliced creation expressions produce correct values."""
441
+ assert_eq(da.ones((100, 100), chunks=10)[:15, :25], np.ones((15, 25)))
442
+ assert_eq(da.zeros((100, 100), chunks=10)[:15, :25], np.zeros((15, 25)))
443
+ assert_eq(da.full((100, 100), 7.5, chunks=10)[:15, :25], np.full((15, 25), 7.5))
444
+
445
+
446
+ def test_slice_creation_preserves_dtype():
447
+ """Verify sliced creation preserves dtype."""
448
+ x = da.ones((100, 100), chunks=10, dtype="int32")[:15, :25]
449
+ assert x.dtype == np.dtype("int32")
450
+ assert_eq(x, np.ones((15, 25), dtype="int32"))
451
+
452
+
453
+ # =============================================================================
454
+ # Slice through Concatenate
455
+ # =============================================================================
456
+
457
+
458
+ def test_slice_through_concat_same_axis_first_array():
459
+ """Slice entirely within first array of concat -> just first array sliced."""
460
+ a = da.ones((10, 5), chunks=5)
461
+ b = da.ones((10, 5), chunks=5)
462
+ result = da.concatenate([a, b], axis=0)[:5] # Only needs 'a'
463
+ expected = a[:5]
464
+
465
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
466
+
467
+
468
+ def test_slice_through_concat_same_axis_spans_arrays():
469
+ """Slice spans multiple arrays in concat."""
470
+ a = da.ones((10, 5), chunks=5)
471
+ b = da.ones((10, 5), chunks=5)
472
+ c = da.ones((10, 5), chunks=5)
473
+ # slice 5:15 spans a[5:10] and b[0:5]
474
+ result = da.concatenate([a, b, c], axis=0)[5:15]
475
+ expected = da.concatenate([a[5:], b[:5]], axis=0)
476
+
477
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
478
+
479
+
480
+ def test_slice_through_concat_different_axis():
481
+ """Slice on different axis than concat -> push to all inputs."""
482
+ a = da.ones((10, 20), chunks=5)
483
+ b = da.ones((10, 20), chunks=5)
484
+ result = da.concatenate([a, b], axis=0)[:, :5] # Slice axis 1
485
+ expected = da.concatenate([a[:, :5], b[:, :5]], axis=0)
486
+
487
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
488
+
489
+
490
+ def test_slice_through_concat_correctness():
491
+ """Verify slice through concat produces correct values."""
492
+ a = np.arange(20).reshape(4, 5)
493
+ b = np.arange(20, 40).reshape(4, 5)
494
+ da_a = da.from_array(a, chunks=2)
495
+ da_b = da.from_array(b, chunks=2)
496
+
497
+ # Same axis slice
498
+ result = da.concatenate([da_a, da_b], axis=0)[:3]
499
+ assert_eq(result, np.concatenate([a, b], axis=0)[:3])
500
+
501
+ # Different axis slice
502
+ result = da.concatenate([da_a, da_b], axis=0)[:, :3]
503
+ assert_eq(result, np.concatenate([a, b], axis=0)[:, :3])
504
+
505
+ # Slice spanning both arrays
506
+ result = da.concatenate([da_a, da_b], axis=0)[2:6]
507
+ assert_eq(result, np.concatenate([a, b], axis=0)[2:6])
508
+
509
+
510
+ def test_slice_through_concat_reduces_tasks():
511
+ """Verify slice through concat reduces task count."""
512
+ a = da.ones((100, 100), chunks=10)
513
+ b = da.ones((100, 100), chunks=10)
514
+ concat = da.concatenate([a, b], axis=0)
515
+
516
+ full_tasks = len(concat.optimize().__dask_graph__())
517
+ # Slice only first 5 rows - should only need first array
518
+ sliced_tasks = len(concat[:5].optimize().__dask_graph__())
519
+
520
+ assert sliced_tasks < full_tasks
521
+
522
+
523
+ # =============================================================================
524
+ # Slice through Stack
525
+ # =============================================================================
526
+
527
+
528
+ def test_slice_through_stack_selects_subset():
529
+ """Slice on stacked axis selects subset of inputs."""
530
+ a = da.ones((10, 5), chunks=5)
531
+ b = da.ones((10, 5), chunks=5)
532
+ c = da.ones((10, 5), chunks=5)
533
+ # stack gives shape (3, 10, 5), slice [:1] should be stack([a])
534
+ result = da.stack([a, b, c], axis=0)[:1]
535
+ expected = da.stack([a], axis=0)
536
+
537
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
538
+
539
+
540
+ def test_slice_through_stack_other_axis():
541
+ """Slice on non-stacked axis pushes to all inputs."""
542
+ a = da.ones((10, 20), chunks=5)
543
+ b = da.ones((10, 20), chunks=5)
544
+ # stack gives shape (2, 10, 20), slice [:, :5, :10] pushes to each array
545
+ result = da.stack([a, b], axis=0)[:, :5, :10]
546
+ expected = da.stack([a[:5, :10], b[:5, :10]], axis=0)
547
+
548
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
549
+
550
+
551
+ def test_slice_through_stack_mixed():
552
+ """Slice on both stacked and other axes."""
553
+ a = da.ones((10, 20), chunks=5)
554
+ b = da.ones((10, 20), chunks=5)
555
+ c = da.ones((10, 20), chunks=5)
556
+ # stack gives shape (3, 10, 20), slice [:2, :5] keeps a and b, sliced
557
+ result = da.stack([a, b, c], axis=0)[:2, :5]
558
+ expected = da.stack([a[:5], b[:5]], axis=0)
559
+
560
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
561
+
562
+
563
+ def test_slice_through_stack_correctness():
564
+ """Verify slice through stack produces correct values."""
565
+ a = np.arange(20).reshape(4, 5)
566
+ b = np.arange(20, 40).reshape(4, 5)
567
+ c = np.arange(40, 60).reshape(4, 5)
568
+ da_a = da.from_array(a, chunks=2)
569
+ da_b = da.from_array(b, chunks=2)
570
+ da_c = da.from_array(c, chunks=2)
571
+
572
+ # Slice on stacked axis
573
+ result = da.stack([da_a, da_b, da_c], axis=0)[:2]
574
+ assert_eq(result, np.stack([a, b, c], axis=0)[:2])
575
+
576
+ # Slice on other axis
577
+ result = da.stack([da_a, da_b, da_c], axis=0)[:, :2, :3]
578
+ assert_eq(result, np.stack([a, b, c], axis=0)[:, :2, :3])
579
+
580
+
581
+ def test_slice_through_stack_reduces_tasks():
582
+ """Verify slice through stack reduces task count."""
583
+ a = da.ones((100, 100), chunks=10)
584
+ b = da.ones((100, 100), chunks=10)
585
+ c = da.ones((100, 100), chunks=10)
586
+ stacked = da.stack([a, b, c], axis=0)
587
+
588
+ full_tasks = len(stacked.optimize().__dask_graph__())
589
+ # Slice only first array
590
+ sliced_tasks = len(stacked[:1].optimize().__dask_graph__())
591
+
592
+ assert sliced_tasks < full_tasks
593
+
594
+
595
+ # =============================================================================
596
+ # Slice through BroadcastTo
597
+ # =============================================================================
598
+
599
+
600
+ def test_slice_through_broadcast_to_new_dim():
601
+ """Slice on dimension added by broadcast."""
602
+ x = da.ones((10,), chunks=5)
603
+ # broadcast_to adds a new dimension at front: (10,) -> (20, 10)
604
+ result = da.broadcast_to(x, (20, 10))[:5, :]
605
+ expected = da.broadcast_to(x, (5, 10))
606
+
607
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
608
+
609
+
610
+ def test_slice_through_broadcast_to_existing_dim():
611
+ """Slice on dimension that exists in input."""
612
+ x = da.ones((10,), chunks=5)
613
+ # broadcast_to adds new dim: (10,) -> (20, 10)
614
+ result = da.broadcast_to(x, (20, 10))[:, :5]
615
+ expected = da.broadcast_to(x[:5], (20, 5))
616
+
617
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
618
+
619
+
620
+ def test_slice_through_broadcast_to_both_dims():
621
+ """Slice on both new and existing dimensions."""
622
+ x = da.ones((10,), chunks=5)
623
+ result = da.broadcast_to(x, (20, 10))[:5, :3]
624
+ expected = da.broadcast_to(x[:3], (5, 3))
625
+
626
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
627
+
628
+
629
+ def test_slice_through_broadcast_to_broadcasted_dim():
630
+ """Slice on dimension that was size-1 in input."""
631
+ x = da.ones((1, 10), chunks=(1, 5))
632
+ # broadcast_to expands first dim: (1, 10) -> (20, 10)
633
+ result = da.broadcast_to(x, (20, 10))[:5, :3]
634
+ # First dim can't push (was 1), second dim pushes
635
+ expected = da.broadcast_to(x[:, :3], (5, 3))
636
+
637
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
638
+
639
+
640
+ def test_slice_through_broadcast_to_correctness():
641
+ """Verify slice through broadcast_to produces correct values."""
642
+ x = np.arange(10)
643
+ da_x = da.from_array(x, chunks=5)
644
+
645
+ # Broadcast to 2D then slice
646
+ result = da.broadcast_to(da_x, (20, 10))[:5, :3]
647
+ expected = np.broadcast_to(x, (20, 10))[:5, :3]
648
+ assert_eq(result, expected)
649
+
650
+
651
+ def test_slice_through_broadcast_to_reduces_tasks():
652
+ """Verify slice through broadcast_to reduces task count."""
653
+ x = da.ones((100,), chunks=10)
654
+ broadcasted = da.broadcast_to(x, (100, 100))
655
+
656
+ full_tasks = len(broadcasted.optimize().__dask_graph__())
657
+ # Slice to smaller output
658
+ sliced_tasks = len(broadcasted[:5, :5].optimize().__dask_graph__())
659
+
660
+ assert sliced_tasks < full_tasks
661
+
662
+
663
+ # --- Shuffle (take) through Elemwise Tests ---
664
+
665
+
666
+ def test_shuffle_pushes_through_elemwise_add():
667
+ """(x + y)[[1,3,5]] should optimize to x[[1,3,5]] + y[[1,3,5]]."""
668
+ x = da.arange(20, chunks=5)
669
+ y = da.arange(20, chunks=5)
670
+
671
+ indices = [1, 3, 5, 7, 9]
672
+ result = (x + y)[indices]
673
+ expected = x[indices] + y[indices]
674
+
675
+ # Structure should match
676
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
677
+
678
+ # Verify correctness
679
+ x_np = np.arange(20)
680
+ y_np = np.arange(20)
681
+ assert_eq(result, (x_np + y_np)[indices])
682
+
683
+
684
+ def test_shuffle_pushes_through_elemwise_mul():
685
+ """(x * y)[[2,4,6]] should optimize to x[[2,4,6]] * y[[2,4,6]]."""
686
+ x = da.arange(30, chunks=10)
687
+ y = da.arange(30, chunks=10)
688
+
689
+ indices = [2, 4, 6, 8]
690
+ result = (x * y)[indices]
691
+ expected = x[indices] * y[indices]
692
+
693
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
694
+ assert_eq(result, expected)
695
+
696
+
697
+ def test_shuffle_pushes_through_elemwise_2d():
698
+ """Shuffle on 2D array along axis 0."""
699
+ x = da.ones((10, 8), chunks=(5, 4))
700
+ y = da.ones((10, 8), chunks=(5, 4))
701
+
702
+ indices = [0, 2, 4, 6]
703
+ result = (x + y)[indices, :]
704
+ expected = x[indices, :] + y[indices, :]
705
+
706
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
707
+ assert_eq(result, expected)
708
+
709
+
710
+ def test_shuffle_pushes_through_elemwise_scalar():
711
+ """Shuffle through elemwise with scalar."""
712
+ x = da.arange(20, chunks=5)
713
+
714
+ indices = [1, 5, 9, 13]
715
+ result = (x + 1)[indices]
716
+ expected = x[indices] + 1
717
+
718
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
719
+ assert_eq(result, expected)
720
+
721
+
722
+ def test_shuffle_pushes_through_unary_elemwise():
723
+ """Shuffle through unary elemwise (e.g. negative)."""
724
+ x = da.arange(20, chunks=5)
725
+
726
+ indices = [2, 4, 6, 8]
727
+ result = (-x)[indices]
728
+ expected = -(x[indices])
729
+
730
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
731
+ assert_eq(result, expected)
732
+
733
+
734
+ def test_shuffle_through_elemwise_reduces_work():
735
+ """Taking a subset should reduce computation by only computing needed elements."""
736
+ x = da.ones((100,), chunks=10)
737
+ y = da.ones((100,), chunks=10)
738
+
739
+ # Take only 10 of 100 elements
740
+ indices = list(range(0, 100, 10)) # [0, 10, 20, ..., 90]
741
+ result = (x + y)[indices]
742
+
743
+ # Optimized should have fewer tasks since we only compute what we need
744
+ unopt_tasks = len(result.__dask_graph__())
745
+ opt_tasks = len(result.optimize().__dask_graph__())
746
+
747
+ # Optimization should reduce task count
748
+ assert opt_tasks <= unopt_tasks
749
+
750
+
751
+ def test_shuffle_through_elemwise_with_broadcast_2d():
752
+ """Shuffle through elemwise with 2D broadcast operand (size-1 dimension).
753
+
754
+ (a * y2d)[[5]] should optimize to a[[5]] * y2d (shuffle only non-broadcast input).
755
+ """
756
+ a = da.from_array(np.arange(200).reshape(10, 20), chunks=(4, 5))
757
+ y2d = da.from_array(np.arange(20).reshape(1, 20), chunks=(1, 20))
758
+
759
+ result = (a * y2d)[[5]]
760
+ expected = a[[5]] * y2d
761
+
762
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
763
+ assert_eq(result, expected)
764
+
765
+
766
+ def test_shuffle_through_elemwise_with_broadcast_1d():
767
+ """Shuffle through elemwise with 1D broadcast operand.
768
+
769
+ (a * y1d)[[5]] should optimize to a[[5]] * y1d (shuffle only the 2D input).
770
+ """
771
+ a = da.from_array(np.arange(200).reshape(10, 20), chunks=(4, 5))
772
+ y1d = da.from_array(np.arange(20), chunks=20)
773
+
774
+ result = (a * y1d)[[5]]
775
+ expected = a[[5]] * y1d
776
+
777
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
778
+ assert_eq(result, expected)
779
+
780
+
781
+ # --- Shuffle through Transpose Tests ---
782
+
783
+
784
+ def test_shuffle_pushes_through_transpose():
785
+ """x.T[[1,3,5]] should optimize to x[:, [1,3,5]].T."""
786
+ x = da.arange(20, chunks=5).reshape((4, 5))
787
+
788
+ indices = [1, 3]
789
+ result = x.T[indices, :] # Take rows 1, 3 from transposed (5, 4)
790
+ expected = x[:, indices].T # Take cols 1, 3 from (4, 5), then transpose
791
+
792
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
793
+ assert_eq(result, expected)
794
+
795
+
796
+ def test_shuffle_pushes_through_transpose_axis1():
797
+ """x.T[:, [0,2]] should optimize to x[[0,2], :].T."""
798
+ x = da.arange(20, chunks=5).reshape((4, 5))
799
+
800
+ indices = [0, 2]
801
+ result = x.T[:, indices] # Take cols 0, 2 from transposed (5, 4)
802
+ expected = x[indices, :].T # Take rows 0, 2 from (4, 5), then transpose
803
+
804
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
805
+ assert_eq(result, expected)
806
+
807
+
808
+ def test_shuffle_pushes_through_transpose_3d():
809
+ """Shuffle through 3D transpose."""
810
+ x = da.ones((2, 3, 4), chunks=2)
811
+
812
+ indices = [0, 2]
813
+ # Transpose (2,3,4) -> (4,3,2), then take along axis 0
814
+ result = x.transpose((2, 1, 0))[indices, :, :]
815
+ # Equivalent: take along axis 2 of original, then transpose
816
+ expected = x[:, :, indices].transpose((2, 1, 0))
817
+
818
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
819
+ assert_eq(result, expected)
820
+
821
+
822
+ # --- Shuffle through Concatenate/Stack Tests ---
823
+
824
+
825
+ def test_shuffle_pushes_through_concatenate():
826
+ """Shuffle on non-concat axis pushes to all inputs."""
827
+ a = da.arange(20, chunks=5).reshape((4, 5))
828
+ b = da.arange(20, 40, chunks=5).reshape((4, 5))
829
+
830
+ concat = da.concatenate([a, b], axis=1) # (4, 10)
831
+ indices = [0, 2]
832
+ result = concat[indices, :] # Take rows 0, 2
833
+
834
+ expected = da.concatenate([a[indices, :], b[indices, :]], axis=1)
835
+
836
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
837
+ assert_eq(result, expected)
838
+
839
+
840
+ def test_shuffle_pushes_through_stack():
841
+ """Shuffle on non-stack axis pushes to all inputs."""
842
+ a = da.arange(12, chunks=4).reshape((3, 4))
843
+ b = da.arange(12, 24, chunks=4).reshape((3, 4))
844
+
845
+ stacked = da.stack([a, b], axis=0) # (2, 3, 4)
846
+ indices = [0, 2]
847
+ result = stacked[:, indices, :] # Take along axis 1
848
+
849
+ expected = da.stack([a[indices, :], b[indices, :]], axis=0)
850
+
851
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
852
+ assert_eq(result, expected)
853
+
854
+
855
+ # --- Shuffle through Blockwise Tests ---
856
+
857
+
858
+ def test_shuffle_pushes_through_blockwise():
859
+ """Shuffle through blockwise when adjust_chunks doesn't affect shuffle axis."""
860
+ from dask_array._blockwise import Blockwise
861
+
862
+ # map_blocks creates a generic Blockwise with no adjust_chunks
863
+ x = da.ones((4, 6), chunks=(2, 3))
864
+ mapped = x.map_blocks(lambda b: b * 2)
865
+
866
+ indices = [0, 2]
867
+ result = mapped[indices, :]
868
+
869
+ # Expected: shuffle first, then map_blocks
870
+ expected = x[indices, :].map_blocks(lambda b: b * 2)
871
+
872
+ # Verify the optimization happened - Blockwise should be at top
873
+ opt = result.expr.simplify()
874
+ assert isinstance(opt, Blockwise)
875
+
876
+ # Verify correctness
877
+ assert_eq(result, expected)
878
+
879
+
880
+ def test_shuffle_does_not_push_through_blockwise_adjust_chunks():
881
+ """Shuffle does NOT push through blockwise when adjust_chunks affects shuffle axis."""
882
+ from dask_array._shuffle import Shuffle
883
+
884
+ # map_blocks with explicit chunks sets adjust_chunks
885
+ x = da.ones((8, 6), chunks=(2, 3))
886
+ # Providing chunks means each output block has these chunk sizes (adjust_chunks)
887
+ # This creates output with shape (4, 6) chunks (1, 3)
888
+ mapped = x.map_blocks(lambda b: b * 2, chunks=(1, 3))
889
+
890
+ indices = [0, 2] # Taking along axis 0 - NOT all indices
891
+ result = mapped[indices, :]
892
+
893
+ # Shuffle should stay at top (not push through) because axis 0 has adjust_chunks
894
+ opt = result.expr.simplify()
895
+ assert isinstance(opt, Shuffle)
896
+
897
+ # Still correct
898
+ assert_eq(result, mapped.compute()[indices, :])
899
+
900
+
901
+ # --- ExpandDims (None indexing) Pushdown Tests ---
902
+
903
+
904
+ def test_none_slice_pushes_through_elemwise():
905
+ """Slice with None pushes slicing through elemwise, keeps expand_dims on top."""
906
+ x = da.ones((10, 10), chunks=5)
907
+ y = da.ones((10, 10), chunks=5)
908
+
909
+ # (x + y)[None, :5, :] should optimize to (x[:5] + y[:5])[None, :, :]
910
+ result = (x + y)[None, :5, :]
911
+ expected = (x[:5, :] + y[:5, :])[None, :, :]
912
+
913
+ # Structure should match after optimization
914
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
915
+
916
+ # Verify correctness
917
+ assert_eq(result, expected)
918
+
919
+
920
+ def test_none_slice_multiple_nones():
921
+ """Slice with multiple Nones pushes through correctly."""
922
+ x = da.arange(20, chunks=5).reshape((4, 5))
923
+ y = da.ones((4, 5), chunks=(4, 5))
924
+
925
+ # (x + y)[None, :2, None, :3] -> (x[:2, :3] + y[:2, :3])[None, :, None, :]
926
+ result = (x + y)[None, :2, None, :3]
927
+ expected = (x[:2, :3] + y[:2, :3])[None, :, None, :]
928
+
929
+ # Structure should match after optimization
930
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
931
+
932
+ # Verify correctness
933
+ assert_eq(result, expected)
934
+
935
+
936
+ def test_none_slice_no_slicing():
937
+ """Slice with only None (dimension expansion) uses ExpandDims."""
938
+ from dask_array.manipulation._expand import ExpandDims
939
+
940
+ x = da.ones((10, 10), chunks=5)
941
+ y = da.ones((10, 10), chunks=5)
942
+
943
+ # (x + y)[None, :, :] - only dimension expansion, no slicing
944
+ result = (x + y)[None, :, :]
945
+
946
+ opt = result.expr.simplify()
947
+ # ExpandDims is used for dimension expansion (not Reshape, for fusion compat)
948
+ assert isinstance(opt, ExpandDims)
949
+
950
+ # Verify correctness
951
+ x_np = np.ones((10, 10))
952
+ y_np = np.ones((10, 10))
953
+ assert_eq(result, (x_np + y_np)[None, :, :])
954
+
955
+
956
+ def test_none_slice_through_transpose():
957
+ """Slice with None pushes through transpose."""
958
+ x = da.arange(20, chunks=5).reshape((4, 5))
959
+
960
+ # x.T[None, :3, :2] -> x[:2, :3].T[None, :, :]
961
+ result = x.T[None, :3, :2]
962
+ expected = x[:2, :3].T[None, :, :]
963
+
964
+ # Structure should match after optimization
965
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
966
+
967
+ # Verify correctness
968
+ assert_eq(result, expected)