dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_routines.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
"""Re-exports from routines submodules and other locations.
|
|
2
|
+
|
|
3
|
+
This module maintains backward compatibility by re-exporting all routines
|
|
4
|
+
from their new locations.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from dask_array._core_utils import implements
|
|
12
|
+
|
|
13
|
+
# Re-exports from _blockwise
|
|
14
|
+
from dask_array._blockwise import outer # noqa: F401
|
|
15
|
+
|
|
16
|
+
# Re-exports from _ufunc
|
|
17
|
+
from dask_array._ufunc import ( # noqa: F401
|
|
18
|
+
allclose,
|
|
19
|
+
around,
|
|
20
|
+
isclose,
|
|
21
|
+
isnull,
|
|
22
|
+
notnull,
|
|
23
|
+
round,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Re-exports from routines submodules
|
|
27
|
+
from dask_array.routines._apply import ( # noqa: F401
|
|
28
|
+
apply_along_axis,
|
|
29
|
+
apply_over_axes,
|
|
30
|
+
)
|
|
31
|
+
from dask_array.routines._bincount import bincount # noqa: F401
|
|
32
|
+
from dask_array.routines._broadcast import ( # noqa: F401
|
|
33
|
+
broadcast_arrays,
|
|
34
|
+
unify_chunks,
|
|
35
|
+
)
|
|
36
|
+
from dask_array.routines._coarsen import ( # noqa: F401
|
|
37
|
+
Coarsen,
|
|
38
|
+
aligned_coarsen_chunks,
|
|
39
|
+
coarsen,
|
|
40
|
+
)
|
|
41
|
+
from dask_array.routines._gradient import gradient # noqa: F401
|
|
42
|
+
from dask_array.routines._indexing import ( # noqa: F401
|
|
43
|
+
ravel_multi_index,
|
|
44
|
+
unravel_index,
|
|
45
|
+
)
|
|
46
|
+
from dask_array.routines._insert_delete import ( # noqa: F401
|
|
47
|
+
append,
|
|
48
|
+
delete,
|
|
49
|
+
ediff1d,
|
|
50
|
+
insert,
|
|
51
|
+
)
|
|
52
|
+
from dask_array.routines._misc import ( # noqa: F401
|
|
53
|
+
compress,
|
|
54
|
+
ndim,
|
|
55
|
+
result_type,
|
|
56
|
+
shape,
|
|
57
|
+
take,
|
|
58
|
+
)
|
|
59
|
+
from dask_array.routines._nonzero import ( # noqa: F401
|
|
60
|
+
argwhere,
|
|
61
|
+
count_nonzero,
|
|
62
|
+
flatnonzero,
|
|
63
|
+
isnonzero,
|
|
64
|
+
nonzero,
|
|
65
|
+
)
|
|
66
|
+
from dask_array.routines._search import ( # noqa: F401
|
|
67
|
+
isin,
|
|
68
|
+
searchsorted,
|
|
69
|
+
)
|
|
70
|
+
from dask_array.routines._select import ( # noqa: F401
|
|
71
|
+
choose,
|
|
72
|
+
digitize,
|
|
73
|
+
extract,
|
|
74
|
+
piecewise,
|
|
75
|
+
select,
|
|
76
|
+
)
|
|
77
|
+
from dask_array.routines._statistics import ( # noqa: F401
|
|
78
|
+
average,
|
|
79
|
+
corrcoef,
|
|
80
|
+
cov,
|
|
81
|
+
)
|
|
82
|
+
from dask_array.routines._topk import argtopk, topk # noqa: F401
|
|
83
|
+
from dask_array.routines._triangular import ( # noqa: F401
|
|
84
|
+
tril,
|
|
85
|
+
tril_indices,
|
|
86
|
+
tril_indices_from,
|
|
87
|
+
triu,
|
|
88
|
+
triu_indices,
|
|
89
|
+
triu_indices_from,
|
|
90
|
+
)
|
|
91
|
+
from dask_array.routines._unique import union1d, unique # noqa: F401
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def ptp(a, axis=None):
|
|
95
|
+
"""Peak to peak (maximum - minimum) value along a given axis."""
|
|
96
|
+
return a.max(axis=axis) - a.min(axis=axis)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
@implements(np.iscomplexobj)
|
|
100
|
+
def iscomplexobj(x):
|
|
101
|
+
"""Check whether the input has a complex dtype."""
|
|
102
|
+
return issubclass(x.dtype.type, np.complexfloating)
|
dask_array/_shuffle.py
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import functools
|
|
5
|
+
import math
|
|
6
|
+
from functools import reduce
|
|
7
|
+
from itertools import count, product
|
|
8
|
+
from operator import mul
|
|
9
|
+
from typing import Literal
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from dask import config
|
|
14
|
+
from dask._task_spec import DataNode, List, Task, TaskRef
|
|
15
|
+
from dask_array._expr import ArrayExpr
|
|
16
|
+
from dask_array._chunk import getitem
|
|
17
|
+
from dask_array._dispatch import concatenate_lookup, take_lookup
|
|
18
|
+
from dask.base import tokenize
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _calculate_new_chunksizes(input_chunks, new_chunks, changeable_dimensions: set, maximum_chunk: int):
|
|
22
|
+
chunksize_tolerance = config.get("array.chunk-size-tolerance")
|
|
23
|
+
maximum_chunk = max(maximum_chunk, 1)
|
|
24
|
+
|
|
25
|
+
# iterate until we distributed the increase in chunksize across all dimensions
|
|
26
|
+
# or every non-shuffle dimension is all 1
|
|
27
|
+
while changeable_dimensions:
|
|
28
|
+
n_changeable_dimensions = len(changeable_dimensions)
|
|
29
|
+
chunksize_inc_factor = reduce(mul, map(max, new_chunks)) / maximum_chunk
|
|
30
|
+
if chunksize_inc_factor <= 1:
|
|
31
|
+
break
|
|
32
|
+
|
|
33
|
+
for i in list(changeable_dimensions):
|
|
34
|
+
new_chunksizes = []
|
|
35
|
+
# calculate what the max chunk size in this dimension is and split every
|
|
36
|
+
# chunk that is larger than that. We split the increase factor evenly
|
|
37
|
+
# between all dimensions that are not shuffled.
|
|
38
|
+
up_chunksize_limit_for_dim = max(new_chunks[i]) / (chunksize_inc_factor ** (1 / n_changeable_dimensions))
|
|
39
|
+
for c in input_chunks[i]:
|
|
40
|
+
if c > chunksize_tolerance * up_chunksize_limit_for_dim:
|
|
41
|
+
factor = math.ceil(c / up_chunksize_limit_for_dim)
|
|
42
|
+
|
|
43
|
+
# Ensure that we end up at least with chunksize 1
|
|
44
|
+
factor = min(factor, c)
|
|
45
|
+
|
|
46
|
+
chunksize, remainder = divmod(c, factor)
|
|
47
|
+
nc = [chunksize] * factor
|
|
48
|
+
for ii in range(remainder):
|
|
49
|
+
# Add remainder parts to the first few chunks
|
|
50
|
+
nc[ii] += 1
|
|
51
|
+
new_chunksizes.extend(nc)
|
|
52
|
+
|
|
53
|
+
else:
|
|
54
|
+
new_chunksizes.append(c)
|
|
55
|
+
|
|
56
|
+
if tuple(new_chunksizes) == new_chunks[i] or max(new_chunksizes) == 1:
|
|
57
|
+
changeable_dimensions.remove(i)
|
|
58
|
+
|
|
59
|
+
new_chunks[i] = tuple(new_chunksizes)
|
|
60
|
+
return new_chunks
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def _rechunk_other_dimensions(x, longest_group: int, axis: int, chunks: Literal["auto"]):
|
|
64
|
+
"""Rechunk other dimensions when shuffle groups are too large."""
|
|
65
|
+
assert chunks == "auto", "Only auto is supported for now"
|
|
66
|
+
chunksize_tolerance = config.get("array.chunk-size-tolerance")
|
|
67
|
+
|
|
68
|
+
if longest_group <= max(x.chunks[axis]) * chunksize_tolerance:
|
|
69
|
+
# We are staying below our threshold, so don't rechunk
|
|
70
|
+
return x
|
|
71
|
+
|
|
72
|
+
changeable_dimensions = set(range(len(x.chunks))) - {axis}
|
|
73
|
+
new_chunks = list(x.chunks)
|
|
74
|
+
new_chunks[axis] = (longest_group,)
|
|
75
|
+
|
|
76
|
+
# How large is the largest chunk in the input
|
|
77
|
+
maximum_chunk = reduce(mul, map(max, x.chunks))
|
|
78
|
+
|
|
79
|
+
new_chunks = _calculate_new_chunksizes(x.chunks, new_chunks, changeable_dimensions, maximum_chunk)
|
|
80
|
+
new_chunks[axis] = x.chunks[axis]
|
|
81
|
+
return x.rechunk(tuple(new_chunks))
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _validate_indexer(chunks, indexer, axis):
|
|
85
|
+
if not isinstance(indexer, list) or not all(isinstance(i, list) for i in indexer):
|
|
86
|
+
raise ValueError("indexer must be a list of lists of positional indices")
|
|
87
|
+
|
|
88
|
+
if not axis <= len(chunks):
|
|
89
|
+
raise ValueError(f"Axis {axis} is out of bounds for array with {len(chunks)} axes")
|
|
90
|
+
|
|
91
|
+
if max(map(max, indexer)) >= sum(chunks[axis]):
|
|
92
|
+
raise IndexError(f"Indexer contains out of bounds index. Dimension only has {sum(chunks[axis])} elements.")
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def shuffle(x, indexer: list[list[int]], axis: int, chunks: Literal["auto"] = "auto"):
|
|
96
|
+
"""
|
|
97
|
+
Reorders one dimensions of a Dask Array based on an indexer.
|
|
98
|
+
|
|
99
|
+
The indexer defines a list of positional groups that will end up in the same chunk
|
|
100
|
+
together. A single group is in at most one chunk on this dimension, but a chunk
|
|
101
|
+
might contain multiple groups to avoid fragmentation of the array.
|
|
102
|
+
|
|
103
|
+
The algorithm tries to balance the chunksizes as much as possible to ideally keep the
|
|
104
|
+
number of chunks consistent or at least manageable.
|
|
105
|
+
|
|
106
|
+
Parameters
|
|
107
|
+
----------
|
|
108
|
+
x: dask array
|
|
109
|
+
Array to be shuffled.
|
|
110
|
+
indexer: list[list[int]]
|
|
111
|
+
The indexer that determines which elements along the dimension will end up in the
|
|
112
|
+
same chunk. Multiple groups can be in the same chunk to avoid fragmentation, but
|
|
113
|
+
each group will end up in exactly one chunk.
|
|
114
|
+
axis: int
|
|
115
|
+
The axis to shuffle along.
|
|
116
|
+
chunks: "auto"
|
|
117
|
+
Hint on how to rechunk if single groups are becoming too large. The default is
|
|
118
|
+
to split chunks along the other dimensions evenly to keep the chunksize
|
|
119
|
+
consistent. The rechunking is done in a way that ensures that non all-to-all
|
|
120
|
+
network communication is necessary, chunks are only split and not combined with
|
|
121
|
+
other chunks.
|
|
122
|
+
|
|
123
|
+
Examples
|
|
124
|
+
--------
|
|
125
|
+
>>> import dask_array as da
|
|
126
|
+
>>> import numpy as np
|
|
127
|
+
>>> arr = np.array([[1, 2, 3, 4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15, 16]])
|
|
128
|
+
>>> x = da.from_array(arr, chunks=(2, 4))
|
|
129
|
+
|
|
130
|
+
Separate the elements in different groups.
|
|
131
|
+
|
|
132
|
+
>>> y = x.shuffle([[6, 5, 2], [4, 1], [3, 0, 7]], axis=1)
|
|
133
|
+
|
|
134
|
+
The shuffle algorihthm will combine the first 2 groups into a single chunk to keep
|
|
135
|
+
the number of chunks small.
|
|
136
|
+
|
|
137
|
+
The tolerance of increasing the chunk size is controlled by the configuration
|
|
138
|
+
"array.chunk-size-tolerance". The default value is 1.25.
|
|
139
|
+
|
|
140
|
+
>>> y.chunks
|
|
141
|
+
((2,), (5, 3))
|
|
142
|
+
|
|
143
|
+
The array was reordered along axis 1 according to the positional indexer that was given.
|
|
144
|
+
|
|
145
|
+
>>> y.compute()
|
|
146
|
+
array([[ 7, 6, 3, 5, 2, 4, 1, 8],
|
|
147
|
+
[15, 14, 11, 13, 10, 12, 9, 16]])
|
|
148
|
+
"""
|
|
149
|
+
from dask_array._new_collection import new_collection
|
|
150
|
+
|
|
151
|
+
if np.isnan(x.shape).any():
|
|
152
|
+
from dask_array._core_utils import unknown_chunk_message
|
|
153
|
+
|
|
154
|
+
raise ValueError(f"Shuffling only allowed with known chunk sizes. {unknown_chunk_message}")
|
|
155
|
+
assert isinstance(axis, int), "axis must be an integer"
|
|
156
|
+
_validate_indexer(x.chunks, indexer, axis)
|
|
157
|
+
|
|
158
|
+
x = _rechunk_other_dimensions(x, max(map(len, indexer)), axis, chunks)
|
|
159
|
+
|
|
160
|
+
name = "shuffle"
|
|
161
|
+
|
|
162
|
+
result = _shuffle(x.expr, indexer, axis, name)
|
|
163
|
+
return new_collection(result)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _shuffle(x, indexer, axis, name):
|
|
167
|
+
if len(indexer) == len(x.chunks[axis]):
|
|
168
|
+
# check if the array is already shuffled the way we want
|
|
169
|
+
ctr = 0
|
|
170
|
+
for idx, c in zip(indexer, x.chunks[axis]):
|
|
171
|
+
if idx != list(range(ctr, ctr + c)):
|
|
172
|
+
break
|
|
173
|
+
ctr += c
|
|
174
|
+
else:
|
|
175
|
+
return x
|
|
176
|
+
return Shuffle(x, indexer, axis, name)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
class Shuffle(ArrayExpr):
|
|
180
|
+
_parameters = ["array", "indexer", "axis", "name"]
|
|
181
|
+
|
|
182
|
+
@functools.cached_property
|
|
183
|
+
def _meta(self):
|
|
184
|
+
return self.array._meta
|
|
185
|
+
|
|
186
|
+
@functools.cached_property
|
|
187
|
+
def _name(self):
|
|
188
|
+
return f"{self.operand('name')}-{self.deterministic_token}"
|
|
189
|
+
|
|
190
|
+
@functools.cached_property
|
|
191
|
+
def chunks(self):
|
|
192
|
+
output_chunks = []
|
|
193
|
+
for i, c in enumerate(self.array.chunks):
|
|
194
|
+
if i == self.axis:
|
|
195
|
+
output_chunks.append(tuple(map(len, self._new_chunks)))
|
|
196
|
+
else:
|
|
197
|
+
output_chunks.append(c)
|
|
198
|
+
return tuple(output_chunks)
|
|
199
|
+
|
|
200
|
+
@functools.cached_property
|
|
201
|
+
def _chunk_size_limit(self):
|
|
202
|
+
"""Max input chunk size on the shuffle axis."""
|
|
203
|
+
return max(self.array.chunks[self.axis])
|
|
204
|
+
|
|
205
|
+
@functools.cached_property
|
|
206
|
+
def _new_chunks(self):
|
|
207
|
+
current_chunk, new_chunks = [], []
|
|
208
|
+
limit = self._chunk_size_limit
|
|
209
|
+
for idx in copy.deepcopy(self.indexer):
|
|
210
|
+
# Split oversized groups into limit-sized pieces
|
|
211
|
+
if len(idx) > limit:
|
|
212
|
+
# Flush current chunk first
|
|
213
|
+
if current_chunk:
|
|
214
|
+
new_chunks.append(current_chunk)
|
|
215
|
+
current_chunk = []
|
|
216
|
+
# Split large group into limit-sized pieces
|
|
217
|
+
for i in range(0, len(idx), limit):
|
|
218
|
+
new_chunks.append(idx[i : i + limit])
|
|
219
|
+
elif len(current_chunk) + len(idx) > limit and len(current_chunk) > 0:
|
|
220
|
+
new_chunks.append(current_chunk)
|
|
221
|
+
current_chunk = idx.copy()
|
|
222
|
+
else:
|
|
223
|
+
current_chunk.extend(idx)
|
|
224
|
+
if len(current_chunk) > limit:
|
|
225
|
+
new_chunks.append(current_chunk)
|
|
226
|
+
current_chunk = []
|
|
227
|
+
if len(current_chunk) > 0:
|
|
228
|
+
new_chunks.append(current_chunk)
|
|
229
|
+
return new_chunks
|
|
230
|
+
|
|
231
|
+
def _simplify_down(self):
|
|
232
|
+
"""Push shuffle through various operations using _accept_shuffle pattern."""
|
|
233
|
+
# Check if child can accept this shuffle
|
|
234
|
+
if hasattr(self.array, "_accept_shuffle"):
|
|
235
|
+
return self.array._accept_shuffle(self)
|
|
236
|
+
|
|
237
|
+
def _simplify_up(self, parent, dependents):
|
|
238
|
+
"""Allow slice operations to push through Shuffle."""
|
|
239
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
240
|
+
|
|
241
|
+
if isinstance(parent, SliceSlicesIntegers):
|
|
242
|
+
return self._accept_slice(parent)
|
|
243
|
+
return None
|
|
244
|
+
|
|
245
|
+
def _accept_slice(self, slice_expr):
|
|
246
|
+
"""Accept a slice being pushed through Shuffle.
|
|
247
|
+
|
|
248
|
+
Shuffle reorganizes data along a single axis. We can push slices through:
|
|
249
|
+
1. Non-shuffle axes: directly push through
|
|
250
|
+
2. Shuffle axis (step=1): if input indices are contiguous, slice input
|
|
251
|
+
and adjust the indexer
|
|
252
|
+
|
|
253
|
+
Example (non-shuffle axis):
|
|
254
|
+
Slice(Shuffle(x, axis=0), [:, 10:20])
|
|
255
|
+
-> Shuffle(Slice(x, [:, 10:20]), axis=0)
|
|
256
|
+
|
|
257
|
+
Example (shuffle axis, contiguous):
|
|
258
|
+
Slice(Shuffle(x, axis=0), [100:200, :])
|
|
259
|
+
-> Shuffle(Slice(x, [input_start:input_stop, :]), adjusted_indexer, axis=0)
|
|
260
|
+
"""
|
|
261
|
+
from dask_array._new_collection import new_collection
|
|
262
|
+
|
|
263
|
+
axis = self.axis
|
|
264
|
+
index = slice_expr.index
|
|
265
|
+
indexer = self.indexer
|
|
266
|
+
|
|
267
|
+
# Pad index to full length
|
|
268
|
+
full_index = list(index) + [slice(None)] * (len(self.shape) - len(index))
|
|
269
|
+
|
|
270
|
+
# Check if we're slicing on the shuffle axis
|
|
271
|
+
axis_slice = full_index[axis]
|
|
272
|
+
|
|
273
|
+
if axis_slice == slice(None):
|
|
274
|
+
# Not slicing shuffle axis - push through directly
|
|
275
|
+
sliced_input = new_collection(self.array)[tuple(full_index)]
|
|
276
|
+
return Shuffle(
|
|
277
|
+
sliced_input.expr,
|
|
278
|
+
indexer,
|
|
279
|
+
self.axis,
|
|
280
|
+
self.operand("name"),
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
# Slicing on shuffle axis - check if we can handle it
|
|
284
|
+
if not isinstance(axis_slice, slice):
|
|
285
|
+
return None # Integer indexing removes the dimension
|
|
286
|
+
|
|
287
|
+
# Only handle step=1 slices
|
|
288
|
+
if axis_slice.step is not None and axis_slice.step != 1:
|
|
289
|
+
return None
|
|
290
|
+
|
|
291
|
+
# Normalize slice bounds
|
|
292
|
+
axis_size = self.shape[axis]
|
|
293
|
+
start, stop, _ = axis_slice.indices(axis_size)
|
|
294
|
+
if start >= stop:
|
|
295
|
+
return None # Empty slice
|
|
296
|
+
|
|
297
|
+
# Slice by output positions. The indexer is grouped into runs, not
|
|
298
|
+
# one entry per output element.
|
|
299
|
+
new_indexer = []
|
|
300
|
+
offset = 0
|
|
301
|
+
for chunk in indexer:
|
|
302
|
+
chunk_start = offset
|
|
303
|
+
chunk_stop = offset + len(chunk)
|
|
304
|
+
offset = chunk_stop
|
|
305
|
+
|
|
306
|
+
take_start = max(start, chunk_start)
|
|
307
|
+
take_stop = min(stop, chunk_stop)
|
|
308
|
+
if take_start < take_stop:
|
|
309
|
+
new_indexer.append(chunk[take_start - chunk_start : take_stop - chunk_start])
|
|
310
|
+
|
|
311
|
+
# Find all input indices needed
|
|
312
|
+
input_indices = set()
|
|
313
|
+
for chunk in new_indexer:
|
|
314
|
+
input_indices.update(chunk)
|
|
315
|
+
|
|
316
|
+
if not input_indices:
|
|
317
|
+
return None # No indices
|
|
318
|
+
|
|
319
|
+
input_min = min(input_indices)
|
|
320
|
+
input_max = max(input_indices)
|
|
321
|
+
|
|
322
|
+
# Check if input indices are contiguous
|
|
323
|
+
if len(input_indices) != input_max - input_min + 1:
|
|
324
|
+
return None # Non-contiguous, can't use simple slice
|
|
325
|
+
|
|
326
|
+
# Adjust indexer: subtract input_min from each index
|
|
327
|
+
adjusted_indexer = [[idx - input_min for idx in chunk] for chunk in new_indexer]
|
|
328
|
+
|
|
329
|
+
# Build slice for input array
|
|
330
|
+
input_slice = list(full_index)
|
|
331
|
+
input_slice[axis] = slice(input_min, input_max + 1)
|
|
332
|
+
|
|
333
|
+
sliced_input = new_collection(self.array)[tuple(input_slice)]
|
|
334
|
+
|
|
335
|
+
return Shuffle(
|
|
336
|
+
sliced_input.expr,
|
|
337
|
+
adjusted_indexer,
|
|
338
|
+
self.axis,
|
|
339
|
+
self.operand("name"),
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
def _layer(self) -> dict:
|
|
343
|
+
chunks = self.array.chunks
|
|
344
|
+
axis = self.axis
|
|
345
|
+
|
|
346
|
+
chunk_boundaries = np.cumsum(chunks[axis])
|
|
347
|
+
|
|
348
|
+
# Get existing chunk tuple locations
|
|
349
|
+
chunk_tuples = list(product(*(range(len(c)) for i, c in enumerate(chunks) if i != axis)))
|
|
350
|
+
|
|
351
|
+
intermediates: dict = dict()
|
|
352
|
+
merges: dict = dict()
|
|
353
|
+
dtype = np.min_scalar_type(max(*chunks[axis], self._chunk_size_limit))
|
|
354
|
+
split_name = f"shuffle-split-{self.deterministic_token}"
|
|
355
|
+
slices = [slice(None)] * len(chunks)
|
|
356
|
+
split_name_suffixes = count()
|
|
357
|
+
sorter_name = "shuffle-sorter-"
|
|
358
|
+
taker_name = "shuffle-taker-"
|
|
359
|
+
|
|
360
|
+
old_blocks = {
|
|
361
|
+
old_index: (self.array._name,) + old_index for old_index in np.ndindex(tuple([len(c) for c in chunks]))
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
for new_chunk_idx, new_chunk_taker in enumerate(self._new_chunks):
|
|
365
|
+
new_chunk_taker = np.array(new_chunk_taker)
|
|
366
|
+
sorter = np.argsort(new_chunk_taker).astype(dtype)
|
|
367
|
+
sorter_key = sorter_name + tokenize(sorter)
|
|
368
|
+
# low level fusion can't deal with arrays on first position
|
|
369
|
+
merges[sorter_key] = DataNode(sorter_key, (1, sorter))
|
|
370
|
+
|
|
371
|
+
sorted_array = new_chunk_taker[sorter]
|
|
372
|
+
source_chunk_nr, taker_boundary_ = np.unique(
|
|
373
|
+
np.searchsorted(chunk_boundaries, sorted_array, side="right"),
|
|
374
|
+
return_index=True,
|
|
375
|
+
)
|
|
376
|
+
taker_boundary: list[int] = taker_boundary_.tolist()
|
|
377
|
+
taker_boundary.append(len(new_chunk_taker))
|
|
378
|
+
|
|
379
|
+
taker_cache: dict = {}
|
|
380
|
+
for chunk_tuple in chunk_tuples:
|
|
381
|
+
merge_keys = []
|
|
382
|
+
|
|
383
|
+
for c, b_start, b_end in zip(source_chunk_nr, taker_boundary[:-1], taker_boundary[1:]):
|
|
384
|
+
# insert our axis chunk id into the chunk_tuple
|
|
385
|
+
chunk_key = convert_key(chunk_tuple, c, axis)
|
|
386
|
+
name = (split_name, next(split_name_suffixes))
|
|
387
|
+
this_slice = slices.copy()
|
|
388
|
+
|
|
389
|
+
# Cache the takers to allow de-duplication when serializing
|
|
390
|
+
# Ugly!
|
|
391
|
+
if c in taker_cache:
|
|
392
|
+
taker_key = taker_cache[c]
|
|
393
|
+
else:
|
|
394
|
+
this_slice[axis] = (
|
|
395
|
+
sorted_array[b_start:b_end] - (chunk_boundaries[c - 1] if c > 0 else 0)
|
|
396
|
+
).astype(dtype)
|
|
397
|
+
if len(source_chunk_nr) == 1:
|
|
398
|
+
this_slice[axis] = this_slice[axis][np.argsort(sorter)]
|
|
399
|
+
|
|
400
|
+
taker_key = taker_name + tokenize(this_slice)
|
|
401
|
+
# low level fusion can't deal with arrays on first position
|
|
402
|
+
intermediates[taker_key] = DataNode(taker_key, (1, tuple(this_slice)))
|
|
403
|
+
taker_cache[c] = taker_key
|
|
404
|
+
|
|
405
|
+
intermediates[name] = Task(
|
|
406
|
+
name,
|
|
407
|
+
_getitem,
|
|
408
|
+
TaskRef(old_blocks[chunk_key]),
|
|
409
|
+
TaskRef(taker_key),
|
|
410
|
+
)
|
|
411
|
+
merge_keys.append(name)
|
|
412
|
+
|
|
413
|
+
merge_suffix = convert_key(chunk_tuple, new_chunk_idx, axis)
|
|
414
|
+
out_name_merge = (self._name,) + merge_suffix
|
|
415
|
+
if len(merge_keys) > 1:
|
|
416
|
+
merges[out_name_merge] = Task(
|
|
417
|
+
out_name_merge,
|
|
418
|
+
concatenate_arrays,
|
|
419
|
+
List(*(TaskRef(m) for m in merge_keys)),
|
|
420
|
+
TaskRef(sorter_key),
|
|
421
|
+
axis,
|
|
422
|
+
)
|
|
423
|
+
elif len(merge_keys) == 1:
|
|
424
|
+
t = intermediates.pop(merge_keys[0])
|
|
425
|
+
t.key = out_name_merge
|
|
426
|
+
merges[out_name_merge] = t
|
|
427
|
+
else:
|
|
428
|
+
raise NotImplementedError
|
|
429
|
+
|
|
430
|
+
return {**merges, **intermediates}
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def _getitem(obj, index):
|
|
434
|
+
return getitem(obj, index[1])
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def concatenate_arrays(arrs, sorter, axis):
|
|
438
|
+
return take_lookup(
|
|
439
|
+
concatenate_lookup.dispatch(type(arrs[0]))(arrs, axis=axis),
|
|
440
|
+
np.argsort(sorter[1]),
|
|
441
|
+
axis=axis,
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
def convert_key(key, chunk, axis):
|
|
446
|
+
key = list(key)
|
|
447
|
+
key.insert(axis, int(chunk)) # Normalize np.int64 to Python int
|
|
448
|
+
return tuple(key)
|