dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
"""Where operation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def where(condition, x=None, y=None):
|
|
9
|
+
"""Return elements chosen from x or y depending on condition.
|
|
10
|
+
|
|
11
|
+
Parameters
|
|
12
|
+
----------
|
|
13
|
+
condition : array_like, bool
|
|
14
|
+
Where True, yield x, otherwise yield y.
|
|
15
|
+
x, y : array_like
|
|
16
|
+
Values from which to choose. x, y and condition need to be
|
|
17
|
+
broadcastable to some shape.
|
|
18
|
+
|
|
19
|
+
Returns
|
|
20
|
+
-------
|
|
21
|
+
out : Array
|
|
22
|
+
An array with elements from x where condition is True,
|
|
23
|
+
and elements from y elsewhere.
|
|
24
|
+
|
|
25
|
+
See Also
|
|
26
|
+
--------
|
|
27
|
+
numpy.where
|
|
28
|
+
|
|
29
|
+
Examples
|
|
30
|
+
--------
|
|
31
|
+
>>> import dask_array as da
|
|
32
|
+
>>> x = da.arange(10, chunks=5)
|
|
33
|
+
>>> da.where(x < 5, x, 10 * x).compute() # doctest: +NORMALIZE_WHITESPACE
|
|
34
|
+
array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
|
|
35
|
+
"""
|
|
36
|
+
# Lazy imports to avoid circular dependencies
|
|
37
|
+
from dask_array.core import asarray
|
|
38
|
+
from dask_array.core._blockwise_funcs import elemwise
|
|
39
|
+
|
|
40
|
+
if (x is None) != (y is None):
|
|
41
|
+
raise ValueError("either both or neither of x and y should be given")
|
|
42
|
+
if (x is None) and (y is None):
|
|
43
|
+
# Single arg case - returns indices of nonzero elements
|
|
44
|
+
from dask_array._routines import nonzero
|
|
45
|
+
|
|
46
|
+
return nonzero(condition)
|
|
47
|
+
|
|
48
|
+
# Optimization: for scalar conditions, avoid elemwise overhead
|
|
49
|
+
if np.isscalar(condition):
|
|
50
|
+
from dask_array._broadcast import broadcast_to
|
|
51
|
+
from dask_array._core_utils import broadcast_shapes
|
|
52
|
+
from dask_array.routines._misc import result_type
|
|
53
|
+
|
|
54
|
+
dtype = result_type(x, y)
|
|
55
|
+
x = asarray(x)
|
|
56
|
+
y = asarray(y)
|
|
57
|
+
shape = broadcast_shapes(x.shape, y.shape)
|
|
58
|
+
out = x if condition else y
|
|
59
|
+
return broadcast_to(out, shape).astype(dtype)
|
|
60
|
+
|
|
61
|
+
# Use elemwise with np.where to handle all cases
|
|
62
|
+
return elemwise(np.where, condition, x, y)
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
"""Slicing operations for dask array expressions."""
|
|
2
|
+
|
|
3
|
+
from dask_array.slicing._basic import (
|
|
4
|
+
ArrayOffsetDep,
|
|
5
|
+
Slice,
|
|
6
|
+
SliceSlicesIntegers,
|
|
7
|
+
TakeUnknownOneChunk,
|
|
8
|
+
normalize_index,
|
|
9
|
+
slice_array,
|
|
10
|
+
slice_slices_and_integers,
|
|
11
|
+
slice_with_int_dask_array,
|
|
12
|
+
slice_with_int_dask_array_on_axis,
|
|
13
|
+
slice_with_newaxes,
|
|
14
|
+
slice_wrap_lists,
|
|
15
|
+
take,
|
|
16
|
+
)
|
|
17
|
+
from dask_array.slicing._bool_index import (
|
|
18
|
+
BooleanIndexFlattened,
|
|
19
|
+
getitem_variadic,
|
|
20
|
+
slice_with_bool_dask_array,
|
|
21
|
+
)
|
|
22
|
+
from dask_array.slicing._setitem import (
|
|
23
|
+
ConcatenateArrayChunks,
|
|
24
|
+
SetItem,
|
|
25
|
+
concatenate_array_chunks_expr,
|
|
26
|
+
setitem_array_expr,
|
|
27
|
+
)
|
|
28
|
+
from dask_array.slicing._squeeze import Squeeze, squeeze
|
|
29
|
+
from dask_array.slicing._vindex import (
|
|
30
|
+
VIndexArray,
|
|
31
|
+
_numpy_vindex,
|
|
32
|
+
_vindex,
|
|
33
|
+
_vindex_array,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
__all__ = [
|
|
37
|
+
# Basic slicing
|
|
38
|
+
"ArrayOffsetDep",
|
|
39
|
+
"Slice",
|
|
40
|
+
"SliceSlicesIntegers",
|
|
41
|
+
"TakeUnknownOneChunk",
|
|
42
|
+
"normalize_index",
|
|
43
|
+
"slice_array",
|
|
44
|
+
"slice_slices_and_integers",
|
|
45
|
+
"slice_with_int_dask_array",
|
|
46
|
+
"slice_with_int_dask_array_on_axis",
|
|
47
|
+
"slice_with_newaxes",
|
|
48
|
+
"slice_wrap_lists",
|
|
49
|
+
"take",
|
|
50
|
+
# Boolean indexing
|
|
51
|
+
"BooleanIndexFlattened",
|
|
52
|
+
"getitem_variadic",
|
|
53
|
+
"slice_with_bool_dask_array",
|
|
54
|
+
# Setitem
|
|
55
|
+
"ConcatenateArrayChunks",
|
|
56
|
+
"SetItem",
|
|
57
|
+
"concatenate_array_chunks_expr",
|
|
58
|
+
"setitem_array_expr",
|
|
59
|
+
# Squeeze
|
|
60
|
+
"Squeeze",
|
|
61
|
+
"squeeze",
|
|
62
|
+
# Vindex
|
|
63
|
+
"VIndexArray",
|
|
64
|
+
"_numpy_vindex",
|
|
65
|
+
"_vindex",
|
|
66
|
+
"_vindex_array",
|
|
67
|
+
]
|
|
@@ -0,0 +1,550 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from itertools import product
|
|
5
|
+
from numbers import Integral
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from toolz import pluck
|
|
9
|
+
|
|
10
|
+
from dask._task_spec import Alias, Task, TaskRef
|
|
11
|
+
from dask_array._expr import ArrayExpr
|
|
12
|
+
from dask_array._chunk import getitem
|
|
13
|
+
from dask_array._utils import meta_from_array
|
|
14
|
+
from dask_array.slicing._utils import (
|
|
15
|
+
_slice_1d,
|
|
16
|
+
check_index,
|
|
17
|
+
fuse_slice,
|
|
18
|
+
new_blockdim,
|
|
19
|
+
normalize_slice,
|
|
20
|
+
posify_index,
|
|
21
|
+
replace_ellipsis,
|
|
22
|
+
sanitize_index,
|
|
23
|
+
)
|
|
24
|
+
from dask.layers import ArrayBlockwiseDep
|
|
25
|
+
from dask.utils import cached_cumsum, is_arraylike
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _compute_sliced_chunks(chunks, slc, dim_size):
|
|
29
|
+
"""Compute chunk sizes for the sliced region of a dimension."""
|
|
30
|
+
if slc == slice(None):
|
|
31
|
+
return chunks
|
|
32
|
+
|
|
33
|
+
start, stop, step = slc.indices(dim_size)
|
|
34
|
+
|
|
35
|
+
# Handle step == -1 (flip) specially - preserve chunks in reverse order
|
|
36
|
+
if step == -1:
|
|
37
|
+
# Check if this is a full flip (equivalent to slice(None, None, -1))
|
|
38
|
+
if start == dim_size - 1 and stop == -1:
|
|
39
|
+
# Full flip: reverse the chunks
|
|
40
|
+
return chunks[::-1]
|
|
41
|
+
else:
|
|
42
|
+
# Partial negative step: fall back to single chunk
|
|
43
|
+
new_size = len(range(start, stop, step))
|
|
44
|
+
return (new_size,)
|
|
45
|
+
|
|
46
|
+
if step != 1:
|
|
47
|
+
# For non-unit step (other than -1), fall back to single chunk
|
|
48
|
+
new_size = len(range(start, stop, step))
|
|
49
|
+
return (new_size,)
|
|
50
|
+
|
|
51
|
+
# Handle empty slice - return single chunk of size 0
|
|
52
|
+
if start >= stop:
|
|
53
|
+
return (0,)
|
|
54
|
+
|
|
55
|
+
# Find chunks that overlap with [start, stop)
|
|
56
|
+
result = []
|
|
57
|
+
pos = 0
|
|
58
|
+
for chunk_size in chunks:
|
|
59
|
+
chunk_start = pos
|
|
60
|
+
chunk_end = pos + chunk_size
|
|
61
|
+
pos = chunk_end
|
|
62
|
+
|
|
63
|
+
# Skip chunks entirely before the slice
|
|
64
|
+
if chunk_end <= start:
|
|
65
|
+
continue
|
|
66
|
+
# Stop at chunks entirely after the slice
|
|
67
|
+
if chunk_start >= stop:
|
|
68
|
+
break
|
|
69
|
+
|
|
70
|
+
# Compute the portion of this chunk included in the slice
|
|
71
|
+
included_start = max(chunk_start, start)
|
|
72
|
+
included_end = min(chunk_end, stop)
|
|
73
|
+
result.append(included_end - included_start)
|
|
74
|
+
|
|
75
|
+
return tuple(result) if result else (0,)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def slice_with_int_dask_array(x, index):
|
|
79
|
+
"""Slice x with at most one 1D dask arrays of ints.
|
|
80
|
+
|
|
81
|
+
This is a helper function of :meth:`Array.__getitem__`.
|
|
82
|
+
|
|
83
|
+
Parameters
|
|
84
|
+
----------
|
|
85
|
+
x: Array
|
|
86
|
+
index: tuple with as many elements as x.ndim, among which there are
|
|
87
|
+
one or more Array's with dtype=int
|
|
88
|
+
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
tuple of (sliced x, new index)
|
|
92
|
+
|
|
93
|
+
where the new index is the same as the input, but with slice(None)
|
|
94
|
+
replaced to the original slicer where a 1D filter has been applied and
|
|
95
|
+
one less element where a zero-dimensional filter has been applied.
|
|
96
|
+
"""
|
|
97
|
+
from dask_array._collection import Array
|
|
98
|
+
|
|
99
|
+
assert len(index) == x.ndim
|
|
100
|
+
fancy_indexes = [
|
|
101
|
+
isinstance(idx, (tuple, list)) or (isinstance(idx, (np.ndarray, Array)) and idx.ndim > 0) for idx in index
|
|
102
|
+
]
|
|
103
|
+
if sum(fancy_indexes) > 1:
|
|
104
|
+
raise NotImplementedError("Don't yet support nd fancy indexing")
|
|
105
|
+
|
|
106
|
+
out_index = []
|
|
107
|
+
dropped_axis_cnt = 0
|
|
108
|
+
for in_axis, idx in enumerate(index):
|
|
109
|
+
out_axis = in_axis - dropped_axis_cnt
|
|
110
|
+
if isinstance(idx, Array) and idx.dtype.kind in "iu":
|
|
111
|
+
if idx.ndim == 0:
|
|
112
|
+
idx = idx[np.newaxis]
|
|
113
|
+
x = slice_with_int_dask_array_on_axis(x, idx, out_axis)
|
|
114
|
+
x = x[tuple(0 if i == out_axis else slice(None) for i in range(x.ndim))]
|
|
115
|
+
dropped_axis_cnt += 1
|
|
116
|
+
elif idx.ndim == 1:
|
|
117
|
+
x = slice_with_int_dask_array_on_axis(x, idx, out_axis)
|
|
118
|
+
out_index.append(slice(None))
|
|
119
|
+
else:
|
|
120
|
+
raise NotImplementedError(
|
|
121
|
+
"Slicing with dask.array of ints only permitted when the indexer has zero or one dimensions"
|
|
122
|
+
)
|
|
123
|
+
else:
|
|
124
|
+
out_index.append(idx)
|
|
125
|
+
return x, tuple(out_index)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def normalize_index(idx, shape):
|
|
129
|
+
"""Normalize slicing indexes
|
|
130
|
+
|
|
131
|
+
1. Replaces ellipses with many full slices
|
|
132
|
+
2. Adds full slices to end of index
|
|
133
|
+
3. Checks bounding conditions
|
|
134
|
+
4. Replace multidimensional numpy arrays with dask arrays
|
|
135
|
+
5. Replaces numpy arrays with lists
|
|
136
|
+
6. Posify's integers and lists
|
|
137
|
+
7. Normalizes slices to canonical form
|
|
138
|
+
|
|
139
|
+
Examples
|
|
140
|
+
--------
|
|
141
|
+
>>> normalize_index(1, (10,))
|
|
142
|
+
(1,)
|
|
143
|
+
>>> normalize_index(-1, (10,))
|
|
144
|
+
(9,)
|
|
145
|
+
>>> normalize_index([-1], (10,))
|
|
146
|
+
(array([9]),)
|
|
147
|
+
>>> normalize_index(slice(-3, 10, 1), (10,))
|
|
148
|
+
(slice(7, None, None),)
|
|
149
|
+
>>> normalize_index((Ellipsis, None), (10,))
|
|
150
|
+
(slice(None, None, None), None)
|
|
151
|
+
>>> normalize_index(np.array([[True, False], [False, True], [True, True]]), (3, 2))
|
|
152
|
+
(dask.array<array, shape=(3, 2), dtype=bool, chunksize=(3, 2), chunktype=numpy.ndarray>,)
|
|
153
|
+
"""
|
|
154
|
+
from dask_array._collection import Array, from_array
|
|
155
|
+
|
|
156
|
+
if not isinstance(idx, tuple):
|
|
157
|
+
idx = (idx,)
|
|
158
|
+
|
|
159
|
+
# if a > 1D numpy.array is provided, cast it to a dask array
|
|
160
|
+
if len(idx) > 0 and len(shape) > 1:
|
|
161
|
+
i = idx[0]
|
|
162
|
+
if is_arraylike(i) and not isinstance(i, Array) and i.shape == shape:
|
|
163
|
+
idx = (from_array(i), *idx[1:])
|
|
164
|
+
|
|
165
|
+
idx = replace_ellipsis(len(shape), idx)
|
|
166
|
+
n_sliced_dims = 0
|
|
167
|
+
for i in idx:
|
|
168
|
+
if hasattr(i, "ndim") and i.ndim >= 1:
|
|
169
|
+
n_sliced_dims += i.ndim
|
|
170
|
+
elif i is None:
|
|
171
|
+
continue
|
|
172
|
+
else:
|
|
173
|
+
n_sliced_dims += 1
|
|
174
|
+
|
|
175
|
+
idx = idx + (slice(None),) * (len(shape) - n_sliced_dims)
|
|
176
|
+
if len([i for i in idx if i is not None]) > len(shape):
|
|
177
|
+
raise IndexError("Too many indices for array")
|
|
178
|
+
|
|
179
|
+
none_shape = []
|
|
180
|
+
i = 0
|
|
181
|
+
for ind in idx:
|
|
182
|
+
if ind is not None:
|
|
183
|
+
none_shape.append(shape[i])
|
|
184
|
+
i += 1
|
|
185
|
+
else:
|
|
186
|
+
none_shape.append(None)
|
|
187
|
+
|
|
188
|
+
for axis, (i, d) in enumerate(zip(idx, none_shape)):
|
|
189
|
+
if d is not None:
|
|
190
|
+
check_index(axis, i, d)
|
|
191
|
+
idx = tuple(map(sanitize_index, idx))
|
|
192
|
+
idx = tuple(map(normalize_slice, idx, none_shape))
|
|
193
|
+
idx = posify_index(none_shape, idx)
|
|
194
|
+
return idx
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def slice_with_int_dask_array_on_axis(x, idx, axis):
|
|
198
|
+
"""Slice a ND dask array with a 1D dask arrays of ints along the given
|
|
199
|
+
axis.
|
|
200
|
+
|
|
201
|
+
This is a helper function of :func:`slice_with_int_dask_array`.
|
|
202
|
+
"""
|
|
203
|
+
from dask_array import _chunk as chunk
|
|
204
|
+
from dask_array._collection import blockwise
|
|
205
|
+
|
|
206
|
+
assert 0 <= axis < x.ndim
|
|
207
|
+
|
|
208
|
+
if np.isnan(x.chunks[axis]).any():
|
|
209
|
+
raise NotImplementedError("Slicing an array with unknown chunks with a dask.array of ints is not supported")
|
|
210
|
+
x_axes = tuple(range(x.ndim))
|
|
211
|
+
idx_axes = (x.ndim,) # arbitrary index not already in x_axes
|
|
212
|
+
offset_axes = (axis,)
|
|
213
|
+
|
|
214
|
+
# Calculate the offset at which each chunk starts along axis
|
|
215
|
+
# e.g. chunks=(..., (5, 3, 4), ...) -> offset=[0, 5, 8]
|
|
216
|
+
offset = np.roll(np.cumsum(np.asarray(x.chunks[axis], like=x._meta)), 1)
|
|
217
|
+
offset[0] = 0
|
|
218
|
+
# ArrayOffsetDep needs 1D chunks matching x.chunks[axis], not full x.chunks
|
|
219
|
+
offset = ArrayOffsetDep((x.chunks[axis],), offset)
|
|
220
|
+
|
|
221
|
+
p_axes = x_axes[: axis + 1] + idx_axes + x_axes[axis + 1 :]
|
|
222
|
+
y_axes = x_axes[:axis] + idx_axes + x_axes[axis + 1 :]
|
|
223
|
+
|
|
224
|
+
# Calculate the cartesian product of every chunk of x vs every chunk of idx
|
|
225
|
+
p = blockwise(
|
|
226
|
+
chunk.slice_with_int_dask_array,
|
|
227
|
+
p_axes,
|
|
228
|
+
x,
|
|
229
|
+
x_axes,
|
|
230
|
+
idx,
|
|
231
|
+
idx_axes,
|
|
232
|
+
offset,
|
|
233
|
+
offset_axes,
|
|
234
|
+
x_size=x.shape[axis],
|
|
235
|
+
axis=axis,
|
|
236
|
+
dtype=x.dtype,
|
|
237
|
+
meta=x._meta,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Aggregate on the chunks of x along axis
|
|
241
|
+
y = blockwise(
|
|
242
|
+
chunk.slice_with_int_dask_array_aggregate,
|
|
243
|
+
y_axes,
|
|
244
|
+
idx,
|
|
245
|
+
idx_axes,
|
|
246
|
+
p,
|
|
247
|
+
p_axes,
|
|
248
|
+
concatenate=True,
|
|
249
|
+
x_chunks=x.chunks[axis],
|
|
250
|
+
axis=axis,
|
|
251
|
+
dtype=x.dtype,
|
|
252
|
+
meta=x._meta,
|
|
253
|
+
)
|
|
254
|
+
return y
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
class ArrayOffsetDep(ArrayBlockwiseDep):
|
|
258
|
+
"""1D BlockwiseDep that provides chunk offset values."""
|
|
259
|
+
|
|
260
|
+
def __init__(self, chunks: tuple[tuple[int, ...], ...], values: np.ndarray | dict):
|
|
261
|
+
super().__init__(chunks)
|
|
262
|
+
self.values = values
|
|
263
|
+
|
|
264
|
+
def __getitem__(self, idx: tuple):
|
|
265
|
+
return self.values[idx[0]]
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def slice_array(x, index):
|
|
269
|
+
"""
|
|
270
|
+
slice_with_newaxis : handle None/newaxis case
|
|
271
|
+
slice_wrap_lists : handle fancy indexing with lists
|
|
272
|
+
slice_slices_and_integers : handle everything else
|
|
273
|
+
"""
|
|
274
|
+
if all(isinstance(index, slice) and index == slice(None, None, None) for index in index):
|
|
275
|
+
# all none slices
|
|
276
|
+
return x.expr
|
|
277
|
+
|
|
278
|
+
# Add in missing colons at the end as needed. x[5] -> x[5, :, :]
|
|
279
|
+
not_none_count = sum(i is not None for i in index)
|
|
280
|
+
missing = len(x.chunks) - not_none_count
|
|
281
|
+
index += (slice(None, None, None),) * missing
|
|
282
|
+
return slice_with_newaxes(x, index)
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def slice_with_newaxes(x, index):
|
|
286
|
+
"""
|
|
287
|
+
Handle indexing with Nones
|
|
288
|
+
|
|
289
|
+
Strips out Nones then hands off to slice_wrap_lists, then wraps
|
|
290
|
+
result with ExpandDims if needed.
|
|
291
|
+
"""
|
|
292
|
+
from dask_array.manipulation._expand import ExpandDims
|
|
293
|
+
|
|
294
|
+
# Strip Nones from index
|
|
295
|
+
index2 = tuple(ind for ind in index if ind is not None)
|
|
296
|
+
where_none = [i for i, ind in enumerate(index) if ind is None]
|
|
297
|
+
for i, xx in enumerate(where_none):
|
|
298
|
+
n = sum(isinstance(ind, Integral) for ind in index[:xx])
|
|
299
|
+
if n:
|
|
300
|
+
where_none[i] -= n
|
|
301
|
+
|
|
302
|
+
# Pass down and do work
|
|
303
|
+
x = slice_wrap_lists(x, index2, not where_none)
|
|
304
|
+
|
|
305
|
+
if where_none:
|
|
306
|
+
return ExpandDims(x, tuple(where_none))
|
|
307
|
+
else:
|
|
308
|
+
return x
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
def slice_wrap_lists(x, index, allow_getitem_optimization):
|
|
312
|
+
"""
|
|
313
|
+
Fancy indexing along blocked array dasks
|
|
314
|
+
|
|
315
|
+
Handles index of type list. Calls slice_slices_and_integers for the rest
|
|
316
|
+
|
|
317
|
+
See Also
|
|
318
|
+
--------
|
|
319
|
+
|
|
320
|
+
take : handle slicing with lists ("fancy" indexing)
|
|
321
|
+
slice_slices_and_integers : handle slicing with slices and integers
|
|
322
|
+
"""
|
|
323
|
+
assert all(isinstance(i, (slice, list, Integral)) or is_arraylike(i) for i in index)
|
|
324
|
+
if not len(x.chunks) == len(index):
|
|
325
|
+
raise IndexError("Too many indices for array")
|
|
326
|
+
|
|
327
|
+
# Do we have more than one list in the index?
|
|
328
|
+
where_list = [i for i, ind in enumerate(index) if is_arraylike(ind) and ind.ndim > 0]
|
|
329
|
+
if len(where_list) > 1:
|
|
330
|
+
raise NotImplementedError("Don't yet support nd fancy indexing")
|
|
331
|
+
# Is the single list an empty list? In this case just treat it as a zero
|
|
332
|
+
# length slice
|
|
333
|
+
if where_list and not index[where_list[0]].size:
|
|
334
|
+
index = list(index)
|
|
335
|
+
index[where_list.pop()] = slice(0, 0, 1)
|
|
336
|
+
index = tuple(index)
|
|
337
|
+
|
|
338
|
+
# No lists, hooray! just use slice_slices_and_integers
|
|
339
|
+
if not where_list:
|
|
340
|
+
return slice_slices_and_integers(x, index, allow_getitem_optimization)
|
|
341
|
+
|
|
342
|
+
# Replace all lists with full slices [3, 1, 0] -> slice(None, None, None)
|
|
343
|
+
index_without_list = tuple(slice(None, None, None) if is_arraylike(i) else i for i in index)
|
|
344
|
+
|
|
345
|
+
# lists and full slices. Just use take
|
|
346
|
+
if all(is_arraylike(i) or i == slice(None, None, None) for i in index):
|
|
347
|
+
axis = where_list[0]
|
|
348
|
+
x = take(x, index[where_list[0]], axis=axis)
|
|
349
|
+
# Mixed case. Both slices/integers and lists. slice/integer then take
|
|
350
|
+
else:
|
|
351
|
+
x = slice_slices_and_integers(
|
|
352
|
+
x,
|
|
353
|
+
index_without_list,
|
|
354
|
+
allow_getitem_optimization=False,
|
|
355
|
+
)
|
|
356
|
+
axis = where_list[0]
|
|
357
|
+
axis2 = axis - sum(1 for i, ind in enumerate(index) if i < axis and isinstance(ind, Integral))
|
|
358
|
+
x = take(x, index[axis], axis=axis2)
|
|
359
|
+
|
|
360
|
+
return x
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def slice_slices_and_integers(x, index, allow_getitem_optimization=False):
|
|
364
|
+
from dask_array._core_utils import unknown_chunk_message
|
|
365
|
+
|
|
366
|
+
shape = tuple(cached_cumsum(dim, initial_zero=True)[-1] for dim in x.chunks)
|
|
367
|
+
|
|
368
|
+
for dim, ind in zip(shape, index):
|
|
369
|
+
if np.isnan(dim) and ind != slice(None, None, None):
|
|
370
|
+
raise ValueError(f"Arrays chunk sizes are unknown: {shape}{unknown_chunk_message}")
|
|
371
|
+
assert all(isinstance(ind, (slice, Integral)) for ind in index)
|
|
372
|
+
return SliceSlicesIntegers(x, index, allow_getitem_optimization)
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def take(x, index, axis=0):
|
|
376
|
+
from dask.base import is_dask_collection
|
|
377
|
+
|
|
378
|
+
if not np.isnan(x.chunks[axis]).any():
|
|
379
|
+
from dask_array._shuffle import _shuffle
|
|
380
|
+
from dask_array._utils import arange_safe, asarray_safe
|
|
381
|
+
|
|
382
|
+
# No-op check only for numpy arrays (dask array comparison triggers warnings)
|
|
383
|
+
# Use is_dask_collection to catch both array-expr and legacy dask Arrays
|
|
384
|
+
if not is_dask_collection(index):
|
|
385
|
+
arange = arange_safe(np.sum(x.chunks[axis]), like=index)
|
|
386
|
+
if len(index) == len(arange) and np.abs(index - arange).sum() == 0:
|
|
387
|
+
return x
|
|
388
|
+
|
|
389
|
+
# If index is a dask collection, use lazy blockwise approach
|
|
390
|
+
if is_dask_collection(index):
|
|
391
|
+
return slice_with_int_dask_array_on_axis(x, index, axis)
|
|
392
|
+
|
|
393
|
+
index = asarray_safe(index, like=index)
|
|
394
|
+
|
|
395
|
+
# Compute indexer by grouping consecutive indices from same input chunk
|
|
396
|
+
from dask_array.slicing._vindex import _compute_indexer
|
|
397
|
+
|
|
398
|
+
indexer = _compute_indexer(index, x.chunks[axis])
|
|
399
|
+
return _shuffle(x, indexer, axis, "getitem-")
|
|
400
|
+
elif len(x.chunks[axis]) == 1:
|
|
401
|
+
return TakeUnknownOneChunk(x, index, axis)
|
|
402
|
+
else:
|
|
403
|
+
from dask_array._core_utils import unknown_chunk_message
|
|
404
|
+
|
|
405
|
+
raise ValueError(f"Array chunk size or shape is unknown. {unknown_chunk_message}")
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
class Slice(ArrayExpr):
|
|
409
|
+
@functools.cached_property
|
|
410
|
+
def _name(self):
|
|
411
|
+
return f"getitem-{self.deterministic_token}"
|
|
412
|
+
|
|
413
|
+
@functools.cached_property
|
|
414
|
+
def _meta(self):
|
|
415
|
+
if self.array._meta is None:
|
|
416
|
+
meta = meta_from_array(None, ndim=len(self.chunks), dtype=self.array.dtype)
|
|
417
|
+
else:
|
|
418
|
+
meta = meta_from_array(self.array._meta, ndim=len(self.chunks))
|
|
419
|
+
if np.isscalar(meta):
|
|
420
|
+
meta = np.array(meta)
|
|
421
|
+
return meta
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
class SliceSlicesIntegers(Slice):
|
|
425
|
+
_parameters = ["array", "index", "allow_getitem_optimization"]
|
|
426
|
+
|
|
427
|
+
def _simplify_down(self):
|
|
428
|
+
# Slice(Slice(x)) -> single Slice with fused indices
|
|
429
|
+
if isinstance(self.array, SliceSlicesIntegers):
|
|
430
|
+
try:
|
|
431
|
+
fused = fuse_slice(self.array.index, self.index)
|
|
432
|
+
normalized = tuple(
|
|
433
|
+
normalize_slice(idx, dim) if isinstance(idx, slice) else idx
|
|
434
|
+
for idx, dim in zip(fused, self.array.array.shape)
|
|
435
|
+
)
|
|
436
|
+
return SliceSlicesIntegers(self.array.array, normalized, self.allow_getitem_optimization)
|
|
437
|
+
except NotImplementedError:
|
|
438
|
+
# Skip fusion for unsupported slicing patterns (e.g., negative step)
|
|
439
|
+
pass
|
|
440
|
+
|
|
441
|
+
# Check if the array implements _accept_slice (for operations like Elemwise,
|
|
442
|
+
# Transpose, Blockwise, PartialReduce, ExpandDims that use the simplify_up pattern).
|
|
443
|
+
if hasattr(self.array, "_accept_slice"):
|
|
444
|
+
result = self.array._accept_slice(self)
|
|
445
|
+
if result is not None:
|
|
446
|
+
return result
|
|
447
|
+
|
|
448
|
+
def _slice_chunks(self, chunks, start, length):
|
|
449
|
+
"""Compute new chunks after slicing."""
|
|
450
|
+
result = []
|
|
451
|
+
cumsum = 0
|
|
452
|
+
for chunk_size in chunks:
|
|
453
|
+
chunk_start = cumsum
|
|
454
|
+
chunk_end = cumsum + chunk_size
|
|
455
|
+
cumsum = chunk_end
|
|
456
|
+
|
|
457
|
+
if chunk_end <= start:
|
|
458
|
+
continue
|
|
459
|
+
if chunk_start >= start + length:
|
|
460
|
+
break
|
|
461
|
+
|
|
462
|
+
overlap_start = max(start, chunk_start)
|
|
463
|
+
overlap_end = min(start + length, chunk_end)
|
|
464
|
+
overlap_size = overlap_end - overlap_start
|
|
465
|
+
if overlap_size > 0:
|
|
466
|
+
result.append(overlap_size)
|
|
467
|
+
|
|
468
|
+
return tuple(result) if result else (0,)
|
|
469
|
+
|
|
470
|
+
@functools.cached_property
|
|
471
|
+
def chunks(self):
|
|
472
|
+
new_blockdims = [
|
|
473
|
+
new_blockdim(d, db, i)
|
|
474
|
+
for d, i, db in zip(self.array.shape, self.index, self.array.chunks)
|
|
475
|
+
if not isinstance(i, Integral)
|
|
476
|
+
]
|
|
477
|
+
return tuple(map(tuple, new_blockdims))
|
|
478
|
+
|
|
479
|
+
def _layer(self) -> dict:
|
|
480
|
+
# Get a list (for each dimension) of dicts{blocknum: slice()}
|
|
481
|
+
block_slices = list(map(_slice_1d, self.array.shape, self.array.chunks, self.index))
|
|
482
|
+
sorted_block_slices = [sorted(i.items()) for i in block_slices]
|
|
483
|
+
|
|
484
|
+
# (in_name, 1, 1, 2), (in_name, 1, 1, 4), (in_name, 2, 1, 2), ...
|
|
485
|
+
in_names = list(product([self.array._name], *[pluck(0, s) for s in sorted_block_slices]))
|
|
486
|
+
|
|
487
|
+
# (out_name, 0, 0, 0), (out_name, 0, 0, 1), (out_name, 0, 1, 0), ...
|
|
488
|
+
out_names = list(
|
|
489
|
+
product(
|
|
490
|
+
[self._name],
|
|
491
|
+
*[
|
|
492
|
+
range(len(d))[::-1] if i.step and i.step < 0 else range(len(d))
|
|
493
|
+
for d, i in zip(block_slices, self.index)
|
|
494
|
+
if not isinstance(i, Integral)
|
|
495
|
+
],
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
|
|
499
|
+
all_slices = list(product(*[pluck(1, s) for s in sorted_block_slices]))
|
|
500
|
+
|
|
501
|
+
dsk_out = {
|
|
502
|
+
out_name: (
|
|
503
|
+
Task(out_name, getitem, TaskRef(in_name), slices)
|
|
504
|
+
if not self.allow_getitem_optimization or not all(sl == slice(None, None, None) for sl in slices)
|
|
505
|
+
else Alias(out_name, in_name)
|
|
506
|
+
)
|
|
507
|
+
for out_name, in_name, slices in zip(out_names, in_names, all_slices)
|
|
508
|
+
}
|
|
509
|
+
return dsk_out
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _compose_slices(outer_slice, inner_slice, dim_size):
|
|
513
|
+
"""Compose two slices: inner_slice is relative to outer_slice's result."""
|
|
514
|
+
# Get the range of the outer slice
|
|
515
|
+
outer_start, outer_stop, outer_step = outer_slice.indices(dim_size)
|
|
516
|
+
outer_len = len(range(outer_start, outer_stop, outer_step))
|
|
517
|
+
|
|
518
|
+
# Get the range of the inner slice relative to outer's result
|
|
519
|
+
inner_start, inner_stop, inner_step = inner_slice.indices(outer_len)
|
|
520
|
+
|
|
521
|
+
# Compose: offset inner by outer_start
|
|
522
|
+
if outer_step != 1 or inner_step != 1:
|
|
523
|
+
new_start = outer_start + inner_start * outer_step
|
|
524
|
+
new_stop = outer_start + inner_stop * outer_step
|
|
525
|
+
new_step = outer_step * inner_step
|
|
526
|
+
else:
|
|
527
|
+
new_start = outer_start + inner_start
|
|
528
|
+
new_stop = outer_start + inner_stop
|
|
529
|
+
new_step = 1
|
|
530
|
+
|
|
531
|
+
return slice(new_start, new_stop, new_step if new_step != 1 else None)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
class TakeUnknownOneChunk(Slice):
|
|
535
|
+
_parameters = ["array", "index", "axis"]
|
|
536
|
+
|
|
537
|
+
@functools.cached_property
|
|
538
|
+
def chunks(self):
|
|
539
|
+
return self.array.chunks
|
|
540
|
+
|
|
541
|
+
def _layer(self) -> dict:
|
|
542
|
+
slices = [slice(None)] * len(self.array.chunks)
|
|
543
|
+
slices[self.axis] = list(self.index)
|
|
544
|
+
sl = tuple(slices)
|
|
545
|
+
chunk_tuples = list(product(*(range(len(c)) for i, c in enumerate(self.array.chunks))))
|
|
546
|
+
dsk = {
|
|
547
|
+
(self._name,) + ct: Task((self._name,) + ct, getitem, TaskRef((self.array.name,) + ct), sl)
|
|
548
|
+
for ct in chunk_tuples
|
|
549
|
+
}
|
|
550
|
+
return dsk
|