dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,62 @@
1
+ """Where operation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+
7
+
8
+ def where(condition, x=None, y=None):
9
+ """Return elements chosen from x or y depending on condition.
10
+
11
+ Parameters
12
+ ----------
13
+ condition : array_like, bool
14
+ Where True, yield x, otherwise yield y.
15
+ x, y : array_like
16
+ Values from which to choose. x, y and condition need to be
17
+ broadcastable to some shape.
18
+
19
+ Returns
20
+ -------
21
+ out : Array
22
+ An array with elements from x where condition is True,
23
+ and elements from y elsewhere.
24
+
25
+ See Also
26
+ --------
27
+ numpy.where
28
+
29
+ Examples
30
+ --------
31
+ >>> import dask_array as da
32
+ >>> x = da.arange(10, chunks=5)
33
+ >>> da.where(x < 5, x, 10 * x).compute() # doctest: +NORMALIZE_WHITESPACE
34
+ array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
35
+ """
36
+ # Lazy imports to avoid circular dependencies
37
+ from dask_array.core import asarray
38
+ from dask_array.core._blockwise_funcs import elemwise
39
+
40
+ if (x is None) != (y is None):
41
+ raise ValueError("either both or neither of x and y should be given")
42
+ if (x is None) and (y is None):
43
+ # Single arg case - returns indices of nonzero elements
44
+ from dask_array._routines import nonzero
45
+
46
+ return nonzero(condition)
47
+
48
+ # Optimization: for scalar conditions, avoid elemwise overhead
49
+ if np.isscalar(condition):
50
+ from dask_array._broadcast import broadcast_to
51
+ from dask_array._core_utils import broadcast_shapes
52
+ from dask_array.routines._misc import result_type
53
+
54
+ dtype = result_type(x, y)
55
+ x = asarray(x)
56
+ y = asarray(y)
57
+ shape = broadcast_shapes(x.shape, y.shape)
58
+ out = x if condition else y
59
+ return broadcast_to(out, shape).astype(dtype)
60
+
61
+ # Use elemwise with np.where to handle all cases
62
+ return elemwise(np.where, condition, x, y)
@@ -0,0 +1,67 @@
1
+ """Slicing operations for dask array expressions."""
2
+
3
+ from dask_array.slicing._basic import (
4
+ ArrayOffsetDep,
5
+ Slice,
6
+ SliceSlicesIntegers,
7
+ TakeUnknownOneChunk,
8
+ normalize_index,
9
+ slice_array,
10
+ slice_slices_and_integers,
11
+ slice_with_int_dask_array,
12
+ slice_with_int_dask_array_on_axis,
13
+ slice_with_newaxes,
14
+ slice_wrap_lists,
15
+ take,
16
+ )
17
+ from dask_array.slicing._bool_index import (
18
+ BooleanIndexFlattened,
19
+ getitem_variadic,
20
+ slice_with_bool_dask_array,
21
+ )
22
+ from dask_array.slicing._setitem import (
23
+ ConcatenateArrayChunks,
24
+ SetItem,
25
+ concatenate_array_chunks_expr,
26
+ setitem_array_expr,
27
+ )
28
+ from dask_array.slicing._squeeze import Squeeze, squeeze
29
+ from dask_array.slicing._vindex import (
30
+ VIndexArray,
31
+ _numpy_vindex,
32
+ _vindex,
33
+ _vindex_array,
34
+ )
35
+
36
+ __all__ = [
37
+ # Basic slicing
38
+ "ArrayOffsetDep",
39
+ "Slice",
40
+ "SliceSlicesIntegers",
41
+ "TakeUnknownOneChunk",
42
+ "normalize_index",
43
+ "slice_array",
44
+ "slice_slices_and_integers",
45
+ "slice_with_int_dask_array",
46
+ "slice_with_int_dask_array_on_axis",
47
+ "slice_with_newaxes",
48
+ "slice_wrap_lists",
49
+ "take",
50
+ # Boolean indexing
51
+ "BooleanIndexFlattened",
52
+ "getitem_variadic",
53
+ "slice_with_bool_dask_array",
54
+ # Setitem
55
+ "ConcatenateArrayChunks",
56
+ "SetItem",
57
+ "concatenate_array_chunks_expr",
58
+ "setitem_array_expr",
59
+ # Squeeze
60
+ "Squeeze",
61
+ "squeeze",
62
+ # Vindex
63
+ "VIndexArray",
64
+ "_numpy_vindex",
65
+ "_vindex",
66
+ "_vindex_array",
67
+ ]
@@ -0,0 +1,550 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from itertools import product
5
+ from numbers import Integral
6
+
7
+ import numpy as np
8
+ from toolz import pluck
9
+
10
+ from dask._task_spec import Alias, Task, TaskRef
11
+ from dask_array._expr import ArrayExpr
12
+ from dask_array._chunk import getitem
13
+ from dask_array._utils import meta_from_array
14
+ from dask_array.slicing._utils import (
15
+ _slice_1d,
16
+ check_index,
17
+ fuse_slice,
18
+ new_blockdim,
19
+ normalize_slice,
20
+ posify_index,
21
+ replace_ellipsis,
22
+ sanitize_index,
23
+ )
24
+ from dask.layers import ArrayBlockwiseDep
25
+ from dask.utils import cached_cumsum, is_arraylike
26
+
27
+
28
+ def _compute_sliced_chunks(chunks, slc, dim_size):
29
+ """Compute chunk sizes for the sliced region of a dimension."""
30
+ if slc == slice(None):
31
+ return chunks
32
+
33
+ start, stop, step = slc.indices(dim_size)
34
+
35
+ # Handle step == -1 (flip) specially - preserve chunks in reverse order
36
+ if step == -1:
37
+ # Check if this is a full flip (equivalent to slice(None, None, -1))
38
+ if start == dim_size - 1 and stop == -1:
39
+ # Full flip: reverse the chunks
40
+ return chunks[::-1]
41
+ else:
42
+ # Partial negative step: fall back to single chunk
43
+ new_size = len(range(start, stop, step))
44
+ return (new_size,)
45
+
46
+ if step != 1:
47
+ # For non-unit step (other than -1), fall back to single chunk
48
+ new_size = len(range(start, stop, step))
49
+ return (new_size,)
50
+
51
+ # Handle empty slice - return single chunk of size 0
52
+ if start >= stop:
53
+ return (0,)
54
+
55
+ # Find chunks that overlap with [start, stop)
56
+ result = []
57
+ pos = 0
58
+ for chunk_size in chunks:
59
+ chunk_start = pos
60
+ chunk_end = pos + chunk_size
61
+ pos = chunk_end
62
+
63
+ # Skip chunks entirely before the slice
64
+ if chunk_end <= start:
65
+ continue
66
+ # Stop at chunks entirely after the slice
67
+ if chunk_start >= stop:
68
+ break
69
+
70
+ # Compute the portion of this chunk included in the slice
71
+ included_start = max(chunk_start, start)
72
+ included_end = min(chunk_end, stop)
73
+ result.append(included_end - included_start)
74
+
75
+ return tuple(result) if result else (0,)
76
+
77
+
78
+ def slice_with_int_dask_array(x, index):
79
+ """Slice x with at most one 1D dask arrays of ints.
80
+
81
+ This is a helper function of :meth:`Array.__getitem__`.
82
+
83
+ Parameters
84
+ ----------
85
+ x: Array
86
+ index: tuple with as many elements as x.ndim, among which there are
87
+ one or more Array's with dtype=int
88
+
89
+ Returns
90
+ -------
91
+ tuple of (sliced x, new index)
92
+
93
+ where the new index is the same as the input, but with slice(None)
94
+ replaced to the original slicer where a 1D filter has been applied and
95
+ one less element where a zero-dimensional filter has been applied.
96
+ """
97
+ from dask_array._collection import Array
98
+
99
+ assert len(index) == x.ndim
100
+ fancy_indexes = [
101
+ isinstance(idx, (tuple, list)) or (isinstance(idx, (np.ndarray, Array)) and idx.ndim > 0) for idx in index
102
+ ]
103
+ if sum(fancy_indexes) > 1:
104
+ raise NotImplementedError("Don't yet support nd fancy indexing")
105
+
106
+ out_index = []
107
+ dropped_axis_cnt = 0
108
+ for in_axis, idx in enumerate(index):
109
+ out_axis = in_axis - dropped_axis_cnt
110
+ if isinstance(idx, Array) and idx.dtype.kind in "iu":
111
+ if idx.ndim == 0:
112
+ idx = idx[np.newaxis]
113
+ x = slice_with_int_dask_array_on_axis(x, idx, out_axis)
114
+ x = x[tuple(0 if i == out_axis else slice(None) for i in range(x.ndim))]
115
+ dropped_axis_cnt += 1
116
+ elif idx.ndim == 1:
117
+ x = slice_with_int_dask_array_on_axis(x, idx, out_axis)
118
+ out_index.append(slice(None))
119
+ else:
120
+ raise NotImplementedError(
121
+ "Slicing with dask.array of ints only permitted when the indexer has zero or one dimensions"
122
+ )
123
+ else:
124
+ out_index.append(idx)
125
+ return x, tuple(out_index)
126
+
127
+
128
+ def normalize_index(idx, shape):
129
+ """Normalize slicing indexes
130
+
131
+ 1. Replaces ellipses with many full slices
132
+ 2. Adds full slices to end of index
133
+ 3. Checks bounding conditions
134
+ 4. Replace multidimensional numpy arrays with dask arrays
135
+ 5. Replaces numpy arrays with lists
136
+ 6. Posify's integers and lists
137
+ 7. Normalizes slices to canonical form
138
+
139
+ Examples
140
+ --------
141
+ >>> normalize_index(1, (10,))
142
+ (1,)
143
+ >>> normalize_index(-1, (10,))
144
+ (9,)
145
+ >>> normalize_index([-1], (10,))
146
+ (array([9]),)
147
+ >>> normalize_index(slice(-3, 10, 1), (10,))
148
+ (slice(7, None, None),)
149
+ >>> normalize_index((Ellipsis, None), (10,))
150
+ (slice(None, None, None), None)
151
+ >>> normalize_index(np.array([[True, False], [False, True], [True, True]]), (3, 2))
152
+ (dask.array<array, shape=(3, 2), dtype=bool, chunksize=(3, 2), chunktype=numpy.ndarray>,)
153
+ """
154
+ from dask_array._collection import Array, from_array
155
+
156
+ if not isinstance(idx, tuple):
157
+ idx = (idx,)
158
+
159
+ # if a > 1D numpy.array is provided, cast it to a dask array
160
+ if len(idx) > 0 and len(shape) > 1:
161
+ i = idx[0]
162
+ if is_arraylike(i) and not isinstance(i, Array) and i.shape == shape:
163
+ idx = (from_array(i), *idx[1:])
164
+
165
+ idx = replace_ellipsis(len(shape), idx)
166
+ n_sliced_dims = 0
167
+ for i in idx:
168
+ if hasattr(i, "ndim") and i.ndim >= 1:
169
+ n_sliced_dims += i.ndim
170
+ elif i is None:
171
+ continue
172
+ else:
173
+ n_sliced_dims += 1
174
+
175
+ idx = idx + (slice(None),) * (len(shape) - n_sliced_dims)
176
+ if len([i for i in idx if i is not None]) > len(shape):
177
+ raise IndexError("Too many indices for array")
178
+
179
+ none_shape = []
180
+ i = 0
181
+ for ind in idx:
182
+ if ind is not None:
183
+ none_shape.append(shape[i])
184
+ i += 1
185
+ else:
186
+ none_shape.append(None)
187
+
188
+ for axis, (i, d) in enumerate(zip(idx, none_shape)):
189
+ if d is not None:
190
+ check_index(axis, i, d)
191
+ idx = tuple(map(sanitize_index, idx))
192
+ idx = tuple(map(normalize_slice, idx, none_shape))
193
+ idx = posify_index(none_shape, idx)
194
+ return idx
195
+
196
+
197
+ def slice_with_int_dask_array_on_axis(x, idx, axis):
198
+ """Slice a ND dask array with a 1D dask arrays of ints along the given
199
+ axis.
200
+
201
+ This is a helper function of :func:`slice_with_int_dask_array`.
202
+ """
203
+ from dask_array import _chunk as chunk
204
+ from dask_array._collection import blockwise
205
+
206
+ assert 0 <= axis < x.ndim
207
+
208
+ if np.isnan(x.chunks[axis]).any():
209
+ raise NotImplementedError("Slicing an array with unknown chunks with a dask.array of ints is not supported")
210
+ x_axes = tuple(range(x.ndim))
211
+ idx_axes = (x.ndim,) # arbitrary index not already in x_axes
212
+ offset_axes = (axis,)
213
+
214
+ # Calculate the offset at which each chunk starts along axis
215
+ # e.g. chunks=(..., (5, 3, 4), ...) -> offset=[0, 5, 8]
216
+ offset = np.roll(np.cumsum(np.asarray(x.chunks[axis], like=x._meta)), 1)
217
+ offset[0] = 0
218
+ # ArrayOffsetDep needs 1D chunks matching x.chunks[axis], not full x.chunks
219
+ offset = ArrayOffsetDep((x.chunks[axis],), offset)
220
+
221
+ p_axes = x_axes[: axis + 1] + idx_axes + x_axes[axis + 1 :]
222
+ y_axes = x_axes[:axis] + idx_axes + x_axes[axis + 1 :]
223
+
224
+ # Calculate the cartesian product of every chunk of x vs every chunk of idx
225
+ p = blockwise(
226
+ chunk.slice_with_int_dask_array,
227
+ p_axes,
228
+ x,
229
+ x_axes,
230
+ idx,
231
+ idx_axes,
232
+ offset,
233
+ offset_axes,
234
+ x_size=x.shape[axis],
235
+ axis=axis,
236
+ dtype=x.dtype,
237
+ meta=x._meta,
238
+ )
239
+
240
+ # Aggregate on the chunks of x along axis
241
+ y = blockwise(
242
+ chunk.slice_with_int_dask_array_aggregate,
243
+ y_axes,
244
+ idx,
245
+ idx_axes,
246
+ p,
247
+ p_axes,
248
+ concatenate=True,
249
+ x_chunks=x.chunks[axis],
250
+ axis=axis,
251
+ dtype=x.dtype,
252
+ meta=x._meta,
253
+ )
254
+ return y
255
+
256
+
257
+ class ArrayOffsetDep(ArrayBlockwiseDep):
258
+ """1D BlockwiseDep that provides chunk offset values."""
259
+
260
+ def __init__(self, chunks: tuple[tuple[int, ...], ...], values: np.ndarray | dict):
261
+ super().__init__(chunks)
262
+ self.values = values
263
+
264
+ def __getitem__(self, idx: tuple):
265
+ return self.values[idx[0]]
266
+
267
+
268
+ def slice_array(x, index):
269
+ """
270
+ slice_with_newaxis : handle None/newaxis case
271
+ slice_wrap_lists : handle fancy indexing with lists
272
+ slice_slices_and_integers : handle everything else
273
+ """
274
+ if all(isinstance(index, slice) and index == slice(None, None, None) for index in index):
275
+ # all none slices
276
+ return x.expr
277
+
278
+ # Add in missing colons at the end as needed. x[5] -> x[5, :, :]
279
+ not_none_count = sum(i is not None for i in index)
280
+ missing = len(x.chunks) - not_none_count
281
+ index += (slice(None, None, None),) * missing
282
+ return slice_with_newaxes(x, index)
283
+
284
+
285
+ def slice_with_newaxes(x, index):
286
+ """
287
+ Handle indexing with Nones
288
+
289
+ Strips out Nones then hands off to slice_wrap_lists, then wraps
290
+ result with ExpandDims if needed.
291
+ """
292
+ from dask_array.manipulation._expand import ExpandDims
293
+
294
+ # Strip Nones from index
295
+ index2 = tuple(ind for ind in index if ind is not None)
296
+ where_none = [i for i, ind in enumerate(index) if ind is None]
297
+ for i, xx in enumerate(where_none):
298
+ n = sum(isinstance(ind, Integral) for ind in index[:xx])
299
+ if n:
300
+ where_none[i] -= n
301
+
302
+ # Pass down and do work
303
+ x = slice_wrap_lists(x, index2, not where_none)
304
+
305
+ if where_none:
306
+ return ExpandDims(x, tuple(where_none))
307
+ else:
308
+ return x
309
+
310
+
311
+ def slice_wrap_lists(x, index, allow_getitem_optimization):
312
+ """
313
+ Fancy indexing along blocked array dasks
314
+
315
+ Handles index of type list. Calls slice_slices_and_integers for the rest
316
+
317
+ See Also
318
+ --------
319
+
320
+ take : handle slicing with lists ("fancy" indexing)
321
+ slice_slices_and_integers : handle slicing with slices and integers
322
+ """
323
+ assert all(isinstance(i, (slice, list, Integral)) or is_arraylike(i) for i in index)
324
+ if not len(x.chunks) == len(index):
325
+ raise IndexError("Too many indices for array")
326
+
327
+ # Do we have more than one list in the index?
328
+ where_list = [i for i, ind in enumerate(index) if is_arraylike(ind) and ind.ndim > 0]
329
+ if len(where_list) > 1:
330
+ raise NotImplementedError("Don't yet support nd fancy indexing")
331
+ # Is the single list an empty list? In this case just treat it as a zero
332
+ # length slice
333
+ if where_list and not index[where_list[0]].size:
334
+ index = list(index)
335
+ index[where_list.pop()] = slice(0, 0, 1)
336
+ index = tuple(index)
337
+
338
+ # No lists, hooray! just use slice_slices_and_integers
339
+ if not where_list:
340
+ return slice_slices_and_integers(x, index, allow_getitem_optimization)
341
+
342
+ # Replace all lists with full slices [3, 1, 0] -> slice(None, None, None)
343
+ index_without_list = tuple(slice(None, None, None) if is_arraylike(i) else i for i in index)
344
+
345
+ # lists and full slices. Just use take
346
+ if all(is_arraylike(i) or i == slice(None, None, None) for i in index):
347
+ axis = where_list[0]
348
+ x = take(x, index[where_list[0]], axis=axis)
349
+ # Mixed case. Both slices/integers and lists. slice/integer then take
350
+ else:
351
+ x = slice_slices_and_integers(
352
+ x,
353
+ index_without_list,
354
+ allow_getitem_optimization=False,
355
+ )
356
+ axis = where_list[0]
357
+ axis2 = axis - sum(1 for i, ind in enumerate(index) if i < axis and isinstance(ind, Integral))
358
+ x = take(x, index[axis], axis=axis2)
359
+
360
+ return x
361
+
362
+
363
+ def slice_slices_and_integers(x, index, allow_getitem_optimization=False):
364
+ from dask_array._core_utils import unknown_chunk_message
365
+
366
+ shape = tuple(cached_cumsum(dim, initial_zero=True)[-1] for dim in x.chunks)
367
+
368
+ for dim, ind in zip(shape, index):
369
+ if np.isnan(dim) and ind != slice(None, None, None):
370
+ raise ValueError(f"Arrays chunk sizes are unknown: {shape}{unknown_chunk_message}")
371
+ assert all(isinstance(ind, (slice, Integral)) for ind in index)
372
+ return SliceSlicesIntegers(x, index, allow_getitem_optimization)
373
+
374
+
375
+ def take(x, index, axis=0):
376
+ from dask.base import is_dask_collection
377
+
378
+ if not np.isnan(x.chunks[axis]).any():
379
+ from dask_array._shuffle import _shuffle
380
+ from dask_array._utils import arange_safe, asarray_safe
381
+
382
+ # No-op check only for numpy arrays (dask array comparison triggers warnings)
383
+ # Use is_dask_collection to catch both array-expr and legacy dask Arrays
384
+ if not is_dask_collection(index):
385
+ arange = arange_safe(np.sum(x.chunks[axis]), like=index)
386
+ if len(index) == len(arange) and np.abs(index - arange).sum() == 0:
387
+ return x
388
+
389
+ # If index is a dask collection, use lazy blockwise approach
390
+ if is_dask_collection(index):
391
+ return slice_with_int_dask_array_on_axis(x, index, axis)
392
+
393
+ index = asarray_safe(index, like=index)
394
+
395
+ # Compute indexer by grouping consecutive indices from same input chunk
396
+ from dask_array.slicing._vindex import _compute_indexer
397
+
398
+ indexer = _compute_indexer(index, x.chunks[axis])
399
+ return _shuffle(x, indexer, axis, "getitem-")
400
+ elif len(x.chunks[axis]) == 1:
401
+ return TakeUnknownOneChunk(x, index, axis)
402
+ else:
403
+ from dask_array._core_utils import unknown_chunk_message
404
+
405
+ raise ValueError(f"Array chunk size or shape is unknown. {unknown_chunk_message}")
406
+
407
+
408
+ class Slice(ArrayExpr):
409
+ @functools.cached_property
410
+ def _name(self):
411
+ return f"getitem-{self.deterministic_token}"
412
+
413
+ @functools.cached_property
414
+ def _meta(self):
415
+ if self.array._meta is None:
416
+ meta = meta_from_array(None, ndim=len(self.chunks), dtype=self.array.dtype)
417
+ else:
418
+ meta = meta_from_array(self.array._meta, ndim=len(self.chunks))
419
+ if np.isscalar(meta):
420
+ meta = np.array(meta)
421
+ return meta
422
+
423
+
424
+ class SliceSlicesIntegers(Slice):
425
+ _parameters = ["array", "index", "allow_getitem_optimization"]
426
+
427
+ def _simplify_down(self):
428
+ # Slice(Slice(x)) -> single Slice with fused indices
429
+ if isinstance(self.array, SliceSlicesIntegers):
430
+ try:
431
+ fused = fuse_slice(self.array.index, self.index)
432
+ normalized = tuple(
433
+ normalize_slice(idx, dim) if isinstance(idx, slice) else idx
434
+ for idx, dim in zip(fused, self.array.array.shape)
435
+ )
436
+ return SliceSlicesIntegers(self.array.array, normalized, self.allow_getitem_optimization)
437
+ except NotImplementedError:
438
+ # Skip fusion for unsupported slicing patterns (e.g., negative step)
439
+ pass
440
+
441
+ # Check if the array implements _accept_slice (for operations like Elemwise,
442
+ # Transpose, Blockwise, PartialReduce, ExpandDims that use the simplify_up pattern).
443
+ if hasattr(self.array, "_accept_slice"):
444
+ result = self.array._accept_slice(self)
445
+ if result is not None:
446
+ return result
447
+
448
+ def _slice_chunks(self, chunks, start, length):
449
+ """Compute new chunks after slicing."""
450
+ result = []
451
+ cumsum = 0
452
+ for chunk_size in chunks:
453
+ chunk_start = cumsum
454
+ chunk_end = cumsum + chunk_size
455
+ cumsum = chunk_end
456
+
457
+ if chunk_end <= start:
458
+ continue
459
+ if chunk_start >= start + length:
460
+ break
461
+
462
+ overlap_start = max(start, chunk_start)
463
+ overlap_end = min(start + length, chunk_end)
464
+ overlap_size = overlap_end - overlap_start
465
+ if overlap_size > 0:
466
+ result.append(overlap_size)
467
+
468
+ return tuple(result) if result else (0,)
469
+
470
+ @functools.cached_property
471
+ def chunks(self):
472
+ new_blockdims = [
473
+ new_blockdim(d, db, i)
474
+ for d, i, db in zip(self.array.shape, self.index, self.array.chunks)
475
+ if not isinstance(i, Integral)
476
+ ]
477
+ return tuple(map(tuple, new_blockdims))
478
+
479
+ def _layer(self) -> dict:
480
+ # Get a list (for each dimension) of dicts{blocknum: slice()}
481
+ block_slices = list(map(_slice_1d, self.array.shape, self.array.chunks, self.index))
482
+ sorted_block_slices = [sorted(i.items()) for i in block_slices]
483
+
484
+ # (in_name, 1, 1, 2), (in_name, 1, 1, 4), (in_name, 2, 1, 2), ...
485
+ in_names = list(product([self.array._name], *[pluck(0, s) for s in sorted_block_slices]))
486
+
487
+ # (out_name, 0, 0, 0), (out_name, 0, 0, 1), (out_name, 0, 1, 0), ...
488
+ out_names = list(
489
+ product(
490
+ [self._name],
491
+ *[
492
+ range(len(d))[::-1] if i.step and i.step < 0 else range(len(d))
493
+ for d, i in zip(block_slices, self.index)
494
+ if not isinstance(i, Integral)
495
+ ],
496
+ )
497
+ )
498
+
499
+ all_slices = list(product(*[pluck(1, s) for s in sorted_block_slices]))
500
+
501
+ dsk_out = {
502
+ out_name: (
503
+ Task(out_name, getitem, TaskRef(in_name), slices)
504
+ if not self.allow_getitem_optimization or not all(sl == slice(None, None, None) for sl in slices)
505
+ else Alias(out_name, in_name)
506
+ )
507
+ for out_name, in_name, slices in zip(out_names, in_names, all_slices)
508
+ }
509
+ return dsk_out
510
+
511
+
512
+ def _compose_slices(outer_slice, inner_slice, dim_size):
513
+ """Compose two slices: inner_slice is relative to outer_slice's result."""
514
+ # Get the range of the outer slice
515
+ outer_start, outer_stop, outer_step = outer_slice.indices(dim_size)
516
+ outer_len = len(range(outer_start, outer_stop, outer_step))
517
+
518
+ # Get the range of the inner slice relative to outer's result
519
+ inner_start, inner_stop, inner_step = inner_slice.indices(outer_len)
520
+
521
+ # Compose: offset inner by outer_start
522
+ if outer_step != 1 or inner_step != 1:
523
+ new_start = outer_start + inner_start * outer_step
524
+ new_stop = outer_start + inner_stop * outer_step
525
+ new_step = outer_step * inner_step
526
+ else:
527
+ new_start = outer_start + inner_start
528
+ new_stop = outer_start + inner_stop
529
+ new_step = 1
530
+
531
+ return slice(new_start, new_stop, new_step if new_step != 1 else None)
532
+
533
+
534
+ class TakeUnknownOneChunk(Slice):
535
+ _parameters = ["array", "index", "axis"]
536
+
537
+ @functools.cached_property
538
+ def chunks(self):
539
+ return self.array.chunks
540
+
541
+ def _layer(self) -> dict:
542
+ slices = [slice(None)] * len(self.array.chunks)
543
+ slices[self.axis] = list(self.index)
544
+ sl = tuple(slices)
545
+ chunk_tuples = list(product(*(range(len(c)) for i, c in enumerate(self.array.chunks))))
546
+ dsk = {
547
+ (self._name,) + ct: Task((self._name,) + ct, getitem, TaskRef((self.array.name,) + ct), sl)
548
+ for ct in chunk_tuples
549
+ }
550
+ return dsk