dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,725 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import builtins
|
|
4
|
+
import math
|
|
5
|
+
import warnings
|
|
6
|
+
from functools import partial
|
|
7
|
+
from itertools import product
|
|
8
|
+
from numbers import Integral
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from tlz import compose, get, partition_all
|
|
12
|
+
|
|
13
|
+
from dask import config
|
|
14
|
+
from dask_array._new_collection import new_collection
|
|
15
|
+
from dask_array._expr import ArrayExpr
|
|
16
|
+
from dask_array._utils import compute_meta
|
|
17
|
+
from dask_array._core_utils import _concatenate2
|
|
18
|
+
from dask_array._numpy_compat import ComplexWarning
|
|
19
|
+
from dask_array._utils import is_arraylike, validate_axis
|
|
20
|
+
from dask.blockwise import lol_tuples
|
|
21
|
+
from dask.tokenize import _tokenize_deterministic
|
|
22
|
+
from dask.utils import cached_property, funcname, getargspec, is_series_like
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Reduction(ArrayExpr):
|
|
26
|
+
"""Logical reduction expression that captures reduction intent.
|
|
27
|
+
|
|
28
|
+
This expression represents a reduction operation conceptually,
|
|
29
|
+
without immediately materializing the physical Blockwise + PartialReduce
|
|
30
|
+
cascade. The physical implementation is deferred to _lower().
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
_parameters = [
|
|
34
|
+
"array",
|
|
35
|
+
"chunk",
|
|
36
|
+
"aggregate",
|
|
37
|
+
"axis",
|
|
38
|
+
"keepdims",
|
|
39
|
+
"dtype",
|
|
40
|
+
"split_every",
|
|
41
|
+
"combine",
|
|
42
|
+
"name",
|
|
43
|
+
"concatenate",
|
|
44
|
+
"output_size",
|
|
45
|
+
"meta",
|
|
46
|
+
"weights",
|
|
47
|
+
]
|
|
48
|
+
_defaults = {
|
|
49
|
+
"axis": None,
|
|
50
|
+
"keepdims": False,
|
|
51
|
+
"dtype": None,
|
|
52
|
+
"split_every": None,
|
|
53
|
+
"combine": None,
|
|
54
|
+
"name": None,
|
|
55
|
+
"concatenate": True,
|
|
56
|
+
"output_size": 1,
|
|
57
|
+
"meta": None,
|
|
58
|
+
"weights": None,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
def __dask_tokenize__(self):
|
|
62
|
+
if not self._determ_token:
|
|
63
|
+
self._determ_token = _tokenize_deterministic(
|
|
64
|
+
self.chunk,
|
|
65
|
+
self.aggregate,
|
|
66
|
+
self.array,
|
|
67
|
+
self.axis,
|
|
68
|
+
self.keepdims,
|
|
69
|
+
self.operand("dtype"),
|
|
70
|
+
self.split_every,
|
|
71
|
+
self.combine,
|
|
72
|
+
self.concatenate,
|
|
73
|
+
self.output_size,
|
|
74
|
+
self.weights,
|
|
75
|
+
)
|
|
76
|
+
return self._determ_token
|
|
77
|
+
|
|
78
|
+
@cached_property
|
|
79
|
+
def _name(self):
|
|
80
|
+
prefix = self.operand("name") or funcname(self.chunk)
|
|
81
|
+
return f"{prefix}-{self.deterministic_token}"
|
|
82
|
+
|
|
83
|
+
@cached_property
|
|
84
|
+
def name(self):
|
|
85
|
+
"""Return the name of the final lowered expression.
|
|
86
|
+
|
|
87
|
+
This ensures that Array.name matches the task keys in the graph.
|
|
88
|
+
"""
|
|
89
|
+
return self.lower_completely().name
|
|
90
|
+
|
|
91
|
+
@cached_property
|
|
92
|
+
def chunks(self):
|
|
93
|
+
"""Output chunks after reduction."""
|
|
94
|
+
axis = self.axis
|
|
95
|
+
if self.keepdims:
|
|
96
|
+
return tuple((self.output_size,) if i in axis else c for i, c in enumerate(self.array.chunks))
|
|
97
|
+
else:
|
|
98
|
+
return tuple(c for i, c in enumerate(self.array.chunks) if i not in axis)
|
|
99
|
+
|
|
100
|
+
@cached_property
|
|
101
|
+
def dtype(self):
|
|
102
|
+
if self.operand("dtype") is not None:
|
|
103
|
+
return np.dtype(self.operand("dtype"))
|
|
104
|
+
return self.array.dtype
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def _meta(self):
|
|
108
|
+
# Compute a minimal metadata array with correct dtype and ndim
|
|
109
|
+
dtype = self.dtype
|
|
110
|
+
ndim = len(self.chunks)
|
|
111
|
+
return np.empty((0,) * ndim, dtype=dtype)
|
|
112
|
+
|
|
113
|
+
def _layer(self):
|
|
114
|
+
"""Generate the task layer by lowering first.
|
|
115
|
+
|
|
116
|
+
Reduction should always be lowered before graph generation,
|
|
117
|
+
but we need to support direct _layer() calls for is_dask_collection().
|
|
118
|
+
"""
|
|
119
|
+
return self.lower_completely()._layer()
|
|
120
|
+
|
|
121
|
+
def _simplify_up(self, parent, dependents):
|
|
122
|
+
"""Allow slice operations to push through Reduction."""
|
|
123
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
124
|
+
|
|
125
|
+
if isinstance(parent, SliceSlicesIntegers):
|
|
126
|
+
return self._accept_slice(parent)
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
def _accept_slice(self, slice_expr):
|
|
130
|
+
"""Accept a slice being pushed through this Reduction."""
|
|
131
|
+
reduced_axes = set(self.axis)
|
|
132
|
+
|
|
133
|
+
def make_result(sliced_input, input_index):
|
|
134
|
+
# Handle sliced weights if present
|
|
135
|
+
sliced_weights = None
|
|
136
|
+
if self.weights is not None:
|
|
137
|
+
sliced_weights = new_collection(self.weights)[input_index].expr
|
|
138
|
+
|
|
139
|
+
return Reduction(
|
|
140
|
+
sliced_input.expr,
|
|
141
|
+
self.chunk,
|
|
142
|
+
self.aggregate,
|
|
143
|
+
self.axis,
|
|
144
|
+
self.keepdims,
|
|
145
|
+
self.operand("dtype"),
|
|
146
|
+
self.split_every,
|
|
147
|
+
self.combine,
|
|
148
|
+
self.operand("name"),
|
|
149
|
+
self.concatenate,
|
|
150
|
+
self.output_size,
|
|
151
|
+
self.meta,
|
|
152
|
+
sliced_weights,
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
return _accept_slice_impl(slice_expr, self.array, reduced_axes, self.keepdims, make_result)
|
|
156
|
+
|
|
157
|
+
def _lower(self):
|
|
158
|
+
"""Lower to Blockwise + PartialReduce cascade."""
|
|
159
|
+
from dask_array._collection import blockwise
|
|
160
|
+
|
|
161
|
+
axis = self.axis
|
|
162
|
+
dtype = self.operand("dtype") or float
|
|
163
|
+
name = self.operand("name")
|
|
164
|
+
output_size = self.output_size
|
|
165
|
+
|
|
166
|
+
# Prepare chunk function with dtype if needed
|
|
167
|
+
chunk_func = self.chunk
|
|
168
|
+
if "dtype" in getargspec(chunk_func).args:
|
|
169
|
+
chunk_func = partial(chunk_func, dtype=dtype)
|
|
170
|
+
|
|
171
|
+
aggregate_func = self.aggregate
|
|
172
|
+
if "dtype" in getargspec(aggregate_func).args:
|
|
173
|
+
aggregate_func = partial(aggregate_func, dtype=dtype)
|
|
174
|
+
|
|
175
|
+
# Build args for blockwise
|
|
176
|
+
inds = tuple(range(self.array.ndim))
|
|
177
|
+
args = (self.array, inds)
|
|
178
|
+
|
|
179
|
+
if self.weights is not None:
|
|
180
|
+
args += (self.weights, inds)
|
|
181
|
+
|
|
182
|
+
# Create Blockwise for per-chunk reduction
|
|
183
|
+
adjust_chunks = {i: output_size for i in axis}
|
|
184
|
+
tmp = blockwise(
|
|
185
|
+
chunk_func,
|
|
186
|
+
inds,
|
|
187
|
+
*args,
|
|
188
|
+
axis=axis,
|
|
189
|
+
keepdims=True,
|
|
190
|
+
token=name,
|
|
191
|
+
dtype=dtype,
|
|
192
|
+
adjust_chunks=adjust_chunks,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Compute reduced_meta for PartialReduce
|
|
196
|
+
if self.meta is None and hasattr(self.array, "_meta"):
|
|
197
|
+
try:
|
|
198
|
+
reduced_meta = compute_meta(
|
|
199
|
+
chunk_func, self.array.dtype, self.array._meta, axis=axis, keepdims=True, computing_meta=True
|
|
200
|
+
)
|
|
201
|
+
except TypeError:
|
|
202
|
+
reduced_meta = compute_meta(chunk_func, self.array.dtype, self.array._meta, axis=axis, keepdims=True)
|
|
203
|
+
except ValueError:
|
|
204
|
+
reduced_meta = None
|
|
205
|
+
else:
|
|
206
|
+
reduced_meta = self.meta
|
|
207
|
+
|
|
208
|
+
# Build tree reduction with PartialReduce
|
|
209
|
+
result = _build_tree_reduce_expr(
|
|
210
|
+
tmp.expr,
|
|
211
|
+
aggregate_func,
|
|
212
|
+
axis,
|
|
213
|
+
self.keepdims,
|
|
214
|
+
dtype,
|
|
215
|
+
self.split_every,
|
|
216
|
+
self.combine,
|
|
217
|
+
name,
|
|
218
|
+
self.concatenate,
|
|
219
|
+
reduced_meta,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
# Override final chunks for output_size != 1
|
|
223
|
+
if self.keepdims and output_size != 1:
|
|
224
|
+
from dask_array._expr import ChunksOverride
|
|
225
|
+
|
|
226
|
+
final_chunks = tuple((output_size,) if i in axis else c for i, c in enumerate(result.chunks))
|
|
227
|
+
result = ChunksOverride(result, final_chunks)
|
|
228
|
+
|
|
229
|
+
return result
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def reduction(
|
|
233
|
+
x,
|
|
234
|
+
chunk,
|
|
235
|
+
aggregate,
|
|
236
|
+
axis=None,
|
|
237
|
+
keepdims=False,
|
|
238
|
+
dtype=None,
|
|
239
|
+
split_every=None,
|
|
240
|
+
combine=None,
|
|
241
|
+
name=None,
|
|
242
|
+
out=None,
|
|
243
|
+
concatenate=True,
|
|
244
|
+
output_size=1,
|
|
245
|
+
meta=None,
|
|
246
|
+
weights=None,
|
|
247
|
+
):
|
|
248
|
+
"""General version of reductions
|
|
249
|
+
|
|
250
|
+
Parameters
|
|
251
|
+
----------
|
|
252
|
+
x: Array
|
|
253
|
+
Data being reduced along one or more axes
|
|
254
|
+
chunk: callable(x_chunk, [weights_chunk=None], axis, keepdims)
|
|
255
|
+
First function to be executed when resolving the dask graph.
|
|
256
|
+
This function is applied in parallel to all original chunks of x.
|
|
257
|
+
See below for function parameters.
|
|
258
|
+
combine: callable(x_chunk, axis, keepdims), optional
|
|
259
|
+
Function used for intermediate recursive aggregation (see
|
|
260
|
+
split_every below). If omitted, it defaults to aggregate.
|
|
261
|
+
If the reduction can be performed in less than 3 steps, it will not
|
|
262
|
+
be invoked at all.
|
|
263
|
+
aggregate: callable(x_chunk, axis, keepdims)
|
|
264
|
+
Last function to be executed when resolving the dask graph,
|
|
265
|
+
producing the final output. It is always invoked, even when the reduced
|
|
266
|
+
Array counts a single chunk along the reduced axes.
|
|
267
|
+
axis: int or sequence of ints, optional
|
|
268
|
+
Axis or axes to aggregate upon. If omitted, aggregate along all axes.
|
|
269
|
+
keepdims: boolean, optional
|
|
270
|
+
Whether the reduction function should preserve the reduced axes,
|
|
271
|
+
leaving them at size ``output_size``, or remove them.
|
|
272
|
+
dtype: np.dtype
|
|
273
|
+
data type of output. This argument was previously optional, but
|
|
274
|
+
leaving as ``None`` will now raise an exception.
|
|
275
|
+
split_every: int >= 2 or dict(axis: int), optional
|
|
276
|
+
Determines the depth of the recursive aggregation. If set to or more
|
|
277
|
+
than the number of input chunks, the aggregation will be performed in
|
|
278
|
+
two steps, one ``chunk`` function per input chunk and a single
|
|
279
|
+
``aggregate`` function at the end. If set to less than that, an
|
|
280
|
+
intermediate ``combine`` function will be used, so that any one
|
|
281
|
+
``combine`` or ``aggregate`` function has no more than ``split_every``
|
|
282
|
+
inputs. The depth of the aggregation graph will be
|
|
283
|
+
:math:`log_{split_every}(input chunks along reduced axes)`. Setting to
|
|
284
|
+
a low value can reduce cache size and network transfers, at the cost of
|
|
285
|
+
more CPU and a larger dask graph.
|
|
286
|
+
|
|
287
|
+
Omit to let dask heuristically decide a good default. A default can
|
|
288
|
+
also be set globally with the ``split_every`` key in
|
|
289
|
+
:mod:`dask.config`.
|
|
290
|
+
name: str, optional
|
|
291
|
+
Prefix of the keys of the intermediate and output nodes. If omitted it
|
|
292
|
+
defaults to the function names.
|
|
293
|
+
out: Array, optional
|
|
294
|
+
Another dask array whose contents will be replaced. Omit to create a
|
|
295
|
+
new one. Note that, unlike in numpy, this setting gives no performance
|
|
296
|
+
benefits whatsoever, but can still be useful if one needs to preserve
|
|
297
|
+
the references to a previously existing Array.
|
|
298
|
+
concatenate: bool, optional
|
|
299
|
+
If True (the default), the outputs of the ``chunk``/``combine``
|
|
300
|
+
functions are concatenated into a single np.array before being passed
|
|
301
|
+
to the ``combine``/``aggregate`` functions. If False, the input of
|
|
302
|
+
``combine`` and ``aggregate`` will be either a list of the raw outputs
|
|
303
|
+
of the previous step or a single output, and the function will have to
|
|
304
|
+
concatenate it itself. It can be useful to set this to False if the
|
|
305
|
+
chunk and/or combine steps do not produce np.arrays.
|
|
306
|
+
output_size: int >= 1, optional
|
|
307
|
+
Size of the output of the ``aggregate`` function along the reduced
|
|
308
|
+
axes. Ignored if keepdims is False.
|
|
309
|
+
weights : array_like, optional
|
|
310
|
+
Weights to be used in the reduction of `x`. Will be
|
|
311
|
+
automatically broadcast to the shape of `x`, and so must have
|
|
312
|
+
a compatible shape. For instance, if `x` has shape ``(3, 4)``
|
|
313
|
+
then acceptable shapes for `weights` are ``(3, 4)``, ``(4,)``,
|
|
314
|
+
``(3, 1)``, ``(1, 1)``, ``(1)``, and ``()``.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
dask array
|
|
319
|
+
|
|
320
|
+
**Function Parameters**
|
|
321
|
+
|
|
322
|
+
x_chunk: numpy.ndarray
|
|
323
|
+
Individual input chunk. For ``chunk`` functions, it is one of the
|
|
324
|
+
original chunks of x. For ``combine`` and ``aggregate`` functions, it's
|
|
325
|
+
the concatenation of the outputs produced by the previous ``chunk`` or
|
|
326
|
+
``combine`` functions. If concatenate=False, it's a list of the raw
|
|
327
|
+
outputs from the previous functions.
|
|
328
|
+
weights_chunk: numpy.ndarray, optional
|
|
329
|
+
Only applicable to the ``chunk`` function. Weights, with the
|
|
330
|
+
same shape as `x_chunk`, to be applied during the reduction of
|
|
331
|
+
the individual input chunk. If ``weights`` have not been
|
|
332
|
+
provided then the function may omit this parameter. When
|
|
333
|
+
`weights_chunk` is included then it must occur immediately
|
|
334
|
+
after the `x_chunk` parameter, and must also have a default
|
|
335
|
+
value for cases when ``weights`` are not provided.
|
|
336
|
+
axis: tuple
|
|
337
|
+
Normalized list of axes to reduce upon, e.g. ``(0, )``
|
|
338
|
+
Scalar, negative, and None axes have been normalized away.
|
|
339
|
+
Note that some numpy reduction functions cannot reduce along multiple
|
|
340
|
+
axes at once and strictly require an int in input. Such functions have
|
|
341
|
+
to be wrapped to cope.
|
|
342
|
+
keepdims: bool
|
|
343
|
+
Whether the reduction function should preserve the reduced axes or
|
|
344
|
+
remove them.
|
|
345
|
+
|
|
346
|
+
"""
|
|
347
|
+
# Convert non-dask arrays to dask arrays
|
|
348
|
+
from dask_array._collection import Array
|
|
349
|
+
|
|
350
|
+
if not isinstance(x, Array):
|
|
351
|
+
from dask_array.core._conversion import asanyarray
|
|
352
|
+
|
|
353
|
+
x = asanyarray(x)
|
|
354
|
+
|
|
355
|
+
if axis is None:
|
|
356
|
+
axis = tuple(range(x.ndim))
|
|
357
|
+
if isinstance(axis, Integral):
|
|
358
|
+
axis = (axis,)
|
|
359
|
+
axis = validate_axis(axis, x.ndim)
|
|
360
|
+
|
|
361
|
+
if dtype is None:
|
|
362
|
+
raise ValueError("Must specify dtype")
|
|
363
|
+
|
|
364
|
+
if is_series_like(x):
|
|
365
|
+
x = x.values
|
|
366
|
+
|
|
367
|
+
# Handle weights broadcasting
|
|
368
|
+
weights_expr = None
|
|
369
|
+
if weights is not None:
|
|
370
|
+
from dask_array._broadcast import broadcast_to
|
|
371
|
+
from dask_array.core._conversion import asanyarray
|
|
372
|
+
|
|
373
|
+
wgt = asanyarray(weights)
|
|
374
|
+
try:
|
|
375
|
+
wgt = broadcast_to(wgt, x.shape)
|
|
376
|
+
except ValueError:
|
|
377
|
+
raise ValueError(f"Weights with shape {wgt.shape} are not broadcastable to x with shape {x.shape}")
|
|
378
|
+
weights_expr = wgt.expr
|
|
379
|
+
|
|
380
|
+
# Create the Reduction expression
|
|
381
|
+
result = new_collection(
|
|
382
|
+
Reduction(
|
|
383
|
+
x.expr,
|
|
384
|
+
chunk,
|
|
385
|
+
aggregate,
|
|
386
|
+
axis,
|
|
387
|
+
keepdims,
|
|
388
|
+
dtype,
|
|
389
|
+
split_every,
|
|
390
|
+
combine,
|
|
391
|
+
name,
|
|
392
|
+
concatenate,
|
|
393
|
+
output_size,
|
|
394
|
+
meta,
|
|
395
|
+
weights_expr,
|
|
396
|
+
)
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# Handle out= parameter
|
|
400
|
+
if out is not None:
|
|
401
|
+
from dask_array.core._blockwise_funcs import _handle_out
|
|
402
|
+
|
|
403
|
+
return _handle_out(out, result)
|
|
404
|
+
return result
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
def _tree_reduce(
|
|
408
|
+
x,
|
|
409
|
+
aggregate,
|
|
410
|
+
axis,
|
|
411
|
+
keepdims,
|
|
412
|
+
dtype,
|
|
413
|
+
split_every=None,
|
|
414
|
+
combine=None,
|
|
415
|
+
name=None,
|
|
416
|
+
concatenate=True,
|
|
417
|
+
reduced_meta=None,
|
|
418
|
+
):
|
|
419
|
+
"""Perform the tree reduction step of a reduction.
|
|
420
|
+
|
|
421
|
+
Lower level, users should use ``reduction`` or ``arg_reduction`` directly.
|
|
422
|
+
"""
|
|
423
|
+
return new_collection(
|
|
424
|
+
_build_tree_reduce_expr(
|
|
425
|
+
x, aggregate, axis, keepdims, dtype, split_every, combine, name, concatenate, reduced_meta
|
|
426
|
+
)
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def _build_tree_reduce_expr(
|
|
431
|
+
x,
|
|
432
|
+
aggregate,
|
|
433
|
+
axis,
|
|
434
|
+
keepdims,
|
|
435
|
+
dtype,
|
|
436
|
+
split_every,
|
|
437
|
+
combine,
|
|
438
|
+
name,
|
|
439
|
+
concatenate,
|
|
440
|
+
reduced_meta,
|
|
441
|
+
):
|
|
442
|
+
"""Build tree reduction cascade of PartialReduce expressions.
|
|
443
|
+
|
|
444
|
+
Shared implementation used by both Reduction._build_tree_reduce and _tree_reduce.
|
|
445
|
+
"""
|
|
446
|
+
# Normalize split_every
|
|
447
|
+
split_every = split_every or config.get("split_every", 16)
|
|
448
|
+
if isinstance(split_every, dict):
|
|
449
|
+
split_every = {k: split_every.get(k, 2) for k in axis}
|
|
450
|
+
elif isinstance(split_every, Integral):
|
|
451
|
+
n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2)
|
|
452
|
+
split_every = dict.fromkeys(axis, n)
|
|
453
|
+
else:
|
|
454
|
+
raise ValueError("split_every must be a int or a dict")
|
|
455
|
+
|
|
456
|
+
# Compute tree depth
|
|
457
|
+
depth = 1
|
|
458
|
+
for i, n in enumerate(x.numblocks):
|
|
459
|
+
if i in split_every and split_every[i] != 1:
|
|
460
|
+
depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i]))))
|
|
461
|
+
|
|
462
|
+
# Build combine function
|
|
463
|
+
func = partial(combine or aggregate, axis=axis, keepdims=True)
|
|
464
|
+
if concatenate:
|
|
465
|
+
func = compose(func, partial(_concatenate2, axes=sorted(axis)))
|
|
466
|
+
|
|
467
|
+
# Build intermediate PartialReduce layers
|
|
468
|
+
for _ in range(depth - 1):
|
|
469
|
+
x = PartialReduce(
|
|
470
|
+
x,
|
|
471
|
+
func,
|
|
472
|
+
split_every,
|
|
473
|
+
True,
|
|
474
|
+
dtype=dtype,
|
|
475
|
+
name=(name or funcname(combine or aggregate)) + "-partial",
|
|
476
|
+
reduced_meta=reduced_meta,
|
|
477
|
+
)
|
|
478
|
+
|
|
479
|
+
# Build final aggregate function
|
|
480
|
+
agg_func = partial(aggregate, axis=axis, keepdims=keepdims)
|
|
481
|
+
if concatenate:
|
|
482
|
+
agg_func = compose(agg_func, partial(_concatenate2, axes=sorted(axis)))
|
|
483
|
+
|
|
484
|
+
# Final aggregation layer
|
|
485
|
+
return PartialReduce(
|
|
486
|
+
x,
|
|
487
|
+
agg_func,
|
|
488
|
+
split_every,
|
|
489
|
+
keepdims=keepdims,
|
|
490
|
+
dtype=dtype,
|
|
491
|
+
name=(name or funcname(aggregate)) + "-aggregate",
|
|
492
|
+
reduced_meta=reduced_meta,
|
|
493
|
+
)
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def _accept_slice_impl(slice_expr, input_array, reduced_axes, keepdims, make_result):
|
|
497
|
+
"""Shared implementation for slice pushdown through reductions.
|
|
498
|
+
|
|
499
|
+
Parameters
|
|
500
|
+
----------
|
|
501
|
+
slice_expr : SliceSlicesIntegers
|
|
502
|
+
The slice expression being pushed through
|
|
503
|
+
input_array : ArrayExpr
|
|
504
|
+
The input array to the reduction
|
|
505
|
+
reduced_axes : set
|
|
506
|
+
Set of axes being reduced
|
|
507
|
+
keepdims : bool
|
|
508
|
+
Whether the reduction keeps dimensions
|
|
509
|
+
make_result : callable(sliced_input, input_index) -> expr
|
|
510
|
+
Factory function to create the result expression
|
|
511
|
+
|
|
512
|
+
Returns
|
|
513
|
+
-------
|
|
514
|
+
expr or None
|
|
515
|
+
The transformed expression, or None if slice cannot be pushed through
|
|
516
|
+
"""
|
|
517
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
518
|
+
|
|
519
|
+
index = slice_expr.index
|
|
520
|
+
|
|
521
|
+
# Don't handle None/newaxis
|
|
522
|
+
if any(idx is None for idx in index):
|
|
523
|
+
return None
|
|
524
|
+
|
|
525
|
+
input_ndim = input_array.ndim
|
|
526
|
+
|
|
527
|
+
if keepdims:
|
|
528
|
+
# With keepdims, output has same ndim as input
|
|
529
|
+
full_index = index + (slice(None),) * (input_ndim - len(index))
|
|
530
|
+
else:
|
|
531
|
+
# Without keepdims, reduced axes are removed from output
|
|
532
|
+
out_axis = [i for i in range(input_ndim) if i not in reduced_axes]
|
|
533
|
+
output_ndim = len(out_axis)
|
|
534
|
+
full_index = index + (slice(None),) * (output_ndim - len(index))
|
|
535
|
+
|
|
536
|
+
# Convert integers to size-1 slices to preserve dimensions
|
|
537
|
+
slice_index = tuple(slice(idx, idx + 1) if isinstance(idx, Integral) else idx for idx in full_index)
|
|
538
|
+
has_integers = any(isinstance(idx, Integral) for idx in full_index)
|
|
539
|
+
|
|
540
|
+
# Build input index mapping output axes to input axes
|
|
541
|
+
if keepdims:
|
|
542
|
+
input_index = slice_index
|
|
543
|
+
else:
|
|
544
|
+
input_index = []
|
|
545
|
+
out_pos = 0
|
|
546
|
+
for in_ax in range(input_ndim):
|
|
547
|
+
if in_ax in reduced_axes:
|
|
548
|
+
input_index.append(slice(None))
|
|
549
|
+
else:
|
|
550
|
+
input_index.append(slice_index[out_pos])
|
|
551
|
+
out_pos += 1
|
|
552
|
+
input_index = tuple(input_index)
|
|
553
|
+
|
|
554
|
+
# Apply the slice to the input
|
|
555
|
+
sliced_input = new_collection(input_array)[input_index]
|
|
556
|
+
|
|
557
|
+
# Don't push slice through if it would create empty arrays on non-reduced axes
|
|
558
|
+
for ax in range(input_ndim):
|
|
559
|
+
if ax not in reduced_axes and sliced_input.shape[ax] == 0:
|
|
560
|
+
return None
|
|
561
|
+
|
|
562
|
+
result = make_result(sliced_input, input_index)
|
|
563
|
+
|
|
564
|
+
# If we converted integers to slices, extract with [0] to restore dimensions
|
|
565
|
+
if has_integers:
|
|
566
|
+
extract_index = tuple(0 if isinstance(idx, Integral) else slice(None) for idx in full_index)
|
|
567
|
+
return SliceSlicesIntegers(result, extract_index, slice_expr.allow_getitem_optimization)
|
|
568
|
+
|
|
569
|
+
return result
|
|
570
|
+
|
|
571
|
+
|
|
572
|
+
class PartialReduce(ArrayExpr):
|
|
573
|
+
_parameters = [
|
|
574
|
+
"array",
|
|
575
|
+
"func",
|
|
576
|
+
"split_every",
|
|
577
|
+
"keepdims",
|
|
578
|
+
"dtype",
|
|
579
|
+
"name",
|
|
580
|
+
"reduced_meta",
|
|
581
|
+
]
|
|
582
|
+
_defaults = {
|
|
583
|
+
"keepdims": False,
|
|
584
|
+
"dtype": None,
|
|
585
|
+
"name": None,
|
|
586
|
+
"reduced_meta": None,
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
def __dask_tokenize__(self):
|
|
590
|
+
if not self._determ_token:
|
|
591
|
+
# TODO: Is there an actual need to overwrite this?
|
|
592
|
+
self._determ_token = _tokenize_deterministic(
|
|
593
|
+
self.func, self.array, self.split_every, self.keepdims, self.dtype
|
|
594
|
+
)
|
|
595
|
+
return self._determ_token
|
|
596
|
+
|
|
597
|
+
@cached_property
|
|
598
|
+
def _name(self):
|
|
599
|
+
return (self.operand("name") or funcname(self.func)) + "-" + self.deterministic_token
|
|
600
|
+
|
|
601
|
+
@cached_property
|
|
602
|
+
def dtype(self):
|
|
603
|
+
# Use the explicitly passed dtype parameter instead of inferring from meta
|
|
604
|
+
if self.operand("dtype") is not None:
|
|
605
|
+
return np.dtype(self.operand("dtype"))
|
|
606
|
+
return super().dtype
|
|
607
|
+
|
|
608
|
+
@cached_property
|
|
609
|
+
def chunks(self):
|
|
610
|
+
chunks = [
|
|
611
|
+
(tuple(1 for p in partition_all(self.split_every[i], c)) if i in self.split_every else c)
|
|
612
|
+
for (i, c) in enumerate(self.array.chunks)
|
|
613
|
+
]
|
|
614
|
+
|
|
615
|
+
if not self.keepdims:
|
|
616
|
+
out_axis = [i for i in range(self.array.ndim) if i not in self.split_every]
|
|
617
|
+
getter = lambda k: get(out_axis, k)
|
|
618
|
+
chunks = list(getter(chunks))
|
|
619
|
+
return tuple(chunks)
|
|
620
|
+
|
|
621
|
+
def _layer(self):
|
|
622
|
+
x = self.array
|
|
623
|
+
parts = [list(partition_all(self.split_every.get(i, 1), range(n))) for (i, n) in enumerate(x.numblocks)]
|
|
624
|
+
keys = product(*map(range, map(len, parts)))
|
|
625
|
+
if not self.keepdims:
|
|
626
|
+
out_axis = [i for i in range(x.ndim) if i not in self.split_every]
|
|
627
|
+
getter = lambda k: get(out_axis, k)
|
|
628
|
+
keys = map(getter, keys)
|
|
629
|
+
dsk = {}
|
|
630
|
+
for k, p in zip(keys, product(*parts)):
|
|
631
|
+
free = {i: j[0] for (i, j) in enumerate(p) if len(j) == 1 and i not in self.split_every}
|
|
632
|
+
dummy = dict(i for i in enumerate(p) if i[0] in self.split_every)
|
|
633
|
+
g = lol_tuples((x.name,), range(x.ndim), free, dummy)
|
|
634
|
+
dsk[(self._name,) + k] = (self.func, g)
|
|
635
|
+
|
|
636
|
+
return dsk
|
|
637
|
+
|
|
638
|
+
@property
|
|
639
|
+
def _meta(self):
|
|
640
|
+
meta = self.array._meta
|
|
641
|
+
original_dtype = getattr(self.reduced_meta, "dtype", None) or getattr(meta, "dtype", None)
|
|
642
|
+
|
|
643
|
+
if self.reduced_meta is not None:
|
|
644
|
+
try:
|
|
645
|
+
meta = self.func(self.reduced_meta, computing_meta=True)
|
|
646
|
+
except TypeError:
|
|
647
|
+
# No computing_meta kwarg, try without it
|
|
648
|
+
try:
|
|
649
|
+
meta = self.func(self.reduced_meta)
|
|
650
|
+
except ValueError as e:
|
|
651
|
+
if "zero-size array to reduction operation" in str(e):
|
|
652
|
+
meta = self.reduced_meta
|
|
653
|
+
except IndexError:
|
|
654
|
+
meta = self.reduced_meta
|
|
655
|
+
except (ValueError, IndexError):
|
|
656
|
+
# Can't compute on empty array (ufunc, argtopk, etc.)
|
|
657
|
+
meta = self.reduced_meta
|
|
658
|
+
|
|
659
|
+
# Ensure meta is array-like (func can return Python scalars for object dtype)
|
|
660
|
+
if not is_arraylike(meta) and meta is not None:
|
|
661
|
+
meta = np.array(meta, dtype=original_dtype or object)
|
|
662
|
+
|
|
663
|
+
# Reshape meta to match output dimensions
|
|
664
|
+
if is_arraylike(meta) and meta.ndim != len(self.chunks):
|
|
665
|
+
if len(self.chunks) == 0:
|
|
666
|
+
# 0D output - reduce to scalar
|
|
667
|
+
try:
|
|
668
|
+
meta = meta.sum()
|
|
669
|
+
if not hasattr(meta, "dtype"):
|
|
670
|
+
meta = np.array(meta, dtype=original_dtype)
|
|
671
|
+
except TypeError:
|
|
672
|
+
# dtype doesn't support sum (e.g., datetime64)
|
|
673
|
+
meta = np.empty((), dtype=meta.dtype)
|
|
674
|
+
else:
|
|
675
|
+
target_shape = (0,) * len(self.chunks)
|
|
676
|
+
# Use np.prod(shape) for array-likes that don't expose .size
|
|
677
|
+
meta_size = getattr(meta, "size", None)
|
|
678
|
+
if meta_size is None:
|
|
679
|
+
meta_size = np.prod(meta.shape)
|
|
680
|
+
if meta_size != 0:
|
|
681
|
+
# Can't reshape non-empty array to empty shape (e.g., scalar)
|
|
682
|
+
meta = np.empty(target_shape, dtype=meta.dtype)
|
|
683
|
+
else:
|
|
684
|
+
meta = meta.reshape(target_shape)
|
|
685
|
+
|
|
686
|
+
# Ensure meta has the correct dtype if dtype is explicitly specified
|
|
687
|
+
if self.operand("dtype") is not None and hasattr(meta, "dtype"):
|
|
688
|
+
target_dtype = np.dtype(self.operand("dtype"))
|
|
689
|
+
if meta.dtype != target_dtype:
|
|
690
|
+
with warnings.catch_warnings():
|
|
691
|
+
# Suppress ComplexWarning when casting complex to real (e.g., var)
|
|
692
|
+
warnings.filterwarnings("ignore", category=ComplexWarning)
|
|
693
|
+
meta = meta.astype(target_dtype)
|
|
694
|
+
|
|
695
|
+
# Convert MaskedConstant (np.ma.masked) to a proper MaskedArray
|
|
696
|
+
# since the singleton cannot be tokenized
|
|
697
|
+
if isinstance(meta, np.ma.core.MaskedConstant):
|
|
698
|
+
meta = np.ma.array(meta, ndmin=0)
|
|
699
|
+
|
|
700
|
+
return meta
|
|
701
|
+
|
|
702
|
+
def _simplify_up(self, parent, dependents):
|
|
703
|
+
"""Allow slice operations to push through PartialReduce."""
|
|
704
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
705
|
+
|
|
706
|
+
if isinstance(parent, SliceSlicesIntegers):
|
|
707
|
+
return self._accept_slice(parent)
|
|
708
|
+
return None
|
|
709
|
+
|
|
710
|
+
def _accept_slice(self, slice_expr):
|
|
711
|
+
"""Accept a slice being pushed through this PartialReduce."""
|
|
712
|
+
reduced_axes = set(self.split_every.keys())
|
|
713
|
+
|
|
714
|
+
def make_result(sliced_input, input_index):
|
|
715
|
+
return PartialReduce(
|
|
716
|
+
sliced_input.expr,
|
|
717
|
+
self.func,
|
|
718
|
+
self.split_every,
|
|
719
|
+
self.keepdims,
|
|
720
|
+
self.operand("dtype"),
|
|
721
|
+
self.operand("name"),
|
|
722
|
+
self.reduced_meta,
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
return _accept_slice_impl(slice_expr, self.array, reduced_axes, self.keepdims, make_result)
|