dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,678 @@
|
|
|
1
|
+
"""Tests for slice pushdown through Blockwise expressions.
|
|
2
|
+
|
|
3
|
+
These tests explore when slice pushdown is safe and correct for different
|
|
4
|
+
Blockwise configurations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pytest
|
|
11
|
+
|
|
12
|
+
import dask_array as da
|
|
13
|
+
from dask_array._test_utils import assert_eq
|
|
14
|
+
|
|
15
|
+
# =============================================================================
|
|
16
|
+
# Case 1: Standard Blockwise (reduction chunk step)
|
|
17
|
+
# - out_ind matches input indices
|
|
18
|
+
# - No new_axes, no adjust_chunks
|
|
19
|
+
# - Slice should push through directly
|
|
20
|
+
# =============================================================================
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def test_slice_through_reduction_blockwise():
|
|
24
|
+
"""Slice pushes through the Blockwise chunk step of a reduction."""
|
|
25
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
26
|
+
|
|
27
|
+
# x.sum(axis=0)[:5] should simplify to x[:, :5].sum(axis=0)
|
|
28
|
+
result = x.sum(axis=0)[:5]
|
|
29
|
+
expected = x[:, :5].sum(axis=0)
|
|
30
|
+
|
|
31
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def test_slice_through_reduction_blockwise_axis1():
|
|
35
|
+
"""Slice through reduction on axis 1."""
|
|
36
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
37
|
+
|
|
38
|
+
result = x.sum(axis=1)[:5]
|
|
39
|
+
expected = x[:5, :].sum(axis=1)
|
|
40
|
+
|
|
41
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
# =============================================================================
|
|
45
|
+
# Case 2: Elemwise operations
|
|
46
|
+
# - Already handled by _pushdown_through_elemwise
|
|
47
|
+
# - Included here for completeness
|
|
48
|
+
# =============================================================================
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_slice_through_elemwise_add():
|
|
52
|
+
"""Slice through addition."""
|
|
53
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
54
|
+
y = da.ones((100, 100), chunks=(10, 10))
|
|
55
|
+
|
|
56
|
+
result = (x + y)[:5, :10]
|
|
57
|
+
expected = x[:5, :10] + y[:5, :10]
|
|
58
|
+
|
|
59
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def test_slice_through_elemwise_unary():
|
|
63
|
+
"""Slice through unary function."""
|
|
64
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
65
|
+
|
|
66
|
+
result = da.sin(x)[:5, :10]
|
|
67
|
+
expected = da.sin(x[:5, :10])
|
|
68
|
+
|
|
69
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# =============================================================================
|
|
73
|
+
# Case 3: Broadcasting
|
|
74
|
+
# - Smaller input has fewer indices
|
|
75
|
+
# - Need to only slice dimensions that exist in the smaller input
|
|
76
|
+
# =============================================================================
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def test_slice_through_broadcast_row():
|
|
80
|
+
"""Slice through broadcasting with a row vector."""
|
|
81
|
+
arr = np.arange(100).reshape(10, 10)
|
|
82
|
+
row = np.arange(10)
|
|
83
|
+
|
|
84
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
85
|
+
r = da.from_array(row, chunks=5)
|
|
86
|
+
|
|
87
|
+
# (x + r)[:3, :4] should simplify to x[:3, :4] + r[:4]
|
|
88
|
+
# Note: expected also needs simplify because slices push into from_array regions
|
|
89
|
+
result = (x + r)[:3, :4]
|
|
90
|
+
expected = x[:3, :4] + r[:4]
|
|
91
|
+
|
|
92
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
93
|
+
assert_eq(result, arr[:3, :4] + row[:4])
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test_slice_through_broadcast_column():
|
|
97
|
+
"""Slice through broadcasting with a column vector."""
|
|
98
|
+
arr = np.arange(100).reshape(10, 10)
|
|
99
|
+
col = np.arange(10).reshape(10, 1)
|
|
100
|
+
|
|
101
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
102
|
+
c = da.from_array(col, chunks=(5, 1))
|
|
103
|
+
|
|
104
|
+
# (x + c)[:3, :4] should simplify to x[:3, :4] + c[:3, :]
|
|
105
|
+
result = (x + c)[:3, :4]
|
|
106
|
+
expected = x[:3, :4] + c[:3, :]
|
|
107
|
+
|
|
108
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
109
|
+
assert_eq(result, arr[:3, :4] + col[:3, :])
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_slice_through_broadcast_scalar():
|
|
113
|
+
"""Slice through broadcasting with a scalar."""
|
|
114
|
+
arr = np.arange(100).reshape(10, 10)
|
|
115
|
+
|
|
116
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
117
|
+
|
|
118
|
+
# (x + 5)[:3, :4] should simplify to x[:3, :4] + 5
|
|
119
|
+
result = (x + 5)[:3, :4]
|
|
120
|
+
expected = x[:3, :4] + 5
|
|
121
|
+
|
|
122
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
123
|
+
assert_eq(result, arr[:3, :4] + 5)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def test_slice_through_broadcast_size_one_dims():
|
|
127
|
+
"""Slice through Elemwise where inputs have size-1 dims that broadcast.
|
|
128
|
+
|
|
129
|
+
When inputs have different size-1 dimensions that broadcast together,
|
|
130
|
+
slicing the output should preserve those size-1 dimensions rather than
|
|
131
|
+
applying the output slice to them.
|
|
132
|
+
|
|
133
|
+
This test covers the case where:
|
|
134
|
+
- Input a has shape (1, M, 1) with size-1 dims at positions 0 and 2
|
|
135
|
+
- Input b has shape (1, 1, N) with size-1 dims at positions 0 and 1
|
|
136
|
+
- Output broadcasts to (1, M, N)
|
|
137
|
+
- Slicing output[:, m1:m2, n1:n2] should produce:
|
|
138
|
+
- a[:, m1:m2, :] + b[:, :, n1:n2] (preserving size-1 dims)
|
|
139
|
+
"""
|
|
140
|
+
# Create inputs with size-1 dims in different positions
|
|
141
|
+
a_np = np.arange(20).reshape(1, 20, 1)
|
|
142
|
+
b_np = np.arange(30).reshape(1, 1, 30)
|
|
143
|
+
|
|
144
|
+
a = da.from_array(a_np, chunks=(1, 10, 1))
|
|
145
|
+
b = da.from_array(b_np, chunks=(1, 1, 15))
|
|
146
|
+
|
|
147
|
+
# Output broadcasts to (1, 20, 30)
|
|
148
|
+
result = a + b
|
|
149
|
+
assert result.shape == (1, 20, 30)
|
|
150
|
+
|
|
151
|
+
# Slice the output - this should not fail during simplify
|
|
152
|
+
sliced = result[:, 5:10, 10:20]
|
|
153
|
+
assert sliced.shape == (1, 5, 10)
|
|
154
|
+
|
|
155
|
+
# Simplify should succeed (was failing before fix)
|
|
156
|
+
simplified = sliced.expr.simplify()
|
|
157
|
+
assert simplified is not None
|
|
158
|
+
|
|
159
|
+
# Verify computed values are correct
|
|
160
|
+
expected = (a_np + b_np)[:, 5:10, 10:20]
|
|
161
|
+
assert_eq(sliced, expected)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def test_slice_through_where_with_broadcast():
|
|
165
|
+
"""Slice through where() with broadcast condition.
|
|
166
|
+
|
|
167
|
+
Regression test for xarray integration - slicing through Where
|
|
168
|
+
with broadcast inputs was failing due to incorrect size-1 handling.
|
|
169
|
+
"""
|
|
170
|
+
# Broadcast condition from size-1 dims
|
|
171
|
+
cond = (
|
|
172
|
+
da.ones((10, 1, 1), dtype=bool, chunks=(5, 1, 1))
|
|
173
|
+
& da.ones((1, 20, 1), dtype=bool, chunks=(1, 10, 1))
|
|
174
|
+
& da.ones((1, 1, 30), dtype=bool, chunks=(1, 1, 15))
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
result = da.where(cond, da.ones((10, 20, 30), chunks=(5, 10, 15)), np.nan)
|
|
178
|
+
sliced = result[:, 5:15, 10:25]
|
|
179
|
+
|
|
180
|
+
# Simplify should succeed (was failing before fix)
|
|
181
|
+
sliced.expr.simplify()
|
|
182
|
+
assert_eq(sliced, np.ones((10, 10, 15)))
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def test_slice_through_shuffle_non_shuffle_axis():
|
|
186
|
+
"""Slice pushes through Shuffle when slicing non-shuffle axes."""
|
|
187
|
+
arr = np.arange(100 * 50 * 60).reshape(100, 50, 60)
|
|
188
|
+
x = da.from_array(arr, chunks=(1, 25, 30)) # chunks=1 on axis 0
|
|
189
|
+
|
|
190
|
+
# Fancy indexing creates Shuffle; use non-identity to prevent simplification
|
|
191
|
+
indices = list(range(50)) + list(range(99, 49, -1)) # 0-49, then 99-50 reversed
|
|
192
|
+
shuffled = x[indices, :, :]
|
|
193
|
+
result = shuffled[:, 10:20, 30:40]
|
|
194
|
+
|
|
195
|
+
# Expected: slice pushed through, so shuffle input is sliced
|
|
196
|
+
# x[:, 10:20, 30:40] then shuffled, not x shuffled then sliced
|
|
197
|
+
expected = x[:, 10:20, 30:40][indices, :, :]
|
|
198
|
+
|
|
199
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
200
|
+
assert_eq(result, arr[indices, :, :][:, 10:20, 30:40])
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def test_slice_through_shuffle_on_shuffle_axis():
|
|
204
|
+
"""Slice on shuffle axis pushes through when input indices are contiguous.
|
|
205
|
+
|
|
206
|
+
This optimization applies to xarray's unstack pattern where the shuffle
|
|
207
|
+
indexer maps contiguous ranges (identity-like with possible padding).
|
|
208
|
+
"""
|
|
209
|
+
from dask_array._new_collection import new_collection
|
|
210
|
+
from dask_array._shuffle import _shuffle
|
|
211
|
+
|
|
212
|
+
arr = np.arange(100 * 50).reshape(100, 50)
|
|
213
|
+
x = da.from_array(arr, chunks=(1, 25))
|
|
214
|
+
|
|
215
|
+
# Simulate xarray unstack: identity shuffle with single-element chunks
|
|
216
|
+
# This is exactly what xarray produces for time dimension restructuring
|
|
217
|
+
indexer = [[i] for i in range(100)]
|
|
218
|
+
shuffled = new_collection(_shuffle(x.expr, indexer, axis=0, name="shuffle"))
|
|
219
|
+
result = shuffled[20:40, :]
|
|
220
|
+
|
|
221
|
+
# Expected: input sliced to [20:40], indexer adjusted
|
|
222
|
+
adjusted_indexer = [[i] for i in range(20)]
|
|
223
|
+
expected = new_collection(_shuffle(x[20:40, :].expr, adjusted_indexer, axis=0, name="shuffle"))
|
|
224
|
+
|
|
225
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
226
|
+
assert_eq(result, arr[20:40, :])
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
def test_slice_through_grouped_shuffle_on_shuffle_axis():
|
|
230
|
+
arr = np.arange(8)
|
|
231
|
+
x = da.from_array(arr, chunks=4)
|
|
232
|
+
indexer = np.array([6, 5, 2, 4, 1, 3, 0, 7])
|
|
233
|
+
|
|
234
|
+
result = x[indexer][1:4]
|
|
235
|
+
|
|
236
|
+
assert_eq(result, arr[indexer][1:4])
|
|
237
|
+
assert_eq(da.Array(result.expr.optimize(fuse=False)), arr[indexer][1:4])
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
# =============================================================================
|
|
241
|
+
# Case 4: new_axes - Blockwise adds dimensions
|
|
242
|
+
# - Slice on a new axis doesn't correspond to input
|
|
243
|
+
# - Should NOT push through (or handle specially)
|
|
244
|
+
# =============================================================================
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def test_slice_new_axis_not_pushed():
|
|
248
|
+
"""Slicing on a new_axis dimension should not push through naively."""
|
|
249
|
+
arr = np.arange(100).reshape(10, 10)
|
|
250
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
251
|
+
|
|
252
|
+
# map_blocks that adds a new axis
|
|
253
|
+
y = da.map_blocks(lambda b: b[..., np.newaxis], x, new_axis=2, dtype=arr.dtype)
|
|
254
|
+
|
|
255
|
+
# Slice on the new axis - this shouldn't cause issues
|
|
256
|
+
result = y[:3, :4, :]
|
|
257
|
+
expected = arr[:3, :4, np.newaxis]
|
|
258
|
+
|
|
259
|
+
assert_eq(result, expected)
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def test_slice_symbolic_new_axis_not_pushed():
|
|
263
|
+
x_np = np.arange(6)
|
|
264
|
+
x = da.from_array(x_np, chunks=3)
|
|
265
|
+
|
|
266
|
+
y = da.blockwise(
|
|
267
|
+
lambda block: np.broadcast_to(block[:, None], (block.shape[0], 5)),
|
|
268
|
+
"az",
|
|
269
|
+
x,
|
|
270
|
+
"a",
|
|
271
|
+
new_axes={"z": 5},
|
|
272
|
+
dtype=x.dtype,
|
|
273
|
+
)
|
|
274
|
+
result = y[:, :2]
|
|
275
|
+
expected = np.broadcast_to(x_np[:, None], (6, 5))[:, :2]
|
|
276
|
+
|
|
277
|
+
assert_eq(result, expected)
|
|
278
|
+
assert_eq(da.Array(result.expr.optimize(fuse=False)), expected)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def test_slice_only_new_axis():
|
|
282
|
+
"""Slicing only the new axis dimension."""
|
|
283
|
+
arr = np.arange(100).reshape(10, 10)
|
|
284
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
285
|
+
|
|
286
|
+
# Add new axis of size > 1
|
|
287
|
+
y = da.map_blocks(
|
|
288
|
+
lambda b: np.repeat(b[..., np.newaxis], 3, axis=2),
|
|
289
|
+
x,
|
|
290
|
+
new_axis=2,
|
|
291
|
+
chunks=(5, 5, 3),
|
|
292
|
+
dtype=arr.dtype,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
# Slice on the new axis
|
|
296
|
+
result = y[:, :, :2]
|
|
297
|
+
# This is complex - the slice on axis 2 can't push to input
|
|
298
|
+
|
|
299
|
+
assert_eq(result, np.repeat(arr[..., np.newaxis], 3, axis=2)[:, :, :2])
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# =============================================================================
|
|
303
|
+
# Case 5: drop_axis / contraction
|
|
304
|
+
# - Input has more dimensions than output
|
|
305
|
+
# - Some input indices don't appear in output
|
|
306
|
+
# =============================================================================
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def test_slice_through_drop_axis():
|
|
310
|
+
"""Slice through a drop_axis operation."""
|
|
311
|
+
arr = np.arange(100).reshape(10, 10)
|
|
312
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
313
|
+
|
|
314
|
+
# map_blocks that drops axis 0
|
|
315
|
+
y = da.map_blocks(lambda b: b.sum(axis=0), x, drop_axis=0, dtype=arr.dtype)
|
|
316
|
+
|
|
317
|
+
# y has shape (10,), slicing [:5] should map to x[:, :5]
|
|
318
|
+
result = y[:5]
|
|
319
|
+
expected = arr.sum(axis=0)[:5]
|
|
320
|
+
|
|
321
|
+
assert_eq(result, expected)
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def test_slice_through_drop_axis_1():
|
|
325
|
+
"""Slice through dropping axis 1."""
|
|
326
|
+
arr = np.arange(100).reshape(10, 10)
|
|
327
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
328
|
+
|
|
329
|
+
# map_blocks that drops axis 1
|
|
330
|
+
y = da.map_blocks(lambda b: b.sum(axis=1), x, drop_axis=1, dtype=arr.dtype)
|
|
331
|
+
|
|
332
|
+
# y has shape (10,), slicing [:5] should map to x[:5, :]
|
|
333
|
+
result = y[:5]
|
|
334
|
+
expected = arr.sum(axis=1)[:5]
|
|
335
|
+
|
|
336
|
+
assert_eq(result, expected)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
# =============================================================================
|
|
340
|
+
# Case 6: adjust_chunks
|
|
341
|
+
# - Chunk sizes change in the output
|
|
342
|
+
# - Slice indices may not map correctly
|
|
343
|
+
# =============================================================================
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def test_slice_adjust_chunks():
|
|
347
|
+
"""Slice through an operation that adjusts chunks."""
|
|
348
|
+
arr = np.arange(100).reshape(10, 10)
|
|
349
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
350
|
+
|
|
351
|
+
# Double each chunk along axis 0
|
|
352
|
+
def double_rows(block):
|
|
353
|
+
return np.repeat(block, 2, axis=0)
|
|
354
|
+
|
|
355
|
+
y = da.map_blocks(
|
|
356
|
+
double_rows,
|
|
357
|
+
x,
|
|
358
|
+
chunks=(10, 5), # chunks double in size
|
|
359
|
+
dtype=arr.dtype,
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
# y has shape (20, 10)
|
|
363
|
+
result = y[:5, :5]
|
|
364
|
+
expected = np.repeat(arr, 2, axis=0)[:5, :5]
|
|
365
|
+
|
|
366
|
+
assert_eq(result, expected)
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
# =============================================================================
|
|
370
|
+
# Case 7: Multiple inputs with different shapes
|
|
371
|
+
# - Inputs align via broadcasting
|
|
372
|
+
# - Need to map slice to each input appropriately
|
|
373
|
+
# =============================================================================
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def test_slice_multiple_inputs_same_shape():
|
|
377
|
+
"""Slice through blockwise with multiple same-shaped inputs."""
|
|
378
|
+
arr1 = np.arange(100).reshape(10, 10)
|
|
379
|
+
arr2 = np.arange(100, 200).reshape(10, 10)
|
|
380
|
+
|
|
381
|
+
x = da.from_array(arr1, chunks=(5, 5))
|
|
382
|
+
y = da.from_array(arr2, chunks=(5, 5))
|
|
383
|
+
|
|
384
|
+
# (x + y)[:3, :4] should simplify to x[:3, :4] + y[:3, :4]
|
|
385
|
+
result = (x + y)[:3, :4]
|
|
386
|
+
expected = x[:3, :4] + y[:3, :4]
|
|
387
|
+
|
|
388
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
389
|
+
assert_eq(result, arr1[:3, :4] + arr2[:3, :4])
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def test_slice_multiple_inputs_broadcast():
|
|
393
|
+
"""Slice through blockwise with broadcasting inputs."""
|
|
394
|
+
arr = np.arange(100).reshape(10, 10)
|
|
395
|
+
vec = np.arange(10)
|
|
396
|
+
|
|
397
|
+
x = da.from_array(arr, chunks=(5, 5))
|
|
398
|
+
v = da.from_array(vec, chunks=5)
|
|
399
|
+
|
|
400
|
+
# (x * v)[:3, :4] should simplify to x[:3, :4] * v[:4]
|
|
401
|
+
result = (x * v)[:3, :4]
|
|
402
|
+
expected = x[:3, :4] * v[:4]
|
|
403
|
+
|
|
404
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
405
|
+
assert_eq(result, arr[:3, :4] * vec[:4])
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
# =============================================================================
|
|
409
|
+
# Correctness tests - verify computed values
|
|
410
|
+
# =============================================================================
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
@pytest.mark.parametrize(
|
|
414
|
+
"shape,chunks,axis,slice_",
|
|
415
|
+
[
|
|
416
|
+
((100, 100), (10, 10), 0, slice(5)),
|
|
417
|
+
((100, 100), (10, 10), 1, slice(5)),
|
|
418
|
+
((100, 100), (10, 10), 0, slice(10, 20)),
|
|
419
|
+
((100, 100), (10, 10), 1, slice(10, 20)),
|
|
420
|
+
((50, 50, 50), (10, 10, 10), 0, slice(5)),
|
|
421
|
+
((50, 50, 50), (10, 10, 10), 1, slice(5)),
|
|
422
|
+
((50, 50, 50), (10, 10, 10), 2, slice(5)),
|
|
423
|
+
],
|
|
424
|
+
)
|
|
425
|
+
def test_slice_through_reduction_correctness(shape, chunks, axis, slice_):
|
|
426
|
+
"""Verify slice-through-reduction produces correct values."""
|
|
427
|
+
arr = np.random.random(shape)
|
|
428
|
+
x = da.from_array(arr, chunks=chunks)
|
|
429
|
+
|
|
430
|
+
# Build the slice tuple for the output
|
|
431
|
+
out_ndim = len(shape) - 1 # reduction removes one axis
|
|
432
|
+
slices = [slice(None)] * out_ndim
|
|
433
|
+
slices[0] = slice_
|
|
434
|
+
|
|
435
|
+
result = x.sum(axis=axis)[tuple(slices)]
|
|
436
|
+
expected = arr.sum(axis=axis)[tuple(slices)]
|
|
437
|
+
|
|
438
|
+
assert_eq(result, expected)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
# =============================================================================
|
|
442
|
+
# Verify optimization is/isn't applied
|
|
443
|
+
# =============================================================================
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def test_optimization_applied_to_reduction():
|
|
447
|
+
"""Verify optimization IS applied: slice pushed through reduction."""
|
|
448
|
+
from dask_array.reductions._reduction import Reduction
|
|
449
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
450
|
+
|
|
451
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
452
|
+
y = x.sum(axis=0)[:5]
|
|
453
|
+
|
|
454
|
+
# Before simplification: Slice(Reduction(...))
|
|
455
|
+
assert isinstance(y.expr, SliceSlicesIntegers)
|
|
456
|
+
|
|
457
|
+
# After simplification: Reduction(Slice(...)) - slice pushed through
|
|
458
|
+
simplified = y.expr.simplify()
|
|
459
|
+
assert not isinstance(simplified, SliceSlicesIntegers)
|
|
460
|
+
assert isinstance(simplified, Reduction)
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def test_optimization_pushes_through_new_axes_when_safe():
|
|
464
|
+
"""Verify slice pushes through new_axes when not slicing the new axis."""
|
|
465
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
466
|
+
|
|
467
|
+
x = da.ones((20, 20), chunks=(5, 5))
|
|
468
|
+
y = da.map_blocks(lambda b: b[..., np.newaxis], x, new_axis=2, dtype=float)
|
|
469
|
+
z = y[:5, :5, :] # Not slicing the new axis (axis 2)
|
|
470
|
+
|
|
471
|
+
# The slice CAN push through because we're not slicing axis 2
|
|
472
|
+
simplified = z.expr.simplify()
|
|
473
|
+
assert not isinstance(simplified, SliceSlicesIntegers)
|
|
474
|
+
assert_eq(z, np.ones((20, 20))[:5, :5, np.newaxis])
|
|
475
|
+
|
|
476
|
+
|
|
477
|
+
def test_optimization_not_applied_slicing_new_axes():
|
|
478
|
+
"""Verify optimization is NOT applied when slicing new_axes dimension."""
|
|
479
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
480
|
+
|
|
481
|
+
x = da.ones((20, 20), chunks=(5, 5))
|
|
482
|
+
# Add new axis of size 3
|
|
483
|
+
y = da.map_blocks(
|
|
484
|
+
lambda b: np.repeat(b[..., np.newaxis], 3, axis=2),
|
|
485
|
+
x,
|
|
486
|
+
new_axis=2,
|
|
487
|
+
chunks=(5, 5, 3),
|
|
488
|
+
dtype=float,
|
|
489
|
+
)
|
|
490
|
+
z = y[:5, :5, :2] # Slicing the new axis (axis 2)
|
|
491
|
+
|
|
492
|
+
# The slice should NOT push through because we're slicing axis 2
|
|
493
|
+
simplified = z.expr.simplify()
|
|
494
|
+
assert isinstance(simplified, SliceSlicesIntegers)
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
def test_optimization_reduces_tasks():
|
|
498
|
+
"""Verify optimization reduces task count for from_array."""
|
|
499
|
+
arr = np.ones((100, 100))
|
|
500
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
501
|
+
|
|
502
|
+
full = x.sum(axis=0)
|
|
503
|
+
sliced = x.sum(axis=0)[:5]
|
|
504
|
+
|
|
505
|
+
full_tasks = len(full.optimize().__dask_graph__())
|
|
506
|
+
sliced_tasks = len(sliced.optimize().__dask_graph__())
|
|
507
|
+
|
|
508
|
+
# Sliced should have fewer tasks (only processes 1 column of chunks)
|
|
509
|
+
assert sliced_tasks < full_tasks
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
# =============================================================================
|
|
513
|
+
# Case 8: Tensordot / Matmul
|
|
514
|
+
# - adjust_chunks only affects contracted dimension
|
|
515
|
+
# - Slices on non-contracted dimensions can push through
|
|
516
|
+
# =============================================================================
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
@pytest.mark.filterwarnings("ignore::dask.array.core.PerformanceWarning")
|
|
520
|
+
def test_slice_through_tensordot_correctness():
|
|
521
|
+
"""Verify slice through tensordot produces correct values."""
|
|
522
|
+
arr = np.random.random((100, 100))
|
|
523
|
+
x = da.from_array(arr, chunks=(10, 10))
|
|
524
|
+
|
|
525
|
+
result = x.dot(x.T)[:5, :5]
|
|
526
|
+
expected = arr.dot(arr.T)[:5, :5]
|
|
527
|
+
|
|
528
|
+
assert_eq(result, expected)
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
@pytest.mark.filterwarnings("ignore::dask.array.core.PerformanceWarning")
|
|
532
|
+
def test_slice_through_matmul_correctness():
|
|
533
|
+
"""Verify slice through matmul produces correct values."""
|
|
534
|
+
arr1 = np.random.random((100, 50))
|
|
535
|
+
arr2 = np.random.random((50, 100))
|
|
536
|
+
x = da.from_array(arr1, chunks=(10, 10))
|
|
537
|
+
y = da.from_array(arr2, chunks=(10, 10))
|
|
538
|
+
|
|
539
|
+
result = (x @ y)[:5, :5]
|
|
540
|
+
expected = (arr1 @ arr2)[:5, :5]
|
|
541
|
+
|
|
542
|
+
assert_eq(result, expected)
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
@pytest.mark.filterwarnings("ignore::dask.array.core.PerformanceWarning")
|
|
546
|
+
def test_slice_through_matmul_expression_structure():
|
|
547
|
+
"""Verify x.dot(y)[a:b, c:d] simplifies to x[a:b, :].dot(y[:, c:d])."""
|
|
548
|
+
x = da.ones((100, 50), chunks=(10, 10))
|
|
549
|
+
y = da.ones((50, 100), chunks=(10, 10))
|
|
550
|
+
|
|
551
|
+
# Use different slices to verify correct operand mapping
|
|
552
|
+
result = (x @ y)[:15, :25]
|
|
553
|
+
expected = x[:15, :] @ y[:, :25]
|
|
554
|
+
|
|
555
|
+
# Both should simplify to equivalent expressions
|
|
556
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
@pytest.mark.filterwarnings("ignore::dask.array.core.PerformanceWarning")
|
|
560
|
+
def test_slice_through_tensordot_reduces_tasks():
|
|
561
|
+
"""Verify slice through tensordot reduces task count.
|
|
562
|
+
|
|
563
|
+
x.dot(x.T)[0:5, 0:5] should optimize to compute only the
|
|
564
|
+
submatrix, not the full matrix then slice.
|
|
565
|
+
"""
|
|
566
|
+
x = da.ones((100, 100), chunks=(10, 10))
|
|
567
|
+
|
|
568
|
+
full = x.dot(x.T)
|
|
569
|
+
sliced = x.dot(x.T)[:5, :5]
|
|
570
|
+
|
|
571
|
+
full_tasks = len(full.optimize().__dask_graph__())
|
|
572
|
+
sliced_tasks = len(sliced.optimize().__dask_graph__())
|
|
573
|
+
|
|
574
|
+
# Sliced should have significantly fewer tasks
|
|
575
|
+
# Full: 10x10 output chunks = 100 output chunks
|
|
576
|
+
# Sliced: 1x1 output chunks = 1 output chunk
|
|
577
|
+
# Task reduction should be ~10x or more
|
|
578
|
+
assert sliced_tasks < full_tasks / 5
|
|
579
|
+
|
|
580
|
+
|
|
581
|
+
# =============================================================================
|
|
582
|
+
# Regression tests
|
|
583
|
+
# =============================================================================
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def test_integer_index_on_size_one_dim_through_elemwise():
|
|
587
|
+
"""Integer indexing on size-1 dims must remove the dimension.
|
|
588
|
+
|
|
589
|
+
Regression test: when Elemwise._accept_slice pushed integer indices
|
|
590
|
+
through size-1 dimensions, it was incorrectly converting them to
|
|
591
|
+
slice(None), keeping the dimension instead of removing it.
|
|
592
|
+
"""
|
|
593
|
+
arr = da.from_array(np.random.randn(8, 9, 10), chunks=(8, 9, 10))
|
|
594
|
+
shuffled = da.shuffle(arr, [[0]], axis=2) # -> (8, 9, 1)
|
|
595
|
+
|
|
596
|
+
# Elemwise on top of shuffle
|
|
597
|
+
cond = da.from_array(np.array([True]), chunks=(1,))
|
|
598
|
+
elemwise = da.where(cond, shuffled, np.nan)
|
|
599
|
+
|
|
600
|
+
# Integer index should remove the dimension
|
|
601
|
+
indexed = elemwise[:, :, 0]
|
|
602
|
+
assert indexed.shape == (8, 9)
|
|
603
|
+
assert indexed.compute().shape == (8, 9)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def test_integer_index_through_elemwise_broadcast():
|
|
607
|
+
"""Integer index through Elemwise with broadcasting preserves semantics."""
|
|
608
|
+
# Array with size-1 dimension
|
|
609
|
+
x = da.ones((10, 1, 20), chunks=(5, 1, 10))
|
|
610
|
+
y = da.ones((10, 15, 20), chunks=(5, 5, 10))
|
|
611
|
+
|
|
612
|
+
result = (x + y)[:, :, 0]
|
|
613
|
+
|
|
614
|
+
# Integer index on axis 2 should remove it
|
|
615
|
+
assert result.shape == (10, 15)
|
|
616
|
+
assert_eq(result, np.ones((10, 15)) * 2)
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
# =============================================================================
|
|
620
|
+
# Regression tests for empty slice handling
|
|
621
|
+
# =============================================================================
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def test_empty_slice_through_elemwise_broadcast():
|
|
625
|
+
"""Empty slice through Elemwise with broadcast preserves empty output.
|
|
626
|
+
|
|
627
|
+
Regression test: empty slices like [:0] on broadcast dimensions were
|
|
628
|
+
incorrectly replaced with [:], producing non-empty output.
|
|
629
|
+
"""
|
|
630
|
+
scalar_da = da.from_array(np.float32(0.0), chunks=-1)
|
|
631
|
+
arr_da = da.from_array(np.array([[0.0]], dtype="float32"), chunks=-1)
|
|
632
|
+
|
|
633
|
+
# scalar () + (1, 1) broadcasts to (1, 1)
|
|
634
|
+
added = scalar_da + arr_da
|
|
635
|
+
assert added.shape == (1, 1)
|
|
636
|
+
|
|
637
|
+
# [0, :0] should give shape (0,) - empty array
|
|
638
|
+
result = added[0, :0]
|
|
639
|
+
assert result.shape == (0,)
|
|
640
|
+
assert result.compute().shape == (0,)
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
def test_integer_index_out_of_bounds_on_broadcast_dim():
|
|
644
|
+
"""Integer index larger than input size works on broadcast dimension.
|
|
645
|
+
|
|
646
|
+
Regression test: integer indices like [1] on size-1 broadcast dimensions
|
|
647
|
+
were applied directly, causing IndexError.
|
|
648
|
+
"""
|
|
649
|
+
scalar = da.from_array(np.float32(0.0), chunks=-1)
|
|
650
|
+
arr1 = da.from_array(np.array([[0.0, 1.0]], dtype="float32"), chunks=-1) # (1, 2)
|
|
651
|
+
arr2 = da.from_array(np.zeros((1, 1, 1, 1), dtype="float32"), chunks=-1)
|
|
652
|
+
|
|
653
|
+
# scalar + (1, 2) + (1, 1, 1, 1) = (1, 1, 1, 2)
|
|
654
|
+
result = scalar + arr1 + arr2
|
|
655
|
+
assert result.shape == (1, 1, 1, 2)
|
|
656
|
+
|
|
657
|
+
# [0, 0, 0, 1] - the index 1 on axis 3 is valid for output but the
|
|
658
|
+
# (1, 1, 1, 1) input only has size 1 on that axis (broadcast)
|
|
659
|
+
indexed = result[0, 0, 0, 1]
|
|
660
|
+
assert indexed.shape == ()
|
|
661
|
+
assert indexed.compute() == 1.0 # arr1[0, 1] = 1.0
|
|
662
|
+
|
|
663
|
+
|
|
664
|
+
def test_empty_slice_not_pushed_through_reduction():
|
|
665
|
+
"""Empty slice after reduction is not pushed through.
|
|
666
|
+
|
|
667
|
+
Regression test: pushing empty slices through reductions created invalid
|
|
668
|
+
task graphs because the reduction machinery doesn't handle empty
|
|
669
|
+
non-reduced dimensions.
|
|
670
|
+
"""
|
|
671
|
+
arr = da.from_array(np.zeros((1, 2, 1, 1), dtype="float32"), chunks=-1)
|
|
672
|
+
reduced = da.nanmin(arr, axis=(1, 2, 3)) # (1,)
|
|
673
|
+
|
|
674
|
+
# [:-1] on (1,) gives (0,) - empty array
|
|
675
|
+
sliced = reduced[:-1]
|
|
676
|
+
assert sliced.shape == (0,)
|
|
677
|
+
result = sliced.compute()
|
|
678
|
+
assert result.shape == (0,)
|