dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,102 @@
1
+ """Re-exports from routines submodules and other locations.
2
+
3
+ This module maintains backward compatibility by re-exporting all routines
4
+ from their new locations.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import numpy as np
10
+
11
+ from dask_array._core_utils import implements
12
+
13
+ # Re-exports from _blockwise
14
+ from dask_array._blockwise import outer # noqa: F401
15
+
16
+ # Re-exports from _ufunc
17
+ from dask_array._ufunc import ( # noqa: F401
18
+ allclose,
19
+ around,
20
+ isclose,
21
+ isnull,
22
+ notnull,
23
+ round,
24
+ )
25
+
26
+ # Re-exports from routines submodules
27
+ from dask_array.routines._apply import ( # noqa: F401
28
+ apply_along_axis,
29
+ apply_over_axes,
30
+ )
31
+ from dask_array.routines._bincount import bincount # noqa: F401
32
+ from dask_array.routines._broadcast import ( # noqa: F401
33
+ broadcast_arrays,
34
+ unify_chunks,
35
+ )
36
+ from dask_array.routines._coarsen import ( # noqa: F401
37
+ Coarsen,
38
+ aligned_coarsen_chunks,
39
+ coarsen,
40
+ )
41
+ from dask_array.routines._gradient import gradient # noqa: F401
42
+ from dask_array.routines._indexing import ( # noqa: F401
43
+ ravel_multi_index,
44
+ unravel_index,
45
+ )
46
+ from dask_array.routines._insert_delete import ( # noqa: F401
47
+ append,
48
+ delete,
49
+ ediff1d,
50
+ insert,
51
+ )
52
+ from dask_array.routines._misc import ( # noqa: F401
53
+ compress,
54
+ ndim,
55
+ result_type,
56
+ shape,
57
+ take,
58
+ )
59
+ from dask_array.routines._nonzero import ( # noqa: F401
60
+ argwhere,
61
+ count_nonzero,
62
+ flatnonzero,
63
+ isnonzero,
64
+ nonzero,
65
+ )
66
+ from dask_array.routines._search import ( # noqa: F401
67
+ isin,
68
+ searchsorted,
69
+ )
70
+ from dask_array.routines._select import ( # noqa: F401
71
+ choose,
72
+ digitize,
73
+ extract,
74
+ piecewise,
75
+ select,
76
+ )
77
+ from dask_array.routines._statistics import ( # noqa: F401
78
+ average,
79
+ corrcoef,
80
+ cov,
81
+ )
82
+ from dask_array.routines._topk import argtopk, topk # noqa: F401
83
+ from dask_array.routines._triangular import ( # noqa: F401
84
+ tril,
85
+ tril_indices,
86
+ tril_indices_from,
87
+ triu,
88
+ triu_indices,
89
+ triu_indices_from,
90
+ )
91
+ from dask_array.routines._unique import union1d, unique # noqa: F401
92
+
93
+
94
+ def ptp(a, axis=None):
95
+ """Peak to peak (maximum - minimum) value along a given axis."""
96
+ return a.max(axis=axis) - a.min(axis=axis)
97
+
98
+
99
+ @implements(np.iscomplexobj)
100
+ def iscomplexobj(x):
101
+ """Check whether the input has a complex dtype."""
102
+ return issubclass(x.dtype.type, np.complexfloating)
dask_array/_shuffle.py ADDED
@@ -0,0 +1,448 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import functools
5
+ import math
6
+ from functools import reduce
7
+ from itertools import count, product
8
+ from operator import mul
9
+ from typing import Literal
10
+
11
+ import numpy as np
12
+
13
+ from dask import config
14
+ from dask._task_spec import DataNode, List, Task, TaskRef
15
+ from dask_array._expr import ArrayExpr
16
+ from dask_array._chunk import getitem
17
+ from dask_array._dispatch import concatenate_lookup, take_lookup
18
+ from dask.base import tokenize
19
+
20
+
21
+ def _calculate_new_chunksizes(input_chunks, new_chunks, changeable_dimensions: set, maximum_chunk: int):
22
+ chunksize_tolerance = config.get("array.chunk-size-tolerance")
23
+ maximum_chunk = max(maximum_chunk, 1)
24
+
25
+ # iterate until we distributed the increase in chunksize across all dimensions
26
+ # or every non-shuffle dimension is all 1
27
+ while changeable_dimensions:
28
+ n_changeable_dimensions = len(changeable_dimensions)
29
+ chunksize_inc_factor = reduce(mul, map(max, new_chunks)) / maximum_chunk
30
+ if chunksize_inc_factor <= 1:
31
+ break
32
+
33
+ for i in list(changeable_dimensions):
34
+ new_chunksizes = []
35
+ # calculate what the max chunk size in this dimension is and split every
36
+ # chunk that is larger than that. We split the increase factor evenly
37
+ # between all dimensions that are not shuffled.
38
+ up_chunksize_limit_for_dim = max(new_chunks[i]) / (chunksize_inc_factor ** (1 / n_changeable_dimensions))
39
+ for c in input_chunks[i]:
40
+ if c > chunksize_tolerance * up_chunksize_limit_for_dim:
41
+ factor = math.ceil(c / up_chunksize_limit_for_dim)
42
+
43
+ # Ensure that we end up at least with chunksize 1
44
+ factor = min(factor, c)
45
+
46
+ chunksize, remainder = divmod(c, factor)
47
+ nc = [chunksize] * factor
48
+ for ii in range(remainder):
49
+ # Add remainder parts to the first few chunks
50
+ nc[ii] += 1
51
+ new_chunksizes.extend(nc)
52
+
53
+ else:
54
+ new_chunksizes.append(c)
55
+
56
+ if tuple(new_chunksizes) == new_chunks[i] or max(new_chunksizes) == 1:
57
+ changeable_dimensions.remove(i)
58
+
59
+ new_chunks[i] = tuple(new_chunksizes)
60
+ return new_chunks
61
+
62
+
63
+ def _rechunk_other_dimensions(x, longest_group: int, axis: int, chunks: Literal["auto"]):
64
+ """Rechunk other dimensions when shuffle groups are too large."""
65
+ assert chunks == "auto", "Only auto is supported for now"
66
+ chunksize_tolerance = config.get("array.chunk-size-tolerance")
67
+
68
+ if longest_group <= max(x.chunks[axis]) * chunksize_tolerance:
69
+ # We are staying below our threshold, so don't rechunk
70
+ return x
71
+
72
+ changeable_dimensions = set(range(len(x.chunks))) - {axis}
73
+ new_chunks = list(x.chunks)
74
+ new_chunks[axis] = (longest_group,)
75
+
76
+ # How large is the largest chunk in the input
77
+ maximum_chunk = reduce(mul, map(max, x.chunks))
78
+
79
+ new_chunks = _calculate_new_chunksizes(x.chunks, new_chunks, changeable_dimensions, maximum_chunk)
80
+ new_chunks[axis] = x.chunks[axis]
81
+ return x.rechunk(tuple(new_chunks))
82
+
83
+
84
+ def _validate_indexer(chunks, indexer, axis):
85
+ if not isinstance(indexer, list) or not all(isinstance(i, list) for i in indexer):
86
+ raise ValueError("indexer must be a list of lists of positional indices")
87
+
88
+ if not axis <= len(chunks):
89
+ raise ValueError(f"Axis {axis} is out of bounds for array with {len(chunks)} axes")
90
+
91
+ if max(map(max, indexer)) >= sum(chunks[axis]):
92
+ raise IndexError(f"Indexer contains out of bounds index. Dimension only has {sum(chunks[axis])} elements.")
93
+
94
+
95
+ def shuffle(x, indexer: list[list[int]], axis: int, chunks: Literal["auto"] = "auto"):
96
+ """
97
+ Reorders one dimensions of a Dask Array based on an indexer.
98
+
99
+ The indexer defines a list of positional groups that will end up in the same chunk
100
+ together. A single group is in at most one chunk on this dimension, but a chunk
101
+ might contain multiple groups to avoid fragmentation of the array.
102
+
103
+ The algorithm tries to balance the chunksizes as much as possible to ideally keep the
104
+ number of chunks consistent or at least manageable.
105
+
106
+ Parameters
107
+ ----------
108
+ x: dask array
109
+ Array to be shuffled.
110
+ indexer: list[list[int]]
111
+ The indexer that determines which elements along the dimension will end up in the
112
+ same chunk. Multiple groups can be in the same chunk to avoid fragmentation, but
113
+ each group will end up in exactly one chunk.
114
+ axis: int
115
+ The axis to shuffle along.
116
+ chunks: "auto"
117
+ Hint on how to rechunk if single groups are becoming too large. The default is
118
+ to split chunks along the other dimensions evenly to keep the chunksize
119
+ consistent. The rechunking is done in a way that ensures that non all-to-all
120
+ network communication is necessary, chunks are only split and not combined with
121
+ other chunks.
122
+
123
+ Examples
124
+ --------
125
+ >>> import dask_array as da
126
+ >>> import numpy as np
127
+ >>> arr = np.array([[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]])
128
+ >>> x = da.from_array(arr, chunks=(2, 4))
129
+
130
+ Separate the elements in different groups.
131
+
132
+ >>> y = x.shuffle([[6, 5, 2], [4, 1], [3, 0, 7]], axis=1)
133
+
134
+ The shuffle algorihthm will combine the first 2 groups into a single chunk to keep
135
+ the number of chunks small.
136
+
137
+ The tolerance of increasing the chunk size is controlled by the configuration
138
+ "array.chunk-size-tolerance". The default value is 1.25.
139
+
140
+ >>> y.chunks
141
+ ((2,), (5, 3))
142
+
143
+ The array was reordered along axis 1 according to the positional indexer that was given.
144
+
145
+ >>> y.compute()
146
+ array([[ 7, 6, 3, 5, 2, 4, 1, 8],
147
+ [15, 14, 11, 13, 10, 12, 9, 16]])
148
+ """
149
+ from dask_array._new_collection import new_collection
150
+
151
+ if np.isnan(x.shape).any():
152
+ from dask_array._core_utils import unknown_chunk_message
153
+
154
+ raise ValueError(f"Shuffling only allowed with known chunk sizes. {unknown_chunk_message}")
155
+ assert isinstance(axis, int), "axis must be an integer"
156
+ _validate_indexer(x.chunks, indexer, axis)
157
+
158
+ x = _rechunk_other_dimensions(x, max(map(len, indexer)), axis, chunks)
159
+
160
+ name = "shuffle"
161
+
162
+ result = _shuffle(x.expr, indexer, axis, name)
163
+ return new_collection(result)
164
+
165
+
166
+ def _shuffle(x, indexer, axis, name):
167
+ if len(indexer) == len(x.chunks[axis]):
168
+ # check if the array is already shuffled the way we want
169
+ ctr = 0
170
+ for idx, c in zip(indexer, x.chunks[axis]):
171
+ if idx != list(range(ctr, ctr + c)):
172
+ break
173
+ ctr += c
174
+ else:
175
+ return x
176
+ return Shuffle(x, indexer, axis, name)
177
+
178
+
179
+ class Shuffle(ArrayExpr):
180
+ _parameters = ["array", "indexer", "axis", "name"]
181
+
182
+ @functools.cached_property
183
+ def _meta(self):
184
+ return self.array._meta
185
+
186
+ @functools.cached_property
187
+ def _name(self):
188
+ return f"{self.operand('name')}-{self.deterministic_token}"
189
+
190
+ @functools.cached_property
191
+ def chunks(self):
192
+ output_chunks = []
193
+ for i, c in enumerate(self.array.chunks):
194
+ if i == self.axis:
195
+ output_chunks.append(tuple(map(len, self._new_chunks)))
196
+ else:
197
+ output_chunks.append(c)
198
+ return tuple(output_chunks)
199
+
200
+ @functools.cached_property
201
+ def _chunk_size_limit(self):
202
+ """Max input chunk size on the shuffle axis."""
203
+ return max(self.array.chunks[self.axis])
204
+
205
+ @functools.cached_property
206
+ def _new_chunks(self):
207
+ current_chunk, new_chunks = [], []
208
+ limit = self._chunk_size_limit
209
+ for idx in copy.deepcopy(self.indexer):
210
+ # Split oversized groups into limit-sized pieces
211
+ if len(idx) > limit:
212
+ # Flush current chunk first
213
+ if current_chunk:
214
+ new_chunks.append(current_chunk)
215
+ current_chunk = []
216
+ # Split large group into limit-sized pieces
217
+ for i in range(0, len(idx), limit):
218
+ new_chunks.append(idx[i : i + limit])
219
+ elif len(current_chunk) + len(idx) > limit and len(current_chunk) > 0:
220
+ new_chunks.append(current_chunk)
221
+ current_chunk = idx.copy()
222
+ else:
223
+ current_chunk.extend(idx)
224
+ if len(current_chunk) > limit:
225
+ new_chunks.append(current_chunk)
226
+ current_chunk = []
227
+ if len(current_chunk) > 0:
228
+ new_chunks.append(current_chunk)
229
+ return new_chunks
230
+
231
+ def _simplify_down(self):
232
+ """Push shuffle through various operations using _accept_shuffle pattern."""
233
+ # Check if child can accept this shuffle
234
+ if hasattr(self.array, "_accept_shuffle"):
235
+ return self.array._accept_shuffle(self)
236
+
237
+ def _simplify_up(self, parent, dependents):
238
+ """Allow slice operations to push through Shuffle."""
239
+ from dask_array.slicing import SliceSlicesIntegers
240
+
241
+ if isinstance(parent, SliceSlicesIntegers):
242
+ return self._accept_slice(parent)
243
+ return None
244
+
245
+ def _accept_slice(self, slice_expr):
246
+ """Accept a slice being pushed through Shuffle.
247
+
248
+ Shuffle reorganizes data along a single axis. We can push slices through:
249
+ 1. Non-shuffle axes: directly push through
250
+ 2. Shuffle axis (step=1): if input indices are contiguous, slice input
251
+ and adjust the indexer
252
+
253
+ Example (non-shuffle axis):
254
+ Slice(Shuffle(x, axis=0), [:, 10:20])
255
+ -> Shuffle(Slice(x, [:, 10:20]), axis=0)
256
+
257
+ Example (shuffle axis, contiguous):
258
+ Slice(Shuffle(x, axis=0), [100:200, :])
259
+ -> Shuffle(Slice(x, [input_start:input_stop, :]), adjusted_indexer, axis=0)
260
+ """
261
+ from dask_array._new_collection import new_collection
262
+
263
+ axis = self.axis
264
+ index = slice_expr.index
265
+ indexer = self.indexer
266
+
267
+ # Pad index to full length
268
+ full_index = list(index) + [slice(None)] * (len(self.shape) - len(index))
269
+
270
+ # Check if we're slicing on the shuffle axis
271
+ axis_slice = full_index[axis]
272
+
273
+ if axis_slice == slice(None):
274
+ # Not slicing shuffle axis - push through directly
275
+ sliced_input = new_collection(self.array)[tuple(full_index)]
276
+ return Shuffle(
277
+ sliced_input.expr,
278
+ indexer,
279
+ self.axis,
280
+ self.operand("name"),
281
+ )
282
+
283
+ # Slicing on shuffle axis - check if we can handle it
284
+ if not isinstance(axis_slice, slice):
285
+ return None # Integer indexing removes the dimension
286
+
287
+ # Only handle step=1 slices
288
+ if axis_slice.step is not None and axis_slice.step != 1:
289
+ return None
290
+
291
+ # Normalize slice bounds
292
+ axis_size = self.shape[axis]
293
+ start, stop, _ = axis_slice.indices(axis_size)
294
+ if start >= stop:
295
+ return None # Empty slice
296
+
297
+ # Slice by output positions. The indexer is grouped into runs, not
298
+ # one entry per output element.
299
+ new_indexer = []
300
+ offset = 0
301
+ for chunk in indexer:
302
+ chunk_start = offset
303
+ chunk_stop = offset + len(chunk)
304
+ offset = chunk_stop
305
+
306
+ take_start = max(start, chunk_start)
307
+ take_stop = min(stop, chunk_stop)
308
+ if take_start < take_stop:
309
+ new_indexer.append(chunk[take_start - chunk_start : take_stop - chunk_start])
310
+
311
+ # Find all input indices needed
312
+ input_indices = set()
313
+ for chunk in new_indexer:
314
+ input_indices.update(chunk)
315
+
316
+ if not input_indices:
317
+ return None # No indices
318
+
319
+ input_min = min(input_indices)
320
+ input_max = max(input_indices)
321
+
322
+ # Check if input indices are contiguous
323
+ if len(input_indices) != input_max - input_min + 1:
324
+ return None # Non-contiguous, can't use simple slice
325
+
326
+ # Adjust indexer: subtract input_min from each index
327
+ adjusted_indexer = [[idx - input_min for idx in chunk] for chunk in new_indexer]
328
+
329
+ # Build slice for input array
330
+ input_slice = list(full_index)
331
+ input_slice[axis] = slice(input_min, input_max + 1)
332
+
333
+ sliced_input = new_collection(self.array)[tuple(input_slice)]
334
+
335
+ return Shuffle(
336
+ sliced_input.expr,
337
+ adjusted_indexer,
338
+ self.axis,
339
+ self.operand("name"),
340
+ )
341
+
342
+ def _layer(self) -> dict:
343
+ chunks = self.array.chunks
344
+ axis = self.axis
345
+
346
+ chunk_boundaries = np.cumsum(chunks[axis])
347
+
348
+ # Get existing chunk tuple locations
349
+ chunk_tuples = list(product(*(range(len(c)) for i, c in enumerate(chunks) if i != axis)))
350
+
351
+ intermediates: dict = dict()
352
+ merges: dict = dict()
353
+ dtype = np.min_scalar_type(max(*chunks[axis], self._chunk_size_limit))
354
+ split_name = f"shuffle-split-{self.deterministic_token}"
355
+ slices = [slice(None)] * len(chunks)
356
+ split_name_suffixes = count()
357
+ sorter_name = "shuffle-sorter-"
358
+ taker_name = "shuffle-taker-"
359
+
360
+ old_blocks = {
361
+ old_index: (self.array._name,) + old_index for old_index in np.ndindex(tuple([len(c) for c in chunks]))
362
+ }
363
+
364
+ for new_chunk_idx, new_chunk_taker in enumerate(self._new_chunks):
365
+ new_chunk_taker = np.array(new_chunk_taker)
366
+ sorter = np.argsort(new_chunk_taker).astype(dtype)
367
+ sorter_key = sorter_name + tokenize(sorter)
368
+ # low level fusion can't deal with arrays on first position
369
+ merges[sorter_key] = DataNode(sorter_key, (1, sorter))
370
+
371
+ sorted_array = new_chunk_taker[sorter]
372
+ source_chunk_nr, taker_boundary_ = np.unique(
373
+ np.searchsorted(chunk_boundaries, sorted_array, side="right"),
374
+ return_index=True,
375
+ )
376
+ taker_boundary: list[int] = taker_boundary_.tolist()
377
+ taker_boundary.append(len(new_chunk_taker))
378
+
379
+ taker_cache: dict = {}
380
+ for chunk_tuple in chunk_tuples:
381
+ merge_keys = []
382
+
383
+ for c, b_start, b_end in zip(source_chunk_nr, taker_boundary[:-1], taker_boundary[1:]):
384
+ # insert our axis chunk id into the chunk_tuple
385
+ chunk_key = convert_key(chunk_tuple, c, axis)
386
+ name = (split_name, next(split_name_suffixes))
387
+ this_slice = slices.copy()
388
+
389
+ # Cache the takers to allow de-duplication when serializing
390
+ # Ugly!
391
+ if c in taker_cache:
392
+ taker_key = taker_cache[c]
393
+ else:
394
+ this_slice[axis] = (
395
+ sorted_array[b_start:b_end] - (chunk_boundaries[c - 1] if c > 0 else 0)
396
+ ).astype(dtype)
397
+ if len(source_chunk_nr) == 1:
398
+ this_slice[axis] = this_slice[axis][np.argsort(sorter)]
399
+
400
+ taker_key = taker_name + tokenize(this_slice)
401
+ # low level fusion can't deal with arrays on first position
402
+ intermediates[taker_key] = DataNode(taker_key, (1, tuple(this_slice)))
403
+ taker_cache[c] = taker_key
404
+
405
+ intermediates[name] = Task(
406
+ name,
407
+ _getitem,
408
+ TaskRef(old_blocks[chunk_key]),
409
+ TaskRef(taker_key),
410
+ )
411
+ merge_keys.append(name)
412
+
413
+ merge_suffix = convert_key(chunk_tuple, new_chunk_idx, axis)
414
+ out_name_merge = (self._name,) + merge_suffix
415
+ if len(merge_keys) > 1:
416
+ merges[out_name_merge] = Task(
417
+ out_name_merge,
418
+ concatenate_arrays,
419
+ List(*(TaskRef(m) for m in merge_keys)),
420
+ TaskRef(sorter_key),
421
+ axis,
422
+ )
423
+ elif len(merge_keys) == 1:
424
+ t = intermediates.pop(merge_keys[0])
425
+ t.key = out_name_merge
426
+ merges[out_name_merge] = t
427
+ else:
428
+ raise NotImplementedError
429
+
430
+ return {**merges, **intermediates}
431
+
432
+
433
+ def _getitem(obj, index):
434
+ return getitem(obj, index[1])
435
+
436
+
437
+ def concatenate_arrays(arrs, sorter, axis):
438
+ return take_lookup(
439
+ concatenate_lookup.dispatch(type(arrs[0]))(arrs, axis=axis),
440
+ np.argsort(sorter[1]),
441
+ axis=axis,
442
+ )
443
+
444
+
445
+ def convert_key(key, chunk, axis):
446
+ key = list(key)
447
+ key.insert(axis, int(chunk)) # Normalize np.int64 to Python int
448
+ return tuple(key)