dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
"""Block indexing expression for x.blocks[...] access."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import functools
|
|
6
|
+
import math
|
|
7
|
+
from itertools import product
|
|
8
|
+
from numbers import Number
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from dask_array._new_collection import new_collection
|
|
13
|
+
from dask._task_spec import Alias
|
|
14
|
+
from dask_array._expr import ArrayExpr
|
|
15
|
+
from dask_array.slicing._utils import normalize_index
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class BlockView:
|
|
19
|
+
"""An array-like interface to the blocks of an array.
|
|
20
|
+
|
|
21
|
+
BlockView provides an array-like interface to the blocks of a dask array.
|
|
22
|
+
Numpy-style indexing of a BlockView returns a selection of blocks as a
|
|
23
|
+
new dask array.
|
|
24
|
+
|
|
25
|
+
You can index BlockView like a numpy array of shape equal to the number
|
|
26
|
+
of blocks in each dimension (available as array.blocks.size). The
|
|
27
|
+
dimensionality of the output array matches the dimension of this array,
|
|
28
|
+
even if integer indices are passed. Slicing with np.newaxis or multiple
|
|
29
|
+
lists is not supported.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
__slots__ = ("_array",)
|
|
33
|
+
|
|
34
|
+
def __init__(self, array):
|
|
35
|
+
self._array = array
|
|
36
|
+
|
|
37
|
+
def __getitem__(self, index):
|
|
38
|
+
return new_collection(blocks_getitem(self._array.expr, index))
|
|
39
|
+
|
|
40
|
+
def __eq__(self, other):
|
|
41
|
+
# Check if other is any BlockView type (including legacy)
|
|
42
|
+
if hasattr(other, "_array") and type(other).__name__ == "BlockView":
|
|
43
|
+
return self._array is other._array
|
|
44
|
+
return NotImplemented
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def size(self):
|
|
48
|
+
"""The total number of blocks in the array."""
|
|
49
|
+
return math.prod(self.shape)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def shape(self):
|
|
53
|
+
"""The number of blocks per axis. Alias of dask.array.numblocks."""
|
|
54
|
+
return self._array.numblocks
|
|
55
|
+
|
|
56
|
+
def ravel(self):
|
|
57
|
+
"""Return a flattened list of all the blocks in the array in C order."""
|
|
58
|
+
return [self[idx] for idx in np.ndindex(self.shape)]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Blocks(ArrayExpr):
|
|
62
|
+
"""Expression for block-based indexing (x.blocks[...]).
|
|
63
|
+
|
|
64
|
+
This expression allows accessing array blocks by block index rather than
|
|
65
|
+
element index. The index is normalized to always use slices (never integers)
|
|
66
|
+
to preserve dimensionality.
|
|
67
|
+
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
array : ArrayExpr
|
|
71
|
+
The source array expression
|
|
72
|
+
index : tuple
|
|
73
|
+
Normalized block indices (after converting integers to length-1 slices)
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
_parameters = ["array", "index"]
|
|
77
|
+
|
|
78
|
+
@functools.cached_property
|
|
79
|
+
def _name(self):
|
|
80
|
+
return f"blocks-{self.deterministic_token}"
|
|
81
|
+
|
|
82
|
+
@functools.cached_property
|
|
83
|
+
def _meta(self):
|
|
84
|
+
return self.array._meta
|
|
85
|
+
|
|
86
|
+
@functools.cached_property
|
|
87
|
+
def chunks(self):
|
|
88
|
+
"""Compute chunks by selecting from the source array's chunks."""
|
|
89
|
+
return tuple(tuple(np.array(c)[idx].tolist()) for c, idx in zip(self.array.chunks, self.index))
|
|
90
|
+
|
|
91
|
+
def _layer(self) -> dict:
|
|
92
|
+
"""Generate the task graph layer.
|
|
93
|
+
|
|
94
|
+
Each output block is an alias to the corresponding input block.
|
|
95
|
+
"""
|
|
96
|
+
# Pre-compute index mappings for each dimension
|
|
97
|
+
index_maps = [np.arange(n)[idx] for n, idx in zip(self.array.numblocks, self.index)]
|
|
98
|
+
|
|
99
|
+
dsk = {}
|
|
100
|
+
for out_key in product(*(range(len(c)) for c in self.chunks)):
|
|
101
|
+
in_key = tuple(int(m[i]) for m, i in zip(index_maps, out_key))
|
|
102
|
+
out_name = (self._name,) + out_key
|
|
103
|
+
in_name = (self.array._name,) + in_key
|
|
104
|
+
dsk[out_name] = Alias(out_name, in_name)
|
|
105
|
+
|
|
106
|
+
return dsk
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def blocks_getitem(array, index):
|
|
110
|
+
"""Create a Blocks expression for block indexing.
|
|
111
|
+
|
|
112
|
+
Parameters
|
|
113
|
+
----------
|
|
114
|
+
array : ArrayExpr
|
|
115
|
+
The source array expression
|
|
116
|
+
index : tuple
|
|
117
|
+
The block index (may contain integers, slices, or lists)
|
|
118
|
+
|
|
119
|
+
Returns
|
|
120
|
+
-------
|
|
121
|
+
Blocks
|
|
122
|
+
The blocks expression
|
|
123
|
+
"""
|
|
124
|
+
if not isinstance(index, tuple):
|
|
125
|
+
index = (index,)
|
|
126
|
+
|
|
127
|
+
if sum(isinstance(ind, (np.ndarray, list)) for ind in index) > 1:
|
|
128
|
+
raise ValueError("Can only slice with a single list")
|
|
129
|
+
if any(ind is None for ind in index):
|
|
130
|
+
raise ValueError("Slicing with np.newaxis or None is not supported")
|
|
131
|
+
|
|
132
|
+
# Normalize index to array's numblocks
|
|
133
|
+
index = normalize_index(index, array.numblocks)
|
|
134
|
+
|
|
135
|
+
# Convert integers to length-1 slices to preserve dimensionality
|
|
136
|
+
index = tuple(slice(k, k + 1) if isinstance(k, Number) else k for k in index)
|
|
137
|
+
|
|
138
|
+
return Blocks(array, index)
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import warnings
|
|
5
|
+
from operator import getitem
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from dask._task_spec import Alias
|
|
10
|
+
from dask_array._expr import ArrayExpr
|
|
11
|
+
from dask_array._utils import meta_from_array
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def getitem_variadic(x, *index):
|
|
15
|
+
"""Helper function for boolean indexing."""
|
|
16
|
+
return x[index]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def slice_with_bool_dask_array(x, index):
|
|
20
|
+
"""Slice x with one or more dask arrays of bools.
|
|
21
|
+
|
|
22
|
+
This is a helper function of :meth:`Array.__getitem__`.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
x: Array
|
|
27
|
+
index: tuple with as many elements as x.ndim, among which there are
|
|
28
|
+
one or more Array's with dtype=bool
|
|
29
|
+
|
|
30
|
+
Returns
|
|
31
|
+
-------
|
|
32
|
+
tuple of (sliced x, new index)
|
|
33
|
+
|
|
34
|
+
where the new index is the same as the input, but with slice(None)
|
|
35
|
+
replaced to the original slicer when a filter has been applied.
|
|
36
|
+
|
|
37
|
+
Note: The sliced x will have nan chunks on the sliced axes.
|
|
38
|
+
"""
|
|
39
|
+
from dask_array._collection import (
|
|
40
|
+
Array,
|
|
41
|
+
blockwise,
|
|
42
|
+
elemwise,
|
|
43
|
+
new_collection,
|
|
44
|
+
)
|
|
45
|
+
from dask_array._expr import ChunksOverride
|
|
46
|
+
|
|
47
|
+
out_index = [slice(None) if isinstance(ind, Array) and ind.dtype == bool else ind for ind in index]
|
|
48
|
+
|
|
49
|
+
# Case 1: Full-dimensional boolean mask
|
|
50
|
+
if len(index) == 1 and index[0].ndim == x.ndim:
|
|
51
|
+
if not np.isnan(x.shape).any() and not np.isnan(index[0].shape).any():
|
|
52
|
+
x = x.ravel()
|
|
53
|
+
index = tuple(i.ravel() for i in index)
|
|
54
|
+
elif x.ndim > 1:
|
|
55
|
+
warnings.warn(
|
|
56
|
+
"When slicing a Dask array of unknown chunks with a boolean mask "
|
|
57
|
+
"Dask array, the output array may have a different ordering "
|
|
58
|
+
"compared to the equivalent NumPy operation. This will raise an "
|
|
59
|
+
"error in a future release of Dask.",
|
|
60
|
+
stacklevel=3,
|
|
61
|
+
)
|
|
62
|
+
# Use elemwise to apply getitem across blocks
|
|
63
|
+
y = elemwise(getitem, x, index[0], dtype=x.dtype)
|
|
64
|
+
# Trigger eager chunk validation to match legacy behavior
|
|
65
|
+
# This will raise if x and index have incompatible chunks
|
|
66
|
+
_ = y.chunks
|
|
67
|
+
result = BooleanIndexFlattened(y.expr)
|
|
68
|
+
return new_collection(result), out_index
|
|
69
|
+
|
|
70
|
+
# Case 2: 1D boolean arrays on specific dimensions
|
|
71
|
+
if any(isinstance(ind, Array) and ind.dtype == bool and ind.ndim != 1 for ind in index):
|
|
72
|
+
raise NotImplementedError(
|
|
73
|
+
"Slicing with dask.array of bools only permitted when "
|
|
74
|
+
"the indexer has only one dimension or when "
|
|
75
|
+
"it has the same dimension as the sliced "
|
|
76
|
+
"array"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
indexes = [ind if isinstance(ind, Array) and ind.dtype == bool else slice(None) for ind in index]
|
|
80
|
+
|
|
81
|
+
# Track which dimension indices have boolean arrays
|
|
82
|
+
dsk_ind = []
|
|
83
|
+
|
|
84
|
+
from toolz import concat
|
|
85
|
+
|
|
86
|
+
arginds = []
|
|
87
|
+
i = 0
|
|
88
|
+
for dim, ind in enumerate(indexes):
|
|
89
|
+
if isinstance(ind, Array) and ind.dtype == bool:
|
|
90
|
+
dsk_ind.append(dim)
|
|
91
|
+
new = (ind, tuple(range(i, i + ind.ndim)))
|
|
92
|
+
i += x.ndim
|
|
93
|
+
else:
|
|
94
|
+
new = (slice(None), None)
|
|
95
|
+
i += 1
|
|
96
|
+
arginds.append(new)
|
|
97
|
+
|
|
98
|
+
arginds = list(concat(arginds))
|
|
99
|
+
|
|
100
|
+
out = blockwise(
|
|
101
|
+
getitem_variadic,
|
|
102
|
+
tuple(range(x.ndim)),
|
|
103
|
+
x,
|
|
104
|
+
tuple(range(x.ndim)),
|
|
105
|
+
*arginds,
|
|
106
|
+
dtype=x.dtype,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# For boolean indexing, override chunks on boolean-indexed dimensions
|
|
110
|
+
# with nan values since the output size is unknown
|
|
111
|
+
new_chunks = tuple(
|
|
112
|
+
tuple(np.nan for _ in range(len(c))) if dim in dsk_ind else c for dim, c in enumerate(out.chunks)
|
|
113
|
+
)
|
|
114
|
+
result = ChunksOverride(out.expr, new_chunks)
|
|
115
|
+
return new_collection(result), tuple(out_index)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class BooleanIndexFlattened(ArrayExpr):
|
|
119
|
+
"""Flattens the output of a full-dimensional boolean index operation."""
|
|
120
|
+
|
|
121
|
+
_parameters = ["array"]
|
|
122
|
+
|
|
123
|
+
@functools.cached_property
|
|
124
|
+
def _name(self):
|
|
125
|
+
return f"getitem-{self.deterministic_token}"
|
|
126
|
+
|
|
127
|
+
@functools.cached_property
|
|
128
|
+
def _meta(self):
|
|
129
|
+
return meta_from_array(self.array._meta, ndim=1)
|
|
130
|
+
|
|
131
|
+
@functools.cached_property
|
|
132
|
+
def chunks(self):
|
|
133
|
+
# Total number of blocks = product of numblocks
|
|
134
|
+
from functools import reduce
|
|
135
|
+
from operator import mul
|
|
136
|
+
|
|
137
|
+
nblocks = reduce(mul, self.array.numblocks, 1)
|
|
138
|
+
return ((np.nan,) * nblocks,)
|
|
139
|
+
|
|
140
|
+
def _layer(self) -> dict:
|
|
141
|
+
from dask.base import flatten
|
|
142
|
+
|
|
143
|
+
# Flatten the keys from the elemwise result
|
|
144
|
+
keys = list(flatten(self.array.__dask_keys__()))
|
|
145
|
+
return {(self._name, i): Alias((self._name, i), k) for i, k in enumerate(keys)}
|
|
@@ -0,0 +1,329 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import math
|
|
5
|
+
from itertools import product
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
from dask._task_spec import Alias, List, Task, TaskRef
|
|
10
|
+
from dask_array._expr import ArrayExpr
|
|
11
|
+
from dask_array._utils import meta_from_array
|
|
12
|
+
from dask_array._core_utils import concatenate3 as concatenate_shaped
|
|
13
|
+
from dask_array.slicing._utils import parse_assignment_indices, setitem
|
|
14
|
+
from dask.base import is_dask_collection
|
|
15
|
+
from dask.core import flatten
|
|
16
|
+
from dask.utils import cached_cumsum
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def setitem_array_expr(out_name, array, indices, value):
|
|
20
|
+
"""Array-expr version of setitem_array that generates Task objects directly.
|
|
21
|
+
|
|
22
|
+
This function creates a new dask graph that assigns values to each block
|
|
23
|
+
that is touched by the indices, leaving other blocks unchanged.
|
|
24
|
+
"""
|
|
25
|
+
array_shape = array.shape
|
|
26
|
+
value_shape = value.shape
|
|
27
|
+
value_ndim = len(value_shape)
|
|
28
|
+
|
|
29
|
+
# Reformat input indices
|
|
30
|
+
indices, implied_shape, reverse, implied_shape_positions = parse_assignment_indices(indices, array_shape)
|
|
31
|
+
|
|
32
|
+
# Empty slices can only be assigned size 1 values
|
|
33
|
+
if 0 in implied_shape and value_shape and max(value_shape) > 1:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"shape mismatch: value array of shape {value_shape} "
|
|
36
|
+
"could not be broadcast to indexing result "
|
|
37
|
+
f"of shape {tuple(implied_shape)}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Set variables needed when creating the part of the assignment value
|
|
41
|
+
offset = len(implied_shape) - value_ndim
|
|
42
|
+
if offset >= 0:
|
|
43
|
+
array_common_shape = implied_shape[offset:]
|
|
44
|
+
value_common_shape = value_shape
|
|
45
|
+
value_offset = 0
|
|
46
|
+
reverse = [i - offset for i in reverse if i >= offset]
|
|
47
|
+
else:
|
|
48
|
+
value_offset = -offset
|
|
49
|
+
array_common_shape = implied_shape
|
|
50
|
+
value_common_shape = value_shape[value_offset:]
|
|
51
|
+
offset = 0
|
|
52
|
+
if value_shape[:value_offset] != (1,) * value_offset:
|
|
53
|
+
raise ValueError(
|
|
54
|
+
f"could not broadcast input array from shape{value_shape} into shape {tuple(implied_shape)}"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
base_value_indices = []
|
|
58
|
+
non_broadcast_dimensions = []
|
|
59
|
+
|
|
60
|
+
for i, (a, b, j) in enumerate(zip(array_common_shape, value_common_shape, implied_shape_positions)):
|
|
61
|
+
index = indices[j]
|
|
62
|
+
if is_dask_collection(index) and index.dtype == bool:
|
|
63
|
+
if math.isnan(b) or b <= index.size:
|
|
64
|
+
base_value_indices.append(None)
|
|
65
|
+
non_broadcast_dimensions.append(i)
|
|
66
|
+
else:
|
|
67
|
+
raise ValueError(
|
|
68
|
+
f"shape mismatch: value array dimension size of {b} is "
|
|
69
|
+
"greater then corresponding boolean index size of "
|
|
70
|
+
f"{index.size}"
|
|
71
|
+
)
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
if b == 1:
|
|
75
|
+
base_value_indices.append(slice(None))
|
|
76
|
+
elif a == b:
|
|
77
|
+
base_value_indices.append(None)
|
|
78
|
+
non_broadcast_dimensions.append(i)
|
|
79
|
+
elif math.isnan(a):
|
|
80
|
+
base_value_indices.append(None)
|
|
81
|
+
non_broadcast_dimensions.append(i)
|
|
82
|
+
else:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"shape mismatch: value array of shape {value_shape} "
|
|
85
|
+
"could not be broadcast to indexing result of shape "
|
|
86
|
+
f"{tuple(implied_shape)}"
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# Translate chunks tuple to array locations
|
|
90
|
+
chunks = array.chunks
|
|
91
|
+
cumdims = [cached_cumsum(bds, initial_zero=True) for bds in chunks]
|
|
92
|
+
array_locations = [[(s, s + dim) for s, dim in zip(starts, shapes)] for starts, shapes in zip(cumdims, chunks)]
|
|
93
|
+
array_locations = product(*array_locations)
|
|
94
|
+
|
|
95
|
+
in_keys = list(flatten(array.__dask_keys__()))
|
|
96
|
+
|
|
97
|
+
# Build graph with Task objects
|
|
98
|
+
dsk = {}
|
|
99
|
+
out_name_tuple = (out_name,)
|
|
100
|
+
|
|
101
|
+
# Helper closures for index handling (simplified from legacy)
|
|
102
|
+
def block_index_from_1d_index(index, loc0, loc1, is_bool):
|
|
103
|
+
if is_bool:
|
|
104
|
+
return index[loc0:loc1]
|
|
105
|
+
elif is_dask_collection(index):
|
|
106
|
+
i = np.where((loc0 <= index) & (index < loc1), index, loc1)
|
|
107
|
+
return i - loc0
|
|
108
|
+
else:
|
|
109
|
+
i = np.where((loc0 <= index) & (index < loc1))[0]
|
|
110
|
+
return index[i] - loc0
|
|
111
|
+
|
|
112
|
+
def block_index_shape_from_1d_bool_index(index, loc0, loc1):
|
|
113
|
+
return np.sum(index[loc0:loc1])
|
|
114
|
+
|
|
115
|
+
def n_preceding_from_1d_bool_index(index, loc0):
|
|
116
|
+
return np.sum(index[:loc0])
|
|
117
|
+
|
|
118
|
+
def value_indices_from_1d_int_index(index, vsize, loc0, loc1):
|
|
119
|
+
if is_dask_collection(index):
|
|
120
|
+
if np.isnan(index.size):
|
|
121
|
+
i = np.where((loc0 <= index) & (index < loc1), True, False)
|
|
122
|
+
i = concatenate_array_chunks_expr(i)
|
|
123
|
+
i._chunks = ((vsize,),)
|
|
124
|
+
else:
|
|
125
|
+
i = np.where((loc0 <= index) & (index < loc1))[0]
|
|
126
|
+
i = concatenate_array_chunks_expr(i)
|
|
127
|
+
else:
|
|
128
|
+
i = np.where((loc0 <= index) & (index < loc1))[0]
|
|
129
|
+
return i
|
|
130
|
+
|
|
131
|
+
for in_key, locations in zip(in_keys, array_locations):
|
|
132
|
+
block_indices = []
|
|
133
|
+
block_indices_shape = []
|
|
134
|
+
block_preceding_sizes = []
|
|
135
|
+
overlaps = True
|
|
136
|
+
dim_1d_int_index = None
|
|
137
|
+
|
|
138
|
+
for dim, (index, (loc0, loc1)) in enumerate(zip(indices, locations)):
|
|
139
|
+
integer_index = isinstance(index, int)
|
|
140
|
+
if isinstance(index, slice):
|
|
141
|
+
stop = loc1 - loc0
|
|
142
|
+
if index.stop < loc1:
|
|
143
|
+
stop -= loc1 - index.stop
|
|
144
|
+
start = index.start - loc0
|
|
145
|
+
if start < 0:
|
|
146
|
+
start %= index.step
|
|
147
|
+
if start >= stop:
|
|
148
|
+
overlaps = False
|
|
149
|
+
break
|
|
150
|
+
step = index.step
|
|
151
|
+
block_index = slice(start, stop, step)
|
|
152
|
+
block_index_size, rem = divmod(stop - start, step)
|
|
153
|
+
if rem:
|
|
154
|
+
block_index_size += 1
|
|
155
|
+
pre = index.indices(loc0)
|
|
156
|
+
n_preceding, rem = divmod(pre[1] - pre[0], step)
|
|
157
|
+
if rem:
|
|
158
|
+
n_preceding += 1
|
|
159
|
+
elif integer_index:
|
|
160
|
+
if not loc0 <= index < loc1:
|
|
161
|
+
overlaps = False
|
|
162
|
+
break
|
|
163
|
+
block_index = index - loc0
|
|
164
|
+
else:
|
|
165
|
+
is_bool = index.dtype == bool
|
|
166
|
+
block_index = block_index_from_1d_index(index, loc0, loc1, is_bool)
|
|
167
|
+
if is_bool:
|
|
168
|
+
block_index_size = block_index_shape_from_1d_bool_index(index, loc0, loc1)
|
|
169
|
+
n_preceding = n_preceding_from_1d_bool_index(index, loc0)
|
|
170
|
+
else:
|
|
171
|
+
block_index_size = None
|
|
172
|
+
n_preceding = None
|
|
173
|
+
dim_1d_int_index = dim
|
|
174
|
+
loc0_loc1 = loc0, loc1
|
|
175
|
+
|
|
176
|
+
if not is_dask_collection(index) and not block_index.size:
|
|
177
|
+
overlaps = False
|
|
178
|
+
break
|
|
179
|
+
|
|
180
|
+
block_indices.append(block_index)
|
|
181
|
+
if not integer_index:
|
|
182
|
+
block_indices_shape.append(block_index_size)
|
|
183
|
+
block_preceding_sizes.append(n_preceding)
|
|
184
|
+
|
|
185
|
+
out_key = out_name_tuple + in_key[1:]
|
|
186
|
+
|
|
187
|
+
if not overlaps:
|
|
188
|
+
dsk[out_key] = Alias(out_key, in_key)
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
# Build value indices for this block
|
|
192
|
+
value_indices = base_value_indices[:]
|
|
193
|
+
for i in non_broadcast_dimensions:
|
|
194
|
+
j = i + offset
|
|
195
|
+
if j == dim_1d_int_index:
|
|
196
|
+
value_indices[i] = value_indices_from_1d_int_index(
|
|
197
|
+
indices[j], value_shape[i + value_offset], *loc0_loc1
|
|
198
|
+
)
|
|
199
|
+
else:
|
|
200
|
+
start = block_preceding_sizes[j]
|
|
201
|
+
value_indices[i] = slice(start, start + block_indices_shape[j])
|
|
202
|
+
|
|
203
|
+
for i in reverse:
|
|
204
|
+
size = value_common_shape[i]
|
|
205
|
+
start, stop, step = value_indices[i].indices(size)
|
|
206
|
+
size -= 1
|
|
207
|
+
start = size - start
|
|
208
|
+
stop = size - stop
|
|
209
|
+
if stop < 0:
|
|
210
|
+
stop = None
|
|
211
|
+
value_indices[i] = slice(start, stop, -1)
|
|
212
|
+
|
|
213
|
+
if value_ndim > len(indices):
|
|
214
|
+
value_indices.insert(0, Ellipsis)
|
|
215
|
+
|
|
216
|
+
# Get the value slice and concatenate to single chunk
|
|
217
|
+
v = value[tuple(value_indices)]
|
|
218
|
+
v = concatenate_array_chunks_expr(v)
|
|
219
|
+
v_key = next(flatten(v.__dask_keys__()))
|
|
220
|
+
|
|
221
|
+
# Merge value's graph into dsk
|
|
222
|
+
dsk.update(dict(v.__dask_graph__()))
|
|
223
|
+
|
|
224
|
+
# Convert block_indices to use TaskRef for any dask keys
|
|
225
|
+
task_block_indices = []
|
|
226
|
+
for idx in block_indices:
|
|
227
|
+
if is_dask_collection(idx):
|
|
228
|
+
idx = concatenate_array_chunks_expr(idx)
|
|
229
|
+
idx_key = next(flatten(idx.__dask_keys__()))
|
|
230
|
+
dsk.update(dict(idx.__dask_graph__()))
|
|
231
|
+
task_block_indices.append(TaskRef(idx_key))
|
|
232
|
+
else:
|
|
233
|
+
task_block_indices.append(idx)
|
|
234
|
+
|
|
235
|
+
# Create Task with proper TaskRef wrappers
|
|
236
|
+
dsk[out_key] = Task(
|
|
237
|
+
out_key,
|
|
238
|
+
setitem,
|
|
239
|
+
TaskRef(in_key),
|
|
240
|
+
TaskRef(v_key),
|
|
241
|
+
List(*task_block_indices),
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
return dsk
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
class SetItem(ArrayExpr):
|
|
248
|
+
"""Expression for array assignment (setitem)."""
|
|
249
|
+
|
|
250
|
+
_parameters = ["array", "index", "value"]
|
|
251
|
+
|
|
252
|
+
@functools.cached_property
|
|
253
|
+
def _name(self):
|
|
254
|
+
return f"setitem-{self.deterministic_token}"
|
|
255
|
+
|
|
256
|
+
@functools.cached_property
|
|
257
|
+
def _meta(self):
|
|
258
|
+
meta = meta_from_array(self.array._meta, ndim=self.array.ndim)
|
|
259
|
+
if np.isscalar(meta):
|
|
260
|
+
meta = np.array(meta)
|
|
261
|
+
return meta
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def chunks(self):
|
|
265
|
+
return self.array.chunks
|
|
266
|
+
|
|
267
|
+
def _layer(self) -> dict:
|
|
268
|
+
from dask_array._collection import Array
|
|
269
|
+
|
|
270
|
+
# Wrap expressions as Array for setitem_array_expr
|
|
271
|
+
array = Array(self.array)
|
|
272
|
+
value = Array(self.value) if hasattr(self.value, "_meta") else self.value
|
|
273
|
+
|
|
274
|
+
return setitem_array_expr(self._name, array, self.index, value)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class ConcatenateArrayChunks(ArrayExpr):
|
|
278
|
+
"""Concatenate all chunks of an array into a single chunk.
|
|
279
|
+
|
|
280
|
+
This is an array-expr version of dask.array.slicing.concatenate_array_chunks.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
_parameters = ["array"]
|
|
284
|
+
|
|
285
|
+
@functools.cached_property
|
|
286
|
+
def _name(self):
|
|
287
|
+
return f"concatenate-shaped-{self.deterministic_token}"
|
|
288
|
+
|
|
289
|
+
@functools.cached_property
|
|
290
|
+
def _meta(self):
|
|
291
|
+
return meta_from_array(self.array._meta, ndim=self.array.ndim)
|
|
292
|
+
|
|
293
|
+
@functools.cached_property
|
|
294
|
+
def chunks(self):
|
|
295
|
+
# Single chunk containing all the data
|
|
296
|
+
shape = self.array.shape
|
|
297
|
+
if not shape:
|
|
298
|
+
return ((1,),)
|
|
299
|
+
return tuple((s,) for s in shape)
|
|
300
|
+
|
|
301
|
+
def _layer(self) -> dict:
|
|
302
|
+
from dask.base import flatten
|
|
303
|
+
|
|
304
|
+
# Get all keys from the input array as TaskRefs
|
|
305
|
+
keys = [TaskRef(k) for k in flatten(self.array.__dask_keys__())]
|
|
306
|
+
# Output key has ndim indices, all 0 since we have a single chunk
|
|
307
|
+
out_key = (self._name,) + (0,) * self.array.ndim
|
|
308
|
+
|
|
309
|
+
return {
|
|
310
|
+
out_key: Task(
|
|
311
|
+
out_key,
|
|
312
|
+
concatenate_shaped,
|
|
313
|
+
List(*keys),
|
|
314
|
+
self.array.numblocks,
|
|
315
|
+
)
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def concatenate_array_chunks_expr(x):
|
|
320
|
+
"""Concatenate all chunks of an array into a single chunk.
|
|
321
|
+
|
|
322
|
+
Array-expr version of dask.array.slicing.concatenate_array_chunks.
|
|
323
|
+
"""
|
|
324
|
+
from dask_array._new_collection import new_collection
|
|
325
|
+
|
|
326
|
+
if x.npartitions == 1:
|
|
327
|
+
return x
|
|
328
|
+
|
|
329
|
+
return new_collection(ConcatenateArrayChunks(x.expr))
|