dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,725 @@
1
+ from __future__ import annotations
2
+
3
+ import builtins
4
+ import math
5
+ import warnings
6
+ from functools import partial
7
+ from itertools import product
8
+ from numbers import Integral
9
+
10
+ import numpy as np
11
+ from tlz import compose, get, partition_all
12
+
13
+ from dask import config
14
+ from dask_array._new_collection import new_collection
15
+ from dask_array._expr import ArrayExpr
16
+ from dask_array._utils import compute_meta
17
+ from dask_array._core_utils import _concatenate2
18
+ from dask_array._numpy_compat import ComplexWarning
19
+ from dask_array._utils import is_arraylike, validate_axis
20
+ from dask.blockwise import lol_tuples
21
+ from dask.tokenize import _tokenize_deterministic
22
+ from dask.utils import cached_property, funcname, getargspec, is_series_like
23
+
24
+
25
+ class Reduction(ArrayExpr):
26
+ """Logical reduction expression that captures reduction intent.
27
+
28
+ This expression represents a reduction operation conceptually,
29
+ without immediately materializing the physical Blockwise + PartialReduce
30
+ cascade. The physical implementation is deferred to _lower().
31
+ """
32
+
33
+ _parameters = [
34
+ "array",
35
+ "chunk",
36
+ "aggregate",
37
+ "axis",
38
+ "keepdims",
39
+ "dtype",
40
+ "split_every",
41
+ "combine",
42
+ "name",
43
+ "concatenate",
44
+ "output_size",
45
+ "meta",
46
+ "weights",
47
+ ]
48
+ _defaults = {
49
+ "axis": None,
50
+ "keepdims": False,
51
+ "dtype": None,
52
+ "split_every": None,
53
+ "combine": None,
54
+ "name": None,
55
+ "concatenate": True,
56
+ "output_size": 1,
57
+ "meta": None,
58
+ "weights": None,
59
+ }
60
+
61
+ def __dask_tokenize__(self):
62
+ if not self._determ_token:
63
+ self._determ_token = _tokenize_deterministic(
64
+ self.chunk,
65
+ self.aggregate,
66
+ self.array,
67
+ self.axis,
68
+ self.keepdims,
69
+ self.operand("dtype"),
70
+ self.split_every,
71
+ self.combine,
72
+ self.concatenate,
73
+ self.output_size,
74
+ self.weights,
75
+ )
76
+ return self._determ_token
77
+
78
+ @cached_property
79
+ def _name(self):
80
+ prefix = self.operand("name") or funcname(self.chunk)
81
+ return f"{prefix}-{self.deterministic_token}"
82
+
83
+ @cached_property
84
+ def name(self):
85
+ """Return the name of the final lowered expression.
86
+
87
+ This ensures that Array.name matches the task keys in the graph.
88
+ """
89
+ return self.lower_completely().name
90
+
91
+ @cached_property
92
+ def chunks(self):
93
+ """Output chunks after reduction."""
94
+ axis = self.axis
95
+ if self.keepdims:
96
+ return tuple((self.output_size,) if i in axis else c for i, c in enumerate(self.array.chunks))
97
+ else:
98
+ return tuple(c for i, c in enumerate(self.array.chunks) if i not in axis)
99
+
100
+ @cached_property
101
+ def dtype(self):
102
+ if self.operand("dtype") is not None:
103
+ return np.dtype(self.operand("dtype"))
104
+ return self.array.dtype
105
+
106
+ @property
107
+ def _meta(self):
108
+ # Compute a minimal metadata array with correct dtype and ndim
109
+ dtype = self.dtype
110
+ ndim = len(self.chunks)
111
+ return np.empty((0,) * ndim, dtype=dtype)
112
+
113
+ def _layer(self):
114
+ """Generate the task layer by lowering first.
115
+
116
+ Reduction should always be lowered before graph generation,
117
+ but we need to support direct _layer() calls for is_dask_collection().
118
+ """
119
+ return self.lower_completely()._layer()
120
+
121
+ def _simplify_up(self, parent, dependents):
122
+ """Allow slice operations to push through Reduction."""
123
+ from dask_array.slicing import SliceSlicesIntegers
124
+
125
+ if isinstance(parent, SliceSlicesIntegers):
126
+ return self._accept_slice(parent)
127
+ return None
128
+
129
+ def _accept_slice(self, slice_expr):
130
+ """Accept a slice being pushed through this Reduction."""
131
+ reduced_axes = set(self.axis)
132
+
133
+ def make_result(sliced_input, input_index):
134
+ # Handle sliced weights if present
135
+ sliced_weights = None
136
+ if self.weights is not None:
137
+ sliced_weights = new_collection(self.weights)[input_index].expr
138
+
139
+ return Reduction(
140
+ sliced_input.expr,
141
+ self.chunk,
142
+ self.aggregate,
143
+ self.axis,
144
+ self.keepdims,
145
+ self.operand("dtype"),
146
+ self.split_every,
147
+ self.combine,
148
+ self.operand("name"),
149
+ self.concatenate,
150
+ self.output_size,
151
+ self.meta,
152
+ sliced_weights,
153
+ )
154
+
155
+ return _accept_slice_impl(slice_expr, self.array, reduced_axes, self.keepdims, make_result)
156
+
157
+ def _lower(self):
158
+ """Lower to Blockwise + PartialReduce cascade."""
159
+ from dask_array._collection import blockwise
160
+
161
+ axis = self.axis
162
+ dtype = self.operand("dtype") or float
163
+ name = self.operand("name")
164
+ output_size = self.output_size
165
+
166
+ # Prepare chunk function with dtype if needed
167
+ chunk_func = self.chunk
168
+ if "dtype" in getargspec(chunk_func).args:
169
+ chunk_func = partial(chunk_func, dtype=dtype)
170
+
171
+ aggregate_func = self.aggregate
172
+ if "dtype" in getargspec(aggregate_func).args:
173
+ aggregate_func = partial(aggregate_func, dtype=dtype)
174
+
175
+ # Build args for blockwise
176
+ inds = tuple(range(self.array.ndim))
177
+ args = (self.array, inds)
178
+
179
+ if self.weights is not None:
180
+ args += (self.weights, inds)
181
+
182
+ # Create Blockwise for per-chunk reduction
183
+ adjust_chunks = {i: output_size for i in axis}
184
+ tmp = blockwise(
185
+ chunk_func,
186
+ inds,
187
+ *args,
188
+ axis=axis,
189
+ keepdims=True,
190
+ token=name,
191
+ dtype=dtype,
192
+ adjust_chunks=adjust_chunks,
193
+ )
194
+
195
+ # Compute reduced_meta for PartialReduce
196
+ if self.meta is None and hasattr(self.array, "_meta"):
197
+ try:
198
+ reduced_meta = compute_meta(
199
+ chunk_func, self.array.dtype, self.array._meta, axis=axis, keepdims=True, computing_meta=True
200
+ )
201
+ except TypeError:
202
+ reduced_meta = compute_meta(chunk_func, self.array.dtype, self.array._meta, axis=axis, keepdims=True)
203
+ except ValueError:
204
+ reduced_meta = None
205
+ else:
206
+ reduced_meta = self.meta
207
+
208
+ # Build tree reduction with PartialReduce
209
+ result = _build_tree_reduce_expr(
210
+ tmp.expr,
211
+ aggregate_func,
212
+ axis,
213
+ self.keepdims,
214
+ dtype,
215
+ self.split_every,
216
+ self.combine,
217
+ name,
218
+ self.concatenate,
219
+ reduced_meta,
220
+ )
221
+
222
+ # Override final chunks for output_size != 1
223
+ if self.keepdims and output_size != 1:
224
+ from dask_array._expr import ChunksOverride
225
+
226
+ final_chunks = tuple((output_size,) if i in axis else c for i, c in enumerate(result.chunks))
227
+ result = ChunksOverride(result, final_chunks)
228
+
229
+ return result
230
+
231
+
232
+ def reduction(
233
+ x,
234
+ chunk,
235
+ aggregate,
236
+ axis=None,
237
+ keepdims=False,
238
+ dtype=None,
239
+ split_every=None,
240
+ combine=None,
241
+ name=None,
242
+ out=None,
243
+ concatenate=True,
244
+ output_size=1,
245
+ meta=None,
246
+ weights=None,
247
+ ):
248
+ """General version of reductions
249
+
250
+ Parameters
251
+ ----------
252
+ x: Array
253
+ Data being reduced along one or more axes
254
+ chunk: callable(x_chunk, [weights_chunk=None], axis, keepdims)
255
+ First function to be executed when resolving the dask graph.
256
+ This function is applied in parallel to all original chunks of x.
257
+ See below for function parameters.
258
+ combine: callable(x_chunk, axis, keepdims), optional
259
+ Function used for intermediate recursive aggregation (see
260
+ split_every below). If omitted, it defaults to aggregate.
261
+ If the reduction can be performed in less than 3 steps, it will not
262
+ be invoked at all.
263
+ aggregate: callable(x_chunk, axis, keepdims)
264
+ Last function to be executed when resolving the dask graph,
265
+ producing the final output. It is always invoked, even when the reduced
266
+ Array counts a single chunk along the reduced axes.
267
+ axis: int or sequence of ints, optional
268
+ Axis or axes to aggregate upon. If omitted, aggregate along all axes.
269
+ keepdims: boolean, optional
270
+ Whether the reduction function should preserve the reduced axes,
271
+ leaving them at size ``output_size``, or remove them.
272
+ dtype: np.dtype
273
+ data type of output. This argument was previously optional, but
274
+ leaving as ``None`` will now raise an exception.
275
+ split_every: int >= 2 or dict(axis: int), optional
276
+ Determines the depth of the recursive aggregation. If set to or more
277
+ than the number of input chunks, the aggregation will be performed in
278
+ two steps, one ``chunk`` function per input chunk and a single
279
+ ``aggregate`` function at the end. If set to less than that, an
280
+ intermediate ``combine`` function will be used, so that any one
281
+ ``combine`` or ``aggregate`` function has no more than ``split_every``
282
+ inputs. The depth of the aggregation graph will be
283
+ :math:`log_{split_every}(input chunks along reduced axes)`. Setting to
284
+ a low value can reduce cache size and network transfers, at the cost of
285
+ more CPU and a larger dask graph.
286
+
287
+ Omit to let dask heuristically decide a good default. A default can
288
+ also be set globally with the ``split_every`` key in
289
+ :mod:`dask.config`.
290
+ name: str, optional
291
+ Prefix of the keys of the intermediate and output nodes. If omitted it
292
+ defaults to the function names.
293
+ out: Array, optional
294
+ Another dask array whose contents will be replaced. Omit to create a
295
+ new one. Note that, unlike in numpy, this setting gives no performance
296
+ benefits whatsoever, but can still be useful if one needs to preserve
297
+ the references to a previously existing Array.
298
+ concatenate: bool, optional
299
+ If True (the default), the outputs of the ``chunk``/``combine``
300
+ functions are concatenated into a single np.array before being passed
301
+ to the ``combine``/``aggregate`` functions. If False, the input of
302
+ ``combine`` and ``aggregate`` will be either a list of the raw outputs
303
+ of the previous step or a single output, and the function will have to
304
+ concatenate it itself. It can be useful to set this to False if the
305
+ chunk and/or combine steps do not produce np.arrays.
306
+ output_size: int >= 1, optional
307
+ Size of the output of the ``aggregate`` function along the reduced
308
+ axes. Ignored if keepdims is False.
309
+ weights : array_like, optional
310
+ Weights to be used in the reduction of `x`. Will be
311
+ automatically broadcast to the shape of `x`, and so must have
312
+ a compatible shape. For instance, if `x` has shape ``(3, 4)``
313
+ then acceptable shapes for `weights` are ``(3, 4)``, ``(4,)``,
314
+ ``(3, 1)``, ``(1, 1)``, ``(1)``, and ``()``.
315
+
316
+ Returns
317
+ -------
318
+ dask array
319
+
320
+ **Function Parameters**
321
+
322
+ x_chunk: numpy.ndarray
323
+ Individual input chunk. For ``chunk`` functions, it is one of the
324
+ original chunks of x. For ``combine`` and ``aggregate`` functions, it's
325
+ the concatenation of the outputs produced by the previous ``chunk`` or
326
+ ``combine`` functions. If concatenate=False, it's a list of the raw
327
+ outputs from the previous functions.
328
+ weights_chunk: numpy.ndarray, optional
329
+ Only applicable to the ``chunk`` function. Weights, with the
330
+ same shape as `x_chunk`, to be applied during the reduction of
331
+ the individual input chunk. If ``weights`` have not been
332
+ provided then the function may omit this parameter. When
333
+ `weights_chunk` is included then it must occur immediately
334
+ after the `x_chunk` parameter, and must also have a default
335
+ value for cases when ``weights`` are not provided.
336
+ axis: tuple
337
+ Normalized list of axes to reduce upon, e.g. ``(0, )``
338
+ Scalar, negative, and None axes have been normalized away.
339
+ Note that some numpy reduction functions cannot reduce along multiple
340
+ axes at once and strictly require an int in input. Such functions have
341
+ to be wrapped to cope.
342
+ keepdims: bool
343
+ Whether the reduction function should preserve the reduced axes or
344
+ remove them.
345
+
346
+ """
347
+ # Convert non-dask arrays to dask arrays
348
+ from dask_array._collection import Array
349
+
350
+ if not isinstance(x, Array):
351
+ from dask_array.core._conversion import asanyarray
352
+
353
+ x = asanyarray(x)
354
+
355
+ if axis is None:
356
+ axis = tuple(range(x.ndim))
357
+ if isinstance(axis, Integral):
358
+ axis = (axis,)
359
+ axis = validate_axis(axis, x.ndim)
360
+
361
+ if dtype is None:
362
+ raise ValueError("Must specify dtype")
363
+
364
+ if is_series_like(x):
365
+ x = x.values
366
+
367
+ # Handle weights broadcasting
368
+ weights_expr = None
369
+ if weights is not None:
370
+ from dask_array._broadcast import broadcast_to
371
+ from dask_array.core._conversion import asanyarray
372
+
373
+ wgt = asanyarray(weights)
374
+ try:
375
+ wgt = broadcast_to(wgt, x.shape)
376
+ except ValueError:
377
+ raise ValueError(f"Weights with shape {wgt.shape} are not broadcastable to x with shape {x.shape}")
378
+ weights_expr = wgt.expr
379
+
380
+ # Create the Reduction expression
381
+ result = new_collection(
382
+ Reduction(
383
+ x.expr,
384
+ chunk,
385
+ aggregate,
386
+ axis,
387
+ keepdims,
388
+ dtype,
389
+ split_every,
390
+ combine,
391
+ name,
392
+ concatenate,
393
+ output_size,
394
+ meta,
395
+ weights_expr,
396
+ )
397
+ )
398
+
399
+ # Handle out= parameter
400
+ if out is not None:
401
+ from dask_array.core._blockwise_funcs import _handle_out
402
+
403
+ return _handle_out(out, result)
404
+ return result
405
+
406
+
407
+ def _tree_reduce(
408
+ x,
409
+ aggregate,
410
+ axis,
411
+ keepdims,
412
+ dtype,
413
+ split_every=None,
414
+ combine=None,
415
+ name=None,
416
+ concatenate=True,
417
+ reduced_meta=None,
418
+ ):
419
+ """Perform the tree reduction step of a reduction.
420
+
421
+ Lower level, users should use ``reduction`` or ``arg_reduction`` directly.
422
+ """
423
+ return new_collection(
424
+ _build_tree_reduce_expr(
425
+ x, aggregate, axis, keepdims, dtype, split_every, combine, name, concatenate, reduced_meta
426
+ )
427
+ )
428
+
429
+
430
+ def _build_tree_reduce_expr(
431
+ x,
432
+ aggregate,
433
+ axis,
434
+ keepdims,
435
+ dtype,
436
+ split_every,
437
+ combine,
438
+ name,
439
+ concatenate,
440
+ reduced_meta,
441
+ ):
442
+ """Build tree reduction cascade of PartialReduce expressions.
443
+
444
+ Shared implementation used by both Reduction._build_tree_reduce and _tree_reduce.
445
+ """
446
+ # Normalize split_every
447
+ split_every = split_every or config.get("split_every", 16)
448
+ if isinstance(split_every, dict):
449
+ split_every = {k: split_every.get(k, 2) for k in axis}
450
+ elif isinstance(split_every, Integral):
451
+ n = builtins.max(int(split_every ** (1 / (len(axis) or 1))), 2)
452
+ split_every = dict.fromkeys(axis, n)
453
+ else:
454
+ raise ValueError("split_every must be a int or a dict")
455
+
456
+ # Compute tree depth
457
+ depth = 1
458
+ for i, n in enumerate(x.numblocks):
459
+ if i in split_every and split_every[i] != 1:
460
+ depth = int(builtins.max(depth, math.ceil(math.log(n, split_every[i]))))
461
+
462
+ # Build combine function
463
+ func = partial(combine or aggregate, axis=axis, keepdims=True)
464
+ if concatenate:
465
+ func = compose(func, partial(_concatenate2, axes=sorted(axis)))
466
+
467
+ # Build intermediate PartialReduce layers
468
+ for _ in range(depth - 1):
469
+ x = PartialReduce(
470
+ x,
471
+ func,
472
+ split_every,
473
+ True,
474
+ dtype=dtype,
475
+ name=(name or funcname(combine or aggregate)) + "-partial",
476
+ reduced_meta=reduced_meta,
477
+ )
478
+
479
+ # Build final aggregate function
480
+ agg_func = partial(aggregate, axis=axis, keepdims=keepdims)
481
+ if concatenate:
482
+ agg_func = compose(agg_func, partial(_concatenate2, axes=sorted(axis)))
483
+
484
+ # Final aggregation layer
485
+ return PartialReduce(
486
+ x,
487
+ agg_func,
488
+ split_every,
489
+ keepdims=keepdims,
490
+ dtype=dtype,
491
+ name=(name or funcname(aggregate)) + "-aggregate",
492
+ reduced_meta=reduced_meta,
493
+ )
494
+
495
+
496
+ def _accept_slice_impl(slice_expr, input_array, reduced_axes, keepdims, make_result):
497
+ """Shared implementation for slice pushdown through reductions.
498
+
499
+ Parameters
500
+ ----------
501
+ slice_expr : SliceSlicesIntegers
502
+ The slice expression being pushed through
503
+ input_array : ArrayExpr
504
+ The input array to the reduction
505
+ reduced_axes : set
506
+ Set of axes being reduced
507
+ keepdims : bool
508
+ Whether the reduction keeps dimensions
509
+ make_result : callable(sliced_input, input_index) -> expr
510
+ Factory function to create the result expression
511
+
512
+ Returns
513
+ -------
514
+ expr or None
515
+ The transformed expression, or None if slice cannot be pushed through
516
+ """
517
+ from dask_array.slicing import SliceSlicesIntegers
518
+
519
+ index = slice_expr.index
520
+
521
+ # Don't handle None/newaxis
522
+ if any(idx is None for idx in index):
523
+ return None
524
+
525
+ input_ndim = input_array.ndim
526
+
527
+ if keepdims:
528
+ # With keepdims, output has same ndim as input
529
+ full_index = index + (slice(None),) * (input_ndim - len(index))
530
+ else:
531
+ # Without keepdims, reduced axes are removed from output
532
+ out_axis = [i for i in range(input_ndim) if i not in reduced_axes]
533
+ output_ndim = len(out_axis)
534
+ full_index = index + (slice(None),) * (output_ndim - len(index))
535
+
536
+ # Convert integers to size-1 slices to preserve dimensions
537
+ slice_index = tuple(slice(idx, idx + 1) if isinstance(idx, Integral) else idx for idx in full_index)
538
+ has_integers = any(isinstance(idx, Integral) for idx in full_index)
539
+
540
+ # Build input index mapping output axes to input axes
541
+ if keepdims:
542
+ input_index = slice_index
543
+ else:
544
+ input_index = []
545
+ out_pos = 0
546
+ for in_ax in range(input_ndim):
547
+ if in_ax in reduced_axes:
548
+ input_index.append(slice(None))
549
+ else:
550
+ input_index.append(slice_index[out_pos])
551
+ out_pos += 1
552
+ input_index = tuple(input_index)
553
+
554
+ # Apply the slice to the input
555
+ sliced_input = new_collection(input_array)[input_index]
556
+
557
+ # Don't push slice through if it would create empty arrays on non-reduced axes
558
+ for ax in range(input_ndim):
559
+ if ax not in reduced_axes and sliced_input.shape[ax] == 0:
560
+ return None
561
+
562
+ result = make_result(sliced_input, input_index)
563
+
564
+ # If we converted integers to slices, extract with [0] to restore dimensions
565
+ if has_integers:
566
+ extract_index = tuple(0 if isinstance(idx, Integral) else slice(None) for idx in full_index)
567
+ return SliceSlicesIntegers(result, extract_index, slice_expr.allow_getitem_optimization)
568
+
569
+ return result
570
+
571
+
572
+ class PartialReduce(ArrayExpr):
573
+ _parameters = [
574
+ "array",
575
+ "func",
576
+ "split_every",
577
+ "keepdims",
578
+ "dtype",
579
+ "name",
580
+ "reduced_meta",
581
+ ]
582
+ _defaults = {
583
+ "keepdims": False,
584
+ "dtype": None,
585
+ "name": None,
586
+ "reduced_meta": None,
587
+ }
588
+
589
+ def __dask_tokenize__(self):
590
+ if not self._determ_token:
591
+ # TODO: Is there an actual need to overwrite this?
592
+ self._determ_token = _tokenize_deterministic(
593
+ self.func, self.array, self.split_every, self.keepdims, self.dtype
594
+ )
595
+ return self._determ_token
596
+
597
+ @cached_property
598
+ def _name(self):
599
+ return (self.operand("name") or funcname(self.func)) + "-" + self.deterministic_token
600
+
601
+ @cached_property
602
+ def dtype(self):
603
+ # Use the explicitly passed dtype parameter instead of inferring from meta
604
+ if self.operand("dtype") is not None:
605
+ return np.dtype(self.operand("dtype"))
606
+ return super().dtype
607
+
608
+ @cached_property
609
+ def chunks(self):
610
+ chunks = [
611
+ (tuple(1 for p in partition_all(self.split_every[i], c)) if i in self.split_every else c)
612
+ for (i, c) in enumerate(self.array.chunks)
613
+ ]
614
+
615
+ if not self.keepdims:
616
+ out_axis = [i for i in range(self.array.ndim) if i not in self.split_every]
617
+ getter = lambda k: get(out_axis, k)
618
+ chunks = list(getter(chunks))
619
+ return tuple(chunks)
620
+
621
+ def _layer(self):
622
+ x = self.array
623
+ parts = [list(partition_all(self.split_every.get(i, 1), range(n))) for (i, n) in enumerate(x.numblocks)]
624
+ keys = product(*map(range, map(len, parts)))
625
+ if not self.keepdims:
626
+ out_axis = [i for i in range(x.ndim) if i not in self.split_every]
627
+ getter = lambda k: get(out_axis, k)
628
+ keys = map(getter, keys)
629
+ dsk = {}
630
+ for k, p in zip(keys, product(*parts)):
631
+ free = {i: j[0] for (i, j) in enumerate(p) if len(j) == 1 and i not in self.split_every}
632
+ dummy = dict(i for i in enumerate(p) if i[0] in self.split_every)
633
+ g = lol_tuples((x.name,), range(x.ndim), free, dummy)
634
+ dsk[(self._name,) + k] = (self.func, g)
635
+
636
+ return dsk
637
+
638
+ @property
639
+ def _meta(self):
640
+ meta = self.array._meta
641
+ original_dtype = getattr(self.reduced_meta, "dtype", None) or getattr(meta, "dtype", None)
642
+
643
+ if self.reduced_meta is not None:
644
+ try:
645
+ meta = self.func(self.reduced_meta, computing_meta=True)
646
+ except TypeError:
647
+ # No computing_meta kwarg, try without it
648
+ try:
649
+ meta = self.func(self.reduced_meta)
650
+ except ValueError as e:
651
+ if "zero-size array to reduction operation" in str(e):
652
+ meta = self.reduced_meta
653
+ except IndexError:
654
+ meta = self.reduced_meta
655
+ except (ValueError, IndexError):
656
+ # Can't compute on empty array (ufunc, argtopk, etc.)
657
+ meta = self.reduced_meta
658
+
659
+ # Ensure meta is array-like (func can return Python scalars for object dtype)
660
+ if not is_arraylike(meta) and meta is not None:
661
+ meta = np.array(meta, dtype=original_dtype or object)
662
+
663
+ # Reshape meta to match output dimensions
664
+ if is_arraylike(meta) and meta.ndim != len(self.chunks):
665
+ if len(self.chunks) == 0:
666
+ # 0D output - reduce to scalar
667
+ try:
668
+ meta = meta.sum()
669
+ if not hasattr(meta, "dtype"):
670
+ meta = np.array(meta, dtype=original_dtype)
671
+ except TypeError:
672
+ # dtype doesn't support sum (e.g., datetime64)
673
+ meta = np.empty((), dtype=meta.dtype)
674
+ else:
675
+ target_shape = (0,) * len(self.chunks)
676
+ # Use np.prod(shape) for array-likes that don't expose .size
677
+ meta_size = getattr(meta, "size", None)
678
+ if meta_size is None:
679
+ meta_size = np.prod(meta.shape)
680
+ if meta_size != 0:
681
+ # Can't reshape non-empty array to empty shape (e.g., scalar)
682
+ meta = np.empty(target_shape, dtype=meta.dtype)
683
+ else:
684
+ meta = meta.reshape(target_shape)
685
+
686
+ # Ensure meta has the correct dtype if dtype is explicitly specified
687
+ if self.operand("dtype") is not None and hasattr(meta, "dtype"):
688
+ target_dtype = np.dtype(self.operand("dtype"))
689
+ if meta.dtype != target_dtype:
690
+ with warnings.catch_warnings():
691
+ # Suppress ComplexWarning when casting complex to real (e.g., var)
692
+ warnings.filterwarnings("ignore", category=ComplexWarning)
693
+ meta = meta.astype(target_dtype)
694
+
695
+ # Convert MaskedConstant (np.ma.masked) to a proper MaskedArray
696
+ # since the singleton cannot be tokenized
697
+ if isinstance(meta, np.ma.core.MaskedConstant):
698
+ meta = np.ma.array(meta, ndmin=0)
699
+
700
+ return meta
701
+
702
+ def _simplify_up(self, parent, dependents):
703
+ """Allow slice operations to push through PartialReduce."""
704
+ from dask_array.slicing import SliceSlicesIntegers
705
+
706
+ if isinstance(parent, SliceSlicesIntegers):
707
+ return self._accept_slice(parent)
708
+ return None
709
+
710
+ def _accept_slice(self, slice_expr):
711
+ """Accept a slice being pushed through this PartialReduce."""
712
+ reduced_axes = set(self.split_every.keys())
713
+
714
+ def make_result(sliced_input, input_index):
715
+ return PartialReduce(
716
+ sliced_input.expr,
717
+ self.func,
718
+ self.split_every,
719
+ self.keepdims,
720
+ self.operand("dtype"),
721
+ self.operand("name"),
722
+ self.reduced_meta,
723
+ )
724
+
725
+ return _accept_slice_impl(slice_expr, self.array, reduced_axes, self.keepdims, make_result)