dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_rechunk.py
ADDED
|
@@ -0,0 +1,1050 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import heapq
|
|
4
|
+
import itertools
|
|
5
|
+
import math
|
|
6
|
+
import operator
|
|
7
|
+
from functools import reduce
|
|
8
|
+
from itertools import chain, product
|
|
9
|
+
from operator import add, itemgetter, mul
|
|
10
|
+
from warnings import warn
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import toolz
|
|
14
|
+
from tlz import accumulate
|
|
15
|
+
|
|
16
|
+
from dask import config
|
|
17
|
+
from dask._task_spec import Alias, List, Task, TaskRef
|
|
18
|
+
from dask.base import tokenize
|
|
19
|
+
from dask.utils import cached_property, parse_bytes
|
|
20
|
+
|
|
21
|
+
from dask_array._expr import ArrayExpr
|
|
22
|
+
from dask_array._core_utils import concatenate3, normalize_chunks
|
|
23
|
+
from dask_array._utils import validate_axis
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# ============================================================================
|
|
27
|
+
# Rechunk planning utilities (copied from dask.array.rechunk)
|
|
28
|
+
# ============================================================================
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def cumdims_label(chunks, const):
|
|
32
|
+
"""Internal utility for cumulative sum with label.
|
|
33
|
+
|
|
34
|
+
>>> cumdims_label(((5, 3, 3), (2, 2, 1)), 'n') # doctest: +NORMALIZE_WHITESPACE
|
|
35
|
+
[(('n', 0), ('n', 5), ('n', 8), ('n', 11)),
|
|
36
|
+
(('n', 0), ('n', 2), ('n', 4), ('n', 5))]
|
|
37
|
+
"""
|
|
38
|
+
return [tuple(zip((const,) * (1 + len(bds)), accumulate(add, (0,) + bds))) for bds in chunks]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _breakpoints(cumold, cumnew):
|
|
42
|
+
"""
|
|
43
|
+
>>> new = cumdims_label(((2, 3), (2, 2, 1)), 'n')
|
|
44
|
+
>>> old = cumdims_label(((2, 2, 1), (5,)), 'o')
|
|
45
|
+
|
|
46
|
+
>>> _breakpoints(new[0], old[0])
|
|
47
|
+
(('n', 0), ('o', 0), ('n', 2), ('o', 2), ('o', 4), ('n', 5), ('o', 5))
|
|
48
|
+
>>> _breakpoints(new[1], old[1])
|
|
49
|
+
(('n', 0), ('o', 0), ('n', 2), ('n', 4), ('n', 5), ('o', 5))
|
|
50
|
+
"""
|
|
51
|
+
return tuple(sorted(cumold + cumnew, key=itemgetter(1)))
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _intersect_1d(breaks):
|
|
55
|
+
"""
|
|
56
|
+
Internal utility to intersect chunks for 1d after preprocessing.
|
|
57
|
+
|
|
58
|
+
>>> new = cumdims_label(((2, 3), (2, 2, 1)), 'n')
|
|
59
|
+
>>> old = cumdims_label(((2, 2, 1), (5,)), 'o')
|
|
60
|
+
|
|
61
|
+
>>> _intersect_1d(_breakpoints(old[0], new[0])) # doctest: +NORMALIZE_WHITESPACE
|
|
62
|
+
[[(0, slice(0, 2, None))],
|
|
63
|
+
[(1, slice(0, 2, None)), (2, slice(0, 1, None))]]
|
|
64
|
+
>>> _intersect_1d(_breakpoints(old[1], new[1])) # doctest: +NORMALIZE_WHITESPACE
|
|
65
|
+
[[(0, slice(0, 2, None))],
|
|
66
|
+
[(0, slice(2, 4, None))],
|
|
67
|
+
[(0, slice(4, 5, None))]]
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
breaks: list of tuples
|
|
72
|
+
Each tuple is ('o', 8) or ('n', 8)
|
|
73
|
+
These are pairs of 'o' old or new 'n'
|
|
74
|
+
indicator with a corresponding cumulative sum,
|
|
75
|
+
or breakpoint (a position along the chunking axis).
|
|
76
|
+
The list of pairs is already ordered by breakpoint.
|
|
77
|
+
Note that an 'o' pair always occurs BEFORE
|
|
78
|
+
an 'n' pair if both share the same breakpoint.
|
|
79
|
+
Uses 'o' and 'n' to make new tuples of slices for
|
|
80
|
+
the new block crosswalk to old blocks.
|
|
81
|
+
"""
|
|
82
|
+
o_pairs = [pair for pair in breaks if pair[0] == "o"]
|
|
83
|
+
last_old_chunk_idx = len(o_pairs) - 2
|
|
84
|
+
last_o_br = o_pairs[-1][1]
|
|
85
|
+
|
|
86
|
+
start = 0
|
|
87
|
+
last_end = 0
|
|
88
|
+
old_idx = 0
|
|
89
|
+
last_o_end = 0
|
|
90
|
+
ret = []
|
|
91
|
+
ret_next = []
|
|
92
|
+
for idx in range(1, len(breaks)):
|
|
93
|
+
label, br = breaks[idx]
|
|
94
|
+
last_label, last_br = breaks[idx - 1]
|
|
95
|
+
if last_label == "n":
|
|
96
|
+
start = last_end
|
|
97
|
+
if ret_next:
|
|
98
|
+
ret.append(ret_next)
|
|
99
|
+
ret_next = []
|
|
100
|
+
else:
|
|
101
|
+
start = 0
|
|
102
|
+
end = br - last_br + start
|
|
103
|
+
last_end = end
|
|
104
|
+
if br == last_br:
|
|
105
|
+
if label == "o":
|
|
106
|
+
old_idx += 1
|
|
107
|
+
last_o_end = end
|
|
108
|
+
if label == "n" and last_label == "n":
|
|
109
|
+
if br == last_o_br:
|
|
110
|
+
slc = slice(last_o_end, last_o_end)
|
|
111
|
+
ret_next.append((last_old_chunk_idx, slc))
|
|
112
|
+
continue
|
|
113
|
+
else:
|
|
114
|
+
continue
|
|
115
|
+
ret_next.append((old_idx, slice(start, end)))
|
|
116
|
+
if label == "o":
|
|
117
|
+
old_idx += 1
|
|
118
|
+
start = 0
|
|
119
|
+
last_o_end = end
|
|
120
|
+
|
|
121
|
+
if ret_next:
|
|
122
|
+
ret.append(ret_next)
|
|
123
|
+
|
|
124
|
+
return ret
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def old_to_new(old_chunks, new_chunks):
|
|
128
|
+
"""Helper to build old_chunks to new_chunks.
|
|
129
|
+
|
|
130
|
+
Handles missing values, as long as the dimension with the missing chunk values
|
|
131
|
+
is unchanged.
|
|
132
|
+
|
|
133
|
+
Examples
|
|
134
|
+
--------
|
|
135
|
+
>>> old = ((10, 10, 10, 10, 10), )
|
|
136
|
+
>>> new = ((25, 5, 20), )
|
|
137
|
+
>>> old_to_new(old, new) # doctest: +NORMALIZE_WHITESPACE
|
|
138
|
+
[[[(0, slice(0, 10, None)), (1, slice(0, 10, None)), (2, slice(0, 5, None))],
|
|
139
|
+
[(2, slice(5, 10, None))],
|
|
140
|
+
[(3, slice(0, 10, None)), (4, slice(0, 10, None))]]]
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def is_unknown(dim):
|
|
144
|
+
return any(math.isnan(chunk) for chunk in dim)
|
|
145
|
+
|
|
146
|
+
dims_unknown = [is_unknown(dim) for dim in old_chunks]
|
|
147
|
+
|
|
148
|
+
known_indices = []
|
|
149
|
+
unknown_indices = []
|
|
150
|
+
for i, unknown in enumerate(dims_unknown):
|
|
151
|
+
if unknown:
|
|
152
|
+
unknown_indices.append(i)
|
|
153
|
+
else:
|
|
154
|
+
known_indices.append(i)
|
|
155
|
+
|
|
156
|
+
old_known = [old_chunks[i] for i in known_indices]
|
|
157
|
+
new_known = [new_chunks[i] for i in known_indices]
|
|
158
|
+
|
|
159
|
+
cmos = cumdims_label(old_known, "o")
|
|
160
|
+
cmns = cumdims_label(new_known, "n")
|
|
161
|
+
|
|
162
|
+
sliced = [None] * len(old_chunks)
|
|
163
|
+
for i, cmo, cmn in zip(known_indices, cmos, cmns):
|
|
164
|
+
sliced[i] = _intersect_1d(_breakpoints(cmo, cmn))
|
|
165
|
+
|
|
166
|
+
for i in unknown_indices:
|
|
167
|
+
dim = old_chunks[i]
|
|
168
|
+
extra = [[(j, slice(0, size if not math.isnan(size) else None))] for j, size in enumerate(dim)]
|
|
169
|
+
sliced[i] = extra
|
|
170
|
+
assert all(x is not None for x in sliced)
|
|
171
|
+
return sliced
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def intersect_chunks(old_chunks, new_chunks):
|
|
175
|
+
"""
|
|
176
|
+
Make dask.array slices as intersection of old and new chunks.
|
|
177
|
+
|
|
178
|
+
>>> intersections = intersect_chunks(((4, 4), (2,)),
|
|
179
|
+
... ((8,), (1, 1)))
|
|
180
|
+
>>> list(intersections) # doctest: +NORMALIZE_WHITESPACE
|
|
181
|
+
[(((0, slice(0, 4, None)), (0, slice(0, 1, None))),
|
|
182
|
+
((1, slice(0, 4, None)), (0, slice(0, 1, None)))),
|
|
183
|
+
(((0, slice(0, 4, None)), (0, slice(1, 2, None))),
|
|
184
|
+
((1, slice(0, 4, None)), (0, slice(1, 2, None))))]
|
|
185
|
+
|
|
186
|
+
Parameters
|
|
187
|
+
----------
|
|
188
|
+
old_chunks : iterable of tuples
|
|
189
|
+
block sizes along each dimension (convert from old_chunks)
|
|
190
|
+
new_chunks: iterable of tuples
|
|
191
|
+
block sizes along each dimension (converts to new_chunks)
|
|
192
|
+
"""
|
|
193
|
+
cross1 = product(*old_to_new(old_chunks, new_chunks))
|
|
194
|
+
cross = chain(tuple(product(*cr)) for cr in cross1)
|
|
195
|
+
return cross
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def _validate_rechunk(old_chunks, new_chunks):
|
|
199
|
+
"""Validates that rechunking an array from ``old_chunks`` to ``new_chunks``
|
|
200
|
+
is possible, raises an error if otherwise.
|
|
201
|
+
"""
|
|
202
|
+
assert len(old_chunks) == len(new_chunks)
|
|
203
|
+
|
|
204
|
+
old_shapes = tuple(map(sum, old_chunks))
|
|
205
|
+
new_shapes = tuple(map(sum, new_chunks))
|
|
206
|
+
|
|
207
|
+
for old_shape, old_dim, new_shape, new_dim in zip(old_shapes, old_chunks, new_shapes, new_chunks):
|
|
208
|
+
if old_shape != new_shape:
|
|
209
|
+
if not (math.isnan(old_shape) and math.isnan(new_shape)) or not np.array_equal(
|
|
210
|
+
old_dim, new_dim, equal_nan=True
|
|
211
|
+
):
|
|
212
|
+
raise ValueError(
|
|
213
|
+
"Chunks must be unchanging along dimensions with missing values.\n\n"
|
|
214
|
+
"A possible solution:\n x.compute_chunk_sizes()"
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _number_of_blocks(chunks):
|
|
219
|
+
return reduce(mul, map(len, chunks))
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _largest_block_size(chunks):
|
|
223
|
+
return reduce(mul, map(max, chunks))
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def estimate_graph_size(old_chunks, new_chunks):
|
|
227
|
+
"""Estimate the graph size during a rechunk computation."""
|
|
228
|
+
crossed_size = reduce(
|
|
229
|
+
mul,
|
|
230
|
+
((len(oc) + len(nc) - 1 if oc != nc else len(oc)) for oc, nc in zip(old_chunks, new_chunks)),
|
|
231
|
+
)
|
|
232
|
+
return crossed_size
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def divide_to_width(desired_chunks, max_width):
|
|
236
|
+
"""Minimally divide the given chunks so as to make the largest chunk
|
|
237
|
+
width less or equal than *max_width*.
|
|
238
|
+
"""
|
|
239
|
+
chunks = []
|
|
240
|
+
for c in desired_chunks:
|
|
241
|
+
nb_divides = int(np.ceil(c / max_width))
|
|
242
|
+
for i in range(nb_divides):
|
|
243
|
+
n = c // (nb_divides - i)
|
|
244
|
+
chunks.append(n)
|
|
245
|
+
c -= n
|
|
246
|
+
assert c == 0
|
|
247
|
+
return tuple(chunks)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def merge_to_number(desired_chunks, max_number):
|
|
251
|
+
"""Minimally merge the given chunks so as to drop the number of
|
|
252
|
+
chunks below *max_number*, while minimizing the largest width.
|
|
253
|
+
"""
|
|
254
|
+
if len(desired_chunks) <= max_number:
|
|
255
|
+
return desired_chunks
|
|
256
|
+
|
|
257
|
+
distinct = set(desired_chunks)
|
|
258
|
+
if len(distinct) == 1:
|
|
259
|
+
w = distinct.pop()
|
|
260
|
+
n = len(desired_chunks)
|
|
261
|
+
total = n * w
|
|
262
|
+
|
|
263
|
+
desired_width = total // max_number
|
|
264
|
+
width = w * (desired_width // w)
|
|
265
|
+
adjust = (total - max_number * width) // w
|
|
266
|
+
|
|
267
|
+
return (width + w,) * adjust + (width,) * (max_number - adjust)
|
|
268
|
+
|
|
269
|
+
desired_width = sum(desired_chunks) // max_number
|
|
270
|
+
nmerges = len(desired_chunks) - max_number
|
|
271
|
+
|
|
272
|
+
heap = [(desired_chunks[i] + desired_chunks[i + 1], i, i + 1) for i in range(len(desired_chunks) - 1)]
|
|
273
|
+
heapq.heapify(heap)
|
|
274
|
+
|
|
275
|
+
chunks = list(desired_chunks)
|
|
276
|
+
|
|
277
|
+
while nmerges > 0:
|
|
278
|
+
width, i, j = heapq.heappop(heap)
|
|
279
|
+
if chunks[j] == 0:
|
|
280
|
+
j += 1
|
|
281
|
+
while chunks[j] == 0:
|
|
282
|
+
j += 1
|
|
283
|
+
heapq.heappush(heap, (chunks[i] + chunks[j], i, j))
|
|
284
|
+
continue
|
|
285
|
+
elif chunks[i] + chunks[j] != width:
|
|
286
|
+
heapq.heappush(heap, (chunks[i] + chunks[j], i, j))
|
|
287
|
+
continue
|
|
288
|
+
assert chunks[i] != 0
|
|
289
|
+
chunks[i] = 0
|
|
290
|
+
chunks[j] = width
|
|
291
|
+
nmerges -= 1
|
|
292
|
+
|
|
293
|
+
return tuple(filter(None, chunks))
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def find_merge_rechunk(old_chunks, new_chunks, block_size_limit):
|
|
297
|
+
"""
|
|
298
|
+
Find an intermediate rechunk that would merge some adjacent blocks
|
|
299
|
+
together in order to get us nearer the *new_chunks* target, without
|
|
300
|
+
violating the *block_size_limit* (in number of elements).
|
|
301
|
+
"""
|
|
302
|
+
ndim = len(old_chunks)
|
|
303
|
+
|
|
304
|
+
old_largest_width = [max(c) for c in old_chunks]
|
|
305
|
+
new_largest_width = [max(c) for c in new_chunks]
|
|
306
|
+
|
|
307
|
+
graph_size_effect = {dim: len(nc) / len(oc) for dim, (oc, nc) in enumerate(zip(old_chunks, new_chunks))}
|
|
308
|
+
|
|
309
|
+
block_size_effect = {dim: new_largest_width[dim] / (old_largest_width[dim] or 1) for dim in range(ndim)}
|
|
310
|
+
|
|
311
|
+
merge_candidates = [dim for dim in range(ndim) if graph_size_effect[dim] <= 1.0]
|
|
312
|
+
|
|
313
|
+
def key(k):
|
|
314
|
+
gse = graph_size_effect[k]
|
|
315
|
+
bse = block_size_effect[k]
|
|
316
|
+
if bse == 1:
|
|
317
|
+
bse = 1 + 1e-9
|
|
318
|
+
return (np.log(gse) / np.log(bse)) if bse > 0 else 0
|
|
319
|
+
|
|
320
|
+
sorted_candidates = sorted(merge_candidates, key=key)
|
|
321
|
+
|
|
322
|
+
largest_block_size = reduce(mul, old_largest_width)
|
|
323
|
+
|
|
324
|
+
chunks = list(old_chunks)
|
|
325
|
+
memory_limit_hit = False
|
|
326
|
+
|
|
327
|
+
for dim in sorted_candidates:
|
|
328
|
+
new_largest_block_size = largest_block_size * new_largest_width[dim] // (old_largest_width[dim] or 1)
|
|
329
|
+
if new_largest_block_size <= block_size_limit:
|
|
330
|
+
chunks[dim] = new_chunks[dim]
|
|
331
|
+
largest_block_size = new_largest_block_size
|
|
332
|
+
else:
|
|
333
|
+
largest_width = old_largest_width[dim]
|
|
334
|
+
chunk_limit = int(block_size_limit * largest_width / largest_block_size)
|
|
335
|
+
c = divide_to_width(new_chunks[dim], chunk_limit)
|
|
336
|
+
if len(c) <= len(old_chunks[dim]):
|
|
337
|
+
chunks[dim] = c
|
|
338
|
+
largest_block_size = largest_block_size * max(c) // largest_width
|
|
339
|
+
|
|
340
|
+
memory_limit_hit = True
|
|
341
|
+
|
|
342
|
+
assert largest_block_size == _largest_block_size(chunks)
|
|
343
|
+
assert largest_block_size <= block_size_limit
|
|
344
|
+
return tuple(chunks), memory_limit_hit
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def find_split_rechunk(old_chunks, new_chunks, graph_size_limit):
|
|
348
|
+
"""
|
|
349
|
+
Find an intermediate rechunk that would split some chunks to
|
|
350
|
+
get us nearer *new_chunks*, without violating the *graph_size_limit*.
|
|
351
|
+
"""
|
|
352
|
+
ndim = len(old_chunks)
|
|
353
|
+
|
|
354
|
+
chunks = list(old_chunks)
|
|
355
|
+
|
|
356
|
+
for dim in range(ndim):
|
|
357
|
+
graph_size = estimate_graph_size(chunks, new_chunks)
|
|
358
|
+
if graph_size > graph_size_limit:
|
|
359
|
+
break
|
|
360
|
+
if len(old_chunks[dim]) > len(new_chunks[dim]):
|
|
361
|
+
continue
|
|
362
|
+
max_number = int(len(old_chunks[dim]) * graph_size_limit / graph_size)
|
|
363
|
+
c = merge_to_number(new_chunks[dim], max_number)
|
|
364
|
+
assert len(c) <= max_number
|
|
365
|
+
if len(c) >= len(old_chunks[dim]) and max(c) <= max(old_chunks[dim]):
|
|
366
|
+
chunks[dim] = c
|
|
367
|
+
|
|
368
|
+
return tuple(chunks)
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
def _graph_size_threshold(old_chunks, new_chunks, threshold):
|
|
372
|
+
return threshold * (_number_of_blocks(old_chunks) + _number_of_blocks(new_chunks))
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def plan_rechunk(old_chunks, new_chunks, itemsize, threshold=None, block_size_limit=None):
|
|
376
|
+
"""Plan an iterative rechunking from *old_chunks* to *new_chunks*.
|
|
377
|
+
The plan aims to minimize the rechunk graph size.
|
|
378
|
+
|
|
379
|
+
Parameters
|
|
380
|
+
----------
|
|
381
|
+
itemsize: int
|
|
382
|
+
The item size of the array
|
|
383
|
+
threshold: int
|
|
384
|
+
The graph growth factor under which we don't bother
|
|
385
|
+
introducing an intermediate step
|
|
386
|
+
block_size_limit: int
|
|
387
|
+
The maximum block size (in bytes) we want to produce during an
|
|
388
|
+
intermediate step
|
|
389
|
+
"""
|
|
390
|
+
threshold = threshold or config.get("array.rechunk.threshold")
|
|
391
|
+
block_size_limit = block_size_limit or config.get("array.chunk-size")
|
|
392
|
+
if isinstance(block_size_limit, str):
|
|
393
|
+
block_size_limit = parse_bytes(block_size_limit)
|
|
394
|
+
|
|
395
|
+
has_nans = (any(math.isnan(y) for y in x) for x in old_chunks)
|
|
396
|
+
|
|
397
|
+
if len(new_chunks) <= 1 or not all(new_chunks) or any(has_nans):
|
|
398
|
+
return [new_chunks]
|
|
399
|
+
|
|
400
|
+
block_size_limit /= itemsize
|
|
401
|
+
|
|
402
|
+
largest_old_block = _largest_block_size(old_chunks)
|
|
403
|
+
largest_new_block = _largest_block_size(new_chunks)
|
|
404
|
+
block_size_limit = max([block_size_limit, largest_old_block, largest_new_block])
|
|
405
|
+
|
|
406
|
+
graph_size_threshold = _graph_size_threshold(old_chunks, new_chunks, threshold)
|
|
407
|
+
|
|
408
|
+
current_chunks = old_chunks
|
|
409
|
+
first_pass = True
|
|
410
|
+
steps = []
|
|
411
|
+
|
|
412
|
+
while True:
|
|
413
|
+
graph_size = estimate_graph_size(current_chunks, new_chunks)
|
|
414
|
+
if graph_size < graph_size_threshold:
|
|
415
|
+
break
|
|
416
|
+
|
|
417
|
+
if first_pass:
|
|
418
|
+
chunks = current_chunks
|
|
419
|
+
else:
|
|
420
|
+
chunks = find_split_rechunk(current_chunks, new_chunks, graph_size * threshold)
|
|
421
|
+
chunks, memory_limit_hit = find_merge_rechunk(chunks, new_chunks, block_size_limit)
|
|
422
|
+
if (chunks == current_chunks and not first_pass) or chunks == new_chunks:
|
|
423
|
+
break
|
|
424
|
+
if chunks != current_chunks:
|
|
425
|
+
steps.append(chunks)
|
|
426
|
+
current_chunks = chunks
|
|
427
|
+
if not memory_limit_hit:
|
|
428
|
+
break
|
|
429
|
+
first_pass = False
|
|
430
|
+
|
|
431
|
+
return steps + [new_chunks]
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def _get_chunks(n, chunksize):
|
|
435
|
+
leftover = n % chunksize
|
|
436
|
+
n_chunks = n // chunksize
|
|
437
|
+
|
|
438
|
+
chunks = [chunksize] * n_chunks
|
|
439
|
+
if leftover:
|
|
440
|
+
chunks.append(leftover)
|
|
441
|
+
return tuple(chunks)
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
def _balance_chunksizes(chunks: tuple[int, ...]) -> tuple[int, ...]:
|
|
445
|
+
"""
|
|
446
|
+
Balance the chunk sizes
|
|
447
|
+
|
|
448
|
+
Parameters
|
|
449
|
+
----------
|
|
450
|
+
chunks : tuple[int, ...]
|
|
451
|
+
Chunk sizes for Dask array.
|
|
452
|
+
|
|
453
|
+
Returns
|
|
454
|
+
-------
|
|
455
|
+
new_chunks : tuple[int, ...]
|
|
456
|
+
New chunks for Dask array with balanced sizes.
|
|
457
|
+
"""
|
|
458
|
+
median_len = np.median(chunks).astype(int)
|
|
459
|
+
n_chunks = len(chunks)
|
|
460
|
+
eps = median_len // 2
|
|
461
|
+
if min(chunks) <= 0.5 * max(chunks):
|
|
462
|
+
n_chunks -= 1
|
|
463
|
+
|
|
464
|
+
new_chunks = [_get_chunks(sum(chunks), chunk_len) for chunk_len in range(median_len - eps, median_len + eps + 1)]
|
|
465
|
+
possible_chunks = [c for c in new_chunks if len(c) == n_chunks]
|
|
466
|
+
if not len(possible_chunks):
|
|
467
|
+
warn("chunk size balancing not possible with given chunks. Try increasing the chunk size.")
|
|
468
|
+
return chunks
|
|
469
|
+
|
|
470
|
+
diffs = [max(c) - min(c) for c in possible_chunks]
|
|
471
|
+
best_chunk_size = np.argmin(diffs)
|
|
472
|
+
return possible_chunks[best_chunk_size]
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def _choose_rechunk_method(old_chunks, new_chunks, threshold=None):
|
|
476
|
+
if method := config.get("array.rechunk.method", None):
|
|
477
|
+
return method
|
|
478
|
+
try:
|
|
479
|
+
from distributed import default_client
|
|
480
|
+
|
|
481
|
+
default_client()
|
|
482
|
+
except (ImportError, ValueError):
|
|
483
|
+
return "tasks"
|
|
484
|
+
|
|
485
|
+
_old_to_new = old_to_new(old_chunks, new_chunks)
|
|
486
|
+
graph_size = math.prod(sum(len(ins) for ins in axis) for axis in _old_to_new)
|
|
487
|
+
threshold = threshold or config.get("array.rechunk.threshold")
|
|
488
|
+
graph_size_threshold = _graph_size_threshold(old_chunks, new_chunks, threshold)
|
|
489
|
+
return "tasks" if graph_size < graph_size_threshold else "p2p"
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
# ============================================================================
|
|
493
|
+
# Expression classes
|
|
494
|
+
# ============================================================================
|
|
495
|
+
|
|
496
|
+
|
|
497
|
+
class Rechunk(ArrayExpr):
|
|
498
|
+
_parameters = [
|
|
499
|
+
"array",
|
|
500
|
+
"_chunks",
|
|
501
|
+
"threshold",
|
|
502
|
+
"block_size_limit",
|
|
503
|
+
"balance",
|
|
504
|
+
"method",
|
|
505
|
+
]
|
|
506
|
+
|
|
507
|
+
_defaults = {
|
|
508
|
+
"_chunks": "auto",
|
|
509
|
+
"threshold": None,
|
|
510
|
+
"block_size_limit": None,
|
|
511
|
+
"balance": None,
|
|
512
|
+
"method": None,
|
|
513
|
+
}
|
|
514
|
+
|
|
515
|
+
@property
|
|
516
|
+
def _meta(self):
|
|
517
|
+
return self.array._meta
|
|
518
|
+
|
|
519
|
+
@property
|
|
520
|
+
def _name(self):
|
|
521
|
+
return "rechunk-merge-" + tokenize(*self.operands)
|
|
522
|
+
|
|
523
|
+
@cached_property
|
|
524
|
+
def chunks(self):
|
|
525
|
+
x = self.array
|
|
526
|
+
chunks = self.operand("_chunks")
|
|
527
|
+
|
|
528
|
+
# don't rechunk if array is empty
|
|
529
|
+
if x.ndim > 0 and all(s == 0 for s in x.shape):
|
|
530
|
+
return x.chunks
|
|
531
|
+
|
|
532
|
+
if isinstance(chunks, dict):
|
|
533
|
+
chunks = {validate_axis(c, x.ndim): v for c, v in chunks.items()}
|
|
534
|
+
for i in range(x.ndim):
|
|
535
|
+
if i not in chunks:
|
|
536
|
+
chunks[i] = x.chunks[i]
|
|
537
|
+
elif chunks[i] is None:
|
|
538
|
+
chunks[i] = x.chunks[i]
|
|
539
|
+
if isinstance(chunks, (tuple, list)):
|
|
540
|
+
chunks = tuple(lc if lc is not None else rc for lc, rc in zip(chunks, x.chunks))
|
|
541
|
+
chunks = normalize_chunks(
|
|
542
|
+
chunks,
|
|
543
|
+
x.shape,
|
|
544
|
+
limit=self.block_size_limit,
|
|
545
|
+
dtype=x.dtype,
|
|
546
|
+
previous_chunks=x.chunks,
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
if not len(chunks) == x.ndim:
|
|
550
|
+
raise ValueError("Provided chunks are not consistent with shape")
|
|
551
|
+
|
|
552
|
+
if self.balance:
|
|
553
|
+
chunks = tuple(_balance_chunksizes(chunk) for chunk in chunks)
|
|
554
|
+
|
|
555
|
+
_validate_rechunk(x.chunks, chunks)
|
|
556
|
+
|
|
557
|
+
return chunks
|
|
558
|
+
|
|
559
|
+
def _simplify_down(self):
|
|
560
|
+
# No-op rechunk: if chunks already match, return the original array
|
|
561
|
+
if not self.balance and self.chunks == self.array.chunks:
|
|
562
|
+
return self.array
|
|
563
|
+
|
|
564
|
+
from dask_array._blockwise import Elemwise
|
|
565
|
+
from dask_array.manipulation._transpose import Transpose
|
|
566
|
+
|
|
567
|
+
# Rechunk(Rechunk(x)) -> single Rechunk to final chunks
|
|
568
|
+
# Only match Rechunk, not TasksRechunk (which is already lowered)
|
|
569
|
+
# Don't merge if inner has method='p2p' - preserve explicit p2p semantics
|
|
570
|
+
if type(self.array) is Rechunk and self.array.method != "p2p":
|
|
571
|
+
return Rechunk(
|
|
572
|
+
self.array.array,
|
|
573
|
+
self._chunks,
|
|
574
|
+
self.threshold,
|
|
575
|
+
self.block_size_limit,
|
|
576
|
+
self.balance or self.array.balance,
|
|
577
|
+
self.method,
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
# Rechunk(Transpose) -> Transpose(rechunked input)
|
|
581
|
+
if isinstance(self.array, Transpose):
|
|
582
|
+
return self._pushdown_through_transpose()
|
|
583
|
+
|
|
584
|
+
# Rechunk(Elemwise) -> Elemwise(rechunked inputs)
|
|
585
|
+
if isinstance(self.array, Elemwise):
|
|
586
|
+
return self._pushdown_through_elemwise()
|
|
587
|
+
|
|
588
|
+
# Rechunk(Concatenate) -> Concatenate(rechunked inputs)
|
|
589
|
+
# Only for non-concat axes
|
|
590
|
+
from dask_array._concatenate import Concatenate
|
|
591
|
+
|
|
592
|
+
if isinstance(self.array, Concatenate):
|
|
593
|
+
return self._pushdown_through_concatenate()
|
|
594
|
+
|
|
595
|
+
# Rechunk(IO) -> IO with new chunks (if IO supports it)
|
|
596
|
+
# Skip if method='p2p' is explicitly requested - user wants distributed shuffle
|
|
597
|
+
if getattr(self.array, "_can_rechunk_pushdown", False) and self.method != "p2p":
|
|
598
|
+
# Keep the same name prefix - the token will change with the new chunks
|
|
599
|
+
return self.array.substitute_parameters({"_chunks": self.chunks})
|
|
600
|
+
|
|
601
|
+
def _pushdown_through_transpose(self):
|
|
602
|
+
"""Push rechunk through transpose by reordering chunk spec."""
|
|
603
|
+
from dask_array.manipulation._transpose import Transpose
|
|
604
|
+
|
|
605
|
+
transpose = self.array
|
|
606
|
+
axes = transpose.axes
|
|
607
|
+
chunks = self._chunks
|
|
608
|
+
|
|
609
|
+
if isinstance(chunks, tuple):
|
|
610
|
+
# Map output chunks back through transpose axes to get input chunks
|
|
611
|
+
# axes[i] tells us which input axis becomes output axis i
|
|
612
|
+
# So output axis i has chunks[i], which should go to input axis axes[i]
|
|
613
|
+
# We need to invert the permutation: place chunks[i] at position axes[i]
|
|
614
|
+
new_chunks = [None] * len(axes)
|
|
615
|
+
for i, ax in enumerate(axes):
|
|
616
|
+
new_chunks[ax] = chunks[i]
|
|
617
|
+
new_chunks = tuple(new_chunks)
|
|
618
|
+
elif isinstance(chunks, dict):
|
|
619
|
+
# Map dict keys through axes
|
|
620
|
+
new_chunks = {}
|
|
621
|
+
for out_axis, chunk_spec in chunks.items():
|
|
622
|
+
in_axis = axes[out_axis]
|
|
623
|
+
new_chunks[in_axis] = chunk_spec
|
|
624
|
+
else:
|
|
625
|
+
return None
|
|
626
|
+
|
|
627
|
+
rechunked_input = transpose.array.rechunk(new_chunks)
|
|
628
|
+
return Transpose(rechunked_input, axes)
|
|
629
|
+
|
|
630
|
+
def _pushdown_through_elemwise(self):
|
|
631
|
+
"""Push rechunk through elemwise by rechunking each input."""
|
|
632
|
+
from dask_array._blockwise import Elemwise, is_scalar_for_elemwise
|
|
633
|
+
from dask_array._expr import ArrayExpr
|
|
634
|
+
|
|
635
|
+
elemwise = self.array
|
|
636
|
+
out_ind = elemwise.out_ind
|
|
637
|
+
chunks = self._chunks
|
|
638
|
+
|
|
639
|
+
# Convert dict chunks to tuple for positional indexing
|
|
640
|
+
if isinstance(chunks, dict):
|
|
641
|
+
chunks = tuple(chunks.get(i, -1) for i in range(elemwise.ndim))
|
|
642
|
+
|
|
643
|
+
def rechunk_array_arg(arg):
|
|
644
|
+
"""Rechunk an array argument to match target output chunks."""
|
|
645
|
+
if is_scalar_for_elemwise(arg):
|
|
646
|
+
return arg
|
|
647
|
+
if not isinstance(arg, ArrayExpr):
|
|
648
|
+
return arg
|
|
649
|
+
# Map output chunks to this input's dimensions
|
|
650
|
+
# arg has indices tuple(range(arg.ndim)[::-1])
|
|
651
|
+
arg_ind = tuple(range(arg.ndim)[::-1])
|
|
652
|
+
|
|
653
|
+
# For each dimension of arg, find where its index appears in out_ind
|
|
654
|
+
arg_chunks = []
|
|
655
|
+
for i, dim_idx in enumerate(arg_ind):
|
|
656
|
+
# Get the arg's dimension size for this position
|
|
657
|
+
arg_dim_size = arg.shape[i]
|
|
658
|
+
|
|
659
|
+
# If this dimension is broadcast (size 1), keep its original chunk
|
|
660
|
+
if arg_dim_size == 1:
|
|
661
|
+
arg_chunks.append((1,))
|
|
662
|
+
continue
|
|
663
|
+
|
|
664
|
+
try:
|
|
665
|
+
out_pos = out_ind.index(dim_idx)
|
|
666
|
+
arg_chunks.append(chunks[out_pos])
|
|
667
|
+
except ValueError:
|
|
668
|
+
# Index not in output (shouldn't happen for elemwise)
|
|
669
|
+
arg_chunks.append(-1) # auto
|
|
670
|
+
|
|
671
|
+
return arg.rechunk(tuple(arg_chunks))
|
|
672
|
+
|
|
673
|
+
new_args = [rechunk_array_arg(arg) for arg in elemwise.elemwise_args]
|
|
674
|
+
|
|
675
|
+
# Also rechunk where and out if they are arrays
|
|
676
|
+
new_where = elemwise.where
|
|
677
|
+
if isinstance(new_where, ArrayExpr):
|
|
678
|
+
new_where = rechunk_array_arg(new_where)
|
|
679
|
+
|
|
680
|
+
new_out = elemwise.out
|
|
681
|
+
if isinstance(new_out, ArrayExpr):
|
|
682
|
+
new_out = rechunk_array_arg(new_out)
|
|
683
|
+
|
|
684
|
+
return Elemwise(
|
|
685
|
+
elemwise.op,
|
|
686
|
+
elemwise.operand("dtype"),
|
|
687
|
+
elemwise.operand("name"),
|
|
688
|
+
new_where,
|
|
689
|
+
new_out,
|
|
690
|
+
elemwise.operand("_user_kwargs"),
|
|
691
|
+
*new_args,
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
def _pushdown_through_concatenate(self):
|
|
695
|
+
"""Push rechunk through concatenate for non-concat axes."""
|
|
696
|
+
from dask_array._new_collection import new_collection
|
|
697
|
+
|
|
698
|
+
concat = self.array
|
|
699
|
+
axis = concat.axis
|
|
700
|
+
arrays = concat.args
|
|
701
|
+
chunks = self._chunks
|
|
702
|
+
|
|
703
|
+
# Only handle tuple chunks for now
|
|
704
|
+
if not isinstance(chunks, tuple):
|
|
705
|
+
# For dict chunks, check if we're only rechunking non-concat axes
|
|
706
|
+
if isinstance(chunks, dict) and axis not in chunks:
|
|
707
|
+
# Build chunks for each input (same rechunk spec)
|
|
708
|
+
rechunked_arrays = [new_collection(a).rechunk(chunks) for a in arrays]
|
|
709
|
+
return type(concat)(
|
|
710
|
+
rechunked_arrays[0].expr,
|
|
711
|
+
axis,
|
|
712
|
+
concat._meta,
|
|
713
|
+
*[a.expr for a in rechunked_arrays[1:]],
|
|
714
|
+
)
|
|
715
|
+
return None
|
|
716
|
+
|
|
717
|
+
# Only push through if we're not changing the concat axis chunking
|
|
718
|
+
# (redistributing across concat boundaries is too complex)
|
|
719
|
+
if chunks[axis] != concat.chunks[axis]:
|
|
720
|
+
return None
|
|
721
|
+
|
|
722
|
+
# Build rechunk spec for each input (excluding concat axis)
|
|
723
|
+
# For the concat axis, each input keeps its original chunks
|
|
724
|
+
rechunked_arrays = []
|
|
725
|
+
for arr in arrays:
|
|
726
|
+
arr_chunks = list(chunks)
|
|
727
|
+
arr_chunks[axis] = arr.chunks[axis]
|
|
728
|
+
rechunked_arrays.append(new_collection(arr).rechunk(tuple(arr_chunks)))
|
|
729
|
+
|
|
730
|
+
return type(concat)(
|
|
731
|
+
rechunked_arrays[0].expr,
|
|
732
|
+
axis,
|
|
733
|
+
concat._meta,
|
|
734
|
+
*[a.expr for a in rechunked_arrays[1:]],
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
def _lower(self):
|
|
738
|
+
if not self.balance and (self.chunks == self.array.chunks):
|
|
739
|
+
return self.array
|
|
740
|
+
|
|
741
|
+
method = self.method or _choose_rechunk_method(self.array.chunks, self.chunks, threshold=self.threshold)
|
|
742
|
+
if method == "p2p":
|
|
743
|
+
return P2PRechunk(
|
|
744
|
+
self.array,
|
|
745
|
+
self.chunks,
|
|
746
|
+
self.threshold,
|
|
747
|
+
self.block_size_limit,
|
|
748
|
+
self.balance,
|
|
749
|
+
)
|
|
750
|
+
else:
|
|
751
|
+
return TasksRechunk(self.array, self.chunks, self.threshold, self.block_size_limit)
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
class TasksRechunk(Rechunk):
|
|
755
|
+
_parameters = ["array", "_chunks", "threshold", "block_size_limit"]
|
|
756
|
+
|
|
757
|
+
@cached_property
|
|
758
|
+
def chunks(self):
|
|
759
|
+
return self.operand("_chunks")
|
|
760
|
+
|
|
761
|
+
def _simplify_down(self):
|
|
762
|
+
# TasksRechunk is already lowered - don't apply parent's simplifications
|
|
763
|
+
return None
|
|
764
|
+
|
|
765
|
+
def _lower(self):
|
|
766
|
+
return
|
|
767
|
+
|
|
768
|
+
def _layer(self):
|
|
769
|
+
steps = plan_rechunk(
|
|
770
|
+
self.array.chunks,
|
|
771
|
+
self.chunks,
|
|
772
|
+
self.array.dtype.itemsize,
|
|
773
|
+
self.threshold,
|
|
774
|
+
self.block_size_limit,
|
|
775
|
+
)
|
|
776
|
+
name = self.array.name
|
|
777
|
+
old_chunks = self.array.chunks
|
|
778
|
+
layers = []
|
|
779
|
+
for i, c in enumerate(steps):
|
|
780
|
+
level = len(steps) - i - 1
|
|
781
|
+
name, old_chunks, layer = _compute_rechunk(name, old_chunks, c, level, self.name)
|
|
782
|
+
layers.append(layer)
|
|
783
|
+
|
|
784
|
+
return toolz.merge(*layers)
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def _convert_to_task_refs(obj):
|
|
788
|
+
"""Recursively convert nested lists of keys to TaskRefs."""
|
|
789
|
+
if isinstance(obj, list):
|
|
790
|
+
return List(*[_convert_to_task_refs(item) for item in obj])
|
|
791
|
+
elif isinstance(obj, tuple):
|
|
792
|
+
# Keys are tuples like (name, i, j, ...)
|
|
793
|
+
return TaskRef(obj)
|
|
794
|
+
else:
|
|
795
|
+
return obj
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
def _compute_rechunk(old_name, old_chunks, chunks, level, name):
|
|
799
|
+
"""Compute the rechunk of *x* to the given *chunks*."""
|
|
800
|
+
ndim = len(old_chunks)
|
|
801
|
+
crossed = intersect_chunks(old_chunks, chunks)
|
|
802
|
+
x2 = {}
|
|
803
|
+
intermediates = {}
|
|
804
|
+
|
|
805
|
+
if level != 0:
|
|
806
|
+
merge_name = name.replace("rechunk-merge-", f"rechunk-merge-{level}-")
|
|
807
|
+
split_name = name.replace("rechunk-merge-", f"rechunk-split-{level}-")
|
|
808
|
+
else:
|
|
809
|
+
merge_name = name.replace("rechunk-merge-", "rechunk-merge-")
|
|
810
|
+
split_name = name.replace("rechunk-merge-", "rechunk-split-")
|
|
811
|
+
split_name_suffixes = itertools.count()
|
|
812
|
+
|
|
813
|
+
# Pre-allocate old block references
|
|
814
|
+
old_blocks = np.empty([len(c) for c in old_chunks], dtype="O")
|
|
815
|
+
for index in np.ndindex(old_blocks.shape):
|
|
816
|
+
old_blocks[index] = (old_name,) + index
|
|
817
|
+
|
|
818
|
+
# Iterate over all new blocks
|
|
819
|
+
new_index = itertools.product(*(range(len(c)) for c in chunks))
|
|
820
|
+
|
|
821
|
+
for new_idx, cross1 in zip(new_index, crossed):
|
|
822
|
+
key = (merge_name,) + new_idx
|
|
823
|
+
old_block_indices = [[cr[i][0] for cr in cross1] for i in range(ndim)]
|
|
824
|
+
subdims1 = [len(set(old_block_indices[i])) for i in range(ndim)]
|
|
825
|
+
|
|
826
|
+
rec_cat_arg = np.empty(subdims1, dtype="O")
|
|
827
|
+
rec_cat_arg_flat = rec_cat_arg.flat
|
|
828
|
+
|
|
829
|
+
# Iterate over the old blocks required to build the new block
|
|
830
|
+
for rec_cat_index, ind_slices in enumerate(cross1):
|
|
831
|
+
old_block_index, slices = zip(*ind_slices)
|
|
832
|
+
intermediate_name = (split_name, next(split_name_suffixes))
|
|
833
|
+
old_index = old_blocks[old_block_index][1:]
|
|
834
|
+
if all(
|
|
835
|
+
slc.start == 0 and slc.stop == old_chunks[i][ind] for i, (slc, ind) in enumerate(zip(slices, old_index))
|
|
836
|
+
):
|
|
837
|
+
# No slicing needed - use old block directly
|
|
838
|
+
rec_cat_arg_flat[rec_cat_index] = old_blocks[old_block_index]
|
|
839
|
+
else:
|
|
840
|
+
# Need to slice the old block
|
|
841
|
+
intermediates[intermediate_name] = Task(
|
|
842
|
+
intermediate_name,
|
|
843
|
+
operator.getitem,
|
|
844
|
+
TaskRef(old_blocks[old_block_index]),
|
|
845
|
+
slices,
|
|
846
|
+
)
|
|
847
|
+
rec_cat_arg_flat[rec_cat_index] = intermediate_name
|
|
848
|
+
|
|
849
|
+
assert rec_cat_index == rec_cat_arg.size - 1
|
|
850
|
+
|
|
851
|
+
# New block is formed by concatenation of sliced old blocks
|
|
852
|
+
if all(d == 1 for d in rec_cat_arg.shape):
|
|
853
|
+
# Single source block - alias to it
|
|
854
|
+
source_key = rec_cat_arg.flat[0]
|
|
855
|
+
x2[key] = Alias(key, source_key)
|
|
856
|
+
else:
|
|
857
|
+
# Multiple source blocks - concatenate
|
|
858
|
+
x2[key] = Task(key, concatenate3, _convert_to_task_refs(rec_cat_arg.tolist()))
|
|
859
|
+
|
|
860
|
+
del old_blocks, new_index
|
|
861
|
+
|
|
862
|
+
return merge_name, chunks, {**x2, **intermediates}
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
class P2PRechunk(ArrayExpr):
|
|
866
|
+
"""P2P rechunk expression using distributed shuffle."""
|
|
867
|
+
|
|
868
|
+
_parameters = ["array", "_chunks", "threshold", "block_size_limit", "balance"]
|
|
869
|
+
_defaults = {
|
|
870
|
+
"threshold": None,
|
|
871
|
+
"block_size_limit": None,
|
|
872
|
+
"balance": False,
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
@property
|
|
876
|
+
def _meta(self):
|
|
877
|
+
return self.array._meta
|
|
878
|
+
|
|
879
|
+
@property
|
|
880
|
+
def _name(self):
|
|
881
|
+
return "rechunk-p2p-" + tokenize(*self.operands)
|
|
882
|
+
|
|
883
|
+
@cached_property
|
|
884
|
+
def chunks(self):
|
|
885
|
+
return self.operand("_chunks")
|
|
886
|
+
|
|
887
|
+
@cached_property
|
|
888
|
+
def _prechunked_chunks(self):
|
|
889
|
+
"""Calculate chunks needed before the p2p shuffle."""
|
|
890
|
+
from distributed.shuffle._rechunk import _calculate_prechunking
|
|
891
|
+
|
|
892
|
+
return _calculate_prechunking(
|
|
893
|
+
self.array.chunks,
|
|
894
|
+
self.chunks,
|
|
895
|
+
self.array.dtype,
|
|
896
|
+
self.block_size_limit,
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
@cached_property
|
|
900
|
+
def _prechunked_array(self):
|
|
901
|
+
"""Return the input array, potentially prechunked."""
|
|
902
|
+
prechunked = self._prechunked_chunks
|
|
903
|
+
if prechunked != self.array.chunks:
|
|
904
|
+
return TasksRechunk(
|
|
905
|
+
self.array,
|
|
906
|
+
prechunked,
|
|
907
|
+
self.threshold,
|
|
908
|
+
self.block_size_limit,
|
|
909
|
+
)
|
|
910
|
+
return self.array
|
|
911
|
+
|
|
912
|
+
def _simplify_down(self):
|
|
913
|
+
# P2PRechunk is a lowered form - don't apply further simplifications
|
|
914
|
+
return None
|
|
915
|
+
|
|
916
|
+
def _lower(self):
|
|
917
|
+
return None
|
|
918
|
+
|
|
919
|
+
def _layer(self):
|
|
920
|
+
from distributed.shuffle._rechunk import (
|
|
921
|
+
_split_partials,
|
|
922
|
+
partial_concatenate,
|
|
923
|
+
partial_rechunk,
|
|
924
|
+
)
|
|
925
|
+
|
|
926
|
+
import dask
|
|
927
|
+
|
|
928
|
+
input_name = self._prechunked_array.name
|
|
929
|
+
input_chunks = self._prechunked_chunks
|
|
930
|
+
chunks = self.chunks
|
|
931
|
+
token = tokenize(*self.operands)
|
|
932
|
+
disk = dask.config.get("distributed.p2p.storage.disk")
|
|
933
|
+
|
|
934
|
+
_old_to_new = old_to_new(input_chunks, chunks)
|
|
935
|
+
|
|
936
|
+
# Create keepmap (all True - no culling at expression level)
|
|
937
|
+
shape = tuple(len(axis) for axis in chunks)
|
|
938
|
+
keepmap = np.ones(shape, dtype=bool)
|
|
939
|
+
|
|
940
|
+
dsk = {}
|
|
941
|
+
for ndpartial in _split_partials(_old_to_new):
|
|
942
|
+
partial_keepmap = keepmap[ndpartial.new]
|
|
943
|
+
output_count = np.sum(partial_keepmap)
|
|
944
|
+
if output_count == 0:
|
|
945
|
+
continue
|
|
946
|
+
elif output_count == 1:
|
|
947
|
+
# Single output chunk - use simple concatenation
|
|
948
|
+
dsk.update(
|
|
949
|
+
partial_concatenate(
|
|
950
|
+
input_name=input_name,
|
|
951
|
+
input_chunks=input_chunks,
|
|
952
|
+
ndpartial=ndpartial,
|
|
953
|
+
token=token,
|
|
954
|
+
keepmap=keepmap,
|
|
955
|
+
old_to_new=_old_to_new,
|
|
956
|
+
)
|
|
957
|
+
)
|
|
958
|
+
else:
|
|
959
|
+
# Multiple output chunks - use p2p shuffle
|
|
960
|
+
dsk.update(
|
|
961
|
+
partial_rechunk(
|
|
962
|
+
input_name=input_name,
|
|
963
|
+
input_chunks=input_chunks,
|
|
964
|
+
chunks=chunks,
|
|
965
|
+
ndpartial=ndpartial,
|
|
966
|
+
token=token,
|
|
967
|
+
disk=disk,
|
|
968
|
+
keepmap=keepmap,
|
|
969
|
+
)
|
|
970
|
+
)
|
|
971
|
+
return dsk
|
|
972
|
+
|
|
973
|
+
def dependencies(self):
|
|
974
|
+
return [self._prechunked_array]
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def rechunk(
|
|
978
|
+
x,
|
|
979
|
+
chunks="auto",
|
|
980
|
+
threshold=None,
|
|
981
|
+
block_size_limit=None,
|
|
982
|
+
balance=False,
|
|
983
|
+
method=None,
|
|
984
|
+
):
|
|
985
|
+
"""
|
|
986
|
+
Convert blocks in dask array x for new chunks.
|
|
987
|
+
|
|
988
|
+
Parameters
|
|
989
|
+
----------
|
|
990
|
+
x: dask array
|
|
991
|
+
Array to be rechunked.
|
|
992
|
+
chunks: int, tuple, dict or str, optional
|
|
993
|
+
The new block dimensions to create. -1 indicates the full size of the
|
|
994
|
+
corresponding dimension. Default is "auto" which automatically
|
|
995
|
+
determines chunk sizes.
|
|
996
|
+
threshold: int, optional
|
|
997
|
+
The graph growth factor under which we don't bother introducing an
|
|
998
|
+
intermediate step.
|
|
999
|
+
block_size_limit: int, optional
|
|
1000
|
+
The maximum block size (in bytes) we want to produce
|
|
1001
|
+
Defaults to the configuration value ``array.chunk-size``
|
|
1002
|
+
balance : bool, default False
|
|
1003
|
+
If True, try to make each chunk to be the same size.
|
|
1004
|
+
|
|
1005
|
+
This means ``balance=True`` will remove any small leftover chunks, so
|
|
1006
|
+
using ``x.rechunk(chunks=len(x) // N, balance=True)``
|
|
1007
|
+
will almost certainly result in ``N`` chunks.
|
|
1008
|
+
method: {'tasks', 'p2p'}, optional.
|
|
1009
|
+
Rechunking method to use.
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
Examples
|
|
1013
|
+
--------
|
|
1014
|
+
>>> import dask_array as da
|
|
1015
|
+
>>> x = da.ones((1000, 1000), chunks=(100, 100))
|
|
1016
|
+
|
|
1017
|
+
Specify uniform chunk sizes with a tuple
|
|
1018
|
+
|
|
1019
|
+
>>> y = x.rechunk((1000, 10))
|
|
1020
|
+
|
|
1021
|
+
Or chunk only specific dimensions with a dictionary
|
|
1022
|
+
|
|
1023
|
+
>>> y = x.rechunk({0: 1000})
|
|
1024
|
+
|
|
1025
|
+
Use the value ``-1`` to specify that you want a single chunk along a
|
|
1026
|
+
dimension or the value ``"auto"`` to specify that dask can freely rechunk a
|
|
1027
|
+
dimension to attain blocks of a uniform block size
|
|
1028
|
+
|
|
1029
|
+
>>> y = x.rechunk({0: -1, 1: 'auto'}, block_size_limit=1e8)
|
|
1030
|
+
|
|
1031
|
+
If a chunk size does not divide the dimension then rechunk will leave any
|
|
1032
|
+
unevenness to the last chunk.
|
|
1033
|
+
|
|
1034
|
+
>>> x.rechunk(chunks=(400, -1)).chunks
|
|
1035
|
+
((400, 400, 200), (1000,))
|
|
1036
|
+
|
|
1037
|
+
However if you want more balanced chunks, and don't mind Dask choosing a
|
|
1038
|
+
different chunksize for you then you can use the ``balance=True`` option.
|
|
1039
|
+
|
|
1040
|
+
>>> x.rechunk(chunks=(400, -1), balance=True).chunks
|
|
1041
|
+
((500, 500), (1000,))
|
|
1042
|
+
"""
|
|
1043
|
+
import dask
|
|
1044
|
+
from dask_array._new_collection import new_collection
|
|
1045
|
+
|
|
1046
|
+
# Capture config value at creation time, not during lowering
|
|
1047
|
+
if method is None:
|
|
1048
|
+
method = dask.config.get("array.rechunk.method", None)
|
|
1049
|
+
|
|
1050
|
+
return new_collection(x.expr.rechunk(chunks, threshold, block_size_limit, balance, method))
|