dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
"""Tests for slice pushdown through MapOverlap.
|
|
2
|
+
|
|
3
|
+
These tests verify that slicing operations can be pushed through map_overlap
|
|
4
|
+
operations, reducing computation by slicing input arrays before applying
|
|
5
|
+
overlap boundaries.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
import pytest
|
|
12
|
+
|
|
13
|
+
import dask_array as da
|
|
14
|
+
from dask_array._test_utils import assert_eq
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def add_neighbors(x):
|
|
18
|
+
"""Add neighboring values along axis 0. Uses overlap data."""
|
|
19
|
+
result = x.copy()
|
|
20
|
+
if x.shape[0] > 2:
|
|
21
|
+
result[1:-1] = x[:-2] + x[1:-1] + x[2:]
|
|
22
|
+
return result
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# =============================================================================
|
|
26
|
+
# Case 1: Slice on non-overlap axis (should push through)
|
|
27
|
+
# =============================================================================
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def test_slice_through_overlap_non_overlap_axis():
|
|
31
|
+
"""Slice on axis without overlap pushes through."""
|
|
32
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
33
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
34
|
+
|
|
35
|
+
# Overlap only on axis 0
|
|
36
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
37
|
+
|
|
38
|
+
# Slice on axis 1 (no overlap) - should be equivalent to slicing input first
|
|
39
|
+
sliced = result[:, :20]
|
|
40
|
+
expected = x[:, :20].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
41
|
+
|
|
42
|
+
# Verify expression structure matches
|
|
43
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def test_slice_through_overlap_middle_slice():
|
|
47
|
+
"""Slice in the middle of non-overlap axis."""
|
|
48
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
49
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
50
|
+
|
|
51
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
52
|
+
|
|
53
|
+
# Middle slice on axis 1 (no overlap)
|
|
54
|
+
sliced = result[:, 30:70]
|
|
55
|
+
expected = x[:, 30:70].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
56
|
+
|
|
57
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def test_slice_through_overlap_correctness():
|
|
61
|
+
"""Verify slice through overlap produces correct values."""
|
|
62
|
+
arr = np.arange(64).reshape((8, 8)).astype(float)
|
|
63
|
+
x = da.from_array(arr, chunks=(4, 4))
|
|
64
|
+
|
|
65
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
66
|
+
|
|
67
|
+
# Slice on axis 1
|
|
68
|
+
sliced = result[:, 2:6]
|
|
69
|
+
expected = x[:, 2:6].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
70
|
+
|
|
71
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# =============================================================================
|
|
75
|
+
# Case 2: Slice on overlap axis (pushes through with padding)
|
|
76
|
+
# =============================================================================
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_slice_on_overlap_axis_pushes_with_padding():
|
|
80
|
+
"""Slice on axis with overlap pushes through with padded input."""
|
|
81
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
82
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
83
|
+
|
|
84
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
85
|
+
|
|
86
|
+
# Slice on axis 0 (has overlap) - should push through with padded input
|
|
87
|
+
# [:50] with depth=2 needs input [:52], then trim to [:50]
|
|
88
|
+
sliced = result[:50, :]
|
|
89
|
+
expected = x[:52, :].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")[:50, :]
|
|
90
|
+
|
|
91
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_slice_on_both_axes_one_has_overlap():
|
|
95
|
+
"""Slice on both axes when one has overlap."""
|
|
96
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
97
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
98
|
+
|
|
99
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
100
|
+
sliced = result[:50, :50]
|
|
101
|
+
|
|
102
|
+
# Axis 1 has no overlap: slice pushes directly
|
|
103
|
+
# Axis 0 has depth=2: need padded input [:52], then trim to [:50]
|
|
104
|
+
expected = x[:52, :50].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")[:50, :]
|
|
105
|
+
|
|
106
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# =============================================================================
|
|
110
|
+
# Case 3: Multi-dimensional overlap
|
|
111
|
+
# =============================================================================
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def add_neighbors_2d(x):
|
|
115
|
+
"""Add neighboring values along both axes. Uses overlap data."""
|
|
116
|
+
result = x.copy()
|
|
117
|
+
if x.shape[0] > 2:
|
|
118
|
+
result[1:-1, :] += x[:-2, :] + x[2:, :]
|
|
119
|
+
if x.shape[1] > 2:
|
|
120
|
+
result[:, 1:-1] += x[:, :-2] + x[:, 2:]
|
|
121
|
+
return result
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def test_slice_through_2d_overlap():
|
|
125
|
+
"""Slice through 2D overlap - pushes when beneficial."""
|
|
126
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
127
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
128
|
+
|
|
129
|
+
result = x.map_overlap(add_neighbors_2d, depth={0: 1, 1: 1}, boundary="none")
|
|
130
|
+
|
|
131
|
+
# Slice on axis 1 with depth=1 needs input [:, :41], then trim to [:, :40]
|
|
132
|
+
sliced = result[:, :40]
|
|
133
|
+
expected = x[:, :41].map_overlap(add_neighbors_2d, depth={0: 1, 1: 1}, boundary="none")[:, :40]
|
|
134
|
+
|
|
135
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def test_slice_through_2d_overlap_middle():
|
|
139
|
+
"""Middle slice through 2D overlap on non-overlap dimension."""
|
|
140
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
141
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
142
|
+
|
|
143
|
+
# Overlap only on axis 0
|
|
144
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
145
|
+
|
|
146
|
+
# Middle slice on axis 1 (no overlap)
|
|
147
|
+
sliced = result[:, 25:75]
|
|
148
|
+
expected = x[:, 25:75].map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
149
|
+
|
|
150
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def test_slice_through_1d_overlap_on_3d_array():
|
|
154
|
+
"""Slice on multiple non-overlap axes."""
|
|
155
|
+
arr = np.arange(1000).reshape((10, 10, 10)).astype(float)
|
|
156
|
+
x = da.from_array(arr, chunks=(5, 5, 5))
|
|
157
|
+
|
|
158
|
+
# Overlap only on axis 0
|
|
159
|
+
result = x.map_overlap(add_neighbors, depth={0: 1, 1: 0, 2: 0}, boundary="none")
|
|
160
|
+
|
|
161
|
+
# Slice on axes 1 and 2 (neither has overlap)
|
|
162
|
+
sliced = result[:, :3, :3]
|
|
163
|
+
expected = x[:, :3, :3].map_overlap(add_neighbors, depth={0: 1, 1: 0, 2: 0}, boundary="none")
|
|
164
|
+
|
|
165
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
# =============================================================================
|
|
169
|
+
# Case 4: Asymmetric overlap
|
|
170
|
+
# =============================================================================
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def test_slice_through_asymmetric_overlap():
|
|
174
|
+
"""Slice through asymmetric overlap (different left/right depth)."""
|
|
175
|
+
arr = np.arange(64).reshape((8, 8)).astype(float)
|
|
176
|
+
x = da.from_array(arr, chunks=(4, 4))
|
|
177
|
+
|
|
178
|
+
# Asymmetric overlap on axis 0
|
|
179
|
+
result = x.map_overlap(add_neighbors, depth={0: (2, 1), 1: 0}, boundary="none")
|
|
180
|
+
|
|
181
|
+
# Slice on axis 1 (no overlap)
|
|
182
|
+
sliced = result[:, 2:6]
|
|
183
|
+
expected = x[:, 2:6].map_overlap(add_neighbors, depth={0: (2, 1), 1: 0}, boundary="none")
|
|
184
|
+
|
|
185
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def test_slice_on_asymmetric_overlap_axis_pushes():
|
|
189
|
+
"""Slice on axis with asymmetric overlap pushes through with padding."""
|
|
190
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
191
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
192
|
+
|
|
193
|
+
result = x.map_overlap(add_neighbors, depth={0: (2, 1), 1: 0}, boundary="none")
|
|
194
|
+
|
|
195
|
+
# Slice axis 0 with asymmetric depth (2, 1) - needs extra 1 on right
|
|
196
|
+
# [:50] needs input [:51], then trim to [:50]
|
|
197
|
+
sliced = result[:50, :]
|
|
198
|
+
expected = x[:51, :].map_overlap(add_neighbors, depth={0: (2, 1), 1: 0}, boundary="none")[:50, :]
|
|
199
|
+
|
|
200
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# =============================================================================
|
|
204
|
+
# Case 5: Zero overlap (edge case)
|
|
205
|
+
# =============================================================================
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def test_slice_through_zero_overlap():
|
|
209
|
+
"""Slice through axis with zero overlap pushes through."""
|
|
210
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
211
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
212
|
+
|
|
213
|
+
# Zero overlap - no actual overlap computation needed
|
|
214
|
+
result = x.map_overlap(add_neighbors, depth=0, boundary="none")
|
|
215
|
+
|
|
216
|
+
# Slice on axis 0 - with zero overlap, slice should push through
|
|
217
|
+
sliced = result[:50, :]
|
|
218
|
+
expected = x[:50, :].map_overlap(add_neighbors, depth=0, boundary="none")
|
|
219
|
+
|
|
220
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# =============================================================================
|
|
224
|
+
# Case 6: Task reduction verification
|
|
225
|
+
# =============================================================================
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def test_slice_through_overlap_reduces_tasks():
|
|
229
|
+
"""Verify slice pushdown reduces task count."""
|
|
230
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
231
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
232
|
+
|
|
233
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
234
|
+
|
|
235
|
+
full = result
|
|
236
|
+
sliced = result[:, :10] # Take only first 10 columns
|
|
237
|
+
|
|
238
|
+
full_tasks = len(full.optimize().__dask_graph__())
|
|
239
|
+
sliced_tasks = len(sliced.optimize().__dask_graph__())
|
|
240
|
+
|
|
241
|
+
# Sliced should have fewer tasks (processes 1 column of chunks vs 10)
|
|
242
|
+
assert sliced_tasks < full_tasks
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
def test_slice_through_overlap_reduces_numblocks():
|
|
246
|
+
"""Verify slice pushdown reduces number of output blocks."""
|
|
247
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
248
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
249
|
+
|
|
250
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
251
|
+
sliced = result[:, :10]
|
|
252
|
+
|
|
253
|
+
# Full result: 10x10 chunks
|
|
254
|
+
assert result.numblocks == (10, 10)
|
|
255
|
+
|
|
256
|
+
# Sliced result: 10x1 chunks (only 1 column of blocks)
|
|
257
|
+
assert sliced.numblocks == (10, 1)
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
# =============================================================================
|
|
261
|
+
# Case 7: Correctness with computed values
|
|
262
|
+
# =============================================================================
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@pytest.mark.parametrize(
|
|
266
|
+
"shape,chunks,depth,slice_",
|
|
267
|
+
[
|
|
268
|
+
# Start slices (:n form) on non-overlap axes
|
|
269
|
+
((80, 80), (20, 20), {0: 2, 1: 0}, (slice(None), slice(20))),
|
|
270
|
+
((80, 80), (20, 20), {0: 0, 1: 2}, (slice(20), slice(None))),
|
|
271
|
+
# Middle slices (k:n form) on non-overlap axes
|
|
272
|
+
((80, 80), (20, 20), {0: 2, 1: 0}, (slice(None), slice(20, 60))),
|
|
273
|
+
((80, 80), (20, 20), {0: 0, 1: 2}, (slice(20, 60), slice(None))),
|
|
274
|
+
# End slices (k: form) on non-overlap axes
|
|
275
|
+
((80, 80), (20, 20), {0: 2, 1: 0}, (slice(None), slice(40, None))),
|
|
276
|
+
((80, 80), (20, 20), {0: 0, 1: 2}, (slice(40, None), slice(None))),
|
|
277
|
+
],
|
|
278
|
+
)
|
|
279
|
+
def test_slice_through_overlap_parametrized(shape, chunks, depth, slice_):
|
|
280
|
+
"""Parametrized correctness tests for slice through overlap."""
|
|
281
|
+
arr = np.arange(np.prod(shape)).reshape(shape).astype(float)
|
|
282
|
+
x = da.from_array(arr, chunks=chunks)
|
|
283
|
+
|
|
284
|
+
result = x.map_overlap(add_neighbors, depth=depth, boundary="none")
|
|
285
|
+
sliced = result[slice_]
|
|
286
|
+
|
|
287
|
+
# Build expected: slice input first, then overlap
|
|
288
|
+
input_sliced = x[slice_]
|
|
289
|
+
expected = input_sliced.map_overlap(add_neighbors, depth=depth, boundary="none")
|
|
290
|
+
|
|
291
|
+
# Verify expression structure matches
|
|
292
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
# =============================================================================
|
|
296
|
+
# Case 8: Special cases (trim=False, uniform depth)
|
|
297
|
+
# =============================================================================
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def test_map_overlap_no_trim_slice_pushes():
|
|
301
|
+
"""With trim=False, slice should push through to input."""
|
|
302
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
303
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
304
|
+
|
|
305
|
+
# With trim=False, there's no Trim wrapper, so slice can push through
|
|
306
|
+
result = x.map_overlap(add_neighbors, depth={0: 2}, boundary="none", trim=False)
|
|
307
|
+
|
|
308
|
+
# Slice on axis 1 (no overlap on axis 1) - pushes directly through
|
|
309
|
+
sliced = result[:, :30]
|
|
310
|
+
expected = x[:, :30].map_overlap(add_neighbors, depth={0: 2}, boundary="none", trim=False)
|
|
311
|
+
|
|
312
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
def test_map_overlap_uniform_depth_correctness():
|
|
316
|
+
"""Test with uniform depth (int instead of dict).
|
|
317
|
+
|
|
318
|
+
When slicing on an axis with overlap, the optimization pads the input
|
|
319
|
+
slice to include data needed for overlap, then trims the output.
|
|
320
|
+
"""
|
|
321
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
322
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
323
|
+
|
|
324
|
+
result = x.map_overlap(add_neighbors_2d, depth=2, boundary="none")
|
|
325
|
+
sliced = result[:, :30]
|
|
326
|
+
|
|
327
|
+
# Expected: pad input by depth on sliced axis, apply overlap, then trim
|
|
328
|
+
# [:, :30] with depth=2 needs input [:, :32] to preserve overlap semantics
|
|
329
|
+
expected = x[:, :32].map_overlap(add_neighbors_2d, depth=2, boundary="none")[:, :30]
|
|
330
|
+
|
|
331
|
+
assert sliced.expr.simplify()._name == expected.expr.simplify()._name
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
# =============================================================================
|
|
335
|
+
# Case 9: Value correctness verification
|
|
336
|
+
# =============================================================================
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def test_slice_through_overlap_value_correctness():
|
|
340
|
+
"""Verify optimized slice produces correct values."""
|
|
341
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
342
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
343
|
+
|
|
344
|
+
result = x.map_overlap(add_neighbors, depth={0: 2, 1: 0}, boundary="none")
|
|
345
|
+
|
|
346
|
+
# Slice on non-overlap axis
|
|
347
|
+
sliced = result[:, :50]
|
|
348
|
+
|
|
349
|
+
# Compare against unoptimized computation
|
|
350
|
+
full_result = result.compute()
|
|
351
|
+
assert_eq(sliced, full_result[:, :50])
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def test_slice_on_overlap_axis_value_correctness():
|
|
355
|
+
"""Verify slice on overlap axis produces correct values."""
|
|
356
|
+
arr = np.arange(10000).reshape((100, 100)).astype(float)
|
|
357
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
358
|
+
|
|
359
|
+
result = x.map_overlap(add_neighbors_2d, depth=2, boundary="none")
|
|
360
|
+
|
|
361
|
+
# Slice on axis with overlap
|
|
362
|
+
sliced = result[:50, :50]
|
|
363
|
+
|
|
364
|
+
# Compare against unoptimized computation
|
|
365
|
+
full_result = result.compute()
|
|
366
|
+
assert_eq(sliced, full_result[:50, :50])
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
"""Tests for slice pushdown through Reshape expressions.
|
|
2
|
+
|
|
3
|
+
Slice can push through Reshape when leading dimensions are preserved,
|
|
4
|
+
i.e., the reshape only affects trailing dimensions.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
import dask_array as da
|
|
13
|
+
from dask_array._test_utils import assert_eq
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# Case 1: Leading dimension preserved - slice should push through
|
|
17
|
+
# =============================================================================
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def test_slice_through_reshape_leading_dim_preserved():
|
|
21
|
+
"""Slice on preserved leading dimension pushes through reshape."""
|
|
22
|
+
arr = np.arange(60).reshape((10, 6))
|
|
23
|
+
x = da.from_array(arr, chunks=(5, 3))
|
|
24
|
+
|
|
25
|
+
# Reshape (10, 6) -> (10, 2, 3) preserves first dimension
|
|
26
|
+
result = x.reshape((10, 2, 3))[:3]
|
|
27
|
+
expected = x[:3].reshape((3, 2, 3))
|
|
28
|
+
|
|
29
|
+
# After simplification, both should have same structure
|
|
30
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
31
|
+
assert_eq(result, arr.reshape((10, 2, 3))[:3])
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_slice_through_reshape_flatten_trailing():
|
|
35
|
+
"""Slice on leading dimension when flattening trailing dims."""
|
|
36
|
+
arr = np.arange(60).reshape((10, 2, 3))
|
|
37
|
+
x = da.from_array(arr, chunks=(5, 2, 3))
|
|
38
|
+
|
|
39
|
+
# Reshape (10, 2, 3) -> (10, 6) preserves first dimension
|
|
40
|
+
result = x.reshape((10, 6))[:4]
|
|
41
|
+
expected = x[:4].reshape((4, 6))
|
|
42
|
+
|
|
43
|
+
# After optimization, both should produce equivalent graphs
|
|
44
|
+
# (Reshape vs ReshapeLowered difference resolved by lowering)
|
|
45
|
+
assert result.optimize()._name == expected.optimize()._name
|
|
46
|
+
assert_eq(result, arr.reshape((10, 6))[:4])
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_slice_through_reshape_multiple_leading_dims():
|
|
50
|
+
"""Slice when multiple leading dimensions are preserved."""
|
|
51
|
+
arr = np.arange(120).reshape((4, 5, 6))
|
|
52
|
+
x = da.from_array(arr, chunks=(2, 5, 3))
|
|
53
|
+
|
|
54
|
+
# Reshape (4, 5, 6) -> (4, 5, 2, 3) preserves first two dimensions
|
|
55
|
+
result = x.reshape((4, 5, 2, 3))[:2, :3]
|
|
56
|
+
expected = x[:2, :3].reshape((2, 3, 2, 3))
|
|
57
|
+
|
|
58
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
59
|
+
assert_eq(result, arr.reshape((4, 5, 2, 3))[:2, :3])
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_slice_through_reshape_middle_slice():
|
|
63
|
+
"""Middle slice (start > 0) on preserved dimension."""
|
|
64
|
+
arr = np.arange(100).reshape((10, 10))
|
|
65
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
66
|
+
|
|
67
|
+
# Reshape (10, 10) -> (10, 2, 5) preserves first dimension
|
|
68
|
+
result = x.reshape((10, 2, 5))[3:7]
|
|
69
|
+
expected = x[3:7].reshape((4, 2, 5))
|
|
70
|
+
|
|
71
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
72
|
+
assert_eq(result, arr.reshape((10, 2, 5))[3:7])
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
# =============================================================================
|
|
76
|
+
# Case 2: Dimension NOT preserved - slice should NOT push through
|
|
77
|
+
# =============================================================================
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_slice_not_pushed_when_dim_changes():
|
|
81
|
+
"""Slice blocked when the sliced dimension changes size."""
|
|
82
|
+
arr = np.arange(60).reshape((6, 10))
|
|
83
|
+
x = da.from_array(arr, chunks=(3, 5))
|
|
84
|
+
|
|
85
|
+
# Reshape (6, 10) -> (2, 3, 10) splits first dimension
|
|
86
|
+
# First dim changes from 6 to 2, so slice should NOT push through
|
|
87
|
+
result = x.reshape((2, 3, 10))[:1]
|
|
88
|
+
|
|
89
|
+
# Just verify correctness - can't push through
|
|
90
|
+
assert_eq(result, arr.reshape((2, 3, 10))[:1])
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def test_slice_not_pushed_through_flatten():
|
|
94
|
+
"""Slice blocked when reshape completely flattens the array."""
|
|
95
|
+
arr = np.arange(100).reshape((10, 10))
|
|
96
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
97
|
+
|
|
98
|
+
# Reshape to 1D - no dimension correspondence
|
|
99
|
+
result = x.reshape((100,))[:30]
|
|
100
|
+
|
|
101
|
+
# Just verify correctness - can't push through
|
|
102
|
+
assert_eq(result, arr.reshape((100,))[:30])
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def test_slice_on_reshaped_axis_not_pushed():
|
|
106
|
+
"""Slice on axis that was created by reshape doesn't push."""
|
|
107
|
+
arr = np.arange(60).reshape((10, 6))
|
|
108
|
+
x = da.from_array(arr, chunks=(5, 3))
|
|
109
|
+
|
|
110
|
+
# Reshape (10, 6) -> (10, 2, 3), then slice on new axis 1
|
|
111
|
+
result = x.reshape((10, 2, 3))[:, :1]
|
|
112
|
+
|
|
113
|
+
# Axis 1 is new (from splitting 6 -> 2, 3), can't push through
|
|
114
|
+
assert_eq(result, arr.reshape((10, 2, 3))[:, :1])
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
# =============================================================================
|
|
118
|
+
# Case 3: Slice with None (newaxis) - should still push through
|
|
119
|
+
# =============================================================================
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_slice_with_none_pushes_through():
|
|
123
|
+
"""Slice with None (newaxis) should push through and re-apply None."""
|
|
124
|
+
arr = np.arange(60).reshape((10, 6))
|
|
125
|
+
x = da.from_array(arr, chunks=(5, 3))
|
|
126
|
+
|
|
127
|
+
# Reshape (10, 6) -> (10, 2, 3), slice with None
|
|
128
|
+
result = x.reshape((10, 2, 3))[:5, None]
|
|
129
|
+
expected = x[:5].reshape((5, 2, 3))[:, None]
|
|
130
|
+
|
|
131
|
+
# Both should have same structure after simplification
|
|
132
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
133
|
+
assert_eq(result, arr.reshape((10, 2, 3))[:5, None])
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def test_slice_with_none_at_end():
|
|
137
|
+
"""Slice with None at end of index."""
|
|
138
|
+
arr = np.arange(60).reshape((10, 6))
|
|
139
|
+
x = da.from_array(arr, chunks=(5, 3))
|
|
140
|
+
|
|
141
|
+
result = x.reshape((10, 2, 3))[:5, :, :, None]
|
|
142
|
+
expected = x[:5].reshape((5, 2, 3))[:, :, :, None]
|
|
143
|
+
|
|
144
|
+
# Both should have same structure after simplification
|
|
145
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
146
|
+
assert_eq(result, arr.reshape((10, 2, 3))[:5, :, :, None])
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def test_slice_with_multiple_nones():
|
|
150
|
+
"""Slice with multiple Nones."""
|
|
151
|
+
arr = np.arange(60).reshape((10, 6))
|
|
152
|
+
x = da.from_array(arr, chunks=(5, 3))
|
|
153
|
+
|
|
154
|
+
result = x.reshape((10, 2, 3))[None, :5, None]
|
|
155
|
+
expected = x[:5].reshape((5, 2, 3))[None, :, None]
|
|
156
|
+
|
|
157
|
+
# Both should have same structure after simplification
|
|
158
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
159
|
+
assert_eq(result, arr.reshape((10, 2, 3))[None, :5, None])
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_slice_with_none_correctness():
|
|
163
|
+
"""Verify correctness with various None positions."""
|
|
164
|
+
arr = np.arange(120).reshape((10, 12))
|
|
165
|
+
x = da.from_array(arr, chunks=(5, 6))
|
|
166
|
+
|
|
167
|
+
# (10, 12) -> (10, 3, 4)
|
|
168
|
+
reshaped = x.reshape((10, 3, 4))
|
|
169
|
+
|
|
170
|
+
# Various slices with Nones
|
|
171
|
+
assert_eq(reshaped[:5, None, :, :], arr.reshape((10, 3, 4))[:5, None, :, :])
|
|
172
|
+
assert_eq(reshaped[None, :5], arr.reshape((10, 3, 4))[None, :5])
|
|
173
|
+
assert_eq(reshaped[:5, :, None, :], arr.reshape((10, 3, 4))[:5, :, None, :])
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
# =============================================================================
|
|
177
|
+
# Correctness tests with various shapes
|
|
178
|
+
# =============================================================================
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
@pytest.mark.parametrize(
|
|
182
|
+
"in_shape,out_shape,slice_",
|
|
183
|
+
[
|
|
184
|
+
# Trailing split
|
|
185
|
+
((20, 6), (20, 2, 3), (slice(10),)),
|
|
186
|
+
((20, 6), (20, 2, 3), (slice(5, 15),)),
|
|
187
|
+
((20, 12), (20, 3, 4), (slice(None, 8),)),
|
|
188
|
+
# Trailing merge
|
|
189
|
+
((20, 2, 3), (20, 6), (slice(10),)),
|
|
190
|
+
((20, 4, 5), (20, 20), (slice(5, 15),)),
|
|
191
|
+
# Multiple preserved dims
|
|
192
|
+
((10, 5, 6), (10, 5, 2, 3), (slice(5), slice(3))),
|
|
193
|
+
((10, 5, 4), (10, 5, 2, 2), (slice(3, 8), slice(None, 4))),
|
|
194
|
+
],
|
|
195
|
+
)
|
|
196
|
+
def test_slice_through_reshape_correctness(in_shape, out_shape, slice_):
|
|
197
|
+
"""Parametrized correctness tests."""
|
|
198
|
+
arr = np.arange(np.prod(in_shape)).reshape(in_shape)
|
|
199
|
+
chunks = tuple(max(1, s // 2) for s in in_shape)
|
|
200
|
+
x = da.from_array(arr, chunks=chunks)
|
|
201
|
+
|
|
202
|
+
result = x.reshape(out_shape)[slice_]
|
|
203
|
+
expected = arr.reshape(out_shape)[slice_]
|
|
204
|
+
|
|
205
|
+
assert_eq(result, expected)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
# =============================================================================
|
|
209
|
+
# Task reduction tests
|
|
210
|
+
# =============================================================================
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def test_slice_through_reshape_reduces_tasks():
|
|
214
|
+
"""Verify slice pushdown reduces task count."""
|
|
215
|
+
arr = np.ones((100, 10))
|
|
216
|
+
x = da.from_array(arr, chunks=(10, 5))
|
|
217
|
+
|
|
218
|
+
# Reshape preserves first dim, then slice
|
|
219
|
+
full = x.reshape((100, 2, 5))
|
|
220
|
+
sliced = x.reshape((100, 2, 5))[:10]
|
|
221
|
+
|
|
222
|
+
full_tasks = len(full.optimize().__dask_graph__())
|
|
223
|
+
sliced_tasks = len(sliced.optimize().__dask_graph__())
|
|
224
|
+
|
|
225
|
+
# Sliced should have fewer tasks (only 1/10 of chunks)
|
|
226
|
+
assert sliced_tasks < full_tasks
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def test_slice_through_reshape_reduces_numblocks():
|
|
230
|
+
"""Verify slice pushdown reduces number of blocks."""
|
|
231
|
+
arr = np.ones((100, 20))
|
|
232
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
233
|
+
|
|
234
|
+
result = x.reshape((100, 4, 5))[:20]
|
|
235
|
+
optimized = result.optimize()
|
|
236
|
+
|
|
237
|
+
# Should only have 2 blocks in first dimension (20 / 10)
|
|
238
|
+
assert optimized.numblocks[0] == 2
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
# =============================================================================
|
|
242
|
+
# Expression structure tests
|
|
243
|
+
# =============================================================================
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def test_expression_structure_slice_pushed():
|
|
247
|
+
"""Verify slice is pushed through reshape in expression tree."""
|
|
248
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
249
|
+
|
|
250
|
+
x = da.ones((20, 6), chunks=(5, 3))
|
|
251
|
+
result = x.reshape((20, 2, 3))[:5]
|
|
252
|
+
|
|
253
|
+
# Before simplification: Slice(Reshape(...))
|
|
254
|
+
assert isinstance(result.expr, SliceSlicesIntegers)
|
|
255
|
+
|
|
256
|
+
# After simplification: slice should have pushed through
|
|
257
|
+
simplified = result.expr.simplify()
|
|
258
|
+
|
|
259
|
+
# Slice shouldn't be at root after pushdown
|
|
260
|
+
assert not isinstance(simplified, SliceSlicesIntegers)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def test_expression_structure_slice_blocked():
|
|
264
|
+
"""Verify slice is NOT pushed when dimension changes."""
|
|
265
|
+
x = da.ones((6, 10), chunks=(3, 5))
|
|
266
|
+
|
|
267
|
+
# Reshape (6, 10) -> (2, 3, 10) splits first dim
|
|
268
|
+
# First dim changes from 6 to 2, slice should not push through
|
|
269
|
+
result = x.reshape((2, 3, 10))[:1]
|
|
270
|
+
|
|
271
|
+
# Just verify correctness - the optimization doesn't apply here
|
|
272
|
+
assert_eq(result, np.ones((2, 3, 10))[:1])
|