dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,1082 @@
1
+ """Common reduction functions using the expression-based reduction framework."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import builtins
6
+ import math
7
+ import warnings
8
+ from functools import partial
9
+ from numbers import Integral, Number
10
+
11
+ import numpy as np
12
+ from dask.utils import deepmap, derived_from
13
+
14
+ from dask_array._core_utils import _concatenate2
15
+ from dask_array._dispatch import divide_lookup, numel_lookup, nannumel_lookup
16
+ from dask_array._utils import array_safe, asarray_safe, meta_from_array
17
+ from dask_array import _chunk as chunk
18
+ from dask_array.reductions._reduction import reduction
19
+ from dask_array.reductions._arg_reduction import arg_reduction
20
+
21
+
22
+ def divide(a, b, dtype=None):
23
+ """Safe divide handling different array types."""
24
+ key = lambda x: getattr(x, "__array_priority__", float("-inf"))
25
+ f = divide_lookup.dispatch(type(builtins.max(a, b, key=key)))
26
+ return f(a, b, dtype=dtype)
27
+
28
+
29
+ def numel(x, **kwargs):
30
+ """Count number of elements."""
31
+ return numel_lookup(x, **kwargs)
32
+
33
+
34
+ def nannumel(x, **kwargs):
35
+ """Count number of non-NaN elements."""
36
+ return nannumel_lookup(x, **kwargs)
37
+
38
+
39
+ # Simple reductions
40
+ @derived_from(np)
41
+ def sum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
42
+ if dtype is None:
43
+ dtype = getattr(np.zeros(1, dtype=a.dtype).sum(), "dtype", object)
44
+ return reduction(
45
+ a,
46
+ chunk.sum,
47
+ chunk.sum,
48
+ axis=axis,
49
+ keepdims=keepdims,
50
+ dtype=dtype,
51
+ split_every=split_every,
52
+ out=out,
53
+ )
54
+
55
+
56
+ @derived_from(np)
57
+ def prod(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
58
+ if dtype is not None:
59
+ dt = dtype
60
+ else:
61
+ dt = getattr(np.ones((1,), dtype=a.dtype).prod(), "dtype", object)
62
+ return reduction(
63
+ a,
64
+ chunk.prod,
65
+ chunk.prod,
66
+ axis=axis,
67
+ keepdims=keepdims,
68
+ dtype=dt,
69
+ split_every=split_every,
70
+ out=out,
71
+ )
72
+
73
+
74
+ def chunk_min(x, axis=None, keepdims=None):
75
+ """Version of np.min which ignores size 0 arrays"""
76
+ if x.size == 0:
77
+ return array_safe([], x, ndmin=x.ndim, dtype=x.dtype)
78
+ else:
79
+ return np.min(x, axis=axis, keepdims=keepdims)
80
+
81
+
82
+ def chunk_max(x, axis=None, keepdims=None):
83
+ """Version of np.max which ignores size 0 arrays"""
84
+ if x.size == 0:
85
+ return array_safe([], x, ndmin=x.ndim, dtype=x.dtype)
86
+ else:
87
+ return np.max(x, axis=axis, keepdims=keepdims)
88
+
89
+
90
+ @derived_from(np)
91
+ def min(a, axis=None, keepdims=False, split_every=None, out=None):
92
+ return reduction(
93
+ a,
94
+ chunk_min,
95
+ chunk.min,
96
+ combine=chunk_min,
97
+ axis=axis,
98
+ keepdims=keepdims,
99
+ dtype=a.dtype,
100
+ split_every=split_every,
101
+ out=out,
102
+ )
103
+
104
+
105
+ @derived_from(np)
106
+ def max(a, axis=None, keepdims=False, split_every=None, out=None):
107
+ return reduction(
108
+ a,
109
+ chunk_max,
110
+ chunk.max,
111
+ combine=chunk_max,
112
+ axis=axis,
113
+ keepdims=keepdims,
114
+ dtype=a.dtype,
115
+ split_every=split_every,
116
+ out=out,
117
+ )
118
+
119
+
120
+ @derived_from(np)
121
+ def any(a, axis=None, keepdims=False, split_every=None, out=None):
122
+ return reduction(
123
+ a,
124
+ chunk.any,
125
+ chunk.any,
126
+ axis=axis,
127
+ keepdims=keepdims,
128
+ dtype="bool",
129
+ split_every=split_every,
130
+ out=out,
131
+ )
132
+
133
+
134
+ @derived_from(np)
135
+ def all(a, axis=None, keepdims=False, split_every=None, out=None):
136
+ return reduction(
137
+ a,
138
+ chunk.all,
139
+ chunk.all,
140
+ axis=axis,
141
+ keepdims=keepdims,
142
+ dtype="bool",
143
+ split_every=split_every,
144
+ out=out,
145
+ )
146
+
147
+
148
+ # Nan-aware simple reductions
149
+ @derived_from(np)
150
+ def nansum(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
151
+ if dtype is not None:
152
+ dt = dtype
153
+ else:
154
+ dt = getattr(chunk.nansum(np.ones((1,), dtype=a.dtype)), "dtype", object)
155
+ return reduction(
156
+ a,
157
+ chunk.nansum,
158
+ chunk.sum,
159
+ axis=axis,
160
+ keepdims=keepdims,
161
+ dtype=dt,
162
+ split_every=split_every,
163
+ out=out,
164
+ )
165
+
166
+
167
+ @derived_from(np)
168
+ def nanprod(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
169
+ if dtype is not None:
170
+ dt = dtype
171
+ else:
172
+ dt = getattr(chunk.nansum(np.ones((1,), dtype=a.dtype)), "dtype", object)
173
+ return reduction(
174
+ a,
175
+ chunk.nanprod,
176
+ chunk.prod,
177
+ axis=axis,
178
+ keepdims=keepdims,
179
+ dtype=dt,
180
+ split_every=split_every,
181
+ out=out,
182
+ )
183
+
184
+
185
+ def _nanmin_skip(x_chunk, axis, keepdims):
186
+ if x_chunk.size > 0:
187
+ with warnings.catch_warnings():
188
+ warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning)
189
+ return np.nanmin(x_chunk, axis=axis, keepdims=keepdims)
190
+ else:
191
+ return asarray_safe(np.array([], dtype=x_chunk.dtype), like=meta_from_array(x_chunk))
192
+
193
+
194
+ def _nanmax_skip(x_chunk, axis, keepdims):
195
+ if x_chunk.size > 0:
196
+ with warnings.catch_warnings():
197
+ warnings.filterwarnings("ignore", "All-NaN slice encountered", RuntimeWarning)
198
+ return np.nanmax(x_chunk, axis=axis, keepdims=keepdims)
199
+ else:
200
+ return asarray_safe(np.array([], dtype=x_chunk.dtype), like=meta_from_array(x_chunk))
201
+
202
+
203
+ @derived_from(np)
204
+ def nanmin(a, axis=None, keepdims=False, split_every=None, out=None):
205
+ if np.isnan(a.size):
206
+ from dask_array._core_utils import unknown_chunk_message
207
+
208
+ raise ValueError(f"Arrays chunk sizes are unknown. {unknown_chunk_message}")
209
+ if a.size == 0:
210
+ raise ValueError("zero-size array to reduction operation fmin which has no identity")
211
+ return reduction(
212
+ a,
213
+ _nanmin_skip,
214
+ _nanmin_skip,
215
+ axis=axis,
216
+ keepdims=keepdims,
217
+ dtype=a.dtype,
218
+ split_every=split_every,
219
+ out=out,
220
+ )
221
+
222
+
223
+ @derived_from(np)
224
+ def nanmax(a, axis=None, keepdims=False, split_every=None, out=None):
225
+ if np.isnan(a.size):
226
+ from dask_array._core_utils import unknown_chunk_message
227
+
228
+ raise ValueError(f"Arrays chunk sizes are unknown. {unknown_chunk_message}")
229
+ if a.size == 0:
230
+ raise ValueError("zero-size array to reduction operation fmax which has no identity")
231
+ return reduction(
232
+ a,
233
+ _nanmax_skip,
234
+ _nanmax_skip,
235
+ axis=axis,
236
+ keepdims=keepdims,
237
+ dtype=a.dtype,
238
+ split_every=split_every,
239
+ out=out,
240
+ )
241
+
242
+
243
+ # Mean implementation
244
+ def mean_chunk(x, sum=chunk.sum, numel=numel, dtype="f8", computing_meta=False, **kwargs):
245
+ if computing_meta:
246
+ return x
247
+ n = numel(x, dtype=dtype, **kwargs)
248
+ total = sum(x, dtype=dtype, **kwargs)
249
+ return {"n": n, "total": total}
250
+
251
+
252
+ def mean_combine(
253
+ pairs,
254
+ sum=chunk.sum,
255
+ numel=numel,
256
+ dtype="f8",
257
+ axis=None,
258
+ computing_meta=False,
259
+ **kwargs,
260
+ ):
261
+ if not isinstance(pairs, list):
262
+ pairs = [pairs]
263
+
264
+ ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs
265
+ n = _concatenate2(ns, axes=axis).sum(axis=axis, **kwargs)
266
+
267
+ if computing_meta:
268
+ return n
269
+
270
+ totals = deepmap(lambda pair: pair["total"], pairs)
271
+ total = _concatenate2(totals, axes=axis).sum(axis=axis, **kwargs)
272
+
273
+ return {"n": n, "total": total}
274
+
275
+
276
+ def mean_agg(pairs, dtype="f8", axis=None, computing_meta=False, **kwargs):
277
+ ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs
278
+ n = _concatenate2(ns, axes=axis)
279
+ n = np.sum(n, axis=axis, dtype=dtype, **kwargs)
280
+
281
+ if computing_meta:
282
+ return n
283
+
284
+ totals = deepmap(lambda pair: pair["total"], pairs)
285
+ total = _concatenate2(totals, axes=axis).sum(axis=axis, dtype=dtype, **kwargs)
286
+
287
+ with np.errstate(divide="ignore", invalid="ignore"):
288
+ return divide(total, n, dtype=dtype)
289
+
290
+
291
+ @derived_from(np)
292
+ def mean(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
293
+ if dtype is not None:
294
+ dt = dtype
295
+ elif a.dtype == object:
296
+ dt = object
297
+ else:
298
+ dt = getattr(np.mean(np.zeros(shape=(1,), dtype=a.dtype)), "dtype", object)
299
+ return reduction(
300
+ a,
301
+ mean_chunk,
302
+ mean_agg,
303
+ axis=axis,
304
+ keepdims=keepdims,
305
+ dtype=dt,
306
+ split_every=split_every,
307
+ combine=mean_combine,
308
+ out=out,
309
+ concatenate=False,
310
+ )
311
+
312
+
313
+ @derived_from(np)
314
+ def nanmean(a, axis=None, dtype=None, keepdims=False, split_every=None, out=None):
315
+ if dtype is not None:
316
+ dt = dtype
317
+ else:
318
+ dt = getattr(np.mean(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
319
+ return reduction(
320
+ a,
321
+ partial(mean_chunk, sum=chunk.nansum, numel=nannumel),
322
+ mean_agg,
323
+ axis=axis,
324
+ keepdims=keepdims,
325
+ dtype=dt,
326
+ split_every=split_every,
327
+ out=out,
328
+ concatenate=False,
329
+ combine=partial(mean_combine, sum=chunk.nansum, numel=nannumel),
330
+ )
331
+
332
+
333
+ # Moment/variance/std implementation
334
+ def moment_chunk(
335
+ A,
336
+ order=2,
337
+ sum=chunk.sum,
338
+ numel=numel,
339
+ dtype="f8",
340
+ computing_meta=False,
341
+ implicit_complex_dtype=False,
342
+ **kwargs,
343
+ ):
344
+ if computing_meta:
345
+ return A
346
+ n = numel(A, **kwargs)
347
+
348
+ n = n.astype(np.int64)
349
+ if implicit_complex_dtype:
350
+ total = sum(A, **kwargs)
351
+ else:
352
+ total = sum(A, dtype=dtype, **kwargs)
353
+
354
+ with np.errstate(divide="ignore", invalid="ignore"):
355
+ u = total / n
356
+ d = A - u
357
+ if np.issubdtype(A.dtype, np.complexfloating):
358
+ d = np.abs(d)
359
+ xs = [sum(d**i, dtype=dtype, **kwargs) for i in range(2, order + 1)]
360
+ M = np.stack(xs, axis=-1)
361
+ return {"total": total, "n": n, "M": M}
362
+
363
+
364
+ def _moment_helper(Ms, ns, inner_term, order, sum, axis, kwargs):
365
+ M = Ms[..., order - 2].sum(axis=axis, **kwargs) + sum(ns * inner_term**order, axis=axis, **kwargs)
366
+ for k in range(1, order - 1):
367
+ coeff = math.factorial(order) / (math.factorial(k) * math.factorial(order - k))
368
+ M += coeff * sum(Ms[..., order - k - 2] * inner_term**k, axis=axis, **kwargs)
369
+ return M
370
+
371
+
372
+ def moment_combine(
373
+ pairs,
374
+ order=2,
375
+ ddof=0,
376
+ dtype="f8",
377
+ sum=np.sum,
378
+ axis=None,
379
+ computing_meta=False,
380
+ **kwargs,
381
+ ):
382
+ if not isinstance(pairs, list):
383
+ pairs = [pairs]
384
+
385
+ kwargs["dtype"] = None
386
+ kwargs["keepdims"] = True
387
+
388
+ ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs
389
+ ns = _concatenate2(ns, axes=axis)
390
+ n = ns.sum(axis=axis, **kwargs)
391
+
392
+ if computing_meta:
393
+ return n
394
+
395
+ totals = _concatenate2(deepmap(lambda pair: pair["total"], pairs), axes=axis)
396
+ Ms = _concatenate2(deepmap(lambda pair: pair["M"], pairs), axes=axis)
397
+
398
+ total = totals.sum(axis=axis, **kwargs)
399
+
400
+ with np.errstate(divide="ignore", invalid="ignore"):
401
+ if np.issubdtype(total.dtype, np.complexfloating):
402
+ mu = divide(total, n)
403
+ inner_term = np.abs(divide(totals, ns) - mu)
404
+ else:
405
+ mu = divide(total, n, dtype=dtype)
406
+ inner_term = divide(totals, ns, dtype=dtype) - mu
407
+
408
+ xs = [_moment_helper(Ms, ns, inner_term, o, sum, axis, kwargs) for o in range(2, order + 1)]
409
+ M = np.stack(xs, axis=-1)
410
+ return {"total": total, "n": n, "M": M}
411
+
412
+
413
+ def moment_agg(
414
+ pairs,
415
+ order=2,
416
+ ddof=0,
417
+ dtype="f8",
418
+ sum=np.sum,
419
+ axis=None,
420
+ computing_meta=False,
421
+ **kwargs,
422
+ ):
423
+ if not isinstance(pairs, list):
424
+ pairs = [pairs]
425
+
426
+ kwargs["dtype"] = dtype
427
+ # To properly handle ndarrays, the original dimensions need to be kept for
428
+ # part of the calculation.
429
+ keepdim_kw = kwargs.copy()
430
+ keepdim_kw["keepdims"] = True
431
+ keepdim_kw["dtype"] = None
432
+
433
+ ns = deepmap(lambda pair: pair["n"], pairs) if not computing_meta else pairs
434
+ ns = _concatenate2(ns, axes=axis)
435
+ n = ns.sum(axis=axis, **keepdim_kw)
436
+
437
+ if computing_meta:
438
+ return n
439
+
440
+ totals = _concatenate2(deepmap(lambda pair: pair["total"], pairs), axes=axis)
441
+ Ms = _concatenate2(deepmap(lambda pair: pair["M"], pairs), axes=axis)
442
+
443
+ mu = divide(totals.sum(axis=axis, **keepdim_kw), n)
444
+
445
+ with np.errstate(divide="ignore", invalid="ignore"):
446
+ if np.issubdtype(totals.dtype, np.complexfloating):
447
+ inner_term = np.abs(divide(totals, ns) - mu)
448
+ else:
449
+ inner_term = divide(totals, ns, dtype=dtype) - mu
450
+ inner_term = np.where(ns == 0, 0, inner_term)
451
+ M = _moment_helper(Ms, ns, inner_term, order, sum, axis, kwargs)
452
+
453
+ denominator = n.sum(axis=axis, **kwargs) - ddof
454
+
455
+ # taking care of the edge case with empty or all-nans array with ddof > 0
456
+ if isinstance(denominator, Number):
457
+ if denominator < 0:
458
+ denominator = np.nan
459
+ elif denominator is not np.ma.masked:
460
+ denominator[denominator < 0] = np.nan
461
+
462
+ return divide(M, denominator, dtype=dtype)
463
+
464
+
465
+ def moment(a, order, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
466
+ """Calculate the nth centralized moment.
467
+
468
+ Parameters
469
+ ----------
470
+ a : Array
471
+ Data over which to compute moment
472
+ order : int
473
+ Order of the moment that is returned, must be >= 2.
474
+ axis : int, optional
475
+ Axis along which the central moment is computed. The default is to
476
+ compute the moment of the flattened array.
477
+ dtype : data-type, optional
478
+ Type to use in computing the moment. For arrays of integer type the
479
+ default is float64; for arrays of float types it is the same as the
480
+ array type.
481
+ keepdims : bool, optional
482
+ If this is set to True, the axes which are reduced are left in the
483
+ result as dimensions with size one. With this option, the result
484
+ will broadcast correctly against the original array.
485
+ ddof : int, optional
486
+ "Delta Degrees of Freedom": the divisor used in the calculation is
487
+ N - ddof, where N represents the number of elements. By default
488
+ ddof is zero.
489
+
490
+ Returns
491
+ -------
492
+ moment : Array
493
+ """
494
+ if not isinstance(order, Integral) or order < 0:
495
+ raise ValueError("Order must be an integer >= 0")
496
+
497
+ if order < 2:
498
+ from dask_array.creation import ones, zeros
499
+
500
+ reduced = a.sum(axis=axis) # get reduced shape and chunks
501
+ if order == 0:
502
+ # When order equals 0, the result is 1, by definition.
503
+ return ones(reduced.shape, chunks=reduced.chunks, dtype="f8", meta=reduced._meta)
504
+ # By definition the first order about the mean is 0.
505
+ return zeros(reduced.shape, chunks=reduced.chunks, dtype="f8", meta=reduced._meta)
506
+
507
+ if dtype is not None:
508
+ dt = dtype
509
+ else:
510
+ dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
511
+
512
+ implicit_complex_dtype = dtype is None and np.iscomplexobj(a)
513
+
514
+ return reduction(
515
+ a,
516
+ partial(moment_chunk, order=order, implicit_complex_dtype=implicit_complex_dtype),
517
+ partial(moment_agg, order=order, ddof=ddof),
518
+ axis=axis,
519
+ keepdims=keepdims,
520
+ dtype=dt,
521
+ split_every=split_every,
522
+ out=out,
523
+ concatenate=False,
524
+ combine=partial(moment_combine, order=order),
525
+ )
526
+
527
+
528
+ @derived_from(np)
529
+ def var(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
530
+ if dtype is not None:
531
+ dt = dtype
532
+ else:
533
+ dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
534
+
535
+ implicit_complex_dtype = dtype is None and np.iscomplexobj(a)
536
+
537
+ return reduction(
538
+ a,
539
+ partial(moment_chunk, implicit_complex_dtype=implicit_complex_dtype),
540
+ partial(moment_agg, ddof=ddof),
541
+ axis=axis,
542
+ keepdims=keepdims,
543
+ dtype=dt,
544
+ split_every=split_every,
545
+ combine=moment_combine,
546
+ name="var",
547
+ out=out,
548
+ concatenate=False,
549
+ )
550
+
551
+
552
+ @derived_from(np)
553
+ def nanvar(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
554
+ if dtype is not None:
555
+ dt = dtype
556
+ else:
557
+ dt = getattr(np.var(np.ones(shape=(1,), dtype=a.dtype)), "dtype", object)
558
+
559
+ implicit_complex_dtype = dtype is None and np.iscomplexobj(a)
560
+
561
+ return reduction(
562
+ a,
563
+ partial(
564
+ moment_chunk,
565
+ sum=chunk.nansum,
566
+ numel=nannumel,
567
+ implicit_complex_dtype=implicit_complex_dtype,
568
+ ),
569
+ partial(moment_agg, sum=np.sum, ddof=ddof),
570
+ axis=axis,
571
+ keepdims=keepdims,
572
+ dtype=dt,
573
+ split_every=split_every,
574
+ combine=partial(moment_combine, sum=np.nansum),
575
+ out=out,
576
+ concatenate=False,
577
+ )
578
+
579
+
580
+ def _sqrt(a):
581
+ if isinstance(a, np.ma.masked_array) and not a.shape and a.mask.all():
582
+ return np.ma.masked
583
+ return np.sqrt(a)
584
+
585
+
586
+ def safe_sqrt(a):
587
+ """A version of sqrt that properly handles scalar masked arrays."""
588
+ if hasattr(a, "_elemwise"):
589
+ return a._elemwise(_sqrt, a)
590
+ return _sqrt(a)
591
+
592
+
593
+ @derived_from(np)
594
+ def std(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
595
+ result = safe_sqrt(
596
+ var(
597
+ a,
598
+ axis=axis,
599
+ dtype=dtype,
600
+ keepdims=keepdims,
601
+ ddof=ddof,
602
+ split_every=split_every,
603
+ out=out,
604
+ )
605
+ )
606
+ if dtype and dtype != result.dtype:
607
+ result = result.astype(dtype)
608
+ return result
609
+
610
+
611
+ @derived_from(np)
612
+ def nanstd(a, axis=None, dtype=None, keepdims=False, ddof=0, split_every=None, out=None):
613
+ result = safe_sqrt(
614
+ nanvar(
615
+ a,
616
+ axis=axis,
617
+ dtype=dtype,
618
+ keepdims=keepdims,
619
+ ddof=ddof,
620
+ split_every=split_every,
621
+ out=out,
622
+ )
623
+ )
624
+ if dtype and dtype != result.dtype:
625
+ result = result.astype(dtype)
626
+ return result
627
+
628
+
629
+ # Arg reductions helpers
630
+ def _arg_combine(data, axis, argfunc, keepdims=False):
631
+ """Merge intermediate results from ``arg_*`` functions"""
632
+ if isinstance(data, dict):
633
+ # Array type doesn't support structured arrays (e.g., CuPy),
634
+ # therefore `data` is stored in a `dict`.
635
+ assert data["vals"].ndim == data["arg"].ndim
636
+ axis = None if len(axis) == data["vals"].ndim or data["vals"].ndim == 1 else axis[0]
637
+ else:
638
+ axis = None if len(axis) == data.ndim or data.ndim == 1 else axis[0]
639
+
640
+ vals = data["vals"]
641
+ arg = data["arg"]
642
+ if axis is None:
643
+ local_args = argfunc(vals, axis=axis, keepdims=keepdims)
644
+ vals = vals.ravel()[local_args]
645
+ arg = arg.ravel()[local_args]
646
+ else:
647
+ local_args = argfunc(vals, axis=axis)
648
+ inds = list(np.ogrid[tuple(map(slice, local_args.shape))])
649
+ inds.insert(axis, local_args)
650
+ inds = tuple(inds)
651
+ vals = vals[inds]
652
+ arg = arg[inds]
653
+ if keepdims:
654
+ vals = np.expand_dims(vals, axis)
655
+ arg = np.expand_dims(arg, axis)
656
+ return arg, vals
657
+
658
+
659
+ def arg_chunk(func, argfunc, x, axis, offset_info):
660
+ arg_axis = None if len(axis) == x.ndim or x.ndim == 1 else axis[0]
661
+ vals = func(x, axis=arg_axis, keepdims=True)
662
+ arg = argfunc(x, axis=arg_axis, keepdims=True)
663
+ if x.ndim > 0:
664
+ if arg_axis is None:
665
+ offset, total_shape = offset_info
666
+ ind = np.unravel_index(arg.ravel()[0], x.shape)
667
+ total_ind = tuple(o + i for (o, i) in zip(offset, ind))
668
+ arg[:] = np.ravel_multi_index(total_ind, total_shape)
669
+ else:
670
+ arg += offset_info
671
+
672
+ if isinstance(vals, np.ma.masked_array):
673
+ if "min" in argfunc.__name__:
674
+ fill_value = np.ma.minimum_fill_value(vals)
675
+ else:
676
+ fill_value = np.ma.maximum_fill_value(vals)
677
+ vals = np.ma.filled(vals, fill_value)
678
+
679
+ try:
680
+ result = np.empty_like(vals, shape=vals.shape, dtype=[("vals", vals.dtype), ("arg", arg.dtype)])
681
+ except TypeError:
682
+ # Array type doesn't support structured arrays (e.g., CuPy)
683
+ result = dict()
684
+
685
+ result["vals"] = vals
686
+ result["arg"] = arg
687
+ return result
688
+
689
+
690
+ def arg_combine(argfunc, data, axis=None, **kwargs):
691
+ arg, vals = _arg_combine(data, axis, argfunc, keepdims=True)
692
+
693
+ try:
694
+ result = np.empty_like(vals, shape=vals.shape, dtype=[("vals", vals.dtype), ("arg", arg.dtype)])
695
+ except TypeError:
696
+ # Array type doesn't support structured arrays (e.g., CuPy).
697
+ result = dict()
698
+
699
+ result["vals"] = vals
700
+ result["arg"] = arg
701
+ return result
702
+
703
+
704
+ def arg_agg(argfunc, data, axis=None, keepdims=False, **kwargs):
705
+ return _arg_combine(data, axis, argfunc, keepdims=keepdims)[0]
706
+
707
+
708
+ def nanarg_agg(argfunc, data, axis=None, keepdims=False, **kwargs):
709
+ arg, vals = _arg_combine(data, axis, argfunc, keepdims=keepdims)
710
+ if np.any(np.isnan(vals)):
711
+ raise ValueError("All NaN slice encountered")
712
+ return arg
713
+
714
+
715
+ def _nanargmin(x, axis, **kwargs):
716
+ try:
717
+ return chunk.nanargmin(x, axis, **kwargs)
718
+ except ValueError:
719
+ return chunk.nanargmin(np.where(np.isnan(x), np.inf, x), axis, **kwargs)
720
+
721
+
722
+ def _nanargmax(x, axis, **kwargs):
723
+ try:
724
+ return chunk.nanargmax(x, axis, **kwargs)
725
+ except ValueError:
726
+ return chunk.nanargmax(np.where(np.isnan(x), -np.inf, x), axis, **kwargs)
727
+
728
+
729
+ @derived_from(np)
730
+ def argmax(a, axis=None, keepdims=False, split_every=None, out=None):
731
+ return arg_reduction(
732
+ a,
733
+ partial(arg_chunk, chunk.max, chunk.argmax),
734
+ partial(arg_combine, chunk.argmax),
735
+ partial(arg_agg, chunk.argmax),
736
+ axis=axis,
737
+ keepdims=keepdims,
738
+ split_every=split_every,
739
+ out=out,
740
+ )
741
+
742
+
743
+ @derived_from(np)
744
+ def argmin(a, axis=None, keepdims=False, split_every=None, out=None):
745
+ return arg_reduction(
746
+ a,
747
+ partial(arg_chunk, chunk.min, chunk.argmin),
748
+ partial(arg_combine, chunk.argmin),
749
+ partial(arg_agg, chunk.argmin),
750
+ axis=axis,
751
+ keepdims=keepdims,
752
+ split_every=split_every,
753
+ out=out,
754
+ )
755
+
756
+
757
+ @derived_from(np)
758
+ def nanargmax(a, axis=None, keepdims=False, split_every=None, out=None):
759
+ return arg_reduction(
760
+ a,
761
+ partial(arg_chunk, chunk.nanmax, _nanargmax),
762
+ partial(arg_combine, _nanargmax),
763
+ partial(nanarg_agg, _nanargmax),
764
+ axis=axis,
765
+ keepdims=keepdims,
766
+ split_every=split_every,
767
+ out=out,
768
+ )
769
+
770
+
771
+ @derived_from(np)
772
+ def nanargmin(a, axis=None, keepdims=False, split_every=None, out=None):
773
+ return arg_reduction(
774
+ a,
775
+ partial(arg_chunk, chunk.nanmin, _nanargmin),
776
+ partial(arg_combine, _nanargmin),
777
+ partial(nanarg_agg, _nanargmin),
778
+ axis=axis,
779
+ keepdims=keepdims,
780
+ split_every=split_every,
781
+ out=out,
782
+ )
783
+
784
+
785
+ # Median and quantile functions
786
+ from collections.abc import Iterable
787
+ from functools import reduce
788
+ from operator import mul
789
+
790
+ from dask_array._core_utils import handle_out
791
+
792
+ try:
793
+ import numbagg
794
+ except ImportError:
795
+ numbagg = None
796
+
797
+
798
+ @derived_from(np)
799
+ def median(a, axis=None, keepdims=False, out=None):
800
+ """
801
+ This works by automatically chunking the reduced axes to a single chunk if necessary
802
+ and then calling ``numpy.median`` function across the remaining dimensions
803
+ """
804
+ if axis is None:
805
+ raise NotImplementedError(
806
+ "The da.median function only works along an axis. The full algorithm is difficult to do in parallel"
807
+ )
808
+
809
+ if not isinstance(axis, Iterable):
810
+ axis = (axis,)
811
+
812
+ axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
813
+
814
+ # rechunk if reduced axes are not contained in a single chunk
815
+ if builtins.any(a.numblocks[ax] > 1 for ax in axis):
816
+ a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
817
+
818
+ result = a.map_blocks(
819
+ np.median,
820
+ axis=axis,
821
+ keepdims=keepdims,
822
+ drop_axis=axis if not keepdims else None,
823
+ chunks=([1 if ax in axis else c for ax, c in enumerate(a.chunks)] if keepdims else None),
824
+ )
825
+
826
+ result = handle_out(out, result)
827
+ return result
828
+
829
+
830
+ @derived_from(np)
831
+ def nanmedian(a, axis=None, keepdims=False, out=None):
832
+ """
833
+ This works by automatically chunking the reduced axes to a single chunk
834
+ and then calling ``numpy.nanmedian`` function across the remaining dimensions
835
+ """
836
+ from packaging.version import Version
837
+
838
+ if axis is None:
839
+ raise NotImplementedError(
840
+ "The da.nanmedian function only works along an axis or a subset of axes. "
841
+ "The full algorithm is difficult to do in parallel"
842
+ )
843
+
844
+ if not isinstance(axis, Iterable):
845
+ axis = (axis,)
846
+
847
+ axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
848
+
849
+ # rechunk if reduced axes are not contained in a single chunk
850
+ if builtins.any(a.numblocks[ax] > 1 for ax in axis):
851
+ a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
852
+
853
+ if (
854
+ numbagg is not None
855
+ and Version(numbagg.__version__).release >= (0, 7, 0)
856
+ and a.dtype.kind in "uif"
857
+ and not keepdims
858
+ ):
859
+ func = numbagg.nanmedian
860
+ kwargs = {}
861
+ else:
862
+ func = np.nanmedian
863
+ kwargs = {"keepdims": keepdims}
864
+
865
+ result = a.map_blocks(
866
+ func,
867
+ axis=axis,
868
+ drop_axis=axis if not keepdims else None,
869
+ chunks=([1 if ax in axis else c for ax, c in enumerate(a.chunks)] if keepdims else None),
870
+ **kwargs,
871
+ )
872
+
873
+ result = handle_out(out, result)
874
+ return result
875
+
876
+
877
+ def _get_quantile_chunks(a, q, axis, keepdims):
878
+ quantile_chunk = [len(q)] if isinstance(q, Iterable) else []
879
+ if keepdims:
880
+ return quantile_chunk + [1 if ax in axis else c for ax, c in enumerate(a.chunks)]
881
+ else:
882
+ return quantile_chunk + [c for ax, c in enumerate(a.chunks) if ax not in axis]
883
+
884
+
885
+ def _span_indexers(a):
886
+ shapes = 1 if len(a.shape) <= 2 else reduce(mul, list(a.shape)[1:-1])
887
+ original_shapes = shapes * a.shape[0]
888
+ indexers = [tuple(np.repeat(np.arange(a.shape[0]), shapes))]
889
+
890
+ for i in range(1, len(a.shape) - 1):
891
+ indexer = np.repeat(np.arange(a.shape[i]), shapes // a.shape[i])
892
+ indexers.append(tuple(np.tile(indexer, original_shapes // shapes)))
893
+ shapes //= a.shape[i]
894
+ return indexers
895
+
896
+
897
+ def _custom_quantile(a, q, axis=None, method="linear", keepdims=False, **kwargs):
898
+ if method != "linear" or len(axis) != 1 or axis[0] != len(a.shape) - 1 or len(a.shape) == 1 or a.shape[-1] > 1000:
899
+ return np.nanquantile(a, q, axis=axis, method=method, keepdims=keepdims, **kwargs)
900
+
901
+ sorted_arr = np.sort(a, axis=-1)
902
+ indexers = _span_indexers(a)
903
+ nr_quantiles = len(indexers[0])
904
+
905
+ is_scalar = False
906
+ if not isinstance(q, Iterable):
907
+ is_scalar = True
908
+ q = [q]
909
+
910
+ quantiles = []
911
+ reshape_shapes = (1,) + tuple(sorted_arr.shape[:-1]) + ((1,) if keepdims else ())
912
+ for single_q in list(q):
913
+ i = (np.ones(nr_quantiles) * (a.shape[-1] - 1) - np.isnan(sorted_arr).sum(axis=-1).reshape(-1)) * single_q
914
+ lower_value, higher_value = np.floor(i).astype(int), np.ceil(i).astype(int)
915
+
916
+ lower = sorted_arr[tuple(indexers) + (tuple(lower_value),)]
917
+ higher = sorted_arr[tuple(indexers) + (tuple(higher_value),)]
918
+
919
+ factor_higher = i - lower_value
920
+ factor_higher = np.where(factor_higher == 0.0, 1.0, factor_higher)
921
+ factor_lower = higher_value - i
922
+
923
+ quantiles.append((higher * factor_higher + lower * factor_lower).reshape(*reshape_shapes))
924
+
925
+ if is_scalar:
926
+ return quantiles[0].squeeze(axis=0)
927
+ else:
928
+ return np.concatenate(quantiles, axis=0)
929
+
930
+
931
+ @derived_from(np)
932
+ def quantile(
933
+ a,
934
+ q,
935
+ axis=None,
936
+ out=None,
937
+ overwrite_input=False,
938
+ method="linear",
939
+ keepdims=False,
940
+ *,
941
+ weights=None,
942
+ interpolation=None,
943
+ ):
944
+ """
945
+ This works by automatically chunking the reduced axes to a single chunk if necessary
946
+ and then calling ``numpy.quantile`` function across the remaining dimensions
947
+ """
948
+ if interpolation is not None:
949
+ warnings.warn(
950
+ "The `interpolation` argument to quantile was renamed to `method`.",
951
+ FutureWarning,
952
+ stacklevel=2,
953
+ )
954
+ if method != "linear":
955
+ raise TypeError("Cannot pass interpolation and method keywords!")
956
+ method = interpolation
957
+
958
+ if axis is None:
959
+ if builtins.any(n_blocks > 1 for n_blocks in a.numblocks):
960
+ raise NotImplementedError(
961
+ "The da.quantile function only works along an axis. The full algorithm is difficult to do in parallel"
962
+ )
963
+ else:
964
+ axis = tuple(range(len(a.shape)))
965
+
966
+ if not isinstance(axis, Iterable):
967
+ axis = (axis,)
968
+
969
+ axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
970
+
971
+ # rechunk if reduced axes are not contained in a single chunk
972
+ if builtins.any(a.numblocks[ax] > 1 for ax in axis):
973
+ a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
974
+
975
+ # NumPy >= 2.0 supports weights
976
+ kwargs = {}
977
+ try:
978
+ # Check if weights parameter is supported
979
+ import numpy as np
980
+
981
+ if hasattr(np.quantile, "__wrapped__") or weights is not None:
982
+ kwargs["weights"] = weights
983
+ except Exception:
984
+ pass
985
+
986
+ result = a.map_blocks(
987
+ np.quantile,
988
+ q=q,
989
+ method=method,
990
+ axis=axis,
991
+ keepdims=keepdims,
992
+ drop_axis=axis if not keepdims else None,
993
+ new_axis=0 if isinstance(q, Iterable) else None,
994
+ chunks=_get_quantile_chunks(a, q, axis, keepdims),
995
+ dtype=np.quantile(np.array([0], dtype=a.dtype), q).dtype,
996
+ **kwargs,
997
+ )
998
+
999
+ result = handle_out(out, result)
1000
+ return result
1001
+
1002
+
1003
+ @derived_from(np)
1004
+ def nanquantile(
1005
+ a,
1006
+ q,
1007
+ axis=None,
1008
+ out=None,
1009
+ overwrite_input=False,
1010
+ method="linear",
1011
+ keepdims=False,
1012
+ *,
1013
+ weights=None,
1014
+ interpolation=None,
1015
+ ):
1016
+ """
1017
+ This works by automatically chunking the reduced axes to a single chunk
1018
+ and then calling ``numpy.nanquantile`` function across the remaining dimensions
1019
+ """
1020
+ from packaging.version import Version
1021
+
1022
+ if interpolation is not None:
1023
+ warnings.warn(
1024
+ "The `interpolation` argument to nanquantile was renamed to `method`.",
1025
+ FutureWarning,
1026
+ stacklevel=2,
1027
+ )
1028
+ if method != "linear":
1029
+ raise TypeError("Cannot pass interpolation and method keywords!")
1030
+ method = interpolation
1031
+
1032
+ if axis is None:
1033
+ if builtins.any(n_blocks > 1 for n_blocks in a.numblocks):
1034
+ raise NotImplementedError(
1035
+ "The da.nanquantile function only works along an axis. "
1036
+ "The full algorithm is difficult to do in parallel"
1037
+ )
1038
+ else:
1039
+ axis = tuple(range(len(a.shape)))
1040
+
1041
+ if not isinstance(axis, Iterable):
1042
+ axis = (axis,)
1043
+
1044
+ axis = [ax + a.ndim if ax < 0 else ax for ax in axis]
1045
+
1046
+ # rechunk if reduced axes are not contained in a single chunk
1047
+ if builtins.any(a.numblocks[ax] > 1 for ax in axis):
1048
+ a = a.rechunk({ax: -1 if ax in axis else "auto" for ax in range(a.ndim)})
1049
+
1050
+ if (
1051
+ numbagg is not None
1052
+ and Version(numbagg.__version__).release >= (0, 8, 0)
1053
+ and a.dtype.kind in "uif"
1054
+ and weights is None
1055
+ and method == "linear"
1056
+ and not keepdims
1057
+ ):
1058
+ func = numbagg.nanquantile
1059
+ kwargs = {"quantiles": q}
1060
+ else:
1061
+ func = _custom_quantile
1062
+ kwargs = {
1063
+ "q": q,
1064
+ "method": method,
1065
+ "keepdims": keepdims,
1066
+ }
1067
+ # NumPy >= 2.0 supports weights
1068
+ if weights is not None:
1069
+ kwargs["weights"] = weights
1070
+
1071
+ result = a.map_blocks(
1072
+ func,
1073
+ axis=axis,
1074
+ drop_axis=axis if not keepdims else None,
1075
+ new_axis=0 if isinstance(q, Iterable) else None,
1076
+ chunks=_get_quantile_chunks(a, q, axis, keepdims),
1077
+ dtype=np.nanquantile(np.array([0], dtype=a.dtype), q).dtype,
1078
+ **kwargs,
1079
+ )
1080
+
1081
+ result = handle_out(out, result)
1082
+ return result