dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,799 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import operator
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import pytest
|
|
7
|
+
|
|
8
|
+
import dask_array as da
|
|
9
|
+
from dask import is_dask_collection
|
|
10
|
+
from dask_array._test_utils import assert_eq
|
|
11
|
+
from dask_array._collection import Array
|
|
12
|
+
from dask_array._rechunk import Rechunk
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@pytest.fixture()
|
|
16
|
+
def arr():
|
|
17
|
+
return da.random.random((10, 10), chunks=(5, 6))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.mark.parametrize(
|
|
21
|
+
"op",
|
|
22
|
+
[
|
|
23
|
+
"__add__",
|
|
24
|
+
"__sub__",
|
|
25
|
+
"__mul__",
|
|
26
|
+
"__truediv__",
|
|
27
|
+
"__floordiv__",
|
|
28
|
+
"__pow__",
|
|
29
|
+
"__radd__",
|
|
30
|
+
"__rsub__",
|
|
31
|
+
"__rmul__",
|
|
32
|
+
"__rtruediv__",
|
|
33
|
+
"__rfloordiv__",
|
|
34
|
+
"__rpow__",
|
|
35
|
+
],
|
|
36
|
+
)
|
|
37
|
+
def test_arithmetic_ops(arr, op):
|
|
38
|
+
result = getattr(arr, op)(2)
|
|
39
|
+
expected = getattr(arr.compute(), op)(2)
|
|
40
|
+
assert_eq(result, expected)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def test_rechunk(arr):
|
|
44
|
+
result = arr.rechunk((7, 3))
|
|
45
|
+
expected = arr.compute()
|
|
46
|
+
assert_eq(result, expected)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def test_blockwise():
|
|
50
|
+
x = da.random.random((10, 10), chunks=(5, 5))
|
|
51
|
+
z = da.blockwise(operator.add, "ij", x, "ij", 100, None, dtype=x.dtype)
|
|
52
|
+
assert_eq(z, x.compute() + 100)
|
|
53
|
+
|
|
54
|
+
x = da.random.random((10, 10), chunks=(5, 5))
|
|
55
|
+
z = da.blockwise(operator.add, "ij", x, "ij", x, "ij", dtype=x.dtype)
|
|
56
|
+
expr = z.expr.optimize()
|
|
57
|
+
assert len(list(expr.find_operations(Rechunk))) == 0
|
|
58
|
+
assert_eq(z, x.compute() * 2)
|
|
59
|
+
|
|
60
|
+
# align
|
|
61
|
+
x = da.random.random((10, 10), chunks=(5, 5))
|
|
62
|
+
y = da.random.random((10, 10), chunks=(7, 3))
|
|
63
|
+
z = da.blockwise(operator.add, "ij", x, "ij", y, "ij", dtype=x.dtype)
|
|
64
|
+
expr = z.expr.optimize()
|
|
65
|
+
assert len(list(expr.find_operations(Rechunk))) > 0
|
|
66
|
+
assert_eq(z, x.compute() + y.compute())
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@pytest.mark.parametrize("func", ["min", "max", "sum", "prod", "mean", "any", "all"])
|
|
70
|
+
def test_reductions(arr, func):
|
|
71
|
+
# var and std need __array_function__
|
|
72
|
+
result = getattr(arr, func)(axis=0)
|
|
73
|
+
expected = getattr(arr.compute(), func)(axis=0)
|
|
74
|
+
assert_eq(result, expected)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
@pytest.mark.parametrize(
|
|
78
|
+
"func",
|
|
79
|
+
[
|
|
80
|
+
"sum",
|
|
81
|
+
"mean",
|
|
82
|
+
"any",
|
|
83
|
+
"all",
|
|
84
|
+
"max",
|
|
85
|
+
"min",
|
|
86
|
+
"nanmin",
|
|
87
|
+
"nanmax",
|
|
88
|
+
"nanmean",
|
|
89
|
+
"nansum",
|
|
90
|
+
"nanprod",
|
|
91
|
+
],
|
|
92
|
+
)
|
|
93
|
+
def test_reductions_toplevel(arr, func):
|
|
94
|
+
# var and std need __array_function__
|
|
95
|
+
result = getattr(da, func)(arr, axis=0)
|
|
96
|
+
expected = getattr(np, func)(arr.compute(), axis=0)
|
|
97
|
+
assert_eq(result, expected)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_from_array():
|
|
101
|
+
x = np.random.random((10, 10))
|
|
102
|
+
d = da.from_array(x, chunks=(5, 5))
|
|
103
|
+
assert_eq(d, x)
|
|
104
|
+
assert d.chunks == ((5, 5), (5, 5))
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_from_graph_same_key_prefix_different_layers():
|
|
108
|
+
from dask_array.core import from_graph
|
|
109
|
+
|
|
110
|
+
a = from_graph(
|
|
111
|
+
{("x", 0): np.array([1])},
|
|
112
|
+
np.empty((0,), dtype=int),
|
|
113
|
+
((1,),),
|
|
114
|
+
[("x", 0)],
|
|
115
|
+
"a",
|
|
116
|
+
)
|
|
117
|
+
b = from_graph(
|
|
118
|
+
{("x", 0): np.array([2])},
|
|
119
|
+
np.empty((0,), dtype=int),
|
|
120
|
+
((1,),),
|
|
121
|
+
[("x", 0)],
|
|
122
|
+
"b",
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
assert a.expr is not b.expr
|
|
126
|
+
assert_eq(a, np.array([1]))
|
|
127
|
+
assert_eq(b, np.array([2]))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def test_from_graph_tracks_expression_dependencies():
|
|
131
|
+
from dask._task_spec import DependenciesMapping, Task, TaskRef
|
|
132
|
+
from dask_array.core import from_graph
|
|
133
|
+
|
|
134
|
+
x = da.from_array(np.arange(6), chunks=(3,)).rechunk((2,))
|
|
135
|
+
name = "plus-one"
|
|
136
|
+
layer = {
|
|
137
|
+
(name, i): Task((name, i), operator.add, TaskRef((x.name, i)), 1)
|
|
138
|
+
for i in range(len(x.chunks[0]))
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
y = from_graph(
|
|
142
|
+
layer,
|
|
143
|
+
np.empty((0,), dtype=x.dtype),
|
|
144
|
+
x.chunks,
|
|
145
|
+
[(name, i) for i in range(len(x.chunks[0]))],
|
|
146
|
+
name,
|
|
147
|
+
dependencies=[x],
|
|
148
|
+
)
|
|
149
|
+
optimized = da.Array(y[:4].expr.optimize(fuse=True))
|
|
150
|
+
graph = optimized.__dask_graph__()
|
|
151
|
+
missing = [
|
|
152
|
+
dep
|
|
153
|
+
for deps in DependenciesMapping(graph).values()
|
|
154
|
+
for dep in deps
|
|
155
|
+
if dep not in graph
|
|
156
|
+
]
|
|
157
|
+
|
|
158
|
+
assert not missing
|
|
159
|
+
assert_eq(optimized, np.arange(4) + 1)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@pytest.mark.xfail(reason="Requires dask core to recognize 'dask_array' module in is_dask_collection")
|
|
163
|
+
def test_is_dask_collection_doesnt_materialize():
|
|
164
|
+
class ArrayTest(Array):
|
|
165
|
+
def __dask_graph__(self):
|
|
166
|
+
raise NotImplementedError
|
|
167
|
+
|
|
168
|
+
arr = ArrayTest(da.random.random((10, 10), chunks=(5, 5)).expr)
|
|
169
|
+
assert is_dask_collection(arr)
|
|
170
|
+
with pytest.raises(NotImplementedError):
|
|
171
|
+
arr.__dask_graph__()
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def test_astype():
|
|
175
|
+
x = da.random.randint(1, 100, (10, 10), chunks=(5, 5))
|
|
176
|
+
result = x.astype(np.float64)
|
|
177
|
+
expected = x.compute().astype(np.float64)
|
|
178
|
+
assert_eq(result, expected)
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def test_stack_promote_type():
|
|
182
|
+
i = np.arange(10, dtype="i4")
|
|
183
|
+
f = np.arange(10, dtype="f4")
|
|
184
|
+
di = da.from_array(i, chunks=5)
|
|
185
|
+
df = da.from_array(f, chunks=5)
|
|
186
|
+
res = da.stack([di, df])
|
|
187
|
+
assert_eq(res, np.stack([i, f]))
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def test_field_access():
|
|
191
|
+
x = np.array([(1, 1.0), (2, 2.0)], dtype=[("a", "i4"), ("b", "f4")])
|
|
192
|
+
y = da.from_array(x, chunks=(1,))
|
|
193
|
+
assert_eq(y["a"], x["a"])
|
|
194
|
+
assert_eq(y[["b", "a"]], x[["b", "a"]])
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def test_field_access_with_shape():
|
|
198
|
+
dtype = [("col1", ("f4", (3, 2))), ("col2", ("f4", 3))]
|
|
199
|
+
data = np.ones((100, 50), dtype=dtype)
|
|
200
|
+
x = da.from_array(data, 10)
|
|
201
|
+
assert_eq(x["col1"], data["col1"])
|
|
202
|
+
assert_eq(x[["col1"]], data[["col1"]])
|
|
203
|
+
assert_eq(x["col2"], data["col2"])
|
|
204
|
+
assert_eq(x[["col1", "col2"]], data[["col1", "col2"]])
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# =============================================================================
|
|
208
|
+
# Optimization tests (ported from dask-expr prototype)
|
|
209
|
+
# =============================================================================
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def test_transpose_optimize():
|
|
213
|
+
"""Test that transpose of transpose simplifies."""
|
|
214
|
+
a = np.random.random((10, 20))
|
|
215
|
+
b = da.from_array(a, chunks=(2, 5))
|
|
216
|
+
|
|
217
|
+
# T.T should be identity
|
|
218
|
+
assert b.T.T.expr.optimize()._name == b.expr.optimize()._name
|
|
219
|
+
assert_eq(b.T.T, a)
|
|
220
|
+
|
|
221
|
+
# Explicit axes composition
|
|
222
|
+
c = da.from_array(np.random.random((3, 4, 5)), chunks=(1, 2, 3))
|
|
223
|
+
d = c.transpose((2, 0, 1)).transpose((1, 2, 0)) # Should compose to (0, 1, 2) = identity
|
|
224
|
+
assert_eq(d, c)
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
def test_rechunk_optimize():
|
|
228
|
+
"""Test that rechunk of rechunk simplifies to single rechunk."""
|
|
229
|
+
a = np.random.random((10, 10))
|
|
230
|
+
b = da.from_array(a, chunks=(4, 4))
|
|
231
|
+
|
|
232
|
+
c = b.rechunk((2, 5)).rechunk((5, 2))
|
|
233
|
+
d = b.rechunk((5, 2))
|
|
234
|
+
|
|
235
|
+
# Double rechunk should simplify to single rechunk
|
|
236
|
+
assert c.expr.optimize()._name == d.expr.optimize()._name
|
|
237
|
+
assert_eq(c, a)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def test_slicing_optimize_identity():
|
|
241
|
+
"""Test that no-op slice simplifies to identity."""
|
|
242
|
+
a = np.random.random((10, 20))
|
|
243
|
+
b = da.from_array(a, chunks=(2, 5))
|
|
244
|
+
|
|
245
|
+
# b[:] should simplify to b
|
|
246
|
+
assert b[:].expr.optimize()._name == b.expr._name
|
|
247
|
+
assert_eq(b[:], a)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def test_slicing_optimize_fusion():
|
|
251
|
+
"""Test that slice of slice fuses into single slice."""
|
|
252
|
+
a = np.random.random((10, 20))
|
|
253
|
+
b = da.from_array(a, chunks=(2, 5))
|
|
254
|
+
|
|
255
|
+
# Slice fusion: b[5:, 4][::2] should equal b[5::2, 4]
|
|
256
|
+
result = b[5:, 4][::2]
|
|
257
|
+
expected = b[5::2, 4]
|
|
258
|
+
assert result.expr.optimize()._name == expected.expr.optimize()._name
|
|
259
|
+
assert_eq(result, a[5::2, 4])
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def test_slicing_pushdown_elemwise():
|
|
263
|
+
"""Test that slice pushes through elemwise."""
|
|
264
|
+
a = np.random.random((10, 20))
|
|
265
|
+
b = da.from_array(a, chunks=(2, 5))
|
|
266
|
+
|
|
267
|
+
# (b + 1)[:5] should become (b[:5] + 1)
|
|
268
|
+
result = (b + 1)[:5]
|
|
269
|
+
expected = b[:5] + 1
|
|
270
|
+
assert result.expr.optimize()._name == expected.expr.optimize()._name
|
|
271
|
+
assert_eq(result, (a + 1)[:5])
|
|
272
|
+
|
|
273
|
+
# Test with integer index that reduces dimension
|
|
274
|
+
result2 = (b + 1)[5]
|
|
275
|
+
expected2 = b[5] + 1
|
|
276
|
+
assert result2.expr.optimize()._name == expected2.expr.optimize()._name
|
|
277
|
+
assert_eq(result2, (a + 1)[5])
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def test_slicing_pushdown_elemwise_broadcast():
|
|
281
|
+
"""Test slice pushdown through elemwise with broadcasting."""
|
|
282
|
+
a = np.random.random((10, 20))
|
|
283
|
+
c = np.random.random((20,)) # broadcasts on axis 0
|
|
284
|
+
aa = da.from_array(a, chunks=(2, 5))
|
|
285
|
+
cc = da.from_array(c, chunks=(5,))
|
|
286
|
+
|
|
287
|
+
# (aa + cc)[:5] should become (aa[:5] + cc)
|
|
288
|
+
# cc doesn't get sliced because axis 0 is broadcast
|
|
289
|
+
result = (aa + cc)[:5]
|
|
290
|
+
expected = aa[:5] + cc
|
|
291
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
292
|
+
assert_eq(result, (a + c)[:5])
|
|
293
|
+
|
|
294
|
+
# (aa + cc)[:, ::2] should become (aa[:, ::2] + cc[::2])
|
|
295
|
+
result2 = (aa + cc)[:, ::2]
|
|
296
|
+
expected2 = aa[:, ::2] + cc[::2]
|
|
297
|
+
assert result2.expr.simplify()._name == expected2.expr.simplify()._name
|
|
298
|
+
assert_eq(result2, (a + c)[:, ::2])
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def test_slicing_pushdown_transpose():
|
|
302
|
+
"""Test slice pushdown through transpose."""
|
|
303
|
+
a = np.random.random((10, 20))
|
|
304
|
+
b = da.from_array(a, chunks=(2, 5))
|
|
305
|
+
|
|
306
|
+
# b.T[5:] should become b[:, 5:].T
|
|
307
|
+
result = b.T[5:]
|
|
308
|
+
expected = b[:, 5:].T
|
|
309
|
+
assert result.expr.optimize()._name == expected.expr.optimize()._name
|
|
310
|
+
assert_eq(result, a.T[5:])
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def test_rechunk_pushdown_transpose():
|
|
314
|
+
"""Test rechunk pushdown through transpose."""
|
|
315
|
+
a = np.random.random((10, 20))
|
|
316
|
+
b = da.from_array(a, chunks=(2, 5))
|
|
317
|
+
|
|
318
|
+
# b.T.rechunk((10, 5)) should become Transpose(Rechunk(...))
|
|
319
|
+
# not Rechunk(Transpose(...))
|
|
320
|
+
result = b.T.rechunk((10, 5))
|
|
321
|
+
opt = result.expr.optimize()
|
|
322
|
+
# Should be Transpose at top level (rechunk pushed inside)
|
|
323
|
+
assert type(opt).__name__ == "Transpose"
|
|
324
|
+
assert_eq(result, a.T)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
def test_rechunk_pushdown_elemwise():
|
|
328
|
+
"""Test rechunk pushdown through elemwise."""
|
|
329
|
+
a = np.random.random((10, 10))
|
|
330
|
+
b = da.from_array(a, chunks=(4, 4))
|
|
331
|
+
|
|
332
|
+
# (b + 1).rechunk((5, 5)) should become Elemwise at top level
|
|
333
|
+
# not Rechunk(Elemwise(...))
|
|
334
|
+
result = (b + 1).rechunk((5, 5))
|
|
335
|
+
opt = result.expr.optimize()
|
|
336
|
+
# Should be Elemwise at top level (rechunk pushed inside)
|
|
337
|
+
assert type(opt).__name__ == "Elemwise"
|
|
338
|
+
assert_eq(result, a + 1)
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def test_rechunk_pushdown_elemwise_broadcast():
|
|
342
|
+
"""Test rechunk pushdown through elemwise with broadcasting."""
|
|
343
|
+
a = np.random.random((10,))
|
|
344
|
+
aa = da.from_array(a)
|
|
345
|
+
b = np.random.random((10, 10))
|
|
346
|
+
bb = da.from_array(b)
|
|
347
|
+
|
|
348
|
+
# (aa + bb).rechunk((5, 2)) should become Elemwise at top level
|
|
349
|
+
c = (aa + bb).rechunk((5, 2))
|
|
350
|
+
# Expected: rechunk pushed to inputs
|
|
351
|
+
expected = aa.rechunk((2,)) + bb.rechunk((5, 2))
|
|
352
|
+
assert c.expr.simplify()._name == expected.expr.simplify()._name
|
|
353
|
+
|
|
354
|
+
opt = c.expr.optimize()
|
|
355
|
+
# Should be Elemwise at top level (rechunk pushed inside)
|
|
356
|
+
assert type(opt).__name__ == "Elemwise"
|
|
357
|
+
assert_eq(c, a + b)
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# =============================================================================
|
|
361
|
+
# Optimization correctness and safety tests
|
|
362
|
+
# =============================================================================
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def test_optimization_correctness_various_chains():
|
|
366
|
+
"""Verify optimized expressions produce correct results."""
|
|
367
|
+
np.random.seed(42)
|
|
368
|
+
a = da.random.random((15, 25), chunks=(3, 7))
|
|
369
|
+
a_np = a.compute()
|
|
370
|
+
|
|
371
|
+
# Various operation chains - verify correctness
|
|
372
|
+
assert_eq(a.T.T, a_np)
|
|
373
|
+
assert_eq(a.T[5:].T, a_np[:, 5:])
|
|
374
|
+
assert_eq((a + 1).rechunk((5, 5))[:10], (a_np + 1)[:10])
|
|
375
|
+
assert_eq(a.rechunk((5, 5)).rechunk((3, 3)), a_np)
|
|
376
|
+
assert_eq(a[::2, 1:][::2], a_np[::2, 1:][::2])
|
|
377
|
+
assert_eq((a * 2)[:, 10:][5:], (a_np * 2)[:, 10:][5:])
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def test_optimize_empty_array():
|
|
381
|
+
"""Verify optimizations handle empty arrays."""
|
|
382
|
+
a = da.zeros((0, 10), chunks=(1, 5))
|
|
383
|
+
result = (a + 1)[:, :5]
|
|
384
|
+
assert result.shape == (0, 5)
|
|
385
|
+
assert_eq(result, np.zeros((0, 5)))
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def test_optimize_3d_transpose():
|
|
389
|
+
"""Verify transpose composition works for 3D arrays."""
|
|
390
|
+
np.random.seed(42)
|
|
391
|
+
a = da.random.random((4, 5, 6), chunks=2)
|
|
392
|
+
|
|
393
|
+
# (2,0,1) then (1,2,0) should compose to identity
|
|
394
|
+
result = a.transpose((2, 0, 1)).transpose((1, 2, 0))
|
|
395
|
+
opt = result.expr.optimize()
|
|
396
|
+
# Should simplify to original (no Transpose at top)
|
|
397
|
+
assert type(opt).__name__ != "Transpose" or opt.axes == tuple(range(3))
|
|
398
|
+
assert_eq(result, a)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def test_optimize_scalar_in_elemwise():
|
|
402
|
+
"""Verify scalar handling in elemwise pushdown."""
|
|
403
|
+
np.random.seed(42)
|
|
404
|
+
b = da.random.random((10, 10), chunks=5)
|
|
405
|
+
b_np = b.compute()
|
|
406
|
+
|
|
407
|
+
# Scalar + array, then slice
|
|
408
|
+
result = (5 + b)[:5]
|
|
409
|
+
assert_eq(result, (5 + b_np)[:5])
|
|
410
|
+
|
|
411
|
+
# Slice then rechunk with scalar
|
|
412
|
+
result = (b * 2).rechunk((5, 5))
|
|
413
|
+
assert_eq(result, b_np * 2)
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def test_chunks_preserved_after_optimization():
|
|
417
|
+
"""Verify chunk structure is correct after optimization."""
|
|
418
|
+
a = da.random.random((20, 20), chunks=(4, 5))
|
|
419
|
+
|
|
420
|
+
# Transpose then rechunk
|
|
421
|
+
result = a.T.rechunk((10, 10))
|
|
422
|
+
assert result.chunks == ((10, 10), (10, 10))
|
|
423
|
+
|
|
424
|
+
# Elemwise then slice
|
|
425
|
+
result = (a + 1)[:10, :15]
|
|
426
|
+
assert result.chunks == ((4, 4, 2), (5, 5, 5))
|
|
427
|
+
|
|
428
|
+
# Slice then rechunk
|
|
429
|
+
result = a[:12, :8].rechunk((6, 4))
|
|
430
|
+
assert result.chunks == ((6, 6), (4, 4))
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def test_pushdown_broadcast_both_arrays():
|
|
434
|
+
"""Test pushdown when both arrays broadcast to output shape."""
|
|
435
|
+
# (10, 1) + (1, 20) -> (10, 20)
|
|
436
|
+
a = da.from_array(np.random.random((10, 1)), chunks=(5, 1))
|
|
437
|
+
b = da.from_array(np.random.random((1, 20)), chunks=(1, 10))
|
|
438
|
+
a_np, b_np = a.compute(), b.compute()
|
|
439
|
+
|
|
440
|
+
# Slice pushdown - each input sliced on its non-broadcast dimension
|
|
441
|
+
result = (a + b)[:5, :10]
|
|
442
|
+
opt = result.expr.optimize()
|
|
443
|
+
assert type(opt).__name__ == "Elemwise"
|
|
444
|
+
# Input shapes should be sliced appropriately
|
|
445
|
+
assert opt.elemwise_args[0].shape == (5, 1)
|
|
446
|
+
assert opt.elemwise_args[1].shape == (1, 10)
|
|
447
|
+
assert_eq(result, (a_np + b_np)[:5, :10])
|
|
448
|
+
|
|
449
|
+
# Rechunk pushdown - each input rechunked on its non-broadcast dimension
|
|
450
|
+
result = (a + b).rechunk((2, 5))
|
|
451
|
+
opt = result.expr.optimize()
|
|
452
|
+
assert type(opt).__name__ == "Elemwise"
|
|
453
|
+
# Input chunks should be rechunked appropriately
|
|
454
|
+
assert opt.elemwise_args[0].chunks == ((2, 2, 2, 2, 2), (1,))
|
|
455
|
+
assert opt.elemwise_args[1].chunks == ((1,), (5, 5, 5, 5))
|
|
456
|
+
assert_eq(result, a_np + b_np)
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def test_rechunk_pushdown_to_io():
|
|
460
|
+
"""Rechunk should push down into FromArray by changing chunks parameter."""
|
|
461
|
+
from dask_array.io import FromArray
|
|
462
|
+
|
|
463
|
+
a = np.random.random((10, 10))
|
|
464
|
+
b = da.from_array(a, chunks=(4, 4))
|
|
465
|
+
|
|
466
|
+
result = b.rechunk((5, 2)).expr.optimize()
|
|
467
|
+
expected = da.from_array(a, chunks=((5, 5), (2, 2, 2, 2, 2))).expr
|
|
468
|
+
|
|
469
|
+
# Both should be FromArray with matching structure
|
|
470
|
+
assert type(result) is FromArray
|
|
471
|
+
assert result._name == expected._name
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
def test_rechunk_chain_optimize():
|
|
475
|
+
"""Chained rechunks should collapse to single rechunk pushed to IO."""
|
|
476
|
+
from dask_array.io import FromArray
|
|
477
|
+
|
|
478
|
+
a = np.random.random((10, 10))
|
|
479
|
+
b = da.from_array(a, chunks=(4, 4))
|
|
480
|
+
|
|
481
|
+
result = b.rechunk((2, 5)).rechunk((5, 2)).expr.optimize()
|
|
482
|
+
expected = da.from_array(a, chunks=((5, 5), (2, 2, 2, 2, 2))).expr
|
|
483
|
+
|
|
484
|
+
# Both rechunks eliminated, just FromArray
|
|
485
|
+
assert type(result) is FromArray
|
|
486
|
+
assert result._name == expected._name
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
def test_rechunk_transpose_pushdown_to_io():
|
|
490
|
+
"""Rechunk after transpose should push through to IO."""
|
|
491
|
+
from dask_array.io import FromArray
|
|
492
|
+
from dask_array.manipulation._transpose import Transpose
|
|
493
|
+
|
|
494
|
+
a = np.random.random((10, 10))
|
|
495
|
+
b = da.from_array(a, chunks=(4, 4))
|
|
496
|
+
|
|
497
|
+
result = b.T.rechunk((5, 2)).expr.optimize()
|
|
498
|
+
# Rechunk pushed through transpose: input rechunked to (2, 5) then transposed
|
|
499
|
+
expected = da.from_array(a, chunks=((2, 2, 2, 2, 2), (5, 5))).T.expr
|
|
500
|
+
|
|
501
|
+
assert type(result) is Transpose
|
|
502
|
+
assert type(result.array) is FromArray
|
|
503
|
+
assert result._name == expected._name
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def test_rechunk_elemwise_pushdown_to_io():
|
|
507
|
+
"""Rechunk after elemwise should push through to IO inputs."""
|
|
508
|
+
from dask_array._blockwise import Elemwise
|
|
509
|
+
from dask_array.io import FromArray
|
|
510
|
+
|
|
511
|
+
a = np.random.random((10, 10))
|
|
512
|
+
b = da.from_array(a, chunks=(4, 4))
|
|
513
|
+
|
|
514
|
+
result = (b + 1).rechunk((5, 5)).expr.optimize()
|
|
515
|
+
|
|
516
|
+
# Rechunk pushed through elemwise into FromArray
|
|
517
|
+
assert type(result) is Elemwise
|
|
518
|
+
assert type(result.elemwise_args[0]) is FromArray
|
|
519
|
+
assert result.elemwise_args[0].chunks == ((5, 5), (5, 5))
|
|
520
|
+
# Verify the prefix is preserved
|
|
521
|
+
assert result.elemwise_args[0].name.startswith("array-")
|
|
522
|
+
|
|
523
|
+
|
|
524
|
+
def test_rechunk_pushdown_concatenate_other_axis():
|
|
525
|
+
"""Rechunk pushes through concatenate when rechunking non-concat axis."""
|
|
526
|
+
a = da.ones((10, 20), chunks=(5, 10))
|
|
527
|
+
b = da.ones((10, 20), chunks=(5, 10))
|
|
528
|
+
concat = da.concatenate([a, b], axis=0) # shape (20, 20)
|
|
529
|
+
|
|
530
|
+
# Rechunk axis 1 (not concat axis)
|
|
531
|
+
result = concat.rechunk({1: 5})
|
|
532
|
+
|
|
533
|
+
# Expected: rechunk pushed to inputs
|
|
534
|
+
expected = da.concatenate([a.rechunk({1: 5}), b.rechunk({1: 5})], axis=0)
|
|
535
|
+
|
|
536
|
+
# Structure should match
|
|
537
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
538
|
+
assert_eq(result, expected)
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
def test_rechunk_pushdown_concatenate_correctness():
|
|
542
|
+
"""Verify rechunk through concatenate produces correct values with real data."""
|
|
543
|
+
a = np.arange(20).reshape(4, 5)
|
|
544
|
+
b = np.arange(20, 40).reshape(4, 5)
|
|
545
|
+
da_a = da.from_array(a, chunks=(2, 3))
|
|
546
|
+
da_b = da.from_array(b, chunks=(2, 3))
|
|
547
|
+
|
|
548
|
+
concat = da.concatenate([da_a, da_b], axis=0) # shape (8, 5)
|
|
549
|
+
|
|
550
|
+
# Rechunk non-concat axis
|
|
551
|
+
result = concat.rechunk({1: 2})
|
|
552
|
+
expected = da.concatenate([da_a.rechunk({1: 2}), da_b.rechunk({1: 2})], axis=0)
|
|
553
|
+
|
|
554
|
+
# Structure should match
|
|
555
|
+
assert result.expr.simplify()._name == expected.expr.simplify()._name
|
|
556
|
+
assert_eq(result, np.concatenate([a, b], axis=0))
|
|
557
|
+
|
|
558
|
+
|
|
559
|
+
# --- Fusion regression tests ---
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def test_fusion_broadcast_modulo():
|
|
563
|
+
"""Test that fusion handles broadcasting correctly with modulo.
|
|
564
|
+
|
|
565
|
+
When fusing operations where one array broadcasts (has fewer blocks),
|
|
566
|
+
the block indices must use modulo to wrap around correctly.
|
|
567
|
+
This is a regression test for matmul-like operations.
|
|
568
|
+
"""
|
|
569
|
+
# 1D array broadcasting to 2D - simulates matmul broadcast pattern
|
|
570
|
+
a = da.from_array(np.arange(6).reshape(2, 3), chunks=(1, 3))
|
|
571
|
+
b = da.from_array(np.arange(3), chunks=3)
|
|
572
|
+
|
|
573
|
+
# b broadcasts: it has 1 block but a has 2 blocks in first dimension
|
|
574
|
+
result = a * b # Elemwise with broadcast
|
|
575
|
+
assert_eq(result, np.arange(6).reshape(2, 3) * np.arange(3))
|
|
576
|
+
|
|
577
|
+
# Test that the fused graph computes correctly
|
|
578
|
+
opt = result.expr.optimize(fuse=True)
|
|
579
|
+
assert_eq(da.Array(opt), np.arange(6).reshape(2, 3) * np.arange(3))
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def test_fusion_same_array_different_indices():
|
|
583
|
+
"""Test conflict detection when same array used with different indices.
|
|
584
|
+
|
|
585
|
+
When the same array appears multiple times in a computation with
|
|
586
|
+
different index mappings (e.g., da.dot(x, x)), fusion must detect
|
|
587
|
+
this conflict and exclude the conflicting expression.
|
|
588
|
+
"""
|
|
589
|
+
# da.dot(x, x) uses x with indices 'ij' and 'jk', different mappings
|
|
590
|
+
x = da.from_array(np.arange(9).reshape(3, 3), chunks=(2, 2))
|
|
591
|
+
x_np = x.compute()
|
|
592
|
+
|
|
593
|
+
result = da.dot(x, x)
|
|
594
|
+
expected = np.dot(x_np, x_np)
|
|
595
|
+
assert_eq(result, expected)
|
|
596
|
+
|
|
597
|
+
# Test with persist (triggers the conflict path during fusion)
|
|
598
|
+
result_persisted = result.persist()
|
|
599
|
+
assert_eq(result_persisted, expected)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def test_fusion_elemwise_with_out_and_where_true():
|
|
603
|
+
"""Test that out arrays don't break fusion when where=True.
|
|
604
|
+
|
|
605
|
+
When an Elemwise has out=array but where=True (the default),
|
|
606
|
+
the out array should not be a dependency since it's not used
|
|
607
|
+
in the computation - it's just a placeholder for the result.
|
|
608
|
+
"""
|
|
609
|
+
a = da.from_array(np.arange(4), chunks=2)
|
|
610
|
+
b = da.from_array(np.arange(4, 8), chunks=2)
|
|
611
|
+
out = da.zeros(4, chunks=2)
|
|
612
|
+
|
|
613
|
+
# When where=True (default), out is just a placeholder
|
|
614
|
+
result = da.add(a, b, out=out)
|
|
615
|
+
assert result is out
|
|
616
|
+
|
|
617
|
+
# Should compute correctly despite fusion
|
|
618
|
+
assert_eq(result, np.arange(4) + np.arange(4, 8))
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
def test_fusion_elemwise_with_out_and_where_array():
|
|
622
|
+
"""Test that out arrays are properly used when where is an array.
|
|
623
|
+
|
|
624
|
+
When where is a mask array (not True), the out array IS used
|
|
625
|
+
as a dependency and should participate in the computation.
|
|
626
|
+
"""
|
|
627
|
+
a = da.from_array(np.arange(4), chunks=2)
|
|
628
|
+
b = da.from_array(np.arange(4, 8), chunks=2)
|
|
629
|
+
where = da.from_array(np.array([True, False, True, False]), chunks=2)
|
|
630
|
+
out = da.zeros(4, dtype=int, chunks=2)
|
|
631
|
+
|
|
632
|
+
result = da.add(a, b, where=where, out=out)
|
|
633
|
+
assert result is out
|
|
634
|
+
|
|
635
|
+
# Should compute correctly: only positions where=True get the sum
|
|
636
|
+
expected = np.zeros(4, dtype=int)
|
|
637
|
+
np.add(
|
|
638
|
+
np.arange(4),
|
|
639
|
+
np.arange(4, 8),
|
|
640
|
+
where=np.array([True, False, True, False]),
|
|
641
|
+
out=expected,
|
|
642
|
+
)
|
|
643
|
+
assert_eq(result, expected)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
def test_fusion_out_same_as_input():
|
|
647
|
+
"""Test that out=x works when x is also an input argument.
|
|
648
|
+
|
|
649
|
+
When out is the same array as an input (e.g., np.sin(x, out=x)),
|
|
650
|
+
we must NOT exclude it from dependencies since it's actually used.
|
|
651
|
+
"""
|
|
652
|
+
x = da.from_array(np.array([0.0, 0.5, 1.0, 1.5]), chunks=2)
|
|
653
|
+
x_np = x.compute().copy()
|
|
654
|
+
|
|
655
|
+
# In-place operation: out is same as input
|
|
656
|
+
result = np.sin(x, out=x)
|
|
657
|
+
assert result is x
|
|
658
|
+
|
|
659
|
+
expected = np.sin(x_np, out=x_np)
|
|
660
|
+
assert_eq(result, expected)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def test_fusion_transpose_conflict():
|
|
664
|
+
"""Test conflict detection for a + a.T pattern.
|
|
665
|
+
|
|
666
|
+
When the same array is accessed both directly and transposed,
|
|
667
|
+
fusion must detect this conflict since different output blocks
|
|
668
|
+
would need different source blocks from the same expression.
|
|
669
|
+
"""
|
|
670
|
+
a = da.from_array(np.arange(9).reshape(3, 3), chunks=(2, 2))
|
|
671
|
+
a_np = a.compute()
|
|
672
|
+
|
|
673
|
+
# a + a.T accesses 'a' with different index mappings
|
|
674
|
+
result = a + a.T
|
|
675
|
+
expected = a_np + a_np.T
|
|
676
|
+
assert_eq(result, expected)
|
|
677
|
+
|
|
678
|
+
# Verify fusion handles this correctly
|
|
679
|
+
opt = result.expr.optimize(fuse=True)
|
|
680
|
+
assert_eq(da.Array(opt), expected)
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
def test_fusion_chained_transpose():
|
|
684
|
+
"""Test fusion with chained transpose operations.
|
|
685
|
+
|
|
686
|
+
Operations like (a + b).T should fuse correctly since there's
|
|
687
|
+
no conflict - just a consistent dimension permutation.
|
|
688
|
+
"""
|
|
689
|
+
a = da.from_array(np.arange(6).reshape(2, 3), chunks=(1, 2))
|
|
690
|
+
b = da.from_array(np.arange(6, 12).reshape(2, 3), chunks=(1, 2))
|
|
691
|
+
a_np, b_np = a.compute(), b.compute()
|
|
692
|
+
|
|
693
|
+
result = (a + b).T
|
|
694
|
+
expected = (a_np + b_np).T
|
|
695
|
+
assert_eq(result, expected)
|
|
696
|
+
|
|
697
|
+
# Should fuse the add and transpose
|
|
698
|
+
opt = result.expr.optimize(fuse=True)
|
|
699
|
+
assert_eq(da.Array(opt), expected)
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def test_reduction_scalar_aggregate_meta():
|
|
703
|
+
"""Regression test: reduction handles aggregate returning Python scalar.
|
|
704
|
+
|
|
705
|
+
When a custom aggregate function returns a Python scalar instead of
|
|
706
|
+
preserving array dimensions, the meta computation must not fail.
|
|
707
|
+
Previously failed with:
|
|
708
|
+
ValueError: cannot reshape array of size 1 into shape (0,0)
|
|
709
|
+
"""
|
|
710
|
+
arr = da.ones((10, 5, 5), chunks=(5, 5, 5))
|
|
711
|
+
|
|
712
|
+
# Custom aggregate that returns Python int (not numpy array)
|
|
713
|
+
def scalar_agg(x, axis=None, keepdims=False):
|
|
714
|
+
return 42
|
|
715
|
+
|
|
716
|
+
# Should not raise ValueError when accessing _meta
|
|
717
|
+
result = da.reduction(
|
|
718
|
+
arr,
|
|
719
|
+
chunk=np.sum,
|
|
720
|
+
aggregate=scalar_agg,
|
|
721
|
+
axis=0,
|
|
722
|
+
dtype=float,
|
|
723
|
+
)
|
|
724
|
+
assert result._meta.shape == (0, 0)
|
|
725
|
+
assert result._meta.dtype == np.float64
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
def test_fusion_blockwise_contracted_dimensions():
|
|
729
|
+
"""Test fusion with Blockwise that has contracted dimensions.
|
|
730
|
+
|
|
731
|
+
When a Blockwise expression has indices in input that are not in output
|
|
732
|
+
(contracted dimensions), the fusion must correctly handle block lookups.
|
|
733
|
+
|
|
734
|
+
This is a regression test for xarray integration where groupby operations
|
|
735
|
+
create Blockwise with out_ind=(2,) for 1D output from 3D input with
|
|
736
|
+
ind=(0, 1, 2). When fused with Elemwise (out_ind=(0,)), the idx_to_block
|
|
737
|
+
mapping must correctly handle the contracted dimensions 0 and 1.
|
|
738
|
+
|
|
739
|
+
Previously failed with KeyError: 0 in FusedBlockwise._task().
|
|
740
|
+
"""
|
|
741
|
+
from dask_array._blockwise import FusedBlockwise
|
|
742
|
+
|
|
743
|
+
# Create 3D array with single blocks in contracted dimensions
|
|
744
|
+
arr_3d = da.from_array(np.ones((1, 1, 3)), chunks=(1, 1, 1))
|
|
745
|
+
|
|
746
|
+
# Blockwise that reduces dims 0 and 1, keeps dim 2 as output
|
|
747
|
+
# out_ind=(2,) means output indexed by input's dimension 2
|
|
748
|
+
result = da.blockwise(
|
|
749
|
+
lambda x: x.mean(axis=(0, 1)),
|
|
750
|
+
(2,), # out_ind - output dimension comes from input dim 2
|
|
751
|
+
arr_3d.expr,
|
|
752
|
+
(0, 1, 2), # ind - input has all 3 dimensions
|
|
753
|
+
dtype=arr_3d.dtype,
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
# Verify Blockwise is fusable when contracted dims have single blocks
|
|
757
|
+
assert result.expr._is_blockwise_fusable
|
|
758
|
+
|
|
759
|
+
# Elemwise comparison - has out_ind=(0,)
|
|
760
|
+
expected = np.array([1.0, 1.0, 1.0])
|
|
761
|
+
close = da.isclose(result, expected)
|
|
762
|
+
|
|
763
|
+
# Should fuse Elemwise (out_ind=(0,)) with Blockwise (out_ind=(2,))
|
|
764
|
+
optimized = close.expr.optimize(fuse=True)
|
|
765
|
+
assert isinstance(optimized, FusedBlockwise)
|
|
766
|
+
|
|
767
|
+
# Verify correct computation
|
|
768
|
+
assert_eq(close, np.array([True, True, True]))
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
def test_fusion_blockwise_multiblock_contracted_prevents_fusion():
|
|
772
|
+
"""Test that Blockwise with multi-block contracted dims isn't fusable.
|
|
773
|
+
|
|
774
|
+
When a Blockwise has contracted dimensions (in input but not output) with
|
|
775
|
+
multiple blocks, fusion is not possible since each output block would need
|
|
776
|
+
to reference multiple input blocks from the contracted dimension.
|
|
777
|
+
"""
|
|
778
|
+
from dask_array._blockwise import FusedBlockwise
|
|
779
|
+
|
|
780
|
+
# Create 3D array with multiple blocks in contracted dimension 0
|
|
781
|
+
arr_3d = da.from_array(np.ones((2, 1, 3)), chunks=(1, 1, 1))
|
|
782
|
+
|
|
783
|
+
result = da.blockwise(
|
|
784
|
+
lambda x: x.sum(),
|
|
785
|
+
(2,), # output indexed by dim 2
|
|
786
|
+
arr_3d.expr,
|
|
787
|
+
(0, 1, 2),
|
|
788
|
+
dtype=arr_3d.dtype,
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
# Should NOT be fusable due to multi-block contracted dimension
|
|
792
|
+
assert not result.expr._is_blockwise_fusable
|
|
793
|
+
|
|
794
|
+
# Elemwise wrapping the Blockwise
|
|
795
|
+
close = da.isclose(result, np.array([1.0, 1.0, 1.0]))
|
|
796
|
+
|
|
797
|
+
# Should NOT fuse since Blockwise isn't fusable
|
|
798
|
+
optimized = close.expr.optimize(fuse=True)
|
|
799
|
+
assert not isinstance(optimized, FusedBlockwise)
|