dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,272 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from itertools import product
5
+
6
+ import numpy as np
7
+
8
+ from dask._task_spec import Task, TaskRef
9
+ from dask_array._expr import ArrayExpr
10
+ from dask_array._core_utils import normalize_chunks
11
+ from dask_array._utils import meta_from_array
12
+
13
+
14
+ class BroadcastTo(ArrayExpr):
15
+ """Broadcast an array to a new shape."""
16
+
17
+ _parameters = ["array", "_shape", "_chunks", "_meta_override"]
18
+ _defaults = {"_meta_override": None}
19
+
20
+ @functools.cached_property
21
+ def _name(self):
22
+ return f"broadcast_to-{self.deterministic_token}"
23
+
24
+ @functools.cached_property
25
+ def _meta(self):
26
+ meta_override = self.operand("_meta_override")
27
+ # Only use meta_override if it has the correct ndim
28
+ if meta_override is not None and hasattr(meta_override, "ndim") and meta_override.ndim == len(self._shape):
29
+ return meta_override
30
+ return meta_from_array(self.array._meta, ndim=len(self._shape))
31
+
32
+ @functools.cached_property
33
+ def chunks(self):
34
+ return self._chunks
35
+
36
+ def _layer(self) -> dict:
37
+ x = self.array
38
+ shape = self._shape
39
+ chunks = self._chunks
40
+ ndim_new = len(shape) - x.ndim
41
+
42
+ dsk = {}
43
+ enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
44
+ for ec in enumerated_chunks:
45
+ new_index, chunk_shape = zip(*ec)
46
+ old_index = tuple(0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:]))
47
+ old_key = (x._name,) + old_index
48
+ new_key = (self._name,) + new_index
49
+ dsk[new_key] = Task(new_key, np.broadcast_to, TaskRef(old_key), chunk_shape)
50
+
51
+ return dsk
52
+
53
+ def _simplify_up(self, parent, dependents):
54
+ """Allow slice and shuffle operations to push through BroadcastTo."""
55
+ from dask_array._shuffle import Shuffle
56
+ from dask_array.slicing import SliceSlicesIntegers
57
+
58
+ if isinstance(parent, SliceSlicesIntegers):
59
+ return self._accept_slice(parent)
60
+ if isinstance(parent, Shuffle):
61
+ return self._accept_shuffle(parent)
62
+ return None
63
+
64
+ def _accept_shuffle(self, shuffle_expr):
65
+ """Accept a shuffle being pushed through BroadcastTo.
66
+
67
+ - Shuffle on a new dimension (added by broadcast): can't push through
68
+ - Shuffle on dimension broadcast from size 1: no-op, return self
69
+ - Shuffle on dimension with real data: push through to input
70
+ """
71
+ from dask_array._shuffle import Shuffle
72
+
73
+ axis = shuffle_expr.axis
74
+ ndim_new = len(self._shape) - self.array.ndim
75
+
76
+ # Shuffle on a new dimension (added by broadcast) - can't push through
77
+ if axis < ndim_new:
78
+ return None
79
+
80
+ # Map to input axis
81
+ input_axis = axis - ndim_new
82
+ input_size = self.array.shape[input_axis]
83
+
84
+ # If input dimension is size 1 (broadcasted), shuffle is a no-op
85
+ if input_size == 1:
86
+ return self
87
+
88
+ # Push shuffle through to input
89
+ shuffled_input = Shuffle(
90
+ self.array,
91
+ shuffle_expr.indexer,
92
+ input_axis,
93
+ shuffle_expr.operand("name"),
94
+ )
95
+ return BroadcastTo(shuffled_input, self._shape, self._chunks, self._meta)
96
+
97
+ def _accept_slice(self, slice_expr):
98
+ """Accept a slice being pushed through BroadcastTo.
99
+
100
+ For broadcast_to(x, shape)[slices]:
101
+ - Dimensions added by broadcast (new dims): affect output shape only
102
+ - Dimensions from input with size > 1: push slice to input
103
+ - Dimensions from input with size == 1: affect output shape only
104
+ """
105
+ from numbers import Integral
106
+
107
+ from dask_array._new_collection import new_collection
108
+
109
+ input_arr = self.array
110
+ output_shape = self._shape
111
+ index = slice_expr.index
112
+
113
+ # Pad index to full length
114
+ full_index = index + (slice(None),) * (len(output_shape) - len(index))
115
+
116
+ # For now, only handle simple slices
117
+ if any(isinstance(idx, Integral) for idx in full_index):
118
+ return None
119
+ if any(idx is None for idx in full_index):
120
+ return None
121
+
122
+ ndim_new = len(output_shape) - input_arr.ndim
123
+
124
+ # Compute new output shape and input slices
125
+ new_output_shape = []
126
+ input_slices = []
127
+
128
+ for out_dim, idx in enumerate(full_index):
129
+ if not isinstance(idx, slice):
130
+ return None
131
+
132
+ out_size = output_shape[out_dim]
133
+ start, stop, step = idx.indices(out_size)
134
+ if step != 1:
135
+ return None
136
+ new_dim_size = max(0, stop - start)
137
+ new_output_shape.append(new_dim_size)
138
+
139
+ # Check if this dimension maps to input
140
+ if out_dim >= ndim_new:
141
+ in_dim = out_dim - ndim_new
142
+ in_size = input_arr.shape[in_dim]
143
+
144
+ if in_size == 1:
145
+ # Broadcasted from size 1 - can't push, just take full slice
146
+ input_slices.append(slice(None))
147
+ else:
148
+ # Real dimension - push the slice
149
+ input_slices.append(idx)
150
+
151
+ # Slice the input array
152
+ if input_slices:
153
+ sliced_input = new_collection(input_arr)[tuple(input_slices)]
154
+ else:
155
+ sliced_input = new_collection(input_arr)
156
+
157
+ # Compute new chunks for the output
158
+ # For dimensions from input: use input's (sliced) chunks
159
+ # For new dimensions: use the new output shape (single chunk)
160
+ old_chunks = self._chunks
161
+ new_chunks = []
162
+ for out_dim, old_chunk in enumerate(old_chunks):
163
+ if out_dim >= ndim_new:
164
+ in_dim = out_dim - ndim_new
165
+ in_size = input_arr.shape[in_dim]
166
+ if in_size == 1:
167
+ # Broadcasted - compute new chunks from old
168
+ idx = full_index[out_dim]
169
+ start, stop, _ = idx.indices(output_shape[out_dim])
170
+ new_chunks.append(self._slice_chunks(old_chunk, start, stop - start))
171
+ else:
172
+ # Use sliced input's chunks
173
+ new_chunks.append(sliced_input.expr.chunks[in_dim])
174
+ else:
175
+ # New dimension - compute from slice
176
+ idx = full_index[out_dim]
177
+ start, stop, _ = idx.indices(output_shape[out_dim])
178
+ new_chunks.append(self._slice_chunks(old_chunk, start, stop - start))
179
+
180
+ # Compute meta for the new broadcast
181
+ new_meta = meta_from_array(sliced_input.expr._meta)
182
+
183
+ # Create new BroadcastTo
184
+ return BroadcastTo(
185
+ sliced_input.expr,
186
+ tuple(new_output_shape),
187
+ tuple(new_chunks),
188
+ new_meta,
189
+ )
190
+
191
+ def _slice_chunks(self, chunks, start, length):
192
+ """Compute new chunks after slicing."""
193
+ if length == 0:
194
+ return (0,)
195
+
196
+ result = []
197
+ pos = 0
198
+ remaining = length
199
+ for chunk_size in chunks:
200
+ chunk_start = pos
201
+ chunk_end = pos + chunk_size
202
+ pos = chunk_end
203
+
204
+ if chunk_end <= start:
205
+ continue
206
+ if chunk_start >= start + length:
207
+ break
208
+
209
+ # Overlap with the slice
210
+ overlap_start = max(chunk_start, start)
211
+ overlap_end = min(chunk_end, start + length)
212
+ overlap_size = overlap_end - overlap_start
213
+
214
+ if overlap_size > 0:
215
+ result.append(overlap_size)
216
+ remaining -= overlap_size
217
+
218
+ return tuple(result)
219
+
220
+
221
+ def broadcast_to(x, shape, chunks=None, meta=None):
222
+ """Broadcast an array to a new shape.
223
+
224
+ Parameters
225
+ ----------
226
+ x : array_like
227
+ The array to broadcast.
228
+ shape : tuple
229
+ The shape of the desired array.
230
+ chunks : tuple, optional
231
+ If provided, then the result will use these chunks instead of the same
232
+ chunks as the source array.
233
+ meta : empty ndarray, optional
234
+ empty ndarray created with same NumPy backend, ndim and dtype as the
235
+ Dask Array being created
236
+
237
+ Returns
238
+ -------
239
+ Array
240
+ """
241
+ from dask_array._new_collection import new_collection
242
+ from dask_array._collection import asarray
243
+
244
+ x = asarray(x)
245
+ shape = tuple(shape)
246
+
247
+ if meta is None:
248
+ meta = meta_from_array(x._meta)
249
+
250
+ # Identity case
251
+ if x.shape == shape and (chunks is None or chunks == x.chunks):
252
+ return x
253
+
254
+ ndim_new = len(shape) - x.ndim
255
+ if ndim_new < 0 or any(new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1):
256
+ raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
257
+
258
+ if chunks is None:
259
+ chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
260
+ bd if old > 1 else (new,) for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
261
+ )
262
+ else:
263
+ chunks = normalize_chunks(chunks, shape, dtype=x.dtype, previous_chunks=x.chunks)
264
+ for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
265
+ if old_bd != new_bd and old_bd != (1,):
266
+ raise ValueError(
267
+ f"cannot broadcast chunks {x.chunks} to chunks {chunks}: "
268
+ "new chunks must either be along a new "
269
+ "dimension or a dimension of size 1"
270
+ )
271
+
272
+ return new_collection(BroadcastTo(x.expr, shape, chunks, meta))
dask_array/_chunk.py ADDED
@@ -0,0 +1,445 @@
1
+ """A set of NumPy functions to apply per chunk"""
2
+
3
+ from __future__ import annotations
4
+
5
+ import itertools
6
+ from collections.abc import Container, Iterable, Sequence
7
+ from functools import wraps
8
+ from numbers import Integral
9
+
10
+ import numpy as np
11
+
12
+
13
+ def concat(seqs):
14
+ """Concatenate zero or more iterables, any of which may be infinite.
15
+
16
+ An infinite sequence will prevent the rest of the arguments from
17
+ being included.
18
+
19
+ We use chain.from_iterable rather than ``chain(*seqs)`` so that seqs
20
+ can be a generator.
21
+
22
+ >>> list(concat([[], [1], [2, 3]]))
23
+ [1, 2, 3]
24
+
25
+ See also:
26
+ itertools.chain.from_iterable equivalent
27
+ """
28
+ return itertools.chain.from_iterable(seqs)
29
+
30
+
31
+ def flatten(seq, container=list):
32
+ """Flatten nested sequences.
33
+
34
+ >>> list(flatten([1]))
35
+ [1]
36
+
37
+ >>> list(flatten([[1, 2], [1, 2]]))
38
+ [1, 2, 1, 2]
39
+
40
+ >>> list(flatten([[[1], [2]], [[1], [2]]]))
41
+ [1, 2, 1, 2]
42
+
43
+ >>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples
44
+ [(1, 2), (1, 2)]
45
+
46
+ >>> list(flatten((1, 2, [3, 4]))) # support heterogeneous
47
+ [1, 2, 3, 4]
48
+ """
49
+ if isinstance(seq, str):
50
+ yield seq
51
+ else:
52
+ for item in seq:
53
+ if isinstance(item, container):
54
+ yield from flatten(item, container=container)
55
+ else:
56
+ yield item
57
+
58
+
59
+ def astype(x, astype_dtype=None, **kwargs):
60
+ """Change array dtype."""
61
+ return x.astype(astype_dtype, **kwargs)
62
+
63
+
64
+ def view(x, dtype, order="C"):
65
+ """View array as different dtype."""
66
+ if order == "C":
67
+ try:
68
+ x = np.ascontiguousarray(x, like=x)
69
+ except TypeError:
70
+ x = np.ascontiguousarray(x)
71
+ return x.view(dtype)
72
+ else:
73
+ try:
74
+ x = np.asfortranarray(x, like=x)
75
+ except TypeError:
76
+ x = np.asfortranarray(x)
77
+ return x.T.view(dtype).T
78
+
79
+
80
+ def trim(x, axes=None):
81
+ """Trim boundaries off of array.
82
+
83
+ >>> x = np.arange(24).reshape((4, 6))
84
+ >>> trim(x, axes={0: 0, 1: 1})
85
+ array([[ 1, 2, 3, 4],
86
+ [ 7, 8, 9, 10],
87
+ [13, 14, 15, 16],
88
+ [19, 20, 21, 22]])
89
+
90
+ >>> trim(x, axes={0: 1, 1: 1})
91
+ array([[ 7, 8, 9, 10],
92
+ [13, 14, 15, 16]])
93
+ """
94
+ if isinstance(axes, Integral):
95
+ axes = [axes] * x.ndim
96
+ if isinstance(axes, dict):
97
+ axes = [axes.get(i, 0) for i in range(x.ndim)]
98
+
99
+ return x[tuple(slice(ax, -ax if ax else None) for ax in axes)]
100
+
101
+
102
+ def coarsen(reduction, x, axes, trim_excess=False, **kwargs):
103
+ """Coarsen array by applying reduction to fixed size neighborhoods.
104
+
105
+ Parameters
106
+ ----------
107
+ reduction: function
108
+ Function like np.sum, np.mean, etc...
109
+ x: np.ndarray
110
+ Array to be coarsened
111
+ axes: dict
112
+ Mapping of axis to coarsening factor
113
+
114
+ Examples
115
+ --------
116
+ >>> x = np.array([1, 2, 3, 4, 5, 6])
117
+ >>> coarsen(np.sum, x, {0: 2})
118
+ array([ 3, 7, 11])
119
+ >>> coarsen(np.max, x, {0: 3})
120
+ array([3, 6])
121
+ """
122
+ # Insert singleton dimensions if they don't exist already
123
+ for i in range(x.ndim):
124
+ if i not in axes:
125
+ axes[i] = 1
126
+
127
+ if trim_excess:
128
+ ind = tuple(slice(0, -(d % axes[i])) if d % axes[i] else slice(None, None) for i, d in enumerate(x.shape))
129
+ x = x[ind]
130
+
131
+ # (10, 10) -> (5, 2, 5, 2)
132
+ newshape = tuple(concat([(x.shape[i] // axes[i], axes[i]) for i in range(x.ndim)]))
133
+
134
+ return reduction(x.reshape(newshape), axis=tuple(range(1, x.ndim * 2, 2)), **kwargs)
135
+
136
+
137
+ def keepdims_wrapper(a_callable):
138
+ """
139
+ A wrapper for functions that don't provide keepdims to ensure that they do.
140
+ """
141
+
142
+ @wraps(a_callable)
143
+ def keepdims_wrapped_callable(x, axis=None, keepdims=None, *args, **kwargs):
144
+ r = a_callable(x, *args, axis=axis, **kwargs)
145
+
146
+ if not keepdims:
147
+ return r
148
+
149
+ axes = axis
150
+
151
+ if axes is None:
152
+ axes = range(x.ndim)
153
+
154
+ if not isinstance(axes, (Container, Iterable, Sequence)):
155
+ axes = [axes]
156
+
157
+ r_slice = tuple()
158
+ for each_axis in range(x.ndim):
159
+ if each_axis in axes:
160
+ r_slice += (None,)
161
+ else:
162
+ r_slice += (slice(None),)
163
+
164
+ r = r[r_slice]
165
+
166
+ return r
167
+
168
+ return keepdims_wrapped_callable
169
+
170
+
171
+ # Wrap NumPy functions to ensure they provide keepdims.
172
+ sum = np.sum
173
+ prod = np.prod
174
+ min = np.min
175
+ max = np.max
176
+ argmin = keepdims_wrapper(np.argmin)
177
+ nanargmin = keepdims_wrapper(np.nanargmin)
178
+ argmax = keepdims_wrapper(np.argmax)
179
+ nanargmax = keepdims_wrapper(np.nanargmax)
180
+ any = np.any
181
+ all = np.all
182
+ nansum = np.nansum
183
+ nanprod = np.nanprod
184
+
185
+ nancumprod = np.nancumprod
186
+ nancumsum = np.nancumsum
187
+
188
+ nanmin = np.nanmin
189
+ nanmax = np.nanmax
190
+ mean = np.mean
191
+ nanmean = np.nanmean
192
+
193
+ var = np.var
194
+ nanvar = np.nanvar
195
+
196
+ std = np.std
197
+ nanstd = np.nanstd
198
+
199
+
200
+ def topk(a, k, axis, keepdims):
201
+ """Chunk and combine function of topk
202
+
203
+ Extract the k largest elements from a on the given axis.
204
+ If k is negative, extract the -k smallest elements instead.
205
+ Note that, unlike in the parent function, the returned elements
206
+ are not sorted internally.
207
+ """
208
+ assert keepdims is True
209
+ axis = axis[0]
210
+ if abs(k) >= a.shape[axis]:
211
+ return a
212
+
213
+ a = np.partition(a, -k, axis=axis)
214
+ k_slice = slice(-k, None) if k > 0 else slice(-k)
215
+ return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
216
+
217
+
218
+ def topk_aggregate(a, k, axis, keepdims):
219
+ """Final aggregation function of topk
220
+
221
+ Invoke topk one final time and then sort the results internally.
222
+ """
223
+ assert keepdims is True
224
+ a = topk(a, k, axis, keepdims)
225
+ axis = axis[0]
226
+ a = np.sort(a, axis=axis)
227
+ if k < 0:
228
+ return a
229
+ return a[tuple(slice(None, None, -1) if i == axis else slice(None) for i in range(a.ndim))]
230
+
231
+
232
+ def argtopk_preprocess(a, idx):
233
+ """Preparatory step for argtopk
234
+
235
+ Put data together with its original indices in a tuple.
236
+ """
237
+ return a, idx
238
+
239
+
240
+ def argtopk(a_plus_idx, k, axis, keepdims):
241
+ """Chunk and combine function of argtopk
242
+
243
+ Extract the indices of the k largest elements from a on the given axis.
244
+ If k is negative, extract the indices of the -k smallest elements instead.
245
+ Note that, unlike in the parent function, the returned elements
246
+ are not sorted internally.
247
+ """
248
+ assert keepdims is True
249
+ axis = axis[0]
250
+
251
+ if isinstance(a_plus_idx, list):
252
+ a_plus_idx = list(flatten(a_plus_idx))
253
+ a = np.concatenate([ai for ai, _ in a_plus_idx], axis)
254
+ idx = np.concatenate([np.broadcast_to(idxi, ai.shape) for ai, idxi in a_plus_idx], axis)
255
+ else:
256
+ a, idx = a_plus_idx
257
+
258
+ if abs(k) >= a.shape[axis]:
259
+ return a_plus_idx
260
+
261
+ idx2 = np.argpartition(a, -k, axis=axis)
262
+ k_slice = slice(-k, None) if k > 0 else slice(-k)
263
+ idx2 = idx2[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
264
+
265
+ return np.take_along_axis(a, idx2, axis), np.take_along_axis(idx, idx2, axis)
266
+
267
+
268
+ def argtopk_aggregate(a_plus_idx, k, axis, keepdims):
269
+ """Final aggregation function of argtopk
270
+
271
+ Invoke argtopk one final time, sort the results internally, and drop the data.
272
+ """
273
+ assert keepdims is True
274
+ a_plus_idx = a_plus_idx if len(a_plus_idx) > 1 else a_plus_idx[0]
275
+ a, idx = argtopk(a_plus_idx, k, axis, keepdims)
276
+ axis = axis[0]
277
+ idx2 = np.argsort(a, axis=axis)
278
+
279
+ idx = np.take_along_axis(idx, idx2, axis)
280
+ if k < 0:
281
+ return idx
282
+ return idx[tuple(slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim))]
283
+
284
+
285
+ def getitem(obj, index):
286
+ """Getitem function
287
+
288
+ This function creates a copy of the desired selection for array-like
289
+ inputs when the selection is smaller than half of the original array. This
290
+ avoids excess memory usage when extracting a small portion from a large array.
291
+ For more information, see
292
+ https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing.
293
+
294
+ Parameters
295
+ ----------
296
+ obj: ndarray, string, tuple, list
297
+ Object to get item from.
298
+ index: int, list[int], slice()
299
+ Desired selection to extract from obj.
300
+
301
+ Returns
302
+ -------
303
+ Selection obj[index]
304
+
305
+ """
306
+ try:
307
+ result = obj[index]
308
+ except IndexError as e:
309
+ raise ValueError("Array chunk size or shape is unknown. Possible solution with x.compute_chunk_sizes()") from e
310
+
311
+ try:
312
+ if not result.flags.owndata and obj.size >= 2 * result.size:
313
+ result = result.copy()
314
+ except AttributeError:
315
+ pass
316
+
317
+ return result
318
+
319
+
320
+ def arange(start, stop, step, length, dtype, like=None):
321
+ """Arange wrapper for chunk generation.
322
+
323
+ Creates an arange and truncates if needed to match expected length.
324
+ """
325
+ if like is None:
326
+ res = np.arange(start, stop, step, dtype=dtype)
327
+ else:
328
+ try:
329
+ res = np.arange(start, stop, step, dtype=dtype, like=like)
330
+ except TypeError:
331
+ res = np.arange(start, stop, step, dtype=dtype)
332
+ return res[:-1] if len(res) > length else res
333
+
334
+
335
+ def linspace(start, stop, num, endpoint=True, dtype=None):
336
+ """Linspace wrapper for chunk generation."""
337
+ return np.linspace(start, stop, num, endpoint=endpoint, dtype=dtype)
338
+
339
+
340
+ def slice_with_int_dask_array(x, idx, offset, x_size, axis):
341
+ """Chunk function of `slice_with_int_dask_array_on_axis`.
342
+ Slice one chunk of x by one chunk of idx.
343
+
344
+ Parameters
345
+ ----------
346
+ x: ndarray, any dtype, any shape
347
+ i-th chunk of x
348
+ idx: ndarray, ndim=1, dtype=any integer
349
+ j-th chunk of idx (cartesian product with the chunks of x)
350
+ offset: ndarray, shape=(1, ), dtype=int64
351
+ Index of the first element along axis of the current chunk of x
352
+ x_size: int
353
+ Total size of the x da.Array along axis
354
+ axis: int
355
+ normalized axis to take elements from (0 <= axis < x.ndim)
356
+
357
+ Returns
358
+ -------
359
+ x sliced along axis, using only the elements of idx that fall inside the
360
+ current chunk.
361
+ """
362
+ from dask_array._utils import meta_from_array
363
+
364
+ # asarray_safe functionality - convert to appropriate array type
365
+ if hasattr(x, "__array_namespace__"):
366
+ try:
367
+ xp = x.__array_namespace__()
368
+ idx = xp.asarray(idx)
369
+ except (AttributeError, TypeError):
370
+ pass
371
+
372
+ idx = np.asarray(idx)
373
+ if not np.issubdtype(idx.dtype, np.integer):
374
+ idx = meta_from_array(x)
375
+
376
+ # Needed when idx is unsigned
377
+ idx = idx.astype(np.int64)
378
+
379
+ # Normalize negative indices
380
+ idx = np.where(idx < 0, idx + x_size, idx)
381
+
382
+ # A chunk of the offset dask Array is a numpy array with shape (1, ).
383
+ # It indicates the index of the first element along axis of the current
384
+ # chunk of x.
385
+ idx = idx - offset
386
+
387
+ # Drop elements of idx that do not fall inside the current chunk of x
388
+ idx_filter = (idx >= 0) & (idx < x.shape[axis])
389
+ idx = idx[idx_filter]
390
+
391
+ # np.take does not support slice indices
392
+ # return np.take(x, idx, axis)
393
+ return x[tuple(idx if i == axis else slice(None) for i in range(x.ndim))]
394
+
395
+
396
+ def slice_with_int_dask_array_aggregate(idx, chunk_outputs, x_chunks, axis):
397
+ """Final aggregation function of `slice_with_int_dask_array_on_axis`.
398
+ Aggregate all chunks of x by one chunk of idx, reordering the output of
399
+ `slice_with_int_dask_array`.
400
+
401
+ Note that there is no combine function, as a recursive aggregation (e.g.
402
+ with split_every) would not give any benefit.
403
+
404
+ Parameters
405
+ ----------
406
+ idx: ndarray, ndim=1, dtype=any integer
407
+ j-th chunk of idx
408
+ chunk_outputs: ndarray
409
+ concatenation along axis of the outputs of `slice_with_int_dask_array`
410
+ for all chunks of x and the j-th chunk of idx
411
+ x_chunks: tuple
412
+ dask chunks of the x da.Array along axis, e.g. ``(3, 3, 2)``
413
+ axis: int
414
+ normalized axis to take elements from (0 <= axis < x.ndim)
415
+
416
+ Returns
417
+ -------
418
+ Selection from all chunks of x for the j-th chunk of idx, in the correct
419
+ order
420
+ """
421
+ # Needed when idx is unsigned
422
+ idx = idx.astype(np.int64)
423
+
424
+ # Normalize negative indices
425
+ idx = np.where(idx < 0, idx + sum(x_chunks), idx)
426
+
427
+ x_chunk_offset = 0
428
+ chunk_output_offset = 0
429
+
430
+ # Assemble the final index that picks from the output of the previous
431
+ # kernel by adding together one layer per chunk of x
432
+ # FIXME: this could probably be reimplemented with a faster search-based
433
+ # algorithm
434
+ idx_final = np.zeros_like(idx)
435
+ for x_chunk in x_chunks:
436
+ idx_filter = (idx >= x_chunk_offset) & (idx < x_chunk_offset + x_chunk)
437
+ idx_cum = np.cumsum(idx_filter)
438
+ idx_final += np.where(idx_filter, idx_cum - 1 + chunk_output_offset, 0)
439
+ x_chunk_offset += x_chunk
440
+ if idx_cum.size > 0:
441
+ chunk_output_offset += idx_cum[-1]
442
+
443
+ # np.take does not support slice indices
444
+ # return np.take(chunk_outputs, idx_final, axis)
445
+ return chunk_outputs[tuple(idx_final if i == axis else slice(None) for i in range(chunk_outputs.ndim))]