dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,678 @@
1
+ """Tests for slice pushdown through Blockwise expressions.
2
+
3
+ These tests explore when slice pushdown is safe and correct for different
4
+ Blockwise configurations.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+ import pytest
11
+
12
+ import dask_array as da
13
+ from dask_array._test_utils import assert_eq
14
+
15
+ # =============================================================================
16
+ # Case 1: Standard Blockwise (reduction chunk step)
17
+ # - out_ind matches input indices
18
+ # - No new_axes, no adjust_chunks
19
+ # - Slice should push through directly
20
+ # =============================================================================
21
+
22
+
23
+ def test_slice_through_reduction_blockwise():
24
+ """Slice pushes through the Blockwise chunk step of a reduction."""
25
+ x = da.ones((100, 100), chunks=(10, 10))
26
+
27
+ # x.sum(axis=0)[:5] should simplify to x[:, :5].sum(axis=0)
28
+ result = x.sum(axis=0)[:5]
29
+ expected = x[:, :5].sum(axis=0)
30
+
31
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
32
+
33
+
34
+ def test_slice_through_reduction_blockwise_axis1():
35
+ """Slice through reduction on axis 1."""
36
+ x = da.ones((100, 100), chunks=(10, 10))
37
+
38
+ result = x.sum(axis=1)[:5]
39
+ expected = x[:5, :].sum(axis=1)
40
+
41
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
42
+
43
+
44
+ # =============================================================================
45
+ # Case 2: Elemwise operations
46
+ # - Already handled by _pushdown_through_elemwise
47
+ # - Included here for completeness
48
+ # =============================================================================
49
+
50
+
51
+ def test_slice_through_elemwise_add():
52
+ """Slice through addition."""
53
+ x = da.ones((100, 100), chunks=(10, 10))
54
+ y = da.ones((100, 100), chunks=(10, 10))
55
+
56
+ result = (x + y)[:5, :10]
57
+ expected = x[:5, :10] + y[:5, :10]
58
+
59
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
60
+
61
+
62
+ def test_slice_through_elemwise_unary():
63
+ """Slice through unary function."""
64
+ x = da.ones((100, 100), chunks=(10, 10))
65
+
66
+ result = da.sin(x)[:5, :10]
67
+ expected = da.sin(x[:5, :10])
68
+
69
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
70
+
71
+
72
+ # =============================================================================
73
+ # Case 3: Broadcasting
74
+ # - Smaller input has fewer indices
75
+ # - Need to only slice dimensions that exist in the smaller input
76
+ # =============================================================================
77
+
78
+
79
+ def test_slice_through_broadcast_row():
80
+ """Slice through broadcasting with a row vector."""
81
+ arr = np.arange(100).reshape(10, 10)
82
+ row = np.arange(10)
83
+
84
+ x = da.from_array(arr, chunks=(5, 5))
85
+ r = da.from_array(row, chunks=5)
86
+
87
+ # (x + r)[:3, :4] should simplify to x[:3, :4] + r[:4]
88
+ # Note: expected also needs simplify because slices push into from_array regions
89
+ result = (x + r)[:3, :4]
90
+ expected = x[:3, :4] + r[:4]
91
+
92
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
93
+ assert_eq(result, arr[:3, :4] + row[:4])
94
+
95
+
96
+ def test_slice_through_broadcast_column():
97
+ """Slice through broadcasting with a column vector."""
98
+ arr = np.arange(100).reshape(10, 10)
99
+ col = np.arange(10).reshape(10, 1)
100
+
101
+ x = da.from_array(arr, chunks=(5, 5))
102
+ c = da.from_array(col, chunks=(5, 1))
103
+
104
+ # (x + c)[:3, :4] should simplify to x[:3, :4] + c[:3, :]
105
+ result = (x + c)[:3, :4]
106
+ expected = x[:3, :4] + c[:3, :]
107
+
108
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
109
+ assert_eq(result, arr[:3, :4] + col[:3, :])
110
+
111
+
112
+ def test_slice_through_broadcast_scalar():
113
+ """Slice through broadcasting with a scalar."""
114
+ arr = np.arange(100).reshape(10, 10)
115
+
116
+ x = da.from_array(arr, chunks=(5, 5))
117
+
118
+ # (x + 5)[:3, :4] should simplify to x[:3, :4] + 5
119
+ result = (x + 5)[:3, :4]
120
+ expected = x[:3, :4] + 5
121
+
122
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
123
+ assert_eq(result, arr[:3, :4] + 5)
124
+
125
+
126
+ def test_slice_through_broadcast_size_one_dims():
127
+ """Slice through Elemwise where inputs have size-1 dims that broadcast.
128
+
129
+ When inputs have different size-1 dimensions that broadcast together,
130
+ slicing the output should preserve those size-1 dimensions rather than
131
+ applying the output slice to them.
132
+
133
+ This test covers the case where:
134
+ - Input a has shape (1, M, 1) with size-1 dims at positions 0 and 2
135
+ - Input b has shape (1, 1, N) with size-1 dims at positions 0 and 1
136
+ - Output broadcasts to (1, M, N)
137
+ - Slicing output[:, m1:m2, n1:n2] should produce:
138
+ - a[:, m1:m2, :] + b[:, :, n1:n2] (preserving size-1 dims)
139
+ """
140
+ # Create inputs with size-1 dims in different positions
141
+ a_np = np.arange(20).reshape(1, 20, 1)
142
+ b_np = np.arange(30).reshape(1, 1, 30)
143
+
144
+ a = da.from_array(a_np, chunks=(1, 10, 1))
145
+ b = da.from_array(b_np, chunks=(1, 1, 15))
146
+
147
+ # Output broadcasts to (1, 20, 30)
148
+ result = a + b
149
+ assert result.shape == (1, 20, 30)
150
+
151
+ # Slice the output - this should not fail during simplify
152
+ sliced = result[:, 5:10, 10:20]
153
+ assert sliced.shape == (1, 5, 10)
154
+
155
+ # Simplify should succeed (was failing before fix)
156
+ simplified = sliced.expr.simplify()
157
+ assert simplified is not None
158
+
159
+ # Verify computed values are correct
160
+ expected = (a_np + b_np)[:, 5:10, 10:20]
161
+ assert_eq(sliced, expected)
162
+
163
+
164
+ def test_slice_through_where_with_broadcast():
165
+ """Slice through where() with broadcast condition.
166
+
167
+ Regression test for xarray integration - slicing through Where
168
+ with broadcast inputs was failing due to incorrect size-1 handling.
169
+ """
170
+ # Broadcast condition from size-1 dims
171
+ cond = (
172
+ da.ones((10, 1, 1), dtype=bool, chunks=(5, 1, 1))
173
+ & da.ones((1, 20, 1), dtype=bool, chunks=(1, 10, 1))
174
+ & da.ones((1, 1, 30), dtype=bool, chunks=(1, 1, 15))
175
+ )
176
+
177
+ result = da.where(cond, da.ones((10, 20, 30), chunks=(5, 10, 15)), np.nan)
178
+ sliced = result[:, 5:15, 10:25]
179
+
180
+ # Simplify should succeed (was failing before fix)
181
+ sliced.expr.simplify()
182
+ assert_eq(sliced, np.ones((10, 10, 15)))
183
+
184
+
185
+ def test_slice_through_shuffle_non_shuffle_axis():
186
+ """Slice pushes through Shuffle when slicing non-shuffle axes."""
187
+ arr = np.arange(100 * 50 * 60).reshape(100, 50, 60)
188
+ x = da.from_array(arr, chunks=(1, 25, 30)) # chunks=1 on axis 0
189
+
190
+ # Fancy indexing creates Shuffle; use non-identity to prevent simplification
191
+ indices = list(range(50)) + list(range(99, 49, -1)) # 0-49, then 99-50 reversed
192
+ shuffled = x[indices, :, :]
193
+ result = shuffled[:, 10:20, 30:40]
194
+
195
+ # Expected: slice pushed through, so shuffle input is sliced
196
+ # x[:, 10:20, 30:40] then shuffled, not x shuffled then sliced
197
+ expected = x[:, 10:20, 30:40][indices, :, :]
198
+
199
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
200
+ assert_eq(result, arr[indices, :, :][:, 10:20, 30:40])
201
+
202
+
203
+ def test_slice_through_shuffle_on_shuffle_axis():
204
+ """Slice on shuffle axis pushes through when input indices are contiguous.
205
+
206
+ This optimization applies to xarray's unstack pattern where the shuffle
207
+ indexer maps contiguous ranges (identity-like with possible padding).
208
+ """
209
+ from dask_array._new_collection import new_collection
210
+ from dask_array._shuffle import _shuffle
211
+
212
+ arr = np.arange(100 * 50).reshape(100, 50)
213
+ x = da.from_array(arr, chunks=(1, 25))
214
+
215
+ # Simulate xarray unstack: identity shuffle with single-element chunks
216
+ # This is exactly what xarray produces for time dimension restructuring
217
+ indexer = [[i] for i in range(100)]
218
+ shuffled = new_collection(_shuffle(x.expr, indexer, axis=0, name="shuffle"))
219
+ result = shuffled[20:40, :]
220
+
221
+ # Expected: input sliced to [20:40], indexer adjusted
222
+ adjusted_indexer = [[i] for i in range(20)]
223
+ expected = new_collection(_shuffle(x[20:40, :].expr, adjusted_indexer, axis=0, name="shuffle"))
224
+
225
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
226
+ assert_eq(result, arr[20:40, :])
227
+
228
+
229
+ def test_slice_through_grouped_shuffle_on_shuffle_axis():
230
+ arr = np.arange(8)
231
+ x = da.from_array(arr, chunks=4)
232
+ indexer = np.array([6, 5, 2, 4, 1, 3, 0, 7])
233
+
234
+ result = x[indexer][1:4]
235
+
236
+ assert_eq(result, arr[indexer][1:4])
237
+ assert_eq(da.Array(result.expr.optimize(fuse=False)), arr[indexer][1:4])
238
+
239
+
240
+ # =============================================================================
241
+ # Case 4: new_axes - Blockwise adds dimensions
242
+ # - Slice on a new axis doesn't correspond to input
243
+ # - Should NOT push through (or handle specially)
244
+ # =============================================================================
245
+
246
+
247
+ def test_slice_new_axis_not_pushed():
248
+ """Slicing on a new_axis dimension should not push through naively."""
249
+ arr = np.arange(100).reshape(10, 10)
250
+ x = da.from_array(arr, chunks=(5, 5))
251
+
252
+ # map_blocks that adds a new axis
253
+ y = da.map_blocks(lambda b: b[..., np.newaxis], x, new_axis=2, dtype=arr.dtype)
254
+
255
+ # Slice on the new axis - this shouldn't cause issues
256
+ result = y[:3, :4, :]
257
+ expected = arr[:3, :4, np.newaxis]
258
+
259
+ assert_eq(result, expected)
260
+
261
+
262
+ def test_slice_symbolic_new_axis_not_pushed():
263
+ x_np = np.arange(6)
264
+ x = da.from_array(x_np, chunks=3)
265
+
266
+ y = da.blockwise(
267
+ lambda block: np.broadcast_to(block[:, None], (block.shape[0], 5)),
268
+ "az",
269
+ x,
270
+ "a",
271
+ new_axes={"z": 5},
272
+ dtype=x.dtype,
273
+ )
274
+ result = y[:, :2]
275
+ expected = np.broadcast_to(x_np[:, None], (6, 5))[:, :2]
276
+
277
+ assert_eq(result, expected)
278
+ assert_eq(da.Array(result.expr.optimize(fuse=False)), expected)
279
+
280
+
281
+ def test_slice_only_new_axis():
282
+ """Slicing only the new axis dimension."""
283
+ arr = np.arange(100).reshape(10, 10)
284
+ x = da.from_array(arr, chunks=(5, 5))
285
+
286
+ # Add new axis of size > 1
287
+ y = da.map_blocks(
288
+ lambda b: np.repeat(b[..., np.newaxis], 3, axis=2),
289
+ x,
290
+ new_axis=2,
291
+ chunks=(5, 5, 3),
292
+ dtype=arr.dtype,
293
+ )
294
+
295
+ # Slice on the new axis
296
+ result = y[:, :, :2]
297
+ # This is complex - the slice on axis 2 can't push to input
298
+
299
+ assert_eq(result, np.repeat(arr[..., np.newaxis], 3, axis=2)[:, :, :2])
300
+
301
+
302
+ # =============================================================================
303
+ # Case 5: drop_axis / contraction
304
+ # - Input has more dimensions than output
305
+ # - Some input indices don't appear in output
306
+ # =============================================================================
307
+
308
+
309
+ def test_slice_through_drop_axis():
310
+ """Slice through a drop_axis operation."""
311
+ arr = np.arange(100).reshape(10, 10)
312
+ x = da.from_array(arr, chunks=(5, 5))
313
+
314
+ # map_blocks that drops axis 0
315
+ y = da.map_blocks(lambda b: b.sum(axis=0), x, drop_axis=0, dtype=arr.dtype)
316
+
317
+ # y has shape (10,), slicing [:5] should map to x[:, :5]
318
+ result = y[:5]
319
+ expected = arr.sum(axis=0)[:5]
320
+
321
+ assert_eq(result, expected)
322
+
323
+
324
+ def test_slice_through_drop_axis_1():
325
+ """Slice through dropping axis 1."""
326
+ arr = np.arange(100).reshape(10, 10)
327
+ x = da.from_array(arr, chunks=(5, 5))
328
+
329
+ # map_blocks that drops axis 1
330
+ y = da.map_blocks(lambda b: b.sum(axis=1), x, drop_axis=1, dtype=arr.dtype)
331
+
332
+ # y has shape (10,), slicing [:5] should map to x[:5, :]
333
+ result = y[:5]
334
+ expected = arr.sum(axis=1)[:5]
335
+
336
+ assert_eq(result, expected)
337
+
338
+
339
+ # =============================================================================
340
+ # Case 6: adjust_chunks
341
+ # - Chunk sizes change in the output
342
+ # - Slice indices may not map correctly
343
+ # =============================================================================
344
+
345
+
346
+ def test_slice_adjust_chunks():
347
+ """Slice through an operation that adjusts chunks."""
348
+ arr = np.arange(100).reshape(10, 10)
349
+ x = da.from_array(arr, chunks=(5, 5))
350
+
351
+ # Double each chunk along axis 0
352
+ def double_rows(block):
353
+ return np.repeat(block, 2, axis=0)
354
+
355
+ y = da.map_blocks(
356
+ double_rows,
357
+ x,
358
+ chunks=(10, 5), # chunks double in size
359
+ dtype=arr.dtype,
360
+ )
361
+
362
+ # y has shape (20, 10)
363
+ result = y[:5, :5]
364
+ expected = np.repeat(arr, 2, axis=0)[:5, :5]
365
+
366
+ assert_eq(result, expected)
367
+
368
+
369
+ # =============================================================================
370
+ # Case 7: Multiple inputs with different shapes
371
+ # - Inputs align via broadcasting
372
+ # - Need to map slice to each input appropriately
373
+ # =============================================================================
374
+
375
+
376
+ def test_slice_multiple_inputs_same_shape():
377
+ """Slice through blockwise with multiple same-shaped inputs."""
378
+ arr1 = np.arange(100).reshape(10, 10)
379
+ arr2 = np.arange(100, 200).reshape(10, 10)
380
+
381
+ x = da.from_array(arr1, chunks=(5, 5))
382
+ y = da.from_array(arr2, chunks=(5, 5))
383
+
384
+ # (x + y)[:3, :4] should simplify to x[:3, :4] + y[:3, :4]
385
+ result = (x + y)[:3, :4]
386
+ expected = x[:3, :4] + y[:3, :4]
387
+
388
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
389
+ assert_eq(result, arr1[:3, :4] + arr2[:3, :4])
390
+
391
+
392
+ def test_slice_multiple_inputs_broadcast():
393
+ """Slice through blockwise with broadcasting inputs."""
394
+ arr = np.arange(100).reshape(10, 10)
395
+ vec = np.arange(10)
396
+
397
+ x = da.from_array(arr, chunks=(5, 5))
398
+ v = da.from_array(vec, chunks=5)
399
+
400
+ # (x * v)[:3, :4] should simplify to x[:3, :4] * v[:4]
401
+ result = (x * v)[:3, :4]
402
+ expected = x[:3, :4] * v[:4]
403
+
404
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
405
+ assert_eq(result, arr[:3, :4] * vec[:4])
406
+
407
+
408
+ # =============================================================================
409
+ # Correctness tests - verify computed values
410
+ # =============================================================================
411
+
412
+
413
+ @pytest.mark.parametrize(
414
+ "shape,chunks,axis,slice_",
415
+ [
416
+ ((100, 100), (10, 10), 0, slice(5)),
417
+ ((100, 100), (10, 10), 1, slice(5)),
418
+ ((100, 100), (10, 10), 0, slice(10, 20)),
419
+ ((100, 100), (10, 10), 1, slice(10, 20)),
420
+ ((50, 50, 50), (10, 10, 10), 0, slice(5)),
421
+ ((50, 50, 50), (10, 10, 10), 1, slice(5)),
422
+ ((50, 50, 50), (10, 10, 10), 2, slice(5)),
423
+ ],
424
+ )
425
+ def test_slice_through_reduction_correctness(shape, chunks, axis, slice_):
426
+ """Verify slice-through-reduction produces correct values."""
427
+ arr = np.random.random(shape)
428
+ x = da.from_array(arr, chunks=chunks)
429
+
430
+ # Build the slice tuple for the output
431
+ out_ndim = len(shape) - 1 # reduction removes one axis
432
+ slices = [slice(None)] * out_ndim
433
+ slices[0] = slice_
434
+
435
+ result = x.sum(axis=axis)[tuple(slices)]
436
+ expected = arr.sum(axis=axis)[tuple(slices)]
437
+
438
+ assert_eq(result, expected)
439
+
440
+
441
+ # =============================================================================
442
+ # Verify optimization is/isn't applied
443
+ # =============================================================================
444
+
445
+
446
+ def test_optimization_applied_to_reduction():
447
+ """Verify optimization IS applied: slice pushed through reduction."""
448
+ from dask_array.reductions._reduction import Reduction
449
+ from dask_array.slicing import SliceSlicesIntegers
450
+
451
+ x = da.ones((100, 100), chunks=(10, 10))
452
+ y = x.sum(axis=0)[:5]
453
+
454
+ # Before simplification: Slice(Reduction(...))
455
+ assert isinstance(y.expr, SliceSlicesIntegers)
456
+
457
+ # After simplification: Reduction(Slice(...)) - slice pushed through
458
+ simplified = y.expr.simplify()
459
+ assert not isinstance(simplified, SliceSlicesIntegers)
460
+ assert isinstance(simplified, Reduction)
461
+
462
+
463
+ def test_optimization_pushes_through_new_axes_when_safe():
464
+ """Verify slice pushes through new_axes when not slicing the new axis."""
465
+ from dask_array.slicing import SliceSlicesIntegers
466
+
467
+ x = da.ones((20, 20), chunks=(5, 5))
468
+ y = da.map_blocks(lambda b: b[..., np.newaxis], x, new_axis=2, dtype=float)
469
+ z = y[:5, :5, :] # Not slicing the new axis (axis 2)
470
+
471
+ # The slice CAN push through because we're not slicing axis 2
472
+ simplified = z.expr.simplify()
473
+ assert not isinstance(simplified, SliceSlicesIntegers)
474
+ assert_eq(z, np.ones((20, 20))[:5, :5, np.newaxis])
475
+
476
+
477
+ def test_optimization_not_applied_slicing_new_axes():
478
+ """Verify optimization is NOT applied when slicing new_axes dimension."""
479
+ from dask_array.slicing import SliceSlicesIntegers
480
+
481
+ x = da.ones((20, 20), chunks=(5, 5))
482
+ # Add new axis of size 3
483
+ y = da.map_blocks(
484
+ lambda b: np.repeat(b[..., np.newaxis], 3, axis=2),
485
+ x,
486
+ new_axis=2,
487
+ chunks=(5, 5, 3),
488
+ dtype=float,
489
+ )
490
+ z = y[:5, :5, :2] # Slicing the new axis (axis 2)
491
+
492
+ # The slice should NOT push through because we're slicing axis 2
493
+ simplified = z.expr.simplify()
494
+ assert isinstance(simplified, SliceSlicesIntegers)
495
+
496
+
497
+ def test_optimization_reduces_tasks():
498
+ """Verify optimization reduces task count for from_array."""
499
+ arr = np.ones((100, 100))
500
+ x = da.from_array(arr, chunks=(10, 10))
501
+
502
+ full = x.sum(axis=0)
503
+ sliced = x.sum(axis=0)[:5]
504
+
505
+ full_tasks = len(full.optimize().__dask_graph__())
506
+ sliced_tasks = len(sliced.optimize().__dask_graph__())
507
+
508
+ # Sliced should have fewer tasks (only processes 1 column of chunks)
509
+ assert sliced_tasks < full_tasks
510
+
511
+
512
+ # =============================================================================
513
+ # Case 8: Tensordot / Matmul
514
+ # - adjust_chunks only affects contracted dimension
515
+ # - Slices on non-contracted dimensions can push through
516
+ # =============================================================================
517
+
518
+
519
+ @pytest.mark.filterwarnings("ignore::dask.array.core.PerformanceWarning")
520
+ def test_slice_through_tensordot_correctness():
521
+ """Verify slice through tensordot produces correct values."""
522
+ arr = np.random.random((100, 100))
523
+ x = da.from_array(arr, chunks=(10, 10))
524
+
525
+ result = x.dot(x.T)[:5, :5]
526
+ expected = arr.dot(arr.T)[:5, :5]
527
+
528
+ assert_eq(result, expected)
529
+
530
+
531
+ @pytest.mark.filterwarnings("ignore::dask.array.core.PerformanceWarning")
532
+ def test_slice_through_matmul_correctness():
533
+ """Verify slice through matmul produces correct values."""
534
+ arr1 = np.random.random((100, 50))
535
+ arr2 = np.random.random((50, 100))
536
+ x = da.from_array(arr1, chunks=(10, 10))
537
+ y = da.from_array(arr2, chunks=(10, 10))
538
+
539
+ result = (x @ y)[:5, :5]
540
+ expected = (arr1 @ arr2)[:5, :5]
541
+
542
+ assert_eq(result, expected)
543
+
544
+
545
+ @pytest.mark.filterwarnings("ignore::dask.array.core.PerformanceWarning")
546
+ def test_slice_through_matmul_expression_structure():
547
+ """Verify x.dot(y)[a:b, c:d] simplifies to x[a:b, :].dot(y[:, c:d])."""
548
+ x = da.ones((100, 50), chunks=(10, 10))
549
+ y = da.ones((50, 100), chunks=(10, 10))
550
+
551
+ # Use different slices to verify correct operand mapping
552
+ result = (x @ y)[:15, :25]
553
+ expected = x[:15, :] @ y[:, :25]
554
+
555
+ # Both should simplify to equivalent expressions
556
+ assert result.expr.simplify()._name == expected.expr.simplify()._name
557
+
558
+
559
+ @pytest.mark.filterwarnings("ignore::dask.array.core.PerformanceWarning")
560
+ def test_slice_through_tensordot_reduces_tasks():
561
+ """Verify slice through tensordot reduces task count.
562
+
563
+ x.dot(x.T)[0:5, 0:5] should optimize to compute only the
564
+ submatrix, not the full matrix then slice.
565
+ """
566
+ x = da.ones((100, 100), chunks=(10, 10))
567
+
568
+ full = x.dot(x.T)
569
+ sliced = x.dot(x.T)[:5, :5]
570
+
571
+ full_tasks = len(full.optimize().__dask_graph__())
572
+ sliced_tasks = len(sliced.optimize().__dask_graph__())
573
+
574
+ # Sliced should have significantly fewer tasks
575
+ # Full: 10x10 output chunks = 100 output chunks
576
+ # Sliced: 1x1 output chunks = 1 output chunk
577
+ # Task reduction should be ~10x or more
578
+ assert sliced_tasks < full_tasks / 5
579
+
580
+
581
+ # =============================================================================
582
+ # Regression tests
583
+ # =============================================================================
584
+
585
+
586
+ def test_integer_index_on_size_one_dim_through_elemwise():
587
+ """Integer indexing on size-1 dims must remove the dimension.
588
+
589
+ Regression test: when Elemwise._accept_slice pushed integer indices
590
+ through size-1 dimensions, it was incorrectly converting them to
591
+ slice(None), keeping the dimension instead of removing it.
592
+ """
593
+ arr = da.from_array(np.random.randn(8, 9, 10), chunks=(8, 9, 10))
594
+ shuffled = da.shuffle(arr, [[0]], axis=2) # -> (8, 9, 1)
595
+
596
+ # Elemwise on top of shuffle
597
+ cond = da.from_array(np.array([True]), chunks=(1,))
598
+ elemwise = da.where(cond, shuffled, np.nan)
599
+
600
+ # Integer index should remove the dimension
601
+ indexed = elemwise[:, :, 0]
602
+ assert indexed.shape == (8, 9)
603
+ assert indexed.compute().shape == (8, 9)
604
+
605
+
606
+ def test_integer_index_through_elemwise_broadcast():
607
+ """Integer index through Elemwise with broadcasting preserves semantics."""
608
+ # Array with size-1 dimension
609
+ x = da.ones((10, 1, 20), chunks=(5, 1, 10))
610
+ y = da.ones((10, 15, 20), chunks=(5, 5, 10))
611
+
612
+ result = (x + y)[:, :, 0]
613
+
614
+ # Integer index on axis 2 should remove it
615
+ assert result.shape == (10, 15)
616
+ assert_eq(result, np.ones((10, 15)) * 2)
617
+
618
+
619
+ # =============================================================================
620
+ # Regression tests for empty slice handling
621
+ # =============================================================================
622
+
623
+
624
+ def test_empty_slice_through_elemwise_broadcast():
625
+ """Empty slice through Elemwise with broadcast preserves empty output.
626
+
627
+ Regression test: empty slices like [:0] on broadcast dimensions were
628
+ incorrectly replaced with [:], producing non-empty output.
629
+ """
630
+ scalar_da = da.from_array(np.float32(0.0), chunks=-1)
631
+ arr_da = da.from_array(np.array([[0.0]], dtype="float32"), chunks=-1)
632
+
633
+ # scalar () + (1, 1) broadcasts to (1, 1)
634
+ added = scalar_da + arr_da
635
+ assert added.shape == (1, 1)
636
+
637
+ # [0, :0] should give shape (0,) - empty array
638
+ result = added[0, :0]
639
+ assert result.shape == (0,)
640
+ assert result.compute().shape == (0,)
641
+
642
+
643
+ def test_integer_index_out_of_bounds_on_broadcast_dim():
644
+ """Integer index larger than input size works on broadcast dimension.
645
+
646
+ Regression test: integer indices like [1] on size-1 broadcast dimensions
647
+ were applied directly, causing IndexError.
648
+ """
649
+ scalar = da.from_array(np.float32(0.0), chunks=-1)
650
+ arr1 = da.from_array(np.array([[0.0, 1.0]], dtype="float32"), chunks=-1) # (1, 2)
651
+ arr2 = da.from_array(np.zeros((1, 1, 1, 1), dtype="float32"), chunks=-1)
652
+
653
+ # scalar + (1, 2) + (1, 1, 1, 1) = (1, 1, 1, 2)
654
+ result = scalar + arr1 + arr2
655
+ assert result.shape == (1, 1, 1, 2)
656
+
657
+ # [0, 0, 0, 1] - the index 1 on axis 3 is valid for output but the
658
+ # (1, 1, 1, 1) input only has size 1 on that axis (broadcast)
659
+ indexed = result[0, 0, 0, 1]
660
+ assert indexed.shape == ()
661
+ assert indexed.compute() == 1.0 # arr1[0, 1] = 1.0
662
+
663
+
664
+ def test_empty_slice_not_pushed_through_reduction():
665
+ """Empty slice after reduction is not pushed through.
666
+
667
+ Regression test: pushing empty slices through reductions created invalid
668
+ task graphs because the reduction machinery doesn't handle empty
669
+ non-reduced dimensions.
670
+ """
671
+ arr = da.from_array(np.zeros((1, 2, 1, 1), dtype="float32"), chunks=-1)
672
+ reduced = da.nanmin(arr, axis=(1, 2, 3)) # (1,)
673
+
674
+ # [:-1] on (1,) gives (0,) - empty array
675
+ sliced = reduced[:-1]
676
+ assert sliced.shape == (0,)
677
+ result = sliced.compute()
678
+ assert result.shape == (0,)