dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
|
@@ -0,0 +1,1082 @@
|
|
|
1
|
+
"""Common reduction functions using the expression-based reduction framework."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import builtins
|
|
6
|
+
import math
|
|
7
|
+
import warnings
|
|
8
|
+
from functools import partial
|
|
9
|
+
from numbers import Integral, Number
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
from dask.utils import deepmap, derived_from
|
|
13
|
+
|
|
14
|
+
from dask_array._core_utils import _concatenate2
|
|
15
|
+
from dask_array._dispatch import divide_lookup, numel_lookup, nannumel_lookup
|
|
16
|
+
from dask_array._utils import array_safe, asarray_safe, meta_from_array
|
|
17
|
+
from dask_array import _chunk as chunk
|
|
18
|
+
from dask_array.reductions._reduction import reduction
|
|
19
|
+
from dask_array.reductions._arg_reduction import arg_reduction
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def divide(a, b, dtype=None):
|
|
23
|
+
"""Safe divide handling different array types."""
|
|
24
|
+
key = lambda x: getattr(x, "__array_priority__", float("-inf"))
|
|
25
|
+
f = divide_lookup.dispatch(type(builtins.max(a, b, key=key)))
|
|
26
|
+
return f(a, b, dtype=dtype)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def numel(x, **kwargs):
|
|
30
|
+
"""Count number of elements."""
|
|
31
|
+
return numel_lookup(x, **kwargs)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def nannumel(x, **kwargs):
|
|
35
|
+
"""Count number of non-NaN elements."""
|
|
36
|
+
return nannumel_lookup(x, **kwargs)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Simple reductions
|
|
40
|
+
@derived_from(np)
|
|
41
|
+
def sum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
|
|
42
|
+
if dtype is None:
|
|
43
|
+
dtype = getattr(np.zeros(1, dtype=a.dtype).sum(), "dtype", object)
|
|
44
|
+
return reduction(
|
|
45
|
+
a,
|
|
46
|
+
chunk.sum,
|
|
47
|
+
chunk.sum,
|
|
48
|
+
axis=axis,
|
|
49
|
+
keepdims=keepdims,
|
|
50
|
+
dtype=dtype,
|
|
51
|
+
split_every=split_every,
|
|
52
|
+
out=out,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@derived_from(np)
|
|
57
|
+
def prod(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
|
|
58
|
+
if dtype is not None:
|
|
59
|
+
dt = dtype
|
|
60
|
+
else:
|
|
61
|
+
dt = getattr(np.ones((1,), dtype=a.dtype).prod(), "dtype", object)
|
|
62
|
+
return reduction(
|
|
63
|
+
a,
|
|
64
|
+
chunk.prod,
|
|
65
|
+
chunk.prod,
|
|
66
|
+
axis=axis,
|
|
67
|
+
keepdims=keepdims,
|
|
68
|
+
dtype=dt,
|
|
69
|
+
split_every=split_every,
|
|
70
|
+
out=out,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def chunk_min(x, axis=None, keepdims=None):
|
|
75
|
+
"""Version of np.min which ignores size 0 arrays"""
|
|
76
|
+
if x.size == 0:
|
|
77
|
+
return array_safe([], x, ndmin=x.ndim, dtype=x.dtype)
|
|
78
|
+
else:
|
|
79
|
+
return np.min(x, axis=axis, keepdims=keepdims)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def chunk_max(x, axis=None, keepdims=None):
|
|
83
|
+
"""Version of np.max which ignores size 0 arrays"""
|
|
84
|
+
if x.size == 0:
|
|
85
|
+
return array_safe([], x, ndmin=x.ndim, dtype=x.dtype)
|
|
86
|
+
else:
|
|
87
|
+
return np.max(x, axis=axis, keepdims=keepdims)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@derived_from(np)
|
|
91
|
+
def min(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
92
|
+
return reduction(
|
|
93
|
+
a,
|
|
94
|
+
chunk_min,
|
|
95
|
+
chunk.min,
|
|
96
|
+
combine=chunk_min,
|
|
97
|
+
axis=axis,
|
|
98
|
+
keepdims=keepdims,
|
|
99
|
+
dtype=a.dtype,
|
|
100
|
+
split_every=split_every,
|
|
101
|
+
out=out,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
@derived_from(np)
|
|
106
|
+
def max(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
107
|
+
return reduction(
|
|
108
|
+
a,
|
|
109
|
+
chunk_max,
|
|
110
|
+
chunk.max,
|
|
111
|
+
combine=chunk_max,
|
|
112
|
+
axis=axis,
|
|
113
|
+
keepdims=keepdims,
|
|
114
|
+
dtype=a.dtype,
|
|
115
|
+
split_every=split_every,
|
|
116
|
+
out=out,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@derived_from(np)
|
|
121
|
+
def any(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
122
|
+
return reduction(
|
|
123
|
+
a,
|
|
124
|
+
chunk.any,
|
|
125
|
+
chunk.any,
|
|
126
|
+
axis=axis,
|
|
127
|
+
keepdims=keepdims,
|
|
128
|
+
dtype="bool",
|
|
129
|
+
split_every=split_every,
|
|
130
|
+
out=out,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@derived_from(np)
|
|
135
|
+
def all(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
136
|
+
return reduction(
|
|
137
|
+
a,
|
|
138
|
+
chunk.all,
|
|
139
|
+
chunk.all,
|
|
140
|
+
axis=axis,
|
|
141
|
+
keepdims=keepdims,
|
|
142
|
+
dtype="bool",
|
|
143
|
+
split_every=split_every,
|
|
144
|
+
out=out,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
# Nan-aware simple reductions
|
|
149
|
+
@derived_from(np)
|
|
150
|
+
def nansum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
|
|
151
|
+
if dtype is not None:
|
|
152
|
+
dt = dtype
|
|
153
|
+
else:
|
|
154
|
+
dt = getattr(chunk.nansum(np.ones((1,), dtype=a.dtype)), "dtype", object)
|
|
155
|
+
return reduction(
|
|
156
|
+
a,
|
|
157
|
+
chunk.nansum,
|
|
158
|
+
chunk.sum,
|
|
159
|
+
axis=axis,
|
|
160
|
+
keepdims=keepdims,
|
|
161
|
+
dtype=dt,
|
|
162
|
+
split_every=split_every,
|
|
163
|
+
out=out,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@derived_from(np)
|
|
168
|
+
def nanprod(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
|
|
169
|
+
if dtype is not None:
|
|
170
|
+
dt = dtype
|
|
171
|
+
else:
|
|
172
|
+
dt = getattr(chunk.nansum(np.ones((1,), dtype=a.dtype)), "dtype", object)
|
|
173
|
+
return reduction(
|
|
174
|
+
a,
|
|
175
|
+
chunk.nanprod,
|
|
176
|
+
chunk.prod,
|
|
177
|
+
axis=axis,
|
|
178
|
+
keepdims=keepdims,
|
|
179
|
+
dtype=dt,
|
|
180
|
+
split_every=split_every,
|
|
181
|
+
out=out,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def _nanmin_skip(x_chunk, axis, keepdims):
|
|
186
|
+
if x_chunk.size > 0:
|
|
187
|
+
with warnings.catch_warnings():
|
|
188
|
+
warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning)
|
|
189
|
+
return np.nanmin(x_chunk, axis=axis, keepdims=keepdims)
|
|
190
|
+
else:
|
|
191
|
+
return asarray_safe(np.array([], dtype=x_chunk.dtype), like=meta_from_array(x_chunk))
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def _nanmax_skip(x_chunk, axis, keepdims):
|
|
195
|
+
if x_chunk.size > 0:
|
|
196
|
+
with warnings.catch_warnings():
|
|
197
|
+
warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning)
|
|
198
|
+
return np.nanmax(x_chunk, axis=axis, keepdims=keepdims)
|
|
199
|
+
else:
|
|
200
|
+
return asarray_safe(np.array([], dtype=x_chunk.dtype), like=meta_from_array(x_chunk))
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
@derived_from(np)
|
|
204
|
+
def nanmin(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
205
|
+
if np.isnan(a.size):
|
|
206
|
+
from dask_array._core_utils import unknown_chunk_message
|
|
207
|
+
|
|
208
|
+
raise ValueError(f"Arrays chunk sizes are unknown. {unknown_chunk_message}")
|
|
209
|
+
if a.size == 0:
|
|
210
|
+
raise ValueError("zero-size array to reduction operation fmin which has no identity")
|
|
211
|
+
return reduction(
|
|
212
|
+
a,
|
|
213
|
+
_nanmin_skip,
|
|
214
|
+
_nanmin_skip,
|
|
215
|
+
axis=axis,
|
|
216
|
+
keepdims=keepdims,
|
|
217
|
+
dtype=a.dtype,
|
|
218
|
+
split_every=split_every,
|
|
219
|
+
out=out,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
@derived_from(np)
|
|
224
|
+
def nanmax(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
225
|
+
if np.isnan(a.size):
|
|
226
|
+
from dask_array._core_utils import unknown_chunk_message
|
|
227
|
+
|
|
228
|
+
raise ValueError(f"Arrays chunk sizes are unknown. {unknown_chunk_message}")
|
|
229
|
+
if a.size == 0:
|
|
230
|
+
raise ValueError("zero-size array to reduction operation fmax which has no identity")
|
|
231
|
+
return reduction(
|
|
232
|
+
a,
|
|
233
|
+
_nanmax_skip,
|
|
234
|
+
_nanmax_skip,
|
|
235
|
+
axis=axis,
|
|
236
|
+
keepdims=keepdims,
|
|
237
|
+
dtype=a.dtype,
|
|
238
|
+
split_every=split_every,
|
|
239
|
+
out=out,
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
# Mean implementation
|
|
244
|
+
def mean_chunk(x, sum=chunk.sum, numel=numel, dtype="f8", computing_meta=False, **kwargs):
|
|
245
|
+
if computing_meta:
|
|
246
|
+
return x
|
|
247
|
+
n = numel(x, dtype=dtype, **kwargs)
|
|
248
|
+
total = sum(x, dtype=dtype, **kwargs)
|
|
249
|
+
return {"n": n, "total": total}
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def mean_combine(
|
|
253
|
+
pairs,
|
|
254
|
+
sum=chunk.sum,
|
|
255
|
+
numel=numel,
|
|
256
|
+
dtype="f8",
|
|
257
|
+
axis=None,
|
|
258
|
+
computing_meta=False,
|
|
259
|
+
**kwargs,
|
|
260
|
+
):
|
|
261
|
+
if not isinstance(pairs, list):
|
|
262
|
+
pairs = [pairs]
|
|
263
|
+
|
|
264
|
+
ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs
|
|
265
|
+
n = _concatenate2(ns, axes=axis).sum(axis=axis, **kwargs)
|
|
266
|
+
|
|
267
|
+
if computing_meta:
|
|
268
|
+
return n
|
|
269
|
+
|
|
270
|
+
totals = deepmap(lambda pair: pair["total"], pairs)
|
|
271
|
+
total = _concatenate2(totals, axes=axis).sum(axis=axis, **kwargs)
|
|
272
|
+
|
|
273
|
+
return {"n": n, "total": total}
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def mean_agg(pairs, dtype="f8", axis=None, computing_meta=False, **kwargs):
|
|
277
|
+
ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs
|
|
278
|
+
n = _concatenate2(ns, axes=axis)
|
|
279
|
+
n = np.sum(n, axis=axis, dtype=dtype, **kwargs)
|
|
280
|
+
|
|
281
|
+
if computing_meta:
|
|
282
|
+
return n
|
|
283
|
+
|
|
284
|
+
totals = deepmap(lambda pair: pair["total"], pairs)
|
|
285
|
+
total = _concatenate2(totals, axes=axis).sum(axis=axis, dtype=dtype, **kwargs)
|
|
286
|
+
|
|
287
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
288
|
+
return divide(total, n, dtype=dtype)
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
@derived_from(np)
|
|
292
|
+
def mean(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
|
|
293
|
+
if dtype is not None:
|
|
294
|
+
dt = dtype
|
|
295
|
+
elif a.dtype == object:
|
|
296
|
+
dt = object
|
|
297
|
+
else:
|
|
298
|
+
dt = getattr(np.mean(np.zeros(shape=(1,), dtype=a.dtype)), "dtype", object)
|
|
299
|
+
return reduction(
|
|
300
|
+
a,
|
|
301
|
+
mean_chunk,
|
|
302
|
+
mean_agg,
|
|
303
|
+
axis=axis,
|
|
304
|
+
keepdims=keepdims,
|
|
305
|
+
dtype=dt,
|
|
306
|
+
split_every=split_every,
|
|
307
|
+
combine=mean_combine,
|
|
308
|
+
out=out,
|
|
309
|
+
concatenate=False,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@derived_from(np)
|
|
314
|
+
def nanmean(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
|
|
315
|
+
if dtype is not None:
|
|
316
|
+
dt = dtype
|
|
317
|
+
else:
|
|
318
|
+
dt = getattr(np.mean(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
|
|
319
|
+
return reduction(
|
|
320
|
+
a,
|
|
321
|
+
partial(mean_chunk, sum=chunk.nansum, numel=nannumel),
|
|
322
|
+
mean_agg,
|
|
323
|
+
axis=axis,
|
|
324
|
+
keepdims=keepdims,
|
|
325
|
+
dtype=dt,
|
|
326
|
+
split_every=split_every,
|
|
327
|
+
out=out,
|
|
328
|
+
concatenate=False,
|
|
329
|
+
combine=partial(mean_combine, sum=chunk.nansum, numel=nannumel),
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
|
|
333
|
+
# Moment/variance/std implementation
|
|
334
|
+
def moment_chunk(
|
|
335
|
+
A,
|
|
336
|
+
order=2,
|
|
337
|
+
sum=chunk.sum,
|
|
338
|
+
numel=numel,
|
|
339
|
+
dtype="f8",
|
|
340
|
+
computing_meta=False,
|
|
341
|
+
implicit_complex_dtype=False,
|
|
342
|
+
**kwargs,
|
|
343
|
+
):
|
|
344
|
+
if computing_meta:
|
|
345
|
+
return A
|
|
346
|
+
n = numel(A, **kwargs)
|
|
347
|
+
|
|
348
|
+
n = n.astype(np.int64)
|
|
349
|
+
if implicit_complex_dtype:
|
|
350
|
+
total = sum(A, **kwargs)
|
|
351
|
+
else:
|
|
352
|
+
total = sum(A, dtype=dtype, **kwargs)
|
|
353
|
+
|
|
354
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
355
|
+
u = total / n
|
|
356
|
+
d = A - u
|
|
357
|
+
if np.issubdtype(A.dtype, np.complexfloating):
|
|
358
|
+
d = np.abs(d)
|
|
359
|
+
xs = [sum(d**i, dtype=dtype, **kwargs) for i in range(2, order + 1)]
|
|
360
|
+
M = np.stack(xs, axis=-1)
|
|
361
|
+
return {"total": total, "n": n, "M": M}
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _moment_helper(Ms, ns, inner_term, order, sum, axis, kwargs):
|
|
365
|
+
M = Ms[..., order - 2].sum(axis=axis, **kwargs) + sum(ns * inner_term**order, axis=axis, **kwargs)
|
|
366
|
+
for k in range(1, order - 1):
|
|
367
|
+
coeff = math.factorial(order) / (math.factorial(k) * math.factorial(order - k))
|
|
368
|
+
M += coeff * sum(Ms[..., order - k - 2] * inner_term**k, axis=axis, **kwargs)
|
|
369
|
+
return M
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def moment_combine(
|
|
373
|
+
pairs,
|
|
374
|
+
order=2,
|
|
375
|
+
ddof=0,
|
|
376
|
+
dtype="f8",
|
|
377
|
+
sum=np.sum,
|
|
378
|
+
axis=None,
|
|
379
|
+
computing_meta=False,
|
|
380
|
+
**kwargs,
|
|
381
|
+
):
|
|
382
|
+
if not isinstance(pairs, list):
|
|
383
|
+
pairs = [pairs]
|
|
384
|
+
|
|
385
|
+
kwargs["dtype"] = None
|
|
386
|
+
kwargs["keepdims"] = True
|
|
387
|
+
|
|
388
|
+
ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs
|
|
389
|
+
ns = _concatenate2(ns, axes=axis)
|
|
390
|
+
n = ns.sum(axis=axis, **kwargs)
|
|
391
|
+
|
|
392
|
+
if computing_meta:
|
|
393
|
+
return n
|
|
394
|
+
|
|
395
|
+
totals = _concatenate2(deepmap(lambda pair: pair["total"], pairs), axes=axis)
|
|
396
|
+
Ms = _concatenate2(deepmap(lambda pair: pair["M"], pairs), axes=axis)
|
|
397
|
+
|
|
398
|
+
total = totals.sum(axis=axis, **kwargs)
|
|
399
|
+
|
|
400
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
401
|
+
if np.issubdtype(total.dtype, np.complexfloating):
|
|
402
|
+
mu = divide(total, n)
|
|
403
|
+
inner_term = np.abs(divide(totals, ns) - mu)
|
|
404
|
+
else:
|
|
405
|
+
mu = divide(total, n, dtype=dtype)
|
|
406
|
+
inner_term = divide(totals, ns, dtype=dtype) - mu
|
|
407
|
+
|
|
408
|
+
xs = [_moment_helper(Ms, ns, inner_term, o, sum, axis, kwargs) for o in range(2, order + 1)]
|
|
409
|
+
M = np.stack(xs, axis=-1)
|
|
410
|
+
return {"total": total, "n": n, "M": M}
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
def moment_agg(
|
|
414
|
+
pairs,
|
|
415
|
+
order=2,
|
|
416
|
+
ddof=0,
|
|
417
|
+
dtype="f8",
|
|
418
|
+
sum=np.sum,
|
|
419
|
+
axis=None,
|
|
420
|
+
computing_meta=False,
|
|
421
|
+
**kwargs,
|
|
422
|
+
):
|
|
423
|
+
if not isinstance(pairs, list):
|
|
424
|
+
pairs = [pairs]
|
|
425
|
+
|
|
426
|
+
kwargs["dtype"] = dtype
|
|
427
|
+
# To properly handle ndarrays, the original dimensions need to be kept for
|
|
428
|
+
# part of the calculation.
|
|
429
|
+
keepdim_kw = kwargs.copy()
|
|
430
|
+
keepdim_kw["keepdims"] = True
|
|
431
|
+
keepdim_kw["dtype"] = None
|
|
432
|
+
|
|
433
|
+
ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs
|
|
434
|
+
ns = _concatenate2(ns, axes=axis)
|
|
435
|
+
n = ns.sum(axis=axis, **keepdim_kw)
|
|
436
|
+
|
|
437
|
+
if computing_meta:
|
|
438
|
+
return n
|
|
439
|
+
|
|
440
|
+
totals = _concatenate2(deepmap(lambda pair: pair["total"], pairs), axes=axis)
|
|
441
|
+
Ms = _concatenate2(deepmap(lambda pair: pair["M"], pairs), axes=axis)
|
|
442
|
+
|
|
443
|
+
mu = divide(totals.sum(axis=axis, **keepdim_kw), n)
|
|
444
|
+
|
|
445
|
+
with np.errstate(divide="ignore", invalid="ignore"):
|
|
446
|
+
if np.issubdtype(totals.dtype, np.complexfloating):
|
|
447
|
+
inner_term = np.abs(divide(totals, ns) - mu)
|
|
448
|
+
else:
|
|
449
|
+
inner_term = divide(totals, ns, dtype=dtype) - mu
|
|
450
|
+
inner_term = np.where(ns == 0, 0, inner_term)
|
|
451
|
+
M = _moment_helper(Ms, ns, inner_term, order, sum, axis, kwargs)
|
|
452
|
+
|
|
453
|
+
denominator = n.sum(axis=axis, **kwargs) - ddof
|
|
454
|
+
|
|
455
|
+
# taking care of the edge case with empty or all-nans array with ddof > 0
|
|
456
|
+
if isinstance(denominator, Number):
|
|
457
|
+
if denominator < 0:
|
|
458
|
+
denominator = np.nan
|
|
459
|
+
elif denominator is not np.ma.masked:
|
|
460
|
+
denominator[denominator < 0] = np.nan
|
|
461
|
+
|
|
462
|
+
return divide(M, denominator, dtype=dtype)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def moment(a, order, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
|
|
466
|
+
"""Calculate the nth centralized moment.
|
|
467
|
+
|
|
468
|
+
Parameters
|
|
469
|
+
----------
|
|
470
|
+
a : Array
|
|
471
|
+
Data over which to compute moment
|
|
472
|
+
order : int
|
|
473
|
+
Order of the moment that is returned, must be >= 2.
|
|
474
|
+
axis : int, optional
|
|
475
|
+
Axis along which the central moment is computed. The default is to
|
|
476
|
+
compute the moment of the flattened array.
|
|
477
|
+
dtype : data-type, optional
|
|
478
|
+
Type to use in computing the moment. For arrays of integer type the
|
|
479
|
+
default is float64; for arrays of float types it is the same as the
|
|
480
|
+
array type.
|
|
481
|
+
keepdims : bool, optional
|
|
482
|
+
If this is set to True, the axes which are reduced are left in the
|
|
483
|
+
result as dimensions with size one. With this option, the result
|
|
484
|
+
will broadcast correctly against the original array.
|
|
485
|
+
ddof : int, optional
|
|
486
|
+
"Delta Degrees of Freedom": the divisor used in the calculation is
|
|
487
|
+
N - ddof, where N represents the number of elements. By default
|
|
488
|
+
ddof is zero.
|
|
489
|
+
|
|
490
|
+
Returns
|
|
491
|
+
-------
|
|
492
|
+
moment : Array
|
|
493
|
+
"""
|
|
494
|
+
if not isinstance(order, Integral) or order < 0:
|
|
495
|
+
raise ValueError("Order must be an integer >= 0")
|
|
496
|
+
|
|
497
|
+
if order < 2:
|
|
498
|
+
from dask_array.creation import ones, zeros
|
|
499
|
+
|
|
500
|
+
reduced = a.sum(axis=axis) # get reduced shape and chunks
|
|
501
|
+
if order == 0:
|
|
502
|
+
# When order equals 0, the result is 1, by definition.
|
|
503
|
+
return ones(reduced.shape, chunks=reduced.chunks, dtype="f8", meta=reduced._meta)
|
|
504
|
+
# By definition the first order about the mean is 0.
|
|
505
|
+
return zeros(reduced.shape, chunks=reduced.chunks, dtype="f8", meta=reduced._meta)
|
|
506
|
+
|
|
507
|
+
if dtype is not None:
|
|
508
|
+
dt = dtype
|
|
509
|
+
else:
|
|
510
|
+
dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
|
|
511
|
+
|
|
512
|
+
implicit_complex_dtype = dtype is None and np.iscomplexobj(a)
|
|
513
|
+
|
|
514
|
+
return reduction(
|
|
515
|
+
a,
|
|
516
|
+
partial(moment_chunk, order=order, implicit_complex_dtype=implicit_complex_dtype),
|
|
517
|
+
partial(moment_agg, order=order, ddof=ddof),
|
|
518
|
+
axis=axis,
|
|
519
|
+
keepdims=keepdims,
|
|
520
|
+
dtype=dt,
|
|
521
|
+
split_every=split_every,
|
|
522
|
+
out=out,
|
|
523
|
+
concatenate=False,
|
|
524
|
+
combine=partial(moment_combine, order=order),
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
@derived_from(np)
|
|
529
|
+
def var(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
|
|
530
|
+
if dtype is not None:
|
|
531
|
+
dt = dtype
|
|
532
|
+
else:
|
|
533
|
+
dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
|
|
534
|
+
|
|
535
|
+
implicit_complex_dtype = dtype is None and np.iscomplexobj(a)
|
|
536
|
+
|
|
537
|
+
return reduction(
|
|
538
|
+
a,
|
|
539
|
+
partial(moment_chunk, implicit_complex_dtype=implicit_complex_dtype),
|
|
540
|
+
partial(moment_agg, ddof=ddof),
|
|
541
|
+
axis=axis,
|
|
542
|
+
keepdims=keepdims,
|
|
543
|
+
dtype=dt,
|
|
544
|
+
split_every=split_every,
|
|
545
|
+
combine=moment_combine,
|
|
546
|
+
name="var",
|
|
547
|
+
out=out,
|
|
548
|
+
concatenate=False,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
@derived_from(np)
|
|
553
|
+
def nanvar(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
|
|
554
|
+
if dtype is not None:
|
|
555
|
+
dt = dtype
|
|
556
|
+
else:
|
|
557
|
+
dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
|
|
558
|
+
|
|
559
|
+
implicit_complex_dtype = dtype is None and np.iscomplexobj(a)
|
|
560
|
+
|
|
561
|
+
return reduction(
|
|
562
|
+
a,
|
|
563
|
+
partial(
|
|
564
|
+
moment_chunk,
|
|
565
|
+
sum=chunk.nansum,
|
|
566
|
+
numel=nannumel,
|
|
567
|
+
implicit_complex_dtype=implicit_complex_dtype,
|
|
568
|
+
),
|
|
569
|
+
partial(moment_agg, sum=np.sum, ddof=ddof),
|
|
570
|
+
axis=axis,
|
|
571
|
+
keepdims=keepdims,
|
|
572
|
+
dtype=dt,
|
|
573
|
+
split_every=split_every,
|
|
574
|
+
combine=partial(moment_combine, sum=np.nansum),
|
|
575
|
+
out=out,
|
|
576
|
+
concatenate=False,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _sqrt(a):
|
|
581
|
+
if isinstance(a, np.ma.masked_array) and not a.shape and a.mask.all():
|
|
582
|
+
return np.ma.masked
|
|
583
|
+
return np.sqrt(a)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def safe_sqrt(a):
|
|
587
|
+
"""A version of sqrt that properly handles scalar masked arrays."""
|
|
588
|
+
if hasattr(a, "_elemwise"):
|
|
589
|
+
return a._elemwise(_sqrt, a)
|
|
590
|
+
return _sqrt(a)
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
@derived_from(np)
|
|
594
|
+
def std(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
|
|
595
|
+
result = safe_sqrt(
|
|
596
|
+
var(
|
|
597
|
+
a,
|
|
598
|
+
axis=axis,
|
|
599
|
+
dtype=dtype,
|
|
600
|
+
keepdims=keepdims,
|
|
601
|
+
ddof=ddof,
|
|
602
|
+
split_every=split_every,
|
|
603
|
+
out=out,
|
|
604
|
+
)
|
|
605
|
+
)
|
|
606
|
+
if dtype and dtype != result.dtype:
|
|
607
|
+
result = result.astype(dtype)
|
|
608
|
+
return result
|
|
609
|
+
|
|
610
|
+
|
|
611
|
+
@derived_from(np)
|
|
612
|
+
def nanstd(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
|
|
613
|
+
result = safe_sqrt(
|
|
614
|
+
nanvar(
|
|
615
|
+
a,
|
|
616
|
+
axis=axis,
|
|
617
|
+
dtype=dtype,
|
|
618
|
+
keepdims=keepdims,
|
|
619
|
+
ddof=ddof,
|
|
620
|
+
split_every=split_every,
|
|
621
|
+
out=out,
|
|
622
|
+
)
|
|
623
|
+
)
|
|
624
|
+
if dtype and dtype != result.dtype:
|
|
625
|
+
result = result.astype(dtype)
|
|
626
|
+
return result
|
|
627
|
+
|
|
628
|
+
|
|
629
|
+
# Arg reductions helpers
|
|
630
|
+
def _arg_combine(data, axis, argfunc, keepdims=False):
|
|
631
|
+
"""Merge intermediate results from ``arg_*`` functions"""
|
|
632
|
+
if isinstance(data, dict):
|
|
633
|
+
# Array type doesn't support structured arrays (e.g., CuPy),
|
|
634
|
+
# therefore `data` is stored in a `dict`.
|
|
635
|
+
assert data["vals"].ndim == data["arg"].ndim
|
|
636
|
+
axis = None if len(axis) == data["vals"].ndim or data["vals"].ndim == 1 else axis[0]
|
|
637
|
+
else:
|
|
638
|
+
axis = None if len(axis) == data.ndim or data.ndim == 1 else axis[0]
|
|
639
|
+
|
|
640
|
+
vals = data["vals"]
|
|
641
|
+
arg = data["arg"]
|
|
642
|
+
if axis is None:
|
|
643
|
+
local_args = argfunc(vals, axis=axis, keepdims=keepdims)
|
|
644
|
+
vals = vals.ravel()[local_args]
|
|
645
|
+
arg = arg.ravel()[local_args]
|
|
646
|
+
else:
|
|
647
|
+
local_args = argfunc(vals, axis=axis)
|
|
648
|
+
inds = list(np.ogrid[tuple(map(slice, local_args.shape))])
|
|
649
|
+
inds.insert(axis, local_args)
|
|
650
|
+
inds = tuple(inds)
|
|
651
|
+
vals = vals[inds]
|
|
652
|
+
arg = arg[inds]
|
|
653
|
+
if keepdims:
|
|
654
|
+
vals = np.expand_dims(vals, axis)
|
|
655
|
+
arg = np.expand_dims(arg, axis)
|
|
656
|
+
return arg, vals
|
|
657
|
+
|
|
658
|
+
|
|
659
|
+
def arg_chunk(func, argfunc, x, axis, offset_info):
|
|
660
|
+
arg_axis = None if len(axis) == x.ndim or x.ndim == 1 else axis[0]
|
|
661
|
+
vals = func(x, axis=arg_axis, keepdims=True)
|
|
662
|
+
arg = argfunc(x, axis=arg_axis, keepdims=True)
|
|
663
|
+
if x.ndim > 0:
|
|
664
|
+
if arg_axis is None:
|
|
665
|
+
offset, total_shape = offset_info
|
|
666
|
+
ind = np.unravel_index(arg.ravel()[0], x.shape)
|
|
667
|
+
total_ind = tuple(o + i for (o, i) in zip(offset, ind))
|
|
668
|
+
arg[:] = np.ravel_multi_index(total_ind, total_shape)
|
|
669
|
+
else:
|
|
670
|
+
arg += offset_info
|
|
671
|
+
|
|
672
|
+
if isinstance(vals, np.ma.masked_array):
|
|
673
|
+
if "min" in argfunc.__name__:
|
|
674
|
+
fill_value = np.ma.minimum_fill_value(vals)
|
|
675
|
+
else:
|
|
676
|
+
fill_value = np.ma.maximum_fill_value(vals)
|
|
677
|
+
vals = np.ma.filled(vals, fill_value)
|
|
678
|
+
|
|
679
|
+
try:
|
|
680
|
+
result = np.empty_like(vals, shape=vals.shape, dtype=[("vals", vals.dtype), ("arg", arg.dtype)])
|
|
681
|
+
except TypeError:
|
|
682
|
+
# Array type doesn't support structured arrays (e.g., CuPy)
|
|
683
|
+
result = dict()
|
|
684
|
+
|
|
685
|
+
result["vals"] = vals
|
|
686
|
+
result["arg"] = arg
|
|
687
|
+
return result
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def arg_combine(argfunc, data, axis=None, **kwargs):
|
|
691
|
+
arg, vals = _arg_combine(data, axis, argfunc, keepdims=True)
|
|
692
|
+
|
|
693
|
+
try:
|
|
694
|
+
result = np.empty_like(vals, shape=vals.shape, dtype=[("vals", vals.dtype), ("arg", arg.dtype)])
|
|
695
|
+
except TypeError:
|
|
696
|
+
# Array type doesn't support structured arrays (e.g., CuPy).
|
|
697
|
+
result = dict()
|
|
698
|
+
|
|
699
|
+
result["vals"] = vals
|
|
700
|
+
result["arg"] = arg
|
|
701
|
+
return result
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def arg_agg(argfunc, data, axis=None, keepdims=False, **kwargs):
|
|
705
|
+
return _arg_combine(data, axis, argfunc, keepdims=keepdims)[0]
|
|
706
|
+
|
|
707
|
+
|
|
708
|
+
def nanarg_agg(argfunc, data, axis=None, keepdims=False, **kwargs):
|
|
709
|
+
arg, vals = _arg_combine(data, axis, argfunc, keepdims=keepdims)
|
|
710
|
+
if np.any(np.isnan(vals)):
|
|
711
|
+
raise ValueError("All NaN slice encountered")
|
|
712
|
+
return arg
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
def _nanargmin(x, axis, **kwargs):
|
|
716
|
+
try:
|
|
717
|
+
return chunk.nanargmin(x, axis, **kwargs)
|
|
718
|
+
except ValueError:
|
|
719
|
+
return chunk.nanargmin(np.where(np.isnan(x), np.inf, x), axis, **kwargs)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def _nanargmax(x, axis, **kwargs):
|
|
723
|
+
try:
|
|
724
|
+
return chunk.nanargmax(x, axis, **kwargs)
|
|
725
|
+
except ValueError:
|
|
726
|
+
return chunk.nanargmax(np.where(np.isnan(x), -np.inf, x), axis, **kwargs)
|
|
727
|
+
|
|
728
|
+
|
|
729
|
+
@derived_from(np)
|
|
730
|
+
def argmax(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
731
|
+
return arg_reduction(
|
|
732
|
+
a,
|
|
733
|
+
partial(arg_chunk, chunk.max, chunk.argmax),
|
|
734
|
+
partial(arg_combine, chunk.argmax),
|
|
735
|
+
partial(arg_agg, chunk.argmax),
|
|
736
|
+
axis=axis,
|
|
737
|
+
keepdims=keepdims,
|
|
738
|
+
split_every=split_every,
|
|
739
|
+
out=out,
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
@derived_from(np)
|
|
744
|
+
def argmin(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
745
|
+
return arg_reduction(
|
|
746
|
+
a,
|
|
747
|
+
partial(arg_chunk, chunk.min, chunk.argmin),
|
|
748
|
+
partial(arg_combine, chunk.argmin),
|
|
749
|
+
partial(arg_agg, chunk.argmin),
|
|
750
|
+
axis=axis,
|
|
751
|
+
keepdims=keepdims,
|
|
752
|
+
split_every=split_every,
|
|
753
|
+
out=out,
|
|
754
|
+
)
|
|
755
|
+
|
|
756
|
+
|
|
757
|
+
@derived_from(np)
|
|
758
|
+
def nanargmax(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
759
|
+
return arg_reduction(
|
|
760
|
+
a,
|
|
761
|
+
partial(arg_chunk, chunk.nanmax, _nanargmax),
|
|
762
|
+
partial(arg_combine, _nanargmax),
|
|
763
|
+
partial(nanarg_agg, _nanargmax),
|
|
764
|
+
axis=axis,
|
|
765
|
+
keepdims=keepdims,
|
|
766
|
+
split_every=split_every,
|
|
767
|
+
out=out,
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
|
|
771
|
+
@derived_from(np)
|
|
772
|
+
def nanargmin(a, axis=None, keepdims=False, split_every=None, out=None):
|
|
773
|
+
return arg_reduction(
|
|
774
|
+
a,
|
|
775
|
+
partial(arg_chunk, chunk.nanmin, _nanargmin),
|
|
776
|
+
partial(arg_combine, _nanargmin),
|
|
777
|
+
partial(nanarg_agg, _nanargmin),
|
|
778
|
+
axis=axis,
|
|
779
|
+
keepdims=keepdims,
|
|
780
|
+
split_every=split_every,
|
|
781
|
+
out=out,
|
|
782
|
+
)
|
|
783
|
+
|
|
784
|
+
|
|
785
|
+
# Median and quantile functions
|
|
786
|
+
from collections.abc import Iterable
|
|
787
|
+
from functools import reduce
|
|
788
|
+
from operator import mul
|
|
789
|
+
|
|
790
|
+
from dask_array._core_utils import handle_out
|
|
791
|
+
|
|
792
|
+
try:
|
|
793
|
+
import numbagg
|
|
794
|
+
except ImportError:
|
|
795
|
+
numbagg = None
|
|
796
|
+
|
|
797
|
+
|
|
798
|
+
@derived_from(np)
|
|
799
|
+
def median(a, axis=None, keepdims=False, out=None):
|
|
800
|
+
"""
|
|
801
|
+
This works by automatically chunking the reduced axes to a single chunk if necessary
|
|
802
|
+
and then calling ``numpy.median`` function across the remaining dimensions
|
|
803
|
+
"""
|
|
804
|
+
if axis is None:
|
|
805
|
+
raise NotImplementedError(
|
|
806
|
+
"The da.median function only works along an axis. The full algorithm is difficult to do in parallel"
|
|
807
|
+
)
|
|
808
|
+
|
|
809
|
+
if not isinstance(axis, Iterable):
|
|
810
|
+
axis = (axis,)
|
|
811
|
+
|
|
812
|
+
axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
|
|
813
|
+
|
|
814
|
+
# rechunk if reduced axes are not contained in a single chunk
|
|
815
|
+
if builtins.any(a.numblocks[ax] > 1 for ax in axis):
|
|
816
|
+
a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
|
|
817
|
+
|
|
818
|
+
result = a.map_blocks(
|
|
819
|
+
np.median,
|
|
820
|
+
axis=axis,
|
|
821
|
+
keepdims=keepdims,
|
|
822
|
+
drop_axis=axis if not keepdims else None,
|
|
823
|
+
chunks=([1 if ax in axis else c for ax, c in enumerate(a.chunks)] if keepdims else None),
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
result = handle_out(out, result)
|
|
827
|
+
return result
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
@derived_from(np)
|
|
831
|
+
def nanmedian(a, axis=None, keepdims=False, out=None):
|
|
832
|
+
"""
|
|
833
|
+
This works by automatically chunking the reduced axes to a single chunk
|
|
834
|
+
and then calling ``numpy.nanmedian`` function across the remaining dimensions
|
|
835
|
+
"""
|
|
836
|
+
from packaging.version import Version
|
|
837
|
+
|
|
838
|
+
if axis is None:
|
|
839
|
+
raise NotImplementedError(
|
|
840
|
+
"The da.nanmedian function only works along an axis or a subset of axes. "
|
|
841
|
+
"The full algorithm is difficult to do in parallel"
|
|
842
|
+
)
|
|
843
|
+
|
|
844
|
+
if not isinstance(axis, Iterable):
|
|
845
|
+
axis = (axis,)
|
|
846
|
+
|
|
847
|
+
axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
|
|
848
|
+
|
|
849
|
+
# rechunk if reduced axes are not contained in a single chunk
|
|
850
|
+
if builtins.any(a.numblocks[ax] > 1 for ax in axis):
|
|
851
|
+
a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
|
|
852
|
+
|
|
853
|
+
if (
|
|
854
|
+
numbagg is not None
|
|
855
|
+
and Version(numbagg.__version__).release >= (0, 7, 0)
|
|
856
|
+
and a.dtype.kind in "uif"
|
|
857
|
+
and not keepdims
|
|
858
|
+
):
|
|
859
|
+
func = numbagg.nanmedian
|
|
860
|
+
kwargs = {}
|
|
861
|
+
else:
|
|
862
|
+
func = np.nanmedian
|
|
863
|
+
kwargs = {"keepdims": keepdims}
|
|
864
|
+
|
|
865
|
+
result = a.map_blocks(
|
|
866
|
+
func,
|
|
867
|
+
axis=axis,
|
|
868
|
+
drop_axis=axis if not keepdims else None,
|
|
869
|
+
chunks=([1 if ax in axis else c for ax, c in enumerate(a.chunks)] if keepdims else None),
|
|
870
|
+
**kwargs,
|
|
871
|
+
)
|
|
872
|
+
|
|
873
|
+
result = handle_out(out, result)
|
|
874
|
+
return result
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
def _get_quantile_chunks(a, q, axis, keepdims):
|
|
878
|
+
quantile_chunk = [len(q)] if isinstance(q, Iterable) else []
|
|
879
|
+
if keepdims:
|
|
880
|
+
return quantile_chunk + [1 if ax in axis else c for ax, c in enumerate(a.chunks)]
|
|
881
|
+
else:
|
|
882
|
+
return quantile_chunk + [c for ax, c in enumerate(a.chunks) if ax not in axis]
|
|
883
|
+
|
|
884
|
+
|
|
885
|
+
def _span_indexers(a):
|
|
886
|
+
shapes = 1 if len(a.shape) <= 2 else reduce(mul, list(a.shape)[1:-1])
|
|
887
|
+
original_shapes = shapes * a.shape[0]
|
|
888
|
+
indexers = [tuple(np.repeat(np.arange(a.shape[0]), shapes))]
|
|
889
|
+
|
|
890
|
+
for i in range(1, len(a.shape) - 1):
|
|
891
|
+
indexer = np.repeat(np.arange(a.shape[i]), shapes // a.shape[i])
|
|
892
|
+
indexers.append(tuple(np.tile(indexer, original_shapes // shapes)))
|
|
893
|
+
shapes //= a.shape[i]
|
|
894
|
+
return indexers
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
def _custom_quantile(a, q, axis=None, method="linear", keepdims=False, **kwargs):
|
|
898
|
+
if method != "linear" or len(axis) != 1 or axis[0] != len(a.shape) - 1 or len(a.shape) == 1 or a.shape[-1] > 1000:
|
|
899
|
+
return np.nanquantile(a, q, axis=axis, method=method, keepdims=keepdims, **kwargs)
|
|
900
|
+
|
|
901
|
+
sorted_arr = np.sort(a, axis=-1)
|
|
902
|
+
indexers = _span_indexers(a)
|
|
903
|
+
nr_quantiles = len(indexers[0])
|
|
904
|
+
|
|
905
|
+
is_scalar = False
|
|
906
|
+
if not isinstance(q, Iterable):
|
|
907
|
+
is_scalar = True
|
|
908
|
+
q = [q]
|
|
909
|
+
|
|
910
|
+
quantiles = []
|
|
911
|
+
reshape_shapes = (1,) + tuple(sorted_arr.shape[:-1]) + ((1,) if keepdims else ())
|
|
912
|
+
for single_q in list(q):
|
|
913
|
+
i = (np.ones(nr_quantiles) * (a.shape[-1] - 1) - np.isnan(sorted_arr).sum(axis=-1).reshape(-1)) * single_q
|
|
914
|
+
lower_value, higher_value = np.floor(i).astype(int), np.ceil(i).astype(int)
|
|
915
|
+
|
|
916
|
+
lower = sorted_arr[tuple(indexers) + (tuple(lower_value),)]
|
|
917
|
+
higher = sorted_arr[tuple(indexers) + (tuple(higher_value),)]
|
|
918
|
+
|
|
919
|
+
factor_higher = i - lower_value
|
|
920
|
+
factor_higher = np.where(factor_higher == 0.0, 1.0, factor_higher)
|
|
921
|
+
factor_lower = higher_value - i
|
|
922
|
+
|
|
923
|
+
quantiles.append((higher * factor_higher + lower * factor_lower).reshape(*reshape_shapes))
|
|
924
|
+
|
|
925
|
+
if is_scalar:
|
|
926
|
+
return quantiles[0].squeeze(axis=0)
|
|
927
|
+
else:
|
|
928
|
+
return np.concatenate(quantiles, axis=0)
|
|
929
|
+
|
|
930
|
+
|
|
931
|
+
@derived_from(np)
|
|
932
|
+
def quantile(
|
|
933
|
+
a,
|
|
934
|
+
q,
|
|
935
|
+
axis=None,
|
|
936
|
+
out=None,
|
|
937
|
+
overwrite_input=False,
|
|
938
|
+
method="linear",
|
|
939
|
+
keepdims=False,
|
|
940
|
+
*,
|
|
941
|
+
weights=None,
|
|
942
|
+
interpolation=None,
|
|
943
|
+
):
|
|
944
|
+
"""
|
|
945
|
+
This works by automatically chunking the reduced axes to a single chunk if necessary
|
|
946
|
+
and then calling ``numpy.quantile`` function across the remaining dimensions
|
|
947
|
+
"""
|
|
948
|
+
if interpolation is not None:
|
|
949
|
+
warnings.warn(
|
|
950
|
+
"The `interpolation` argument to quantile was renamed to `method`.",
|
|
951
|
+
FutureWarning,
|
|
952
|
+
stacklevel=2,
|
|
953
|
+
)
|
|
954
|
+
if method != "linear":
|
|
955
|
+
raise TypeError("Cannot pass interpolation and method keywords!")
|
|
956
|
+
method = interpolation
|
|
957
|
+
|
|
958
|
+
if axis is None:
|
|
959
|
+
if builtins.any(n_blocks > 1 for n_blocks in a.numblocks):
|
|
960
|
+
raise NotImplementedError(
|
|
961
|
+
"The da.quantile function only works along an axis. The full algorithm is difficult to do in parallel"
|
|
962
|
+
)
|
|
963
|
+
else:
|
|
964
|
+
axis = tuple(range(len(a.shape)))
|
|
965
|
+
|
|
966
|
+
if not isinstance(axis, Iterable):
|
|
967
|
+
axis = (axis,)
|
|
968
|
+
|
|
969
|
+
axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
|
|
970
|
+
|
|
971
|
+
# rechunk if reduced axes are not contained in a single chunk
|
|
972
|
+
if builtins.any(a.numblocks[ax] > 1 for ax in axis):
|
|
973
|
+
a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
|
|
974
|
+
|
|
975
|
+
# NumPy >= 2.0 supports weights
|
|
976
|
+
kwargs = {}
|
|
977
|
+
try:
|
|
978
|
+
# Check if weights parameter is supported
|
|
979
|
+
import numpy as np
|
|
980
|
+
|
|
981
|
+
if hasattr(np.quantile, "__wrapped__") or weights is not None:
|
|
982
|
+
kwargs["weights"] = weights
|
|
983
|
+
except Exception:
|
|
984
|
+
pass
|
|
985
|
+
|
|
986
|
+
result = a.map_blocks(
|
|
987
|
+
np.quantile,
|
|
988
|
+
q=q,
|
|
989
|
+
method=method,
|
|
990
|
+
axis=axis,
|
|
991
|
+
keepdims=keepdims,
|
|
992
|
+
drop_axis=axis if not keepdims else None,
|
|
993
|
+
new_axis=0 if isinstance(q, Iterable) else None,
|
|
994
|
+
chunks=_get_quantile_chunks(a, q, axis, keepdims),
|
|
995
|
+
dtype=np.quantile(np.array([0], dtype=a.dtype), q).dtype,
|
|
996
|
+
**kwargs,
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
result = handle_out(out, result)
|
|
1000
|
+
return result
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
@derived_from(np)
|
|
1004
|
+
def nanquantile(
|
|
1005
|
+
a,
|
|
1006
|
+
q,
|
|
1007
|
+
axis=None,
|
|
1008
|
+
out=None,
|
|
1009
|
+
overwrite_input=False,
|
|
1010
|
+
method="linear",
|
|
1011
|
+
keepdims=False,
|
|
1012
|
+
*,
|
|
1013
|
+
weights=None,
|
|
1014
|
+
interpolation=None,
|
|
1015
|
+
):
|
|
1016
|
+
"""
|
|
1017
|
+
This works by automatically chunking the reduced axes to a single chunk
|
|
1018
|
+
and then calling ``numpy.nanquantile`` function across the remaining dimensions
|
|
1019
|
+
"""
|
|
1020
|
+
from packaging.version import Version
|
|
1021
|
+
|
|
1022
|
+
if interpolation is not None:
|
|
1023
|
+
warnings.warn(
|
|
1024
|
+
"The `interpolation` argument to nanquantile was renamed to `method`.",
|
|
1025
|
+
FutureWarning,
|
|
1026
|
+
stacklevel=2,
|
|
1027
|
+
)
|
|
1028
|
+
if method != "linear":
|
|
1029
|
+
raise TypeError("Cannot pass interpolation and method keywords!")
|
|
1030
|
+
method = interpolation
|
|
1031
|
+
|
|
1032
|
+
if axis is None:
|
|
1033
|
+
if builtins.any(n_blocks > 1 for n_blocks in a.numblocks):
|
|
1034
|
+
raise NotImplementedError(
|
|
1035
|
+
"The da.nanquantile function only works along an axis. "
|
|
1036
|
+
"The full algorithm is difficult to do in parallel"
|
|
1037
|
+
)
|
|
1038
|
+
else:
|
|
1039
|
+
axis = tuple(range(len(a.shape)))
|
|
1040
|
+
|
|
1041
|
+
if not isinstance(axis, Iterable):
|
|
1042
|
+
axis = (axis,)
|
|
1043
|
+
|
|
1044
|
+
axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
|
|
1045
|
+
|
|
1046
|
+
# rechunk if reduced axes are not contained in a single chunk
|
|
1047
|
+
if builtins.any(a.numblocks[ax] > 1 for ax in axis):
|
|
1048
|
+
a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
|
|
1049
|
+
|
|
1050
|
+
if (
|
|
1051
|
+
numbagg is not None
|
|
1052
|
+
and Version(numbagg.__version__).release >= (0, 8, 0)
|
|
1053
|
+
and a.dtype.kind in "uif"
|
|
1054
|
+
and weights is None
|
|
1055
|
+
and method == "linear"
|
|
1056
|
+
and not keepdims
|
|
1057
|
+
):
|
|
1058
|
+
func = numbagg.nanquantile
|
|
1059
|
+
kwargs = {"quantiles": q}
|
|
1060
|
+
else:
|
|
1061
|
+
func = _custom_quantile
|
|
1062
|
+
kwargs = {
|
|
1063
|
+
"q": q,
|
|
1064
|
+
"method": method,
|
|
1065
|
+
"keepdims": keepdims,
|
|
1066
|
+
}
|
|
1067
|
+
# NumPy >= 2.0 supports weights
|
|
1068
|
+
if weights is not None:
|
|
1069
|
+
kwargs["weights"] = weights
|
|
1070
|
+
|
|
1071
|
+
result = a.map_blocks(
|
|
1072
|
+
func,
|
|
1073
|
+
axis=axis,
|
|
1074
|
+
drop_axis=axis if not keepdims else None,
|
|
1075
|
+
new_axis=0 if isinstance(q, Iterable) else None,
|
|
1076
|
+
chunks=_get_quantile_chunks(a, q, axis, keepdims),
|
|
1077
|
+
dtype=np.nanquantile(np.array([0], dtype=a.dtype), q).dtype,
|
|
1078
|
+
**kwargs,
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
result = handle_out(out, result)
|
|
1082
|
+
return result
|