dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,968 @@
|
|
|
1
|
+
"""Tests for slice pushdown into IO expressions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
import dask_array as da
|
|
9
|
+
from dask_array.io import FromArray
|
|
10
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
11
|
+
from dask_array._test_utils import assert_eq
|
|
12
|
+
|
|
13
|
+
# Parametrized correctness tests: (array_shape, chunks, slice_tuple)
|
|
14
|
+
SLICE_CASES = [
|
|
15
|
+
# Basic slices
|
|
16
|
+
((10, 10), (2, 2), (slice(0, 2), slice(0, 2))), # corner
|
|
17
|
+
((10, 10), (2, 2), (slice(0, 4), slice(0, 4))), # 2x2 chunks
|
|
18
|
+
((10, 10), (5, 5), (slice(0, 5), slice(0, 5))), # chunk boundary
|
|
19
|
+
((10, 10), (5, 5), (slice(2, 7), slice(3, 8))), # mid-chunk
|
|
20
|
+
((10, 10), (5, 5), (slice(None), slice(None))), # full slice
|
|
21
|
+
# Edge cases
|
|
22
|
+
((10, 10), (2, 2), (slice(5, 5), slice(None))), # empty
|
|
23
|
+
((10, 10), (2, 2), (slice(3, 4), slice(None))), # single row
|
|
24
|
+
((10, 10), (2, 2), (slice(-4, -1), slice(-3, None))), # negative
|
|
25
|
+
((10, 10, 10), (3, 3, 3), (slice(1, 4), slice(2, 5), slice(3, 6))), # 3D
|
|
26
|
+
# Adversarial
|
|
27
|
+
((10, 10), (10, 10), (slice(2, 5), slice(3, 7))), # single chunk source
|
|
28
|
+
((10, 10), (3, 4), (slice(1, 8), slice(2, 9))), # uneven chunks
|
|
29
|
+
((10, 10), (3, 3), (slice(9, None), slice(9, None))), # last chunk
|
|
30
|
+
((10, 10, 10), (3, 3, 3), (slice(2, 5),)), # partial dims
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@pytest.mark.parametrize("shape,chunks,slc", SLICE_CASES)
|
|
35
|
+
def test_slice_correctness(shape, chunks, slc):
|
|
36
|
+
"""Sliced dask array matches sliced numpy array."""
|
|
37
|
+
arr = np.arange(np.prod(shape)).reshape(shape)
|
|
38
|
+
x = da.from_array(arr, chunks=chunks)
|
|
39
|
+
assert_eq(x[slc], arr[slc])
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
# Task count tests: (array_shape, chunks, slice_tuple, expected_tasks)
|
|
43
|
+
TASK_COUNT_CASES = [
|
|
44
|
+
((10, 10), (2, 2), (slice(0, 2), slice(0, 2)), 1), # 1 chunk
|
|
45
|
+
((10, 10), (2, 2), (slice(0, 4), slice(0, 4)), 4), # 2x2 chunks
|
|
46
|
+
((10, 10), (5, 5), (slice(0, 5), slice(0, 5)), 1), # boundary
|
|
47
|
+
((10, 10), (5, 5), (slice(2, 7), slice(3, 8)), 4), # spans 2x2
|
|
48
|
+
((20, 20), (5, 5), (slice(0, 3), slice(None)), 4), # 1x4 row
|
|
49
|
+
((20, 20), (5, 5), (slice(None), slice(0, 3)), 4), # 4x1 col
|
|
50
|
+
((20, 20), (5, 5), (slice(0, 12), slice(0, 12)), 9), # 3x3
|
|
51
|
+
((100, 100), (10, 10), (slice(0, 5), slice(0, 5)), 1), # small from large
|
|
52
|
+
((10, 10, 10), (5, 5, 5), (slice(0, 3), slice(0, 3), slice(0, 3)), 1), # 3D corner
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@pytest.mark.parametrize("shape,chunks,slc,expected", TASK_COUNT_CASES)
|
|
57
|
+
def test_task_count(shape, chunks, slc, expected):
|
|
58
|
+
"""After optimization, task count equals chunks touched."""
|
|
59
|
+
arr = np.arange(np.prod(shape)).reshape(shape)
|
|
60
|
+
x = da.from_array(arr, chunks=chunks)
|
|
61
|
+
y = x[slc].optimize()
|
|
62
|
+
assert len(y.__dask_graph__()) == expected
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def test_slice_optimize_slice():
|
|
66
|
+
"""Slice, optimize, slice again works correctly."""
|
|
67
|
+
arr = np.arange(100).reshape(10, 10)
|
|
68
|
+
x = da.from_array(arr, chunks=(2, 2))
|
|
69
|
+
|
|
70
|
+
y = x[0:6, 0:6].optimize()
|
|
71
|
+
assert len(y.__dask_graph__()) == 9 # 3x3 chunks
|
|
72
|
+
|
|
73
|
+
z = y[0:2, 0:2].optimize()
|
|
74
|
+
assert len(z.__dask_graph__()) == 1 # 1 chunk
|
|
75
|
+
|
|
76
|
+
assert_eq(z, arr[0:6, 0:6][0:2, 0:2])
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_slice_through_elemwise():
|
|
80
|
+
"""Slice pushes through elemwise into IO."""
|
|
81
|
+
arr = np.arange(100).reshape(10, 10)
|
|
82
|
+
x = da.from_array(arr, chunks=(2, 2))
|
|
83
|
+
y = ((x + 1) * 2)[0:2, 0:2].optimize()
|
|
84
|
+
assert len(y.__dask_graph__()) <= 2
|
|
85
|
+
assert_eq(y, ((arr + 1) * 2)[0:2, 0:2])
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def test_nested_slices():
|
|
89
|
+
"""Nested slices fuse."""
|
|
90
|
+
arr = np.arange(100).reshape(10, 10)
|
|
91
|
+
x = da.from_array(arr, chunks=(2, 2))
|
|
92
|
+
y = x[1:8, 2:9][1:4, 1:4]
|
|
93
|
+
assert_eq(y, arr[1:8, 2:9][1:4, 1:4])
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_expression_structure():
|
|
97
|
+
"""Verify expression types before/after optimization."""
|
|
98
|
+
x = da.from_array(np.arange(100).reshape(10, 10), chunks=(2, 2))
|
|
99
|
+
y = x[0:2, 0:2]
|
|
100
|
+
|
|
101
|
+
assert isinstance(y.expr, SliceSlicesIntegers)
|
|
102
|
+
assert isinstance(y.optimize().expr, FromArray)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def test_steps_and_reverse():
|
|
106
|
+
"""Slices with steps still compute correctly."""
|
|
107
|
+
arr = np.arange(100).reshape(10, 10)
|
|
108
|
+
x = da.from_array(arr, chunks=(2, 2))
|
|
109
|
+
|
|
110
|
+
assert_eq(x[::2, ::2], arr[::2, ::2])
|
|
111
|
+
assert_eq(x[::-1, ::-1], arr[::-1, ::-1])
|
|
112
|
+
assert_eq(x[::5, ::5], arr[::5, ::5])
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def test_non_pushdown_cases():
|
|
116
|
+
"""Integer indexing, fancy indexing, newaxis don't break."""
|
|
117
|
+
arr = np.arange(100).reshape(10, 10)
|
|
118
|
+
x = da.from_array(arr, chunks=(2, 2))
|
|
119
|
+
|
|
120
|
+
assert_eq(x[5, :], arr[5, :])
|
|
121
|
+
assert_eq(x[[1, 3, 5], :], arr[[1, 3, 5], :])
|
|
122
|
+
assert_eq(x[None, :5, :5], arr[None, :5, :5])
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def test_broadcast_to_empty_slice():
|
|
126
|
+
result = da.broadcast_to(da.from_array(np.array([1]), (1,)), (5,))[:0]
|
|
127
|
+
expected = np.array([], dtype=int)
|
|
128
|
+
|
|
129
|
+
assert result.chunks == ((0,),)
|
|
130
|
+
assert_eq(result, expected)
|
|
131
|
+
assert_eq(da.Array(result.expr.optimize(fuse=False)), expected)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def test_masked_array():
|
|
135
|
+
"""Slice pushdown preserves masks."""
|
|
136
|
+
arr = np.ma.array(np.arange(100).reshape(10, 10), mask=False)
|
|
137
|
+
arr.mask[5, 5] = True
|
|
138
|
+
x = da.from_array(arr, chunks=(3, 3))
|
|
139
|
+
result = x[4:7, 4:7].compute()
|
|
140
|
+
expected = arr[4:7, 4:7]
|
|
141
|
+
assert_eq(result, expected)
|
|
142
|
+
assert_eq(result.mask, expected.mask)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def test_deterministic_names():
|
|
146
|
+
"""Same slice -> same name, different slice -> different name."""
|
|
147
|
+
arr = np.arange(100).reshape(10, 10)
|
|
148
|
+
x1 = da.from_array(arr, chunks=(2, 2))
|
|
149
|
+
x2 = da.from_array(arr, chunks=(2, 2))
|
|
150
|
+
|
|
151
|
+
assert x1[0:2, 0:2].optimize().name == x2[0:2, 0:2].optimize().name
|
|
152
|
+
assert x1[0:2, 0:2].optimize().name != x1[0:3, 0:3].optimize().name
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def test_slice_then_reduction():
|
|
156
|
+
"""Slice followed by reduction."""
|
|
157
|
+
arr = np.arange(100).reshape(10, 10)
|
|
158
|
+
x = da.from_array(arr, chunks=(2, 2))
|
|
159
|
+
assert_eq(x[0:4, 0:4].sum(), arr[0:4, 0:4].sum())
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def test_region_deferred_slice():
|
|
163
|
+
"""Slice pushdown uses _region for deferred slicing (not eager read)."""
|
|
164
|
+
arr = np.arange(10000).reshape(100, 100)
|
|
165
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
166
|
+
# Use a slice that fits within a single chunk
|
|
167
|
+
y = x[12:18, 34:39]
|
|
168
|
+
|
|
169
|
+
opt = y.expr.optimize()
|
|
170
|
+
|
|
171
|
+
# Should use _region parameter, not slice the source eagerly
|
|
172
|
+
assert opt.operand("_region") == (slice(12, 18, None), slice(34, 39, None))
|
|
173
|
+
# Source array should still be the full array
|
|
174
|
+
assert opt.array.shape == (100, 100)
|
|
175
|
+
# Chunks should be for the sliced region (6x5)
|
|
176
|
+
assert opt.chunks == ((6,), (5,))
|
|
177
|
+
|
|
178
|
+
# Verify correctness
|
|
179
|
+
assert_eq(y, arr[12:18, 34:39])
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def test_region_single_chunk():
|
|
183
|
+
"""Slice within a single chunk produces one task with direct slice."""
|
|
184
|
+
arr = np.arange(10000 * 10000).reshape(10000, 10000)
|
|
185
|
+
x = da.from_array(arr, chunks=(1000, 1000))
|
|
186
|
+
# Small slice within a single chunk
|
|
187
|
+
y = x[1500:1550, 2300:2350]
|
|
188
|
+
|
|
189
|
+
opt = y.expr.optimize()
|
|
190
|
+
graph = dict(opt.__dask_graph__())
|
|
191
|
+
|
|
192
|
+
# Should be single task (slice fits within one chunk)
|
|
193
|
+
task_keys = [k for k in graph if isinstance(k, tuple) and len(k) == 3]
|
|
194
|
+
assert len(task_keys) == 1
|
|
195
|
+
|
|
196
|
+
# The slice should be direct (1500:1550, 2300:2350), not via 1000x1000 chunk
|
|
197
|
+
graph_str = str(graph)
|
|
198
|
+
assert "1000" not in graph_str, "Should slice directly, not via full chunk"
|
|
199
|
+
|
|
200
|
+
# Verify correctness
|
|
201
|
+
assert_eq(y, arr[1500:1550, 2300:2350])
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def test_region_multiple_chunks():
|
|
205
|
+
"""Slice spanning multiple chunks still produces multiple tasks."""
|
|
206
|
+
arr = np.arange(10000).reshape(100, 100)
|
|
207
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
208
|
+
# Slice spanning 2x2 chunks: 15-25 spans chunks 1,2 in first dim
|
|
209
|
+
# 35-45 spans chunks 3,4 in second dim
|
|
210
|
+
y = x[15:25, 35:45]
|
|
211
|
+
|
|
212
|
+
opt = y.expr.optimize()
|
|
213
|
+
graph = dict(opt.__dask_graph__())
|
|
214
|
+
|
|
215
|
+
# Should be 2x2=4 tasks (slice spans multiple chunks)
|
|
216
|
+
task_keys = [k for k in graph if isinstance(k, tuple) and len(k) == 3]
|
|
217
|
+
assert len(task_keys) == 4
|
|
218
|
+
|
|
219
|
+
# Verify correctness
|
|
220
|
+
assert_eq(y, arr[15:25, 35:45])
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def test_region_zarr_deferred(tmp_path):
|
|
224
|
+
"""Zarr slicing is deferred - graph contains zarr array, not numpy data."""
|
|
225
|
+
zarr = pytest.importorskip("zarr")
|
|
226
|
+
# Create zarr array
|
|
227
|
+
zarr_path = tmp_path / "test.zarr"
|
|
228
|
+
z = zarr.open(
|
|
229
|
+
str(zarr_path),
|
|
230
|
+
mode="w",
|
|
231
|
+
shape=(10000, 10000),
|
|
232
|
+
dtype="float64",
|
|
233
|
+
chunks=(1000, 1000),
|
|
234
|
+
)
|
|
235
|
+
z[1500:1550, 2300:2350] = np.arange(2500).reshape(50, 50)
|
|
236
|
+
|
|
237
|
+
x = da.from_zarr(str(zarr_path))
|
|
238
|
+
y = x[1500:1550, 2300:2350]
|
|
239
|
+
|
|
240
|
+
opt = y.expr.optimize()
|
|
241
|
+
graph = dict(opt.__dask_graph__())
|
|
242
|
+
|
|
243
|
+
# Should have zarr array in graph, not numpy data
|
|
244
|
+
zarr_arrays = [v for v in graph.values() if isinstance(v, zarr.Array)]
|
|
245
|
+
numpy_arrays = [v for v in graph.values() if isinstance(v, np.ndarray)]
|
|
246
|
+
|
|
247
|
+
assert len(zarr_arrays) == 1, "Graph should contain the zarr array"
|
|
248
|
+
assert len(numpy_arrays) == 0, "Graph should not contain numpy arrays (data not loaded)"
|
|
249
|
+
|
|
250
|
+
# The zarr array in graph should be the full array, not sliced
|
|
251
|
+
assert zarr_arrays[0].shape == (10000, 10000)
|
|
252
|
+
|
|
253
|
+
# Verify correctness
|
|
254
|
+
assert_eq(y, z[1500:1550, 2300:2350])
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def test_integer_indexing_pushdown():
|
|
258
|
+
"""Integer indexing uses region pushdown to minimize data loading."""
|
|
259
|
+
arr = np.arange(100).reshape(10, 10)
|
|
260
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
261
|
+
|
|
262
|
+
# Pure integer indexing - should be 2 tasks (FromArray + extract)
|
|
263
|
+
y = x[3, 7]
|
|
264
|
+
opt = y.optimize()
|
|
265
|
+
assert len(opt.__dask_graph__()) == 2
|
|
266
|
+
|
|
267
|
+
# The inner FromArray should have region centered on (3, 7)
|
|
268
|
+
from_array_expr = opt.expr.array
|
|
269
|
+
assert from_array_expr.operand("_region") == (slice(3, 4), slice(7, 8))
|
|
270
|
+
assert from_array_expr.array.shape == (10, 10) # Original array unchanged
|
|
271
|
+
|
|
272
|
+
assert_eq(y, arr[3, 7])
|
|
273
|
+
|
|
274
|
+
# Mixed slice + integer
|
|
275
|
+
y = x[:3, 5]
|
|
276
|
+
assert_eq(y, arr[:3, 5])
|
|
277
|
+
|
|
278
|
+
y = x[5, 2:8]
|
|
279
|
+
assert_eq(y, arr[5, 2:8])
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
# ============================================================
|
|
283
|
+
# Slice through reduction tests
|
|
284
|
+
# ============================================================
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
def test_slice_through_reduction_optimization():
|
|
288
|
+
"""Verify slice pushdown through reduction produces equivalent result.
|
|
289
|
+
|
|
290
|
+
x.sum(axis=0)[:5] should simplify to x[:, :5].sum(axis=0)
|
|
291
|
+
"""
|
|
292
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
293
|
+
|
|
294
|
+
# The naive way: full sum then slice
|
|
295
|
+
y = x.sum(axis=0)[:5]
|
|
296
|
+
|
|
297
|
+
# The optimized way: slice first, then sum
|
|
298
|
+
expected = x[:, :5].sum(axis=0)
|
|
299
|
+
|
|
300
|
+
# After simplification, the names should be equivalent
|
|
301
|
+
# (both sides need simplify since slices also simplify through ones)
|
|
302
|
+
assert y.expr.simplify()._name == expected.expr.simplify()._name
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def test_slice_through_reduction_reduces_tasks():
|
|
306
|
+
"""Slice pushdown through reduction should reduce graph size.
|
|
307
|
+
|
|
308
|
+
For a from_array with (10, 10) chunks, slicing after reduction
|
|
309
|
+
should result in fewer tasks than computing the full reduction.
|
|
310
|
+
"""
|
|
311
|
+
arr = np.arange(10000).reshape(100, 100)
|
|
312
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
313
|
+
|
|
314
|
+
# Full reduction has 10*10 input chunks
|
|
315
|
+
full_sum = x.sum(axis=0)
|
|
316
|
+
full_tasks = len(full_sum.optimize().__dask_graph__())
|
|
317
|
+
|
|
318
|
+
# Sliced reduction should have fewer tasks
|
|
319
|
+
sliced_sum = x.sum(axis=0)[:5]
|
|
320
|
+
sliced_tasks = len(sliced_sum.optimize().__dask_graph__())
|
|
321
|
+
|
|
322
|
+
# Slicing to first 5 elements (1 chunk column) should have ~10x fewer tasks
|
|
323
|
+
assert sliced_tasks < full_tasks
|
|
324
|
+
|
|
325
|
+
# Verify the reduction is correct
|
|
326
|
+
assert_eq(sliced_sum, arr.sum(axis=0)[:5])
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def test_slice_through_reduction_axis1():
|
|
330
|
+
"""Slice pushdown through sum(axis=1)."""
|
|
331
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
332
|
+
|
|
333
|
+
# x.sum(axis=1)[:5] should simplify to x[:5, :].sum(axis=1)
|
|
334
|
+
y = x.sum(axis=1)[:5]
|
|
335
|
+
expected = x[:5, :].sum(axis=1)
|
|
336
|
+
|
|
337
|
+
assert y.expr.simplify()._name == expected.expr.simplify()._name
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def test_slice_through_reduction_3d():
|
|
341
|
+
"""Slice pushdown through reduction on 3D array."""
|
|
342
|
+
x = da.ones((20, 20, 20), chunks=(5, 5, 5))
|
|
343
|
+
|
|
344
|
+
# Reduce axis 1, slice result
|
|
345
|
+
# Output axes: [0, 2] become [0, 1] -> slice [:3, :4] maps to input [:3, :, :4]
|
|
346
|
+
y = x.sum(axis=1)[:3, :4]
|
|
347
|
+
expected = x[:3, :, :4].sum(axis=1)
|
|
348
|
+
|
|
349
|
+
assert y.expr.simplify()._name == expected.expr.simplify()._name
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def test_slice_through_reduction_multiple_axes():
|
|
353
|
+
"""Slice pushdown through reduction on multiple axes."""
|
|
354
|
+
x = da.ones((20, 20, 20), chunks=(5, 5, 5))
|
|
355
|
+
|
|
356
|
+
# Reduce axes 0 and 2, only axis 1 remains
|
|
357
|
+
# Output axis 0 -> input axis 1
|
|
358
|
+
y = x.sum(axis=(0, 2))[:5]
|
|
359
|
+
expected = x[:, :5, :].sum(axis=(0, 2))
|
|
360
|
+
|
|
361
|
+
assert y.expr.simplify()._name == expected.expr.simplify()._name
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def test_slice_through_reduction_correctness():
|
|
365
|
+
"""Verify correctness of optimized slice-through-reduction."""
|
|
366
|
+
arr = np.arange(10000).reshape(100, 100)
|
|
367
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
368
|
+
|
|
369
|
+
# Various cases
|
|
370
|
+
assert_eq(x.sum(axis=0)[:5], arr.sum(axis=0)[:5])
|
|
371
|
+
assert_eq(x.sum(axis=1)[:5], arr.sum(axis=1)[:5])
|
|
372
|
+
assert_eq(x.sum(axis=0)[10:20], arr.sum(axis=0)[10:20])
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def test_slice_through_reduction_integer_index():
|
|
376
|
+
"""Integer indexing through reduction reduces tasks.
|
|
377
|
+
|
|
378
|
+
Integer indices are converted to size-1 slices, pushed through,
|
|
379
|
+
then extracted with [0] at the end.
|
|
380
|
+
"""
|
|
381
|
+
arr = np.arange(10000).reshape(100, 100)
|
|
382
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
383
|
+
|
|
384
|
+
# Full reduction
|
|
385
|
+
full_tasks = len(x.sum(axis=0).optimize().__dask_graph__())
|
|
386
|
+
|
|
387
|
+
# Integer index should have fewer tasks
|
|
388
|
+
result = x.sum(axis=0)[5]
|
|
389
|
+
indexed_tasks = len(result.optimize().__dask_graph__())
|
|
390
|
+
|
|
391
|
+
assert indexed_tasks < full_tasks
|
|
392
|
+
assert_eq(result, arr.sum(axis=0)[5])
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
# =============================================================================
|
|
396
|
+
# Slice through creation expressions (ones, zeros, full, empty)
|
|
397
|
+
# =============================================================================
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def test_slice_ones_returns_smaller_ones():
|
|
401
|
+
"""Slicing ones() returns a new ones() with the sliced shape."""
|
|
402
|
+
from dask_array.creation import Ones
|
|
403
|
+
|
|
404
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
405
|
+
y = x[:15, :25]
|
|
406
|
+
|
|
407
|
+
# After simplification, should be Ones with new shape, not Slice(Ones)
|
|
408
|
+
simplified = y.expr.simplify()
|
|
409
|
+
assert isinstance(simplified, Ones)
|
|
410
|
+
assert simplified.shape == (15, 25)
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def test_slice_zeros_returns_smaller_zeros():
|
|
414
|
+
"""Slicing zeros() returns a new zeros() with the sliced shape."""
|
|
415
|
+
from dask_array.creation import Zeros
|
|
416
|
+
|
|
417
|
+
x = da.zeros((100, 100), chunks=(10, 10))
|
|
418
|
+
y = x[:15, :25]
|
|
419
|
+
|
|
420
|
+
simplified = y.expr.simplify()
|
|
421
|
+
assert isinstance(simplified, Zeros)
|
|
422
|
+
assert simplified.shape == (15, 25)
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
def test_slice_full_returns_smaller_full():
|
|
426
|
+
"""Slicing full() returns a new full() with the sliced shape."""
|
|
427
|
+
from dask_array.creation import Full
|
|
428
|
+
|
|
429
|
+
x = da.full((100, 100), 42, chunks=(10, 10))
|
|
430
|
+
y = x[:15, :25]
|
|
431
|
+
|
|
432
|
+
simplified = y.expr.simplify()
|
|
433
|
+
assert isinstance(simplified, Full)
|
|
434
|
+
assert simplified.shape == (15, 25)
|
|
435
|
+
# Verify fill_value is preserved
|
|
436
|
+
assert_eq(y, np.full((15, 25), 42))
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def test_slice_creation_correctness():
|
|
440
|
+
"""Verify sliced creation expressions produce correct values."""
|
|
441
|
+
assert_eq(da.ones((100, 100), chunks=10)[:15, :25], np.ones((15, 25)))
|
|
442
|
+
assert_eq(da.zeros((100, 100), chunks=10)[:15, :25], np.zeros((15, 25)))
|
|
443
|
+
assert_eq(da.full((100, 100), 7.5, chunks=10)[:15, :25], np.full((15, 25), 7.5))
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def test_slice_creation_preserves_dtype():
|
|
447
|
+
"""Verify sliced creation preserves dtype."""
|
|
448
|
+
x = da.ones((100, 100), chunks=10, dtype="int32")[:15, :25]
|
|
449
|
+
assert x.dtype == np.dtype("int32")
|
|
450
|
+
assert_eq(x, np.ones((15, 25), dtype="int32"))
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
# =============================================================================
|
|
454
|
+
# Slice through Concatenate
|
|
455
|
+
# =============================================================================
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def test_slice_through_concat_same_axis_first_array():
|
|
459
|
+
"""Slice entirely within first array of concat -> just first array sliced."""
|
|
460
|
+
a = da.ones((10, 5), chunks=5)
|
|
461
|
+
b = da.ones((10, 5), chunks=5)
|
|
462
|
+
result = da.concatenate([a, b], axis=0)[:5] # Only needs 'a'
|
|
463
|
+
expected = a[:5]
|
|
464
|
+
|
|
465
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
def test_slice_through_concat_same_axis_spans_arrays():
|
|
469
|
+
"""Slice spans multiple arrays in concat."""
|
|
470
|
+
a = da.ones((10, 5), chunks=5)
|
|
471
|
+
b = da.ones((10, 5), chunks=5)
|
|
472
|
+
c = da.ones((10, 5), chunks=5)
|
|
473
|
+
# slice 5:15 spans a[5:10] and b[0:5]
|
|
474
|
+
result = da.concatenate([a, b, c], axis=0)[5:15]
|
|
475
|
+
expected = da.concatenate([a[5:], b[:5]], axis=0)
|
|
476
|
+
|
|
477
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def test_slice_through_concat_different_axis():
|
|
481
|
+
"""Slice on different axis than concat -> push to all inputs."""
|
|
482
|
+
a = da.ones((10, 20), chunks=5)
|
|
483
|
+
b = da.ones((10, 20), chunks=5)
|
|
484
|
+
result = da.concatenate([a, b], axis=0)[:, :5] # Slice axis 1
|
|
485
|
+
expected = da.concatenate([a[:, :5], b[:, :5]], axis=0)
|
|
486
|
+
|
|
487
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
488
|
+
|
|
489
|
+
|
|
490
|
+
def test_slice_through_concat_correctness():
|
|
491
|
+
"""Verify slice through concat produces correct values."""
|
|
492
|
+
a = np.arange(20).reshape(4, 5)
|
|
493
|
+
b = np.arange(20, 40).reshape(4, 5)
|
|
494
|
+
da_a = da.from_array(a, chunks=2)
|
|
495
|
+
da_b = da.from_array(b, chunks=2)
|
|
496
|
+
|
|
497
|
+
# Same axis slice
|
|
498
|
+
result = da.concatenate([da_a, da_b], axis=0)[:3]
|
|
499
|
+
assert_eq(result, np.concatenate([a, b], axis=0)[:3])
|
|
500
|
+
|
|
501
|
+
# Different axis slice
|
|
502
|
+
result = da.concatenate([da_a, da_b], axis=0)[:, :3]
|
|
503
|
+
assert_eq(result, np.concatenate([a, b], axis=0)[:, :3])
|
|
504
|
+
|
|
505
|
+
# Slice spanning both arrays
|
|
506
|
+
result = da.concatenate([da_a, da_b], axis=0)[2:6]
|
|
507
|
+
assert_eq(result, np.concatenate([a, b], axis=0)[2:6])
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def test_slice_through_concat_reduces_tasks():
|
|
511
|
+
"""Verify slice through concat reduces task count."""
|
|
512
|
+
a = da.ones((100, 100), chunks=10)
|
|
513
|
+
b = da.ones((100, 100), chunks=10)
|
|
514
|
+
concat = da.concatenate([a, b], axis=0)
|
|
515
|
+
|
|
516
|
+
full_tasks = len(concat.optimize().__dask_graph__())
|
|
517
|
+
# Slice only first 5 rows - should only need first array
|
|
518
|
+
sliced_tasks = len(concat[:5].optimize().__dask_graph__())
|
|
519
|
+
|
|
520
|
+
assert sliced_tasks < full_tasks
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
# =============================================================================
|
|
524
|
+
# Slice through Stack
|
|
525
|
+
# =============================================================================
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def test_slice_through_stack_selects_subset():
|
|
529
|
+
"""Slice on stacked axis selects subset of inputs."""
|
|
530
|
+
a = da.ones((10, 5), chunks=5)
|
|
531
|
+
b = da.ones((10, 5), chunks=5)
|
|
532
|
+
c = da.ones((10, 5), chunks=5)
|
|
533
|
+
# stack gives shape (3, 10, 5), slice [:1] should be stack([a])
|
|
534
|
+
result = da.stack([a, b, c], axis=0)[:1]
|
|
535
|
+
expected = da.stack([a], axis=0)
|
|
536
|
+
|
|
537
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
def test_slice_through_stack_other_axis():
|
|
541
|
+
"""Slice on non-stacked axis pushes to all inputs."""
|
|
542
|
+
a = da.ones((10, 20), chunks=5)
|
|
543
|
+
b = da.ones((10, 20), chunks=5)
|
|
544
|
+
# stack gives shape (2, 10, 20), slice [:, :5, :10] pushes to each array
|
|
545
|
+
result = da.stack([a, b], axis=0)[:, :5, :10]
|
|
546
|
+
expected = da.stack([a[:5, :10], b[:5, :10]], axis=0)
|
|
547
|
+
|
|
548
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def test_slice_through_stack_mixed():
|
|
552
|
+
"""Slice on both stacked and other axes."""
|
|
553
|
+
a = da.ones((10, 20), chunks=5)
|
|
554
|
+
b = da.ones((10, 20), chunks=5)
|
|
555
|
+
c = da.ones((10, 20), chunks=5)
|
|
556
|
+
# stack gives shape (3, 10, 20), slice [:2, :5] keeps a and b, sliced
|
|
557
|
+
result = da.stack([a, b, c], axis=0)[:2, :5]
|
|
558
|
+
expected = da.stack([a[:5], b[:5]], axis=0)
|
|
559
|
+
|
|
560
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
561
|
+
|
|
562
|
+
|
|
563
|
+
def test_slice_through_stack_correctness():
|
|
564
|
+
"""Verify slice through stack produces correct values."""
|
|
565
|
+
a = np.arange(20).reshape(4, 5)
|
|
566
|
+
b = np.arange(20, 40).reshape(4, 5)
|
|
567
|
+
c = np.arange(40, 60).reshape(4, 5)
|
|
568
|
+
da_a = da.from_array(a, chunks=2)
|
|
569
|
+
da_b = da.from_array(b, chunks=2)
|
|
570
|
+
da_c = da.from_array(c, chunks=2)
|
|
571
|
+
|
|
572
|
+
# Slice on stacked axis
|
|
573
|
+
result = da.stack([da_a, da_b, da_c], axis=0)[:2]
|
|
574
|
+
assert_eq(result, np.stack([a, b, c], axis=0)[:2])
|
|
575
|
+
|
|
576
|
+
# Slice on other axis
|
|
577
|
+
result = da.stack([da_a, da_b, da_c], axis=0)[:, :2, :3]
|
|
578
|
+
assert_eq(result, np.stack([a, b, c], axis=0)[:, :2, :3])
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
def test_slice_through_stack_reduces_tasks():
|
|
582
|
+
"""Verify slice through stack reduces task count."""
|
|
583
|
+
a = da.ones((100, 100), chunks=10)
|
|
584
|
+
b = da.ones((100, 100), chunks=10)
|
|
585
|
+
c = da.ones((100, 100), chunks=10)
|
|
586
|
+
stacked = da.stack([a, b, c], axis=0)
|
|
587
|
+
|
|
588
|
+
full_tasks = len(stacked.optimize().__dask_graph__())
|
|
589
|
+
# Slice only first array
|
|
590
|
+
sliced_tasks = len(stacked[:1].optimize().__dask_graph__())
|
|
591
|
+
|
|
592
|
+
assert sliced_tasks < full_tasks
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
# =============================================================================
|
|
596
|
+
# Slice through BroadcastTo
|
|
597
|
+
# =============================================================================
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def test_slice_through_broadcast_to_new_dim():
|
|
601
|
+
"""Slice on dimension added by broadcast."""
|
|
602
|
+
x = da.ones((10,), chunks=5)
|
|
603
|
+
# broadcast_to adds a new dimension at front: (10,) -> (20, 10)
|
|
604
|
+
result = da.broadcast_to(x, (20, 10))[:5, :]
|
|
605
|
+
expected = da.broadcast_to(x, (5, 10))
|
|
606
|
+
|
|
607
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
608
|
+
|
|
609
|
+
|
|
610
|
+
def test_slice_through_broadcast_to_existing_dim():
|
|
611
|
+
"""Slice on dimension that exists in input."""
|
|
612
|
+
x = da.ones((10,), chunks=5)
|
|
613
|
+
# broadcast_to adds new dim: (10,) -> (20, 10)
|
|
614
|
+
result = da.broadcast_to(x, (20, 10))[:, :5]
|
|
615
|
+
expected = da.broadcast_to(x[:5], (20, 5))
|
|
616
|
+
|
|
617
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
618
|
+
|
|
619
|
+
|
|
620
|
+
def test_slice_through_broadcast_to_both_dims():
|
|
621
|
+
"""Slice on both new and existing dimensions."""
|
|
622
|
+
x = da.ones((10,), chunks=5)
|
|
623
|
+
result = da.broadcast_to(x, (20, 10))[:5, :3]
|
|
624
|
+
expected = da.broadcast_to(x[:3], (5, 3))
|
|
625
|
+
|
|
626
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
def test_slice_through_broadcast_to_broadcasted_dim():
|
|
630
|
+
"""Slice on dimension that was size-1 in input."""
|
|
631
|
+
x = da.ones((1, 10), chunks=(1, 5))
|
|
632
|
+
# broadcast_to expands first dim: (1, 10) -> (20, 10)
|
|
633
|
+
result = da.broadcast_to(x, (20, 10))[:5, :3]
|
|
634
|
+
# First dim can't push (was 1), second dim pushes
|
|
635
|
+
expected = da.broadcast_to(x[:, :3], (5, 3))
|
|
636
|
+
|
|
637
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def test_slice_through_broadcast_to_correctness():
|
|
641
|
+
"""Verify slice through broadcast_to produces correct values."""
|
|
642
|
+
x = np.arange(10)
|
|
643
|
+
da_x = da.from_array(x, chunks=5)
|
|
644
|
+
|
|
645
|
+
# Broadcast to 2D then slice
|
|
646
|
+
result = da.broadcast_to(da_x, (20, 10))[:5, :3]
|
|
647
|
+
expected = np.broadcast_to(x, (20, 10))[:5, :3]
|
|
648
|
+
assert_eq(result, expected)
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def test_slice_through_broadcast_to_reduces_tasks():
|
|
652
|
+
"""Verify slice through broadcast_to reduces task count."""
|
|
653
|
+
x = da.ones((100,), chunks=10)
|
|
654
|
+
broadcasted = da.broadcast_to(x, (100, 100))
|
|
655
|
+
|
|
656
|
+
full_tasks = len(broadcasted.optimize().__dask_graph__())
|
|
657
|
+
# Slice to smaller output
|
|
658
|
+
sliced_tasks = len(broadcasted[:5, :5].optimize().__dask_graph__())
|
|
659
|
+
|
|
660
|
+
assert sliced_tasks < full_tasks
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
# --- Shuffle (take) through Elemwise Tests ---
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def test_shuffle_pushes_through_elemwise_add():
|
|
667
|
+
"""(x + y)[[1,3,5]] should optimize to x[[1,3,5]] + y[[1,3,5]]."""
|
|
668
|
+
x = da.arange(20, chunks=5)
|
|
669
|
+
y = da.arange(20, chunks=5)
|
|
670
|
+
|
|
671
|
+
indices = [1, 3, 5, 7, 9]
|
|
672
|
+
result = (x + y)[indices]
|
|
673
|
+
expected = x[indices] + y[indices]
|
|
674
|
+
|
|
675
|
+
# Structure should match
|
|
676
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
677
|
+
|
|
678
|
+
# Verify correctness
|
|
679
|
+
x_np = np.arange(20)
|
|
680
|
+
y_np = np.arange(20)
|
|
681
|
+
assert_eq(result, (x_np + y_np)[indices])
|
|
682
|
+
|
|
683
|
+
|
|
684
|
+
def test_shuffle_pushes_through_elemwise_mul():
|
|
685
|
+
"""(x * y)[[2,4,6]] should optimize to x[[2,4,6]] * y[[2,4,6]]."""
|
|
686
|
+
x = da.arange(30, chunks=10)
|
|
687
|
+
y = da.arange(30, chunks=10)
|
|
688
|
+
|
|
689
|
+
indices = [2, 4, 6, 8]
|
|
690
|
+
result = (x * y)[indices]
|
|
691
|
+
expected = x[indices] * y[indices]
|
|
692
|
+
|
|
693
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
694
|
+
assert_eq(result, expected)
|
|
695
|
+
|
|
696
|
+
|
|
697
|
+
def test_shuffle_pushes_through_elemwise_2d():
|
|
698
|
+
"""Shuffle on 2D array along axis 0."""
|
|
699
|
+
x = da.ones((10, 8), chunks=(5, 4))
|
|
700
|
+
y = da.ones((10, 8), chunks=(5, 4))
|
|
701
|
+
|
|
702
|
+
indices = [0, 2, 4, 6]
|
|
703
|
+
result = (x + y)[indices, :]
|
|
704
|
+
expected = x[indices, :] + y[indices, :]
|
|
705
|
+
|
|
706
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
707
|
+
assert_eq(result, expected)
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def test_shuffle_pushes_through_elemwise_scalar():
|
|
711
|
+
"""Shuffle through elemwise with scalar."""
|
|
712
|
+
x = da.arange(20, chunks=5)
|
|
713
|
+
|
|
714
|
+
indices = [1, 5, 9, 13]
|
|
715
|
+
result = (x + 1)[indices]
|
|
716
|
+
expected = x[indices] + 1
|
|
717
|
+
|
|
718
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
719
|
+
assert_eq(result, expected)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def test_shuffle_pushes_through_unary_elemwise():
|
|
723
|
+
"""Shuffle through unary elemwise (e.g. negative)."""
|
|
724
|
+
x = da.arange(20, chunks=5)
|
|
725
|
+
|
|
726
|
+
indices = [2, 4, 6, 8]
|
|
727
|
+
result = (-x)[indices]
|
|
728
|
+
expected = -(x[indices])
|
|
729
|
+
|
|
730
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
731
|
+
assert_eq(result, expected)
|
|
732
|
+
|
|
733
|
+
|
|
734
|
+
def test_shuffle_through_elemwise_reduces_work():
|
|
735
|
+
"""Taking a subset should reduce computation by only computing needed elements."""
|
|
736
|
+
x = da.ones((100,), chunks=10)
|
|
737
|
+
y = da.ones((100,), chunks=10)
|
|
738
|
+
|
|
739
|
+
# Take only 10 of 100 elements
|
|
740
|
+
indices = list(range(0, 100, 10)) # [0, 10, 20, ..., 90]
|
|
741
|
+
result = (x + y)[indices]
|
|
742
|
+
|
|
743
|
+
# Optimized should have fewer tasks since we only compute what we need
|
|
744
|
+
unopt_tasks = len(result.__dask_graph__())
|
|
745
|
+
opt_tasks = len(result.optimize().__dask_graph__())
|
|
746
|
+
|
|
747
|
+
# Optimization should reduce task count
|
|
748
|
+
assert opt_tasks <= unopt_tasks
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def test_shuffle_through_elemwise_with_broadcast_2d():
|
|
752
|
+
"""Shuffle through elemwise with 2D broadcast operand (size-1 dimension).
|
|
753
|
+
|
|
754
|
+
(a * y2d)[[5]] should optimize to a[[5]] * y2d (shuffle only non-broadcast input).
|
|
755
|
+
"""
|
|
756
|
+
a = da.from_array(np.arange(200).reshape(10, 20), chunks=(4, 5))
|
|
757
|
+
y2d = da.from_array(np.arange(20).reshape(1, 20), chunks=(1, 20))
|
|
758
|
+
|
|
759
|
+
result = (a * y2d)[[5]]
|
|
760
|
+
expected = a[[5]] * y2d
|
|
761
|
+
|
|
762
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
763
|
+
assert_eq(result, expected)
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def test_shuffle_through_elemwise_with_broadcast_1d():
|
|
767
|
+
"""Shuffle through elemwise with 1D broadcast operand.
|
|
768
|
+
|
|
769
|
+
(a * y1d)[[5]] should optimize to a[[5]] * y1d (shuffle only the 2D input).
|
|
770
|
+
"""
|
|
771
|
+
a = da.from_array(np.arange(200).reshape(10, 20), chunks=(4, 5))
|
|
772
|
+
y1d = da.from_array(np.arange(20), chunks=20)
|
|
773
|
+
|
|
774
|
+
result = (a * y1d)[[5]]
|
|
775
|
+
expected = a[[5]] * y1d
|
|
776
|
+
|
|
777
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
778
|
+
assert_eq(result, expected)
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
# --- Shuffle through Transpose Tests ---
|
|
782
|
+
|
|
783
|
+
|
|
784
|
+
def test_shuffle_pushes_through_transpose():
|
|
785
|
+
"""x.T[[1,3,5]] should optimize to x[:, [1,3,5]].T."""
|
|
786
|
+
x = da.arange(20, chunks=5).reshape((4, 5))
|
|
787
|
+
|
|
788
|
+
indices = [1, 3]
|
|
789
|
+
result = x.T[indices, :] # Take rows 1, 3 from transposed (5, 4)
|
|
790
|
+
expected = x[:, indices].T # Take cols 1, 3 from (4, 5), then transpose
|
|
791
|
+
|
|
792
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
793
|
+
assert_eq(result, expected)
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def test_shuffle_pushes_through_transpose_axis1():
|
|
797
|
+
"""x.T[:, [0,2]] should optimize to x[[0,2], :].T."""
|
|
798
|
+
x = da.arange(20, chunks=5).reshape((4, 5))
|
|
799
|
+
|
|
800
|
+
indices = [0, 2]
|
|
801
|
+
result = x.T[:, indices] # Take cols 0, 2 from transposed (5, 4)
|
|
802
|
+
expected = x[indices, :].T # Take rows 0, 2 from (4, 5), then transpose
|
|
803
|
+
|
|
804
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
805
|
+
assert_eq(result, expected)
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
def test_shuffle_pushes_through_transpose_3d():
|
|
809
|
+
"""Shuffle through 3D transpose."""
|
|
810
|
+
x = da.ones((2, 3, 4), chunks=2)
|
|
811
|
+
|
|
812
|
+
indices = [0, 2]
|
|
813
|
+
# Transpose (2,3,4) -> (4,3,2), then take along axis 0
|
|
814
|
+
result = x.transpose((2, 1, 0))[indices, :, :]
|
|
815
|
+
# Equivalent: take along axis 2 of original, then transpose
|
|
816
|
+
expected = x[:, :, indices].transpose((2, 1, 0))
|
|
817
|
+
|
|
818
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
819
|
+
assert_eq(result, expected)
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
# --- Shuffle through Concatenate/Stack Tests ---
|
|
823
|
+
|
|
824
|
+
|
|
825
|
+
def test_shuffle_pushes_through_concatenate():
|
|
826
|
+
"""Shuffle on non-concat axis pushes to all inputs."""
|
|
827
|
+
a = da.arange(20, chunks=5).reshape((4, 5))
|
|
828
|
+
b = da.arange(20, 40, chunks=5).reshape((4, 5))
|
|
829
|
+
|
|
830
|
+
concat = da.concatenate([a, b], axis=1) # (4, 10)
|
|
831
|
+
indices = [0, 2]
|
|
832
|
+
result = concat[indices, :] # Take rows 0, 2
|
|
833
|
+
|
|
834
|
+
expected = da.concatenate([a[indices, :], b[indices, :]], axis=1)
|
|
835
|
+
|
|
836
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
837
|
+
assert_eq(result, expected)
|
|
838
|
+
|
|
839
|
+
|
|
840
|
+
def test_shuffle_pushes_through_stack():
|
|
841
|
+
"""Shuffle on non-stack axis pushes to all inputs."""
|
|
842
|
+
a = da.arange(12, chunks=4).reshape((3, 4))
|
|
843
|
+
b = da.arange(12, 24, chunks=4).reshape((3, 4))
|
|
844
|
+
|
|
845
|
+
stacked = da.stack([a, b], axis=0) # (2, 3, 4)
|
|
846
|
+
indices = [0, 2]
|
|
847
|
+
result = stacked[:, indices, :] # Take along axis 1
|
|
848
|
+
|
|
849
|
+
expected = da.stack([a[indices, :], b[indices, :]], axis=0)
|
|
850
|
+
|
|
851
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
852
|
+
assert_eq(result, expected)
|
|
853
|
+
|
|
854
|
+
|
|
855
|
+
# --- Shuffle through Blockwise Tests ---
|
|
856
|
+
|
|
857
|
+
|
|
858
|
+
def test_shuffle_pushes_through_blockwise():
|
|
859
|
+
"""Shuffle through blockwise when adjust_chunks doesn't affect shuffle axis."""
|
|
860
|
+
from dask_array._blockwise import Blockwise
|
|
861
|
+
|
|
862
|
+
# map_blocks creates a generic Blockwise with no adjust_chunks
|
|
863
|
+
x = da.ones((4, 6), chunks=(2, 3))
|
|
864
|
+
mapped = x.map_blocks(lambda b: b * 2)
|
|
865
|
+
|
|
866
|
+
indices = [0, 2]
|
|
867
|
+
result = mapped[indices, :]
|
|
868
|
+
|
|
869
|
+
# Expected: shuffle first, then map_blocks
|
|
870
|
+
expected = x[indices, :].map_blocks(lambda b: b * 2)
|
|
871
|
+
|
|
872
|
+
# Verify the optimization happened - Blockwise should be at top
|
|
873
|
+
opt = result.expr.simplify()
|
|
874
|
+
assert isinstance(opt, Blockwise)
|
|
875
|
+
|
|
876
|
+
# Verify correctness
|
|
877
|
+
assert_eq(result, expected)
|
|
878
|
+
|
|
879
|
+
|
|
880
|
+
def test_shuffle_does_not_push_through_blockwise_adjust_chunks():
|
|
881
|
+
"""Shuffle does NOT push through blockwise when adjust_chunks affects shuffle axis."""
|
|
882
|
+
from dask_array._shuffle import Shuffle
|
|
883
|
+
|
|
884
|
+
# map_blocks with explicit chunks sets adjust_chunks
|
|
885
|
+
x = da.ones((8, 6), chunks=(2, 3))
|
|
886
|
+
# Providing chunks means each output block has these chunk sizes (adjust_chunks)
|
|
887
|
+
# This creates output with shape (4, 6) chunks (1, 3)
|
|
888
|
+
mapped = x.map_blocks(lambda b: b * 2, chunks=(1, 3))
|
|
889
|
+
|
|
890
|
+
indices = [0, 2] # Taking along axis 0 - NOT all indices
|
|
891
|
+
result = mapped[indices, :]
|
|
892
|
+
|
|
893
|
+
# Shuffle should stay at top (not push through) because axis 0 has adjust_chunks
|
|
894
|
+
opt = result.expr.simplify()
|
|
895
|
+
assert isinstance(opt, Shuffle)
|
|
896
|
+
|
|
897
|
+
# Still correct
|
|
898
|
+
assert_eq(result, mapped.compute()[indices, :])
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
# --- ExpandDims (None indexing) Pushdown Tests ---
|
|
902
|
+
|
|
903
|
+
|
|
904
|
+
def test_none_slice_pushes_through_elemwise():
|
|
905
|
+
"""Slice with None pushes slicing through elemwise, keeps expand_dims on top."""
|
|
906
|
+
x = da.ones((10, 10), chunks=5)
|
|
907
|
+
y = da.ones((10, 10), chunks=5)
|
|
908
|
+
|
|
909
|
+
# (x + y)[None, :5, :] should optimize to (x[:5] + y[:5])[None, :, :]
|
|
910
|
+
result = (x + y)[None, :5, :]
|
|
911
|
+
expected = (x[:5, :] + y[:5, :])[None, :, :]
|
|
912
|
+
|
|
913
|
+
# Structure should match after optimization
|
|
914
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
915
|
+
|
|
916
|
+
# Verify correctness
|
|
917
|
+
assert_eq(result, expected)
|
|
918
|
+
|
|
919
|
+
|
|
920
|
+
def test_none_slice_multiple_nones():
|
|
921
|
+
"""Slice with multiple Nones pushes through correctly."""
|
|
922
|
+
x = da.arange(20, chunks=5).reshape((4, 5))
|
|
923
|
+
y = da.ones((4, 5), chunks=(4, 5))
|
|
924
|
+
|
|
925
|
+
# (x + y)[None, :2, None, :3] -> (x[:2, :3] + y[:2, :3])[None, :, None, :]
|
|
926
|
+
result = (x + y)[None, :2, None, :3]
|
|
927
|
+
expected = (x[:2, :3] + y[:2, :3])[None, :, None, :]
|
|
928
|
+
|
|
929
|
+
# Structure should match after optimization
|
|
930
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
931
|
+
|
|
932
|
+
# Verify correctness
|
|
933
|
+
assert_eq(result, expected)
|
|
934
|
+
|
|
935
|
+
|
|
936
|
+
def test_none_slice_no_slicing():
|
|
937
|
+
"""Slice with only None (dimension expansion) uses ExpandDims."""
|
|
938
|
+
from dask_array.manipulation._expand import ExpandDims
|
|
939
|
+
|
|
940
|
+
x = da.ones((10, 10), chunks=5)
|
|
941
|
+
y = da.ones((10, 10), chunks=5)
|
|
942
|
+
|
|
943
|
+
# (x + y)[None, :, :] - only dimension expansion, no slicing
|
|
944
|
+
result = (x + y)[None, :, :]
|
|
945
|
+
|
|
946
|
+
opt = result.expr.simplify()
|
|
947
|
+
# ExpandDims is used for dimension expansion (not Reshape, for fusion compat)
|
|
948
|
+
assert isinstance(opt, ExpandDims)
|
|
949
|
+
|
|
950
|
+
# Verify correctness
|
|
951
|
+
x_np = np.ones((10, 10))
|
|
952
|
+
y_np = np.ones((10, 10))
|
|
953
|
+
assert_eq(result, (x_np + y_np)[None, :, :])
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
def test_none_slice_through_transpose():
|
|
957
|
+
"""Slice with None pushes through transpose."""
|
|
958
|
+
x = da.arange(20, chunks=5).reshape((4, 5))
|
|
959
|
+
|
|
960
|
+
# x.T[None, :3, :2] -> x[:2, :3].T[None, :, :]
|
|
961
|
+
result = x.T[None, :3, :2]
|
|
962
|
+
expected = x[:2, :3].T[None, :, :]
|
|
963
|
+
|
|
964
|
+
# Structure should match after optimization
|
|
965
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
966
|
+
|
|
967
|
+
# Verify correctness
|
|
968
|
+
assert_eq(result, expected)
|