dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,282 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import math
5
+ from collections import defaultdict
6
+ from itertools import product
7
+ from numbers import Number
8
+ from operator import mul
9
+
10
+ import numpy as np
11
+
12
+ from dask._task_spec import List, Task, TaskRef
13
+ from dask_array._expr import ArrayExpr
14
+ from dask_array._utils import meta_from_array
15
+ from dask_array.slicing._utils import replace_ellipsis
16
+ from dask_array._core_utils import (
17
+ _get_axis,
18
+ _vindex_merge,
19
+ _vindex_slice_and_transpose,
20
+ interleave_none,
21
+ keyname,
22
+ )
23
+ from dask.utils import cached_cumsum, cached_max
24
+
25
+
26
+ def _numpy_vindex(indexer, arr):
27
+ """Helper for vindex with single-block arrays indexed by dask arrays."""
28
+ return arr[indexer]
29
+
30
+
31
+ def _vindex(x, *indexes):
32
+ """Point wise indexing with broadcasting.
33
+
34
+ >>> x = np.arange(56).reshape((7, 8))
35
+ >>> x
36
+ array([[ 0, 1, 2, 3, 4, 5, 6, 7],
37
+ [ 8, 9, 10, 11, 12, 13, 14, 15],
38
+ [16, 17, 18, 19, 20, 21, 22, 23],
39
+ [24, 25, 26, 27, 28, 29, 30, 31],
40
+ [32, 33, 34, 35, 36, 37, 38, 39],
41
+ [40, 41, 42, 43, 44, 45, 46, 47],
42
+ [48, 49, 50, 51, 52, 53, 54, 55]])
43
+
44
+ >>> from dask_array._collection import from_array
45
+ >>> d = from_array(x, chunks=(3, 4))
46
+ >>> result = _vindex(d, [0, 1, 6, 0], [0, 1, 0, 7])
47
+ >>> result.compute()
48
+ array([ 0, 9, 48, 7])
49
+ """
50
+
51
+ indexes = replace_ellipsis(x.ndim, indexes)
52
+
53
+ nonfancy_indexes = []
54
+ reduced_indexes = []
55
+ for ind in indexes:
56
+ if isinstance(ind, Number):
57
+ nonfancy_indexes.append(ind)
58
+ elif isinstance(ind, slice):
59
+ nonfancy_indexes.append(ind)
60
+ reduced_indexes.append(slice(None))
61
+ else:
62
+ nonfancy_indexes.append(slice(None))
63
+ reduced_indexes.append(ind)
64
+
65
+ nonfancy_indexes = tuple(nonfancy_indexes)
66
+ reduced_indexes = tuple(reduced_indexes)
67
+
68
+ x = x[nonfancy_indexes]
69
+
70
+ array_indexes = {}
71
+ for i, (ind, size) in enumerate(zip(reduced_indexes, x.shape)):
72
+ if not isinstance(ind, slice):
73
+ ind = np.array(ind, copy=True)
74
+ if ind.dtype.kind == "b":
75
+ raise IndexError("vindex does not support indexing with boolean arrays")
76
+ if ((ind >= size) | (ind < -size)).any():
77
+ raise IndexError(
78
+ f"vindex key has entries out of bounds for indexing along axis {i} of size {size}: {ind!r}"
79
+ )
80
+ ind %= size
81
+ array_indexes[i] = ind
82
+
83
+ if array_indexes:
84
+ x = _vindex_array(x, array_indexes)
85
+
86
+ return x
87
+
88
+
89
+ def _compute_indexer(index, chunks_along_axis):
90
+ """Compute Shuffle indexer by grouping consecutive indices from same input chunk.
91
+
92
+ Returns a list of lists, where each inner list contains indices that come from
93
+ a contiguous run accessing the same input chunk. This preserves locality for
94
+ patterns like np.repeat.
95
+ """
96
+ chunk_boundaries = np.cumsum((0,) + chunks_along_axis)
97
+ input_chunk_ids = np.searchsorted(chunk_boundaries[1:], index, side="right")
98
+ changes = np.concatenate([[0], np.where(np.diff(input_chunk_ids) != 0)[0] + 1, [len(index)]])
99
+ return [index[changes[i] : changes[i + 1]].tolist() for i in range(len(changes) - 1)]
100
+
101
+
102
+ def _vindex_array(x, dict_indexes):
103
+ """Point wise indexing with only NumPy Arrays."""
104
+ from dask_array._new_collection import new_collection
105
+ from dask_array.creation import empty
106
+
107
+ try:
108
+ broadcast_shape = np.broadcast_shapes(*(arr.shape for arr in dict_indexes.values()))
109
+ except ValueError as e:
110
+ shapes_str = " ".join(str(a.shape) for a in dict_indexes.values())
111
+ raise IndexError(
112
+ f"shape mismatch: indexing arrays could not be broadcast together with shapes {shapes_str}"
113
+ ) from e
114
+ npoints = math.prod(broadcast_shape)
115
+
116
+ # Single-axis case: delegate to Shuffle for optimization hooks
117
+ if len(dict_indexes) == 1 and npoints > 0:
118
+ from dask_array._shuffle import _shuffle
119
+
120
+ axis = next(iter(dict_indexes.keys()))
121
+ index = next(iter(dict_indexes.values())).ravel()
122
+ indexer = _compute_indexer(index, x.chunks[axis])
123
+
124
+ result = new_collection(_shuffle(x.expr, indexer, axis, "vindex-"))
125
+ # Shuffle keeps axis in place; reshape for broadcast_shape along that axis
126
+ new_shape = list(result.shape)
127
+ new_shape[axis : axis + 1] = list(broadcast_shape)
128
+ return result.reshape(tuple(new_shape))
129
+
130
+ if npoints > 0:
131
+ result_1d = new_collection(VIndexArray(x.expr, dict_indexes, broadcast_shape, npoints))
132
+ return result_1d.reshape(broadcast_shape + result_1d.shape[1:])
133
+
134
+ # output has zero dimension - just create a new zero-shape array
135
+ axes = [i for i in range(x.ndim) if i in dict_indexes]
136
+ chunks = [c for i, c in enumerate(x.chunks) if i not in axes]
137
+ chunks.insert(0, (0,))
138
+ chunks = tuple(chunks)
139
+
140
+ result_1d = empty(tuple(map(sum, chunks)), chunks=chunks, dtype=x.dtype)
141
+ return result_1d.reshape(broadcast_shape + result_1d.shape[1:])
142
+
143
+
144
+ class VIndexArray(ArrayExpr):
145
+ """Point-wise vectorized indexing with broadcasting.
146
+
147
+ Used for multi-axis fancy indexing where indices broadcast together.
148
+ Single-axis cases delegate to Shuffle for optimization hooks.
149
+ """
150
+
151
+ _parameters = ["array", "dict_indexes", "broadcast_shape", "npoints"]
152
+
153
+ @functools.cached_property
154
+ def _name(self):
155
+ return f"vindex-merge-{self.deterministic_token}"
156
+
157
+ @functools.cached_property
158
+ def _meta(self):
159
+ return meta_from_array(self.array._meta, ndim=len(self.chunks))
160
+
161
+ @functools.cached_property
162
+ def _axes(self):
163
+ """Axes that have array indexing."""
164
+ return [i for i in range(self.array.ndim) if i in self.dict_indexes]
165
+
166
+ def _subset_to_indexed_axes(self, iterable):
167
+ for i, elem in enumerate(iterable):
168
+ if i in self._axes:
169
+ yield elem
170
+
171
+ @functools.cached_property
172
+ def _max_chunk_point_dimensions(self):
173
+ return functools.reduce(mul, map(cached_max, self._subset_to_indexed_axes(self.array.chunks)))
174
+
175
+ @functools.cached_property
176
+ def chunks(self):
177
+ axes = self._axes
178
+ npoints = self.npoints
179
+ max_chunk_point_dimensions = self._max_chunk_point_dimensions
180
+
181
+ chunks = [c for i, c in enumerate(self.array.chunks) if i not in axes]
182
+
183
+ n_chunks, remainder = divmod(npoints, max_chunk_point_dimensions)
184
+ chunks.insert(
185
+ 0,
186
+ (
187
+ (max_chunk_point_dimensions,) * n_chunks + ((remainder,) if remainder > 0 else ())
188
+ if npoints > 0
189
+ else (0,)
190
+ ),
191
+ )
192
+ return tuple(chunks)
193
+
194
+ def _layer(self) -> dict:
195
+ dict_indexes = self.dict_indexes
196
+ broadcast_shape = self.broadcast_shape
197
+ npoints = self.npoints
198
+ axes = self._axes
199
+
200
+ bounds2 = tuple(
201
+ np.array(cached_cumsum(c, initial_zero=True)) for c in self._subset_to_indexed_axes(self.array.chunks)
202
+ )
203
+ axis = _get_axis(tuple(i if i in axes else None for i in range(self.array.ndim)))
204
+
205
+ # Now compute indices of each output element within each input block
206
+ block_idxs = tuple(np.searchsorted(b, ind, side="right") - 1 for b, ind in zip(bounds2, dict_indexes.values()))
207
+ starts = (b[i] for i, b in zip(block_idxs, bounds2))
208
+ inblock_idxs = []
209
+ for idx, start in zip(dict_indexes.values(), starts):
210
+ # Convert unsigned integers to signed to avoid float promotion in subtraction
211
+ if idx.dtype.kind == "u":
212
+ idx = idx.astype(np.int64)
213
+ a = idx - start
214
+ if len(a) > 0:
215
+ dtype = np.min_scalar_type(np.max(a, axis=None))
216
+ inblock_idxs.append(a.astype(dtype, copy=False))
217
+ else:
218
+ inblock_idxs.append(a)
219
+
220
+ inblock_idxs = np.broadcast_arrays(*inblock_idxs) # type: ignore[assignment]
221
+
222
+ max_chunk_point_dimensions = self._max_chunk_point_dimensions
223
+ n_chunks, remainder = divmod(npoints, max_chunk_point_dimensions)
224
+
225
+ other_blocks = product(*[range(len(c)) if i not in axes else [None] for i, c in enumerate(self.array.chunks)])
226
+
227
+ full_slices = [slice(None, None) if i not in axes else None for i in range(self.array.ndim)]
228
+
229
+ # The output is constructed as a new dimension and then reshaped
230
+ outinds = np.arange(npoints).reshape(broadcast_shape)
231
+ outblocks, outblock_idx = np.divmod(outinds, max_chunk_point_dimensions)
232
+
233
+ ravel_shape = (
234
+ n_chunks + 1,
235
+ *self._subset_to_indexed_axes(self.array.numblocks),
236
+ )
237
+ keys = np.ravel_multi_index([outblocks, *block_idxs], ravel_shape)
238
+ sortidx = np.argsort(keys, axis=None)
239
+ sorted_keys = keys.flat[sortidx]
240
+ sorted_inblock_idxs = [_.flat[sortidx] for _ in inblock_idxs]
241
+ sorted_outblock_idx = outblock_idx.flat[sortidx]
242
+ dtype = np.min_scalar_type(max_chunk_point_dimensions)
243
+ sorted_outblock_idx = sorted_outblock_idx.astype(dtype, copy=False)
244
+ flag = np.concatenate([[True], sorted_keys[1:] != sorted_keys[:-1], [True]])
245
+ (key_bounds,) = flag.nonzero()
246
+
247
+ slice_name = f"vindex-slice-{self.deterministic_token}"
248
+ dsk = {}
249
+
250
+ for okey in other_blocks:
251
+ merge_inputs = defaultdict(list)
252
+ merge_indexer = defaultdict(list)
253
+ for i, (start, stop) in enumerate(zip(key_bounds[:-1], key_bounds[1:], strict=True)):
254
+ slicer = slice(start, stop)
255
+ key = sorted_keys[start]
256
+ outblock, *input_blocks = np.unravel_index(key, ravel_shape)
257
+ inblock = [_[slicer] for _ in sorted_inblock_idxs]
258
+ k = keyname(slice_name, i, okey)
259
+ dsk[k] = Task(
260
+ k,
261
+ _vindex_slice_and_transpose,
262
+ TaskRef((self.array._name,) + interleave_none(okey, input_blocks)),
263
+ interleave_none(full_slices, inblock),
264
+ axis,
265
+ )
266
+ merge_inputs[outblock].append(TaskRef(k))
267
+ merge_indexer[outblock].append(sorted_outblock_idx[slicer])
268
+
269
+ for i in merge_inputs.keys():
270
+ k = keyname(self._name, i, okey)
271
+ dsk[k] = Task(
272
+ k,
273
+ _vindex_merge,
274
+ merge_indexer[i],
275
+ List(*merge_inputs[i]),
276
+ )
277
+
278
+ return dsk
279
+
280
+ def __dask_keys__(self):
281
+ # Override to return 1D keys since we reshape after
282
+ return [(self._name,) + idx for idx in np.ndindex(tuple(len(c) for c in self.chunks))]
@@ -0,0 +1,15 @@
1
+ """Stacking and concatenation functions."""
2
+
3
+ from dask_array._concatenate import concatenate
4
+ from dask_array._stack import stack
5
+ from dask_array.stacking._block import block
6
+ from dask_array.stacking._simple import dstack, hstack, vstack
7
+
8
+ __all__ = [
9
+ "stack",
10
+ "concatenate",
11
+ "block",
12
+ "vstack",
13
+ "hstack",
14
+ "dstack",
15
+ ]
@@ -0,0 +1,83 @@
1
+ """Block operation."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ def block(arrays, allow_unknown_chunksizes=False):
7
+ """
8
+ Assemble an nd-array from nested lists of blocks.
9
+
10
+ Blocks in the innermost lists are concatenated along the last
11
+ dimension (-1), then these are concatenated along the second-last
12
+ dimension (-2), and so on until the outermost list is reached.
13
+
14
+ See Also
15
+ --------
16
+ numpy.block
17
+ """
18
+ # Import here to avoid circular imports
19
+ from dask_array._collection import asanyarray, concatenate
20
+ from dask_array._numpy_compat import _Recurser
21
+
22
+ def atleast_nd(x, ndim):
23
+ x = asanyarray(x)
24
+ diff = max(ndim - x.ndim, 0)
25
+ if diff == 0:
26
+ return x
27
+ else:
28
+ return x[(None,) * diff + (Ellipsis,)]
29
+
30
+ def format_index(index):
31
+ return "arrays" + "".join(f"[{i}]" for i in index)
32
+
33
+ rec = _Recurser(recurse_if=lambda x: type(x) is list)
34
+
35
+ # Ensure that the lists are all matched in depth
36
+ list_ndim = None
37
+ any_empty = False
38
+ for index, value, entering in rec.walk(arrays):
39
+ if type(value) is tuple:
40
+ raise TypeError(
41
+ f"{format_index(index)} is a tuple. "
42
+ "Only lists can be used to arrange blocks, and np.block does "
43
+ "not allow implicit conversion from tuple to ndarray."
44
+ )
45
+ if not entering:
46
+ curr_depth = len(index)
47
+ elif len(value) == 0:
48
+ curr_depth = len(index) + 1
49
+ any_empty = True
50
+ else:
51
+ continue
52
+
53
+ if list_ndim is not None and list_ndim != curr_depth:
54
+ raise ValueError(
55
+ f"List depths are mismatched. First element was at depth {list_ndim}, "
56
+ f"but there is an element at depth {curr_depth} ({format_index(index)})"
57
+ )
58
+ list_ndim = curr_depth
59
+
60
+ # Do this here so we catch depth mismatches first
61
+ if any_empty:
62
+ raise ValueError("Lists cannot be empty")
63
+
64
+ # Convert all the arrays to ndarrays
65
+ arrays = rec.map_reduce(arrays, f_map=asanyarray, f_reduce=list)
66
+
67
+ # Determine the maximum dimension of the elements
68
+ elem_ndim = rec.map_reduce(arrays, f_map=lambda xi: xi.ndim, f_reduce=max)
69
+ ndim = max(list_ndim, elem_ndim)
70
+
71
+ # First axis to concatenate along
72
+ first_axis = ndim - list_ndim
73
+
74
+ # Make all the elements the same dimension
75
+ arrays = rec.map_reduce(arrays, f_map=lambda xi: atleast_nd(xi, ndim), f_reduce=list)
76
+
77
+ # Concatenate innermost lists on the right, outermost on the left
78
+ return rec.map_reduce(
79
+ arrays,
80
+ f_reduce=lambda xs, axis: concatenate(list(xs), axis=axis, allow_unknown_chunksizes=allow_unknown_chunksizes),
81
+ f_kwargs=lambda axis: dict(axis=(axis + 1)),
82
+ axis=first_axis,
83
+ )
@@ -0,0 +1,58 @@
1
+ """Simple stacking operations: vstack, hstack, dstack."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ def vstack(tup, allow_unknown_chunksizes=False):
7
+ """Stack arrays in sequence vertically (row wise).
8
+
9
+ See Also
10
+ --------
11
+ numpy.vstack
12
+ """
13
+ # Import here to avoid circular imports
14
+ from dask_array._collection import Array, concatenate
15
+ from dask_array.manipulation._expand import atleast_2d
16
+
17
+ if isinstance(tup, Array):
18
+ raise NotImplementedError("``vstack`` expects a sequence of arrays as the first argument")
19
+
20
+ tup = tuple(atleast_2d(x) for x in tup)
21
+ return concatenate(tup, axis=0, allow_unknown_chunksizes=allow_unknown_chunksizes)
22
+
23
+
24
+ def hstack(tup, allow_unknown_chunksizes=False):
25
+ """Stack arrays in sequence horizontally (column wise).
26
+
27
+ See Also
28
+ --------
29
+ numpy.hstack
30
+ """
31
+ # Import here to avoid circular imports
32
+ from dask_array._collection import Array, concatenate
33
+
34
+ if isinstance(tup, Array):
35
+ raise NotImplementedError("``hstack`` expects a sequence of arrays as the first argument")
36
+
37
+ if all(x.ndim == 1 for x in tup):
38
+ return concatenate(tup, axis=0, allow_unknown_chunksizes=allow_unknown_chunksizes)
39
+ else:
40
+ return concatenate(tup, axis=1, allow_unknown_chunksizes=allow_unknown_chunksizes)
41
+
42
+
43
+ def dstack(tup, allow_unknown_chunksizes=False):
44
+ """Stack arrays in sequence depth wise (along third axis).
45
+
46
+ See Also
47
+ --------
48
+ numpy.dstack
49
+ """
50
+ # Import here to avoid circular imports
51
+ from dask_array._collection import Array, concatenate
52
+ from dask_array.manipulation._expand import atleast_3d
53
+
54
+ if isinstance(tup, Array):
55
+ raise NotImplementedError("``dstack`` expects a sequence of arrays as the first argument")
56
+
57
+ tup = tuple(atleast_3d(x) for x in tup)
58
+ return concatenate(tup, axis=2, allow_unknown_chunksizes=allow_unknown_chunksizes)
@@ -0,0 +1,48 @@
1
+ <style>
2
+ .dask-array-repr .dask-table-header { color: var(--jp-ui-font-color2, #78716c); }
3
+ .dask-array-repr .dask-table-label { color: var(--jp-ui-font-color3, #a8a29e); }
4
+ .dask-array-repr .dask-table-data { color: var(--jp-ui-font-color1, #1c1917); }
5
+ .dask-array-repr .dask-table-border { border-top: 1px solid var(--jp-border-color2, #e7e5e4); }
6
+ </style>
7
+ <details class="dask-array-repr" style="font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;">
8
+ <summary style="cursor: pointer; list-style: none;">
9
+ <table style="border-collapse: separate; border-spacing: 0; display: inline-table;">
10
+ <tr>
11
+ <td style="vertical-align: top;">
12
+ <table style="border-collapse: collapse; font-size: 14px;">
13
+ <thead>
14
+ <tr>
15
+ <td style="padding: 6px 12px;"></td>
16
+ <th class="dask-table-header" style="padding: 6px 12px; text-align: right; font-weight: 400;">Array</th>
17
+ <th class="dask-table-header" style="padding: 6px 12px; text-align: right; font-weight: 400;">Chunk</th>
18
+ </tr>
19
+ </thead>
20
+ <tbody>
21
+ {% if nbytes %}
22
+ <tr>
23
+ <th class="dask-table-label dask-table-border" style="padding: 6px 12px; text-align: left; font-weight: 400;">Bytes</th>
24
+ <td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ nbytes }}</td>
25
+ <td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ cbytes }}</td>
26
+ </tr>
27
+ {% endif %}
28
+ <tr>
29
+ <th class="dask-table-label dask-table-border" style="padding: 6px 12px; text-align: left; font-weight: 400;">Shape</th>
30
+ <td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ array.shape }}</td>
31
+ <td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ array.chunksize }}</td>
32
+ </tr>
33
+ <tr>
34
+ <th class="dask-table-label dask-table-border" style="padding: 6px 12px; text-align: left; font-weight: 400;">Nodes</th>
35
+ <td class="dask-table-data dask-table-border" style="padding: 6px 12px; text-align: right; font-weight: 600;">{{ n_expr }}</td>
36
+ <td class="dask-table-border" style="padding: 6px 12px;"></td>
37
+ </tr>
38
+ </tbody>
39
+ </table>
40
+ </td>
41
+ <td style="vertical-align: middle; padding-left: 24px;">
42
+ {{ grid }}
43
+ </td>
44
+ </tr>
45
+ </table>
46
+ </summary>
47
+ <div style="padding: 12px 0;">{{ expr_flow }}</div>
48
+ </details>
File without changes
@@ -0,0 +1,22 @@
1
+ from __future__ import annotations
2
+
3
+ import pytest
4
+
5
+
6
+ def pytest_addoption(parser):
7
+ parser.addoption(
8
+ "--runslow",
9
+ action="store_true",
10
+ default=False,
11
+ help="run tests marked slow",
12
+ )
13
+
14
+
15
+ def pytest_collection_modifyitems(config, items):
16
+ if config.getoption("--runslow"):
17
+ return
18
+
19
+ skip_slow = pytest.mark.skip(reason="need --runslow option to run")
20
+ for item in items:
21
+ if "slow" in item.keywords:
22
+ item.add_marker(skip_slow)
@@ -0,0 +1,40 @@
1
+ from __future__ import annotations
2
+
3
+ import numpy as np
4
+
5
+ import dask_array as da
6
+
7
+
8
+ def test_top_level_compatibility_exports():
9
+ assert da.newaxis is None
10
+ assert np.isnan(da.nan)
11
+ assert da.inf == np.inf
12
+ assert da.pi == np.pi
13
+ assert da.float64 is np.float64
14
+ assert da.int64 is np.int64
15
+
16
+ assert callable(da.compute)
17
+ assert callable(da.optimize)
18
+ assert callable(da.register_chunk_type)
19
+ assert callable(da.to_hdf5)
20
+ assert callable(da.from_tiledb)
21
+ assert callable(da.to_tiledb)
22
+
23
+
24
+ def test_top_level_optimize_collection():
25
+ x = da.arange(6, chunks=3) + 1
26
+
27
+ result = da.optimize(x)
28
+
29
+ assert isinstance(result, da.Array)
30
+ np.testing.assert_array_equal(result.compute(), np.arange(6) + 1)
31
+
32
+
33
+ def test_random_star_exports_legacy_wrappers():
34
+ namespace = {}
35
+ exec("from dask_array.random import *", namespace)
36
+
37
+ assert callable(namespace["normal"])
38
+ assert callable(namespace["random"])
39
+ assert callable(namespace["randint"])
40
+ assert callable(namespace["standard_normal"])
@@ -0,0 +1,107 @@
1
+ """Tests for coarse_blockdim: preferring larger chunks in binary operations.
2
+
3
+ When combining arrays with different chunk granularities, we prefer coarser
4
+ chunks (fewer blocks) when boundaries align. This reduces task overhead.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import math
10
+
11
+ import numpy as np
12
+ import pytest
13
+
14
+ import dask_array as da
15
+ from dask_array._test_utils import assert_eq
16
+
17
+
18
+ def total_chunks(arr):
19
+ """Total number of chunks across all dimensions."""
20
+ return math.prod(arr.numblocks)
21
+
22
+
23
+ class TestCoarseChunkPreference:
24
+ """Tests for preferring coarser chunks when boundaries align."""
25
+
26
+ def test_shuffle_indexed_array(self):
27
+ """Main use case: xarray groupby pattern.
28
+
29
+ Binary op between array with nice chunks and shuffle-indexed array
30
+ (which has per-element chunks) should preserve the nice chunks.
31
+ """
32
+ # Original data: 10 chunks of size 12
33
+ arr = da.random.random((120, 20, 30), chunks=(12, 20, 30))
34
+
35
+ # Aggregated data indexed to match original shape
36
+ n_groups = 4
37
+ mean_arr = da.random.random((n_groups, 20, 30), chunks=(1, 20, 30))
38
+ indexer = np.tile(np.arange(n_groups), 30)
39
+ indexed_mean = mean_arr[indexer, ...]
40
+
41
+ result = arr - indexed_mean
42
+
43
+ # Should preserve arr's chunk count, not explode to 120
44
+ assert total_chunks(result) <= total_chunks(arr) * 2
45
+ assert_eq(result, arr.compute() - indexed_mean.compute())
46
+
47
+ def test_aligned_1d(self):
48
+ """1D: (20,20) + (10,10,10,10) -> (20,20)"""
49
+ coarse = da.ones(40, chunks=20)
50
+ fine = da.ones(40, chunks=10)
51
+
52
+ result = coarse + fine
53
+
54
+ assert result.chunks == ((20, 20),)
55
+
56
+ def test_aligned_2d(self):
57
+ """2D: coarse chunks preferred in both dimensions."""
58
+ coarse = da.ones((40, 40), chunks=(20, 20))
59
+ fine = da.ones((40, 40), chunks=(10, 10))
60
+
61
+ result = coarse + fine
62
+
63
+ assert result.chunks == ((20, 20), (20, 20))
64
+ assert total_chunks(result) == 4 # not 16
65
+
66
+ def test_multiples_align(self):
67
+ """Chunk sizes that are multiples align: (30,30) + (10,...) -> (30,30)"""
68
+ coarse = da.ones(60, chunks=30)
69
+ fine = da.ones(60, chunks=10)
70
+
71
+ result = coarse + fine
72
+
73
+ assert result.chunks == ((30, 30),)
74
+
75
+
76
+ class TestFallbackToCommonBlockdim:
77
+ """Tests for falling back when boundaries don't align."""
78
+
79
+ def test_misaligned_boundaries(self):
80
+ """(15,15) vs (10,20): boundary 15 not in {10}, must subdivide."""
81
+ a = da.ones(30, chunks=(15, 15))
82
+ b = da.ones(30, chunks=(10, 20))
83
+
84
+ result = a + b
85
+
86
+ # Neither input's chunks work; uses finest common divisor
87
+ assert result.chunks != ((15, 15),)
88
+ assert result.chunks != ((10, 20),)
89
+
90
+ def test_non_divisible(self):
91
+ """(12,12) vs (8,8,8): boundary 12 not in {8,16}, must subdivide."""
92
+ a = da.ones(24, chunks=12)
93
+ b = da.ones(24, chunks=8)
94
+
95
+ result = a + b
96
+
97
+ # More chunks than either input
98
+ assert len(result.chunks[0]) > 2
99
+
100
+ def test_classic_uneven(self):
101
+ """(4,6) vs (6,4): different boundaries, uses (4,2,4)."""
102
+ a = da.arange(10, chunks=((4, 6),))
103
+ b = da.ones(10, chunks=((6, 4),))
104
+
105
+ result = a + b
106
+
107
+ assert result.chunks == ((4, 2, 4),)