dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_rechunk.py ADDED
@@ -0,0 +1,1050 @@
1
+ from __future__ import annotations
2
+
3
+ import heapq
4
+ import itertools
5
+ import math
6
+ import operator
7
+ from functools import reduce
8
+ from itertools import chain, product
9
+ from operator import add, itemgetter, mul
10
+ from warnings import warn
11
+
12
+ import numpy as np
13
+ import toolz
14
+ from tlz import accumulate
15
+
16
+ from dask import config
17
+ from dask._task_spec import Alias, List, Task, TaskRef
18
+ from dask.base import tokenize
19
+ from dask.utils import cached_property, parse_bytes
20
+
21
+ from dask_array._expr import ArrayExpr
22
+ from dask_array._core_utils import concatenate3, normalize_chunks
23
+ from dask_array._utils import validate_axis
24
+
25
+
26
+ # ============================================================================
27
+ # Rechunk planning utilities (copied from dask.array.rechunk)
28
+ # ============================================================================
29
+
30
+
31
+ def cumdims_label(chunks, const):
32
+ """Internal utility for cumulative sum with label.
33
+
34
+ >>> cumdims_label(((5, 3, 3), (2, 2, 1)), 'n') # doctest: +NORMALIZE_WHITESPACE
35
+ [(('n', 0), ('n', 5), ('n', 8), ('n', 11)),
36
+ (('n', 0), ('n', 2), ('n', 4), ('n', 5))]
37
+ """
38
+ return [tuple(zip((const,) * (1 + len(bds)), accumulate(add, (0,) + bds))) for bds in chunks]
39
+
40
+
41
+ def _breakpoints(cumold, cumnew):
42
+ """
43
+ >>> new = cumdims_label(((2, 3), (2, 2, 1)), 'n')
44
+ >>> old = cumdims_label(((2, 2, 1), (5,)), 'o')
45
+
46
+ >>> _breakpoints(new[0], old[0])
47
+ (('n', 0), ('o', 0), ('n', 2), ('o', 2), ('o', 4), ('n', 5), ('o', 5))
48
+ >>> _breakpoints(new[1], old[1])
49
+ (('n', 0), ('o', 0), ('n', 2), ('n', 4), ('n', 5), ('o', 5))
50
+ """
51
+ return tuple(sorted(cumold + cumnew, key=itemgetter(1)))
52
+
53
+
54
+ def _intersect_1d(breaks):
55
+ """
56
+ Internal utility to intersect chunks for 1d after preprocessing.
57
+
58
+ >>> new = cumdims_label(((2, 3), (2, 2, 1)), 'n')
59
+ >>> old = cumdims_label(((2, 2, 1), (5,)), 'o')
60
+
61
+ >>> _intersect_1d(_breakpoints(old[0], new[0])) # doctest: +NORMALIZE_WHITESPACE
62
+ [[(0, slice(0, 2, None))],
63
+ [(1, slice(0, 2, None)), (2, slice(0, 1, None))]]
64
+ >>> _intersect_1d(_breakpoints(old[1], new[1])) # doctest: +NORMALIZE_WHITESPACE
65
+ [[(0, slice(0, 2, None))],
66
+ [(0, slice(2, 4, None))],
67
+ [(0, slice(4, 5, None))]]
68
+
69
+ Parameters
70
+ ----------
71
+ breaks: list of tuples
72
+ Each tuple is ('o', 8) or ('n', 8)
73
+ These are pairs of 'o' old or new 'n'
74
+ indicator with a corresponding cumulative sum,
75
+ or breakpoint (a position along the chunking axis).
76
+ The list of pairs is already ordered by breakpoint.
77
+ Note that an 'o' pair always occurs BEFORE
78
+ an 'n' pair if both share the same breakpoint.
79
+ Uses 'o' and 'n' to make new tuples of slices for
80
+ the new block crosswalk to old blocks.
81
+ """
82
+ o_pairs = [pair for pair in breaks if pair[0] == "o"]
83
+ last_old_chunk_idx = len(o_pairs) - 2
84
+ last_o_br = o_pairs[-1][1]
85
+
86
+ start = 0
87
+ last_end = 0
88
+ old_idx = 0
89
+ last_o_end = 0
90
+ ret = []
91
+ ret_next = []
92
+ for idx in range(1, len(breaks)):
93
+ label, br = breaks[idx]
94
+ last_label, last_br = breaks[idx - 1]
95
+ if last_label == "n":
96
+ start = last_end
97
+ if ret_next:
98
+ ret.append(ret_next)
99
+ ret_next = []
100
+ else:
101
+ start = 0
102
+ end = br - last_br + start
103
+ last_end = end
104
+ if br == last_br:
105
+ if label == "o":
106
+ old_idx += 1
107
+ last_o_end = end
108
+ if label == "n" and last_label == "n":
109
+ if br == last_o_br:
110
+ slc = slice(last_o_end, last_o_end)
111
+ ret_next.append((last_old_chunk_idx, slc))
112
+ continue
113
+ else:
114
+ continue
115
+ ret_next.append((old_idx, slice(start, end)))
116
+ if label == "o":
117
+ old_idx += 1
118
+ start = 0
119
+ last_o_end = end
120
+
121
+ if ret_next:
122
+ ret.append(ret_next)
123
+
124
+ return ret
125
+
126
+
127
+ def old_to_new(old_chunks, new_chunks):
128
+ """Helper to build old_chunks to new_chunks.
129
+
130
+ Handles missing values, as long as the dimension with the missing chunk values
131
+ is unchanged.
132
+
133
+ Examples
134
+ --------
135
+ >>> old = ((10, 10, 10, 10, 10), )
136
+ >>> new = ((25, 5, 20), )
137
+ >>> old_to_new(old, new) # doctest: +NORMALIZE_WHITESPACE
138
+ [[[(0, slice(0, 10, None)), (1, slice(0, 10, None)), (2, slice(0, 5, None))],
139
+ [(2, slice(5, 10, None))],
140
+ [(3, slice(0, 10, None)), (4, slice(0, 10, None))]]]
141
+ """
142
+
143
+ def is_unknown(dim):
144
+ return any(math.isnan(chunk) for chunk in dim)
145
+
146
+ dims_unknown = [is_unknown(dim) for dim in old_chunks]
147
+
148
+ known_indices = []
149
+ unknown_indices = []
150
+ for i, unknown in enumerate(dims_unknown):
151
+ if unknown:
152
+ unknown_indices.append(i)
153
+ else:
154
+ known_indices.append(i)
155
+
156
+ old_known = [old_chunks[i] for i in known_indices]
157
+ new_known = [new_chunks[i] for i in known_indices]
158
+
159
+ cmos = cumdims_label(old_known, "o")
160
+ cmns = cumdims_label(new_known, "n")
161
+
162
+ sliced = [None] * len(old_chunks)
163
+ for i, cmo, cmn in zip(known_indices, cmos, cmns):
164
+ sliced[i] = _intersect_1d(_breakpoints(cmo, cmn))
165
+
166
+ for i in unknown_indices:
167
+ dim = old_chunks[i]
168
+ extra = [[(j, slice(0, size if not math.isnan(size) else None))] for j, size in enumerate(dim)]
169
+ sliced[i] = extra
170
+ assert all(x is not None for x in sliced)
171
+ return sliced
172
+
173
+
174
+ def intersect_chunks(old_chunks, new_chunks):
175
+ """
176
+ Make dask.array slices as intersection of old and new chunks.
177
+
178
+ >>> intersections = intersect_chunks(((4, 4), (2,)),
179
+ ... ((8,), (1, 1)))
180
+ >>> list(intersections) # doctest: +NORMALIZE_WHITESPACE
181
+ [(((0, slice(0, 4, None)), (0, slice(0, 1, None))),
182
+ ((1, slice(0, 4, None)), (0, slice(0, 1, None)))),
183
+ (((0, slice(0, 4, None)), (0, slice(1, 2, None))),
184
+ ((1, slice(0, 4, None)), (0, slice(1, 2, None))))]
185
+
186
+ Parameters
187
+ ----------
188
+ old_chunks : iterable of tuples
189
+ block sizes along each dimension (convert from old_chunks)
190
+ new_chunks: iterable of tuples
191
+ block sizes along each dimension (converts to new_chunks)
192
+ """
193
+ cross1 = product(*old_to_new(old_chunks, new_chunks))
194
+ cross = chain(tuple(product(*cr)) for cr in cross1)
195
+ return cross
196
+
197
+
198
+ def _validate_rechunk(old_chunks, new_chunks):
199
+ """Validates that rechunking an array from ``old_chunks`` to ``new_chunks``
200
+ is possible, raises an error if otherwise.
201
+ """
202
+ assert len(old_chunks) == len(new_chunks)
203
+
204
+ old_shapes = tuple(map(sum, old_chunks))
205
+ new_shapes = tuple(map(sum, new_chunks))
206
+
207
+ for old_shape, old_dim, new_shape, new_dim in zip(old_shapes, old_chunks, new_shapes, new_chunks):
208
+ if old_shape != new_shape:
209
+ if not (math.isnan(old_shape) and math.isnan(new_shape)) or not np.array_equal(
210
+ old_dim, new_dim, equal_nan=True
211
+ ):
212
+ raise ValueError(
213
+ "Chunks must be unchanging along dimensions with missing values.\n\n"
214
+ "A possible solution:\n x.compute_chunk_sizes()"
215
+ )
216
+
217
+
218
+ def _number_of_blocks(chunks):
219
+ return reduce(mul, map(len, chunks))
220
+
221
+
222
+ def _largest_block_size(chunks):
223
+ return reduce(mul, map(max, chunks))
224
+
225
+
226
+ def estimate_graph_size(old_chunks, new_chunks):
227
+ """Estimate the graph size during a rechunk computation."""
228
+ crossed_size = reduce(
229
+ mul,
230
+ ((len(oc) + len(nc) - 1 if oc != nc else len(oc)) for oc, nc in zip(old_chunks, new_chunks)),
231
+ )
232
+ return crossed_size
233
+
234
+
235
+ def divide_to_width(desired_chunks, max_width):
236
+ """Minimally divide the given chunks so as to make the largest chunk
237
+ width less or equal than *max_width*.
238
+ """
239
+ chunks = []
240
+ for c in desired_chunks:
241
+ nb_divides = int(np.ceil(c / max_width))
242
+ for i in range(nb_divides):
243
+ n = c // (nb_divides - i)
244
+ chunks.append(n)
245
+ c -= n
246
+ assert c == 0
247
+ return tuple(chunks)
248
+
249
+
250
+ def merge_to_number(desired_chunks, max_number):
251
+ """Minimally merge the given chunks so as to drop the number of
252
+ chunks below *max_number*, while minimizing the largest width.
253
+ """
254
+ if len(desired_chunks) <= max_number:
255
+ return desired_chunks
256
+
257
+ distinct = set(desired_chunks)
258
+ if len(distinct) == 1:
259
+ w = distinct.pop()
260
+ n = len(desired_chunks)
261
+ total = n * w
262
+
263
+ desired_width = total // max_number
264
+ width = w * (desired_width // w)
265
+ adjust = (total - max_number * width) // w
266
+
267
+ return (width + w,) * adjust + (width,) * (max_number - adjust)
268
+
269
+ desired_width = sum(desired_chunks) // max_number
270
+ nmerges = len(desired_chunks) - max_number
271
+
272
+ heap = [(desired_chunks[i] + desired_chunks[i + 1], i, i + 1) for i in range(len(desired_chunks) - 1)]
273
+ heapq.heapify(heap)
274
+
275
+ chunks = list(desired_chunks)
276
+
277
+ while nmerges > 0:
278
+ width, i, j = heapq.heappop(heap)
279
+ if chunks[j] == 0:
280
+ j += 1
281
+ while chunks[j] == 0:
282
+ j += 1
283
+ heapq.heappush(heap, (chunks[i] + chunks[j], i, j))
284
+ continue
285
+ elif chunks[i] + chunks[j] != width:
286
+ heapq.heappush(heap, (chunks[i] + chunks[j], i, j))
287
+ continue
288
+ assert chunks[i] != 0
289
+ chunks[i] = 0
290
+ chunks[j] = width
291
+ nmerges -= 1
292
+
293
+ return tuple(filter(None, chunks))
294
+
295
+
296
+ def find_merge_rechunk(old_chunks, new_chunks, block_size_limit):
297
+ """
298
+ Find an intermediate rechunk that would merge some adjacent blocks
299
+ together in order to get us nearer the *new_chunks* target, without
300
+ violating the *block_size_limit* (in number of elements).
301
+ """
302
+ ndim = len(old_chunks)
303
+
304
+ old_largest_width = [max(c) for c in old_chunks]
305
+ new_largest_width = [max(c) for c in new_chunks]
306
+
307
+ graph_size_effect = {dim: len(nc) / len(oc) for dim, (oc, nc) in enumerate(zip(old_chunks, new_chunks))}
308
+
309
+ block_size_effect = {dim: new_largest_width[dim] / (old_largest_width[dim] or 1) for dim in range(ndim)}
310
+
311
+ merge_candidates = [dim for dim in range(ndim) if graph_size_effect[dim] <= 1.0]
312
+
313
+ def key(k):
314
+ gse = graph_size_effect[k]
315
+ bse = block_size_effect[k]
316
+ if bse == 1:
317
+ bse = 1 + 1e-9
318
+ return (np.log(gse) / np.log(bse)) if bse > 0 else 0
319
+
320
+ sorted_candidates = sorted(merge_candidates, key=key)
321
+
322
+ largest_block_size = reduce(mul, old_largest_width)
323
+
324
+ chunks = list(old_chunks)
325
+ memory_limit_hit = False
326
+
327
+ for dim in sorted_candidates:
328
+ new_largest_block_size = largest_block_size * new_largest_width[dim] // (old_largest_width[dim] or 1)
329
+ if new_largest_block_size <= block_size_limit:
330
+ chunks[dim] = new_chunks[dim]
331
+ largest_block_size = new_largest_block_size
332
+ else:
333
+ largest_width = old_largest_width[dim]
334
+ chunk_limit = int(block_size_limit * largest_width / largest_block_size)
335
+ c = divide_to_width(new_chunks[dim], chunk_limit)
336
+ if len(c) <= len(old_chunks[dim]):
337
+ chunks[dim] = c
338
+ largest_block_size = largest_block_size * max(c) // largest_width
339
+
340
+ memory_limit_hit = True
341
+
342
+ assert largest_block_size == _largest_block_size(chunks)
343
+ assert largest_block_size <= block_size_limit
344
+ return tuple(chunks), memory_limit_hit
345
+
346
+
347
+ def find_split_rechunk(old_chunks, new_chunks, graph_size_limit):
348
+ """
349
+ Find an intermediate rechunk that would split some chunks to
350
+ get us nearer *new_chunks*, without violating the *graph_size_limit*.
351
+ """
352
+ ndim = len(old_chunks)
353
+
354
+ chunks = list(old_chunks)
355
+
356
+ for dim in range(ndim):
357
+ graph_size = estimate_graph_size(chunks, new_chunks)
358
+ if graph_size > graph_size_limit:
359
+ break
360
+ if len(old_chunks[dim]) > len(new_chunks[dim]):
361
+ continue
362
+ max_number = int(len(old_chunks[dim]) * graph_size_limit / graph_size)
363
+ c = merge_to_number(new_chunks[dim], max_number)
364
+ assert len(c) <= max_number
365
+ if len(c) >= len(old_chunks[dim]) and max(c) <= max(old_chunks[dim]):
366
+ chunks[dim] = c
367
+
368
+ return tuple(chunks)
369
+
370
+
371
+ def _graph_size_threshold(old_chunks, new_chunks, threshold):
372
+ return threshold * (_number_of_blocks(old_chunks) + _number_of_blocks(new_chunks))
373
+
374
+
375
+ def plan_rechunk(old_chunks, new_chunks, itemsize, threshold=None, block_size_limit=None):
376
+ """Plan an iterative rechunking from *old_chunks* to *new_chunks*.
377
+ The plan aims to minimize the rechunk graph size.
378
+
379
+ Parameters
380
+ ----------
381
+ itemsize: int
382
+ The item size of the array
383
+ threshold: int
384
+ The graph growth factor under which we don't bother
385
+ introducing an intermediate step
386
+ block_size_limit: int
387
+ The maximum block size (in bytes) we want to produce during an
388
+ intermediate step
389
+ """
390
+ threshold = threshold or config.get("array.rechunk.threshold")
391
+ block_size_limit = block_size_limit or config.get("array.chunk-size")
392
+ if isinstance(block_size_limit, str):
393
+ block_size_limit = parse_bytes(block_size_limit)
394
+
395
+ has_nans = (any(math.isnan(y) for y in x) for x in old_chunks)
396
+
397
+ if len(new_chunks) <= 1 or not all(new_chunks) or any(has_nans):
398
+ return [new_chunks]
399
+
400
+ block_size_limit /= itemsize
401
+
402
+ largest_old_block = _largest_block_size(old_chunks)
403
+ largest_new_block = _largest_block_size(new_chunks)
404
+ block_size_limit = max([block_size_limit, largest_old_block, largest_new_block])
405
+
406
+ graph_size_threshold = _graph_size_threshold(old_chunks, new_chunks, threshold)
407
+
408
+ current_chunks = old_chunks
409
+ first_pass = True
410
+ steps = []
411
+
412
+ while True:
413
+ graph_size = estimate_graph_size(current_chunks, new_chunks)
414
+ if graph_size < graph_size_threshold:
415
+ break
416
+
417
+ if first_pass:
418
+ chunks = current_chunks
419
+ else:
420
+ chunks = find_split_rechunk(current_chunks, new_chunks, graph_size * threshold)
421
+ chunks, memory_limit_hit = find_merge_rechunk(chunks, new_chunks, block_size_limit)
422
+ if (chunks == current_chunks and not first_pass) or chunks == new_chunks:
423
+ break
424
+ if chunks != current_chunks:
425
+ steps.append(chunks)
426
+ current_chunks = chunks
427
+ if not memory_limit_hit:
428
+ break
429
+ first_pass = False
430
+
431
+ return steps + [new_chunks]
432
+
433
+
434
+ def _get_chunks(n, chunksize):
435
+ leftover = n % chunksize
436
+ n_chunks = n // chunksize
437
+
438
+ chunks = [chunksize] * n_chunks
439
+ if leftover:
440
+ chunks.append(leftover)
441
+ return tuple(chunks)
442
+
443
+
444
+ def _balance_chunksizes(chunks: tuple[int, ...]) -> tuple[int, ...]:
445
+ """
446
+ Balance the chunk sizes
447
+
448
+ Parameters
449
+ ----------
450
+ chunks : tuple[int, ...]
451
+ Chunk sizes for Dask array.
452
+
453
+ Returns
454
+ -------
455
+ new_chunks : tuple[int, ...]
456
+ New chunks for Dask array with balanced sizes.
457
+ """
458
+ median_len = np.median(chunks).astype(int)
459
+ n_chunks = len(chunks)
460
+ eps = median_len // 2
461
+ if min(chunks) <= 0.5 * max(chunks):
462
+ n_chunks -= 1
463
+
464
+ new_chunks = [_get_chunks(sum(chunks), chunk_len) for chunk_len in range(median_len - eps, median_len + eps + 1)]
465
+ possible_chunks = [c for c in new_chunks if len(c) == n_chunks]
466
+ if not len(possible_chunks):
467
+ warn("chunk size balancing not possible with given chunks. Try increasing the chunk size.")
468
+ return chunks
469
+
470
+ diffs = [max(c) - min(c) for c in possible_chunks]
471
+ best_chunk_size = np.argmin(diffs)
472
+ return possible_chunks[best_chunk_size]
473
+
474
+
475
+ def _choose_rechunk_method(old_chunks, new_chunks, threshold=None):
476
+ if method := config.get("array.rechunk.method", None):
477
+ return method
478
+ try:
479
+ from distributed import default_client
480
+
481
+ default_client()
482
+ except (ImportError, ValueError):
483
+ return "tasks"
484
+
485
+ _old_to_new = old_to_new(old_chunks, new_chunks)
486
+ graph_size = math.prod(sum(len(ins) for ins in axis) for axis in _old_to_new)
487
+ threshold = threshold or config.get("array.rechunk.threshold")
488
+ graph_size_threshold = _graph_size_threshold(old_chunks, new_chunks, threshold)
489
+ return "tasks" if graph_size < graph_size_threshold else "p2p"
490
+
491
+
492
+ # ============================================================================
493
+ # Expression classes
494
+ # ============================================================================
495
+
496
+
497
+ class Rechunk(ArrayExpr):
498
+ _parameters = [
499
+ "array",
500
+ "_chunks",
501
+ "threshold",
502
+ "block_size_limit",
503
+ "balance",
504
+ "method",
505
+ ]
506
+
507
+ _defaults = {
508
+ "_chunks": "auto",
509
+ "threshold": None,
510
+ "block_size_limit": None,
511
+ "balance": None,
512
+ "method": None,
513
+ }
514
+
515
+ @property
516
+ def _meta(self):
517
+ return self.array._meta
518
+
519
+ @property
520
+ def _name(self):
521
+ return "rechunk-merge-" + tokenize(*self.operands)
522
+
523
+ @cached_property
524
+ def chunks(self):
525
+ x = self.array
526
+ chunks = self.operand("_chunks")
527
+
528
+ # don't rechunk if array is empty
529
+ if x.ndim > 0 and all(s == 0 for s in x.shape):
530
+ return x.chunks
531
+
532
+ if isinstance(chunks, dict):
533
+ chunks = {validate_axis(c, x.ndim): v for c, v in chunks.items()}
534
+ for i in range(x.ndim):
535
+ if i not in chunks:
536
+ chunks[i] = x.chunks[i]
537
+ elif chunks[i] is None:
538
+ chunks[i] = x.chunks[i]
539
+ if isinstance(chunks, (tuple, list)):
540
+ chunks = tuple(lc if lc is not None else rc for lc, rc in zip(chunks, x.chunks))
541
+ chunks = normalize_chunks(
542
+ chunks,
543
+ x.shape,
544
+ limit=self.block_size_limit,
545
+ dtype=x.dtype,
546
+ previous_chunks=x.chunks,
547
+ )
548
+
549
+ if not len(chunks) == x.ndim:
550
+ raise ValueError("Provided chunks are not consistent with shape")
551
+
552
+ if self.balance:
553
+ chunks = tuple(_balance_chunksizes(chunk) for chunk in chunks)
554
+
555
+ _validate_rechunk(x.chunks, chunks)
556
+
557
+ return chunks
558
+
559
+ def _simplify_down(self):
560
+ # No-op rechunk: if chunks already match, return the original array
561
+ if not self.balance and self.chunks == self.array.chunks:
562
+ return self.array
563
+
564
+ from dask_array._blockwise import Elemwise
565
+ from dask_array.manipulation._transpose import Transpose
566
+
567
+ # Rechunk(Rechunk(x)) -> single Rechunk to final chunks
568
+ # Only match Rechunk, not TasksRechunk (which is already lowered)
569
+ # Don't merge if inner has method='p2p' - preserve explicit p2p semantics
570
+ if type(self.array) is Rechunk and self.array.method != "p2p":
571
+ return Rechunk(
572
+ self.array.array,
573
+ self._chunks,
574
+ self.threshold,
575
+ self.block_size_limit,
576
+ self.balance or self.array.balance,
577
+ self.method,
578
+ )
579
+
580
+ # Rechunk(Transpose) -> Transpose(rechunked input)
581
+ if isinstance(self.array, Transpose):
582
+ return self._pushdown_through_transpose()
583
+
584
+ # Rechunk(Elemwise) -> Elemwise(rechunked inputs)
585
+ if isinstance(self.array, Elemwise):
586
+ return self._pushdown_through_elemwise()
587
+
588
+ # Rechunk(Concatenate) -> Concatenate(rechunked inputs)
589
+ # Only for non-concat axes
590
+ from dask_array._concatenate import Concatenate
591
+
592
+ if isinstance(self.array, Concatenate):
593
+ return self._pushdown_through_concatenate()
594
+
595
+ # Rechunk(IO) -> IO with new chunks (if IO supports it)
596
+ # Skip if method='p2p' is explicitly requested - user wants distributed shuffle
597
+ if getattr(self.array, "_can_rechunk_pushdown", False) and self.method != "p2p":
598
+ # Keep the same name prefix - the token will change with the new chunks
599
+ return self.array.substitute_parameters({"_chunks": self.chunks})
600
+
601
+ def _pushdown_through_transpose(self):
602
+ """Push rechunk through transpose by reordering chunk spec."""
603
+ from dask_array.manipulation._transpose import Transpose
604
+
605
+ transpose = self.array
606
+ axes = transpose.axes
607
+ chunks = self._chunks
608
+
609
+ if isinstance(chunks, tuple):
610
+ # Map output chunks back through transpose axes to get input chunks
611
+ # axes[i] tells us which input axis becomes output axis i
612
+ # So output axis i has chunks[i], which should go to input axis axes[i]
613
+ # We need to invert the permutation: place chunks[i] at position axes[i]
614
+ new_chunks = [None] * len(axes)
615
+ for i, ax in enumerate(axes):
616
+ new_chunks[ax] = chunks[i]
617
+ new_chunks = tuple(new_chunks)
618
+ elif isinstance(chunks, dict):
619
+ # Map dict keys through axes
620
+ new_chunks = {}
621
+ for out_axis, chunk_spec in chunks.items():
622
+ in_axis = axes[out_axis]
623
+ new_chunks[in_axis] = chunk_spec
624
+ else:
625
+ return None
626
+
627
+ rechunked_input = transpose.array.rechunk(new_chunks)
628
+ return Transpose(rechunked_input, axes)
629
+
630
+ def _pushdown_through_elemwise(self):
631
+ """Push rechunk through elemwise by rechunking each input."""
632
+ from dask_array._blockwise import Elemwise, is_scalar_for_elemwise
633
+ from dask_array._expr import ArrayExpr
634
+
635
+ elemwise = self.array
636
+ out_ind = elemwise.out_ind
637
+ chunks = self._chunks
638
+
639
+ # Convert dict chunks to tuple for positional indexing
640
+ if isinstance(chunks, dict):
641
+ chunks = tuple(chunks.get(i, -1) for i in range(elemwise.ndim))
642
+
643
+ def rechunk_array_arg(arg):
644
+ """Rechunk an array argument to match target output chunks."""
645
+ if is_scalar_for_elemwise(arg):
646
+ return arg
647
+ if not isinstance(arg, ArrayExpr):
648
+ return arg
649
+ # Map output chunks to this input's dimensions
650
+ # arg has indices tuple(range(arg.ndim)[::-1])
651
+ arg_ind = tuple(range(arg.ndim)[::-1])
652
+
653
+ # For each dimension of arg, find where its index appears in out_ind
654
+ arg_chunks = []
655
+ for i, dim_idx in enumerate(arg_ind):
656
+ # Get the arg's dimension size for this position
657
+ arg_dim_size = arg.shape[i]
658
+
659
+ # If this dimension is broadcast (size 1), keep its original chunk
660
+ if arg_dim_size == 1:
661
+ arg_chunks.append((1,))
662
+ continue
663
+
664
+ try:
665
+ out_pos = out_ind.index(dim_idx)
666
+ arg_chunks.append(chunks[out_pos])
667
+ except ValueError:
668
+ # Index not in output (shouldn't happen for elemwise)
669
+ arg_chunks.append(-1) # auto
670
+
671
+ return arg.rechunk(tuple(arg_chunks))
672
+
673
+ new_args = [rechunk_array_arg(arg) for arg in elemwise.elemwise_args]
674
+
675
+ # Also rechunk where and out if they are arrays
676
+ new_where = elemwise.where
677
+ if isinstance(new_where, ArrayExpr):
678
+ new_where = rechunk_array_arg(new_where)
679
+
680
+ new_out = elemwise.out
681
+ if isinstance(new_out, ArrayExpr):
682
+ new_out = rechunk_array_arg(new_out)
683
+
684
+ return Elemwise(
685
+ elemwise.op,
686
+ elemwise.operand("dtype"),
687
+ elemwise.operand("name"),
688
+ new_where,
689
+ new_out,
690
+ elemwise.operand("_user_kwargs"),
691
+ *new_args,
692
+ )
693
+
694
+ def _pushdown_through_concatenate(self):
695
+ """Push rechunk through concatenate for non-concat axes."""
696
+ from dask_array._new_collection import new_collection
697
+
698
+ concat = self.array
699
+ axis = concat.axis
700
+ arrays = concat.args
701
+ chunks = self._chunks
702
+
703
+ # Only handle tuple chunks for now
704
+ if not isinstance(chunks, tuple):
705
+ # For dict chunks, check if we're only rechunking non-concat axes
706
+ if isinstance(chunks, dict) and axis not in chunks:
707
+ # Build chunks for each input (same rechunk spec)
708
+ rechunked_arrays = [new_collection(a).rechunk(chunks) for a in arrays]
709
+ return type(concat)(
710
+ rechunked_arrays[0].expr,
711
+ axis,
712
+ concat._meta,
713
+ *[a.expr for a in rechunked_arrays[1:]],
714
+ )
715
+ return None
716
+
717
+ # Only push through if we're not changing the concat axis chunking
718
+ # (redistributing across concat boundaries is too complex)
719
+ if chunks[axis] != concat.chunks[axis]:
720
+ return None
721
+
722
+ # Build rechunk spec for each input (excluding concat axis)
723
+ # For the concat axis, each input keeps its original chunks
724
+ rechunked_arrays = []
725
+ for arr in arrays:
726
+ arr_chunks = list(chunks)
727
+ arr_chunks[axis] = arr.chunks[axis]
728
+ rechunked_arrays.append(new_collection(arr).rechunk(tuple(arr_chunks)))
729
+
730
+ return type(concat)(
731
+ rechunked_arrays[0].expr,
732
+ axis,
733
+ concat._meta,
734
+ *[a.expr for a in rechunked_arrays[1:]],
735
+ )
736
+
737
+ def _lower(self):
738
+ if not self.balance and (self.chunks == self.array.chunks):
739
+ return self.array
740
+
741
+ method = self.method or _choose_rechunk_method(self.array.chunks, self.chunks, threshold=self.threshold)
742
+ if method == "p2p":
743
+ return P2PRechunk(
744
+ self.array,
745
+ self.chunks,
746
+ self.threshold,
747
+ self.block_size_limit,
748
+ self.balance,
749
+ )
750
+ else:
751
+ return TasksRechunk(self.array, self.chunks, self.threshold, self.block_size_limit)
752
+
753
+
754
+ class TasksRechunk(Rechunk):
755
+ _parameters = ["array", "_chunks", "threshold", "block_size_limit"]
756
+
757
+ @cached_property
758
+ def chunks(self):
759
+ return self.operand("_chunks")
760
+
761
+ def _simplify_down(self):
762
+ # TasksRechunk is already lowered - don't apply parent's simplifications
763
+ return None
764
+
765
+ def _lower(self):
766
+ return
767
+
768
+ def _layer(self):
769
+ steps = plan_rechunk(
770
+ self.array.chunks,
771
+ self.chunks,
772
+ self.array.dtype.itemsize,
773
+ self.threshold,
774
+ self.block_size_limit,
775
+ )
776
+ name = self.array.name
777
+ old_chunks = self.array.chunks
778
+ layers = []
779
+ for i, c in enumerate(steps):
780
+ level = len(steps) - i - 1
781
+ name, old_chunks, layer = _compute_rechunk(name, old_chunks, c, level, self.name)
782
+ layers.append(layer)
783
+
784
+ return toolz.merge(*layers)
785
+
786
+
787
+ def _convert_to_task_refs(obj):
788
+ """Recursively convert nested lists of keys to TaskRefs."""
789
+ if isinstance(obj, list):
790
+ return List(*[_convert_to_task_refs(item) for item in obj])
791
+ elif isinstance(obj, tuple):
792
+ # Keys are tuples like (name, i, j, ...)
793
+ return TaskRef(obj)
794
+ else:
795
+ return obj
796
+
797
+
798
+ def _compute_rechunk(old_name, old_chunks, chunks, level, name):
799
+ """Compute the rechunk of *x* to the given *chunks*."""
800
+ ndim = len(old_chunks)
801
+ crossed = intersect_chunks(old_chunks, chunks)
802
+ x2 = {}
803
+ intermediates = {}
804
+
805
+ if level != 0:
806
+ merge_name = name.replace("rechunk-merge-", f"rechunk-merge-{level}-")
807
+ split_name = name.replace("rechunk-merge-", f"rechunk-split-{level}-")
808
+ else:
809
+ merge_name = name.replace("rechunk-merge-", "rechunk-merge-")
810
+ split_name = name.replace("rechunk-merge-", "rechunk-split-")
811
+ split_name_suffixes = itertools.count()
812
+
813
+ # Pre-allocate old block references
814
+ old_blocks = np.empty([len(c) for c in old_chunks], dtype="O")
815
+ for index in np.ndindex(old_blocks.shape):
816
+ old_blocks[index] = (old_name,) + index
817
+
818
+ # Iterate over all new blocks
819
+ new_index = itertools.product(*(range(len(c)) for c in chunks))
820
+
821
+ for new_idx, cross1 in zip(new_index, crossed):
822
+ key = (merge_name,) + new_idx
823
+ old_block_indices = [[cr[i][0] for cr in cross1] for i in range(ndim)]
824
+ subdims1 = [len(set(old_block_indices[i])) for i in range(ndim)]
825
+
826
+ rec_cat_arg = np.empty(subdims1, dtype="O")
827
+ rec_cat_arg_flat = rec_cat_arg.flat
828
+
829
+ # Iterate over the old blocks required to build the new block
830
+ for rec_cat_index, ind_slices in enumerate(cross1):
831
+ old_block_index, slices = zip(*ind_slices)
832
+ intermediate_name = (split_name, next(split_name_suffixes))
833
+ old_index = old_blocks[old_block_index][1:]
834
+ if all(
835
+ slc.start == 0 and slc.stop == old_chunks[i][ind] for i, (slc, ind) in enumerate(zip(slices, old_index))
836
+ ):
837
+ # No slicing needed - use old block directly
838
+ rec_cat_arg_flat[rec_cat_index] = old_blocks[old_block_index]
839
+ else:
840
+ # Need to slice the old block
841
+ intermediates[intermediate_name] = Task(
842
+ intermediate_name,
843
+ operator.getitem,
844
+ TaskRef(old_blocks[old_block_index]),
845
+ slices,
846
+ )
847
+ rec_cat_arg_flat[rec_cat_index] = intermediate_name
848
+
849
+ assert rec_cat_index == rec_cat_arg.size - 1
850
+
851
+ # New block is formed by concatenation of sliced old blocks
852
+ if all(d == 1 for d in rec_cat_arg.shape):
853
+ # Single source block - alias to it
854
+ source_key = rec_cat_arg.flat[0]
855
+ x2[key] = Alias(key, source_key)
856
+ else:
857
+ # Multiple source blocks - concatenate
858
+ x2[key] = Task(key, concatenate3, _convert_to_task_refs(rec_cat_arg.tolist()))
859
+
860
+ del old_blocks, new_index
861
+
862
+ return merge_name, chunks, {**x2, **intermediates}
863
+
864
+
865
+ class P2PRechunk(ArrayExpr):
866
+ """P2P rechunk expression using distributed shuffle."""
867
+
868
+ _parameters = ["array", "_chunks", "threshold", "block_size_limit", "balance"]
869
+ _defaults = {
870
+ "threshold": None,
871
+ "block_size_limit": None,
872
+ "balance": False,
873
+ }
874
+
875
+ @property
876
+ def _meta(self):
877
+ return self.array._meta
878
+
879
+ @property
880
+ def _name(self):
881
+ return "rechunk-p2p-" + tokenize(*self.operands)
882
+
883
+ @cached_property
884
+ def chunks(self):
885
+ return self.operand("_chunks")
886
+
887
+ @cached_property
888
+ def _prechunked_chunks(self):
889
+ """Calculate chunks needed before the p2p shuffle."""
890
+ from distributed.shuffle._rechunk import _calculate_prechunking
891
+
892
+ return _calculate_prechunking(
893
+ self.array.chunks,
894
+ self.chunks,
895
+ self.array.dtype,
896
+ self.block_size_limit,
897
+ )
898
+
899
+ @cached_property
900
+ def _prechunked_array(self):
901
+ """Return the input array, potentially prechunked."""
902
+ prechunked = self._prechunked_chunks
903
+ if prechunked != self.array.chunks:
904
+ return TasksRechunk(
905
+ self.array,
906
+ prechunked,
907
+ self.threshold,
908
+ self.block_size_limit,
909
+ )
910
+ return self.array
911
+
912
+ def _simplify_down(self):
913
+ # P2PRechunk is a lowered form - don't apply further simplifications
914
+ return None
915
+
916
+ def _lower(self):
917
+ return None
918
+
919
+ def _layer(self):
920
+ from distributed.shuffle._rechunk import (
921
+ _split_partials,
922
+ partial_concatenate,
923
+ partial_rechunk,
924
+ )
925
+
926
+ import dask
927
+
928
+ input_name = self._prechunked_array.name
929
+ input_chunks = self._prechunked_chunks
930
+ chunks = self.chunks
931
+ token = tokenize(*self.operands)
932
+ disk = dask.config.get("distributed.p2p.storage.disk")
933
+
934
+ _old_to_new = old_to_new(input_chunks, chunks)
935
+
936
+ # Create keepmap (all True - no culling at expression level)
937
+ shape = tuple(len(axis) for axis in chunks)
938
+ keepmap = np.ones(shape, dtype=bool)
939
+
940
+ dsk = {}
941
+ for ndpartial in _split_partials(_old_to_new):
942
+ partial_keepmap = keepmap[ndpartial.new]
943
+ output_count = np.sum(partial_keepmap)
944
+ if output_count == 0:
945
+ continue
946
+ elif output_count == 1:
947
+ # Single output chunk - use simple concatenation
948
+ dsk.update(
949
+ partial_concatenate(
950
+ input_name=input_name,
951
+ input_chunks=input_chunks,
952
+ ndpartial=ndpartial,
953
+ token=token,
954
+ keepmap=keepmap,
955
+ old_to_new=_old_to_new,
956
+ )
957
+ )
958
+ else:
959
+ # Multiple output chunks - use p2p shuffle
960
+ dsk.update(
961
+ partial_rechunk(
962
+ input_name=input_name,
963
+ input_chunks=input_chunks,
964
+ chunks=chunks,
965
+ ndpartial=ndpartial,
966
+ token=token,
967
+ disk=disk,
968
+ keepmap=keepmap,
969
+ )
970
+ )
971
+ return dsk
972
+
973
+ def dependencies(self):
974
+ return [self._prechunked_array]
975
+
976
+
977
+ def rechunk(
978
+ x,
979
+ chunks="auto",
980
+ threshold=None,
981
+ block_size_limit=None,
982
+ balance=False,
983
+ method=None,
984
+ ):
985
+ """
986
+ Convert blocks in dask array x for new chunks.
987
+
988
+ Parameters
989
+ ----------
990
+ x: dask array
991
+ Array to be rechunked.
992
+ chunks: int, tuple, dict or str, optional
993
+ The new block dimensions to create. -1 indicates the full size of the
994
+ corresponding dimension. Default is "auto" which automatically
995
+ determines chunk sizes.
996
+ threshold: int, optional
997
+ The graph growth factor under which we don't bother introducing an
998
+ intermediate step.
999
+ block_size_limit: int, optional
1000
+ The maximum block size (in bytes) we want to produce
1001
+ Defaults to the configuration value ``array.chunk-size``
1002
+ balance : bool, default False
1003
+ If True, try to make each chunk to be the same size.
1004
+
1005
+ This means ``balance=True`` will remove any small leftover chunks, so
1006
+ using ``x.rechunk(chunks=len(x) // N, balance=True)``
1007
+ will almost certainly result in ``N`` chunks.
1008
+ method: {'tasks', 'p2p'}, optional.
1009
+ Rechunking method to use.
1010
+
1011
+
1012
+ Examples
1013
+ --------
1014
+ >>> import dask_array as da
1015
+ >>> x = da.ones((1000, 1000), chunks=(100, 100))
1016
+
1017
+ Specify uniform chunk sizes with a tuple
1018
+
1019
+ >>> y = x.rechunk((1000, 10))
1020
+
1021
+ Or chunk only specific dimensions with a dictionary
1022
+
1023
+ >>> y = x.rechunk({0: 1000})
1024
+
1025
+ Use the value ``-1`` to specify that you want a single chunk along a
1026
+ dimension or the value ``"auto"`` to specify that dask can freely rechunk a
1027
+ dimension to attain blocks of a uniform block size
1028
+
1029
+ >>> y = x.rechunk({0: -1, 1: 'auto'}, block_size_limit=1e8)
1030
+
1031
+ If a chunk size does not divide the dimension then rechunk will leave any
1032
+ unevenness to the last chunk.
1033
+
1034
+ >>> x.rechunk(chunks=(400, -1)).chunks
1035
+ ((400, 400, 200), (1000,))
1036
+
1037
+ However if you want more balanced chunks, and don't mind Dask choosing a
1038
+ different chunksize for you then you can use the ``balance=True`` option.
1039
+
1040
+ >>> x.rechunk(chunks=(400, -1), balance=True).chunks
1041
+ ((500, 500), (1000,))
1042
+ """
1043
+ import dask
1044
+ from dask_array._new_collection import new_collection
1045
+
1046
+ # Capture config value at creation time, not during lowering
1047
+ if method is None:
1048
+ method = dask.config.get("array.rechunk.method", None)
1049
+
1050
+ return new_collection(x.expr.rechunk(chunks, threshold, block_size_limit, balance, method))