dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,138 @@
1
+ """Block indexing expression for x.blocks[...] access."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+ import math
7
+ from itertools import product
8
+ from numbers import Number
9
+
10
+ import numpy as np
11
+
12
+ from dask_array._new_collection import new_collection
13
+ from dask._task_spec import Alias
14
+ from dask_array._expr import ArrayExpr
15
+ from dask_array.slicing._utils import normalize_index
16
+
17
+
18
+ class BlockView:
19
+ """An array-like interface to the blocks of an array.
20
+
21
+ BlockView provides an array-like interface to the blocks of a dask array.
22
+ Numpy-style indexing of a BlockView returns a selection of blocks as a
23
+ new dask array.
24
+
25
+ You can index BlockView like a numpy array of shape equal to the number
26
+ of blocks in each dimension (available as array.blocks.size). The
27
+ dimensionality of the output array matches the dimension of this array,
28
+ even if integer indices are passed. Slicing with np.newaxis or multiple
29
+ lists is not supported.
30
+ """
31
+
32
+ __slots__ = ("_array",)
33
+
34
+ def __init__(self, array):
35
+ self._array = array
36
+
37
+ def __getitem__(self, index):
38
+ return new_collection(blocks_getitem(self._array.expr, index))
39
+
40
+ def __eq__(self, other):
41
+ # Check if other is any BlockView type (including legacy)
42
+ if hasattr(other, "_array") and type(other).__name__ == "BlockView":
43
+ return self._array is other._array
44
+ return NotImplemented
45
+
46
+ @property
47
+ def size(self):
48
+ """The total number of blocks in the array."""
49
+ return math.prod(self.shape)
50
+
51
+ @property
52
+ def shape(self):
53
+ """The number of blocks per axis. Alias of dask.array.numblocks."""
54
+ return self._array.numblocks
55
+
56
+ def ravel(self):
57
+ """Return a flattened list of all the blocks in the array in C order."""
58
+ return [self[idx] for idx in np.ndindex(self.shape)]
59
+
60
+
61
+ class Blocks(ArrayExpr):
62
+ """Expression for block-based indexing (x.blocks[...]).
63
+
64
+ This expression allows accessing array blocks by block index rather than
65
+ element index. The index is normalized to always use slices (never integers)
66
+ to preserve dimensionality.
67
+
68
+ Parameters
69
+ ----------
70
+ array : ArrayExpr
71
+ The source array expression
72
+ index : tuple
73
+ Normalized block indices (after converting integers to length-1 slices)
74
+ """
75
+
76
+ _parameters = ["array", "index"]
77
+
78
+ @functools.cached_property
79
+ def _name(self):
80
+ return f"blocks-{self.deterministic_token}"
81
+
82
+ @functools.cached_property
83
+ def _meta(self):
84
+ return self.array._meta
85
+
86
+ @functools.cached_property
87
+ def chunks(self):
88
+ """Compute chunks by selecting from the source array's chunks."""
89
+ return tuple(tuple(np.array(c)[idx].tolist()) for c, idx in zip(self.array.chunks, self.index))
90
+
91
+ def _layer(self) -> dict:
92
+ """Generate the task graph layer.
93
+
94
+ Each output block is an alias to the corresponding input block.
95
+ """
96
+ # Pre-compute index mappings for each dimension
97
+ index_maps = [np.arange(n)[idx] for n, idx in zip(self.array.numblocks, self.index)]
98
+
99
+ dsk = {}
100
+ for out_key in product(*(range(len(c)) for c in self.chunks)):
101
+ in_key = tuple(int(m[i]) for m, i in zip(index_maps, out_key))
102
+ out_name = (self._name,) + out_key
103
+ in_name = (self.array._name,) + in_key
104
+ dsk[out_name] = Alias(out_name, in_name)
105
+
106
+ return dsk
107
+
108
+
109
+ def blocks_getitem(array, index):
110
+ """Create a Blocks expression for block indexing.
111
+
112
+ Parameters
113
+ ----------
114
+ array : ArrayExpr
115
+ The source array expression
116
+ index : tuple
117
+ The block index (may contain integers, slices, or lists)
118
+
119
+ Returns
120
+ -------
121
+ Blocks
122
+ The blocks expression
123
+ """
124
+ if not isinstance(index, tuple):
125
+ index = (index,)
126
+
127
+ if sum(isinstance(ind, (np.ndarray, list)) for ind in index) > 1:
128
+ raise ValueError("Can only slice with a single list")
129
+ if any(ind is None for ind in index):
130
+ raise ValueError("Slicing with np.newaxis or None is not supported")
131
+
132
+ # Normalize index to array's numblocks
133
+ index = normalize_index(index, array.numblocks)
134
+
135
+ # Convert integers to length-1 slices to preserve dimensionality
136
+ index = tuple(slice(k, k + 1) if isinstance(k, Number) else k for k in index)
137
+
138
+ return Blocks(array, index)
@@ -0,0 +1,145 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import warnings
5
+ from operator import getitem
6
+
7
+ import numpy as np
8
+
9
+ from dask._task_spec import Alias
10
+ from dask_array._expr import ArrayExpr
11
+ from dask_array._utils import meta_from_array
12
+
13
+
14
+ def getitem_variadic(x, *index):
15
+ """Helper function for boolean indexing."""
16
+ return x[index]
17
+
18
+
19
+ def slice_with_bool_dask_array(x, index):
20
+ """Slice x with one or more dask arrays of bools.
21
+
22
+ This is a helper function of :meth:`Array.__getitem__`.
23
+
24
+ Parameters
25
+ ----------
26
+ x: Array
27
+ index: tuple with as many elements as x.ndim, among which there are
28
+ one or more Array's with dtype=bool
29
+
30
+ Returns
31
+ -------
32
+ tuple of (sliced x, new index)
33
+
34
+ where the new index is the same as the input, but with slice(None)
35
+ replaced to the original slicer when a filter has been applied.
36
+
37
+ Note: The sliced x will have nan chunks on the sliced axes.
38
+ """
39
+ from dask_array._collection import (
40
+ Array,
41
+ blockwise,
42
+ elemwise,
43
+ new_collection,
44
+ )
45
+ from dask_array._expr import ChunksOverride
46
+
47
+ out_index = [slice(None) if isinstance(ind, Array) and ind.dtype == bool else ind for ind in index]
48
+
49
+ # Case 1: Full-dimensional boolean mask
50
+ if len(index) == 1 and index[0].ndim == x.ndim:
51
+ if not np.isnan(x.shape).any() and not np.isnan(index[0].shape).any():
52
+ x = x.ravel()
53
+ index = tuple(i.ravel() for i in index)
54
+ elif x.ndim > 1:
55
+ warnings.warn(
56
+ "When slicing a Dask array of unknown chunks with a boolean mask "
57
+ "Dask array, the output array may have a different ordering "
58
+ "compared to the equivalent NumPy operation. This will raise an "
59
+ "error in a future release of Dask.",
60
+ stacklevel=3,
61
+ )
62
+ # Use elemwise to apply getitem across blocks
63
+ y = elemwise(getitem, x, index[0], dtype=x.dtype)
64
+ # Trigger eager chunk validation to match legacy behavior
65
+ # This will raise if x and index have incompatible chunks
66
+ _ = y.chunks
67
+ result = BooleanIndexFlattened(y.expr)
68
+ return new_collection(result), out_index
69
+
70
+ # Case 2: 1D boolean arrays on specific dimensions
71
+ if any(isinstance(ind, Array) and ind.dtype == bool and ind.ndim != 1 for ind in index):
72
+ raise NotImplementedError(
73
+ "Slicing with dask.array of bools only permitted when "
74
+ "the indexer has only one dimension or when "
75
+ "it has the same dimension as the sliced "
76
+ "array"
77
+ )
78
+
79
+ indexes = [ind if isinstance(ind, Array) and ind.dtype == bool else slice(None) for ind in index]
80
+
81
+ # Track which dimension indices have boolean arrays
82
+ dsk_ind = []
83
+
84
+ from toolz import concat
85
+
86
+ arginds = []
87
+ i = 0
88
+ for dim, ind in enumerate(indexes):
89
+ if isinstance(ind, Array) and ind.dtype == bool:
90
+ dsk_ind.append(dim)
91
+ new = (ind, tuple(range(i, i + ind.ndim)))
92
+ i += x.ndim
93
+ else:
94
+ new = (slice(None), None)
95
+ i += 1
96
+ arginds.append(new)
97
+
98
+ arginds = list(concat(arginds))
99
+
100
+ out = blockwise(
101
+ getitem_variadic,
102
+ tuple(range(x.ndim)),
103
+ x,
104
+ tuple(range(x.ndim)),
105
+ *arginds,
106
+ dtype=x.dtype,
107
+ )
108
+
109
+ # For boolean indexing, override chunks on boolean-indexed dimensions
110
+ # with nan values since the output size is unknown
111
+ new_chunks = tuple(
112
+ tuple(np.nan for _ in range(len(c))) if dim in dsk_ind else c for dim, c in enumerate(out.chunks)
113
+ )
114
+ result = ChunksOverride(out.expr, new_chunks)
115
+ return new_collection(result), tuple(out_index)
116
+
117
+
118
+ class BooleanIndexFlattened(ArrayExpr):
119
+ """Flattens the output of a full-dimensional boolean index operation."""
120
+
121
+ _parameters = ["array"]
122
+
123
+ @functools.cached_property
124
+ def _name(self):
125
+ return f"getitem-{self.deterministic_token}"
126
+
127
+ @functools.cached_property
128
+ def _meta(self):
129
+ return meta_from_array(self.array._meta, ndim=1)
130
+
131
+ @functools.cached_property
132
+ def chunks(self):
133
+ # Total number of blocks = product of numblocks
134
+ from functools import reduce
135
+ from operator import mul
136
+
137
+ nblocks = reduce(mul, self.array.numblocks, 1)
138
+ return ((np.nan,) * nblocks,)
139
+
140
+ def _layer(self) -> dict:
141
+ from dask.base import flatten
142
+
143
+ # Flatten the keys from the elemwise result
144
+ keys = list(flatten(self.array.__dask_keys__()))
145
+ return {(self._name, i): Alias((self._name, i), k) for i, k in enumerate(keys)}
@@ -0,0 +1,329 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import math
5
+ from itertools import product
6
+
7
+ import numpy as np
8
+
9
+ from dask._task_spec import Alias, List, Task, TaskRef
10
+ from dask_array._expr import ArrayExpr
11
+ from dask_array._utils import meta_from_array
12
+ from dask_array._core_utils import concatenate3 as concatenate_shaped
13
+ from dask_array.slicing._utils import parse_assignment_indices, setitem
14
+ from dask.base import is_dask_collection
15
+ from dask.core import flatten
16
+ from dask.utils import cached_cumsum
17
+
18
+
19
+ def setitem_array_expr(out_name, array, indices, value):
20
+ """Array-expr version of setitem_array that generates Task objects directly.
21
+
22
+ This function creates a new dask graph that assigns values to each block
23
+ that is touched by the indices, leaving other blocks unchanged.
24
+ """
25
+ array_shape = array.shape
26
+ value_shape = value.shape
27
+ value_ndim = len(value_shape)
28
+
29
+ # Reformat input indices
30
+ indices, implied_shape, reverse, implied_shape_positions = parse_assignment_indices(indices, array_shape)
31
+
32
+ # Empty slices can only be assigned size 1 values
33
+ if 0 in implied_shape and value_shape and max(value_shape) > 1:
34
+ raise ValueError(
35
+ f"shape mismatch: value array of shape {value_shape} "
36
+ "could not be broadcast to indexing result "
37
+ f"of shape {tuple(implied_shape)}"
38
+ )
39
+
40
+ # Set variables needed when creating the part of the assignment value
41
+ offset = len(implied_shape) - value_ndim
42
+ if offset >= 0:
43
+ array_common_shape = implied_shape[offset:]
44
+ value_common_shape = value_shape
45
+ value_offset = 0
46
+ reverse = [i - offset for i in reverse if i >= offset]
47
+ else:
48
+ value_offset = -offset
49
+ array_common_shape = implied_shape
50
+ value_common_shape = value_shape[value_offset:]
51
+ offset = 0
52
+ if value_shape[:value_offset] != (1,) * value_offset:
53
+ raise ValueError(
54
+ f"could not broadcast input array from shape{value_shape} into shape {tuple(implied_shape)}"
55
+ )
56
+
57
+ base_value_indices = []
58
+ non_broadcast_dimensions = []
59
+
60
+ for i, (a, b, j) in enumerate(zip(array_common_shape, value_common_shape, implied_shape_positions)):
61
+ index = indices[j]
62
+ if is_dask_collection(index) and index.dtype == bool:
63
+ if math.isnan(b) or b <= index.size:
64
+ base_value_indices.append(None)
65
+ non_broadcast_dimensions.append(i)
66
+ else:
67
+ raise ValueError(
68
+ f"shape mismatch: value array dimension size of {b} is "
69
+ "greater then corresponding boolean index size of "
70
+ f"{index.size}"
71
+ )
72
+ continue
73
+
74
+ if b == 1:
75
+ base_value_indices.append(slice(None))
76
+ elif a == b:
77
+ base_value_indices.append(None)
78
+ non_broadcast_dimensions.append(i)
79
+ elif math.isnan(a):
80
+ base_value_indices.append(None)
81
+ non_broadcast_dimensions.append(i)
82
+ else:
83
+ raise ValueError(
84
+ f"shape mismatch: value array of shape {value_shape} "
85
+ "could not be broadcast to indexing result of shape "
86
+ f"{tuple(implied_shape)}"
87
+ )
88
+
89
+ # Translate chunks tuple to array locations
90
+ chunks = array.chunks
91
+ cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
92
+ array_locations = [[(s, s + dim) for s, dim in zip(starts, shapes)] for starts, shapes in zip(cumdims, chunks)]
93
+ array_locations = product(*array_locations)
94
+
95
+ in_keys = list(flatten(array.__dask_keys__()))
96
+
97
+ # Build graph with Task objects
98
+ dsk = {}
99
+ out_name_tuple = (out_name,)
100
+
101
+ # Helper closures for index handling (simplified from legacy)
102
+ def block_index_from_1d_index(index, loc0, loc1, is_bool):
103
+ if is_bool:
104
+ return index[loc0:loc1]
105
+ elif is_dask_collection(index):
106
+ i = np.where((loc0 <= index) & (index < loc1), index, loc1)
107
+ return i - loc0
108
+ else:
109
+ i = np.where((loc0 <= index) & (index < loc1))[0]
110
+ return index[i] - loc0
111
+
112
+ def block_index_shape_from_1d_bool_index(index, loc0, loc1):
113
+ return np.sum(index[loc0:loc1])
114
+
115
+ def n_preceding_from_1d_bool_index(index, loc0):
116
+ return np.sum(index[:loc0])
117
+
118
+ def value_indices_from_1d_int_index(index, vsize, loc0, loc1):
119
+ if is_dask_collection(index):
120
+ if np.isnan(index.size):
121
+ i = np.where((loc0 <= index) & (index < loc1), True, False)
122
+ i = concatenate_array_chunks_expr(i)
123
+ i._chunks = ((vsize,),)
124
+ else:
125
+ i = np.where((loc0 <= index) & (index < loc1))[0]
126
+ i = concatenate_array_chunks_expr(i)
127
+ else:
128
+ i = np.where((loc0 <= index) & (index < loc1))[0]
129
+ return i
130
+
131
+ for in_key, locations in zip(in_keys, array_locations):
132
+ block_indices = []
133
+ block_indices_shape = []
134
+ block_preceding_sizes = []
135
+ overlaps = True
136
+ dim_1d_int_index = None
137
+
138
+ for dim, (index, (loc0, loc1)) in enumerate(zip(indices, locations)):
139
+ integer_index = isinstance(index, int)
140
+ if isinstance(index, slice):
141
+ stop = loc1 - loc0
142
+ if index.stop < loc1:
143
+ stop -= loc1 - index.stop
144
+ start = index.start - loc0
145
+ if start < 0:
146
+ start %= index.step
147
+ if start >= stop:
148
+ overlaps = False
149
+ break
150
+ step = index.step
151
+ block_index = slice(start, stop, step)
152
+ block_index_size, rem = divmod(stop - start, step)
153
+ if rem:
154
+ block_index_size += 1
155
+ pre = index.indices(loc0)
156
+ n_preceding, rem = divmod(pre[1] - pre[0], step)
157
+ if rem:
158
+ n_preceding += 1
159
+ elif integer_index:
160
+ if not loc0 <= index < loc1:
161
+ overlaps = False
162
+ break
163
+ block_index = index - loc0
164
+ else:
165
+ is_bool = index.dtype == bool
166
+ block_index = block_index_from_1d_index(index, loc0, loc1, is_bool)
167
+ if is_bool:
168
+ block_index_size = block_index_shape_from_1d_bool_index(index, loc0, loc1)
169
+ n_preceding = n_preceding_from_1d_bool_index(index, loc0)
170
+ else:
171
+ block_index_size = None
172
+ n_preceding = None
173
+ dim_1d_int_index = dim
174
+ loc0_loc1 = loc0, loc1
175
+
176
+ if not is_dask_collection(index) and not block_index.size:
177
+ overlaps = False
178
+ break
179
+
180
+ block_indices.append(block_index)
181
+ if not integer_index:
182
+ block_indices_shape.append(block_index_size)
183
+ block_preceding_sizes.append(n_preceding)
184
+
185
+ out_key = out_name_tuple + in_key[1:]
186
+
187
+ if not overlaps:
188
+ dsk[out_key] = Alias(out_key, in_key)
189
+ continue
190
+
191
+ # Build value indices for this block
192
+ value_indices = base_value_indices[:]
193
+ for i in non_broadcast_dimensions:
194
+ j = i + offset
195
+ if j == dim_1d_int_index:
196
+ value_indices[i] = value_indices_from_1d_int_index(
197
+ indices[j], value_shape[i + value_offset], *loc0_loc1
198
+ )
199
+ else:
200
+ start = block_preceding_sizes[j]
201
+ value_indices[i] = slice(start, start + block_indices_shape[j])
202
+
203
+ for i in reverse:
204
+ size = value_common_shape[i]
205
+ start, stop, step = value_indices[i].indices(size)
206
+ size -= 1
207
+ start = size - start
208
+ stop = size - stop
209
+ if stop < 0:
210
+ stop = None
211
+ value_indices[i] = slice(start, stop, -1)
212
+
213
+ if value_ndim > len(indices):
214
+ value_indices.insert(0, Ellipsis)
215
+
216
+ # Get the value slice and concatenate to single chunk
217
+ v = value[tuple(value_indices)]
218
+ v = concatenate_array_chunks_expr(v)
219
+ v_key = next(flatten(v.__dask_keys__()))
220
+
221
+ # Merge value's graph into dsk
222
+ dsk.update(dict(v.__dask_graph__()))
223
+
224
+ # Convert block_indices to use TaskRef for any dask keys
225
+ task_block_indices = []
226
+ for idx in block_indices:
227
+ if is_dask_collection(idx):
228
+ idx = concatenate_array_chunks_expr(idx)
229
+ idx_key = next(flatten(idx.__dask_keys__()))
230
+ dsk.update(dict(idx.__dask_graph__()))
231
+ task_block_indices.append(TaskRef(idx_key))
232
+ else:
233
+ task_block_indices.append(idx)
234
+
235
+ # Create Task with proper TaskRef wrappers
236
+ dsk[out_key] = Task(
237
+ out_key,
238
+ setitem,
239
+ TaskRef(in_key),
240
+ TaskRef(v_key),
241
+ List(*task_block_indices),
242
+ )
243
+
244
+ return dsk
245
+
246
+
247
+ class SetItem(ArrayExpr):
248
+ """Expression for array assignment (setitem)."""
249
+
250
+ _parameters = ["array", "index", "value"]
251
+
252
+ @functools.cached_property
253
+ def _name(self):
254
+ return f"setitem-{self.deterministic_token}"
255
+
256
+ @functools.cached_property
257
+ def _meta(self):
258
+ meta = meta_from_array(self.array._meta, ndim=self.array.ndim)
259
+ if np.isscalar(meta):
260
+ meta = np.array(meta)
261
+ return meta
262
+
263
+ @property
264
+ def chunks(self):
265
+ return self.array.chunks
266
+
267
+ def _layer(self) -> dict:
268
+ from dask_array._collection import Array
269
+
270
+ # Wrap expressions as Array for setitem_array_expr
271
+ array = Array(self.array)
272
+ value = Array(self.value) if hasattr(self.value, "_meta") else self.value
273
+
274
+ return setitem_array_expr(self._name, array, self.index, value)
275
+
276
+
277
+ class ConcatenateArrayChunks(ArrayExpr):
278
+ """Concatenate all chunks of an array into a single chunk.
279
+
280
+ This is an array-expr version of dask.array.slicing.concatenate_array_chunks.
281
+ """
282
+
283
+ _parameters = ["array"]
284
+
285
+ @functools.cached_property
286
+ def _name(self):
287
+ return f"concatenate-shaped-{self.deterministic_token}"
288
+
289
+ @functools.cached_property
290
+ def _meta(self):
291
+ return meta_from_array(self.array._meta, ndim=self.array.ndim)
292
+
293
+ @functools.cached_property
294
+ def chunks(self):
295
+ # Single chunk containing all the data
296
+ shape = self.array.shape
297
+ if not shape:
298
+ return ((1,),)
299
+ return tuple((s,) for s in shape)
300
+
301
+ def _layer(self) -> dict:
302
+ from dask.base import flatten
303
+
304
+ # Get all keys from the input array as TaskRefs
305
+ keys = [TaskRef(k) for k in flatten(self.array.__dask_keys__())]
306
+ # Output key has ndim indices, all 0 since we have a single chunk
307
+ out_key = (self._name,) + (0,) * self.array.ndim
308
+
309
+ return {
310
+ out_key: Task(
311
+ out_key,
312
+ concatenate_shaped,
313
+ List(*keys),
314
+ self.array.numblocks,
315
+ )
316
+ }
317
+
318
+
319
+ def concatenate_array_chunks_expr(x):
320
+ """Concatenate all chunks of an array into a single chunk.
321
+
322
+ Array-expr version of dask.array.slicing.concatenate_array_chunks.
323
+ """
324
+ from dask_array._new_collection import new_collection
325
+
326
+ if x.npartitions == 1:
327
+ return x
328
+
329
+ return new_collection(ConcatenateArrayChunks(x.expr))