dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import math
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from itertools import product
|
|
7
|
+
from numbers import Number
|
|
8
|
+
from operator import mul
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
from dask._task_spec import List, Task, TaskRef
|
|
13
|
+
from dask_array._expr import ArrayExpr
|
|
14
|
+
from dask_array._utils import meta_from_array
|
|
15
|
+
from dask_array.slicing._utils import replace_ellipsis
|
|
16
|
+
from dask_array._core_utils import (
|
|
17
|
+
_get_axis,
|
|
18
|
+
_vindex_merge,
|
|
19
|
+
_vindex_slice_and_transpose,
|
|
20
|
+
interleave_none,
|
|
21
|
+
keyname,
|
|
22
|
+
)
|
|
23
|
+
from dask.utils import cached_cumsum, cached_max
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _numpy_vindex(indexer, arr):
|
|
27
|
+
"""Helper for vindex with single-block arrays indexed by dask arrays."""
|
|
28
|
+
return arr[indexer]
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _vindex(x, *indexes):
|
|
32
|
+
"""Point wise indexing with broadcasting.
|
|
33
|
+
|
|
34
|
+
>>> x = np.arange(56).reshape((7, 8))
|
|
35
|
+
>>> x
|
|
36
|
+
array([[ 0, 1, 2, 3, 4, 5, 6, 7],
|
|
37
|
+
[ 8, 9, 10, 11, 12, 13, 14, 15],
|
|
38
|
+
[16, 17, 18, 19, 20, 21, 22, 23],
|
|
39
|
+
[24, 25, 26, 27, 28, 29, 30, 31],
|
|
40
|
+
[32, 33, 34, 35, 36, 37, 38, 39],
|
|
41
|
+
[40, 41, 42, 43, 44, 45, 46, 47],
|
|
42
|
+
[48, 49, 50, 51, 52, 53, 54, 55]])
|
|
43
|
+
|
|
44
|
+
>>> from dask_array._collection import from_array
|
|
45
|
+
>>> d = from_array(x, chunks=(3, 4))
|
|
46
|
+
>>> result = _vindex(d, [0, 1, 6, 0], [0, 1, 0, 7])
|
|
47
|
+
>>> result.compute()
|
|
48
|
+
array([ 0, 9, 48, 7])
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
indexes = replace_ellipsis(x.ndim, indexes)
|
|
52
|
+
|
|
53
|
+
nonfancy_indexes = []
|
|
54
|
+
reduced_indexes = []
|
|
55
|
+
for ind in indexes:
|
|
56
|
+
if isinstance(ind, Number):
|
|
57
|
+
nonfancy_indexes.append(ind)
|
|
58
|
+
elif isinstance(ind, slice):
|
|
59
|
+
nonfancy_indexes.append(ind)
|
|
60
|
+
reduced_indexes.append(slice(None))
|
|
61
|
+
else:
|
|
62
|
+
nonfancy_indexes.append(slice(None))
|
|
63
|
+
reduced_indexes.append(ind)
|
|
64
|
+
|
|
65
|
+
nonfancy_indexes = tuple(nonfancy_indexes)
|
|
66
|
+
reduced_indexes = tuple(reduced_indexes)
|
|
67
|
+
|
|
68
|
+
x = x[nonfancy_indexes]
|
|
69
|
+
|
|
70
|
+
array_indexes = {}
|
|
71
|
+
for i, (ind, size) in enumerate(zip(reduced_indexes, x.shape)):
|
|
72
|
+
if not isinstance(ind, slice):
|
|
73
|
+
ind = np.array(ind, copy=True)
|
|
74
|
+
if ind.dtype.kind == "b":
|
|
75
|
+
raise IndexError("vindex does not support indexing with boolean arrays")
|
|
76
|
+
if ((ind >= size) | (ind < -size)).any():
|
|
77
|
+
raise IndexError(
|
|
78
|
+
f"vindex key has entries out of bounds for indexing along axis {i} of size {size}: {ind!r}"
|
|
79
|
+
)
|
|
80
|
+
ind %= size
|
|
81
|
+
array_indexes[i] = ind
|
|
82
|
+
|
|
83
|
+
if array_indexes:
|
|
84
|
+
x = _vindex_array(x, array_indexes)
|
|
85
|
+
|
|
86
|
+
return x
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def _compute_indexer(index, chunks_along_axis):
|
|
90
|
+
"""Compute Shuffle indexer by grouping consecutive indices from same input chunk.
|
|
91
|
+
|
|
92
|
+
Returns a list of lists, where each inner list contains indices that come from
|
|
93
|
+
a contiguous run accessing the same input chunk. This preserves locality for
|
|
94
|
+
patterns like np.repeat.
|
|
95
|
+
"""
|
|
96
|
+
chunk_boundaries = np.cumsum((0,) + chunks_along_axis)
|
|
97
|
+
input_chunk_ids = np.searchsorted(chunk_boundaries[1:], index, side="right")
|
|
98
|
+
changes = np.concatenate([[0], np.where(np.diff(input_chunk_ids) != 0)[0] + 1, [len(index)]])
|
|
99
|
+
return [index[changes[i] : changes[i + 1]].tolist() for i in range(len(changes) - 1)]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _vindex_array(x, dict_indexes):
|
|
103
|
+
"""Point wise indexing with only NumPy Arrays."""
|
|
104
|
+
from dask_array._new_collection import new_collection
|
|
105
|
+
from dask_array.creation import empty
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
broadcast_shape = np.broadcast_shapes(*(arr.shape for arr in dict_indexes.values()))
|
|
109
|
+
except ValueError as e:
|
|
110
|
+
shapes_str = " ".join(str(a.shape) for a in dict_indexes.values())
|
|
111
|
+
raise IndexError(
|
|
112
|
+
f"shape mismatch: indexing arrays could not be broadcast together with shapes {shapes_str}"
|
|
113
|
+
) from e
|
|
114
|
+
npoints = math.prod(broadcast_shape)
|
|
115
|
+
|
|
116
|
+
# Single-axis case: delegate to Shuffle for optimization hooks
|
|
117
|
+
if len(dict_indexes) == 1 and npoints > 0:
|
|
118
|
+
from dask_array._shuffle import _shuffle
|
|
119
|
+
|
|
120
|
+
axis = next(iter(dict_indexes.keys()))
|
|
121
|
+
index = next(iter(dict_indexes.values())).ravel()
|
|
122
|
+
indexer = _compute_indexer(index, x.chunks[axis])
|
|
123
|
+
|
|
124
|
+
result = new_collection(_shuffle(x.expr, indexer, axis, "vindex-"))
|
|
125
|
+
# Shuffle keeps axis in place; reshape for broadcast_shape along that axis
|
|
126
|
+
new_shape = list(result.shape)
|
|
127
|
+
new_shape[axis : axis + 1] = list(broadcast_shape)
|
|
128
|
+
return result.reshape(tuple(new_shape))
|
|
129
|
+
|
|
130
|
+
if npoints > 0:
|
|
131
|
+
result_1d = new_collection(VIndexArray(x.expr, dict_indexes, broadcast_shape, npoints))
|
|
132
|
+
return result_1d.reshape(broadcast_shape + result_1d.shape[1:])
|
|
133
|
+
|
|
134
|
+
# output has zero dimension - just create a new zero-shape array
|
|
135
|
+
axes = [i for i in range(x.ndim) if i in dict_indexes]
|
|
136
|
+
chunks = [c for i, c in enumerate(x.chunks) if i not in axes]
|
|
137
|
+
chunks.insert(0, (0,))
|
|
138
|
+
chunks = tuple(chunks)
|
|
139
|
+
|
|
140
|
+
result_1d = empty(tuple(map(sum, chunks)), chunks=chunks, dtype=x.dtype)
|
|
141
|
+
return result_1d.reshape(broadcast_shape + result_1d.shape[1:])
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class VIndexArray(ArrayExpr):
|
|
145
|
+
"""Point-wise vectorized indexing with broadcasting.
|
|
146
|
+
|
|
147
|
+
Used for multi-axis fancy indexing where indices broadcast together.
|
|
148
|
+
Single-axis cases delegate to Shuffle for optimization hooks.
|
|
149
|
+
"""
|
|
150
|
+
|
|
151
|
+
_parameters = ["array", "dict_indexes", "broadcast_shape", "npoints"]
|
|
152
|
+
|
|
153
|
+
@functools.cached_property
|
|
154
|
+
def _name(self):
|
|
155
|
+
return f"vindex-merge-{self.deterministic_token}"
|
|
156
|
+
|
|
157
|
+
@functools.cached_property
|
|
158
|
+
def _meta(self):
|
|
159
|
+
return meta_from_array(self.array._meta, ndim=len(self.chunks))
|
|
160
|
+
|
|
161
|
+
@functools.cached_property
|
|
162
|
+
def _axes(self):
|
|
163
|
+
"""Axes that have array indexing."""
|
|
164
|
+
return [i for i in range(self.array.ndim) if i in self.dict_indexes]
|
|
165
|
+
|
|
166
|
+
def _subset_to_indexed_axes(self, iterable):
|
|
167
|
+
for i, elem in enumerate(iterable):
|
|
168
|
+
if i in self._axes:
|
|
169
|
+
yield elem
|
|
170
|
+
|
|
171
|
+
@functools.cached_property
|
|
172
|
+
def _max_chunk_point_dimensions(self):
|
|
173
|
+
return functools.reduce(mul, map(cached_max, self._subset_to_indexed_axes(self.array.chunks)))
|
|
174
|
+
|
|
175
|
+
@functools.cached_property
|
|
176
|
+
def chunks(self):
|
|
177
|
+
axes = self._axes
|
|
178
|
+
npoints = self.npoints
|
|
179
|
+
max_chunk_point_dimensions = self._max_chunk_point_dimensions
|
|
180
|
+
|
|
181
|
+
chunks = [c for i, c in enumerate(self.array.chunks) if i not in axes]
|
|
182
|
+
|
|
183
|
+
n_chunks, remainder = divmod(npoints, max_chunk_point_dimensions)
|
|
184
|
+
chunks.insert(
|
|
185
|
+
0,
|
|
186
|
+
(
|
|
187
|
+
(max_chunk_point_dimensions,) * n_chunks + ((remainder,) if remainder > 0 else ())
|
|
188
|
+
if npoints > 0
|
|
189
|
+
else (0,)
|
|
190
|
+
),
|
|
191
|
+
)
|
|
192
|
+
return tuple(chunks)
|
|
193
|
+
|
|
194
|
+
def _layer(self) -> dict:
|
|
195
|
+
dict_indexes = self.dict_indexes
|
|
196
|
+
broadcast_shape = self.broadcast_shape
|
|
197
|
+
npoints = self.npoints
|
|
198
|
+
axes = self._axes
|
|
199
|
+
|
|
200
|
+
bounds2 = tuple(
|
|
201
|
+
np.array(cached_cumsum(c, initial_zero=True)) for c in self._subset_to_indexed_axes(self.array.chunks)
|
|
202
|
+
)
|
|
203
|
+
axis = _get_axis(tuple(i if i in axes else None for i in range(self.array.ndim)))
|
|
204
|
+
|
|
205
|
+
# Now compute indices of each output element within each input block
|
|
206
|
+
block_idxs = tuple(np.searchsorted(b, ind, side="right") - 1 for b, ind in zip(bounds2, dict_indexes.values()))
|
|
207
|
+
starts = (b[i] for i, b in zip(block_idxs, bounds2))
|
|
208
|
+
inblock_idxs = []
|
|
209
|
+
for idx, start in zip(dict_indexes.values(), starts):
|
|
210
|
+
# Convert unsigned integers to signed to avoid float promotion in subtraction
|
|
211
|
+
if idx.dtype.kind == "u":
|
|
212
|
+
idx = idx.astype(np.int64)
|
|
213
|
+
a = idx - start
|
|
214
|
+
if len(a) > 0:
|
|
215
|
+
dtype = np.min_scalar_type(np.max(a, axis=None))
|
|
216
|
+
inblock_idxs.append(a.astype(dtype, copy=False))
|
|
217
|
+
else:
|
|
218
|
+
inblock_idxs.append(a)
|
|
219
|
+
|
|
220
|
+
inblock_idxs = np.broadcast_arrays(*inblock_idxs) # type: ignore[assignment]
|
|
221
|
+
|
|
222
|
+
max_chunk_point_dimensions = self._max_chunk_point_dimensions
|
|
223
|
+
n_chunks, remainder = divmod(npoints, max_chunk_point_dimensions)
|
|
224
|
+
|
|
225
|
+
other_blocks = product(*[range(len(c)) if i not in axes else [None] for i, c in enumerate(self.array.chunks)])
|
|
226
|
+
|
|
227
|
+
full_slices = [slice(None, None) if i not in axes else None for i in range(self.array.ndim)]
|
|
228
|
+
|
|
229
|
+
# The output is constructed as a new dimension and then reshaped
|
|
230
|
+
outinds = np.arange(npoints).reshape(broadcast_shape)
|
|
231
|
+
outblocks, outblock_idx = np.divmod(outinds, max_chunk_point_dimensions)
|
|
232
|
+
|
|
233
|
+
ravel_shape = (
|
|
234
|
+
n_chunks + 1,
|
|
235
|
+
*self._subset_to_indexed_axes(self.array.numblocks),
|
|
236
|
+
)
|
|
237
|
+
keys = np.ravel_multi_index([outblocks, *block_idxs], ravel_shape)
|
|
238
|
+
sortidx = np.argsort(keys, axis=None)
|
|
239
|
+
sorted_keys = keys.flat[sortidx]
|
|
240
|
+
sorted_inblock_idxs = [_.flat[sortidx] for _ in inblock_idxs]
|
|
241
|
+
sorted_outblock_idx = outblock_idx.flat[sortidx]
|
|
242
|
+
dtype = np.min_scalar_type(max_chunk_point_dimensions)
|
|
243
|
+
sorted_outblock_idx = sorted_outblock_idx.astype(dtype, copy=False)
|
|
244
|
+
flag = np.concatenate([[True], sorted_keys[1:] != sorted_keys[:-1], [True]])
|
|
245
|
+
(key_bounds,) = flag.nonzero()
|
|
246
|
+
|
|
247
|
+
slice_name = f"vindex-slice-{self.deterministic_token}"
|
|
248
|
+
dsk = {}
|
|
249
|
+
|
|
250
|
+
for okey in other_blocks:
|
|
251
|
+
merge_inputs = defaultdict(list)
|
|
252
|
+
merge_indexer = defaultdict(list)
|
|
253
|
+
for i, (start, stop) in enumerate(zip(key_bounds[:-1], key_bounds[1:], strict=True)):
|
|
254
|
+
slicer = slice(start, stop)
|
|
255
|
+
key = sorted_keys[start]
|
|
256
|
+
outblock, *input_blocks = np.unravel_index(key, ravel_shape)
|
|
257
|
+
inblock = [_[slicer] for _ in sorted_inblock_idxs]
|
|
258
|
+
k = keyname(slice_name, i, okey)
|
|
259
|
+
dsk[k] = Task(
|
|
260
|
+
k,
|
|
261
|
+
_vindex_slice_and_transpose,
|
|
262
|
+
TaskRef((self.array._name,) + interleave_none(okey, input_blocks)),
|
|
263
|
+
interleave_none(full_slices, inblock),
|
|
264
|
+
axis,
|
|
265
|
+
)
|
|
266
|
+
merge_inputs[outblock].append(TaskRef(k))
|
|
267
|
+
merge_indexer[outblock].append(sorted_outblock_idx[slicer])
|
|
268
|
+
|
|
269
|
+
for i in merge_inputs.keys():
|
|
270
|
+
k = keyname(self._name, i, okey)
|
|
271
|
+
dsk[k] = Task(
|
|
272
|
+
k,
|
|
273
|
+
_vindex_merge,
|
|
274
|
+
merge_indexer[i],
|
|
275
|
+
List(*merge_inputs[i]),
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return dsk
|
|
279
|
+
|
|
280
|
+
def __dask_keys__(self):
|
|
281
|
+
# Override to return 1D keys since we reshape after
|
|
282
|
+
return [(self._name,) + idx for idx in np.ndindex(tuple(len(c) for c in self.chunks))]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
"""Stacking and concatenation functions."""
|
|
2
|
+
|
|
3
|
+
from dask_array._concatenate import concatenate
|
|
4
|
+
from dask_array._stack import stack
|
|
5
|
+
from dask_array.stacking._block import block
|
|
6
|
+
from dask_array.stacking._simple import dstack, hstack, vstack
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
"stack",
|
|
10
|
+
"concatenate",
|
|
11
|
+
"block",
|
|
12
|
+
"vstack",
|
|
13
|
+
"hstack",
|
|
14
|
+
"dstack",
|
|
15
|
+
]
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Block operation."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def block(arrays, allow_unknown_chunksizes=False):
|
|
7
|
+
"""
|
|
8
|
+
Assemble an nd-array from nested lists of blocks.
|
|
9
|
+
|
|
10
|
+
Blocks in the innermost lists are concatenated along the last
|
|
11
|
+
dimension (-1), then these are concatenated along the second-last
|
|
12
|
+
dimension (-2), and so on until the outermost list is reached.
|
|
13
|
+
|
|
14
|
+
See Also
|
|
15
|
+
--------
|
|
16
|
+
numpy.block
|
|
17
|
+
"""
|
|
18
|
+
# Import here to avoid circular imports
|
|
19
|
+
from dask_array._collection import asanyarray, concatenate
|
|
20
|
+
from dask_array._numpy_compat import _Recurser
|
|
21
|
+
|
|
22
|
+
def atleast_nd(x, ndim):
|
|
23
|
+
x = asanyarray(x)
|
|
24
|
+
diff = max(ndim - x.ndim, 0)
|
|
25
|
+
if diff == 0:
|
|
26
|
+
return x
|
|
27
|
+
else:
|
|
28
|
+
return x[(None,) * diff + (Ellipsis,)]
|
|
29
|
+
|
|
30
|
+
def format_index(index):
|
|
31
|
+
return "arrays" + "".join(f"[{i}]" for i in index)
|
|
32
|
+
|
|
33
|
+
rec = _Recurser(recurse_if=lambda x: type(x) is list)
|
|
34
|
+
|
|
35
|
+
# Ensure that the lists are all matched in depth
|
|
36
|
+
list_ndim = None
|
|
37
|
+
any_empty = False
|
|
38
|
+
for index, value, entering in rec.walk(arrays):
|
|
39
|
+
if type(value) is tuple:
|
|
40
|
+
raise TypeError(
|
|
41
|
+
f"{format_index(index)} is a tuple. "
|
|
42
|
+
"Only lists can be used to arrange blocks, and np.block does "
|
|
43
|
+
"not allow implicit conversion from tuple to ndarray."
|
|
44
|
+
)
|
|
45
|
+
if not entering:
|
|
46
|
+
curr_depth = len(index)
|
|
47
|
+
elif len(value) == 0:
|
|
48
|
+
curr_depth = len(index) + 1
|
|
49
|
+
any_empty = True
|
|
50
|
+
else:
|
|
51
|
+
continue
|
|
52
|
+
|
|
53
|
+
if list_ndim is not None and list_ndim != curr_depth:
|
|
54
|
+
raise ValueError(
|
|
55
|
+
f"List depths are mismatched. First element was at depth {list_ndim}, "
|
|
56
|
+
f"but there is an element at depth {curr_depth} ({format_index(index)})"
|
|
57
|
+
)
|
|
58
|
+
list_ndim = curr_depth
|
|
59
|
+
|
|
60
|
+
# Do this here so we catch depth mismatches first
|
|
61
|
+
if any_empty:
|
|
62
|
+
raise ValueError("Lists cannot be empty")
|
|
63
|
+
|
|
64
|
+
# Convert all the arrays to ndarrays
|
|
65
|
+
arrays = rec.map_reduce(arrays, f_map=asanyarray, f_reduce=list)
|
|
66
|
+
|
|
67
|
+
# Determine the maximum dimension of the elements
|
|
68
|
+
elem_ndim = rec.map_reduce(arrays, f_map=lambda xi: xi.ndim, f_reduce=max)
|
|
69
|
+
ndim = max(list_ndim, elem_ndim)
|
|
70
|
+
|
|
71
|
+
# First axis to concatenate along
|
|
72
|
+
first_axis = ndim - list_ndim
|
|
73
|
+
|
|
74
|
+
# Make all the elements the same dimension
|
|
75
|
+
arrays = rec.map_reduce(arrays, f_map=lambda xi: atleast_nd(xi, ndim), f_reduce=list)
|
|
76
|
+
|
|
77
|
+
# Concatenate innermost lists on the right, outermost on the left
|
|
78
|
+
return rec.map_reduce(
|
|
79
|
+
arrays,
|
|
80
|
+
f_reduce=lambda xs, axis: concatenate(list(xs), axis=axis, allow_unknown_chunksizes=allow_unknown_chunksizes),
|
|
81
|
+
f_kwargs=lambda axis: dict(axis=(axis + 1)),
|
|
82
|
+
axis=first_axis,
|
|
83
|
+
)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
"""Simple stacking operations: vstack, hstack, dstack."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def vstack(tup, allow_unknown_chunksizes=False):
|
|
7
|
+
"""Stack arrays in sequence vertically (row wise).
|
|
8
|
+
|
|
9
|
+
See Also
|
|
10
|
+
--------
|
|
11
|
+
numpy.vstack
|
|
12
|
+
"""
|
|
13
|
+
# Import here to avoid circular imports
|
|
14
|
+
from dask_array._collection import Array, concatenate
|
|
15
|
+
from dask_array.manipulation._expand import atleast_2d
|
|
16
|
+
|
|
17
|
+
if isinstance(tup, Array):
|
|
18
|
+
raise NotImplementedError("``vstack`` expects a sequence of arrays as the first argument")
|
|
19
|
+
|
|
20
|
+
tup = tuple(atleast_2d(x) for x in tup)
|
|
21
|
+
return concatenate(tup, axis=0, allow_unknown_chunksizes=allow_unknown_chunksizes)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def hstack(tup, allow_unknown_chunksizes=False):
|
|
25
|
+
"""Stack arrays in sequence horizontally (column wise).
|
|
26
|
+
|
|
27
|
+
See Also
|
|
28
|
+
--------
|
|
29
|
+
numpy.hstack
|
|
30
|
+
"""
|
|
31
|
+
# Import here to avoid circular imports
|
|
32
|
+
from dask_array._collection import Array, concatenate
|
|
33
|
+
|
|
34
|
+
if isinstance(tup, Array):
|
|
35
|
+
raise NotImplementedError("``hstack`` expects a sequence of arrays as the first argument")
|
|
36
|
+
|
|
37
|
+
if all(x.ndim == 1 for x in tup):
|
|
38
|
+
return concatenate(tup, axis=0, allow_unknown_chunksizes=allow_unknown_chunksizes)
|
|
39
|
+
else:
|
|
40
|
+
return concatenate(tup, axis=1, allow_unknown_chunksizes=allow_unknown_chunksizes)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def dstack(tup, allow_unknown_chunksizes=False):
|
|
44
|
+
"""Stack arrays in sequence depth wise (along third axis).
|
|
45
|
+
|
|
46
|
+
See Also
|
|
47
|
+
--------
|
|
48
|
+
numpy.dstack
|
|
49
|
+
"""
|
|
50
|
+
# Import here to avoid circular imports
|
|
51
|
+
from dask_array._collection import Array, concatenate
|
|
52
|
+
from dask_array.manipulation._expand import atleast_3d
|
|
53
|
+
|
|
54
|
+
if isinstance(tup, Array):
|
|
55
|
+
raise NotImplementedError("``dstack`` expects a sequence of arrays as the first argument")
|
|
56
|
+
|
|
57
|
+
tup = tuple(atleast_3d(x) for x in tup)
|
|
58
|
+
return concatenate(tup, axis=2, allow_unknown_chunksizes=allow_unknown_chunksizes)
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
<style>
|
|
2
|
+
.dask-array-repr .dask-table-header { color: var(--jp-ui-font-color2, #78716c); }
|
|
3
|
+
.dask-array-repr .dask-table-label { color: var(--jp-ui-font-color3, #a8a29e); }
|
|
4
|
+
.dask-array-repr .dask-table-data { color: var(--jp-ui-font-color1, #1c1917); }
|
|
5
|
+
.dask-array-repr .dask-table-border { border-top: 1px solid var(--jp-border-color2, #e7e5e4); }
|
|
6
|
+
</style>
|
|
7
|
+
<details class="dask-array-repr" style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;">
|
|
8
|
+
<summary style="cursor: pointer; list-style: none;">
|
|
9
|
+
<table style="border-collapse: separate; border-spacing: 0; display: inline-table;">
|
|
10
|
+
<tr>
|
|
11
|
+
<td style="vertical-align: top;">
|
|
12
|
+
<table style="border-collapse: collapse; font-size: 14px;">
|
|
13
|
+
<thead>
|
|
14
|
+
<tr>
|
|
15
|
+
<td style="padding: 6px 12px;"></td>
|
|
16
|
+
<th class="dask-table-header" style="padding: 6px 12px; text-align: right; font-weight: 400;">Array</th>
|
|
17
|
+
<th class="dask-table-header" style="padding: 6px 12px; text-align: right; font-weight: 400;">Chunk</th>
|
|
18
|
+
</tr>
|
|
19
|
+
</thead>
|
|
20
|
+
<tbody>
|
|
21
|
+
{% if nbytes %}
|
|
22
|
+
<tr>
|
|
23
|
+
<th class="dask-table-label dask-table-border" style="padding: 6px 12px; text-align: left; font-weight: 400;">Bytes</th>
|
|
24
|
+
<td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ nbytes }}</td>
|
|
25
|
+
<td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ cbytes }}</td>
|
|
26
|
+
</tr>
|
|
27
|
+
{% endif %}
|
|
28
|
+
<tr>
|
|
29
|
+
<th class="dask-table-label dask-table-border" style="padding: 6px 12px; text-align: left; font-weight: 400;">Shape</th>
|
|
30
|
+
<td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ array.shape }}</td>
|
|
31
|
+
<td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ array.chunksize }}</td>
|
|
32
|
+
</tr>
|
|
33
|
+
<tr>
|
|
34
|
+
<th class="dask-table-label dask-table-border" style="padding: 6px 12px; text-align: left; font-weight: 400;">Nodes</th>
|
|
35
|
+
<td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ n_expr }}</td>
|
|
36
|
+
<td class="dask-table-border" style="padding: 6px 12px;"></td>
|
|
37
|
+
</tr>
|
|
38
|
+
</tbody>
|
|
39
|
+
</table>
|
|
40
|
+
</td>
|
|
41
|
+
<td style="vertical-align: middle; padding-left: 24px;">
|
|
42
|
+
{{ grid }}
|
|
43
|
+
</td>
|
|
44
|
+
</tr>
|
|
45
|
+
</table>
|
|
46
|
+
</summary>
|
|
47
|
+
<div style="padding: 12px 0;">{{ expr_flow }}</div>
|
|
48
|
+
</details>
|
|
File without changes
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def pytest_addoption(parser):
|
|
7
|
+
parser.addoption(
|
|
8
|
+
"--runslow",
|
|
9
|
+
action="store_true",
|
|
10
|
+
default=False,
|
|
11
|
+
help="run tests marked slow",
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def pytest_collection_modifyitems(config, items):
|
|
16
|
+
if config.getoption("--runslow"):
|
|
17
|
+
return
|
|
18
|
+
|
|
19
|
+
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
|
|
20
|
+
for item in items:
|
|
21
|
+
if "slow" in item.keywords:
|
|
22
|
+
item.add_marker(skip_slow)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
import dask_array as da
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_top_level_compatibility_exports():
|
|
9
|
+
assert da.newaxis is None
|
|
10
|
+
assert np.isnan(da.nan)
|
|
11
|
+
assert da.inf == np.inf
|
|
12
|
+
assert da.pi == np.pi
|
|
13
|
+
assert da.float64 is np.float64
|
|
14
|
+
assert da.int64 is np.int64
|
|
15
|
+
|
|
16
|
+
assert callable(da.compute)
|
|
17
|
+
assert callable(da.optimize)
|
|
18
|
+
assert callable(da.register_chunk_type)
|
|
19
|
+
assert callable(da.to_hdf5)
|
|
20
|
+
assert callable(da.from_tiledb)
|
|
21
|
+
assert callable(da.to_tiledb)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def test_top_level_optimize_collection():
|
|
25
|
+
x = da.arange(6, chunks=3) + 1
|
|
26
|
+
|
|
27
|
+
result = da.optimize(x)
|
|
28
|
+
|
|
29
|
+
assert isinstance(result, da.Array)
|
|
30
|
+
np.testing.assert_array_equal(result.compute(), np.arange(6) + 1)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def test_random_star_exports_legacy_wrappers():
|
|
34
|
+
namespace = {}
|
|
35
|
+
exec("from dask_array.random import *", namespace)
|
|
36
|
+
|
|
37
|
+
assert callable(namespace["normal"])
|
|
38
|
+
assert callable(namespace["random"])
|
|
39
|
+
assert callable(namespace["randint"])
|
|
40
|
+
assert callable(namespace["standard_normal"])
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Tests for coarse_blockdim: preferring larger chunks in binary operations.
|
|
2
|
+
|
|
3
|
+
When combining arrays with different chunk granularities, we prefer coarser
|
|
4
|
+
chunks (fewer blocks) when boundaries align. This reduces task overhead.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import math
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pytest
|
|
13
|
+
|
|
14
|
+
import dask_array as da
|
|
15
|
+
from dask_array._test_utils import assert_eq
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def total_chunks(arr):
|
|
19
|
+
"""Total number of chunks across all dimensions."""
|
|
20
|
+
return math.prod(arr.numblocks)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TestCoarseChunkPreference:
|
|
24
|
+
"""Tests for preferring coarser chunks when boundaries align."""
|
|
25
|
+
|
|
26
|
+
def test_shuffle_indexed_array(self):
|
|
27
|
+
"""Main use case: xarray groupby pattern.
|
|
28
|
+
|
|
29
|
+
Binary op between array with nice chunks and shuffle-indexed array
|
|
30
|
+
(which has per-element chunks) should preserve the nice chunks.
|
|
31
|
+
"""
|
|
32
|
+
# Original data: 10 chunks of size 12
|
|
33
|
+
arr = da.random.random((120, 20, 30), chunks=(12, 20, 30))
|
|
34
|
+
|
|
35
|
+
# Aggregated data indexed to match original shape
|
|
36
|
+
n_groups = 4
|
|
37
|
+
mean_arr = da.random.random((n_groups, 20, 30), chunks=(1, 20, 30))
|
|
38
|
+
indexer = np.tile(np.arange(n_groups), 30)
|
|
39
|
+
indexed_mean = mean_arr[indexer, ...]
|
|
40
|
+
|
|
41
|
+
result = arr - indexed_mean
|
|
42
|
+
|
|
43
|
+
# Should preserve arr's chunk count, not explode to 120
|
|
44
|
+
assert total_chunks(result) <= total_chunks(arr) * 2
|
|
45
|
+
assert_eq(result, arr.compute() - indexed_mean.compute())
|
|
46
|
+
|
|
47
|
+
def test_aligned_1d(self):
|
|
48
|
+
"""1D: (20,20) + (10,10,10,10) -> (20,20)"""
|
|
49
|
+
coarse = da.ones(40, chunks=20)
|
|
50
|
+
fine = da.ones(40, chunks=10)
|
|
51
|
+
|
|
52
|
+
result = coarse + fine
|
|
53
|
+
|
|
54
|
+
assert result.chunks == ((20, 20),)
|
|
55
|
+
|
|
56
|
+
def test_aligned_2d(self):
|
|
57
|
+
"""2D: coarse chunks preferred in both dimensions."""
|
|
58
|
+
coarse = da.ones((40, 40), chunks=(20, 20))
|
|
59
|
+
fine = da.ones((40, 40), chunks=(10, 10))
|
|
60
|
+
|
|
61
|
+
result = coarse + fine
|
|
62
|
+
|
|
63
|
+
assert result.chunks == ((20, 20), (20, 20))
|
|
64
|
+
assert total_chunks(result) == 4 # not 16
|
|
65
|
+
|
|
66
|
+
def test_multiples_align(self):
|
|
67
|
+
"""Chunk sizes that are multiples align: (30,30) + (10,...) -> (30,30)"""
|
|
68
|
+
coarse = da.ones(60, chunks=30)
|
|
69
|
+
fine = da.ones(60, chunks=10)
|
|
70
|
+
|
|
71
|
+
result = coarse + fine
|
|
72
|
+
|
|
73
|
+
assert result.chunks == ((30, 30),)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class TestFallbackToCommonBlockdim:
|
|
77
|
+
"""Tests for falling back when boundaries don't align."""
|
|
78
|
+
|
|
79
|
+
def test_misaligned_boundaries(self):
|
|
80
|
+
"""(15,15) vs (10,20): boundary 15 not in {10}, must subdivide."""
|
|
81
|
+
a = da.ones(30, chunks=(15, 15))
|
|
82
|
+
b = da.ones(30, chunks=(10, 20))
|
|
83
|
+
|
|
84
|
+
result = a + b
|
|
85
|
+
|
|
86
|
+
# Neither input's chunks work; uses finest common divisor
|
|
87
|
+
assert result.chunks != ((15, 15),)
|
|
88
|
+
assert result.chunks != ((10, 20),)
|
|
89
|
+
|
|
90
|
+
def test_non_divisible(self):
|
|
91
|
+
"""(12,12) vs (8,8,8): boundary 12 not in {8,16}, must subdivide."""
|
|
92
|
+
a = da.ones(24, chunks=12)
|
|
93
|
+
b = da.ones(24, chunks=8)
|
|
94
|
+
|
|
95
|
+
result = a + b
|
|
96
|
+
|
|
97
|
+
# More chunks than either input
|
|
98
|
+
assert len(result.chunks[0]) > 2
|
|
99
|
+
|
|
100
|
+
def test_classic_uneven(self):
|
|
101
|
+
"""(4,6) vs (6,4): different boundaries, uses (4,2,4)."""
|
|
102
|
+
a = da.arange(10, chunks=((4, 6),))
|
|
103
|
+
b = da.ones(10, chunks=((6, 4),))
|
|
104
|
+
|
|
105
|
+
result = a + b
|
|
106
|
+
|
|
107
|
+
assert result.chunks == ((4, 2, 4),)
|