dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_broadcast.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
from itertools import product
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from dask._task_spec import Task, TaskRef
|
|
9
|
+
from dask_array._expr import ArrayExpr
|
|
10
|
+
from dask_array._core_utils import normalize_chunks
|
|
11
|
+
from dask_array._utils import meta_from_array
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BroadcastTo(ArrayExpr):
|
|
15
|
+
"""Broadcast an array to a new shape."""
|
|
16
|
+
|
|
17
|
+
_parameters = ["array", "_shape", "_chunks", "_meta_override"]
|
|
18
|
+
_defaults = {"_meta_override": None}
|
|
19
|
+
|
|
20
|
+
@functools.cached_property
|
|
21
|
+
def _name(self):
|
|
22
|
+
return f"broadcast_to-{self.deterministic_token}"
|
|
23
|
+
|
|
24
|
+
@functools.cached_property
|
|
25
|
+
def _meta(self):
|
|
26
|
+
meta_override = self.operand("_meta_override")
|
|
27
|
+
# Only use meta_override if it has the correct ndim
|
|
28
|
+
if meta_override is not None and hasattr(meta_override, "ndim") and meta_override.ndim == len(self._shape):
|
|
29
|
+
return meta_override
|
|
30
|
+
return meta_from_array(self.array._meta, ndim=len(self._shape))
|
|
31
|
+
|
|
32
|
+
@functools.cached_property
|
|
33
|
+
def chunks(self):
|
|
34
|
+
return self._chunks
|
|
35
|
+
|
|
36
|
+
def _layer(self) -> dict:
|
|
37
|
+
x = self.array
|
|
38
|
+
shape = self._shape
|
|
39
|
+
chunks = self._chunks
|
|
40
|
+
ndim_new = len(shape) - x.ndim
|
|
41
|
+
|
|
42
|
+
dsk = {}
|
|
43
|
+
enumerated_chunks = product(*(enumerate(bds) for bds in chunks))
|
|
44
|
+
for ec in enumerated_chunks:
|
|
45
|
+
new_index, chunk_shape = zip(*ec)
|
|
46
|
+
old_index = tuple(0 if bd == (1,) else i for bd, i in zip(x.chunks, new_index[ndim_new:]))
|
|
47
|
+
old_key = (x._name,) + old_index
|
|
48
|
+
new_key = (self._name,) + new_index
|
|
49
|
+
dsk[new_key] = Task(new_key, np.broadcast_to, TaskRef(old_key), chunk_shape)
|
|
50
|
+
|
|
51
|
+
return dsk
|
|
52
|
+
|
|
53
|
+
def _simplify_up(self, parent, dependents):
|
|
54
|
+
"""Allow slice and shuffle operations to push through BroadcastTo."""
|
|
55
|
+
from dask_array._shuffle import Shuffle
|
|
56
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
57
|
+
|
|
58
|
+
if isinstance(parent, SliceSlicesIntegers):
|
|
59
|
+
return self._accept_slice(parent)
|
|
60
|
+
if isinstance(parent, Shuffle):
|
|
61
|
+
return self._accept_shuffle(parent)
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
def _accept_shuffle(self, shuffle_expr):
|
|
65
|
+
"""Accept a shuffle being pushed through BroadcastTo.
|
|
66
|
+
|
|
67
|
+
- Shuffle on a new dimension (added by broadcast): can't push through
|
|
68
|
+
- Shuffle on dimension broadcast from size 1: no-op, return self
|
|
69
|
+
- Shuffle on dimension with real data: push through to input
|
|
70
|
+
"""
|
|
71
|
+
from dask_array._shuffle import Shuffle
|
|
72
|
+
|
|
73
|
+
axis = shuffle_expr.axis
|
|
74
|
+
ndim_new = len(self._shape) - self.array.ndim
|
|
75
|
+
|
|
76
|
+
# Shuffle on a new dimension (added by broadcast) - can't push through
|
|
77
|
+
if axis < ndim_new:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
# Map to input axis
|
|
81
|
+
input_axis = axis - ndim_new
|
|
82
|
+
input_size = self.array.shape[input_axis]
|
|
83
|
+
|
|
84
|
+
# If input dimension is size 1 (broadcasted), shuffle is a no-op
|
|
85
|
+
if input_size == 1:
|
|
86
|
+
return self
|
|
87
|
+
|
|
88
|
+
# Push shuffle through to input
|
|
89
|
+
shuffled_input = Shuffle(
|
|
90
|
+
self.array,
|
|
91
|
+
shuffle_expr.indexer,
|
|
92
|
+
input_axis,
|
|
93
|
+
shuffle_expr.operand("name"),
|
|
94
|
+
)
|
|
95
|
+
return BroadcastTo(shuffled_input, self._shape, self._chunks, self._meta)
|
|
96
|
+
|
|
97
|
+
def _accept_slice(self, slice_expr):
|
|
98
|
+
"""Accept a slice being pushed through BroadcastTo.
|
|
99
|
+
|
|
100
|
+
For broadcast_to(x, shape)[slices]:
|
|
101
|
+
- Dimensions added by broadcast (new dims): affect output shape only
|
|
102
|
+
- Dimensions from input with size > 1: push slice to input
|
|
103
|
+
- Dimensions from input with size == 1: affect output shape only
|
|
104
|
+
"""
|
|
105
|
+
from numbers import Integral
|
|
106
|
+
|
|
107
|
+
from dask_array._new_collection import new_collection
|
|
108
|
+
|
|
109
|
+
input_arr = self.array
|
|
110
|
+
output_shape = self._shape
|
|
111
|
+
index = slice_expr.index
|
|
112
|
+
|
|
113
|
+
# Pad index to full length
|
|
114
|
+
full_index = index + (slice(None),) * (len(output_shape) - len(index))
|
|
115
|
+
|
|
116
|
+
# For now, only handle simple slices
|
|
117
|
+
if any(isinstance(idx, Integral) for idx in full_index):
|
|
118
|
+
return None
|
|
119
|
+
if any(idx is None for idx in full_index):
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
ndim_new = len(output_shape) - input_arr.ndim
|
|
123
|
+
|
|
124
|
+
# Compute new output shape and input slices
|
|
125
|
+
new_output_shape = []
|
|
126
|
+
input_slices = []
|
|
127
|
+
|
|
128
|
+
for out_dim, idx in enumerate(full_index):
|
|
129
|
+
if not isinstance(idx, slice):
|
|
130
|
+
return None
|
|
131
|
+
|
|
132
|
+
out_size = output_shape[out_dim]
|
|
133
|
+
start, stop, step = idx.indices(out_size)
|
|
134
|
+
if step != 1:
|
|
135
|
+
return None
|
|
136
|
+
new_dim_size = max(0, stop - start)
|
|
137
|
+
new_output_shape.append(new_dim_size)
|
|
138
|
+
|
|
139
|
+
# Check if this dimension maps to input
|
|
140
|
+
if out_dim >= ndim_new:
|
|
141
|
+
in_dim = out_dim - ndim_new
|
|
142
|
+
in_size = input_arr.shape[in_dim]
|
|
143
|
+
|
|
144
|
+
if in_size == 1:
|
|
145
|
+
# Broadcasted from size 1 - can't push, just take full slice
|
|
146
|
+
input_slices.append(slice(None))
|
|
147
|
+
else:
|
|
148
|
+
# Real dimension - push the slice
|
|
149
|
+
input_slices.append(idx)
|
|
150
|
+
|
|
151
|
+
# Slice the input array
|
|
152
|
+
if input_slices:
|
|
153
|
+
sliced_input = new_collection(input_arr)[tuple(input_slices)]
|
|
154
|
+
else:
|
|
155
|
+
sliced_input = new_collection(input_arr)
|
|
156
|
+
|
|
157
|
+
# Compute new chunks for the output
|
|
158
|
+
# For dimensions from input: use input's (sliced) chunks
|
|
159
|
+
# For new dimensions: use the new output shape (single chunk)
|
|
160
|
+
old_chunks = self._chunks
|
|
161
|
+
new_chunks = []
|
|
162
|
+
for out_dim, old_chunk in enumerate(old_chunks):
|
|
163
|
+
if out_dim >= ndim_new:
|
|
164
|
+
in_dim = out_dim - ndim_new
|
|
165
|
+
in_size = input_arr.shape[in_dim]
|
|
166
|
+
if in_size == 1:
|
|
167
|
+
# Broadcasted - compute new chunks from old
|
|
168
|
+
idx = full_index[out_dim]
|
|
169
|
+
start, stop, _ = idx.indices(output_shape[out_dim])
|
|
170
|
+
new_chunks.append(self._slice_chunks(old_chunk, start, stop - start))
|
|
171
|
+
else:
|
|
172
|
+
# Use sliced input's chunks
|
|
173
|
+
new_chunks.append(sliced_input.expr.chunks[in_dim])
|
|
174
|
+
else:
|
|
175
|
+
# New dimension - compute from slice
|
|
176
|
+
idx = full_index[out_dim]
|
|
177
|
+
start, stop, _ = idx.indices(output_shape[out_dim])
|
|
178
|
+
new_chunks.append(self._slice_chunks(old_chunk, start, stop - start))
|
|
179
|
+
|
|
180
|
+
# Compute meta for the new broadcast
|
|
181
|
+
new_meta = meta_from_array(sliced_input.expr._meta)
|
|
182
|
+
|
|
183
|
+
# Create new BroadcastTo
|
|
184
|
+
return BroadcastTo(
|
|
185
|
+
sliced_input.expr,
|
|
186
|
+
tuple(new_output_shape),
|
|
187
|
+
tuple(new_chunks),
|
|
188
|
+
new_meta,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def _slice_chunks(self, chunks, start, length):
|
|
192
|
+
"""Compute new chunks after slicing."""
|
|
193
|
+
if length == 0:
|
|
194
|
+
return (0,)
|
|
195
|
+
|
|
196
|
+
result = []
|
|
197
|
+
pos = 0
|
|
198
|
+
remaining = length
|
|
199
|
+
for chunk_size in chunks:
|
|
200
|
+
chunk_start = pos
|
|
201
|
+
chunk_end = pos + chunk_size
|
|
202
|
+
pos = chunk_end
|
|
203
|
+
|
|
204
|
+
if chunk_end <= start:
|
|
205
|
+
continue
|
|
206
|
+
if chunk_start >= start + length:
|
|
207
|
+
break
|
|
208
|
+
|
|
209
|
+
# Overlap with the slice
|
|
210
|
+
overlap_start = max(chunk_start, start)
|
|
211
|
+
overlap_end = min(chunk_end, start + length)
|
|
212
|
+
overlap_size = overlap_end - overlap_start
|
|
213
|
+
|
|
214
|
+
if overlap_size > 0:
|
|
215
|
+
result.append(overlap_size)
|
|
216
|
+
remaining -= overlap_size
|
|
217
|
+
|
|
218
|
+
return tuple(result)
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def broadcast_to(x, shape, chunks=None, meta=None):
|
|
222
|
+
"""Broadcast an array to a new shape.
|
|
223
|
+
|
|
224
|
+
Parameters
|
|
225
|
+
----------
|
|
226
|
+
x : array_like
|
|
227
|
+
The array to broadcast.
|
|
228
|
+
shape : tuple
|
|
229
|
+
The shape of the desired array.
|
|
230
|
+
chunks : tuple, optional
|
|
231
|
+
If provided, then the result will use these chunks instead of the same
|
|
232
|
+
chunks as the source array.
|
|
233
|
+
meta : empty ndarray, optional
|
|
234
|
+
empty ndarray created with same NumPy backend, ndim and dtype as the
|
|
235
|
+
Dask Array being created
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
Array
|
|
240
|
+
"""
|
|
241
|
+
from dask_array._new_collection import new_collection
|
|
242
|
+
from dask_array._collection import asarray
|
|
243
|
+
|
|
244
|
+
x = asarray(x)
|
|
245
|
+
shape = tuple(shape)
|
|
246
|
+
|
|
247
|
+
if meta is None:
|
|
248
|
+
meta = meta_from_array(x._meta)
|
|
249
|
+
|
|
250
|
+
# Identity case
|
|
251
|
+
if x.shape == shape and (chunks is None or chunks == x.chunks):
|
|
252
|
+
return x
|
|
253
|
+
|
|
254
|
+
ndim_new = len(shape) - x.ndim
|
|
255
|
+
if ndim_new < 0 or any(new != old for new, old in zip(shape[ndim_new:], x.shape) if old != 1):
|
|
256
|
+
raise ValueError(f"cannot broadcast shape {x.shape} to shape {shape}")
|
|
257
|
+
|
|
258
|
+
if chunks is None:
|
|
259
|
+
chunks = tuple((s,) for s in shape[:ndim_new]) + tuple(
|
|
260
|
+
bd if old > 1 else (new,) for bd, old, new in zip(x.chunks, x.shape, shape[ndim_new:])
|
|
261
|
+
)
|
|
262
|
+
else:
|
|
263
|
+
chunks = normalize_chunks(chunks, shape, dtype=x.dtype, previous_chunks=x.chunks)
|
|
264
|
+
for old_bd, new_bd in zip(x.chunks, chunks[ndim_new:]):
|
|
265
|
+
if old_bd != new_bd and old_bd != (1,):
|
|
266
|
+
raise ValueError(
|
|
267
|
+
f"cannot broadcast chunks {x.chunks} to chunks {chunks}: "
|
|
268
|
+
"new chunks must either be along a new "
|
|
269
|
+
"dimension or a dimension of size 1"
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
return new_collection(BroadcastTo(x.expr, shape, chunks, meta))
|
dask_array/_chunk.py
ADDED
|
@@ -0,0 +1,445 @@
|
|
|
1
|
+
"""A set of NumPy functions to apply per chunk"""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import itertools
|
|
6
|
+
from collections.abc import Container, Iterable, Sequence
|
|
7
|
+
from functools import wraps
|
|
8
|
+
from numbers import Integral
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def concat(seqs):
|
|
14
|
+
"""Concatenate zero or more iterables, any of which may be infinite.
|
|
15
|
+
|
|
16
|
+
An infinite sequence will prevent the rest of the arguments from
|
|
17
|
+
being included.
|
|
18
|
+
|
|
19
|
+
We use chain.from_iterable rather than ``chain(*seqs)`` so that seqs
|
|
20
|
+
can be a generator.
|
|
21
|
+
|
|
22
|
+
>>> list(concat([[], [1], [2, 3]]))
|
|
23
|
+
[1, 2, 3]
|
|
24
|
+
|
|
25
|
+
See also:
|
|
26
|
+
itertools.chain.from_iterable equivalent
|
|
27
|
+
"""
|
|
28
|
+
return itertools.chain.from_iterable(seqs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def flatten(seq, container=list):
|
|
32
|
+
"""Flatten nested sequences.
|
|
33
|
+
|
|
34
|
+
>>> list(flatten([1]))
|
|
35
|
+
[1]
|
|
36
|
+
|
|
37
|
+
>>> list(flatten([[1, 2], [1, 2]]))
|
|
38
|
+
[1, 2, 1, 2]
|
|
39
|
+
|
|
40
|
+
>>> list(flatten([[[1], [2]], [[1], [2]]]))
|
|
41
|
+
[1, 2, 1, 2]
|
|
42
|
+
|
|
43
|
+
>>> list(flatten(((1, 2), (1, 2)))) # Don't flatten tuples
|
|
44
|
+
[(1, 2), (1, 2)]
|
|
45
|
+
|
|
46
|
+
>>> list(flatten((1, 2, [3, 4]))) # support heterogeneous
|
|
47
|
+
[1, 2, 3, 4]
|
|
48
|
+
"""
|
|
49
|
+
if isinstance(seq, str):
|
|
50
|
+
yield seq
|
|
51
|
+
else:
|
|
52
|
+
for item in seq:
|
|
53
|
+
if isinstance(item, container):
|
|
54
|
+
yield from flatten(item, container=container)
|
|
55
|
+
else:
|
|
56
|
+
yield item
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def astype(x, astype_dtype=None, **kwargs):
|
|
60
|
+
"""Change array dtype."""
|
|
61
|
+
return x.astype(astype_dtype, **kwargs)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def view(x, dtype, order="C"):
|
|
65
|
+
"""View array as different dtype."""
|
|
66
|
+
if order == "C":
|
|
67
|
+
try:
|
|
68
|
+
x = np.ascontiguousarray(x, like=x)
|
|
69
|
+
except TypeError:
|
|
70
|
+
x = np.ascontiguousarray(x)
|
|
71
|
+
return x.view(dtype)
|
|
72
|
+
else:
|
|
73
|
+
try:
|
|
74
|
+
x = np.asfortranarray(x, like=x)
|
|
75
|
+
except TypeError:
|
|
76
|
+
x = np.asfortranarray(x)
|
|
77
|
+
return x.T.view(dtype).T
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def trim(x, axes=None):
|
|
81
|
+
"""Trim boundaries off of array.
|
|
82
|
+
|
|
83
|
+
>>> x = np.arange(24).reshape((4, 6))
|
|
84
|
+
>>> trim(x, axes={0: 0, 1: 1})
|
|
85
|
+
array([[ 1, 2, 3, 4],
|
|
86
|
+
[ 7, 8, 9, 10],
|
|
87
|
+
[13, 14, 15, 16],
|
|
88
|
+
[19, 20, 21, 22]])
|
|
89
|
+
|
|
90
|
+
>>> trim(x, axes={0: 1, 1: 1})
|
|
91
|
+
array([[ 7, 8, 9, 10],
|
|
92
|
+
[13, 14, 15, 16]])
|
|
93
|
+
"""
|
|
94
|
+
if isinstance(axes, Integral):
|
|
95
|
+
axes = [axes] * x.ndim
|
|
96
|
+
if isinstance(axes, dict):
|
|
97
|
+
axes = [axes.get(i, 0) for i in range(x.ndim)]
|
|
98
|
+
|
|
99
|
+
return x[tuple(slice(ax, -ax if ax else None) for ax in axes)]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def coarsen(reduction, x, axes, trim_excess=False, **kwargs):
|
|
103
|
+
"""Coarsen array by applying reduction to fixed size neighborhoods.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
reduction: function
|
|
108
|
+
Function like np.sum, np.mean, etc...
|
|
109
|
+
x: np.ndarray
|
|
110
|
+
Array to be coarsened
|
|
111
|
+
axes: dict
|
|
112
|
+
Mapping of axis to coarsening factor
|
|
113
|
+
|
|
114
|
+
Examples
|
|
115
|
+
--------
|
|
116
|
+
>>> x = np.array([1, 2, 3, 4, 5, 6])
|
|
117
|
+
>>> coarsen(np.sum, x, {0: 2})
|
|
118
|
+
array([ 3, 7, 11])
|
|
119
|
+
>>> coarsen(np.max, x, {0: 3})
|
|
120
|
+
array([3, 6])
|
|
121
|
+
"""
|
|
122
|
+
# Insert singleton dimensions if they don't exist already
|
|
123
|
+
for i in range(x.ndim):
|
|
124
|
+
if i not in axes:
|
|
125
|
+
axes[i] = 1
|
|
126
|
+
|
|
127
|
+
if trim_excess:
|
|
128
|
+
ind = tuple(slice(0, -(d % axes[i])) if d % axes[i] else slice(None, None) for i, d in enumerate(x.shape))
|
|
129
|
+
x = x[ind]
|
|
130
|
+
|
|
131
|
+
# (10, 10) -> (5, 2, 5, 2)
|
|
132
|
+
newshape = tuple(concat([(x.shape[i] // axes[i], axes[i]) for i in range(x.ndim)]))
|
|
133
|
+
|
|
134
|
+
return reduction(x.reshape(newshape), axis=tuple(range(1, x.ndim * 2, 2)), **kwargs)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def keepdims_wrapper(a_callable):
|
|
138
|
+
"""
|
|
139
|
+
A wrapper for functions that don't provide keepdims to ensure that they do.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
@wraps(a_callable)
|
|
143
|
+
def keepdims_wrapped_callable(x, axis=None, keepdims=None, *args, **kwargs):
|
|
144
|
+
r = a_callable(x, *args, axis=axis, **kwargs)
|
|
145
|
+
|
|
146
|
+
if not keepdims:
|
|
147
|
+
return r
|
|
148
|
+
|
|
149
|
+
axes = axis
|
|
150
|
+
|
|
151
|
+
if axes is None:
|
|
152
|
+
axes = range(x.ndim)
|
|
153
|
+
|
|
154
|
+
if not isinstance(axes, (Container, Iterable, Sequence)):
|
|
155
|
+
axes = [axes]
|
|
156
|
+
|
|
157
|
+
r_slice = tuple()
|
|
158
|
+
for each_axis in range(x.ndim):
|
|
159
|
+
if each_axis in axes:
|
|
160
|
+
r_slice += (None,)
|
|
161
|
+
else:
|
|
162
|
+
r_slice += (slice(None),)
|
|
163
|
+
|
|
164
|
+
r = r[r_slice]
|
|
165
|
+
|
|
166
|
+
return r
|
|
167
|
+
|
|
168
|
+
return keepdims_wrapped_callable
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# Wrap NumPy functions to ensure they provide keepdims.
|
|
172
|
+
sum = np.sum
|
|
173
|
+
prod = np.prod
|
|
174
|
+
min = np.min
|
|
175
|
+
max = np.max
|
|
176
|
+
argmin = keepdims_wrapper(np.argmin)
|
|
177
|
+
nanargmin = keepdims_wrapper(np.nanargmin)
|
|
178
|
+
argmax = keepdims_wrapper(np.argmax)
|
|
179
|
+
nanargmax = keepdims_wrapper(np.nanargmax)
|
|
180
|
+
any = np.any
|
|
181
|
+
all = np.all
|
|
182
|
+
nansum = np.nansum
|
|
183
|
+
nanprod = np.nanprod
|
|
184
|
+
|
|
185
|
+
nancumprod = np.nancumprod
|
|
186
|
+
nancumsum = np.nancumsum
|
|
187
|
+
|
|
188
|
+
nanmin = np.nanmin
|
|
189
|
+
nanmax = np.nanmax
|
|
190
|
+
mean = np.mean
|
|
191
|
+
nanmean = np.nanmean
|
|
192
|
+
|
|
193
|
+
var = np.var
|
|
194
|
+
nanvar = np.nanvar
|
|
195
|
+
|
|
196
|
+
std = np.std
|
|
197
|
+
nanstd = np.nanstd
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def topk(a, k, axis, keepdims):
|
|
201
|
+
"""Chunk and combine function of topk
|
|
202
|
+
|
|
203
|
+
Extract the k largest elements from a on the given axis.
|
|
204
|
+
If k is negative, extract the -k smallest elements instead.
|
|
205
|
+
Note that, unlike in the parent function, the returned elements
|
|
206
|
+
are not sorted internally.
|
|
207
|
+
"""
|
|
208
|
+
assert keepdims is True
|
|
209
|
+
axis = axis[0]
|
|
210
|
+
if abs(k) >= a.shape[axis]:
|
|
211
|
+
return a
|
|
212
|
+
|
|
213
|
+
a = np.partition(a, -k, axis=axis)
|
|
214
|
+
k_slice = slice(-k, None) if k > 0 else slice(-k)
|
|
215
|
+
return a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def topk_aggregate(a, k, axis, keepdims):
|
|
219
|
+
"""Final aggregation function of topk
|
|
220
|
+
|
|
221
|
+
Invoke topk one final time and then sort the results internally.
|
|
222
|
+
"""
|
|
223
|
+
assert keepdims is True
|
|
224
|
+
a = topk(a, k, axis, keepdims)
|
|
225
|
+
axis = axis[0]
|
|
226
|
+
a = np.sort(a, axis=axis)
|
|
227
|
+
if k < 0:
|
|
228
|
+
return a
|
|
229
|
+
return a[tuple(slice(None, None, -1) if i == axis else slice(None) for i in range(a.ndim))]
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def argtopk_preprocess(a, idx):
|
|
233
|
+
"""Preparatory step for argtopk
|
|
234
|
+
|
|
235
|
+
Put data together with its original indices in a tuple.
|
|
236
|
+
"""
|
|
237
|
+
return a, idx
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def argtopk(a_plus_idx, k, axis, keepdims):
|
|
241
|
+
"""Chunk and combine function of argtopk
|
|
242
|
+
|
|
243
|
+
Extract the indices of the k largest elements from a on the given axis.
|
|
244
|
+
If k is negative, extract the indices of the -k smallest elements instead.
|
|
245
|
+
Note that, unlike in the parent function, the returned elements
|
|
246
|
+
are not sorted internally.
|
|
247
|
+
"""
|
|
248
|
+
assert keepdims is True
|
|
249
|
+
axis = axis[0]
|
|
250
|
+
|
|
251
|
+
if isinstance(a_plus_idx, list):
|
|
252
|
+
a_plus_idx = list(flatten(a_plus_idx))
|
|
253
|
+
a = np.concatenate([ai for ai, _ in a_plus_idx], axis)
|
|
254
|
+
idx = np.concatenate([np.broadcast_to(idxi, ai.shape) for ai, idxi in a_plus_idx], axis)
|
|
255
|
+
else:
|
|
256
|
+
a, idx = a_plus_idx
|
|
257
|
+
|
|
258
|
+
if abs(k) >= a.shape[axis]:
|
|
259
|
+
return a_plus_idx
|
|
260
|
+
|
|
261
|
+
idx2 = np.argpartition(a, -k, axis=axis)
|
|
262
|
+
k_slice = slice(-k, None) if k > 0 else slice(-k)
|
|
263
|
+
idx2 = idx2[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))]
|
|
264
|
+
|
|
265
|
+
return np.take_along_axis(a, idx2, axis), np.take_along_axis(idx, idx2, axis)
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def argtopk_aggregate(a_plus_idx, k, axis, keepdims):
|
|
269
|
+
"""Final aggregation function of argtopk
|
|
270
|
+
|
|
271
|
+
Invoke argtopk one final time, sort the results internally, and drop the data.
|
|
272
|
+
"""
|
|
273
|
+
assert keepdims is True
|
|
274
|
+
a_plus_idx = a_plus_idx if len(a_plus_idx) > 1 else a_plus_idx[0]
|
|
275
|
+
a, idx = argtopk(a_plus_idx, k, axis, keepdims)
|
|
276
|
+
axis = axis[0]
|
|
277
|
+
idx2 = np.argsort(a, axis=axis)
|
|
278
|
+
|
|
279
|
+
idx = np.take_along_axis(idx, idx2, axis)
|
|
280
|
+
if k < 0:
|
|
281
|
+
return idx
|
|
282
|
+
return idx[tuple(slice(None, None, -1) if i == axis else slice(None) for i in range(idx.ndim))]
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def getitem(obj, index):
|
|
286
|
+
"""Getitem function
|
|
287
|
+
|
|
288
|
+
This function creates a copy of the desired selection for array-like
|
|
289
|
+
inputs when the selection is smaller than half of the original array. This
|
|
290
|
+
avoids excess memory usage when extracting a small portion from a large array.
|
|
291
|
+
For more information, see
|
|
292
|
+
https://numpy.org/doc/stable/reference/arrays.indexing.html#basic-slicing-and-indexing.
|
|
293
|
+
|
|
294
|
+
Parameters
|
|
295
|
+
----------
|
|
296
|
+
obj: ndarray, string, tuple, list
|
|
297
|
+
Object to get item from.
|
|
298
|
+
index: int, list[int], slice()
|
|
299
|
+
Desired selection to extract from obj.
|
|
300
|
+
|
|
301
|
+
Returns
|
|
302
|
+
-------
|
|
303
|
+
Selection obj[index]
|
|
304
|
+
|
|
305
|
+
"""
|
|
306
|
+
try:
|
|
307
|
+
result = obj[index]
|
|
308
|
+
except IndexError as e:
|
|
309
|
+
raise ValueError("Array chunk size or shape is unknown. Possible solution with x.compute_chunk_sizes()") from e
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
if not result.flags.owndata and obj.size >= 2 * result.size:
|
|
313
|
+
result = result.copy()
|
|
314
|
+
except AttributeError:
|
|
315
|
+
pass
|
|
316
|
+
|
|
317
|
+
return result
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def arange(start, stop, step, length, dtype, like=None):
|
|
321
|
+
"""Arange wrapper for chunk generation.
|
|
322
|
+
|
|
323
|
+
Creates an arange and truncates if needed to match expected length.
|
|
324
|
+
"""
|
|
325
|
+
if like is None:
|
|
326
|
+
res = np.arange(start, stop, step, dtype=dtype)
|
|
327
|
+
else:
|
|
328
|
+
try:
|
|
329
|
+
res = np.arange(start, stop, step, dtype=dtype, like=like)
|
|
330
|
+
except TypeError:
|
|
331
|
+
res = np.arange(start, stop, step, dtype=dtype)
|
|
332
|
+
return res[:-1] if len(res) > length else res
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def linspace(start, stop, num, endpoint=True, dtype=None):
|
|
336
|
+
"""Linspace wrapper for chunk generation."""
|
|
337
|
+
return np.linspace(start, stop, num, endpoint=endpoint, dtype=dtype)
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def slice_with_int_dask_array(x, idx, offset, x_size, axis):
|
|
341
|
+
"""Chunk function of `slice_with_int_dask_array_on_axis`.
|
|
342
|
+
Slice one chunk of x by one chunk of idx.
|
|
343
|
+
|
|
344
|
+
Parameters
|
|
345
|
+
----------
|
|
346
|
+
x: ndarray, any dtype, any shape
|
|
347
|
+
i-th chunk of x
|
|
348
|
+
idx: ndarray, ndim=1, dtype=any integer
|
|
349
|
+
j-th chunk of idx (cartesian product with the chunks of x)
|
|
350
|
+
offset: ndarray, shape=(1, ), dtype=int64
|
|
351
|
+
Index of the first element along axis of the current chunk of x
|
|
352
|
+
x_size: int
|
|
353
|
+
Total size of the x da.Array along axis
|
|
354
|
+
axis: int
|
|
355
|
+
normalized axis to take elements from (0 <= axis < x.ndim)
|
|
356
|
+
|
|
357
|
+
Returns
|
|
358
|
+
-------
|
|
359
|
+
x sliced along axis, using only the elements of idx that fall inside the
|
|
360
|
+
current chunk.
|
|
361
|
+
"""
|
|
362
|
+
from dask_array._utils import meta_from_array
|
|
363
|
+
|
|
364
|
+
# asarray_safe functionality - convert to appropriate array type
|
|
365
|
+
if hasattr(x, "__array_namespace__"):
|
|
366
|
+
try:
|
|
367
|
+
xp = x.__array_namespace__()
|
|
368
|
+
idx = xp.asarray(idx)
|
|
369
|
+
except (AttributeError, TypeError):
|
|
370
|
+
pass
|
|
371
|
+
|
|
372
|
+
idx = np.asarray(idx)
|
|
373
|
+
if not np.issubdtype(idx.dtype, np.integer):
|
|
374
|
+
idx = meta_from_array(x)
|
|
375
|
+
|
|
376
|
+
# Needed when idx is unsigned
|
|
377
|
+
idx = idx.astype(np.int64)
|
|
378
|
+
|
|
379
|
+
# Normalize negative indices
|
|
380
|
+
idx = np.where(idx < 0, idx + x_size, idx)
|
|
381
|
+
|
|
382
|
+
# A chunk of the offset dask Array is a numpy array with shape (1, ).
|
|
383
|
+
# It indicates the index of the first element along axis of the current
|
|
384
|
+
# chunk of x.
|
|
385
|
+
idx = idx - offset
|
|
386
|
+
|
|
387
|
+
# Drop elements of idx that do not fall inside the current chunk of x
|
|
388
|
+
idx_filter = (idx >= 0) & (idx < x.shape[axis])
|
|
389
|
+
idx = idx[idx_filter]
|
|
390
|
+
|
|
391
|
+
# np.take does not support slice indices
|
|
392
|
+
# return np.take(x, idx, axis)
|
|
393
|
+
return x[tuple(idx if i == axis else slice(None) for i in range(x.ndim))]
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def slice_with_int_dask_array_aggregate(idx, chunk_outputs, x_chunks, axis):
|
|
397
|
+
"""Final aggregation function of `slice_with_int_dask_array_on_axis`.
|
|
398
|
+
Aggregate all chunks of x by one chunk of idx, reordering the output of
|
|
399
|
+
`slice_with_int_dask_array`.
|
|
400
|
+
|
|
401
|
+
Note that there is no combine function, as a recursive aggregation (e.g.
|
|
402
|
+
with split_every) would not give any benefit.
|
|
403
|
+
|
|
404
|
+
Parameters
|
|
405
|
+
----------
|
|
406
|
+
idx: ndarray, ndim=1, dtype=any integer
|
|
407
|
+
j-th chunk of idx
|
|
408
|
+
chunk_outputs: ndarray
|
|
409
|
+
concatenation along axis of the outputs of `slice_with_int_dask_array`
|
|
410
|
+
for all chunks of x and the j-th chunk of idx
|
|
411
|
+
x_chunks: tuple
|
|
412
|
+
dask chunks of the x da.Array along axis, e.g. ``(3, 3, 2)``
|
|
413
|
+
axis: int
|
|
414
|
+
normalized axis to take elements from (0 <= axis < x.ndim)
|
|
415
|
+
|
|
416
|
+
Returns
|
|
417
|
+
-------
|
|
418
|
+
Selection from all chunks of x for the j-th chunk of idx, in the correct
|
|
419
|
+
order
|
|
420
|
+
"""
|
|
421
|
+
# Needed when idx is unsigned
|
|
422
|
+
idx = idx.astype(np.int64)
|
|
423
|
+
|
|
424
|
+
# Normalize negative indices
|
|
425
|
+
idx = np.where(idx < 0, idx + sum(x_chunks), idx)
|
|
426
|
+
|
|
427
|
+
x_chunk_offset = 0
|
|
428
|
+
chunk_output_offset = 0
|
|
429
|
+
|
|
430
|
+
# Assemble the final index that picks from the output of the previous
|
|
431
|
+
# kernel by adding together one layer per chunk of x
|
|
432
|
+
# FIXME: this could probably be reimplemented with a faster search-based
|
|
433
|
+
# algorithm
|
|
434
|
+
idx_final = np.zeros_like(idx)
|
|
435
|
+
for x_chunk in x_chunks:
|
|
436
|
+
idx_filter = (idx >= x_chunk_offset) & (idx < x_chunk_offset + x_chunk)
|
|
437
|
+
idx_cum = np.cumsum(idx_filter)
|
|
438
|
+
idx_final += np.where(idx_filter, idx_cum - 1 + chunk_output_offset, 0)
|
|
439
|
+
x_chunk_offset += x_chunk
|
|
440
|
+
if idx_cum.size > 0:
|
|
441
|
+
chunk_output_offset += idx_cum[-1]
|
|
442
|
+
|
|
443
|
+
# np.take does not support slice indices
|
|
444
|
+
# return np.take(chunk_outputs, idx_final, axis)
|
|
445
|
+
return chunk_outputs[tuple(idx_final if i == axis else slice(None) for i in range(chunk_outputs.ndim))]
|