dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_expr.py ADDED
@@ -0,0 +1,544 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import math
5
+ import re
6
+ import warnings
7
+ from functools import cached_property, reduce
8
+ from itertools import product
9
+ from operator import mul
10
+
11
+ import numpy as np
12
+ import toolz
13
+
14
+ from dask._expr import FinalizeCompute, SingletonExpr
15
+ from dask._task_spec import List, Task, TaskRef
16
+ from dask_array._core_utils import (
17
+ PerformanceWarning,
18
+ T_IntOrNaN,
19
+ common_blockdim,
20
+ unknown_chunk_message,
21
+ )
22
+ from dask.blockwise import broadcast_dimensions
23
+ from dask.layers import ArrayBlockwiseDep
24
+ from dask.utils import cached_cumsum, funcname
25
+
26
+ _OBJECT_AT_PATTERN = re.compile(r"<.+? at 0x[0-9a-fA-F]+>")
27
+
28
+
29
+ def _collect_cached_property_names(cls):
30
+ """Collect all cached_property names from a class and its parents."""
31
+ names = set()
32
+ for parent in cls.__mro__:
33
+ for k, v in parent.__dict__.items():
34
+ if isinstance(v, functools.cached_property):
35
+ names.add(k)
36
+ return frozenset(names)
37
+
38
+
39
+ def _simplify_repr(op):
40
+ """Simplify operand representation for tree_repr display."""
41
+ if isinstance(op, np.ndarray):
42
+ return "<array>"
43
+ if isinstance(op, np.dtype):
44
+ return str(op)
45
+ if callable(op):
46
+ return funcname(op)
47
+ # Simplify objects that show "object at 0x..." in repr
48
+ r = repr(op)
49
+ if " object at 0x" in r:
50
+ return f"<{type(op).__name__}>"
51
+ return op
52
+
53
+
54
+ def _clean_header(header):
55
+ """Clean up any remaining verbose patterns in the header string."""
56
+ # Replace "<function foo at 0x...>" or "<X object at 0x...>" with "..."
57
+ return _OBJECT_AT_PATTERN.sub("...", header)
58
+
59
+
60
+ class ArrayExpr(SingletonExpr):
61
+ # Whether this expression can be fused with other blockwise operations.
62
+ # Override to True in subclasses that support fusion (Blockwise, Random, etc.)
63
+ _is_blockwise_fusable = False
64
+
65
+ def _all_input_block_ids(self, block_id):
66
+ """Return all input block_ids for dependencies.
67
+
68
+ Returns a dict mapping dep._name to a list of block_ids.
69
+ This handles the case where the same dependency is used multiple
70
+ times with different index mappings (e.g., da.dot(x, x)).
71
+
72
+ Subclasses like Blockwise override this to iterate over all args.
73
+ """
74
+ result = {}
75
+ for dep in self.dependencies():
76
+ dep_block_id = self._input_block_id(dep, block_id)
77
+ if dep._name not in result:
78
+ result[dep._name] = []
79
+ result[dep._name].append(dep_block_id)
80
+ return result
81
+
82
+ def _input_block_id(self, dep, block_id):
83
+ """Map output block_id to input block_id for a dependency.
84
+
85
+ Default implementation returns the same block_id.
86
+ Subclasses override for transformations like transpose.
87
+ """
88
+ return block_id
89
+
90
+ # Pre-computed set of cached_property names for efficient serialization
91
+ _cached_property_names: frozenset[str] = frozenset()
92
+
93
+ def __init_subclass__(cls, **kwargs):
94
+ super().__init_subclass__(**kwargs)
95
+ cls._cached_property_names = _collect_cached_property_names(cls)
96
+
97
+ def __reduce__(self):
98
+ import dask
99
+ from dask._expr import Expr
100
+
101
+ if dask.config.get("dask-expr-no-serialize", False):
102
+ raise RuntimeError(f"Serializing a {type(self)} object")
103
+ cache = {}
104
+ if type(self)._pickle_functools_cache:
105
+ for k in type(self)._cached_property_names:
106
+ if k in self.__dict__:
107
+ cache[k] = self.__dict__[k]
108
+ return Expr._reconstruct, (
109
+ type(self),
110
+ *self.operands,
111
+ self.deterministic_token,
112
+ cache,
113
+ )
114
+
115
+ def _operands_for_repr(self):
116
+ return []
117
+
118
+ def _tree_repr_lines(self, indent=0, recursive=True):
119
+ header = funcname(type(self)) + ":"
120
+ lines = []
121
+ for i, op in enumerate(self.operands):
122
+ if isinstance(op, ArrayExpr):
123
+ if recursive:
124
+ lines.extend(op._tree_repr_lines(2))
125
+ else:
126
+ op = _simplify_repr(op)
127
+ header = self._tree_repr_argument_construction(i, op, header)
128
+
129
+ header = _clean_header(header)
130
+ lines = [header] + lines
131
+ lines = [" " * indent + line for line in lines]
132
+ return lines
133
+
134
+ def _table(self, color=True):
135
+ """Display expression tree as a formatted table.
136
+
137
+ Requires the `rich` library to be installed.
138
+ """
139
+ from dask_array._visualize import expr_table
140
+
141
+ return expr_table(self, color=color)
142
+
143
+ def _repr_html_(self):
144
+ """Jupyter notebook display using rich table."""
145
+ try:
146
+ return self._table()._repr_html_()
147
+ except (ImportError, NotImplementedError):
148
+ return f"<pre>{chr(10).join(self._tree_repr_lines())}</pre>"
149
+
150
+ def __repr__(self):
151
+ """Return rich table representation if available, else simple repr."""
152
+ try:
153
+ return repr(self._table())
154
+ except (ImportError, NotImplementedError, Exception):
155
+ return str(self)
156
+
157
+ def pprint(self):
158
+ """Pretty print the expression tree using rich table if available."""
159
+ try:
160
+ self._table().print()
161
+ except (ImportError, NotImplementedError):
162
+ for line in self._tree_repr_lines():
163
+ print(line)
164
+
165
+ @cached_property
166
+ def shape(self) -> tuple[T_IntOrNaN, ...]:
167
+ return tuple(cached_cumsum(c, initial_zero=True)[-1] for c in self.chunks)
168
+
169
+ @cached_property
170
+ def ndim(self):
171
+ return len(self.shape)
172
+
173
+ @cached_property
174
+ def chunksize(self) -> tuple[T_IntOrNaN, ...]:
175
+ return tuple(max(c) for c in self.chunks)
176
+
177
+ @cached_property
178
+ def dtype(self):
179
+ if isinstance(self._meta, tuple):
180
+ dtype = self._meta[0].dtype
181
+ else:
182
+ dtype = self._meta.dtype
183
+ return dtype
184
+
185
+ @cached_property
186
+ def chunks(self):
187
+ if "chunks" in self._parameters:
188
+ return self.operand("chunks")
189
+ raise NotImplementedError("Subclass must implement 'chunks'")
190
+
191
+ @cached_property
192
+ def numblocks(self):
193
+ return tuple(map(len, self.chunks))
194
+
195
+ @cached_property
196
+ def size(self) -> T_IntOrNaN:
197
+ """Number of elements in array"""
198
+ return reduce(mul, self.shape, 1)
199
+
200
+ @property
201
+ def name(self):
202
+ return self._name
203
+
204
+ def __len__(self):
205
+ if not self.chunks:
206
+ raise TypeError("len() of unsized object")
207
+ if np.isnan(self.chunks[0]).any():
208
+ msg = f"Cannot call len() on object with unknown chunk size.{unknown_chunk_message}"
209
+ raise ValueError(msg)
210
+ return int(sum(self.chunks[0]))
211
+
212
+ @functools.cached_property
213
+ def _cached_keys(self):
214
+ out = self.lower_completely()
215
+
216
+ name, chunks, numblocks = out.name, out.chunks, out.numblocks
217
+
218
+ def keys(*args):
219
+ if not chunks:
220
+ return List(TaskRef((name,)))
221
+ ind = len(args)
222
+ if ind + 1 == len(numblocks):
223
+ result = List(*(TaskRef((name,) + args + (i,)) for i in range(numblocks[ind])))
224
+ else:
225
+ result = List(*(keys(*(args + (i,))) for i in range(numblocks[ind])))
226
+ return result
227
+
228
+ return keys()
229
+
230
+ def __dask_keys__(self):
231
+ key_refs = self._cached_keys
232
+
233
+ def unwrap(task):
234
+ if isinstance(task, List):
235
+ return [unwrap(t) for t in task.args]
236
+ return task.key
237
+
238
+ return unwrap(key_refs)
239
+
240
+ def __hash__(self):
241
+ return hash(self._name)
242
+
243
+ def optimize(self, fuse: bool = True):
244
+ expr = self.simplify().lower_completely()
245
+ if fuse:
246
+ expr = expr.fuse()
247
+ return expr
248
+
249
+ def fuse(self):
250
+ from dask_array._blockwise import optimize_blockwise_fusion_array
251
+
252
+ return optimize_blockwise_fusion_array(self)
253
+
254
+ def rechunk(
255
+ self,
256
+ chunks="auto",
257
+ threshold=None,
258
+ block_size_limit=None,
259
+ balance=False,
260
+ method=None,
261
+ ):
262
+ if self.ndim > 0 and all(s == 0 for s in self.shape):
263
+ return self
264
+
265
+ from dask_array._rechunk import Rechunk
266
+ from dask_array._core_utils import normalize_chunks
267
+ from dask_array._utils import validate_axis
268
+
269
+ # Pre-resolve chunks to check for no-op and avoid singleton caching issues
270
+ resolved_chunks = chunks
271
+ if isinstance(chunks, dict):
272
+ normalized_dict = {validate_axis(k, self.ndim): v for k, v in chunks.items()}
273
+ resolved_chunks = tuple(
274
+ (normalized_dict[i] if i in normalized_dict and normalized_dict[i] is not None else self.chunks[i])
275
+ for i in range(self.ndim)
276
+ )
277
+ if isinstance(resolved_chunks, (tuple, list)):
278
+ resolved_chunks = tuple(lc if lc is not None else rc for lc, rc in zip(resolved_chunks, self.chunks))
279
+ resolved_chunks = normalize_chunks(
280
+ resolved_chunks,
281
+ self.shape,
282
+ limit=block_size_limit,
283
+ dtype=self.dtype,
284
+ previous_chunks=self.chunks,
285
+ )
286
+
287
+ # No-op rechunk: if chunks already match, return self
288
+ if not balance and resolved_chunks == self.chunks:
289
+ return self
290
+
291
+ result = Rechunk(self, resolved_chunks, threshold, block_size_limit, balance, method)
292
+ # Ensure that chunks are compatible
293
+ result.chunks
294
+ return result
295
+
296
+ def finalize_compute(self):
297
+ return FinalizeComputeArray(self)
298
+
299
+
300
+ def coarse_blockdim(blockdims):
301
+ """Find the coarsest block dimension from a set of block dimensions.
302
+
303
+ Prefers the chunking with the fewest blocks, which results in larger
304
+ chunk sizes and fewer tasks. The finer-grained inputs will be rechunked
305
+ to match.
306
+
307
+ Unlike common_blockdim which finds the finest common divisor, this
308
+ function prefers larger chunks to minimize task overhead. However, if
309
+ the chunk boundaries don't align (one chunking's boundaries aren't a
310
+ subset of another's), falls back to common_blockdim behavior.
311
+
312
+ Parameters
313
+ ----------
314
+ blockdims : set of tuples
315
+ Set of chunk tuples for a single dimension
316
+
317
+ Returns
318
+ -------
319
+ tuple
320
+ The preferred chunk tuple (fewest blocks if alignable, otherwise
321
+ finest common divisor)
322
+
323
+ Examples
324
+ --------
325
+ >>> coarse_blockdim({(12, 12, 12, 12), (1, 1, 1, 1, 1)}) # prefer fewer chunks
326
+ (12, 12, 12, 12)
327
+ >>> coarse_blockdim({(10,), (5, 5)}) # single chunk preferred
328
+ (10,)
329
+ >>> coarse_blockdim({(4, 6), (6, 4)}) # incompatible - use common divisor
330
+ (4, 2, 4)
331
+ """
332
+ if not any(blockdims):
333
+ return ()
334
+
335
+ # Handle unknown chunks - same logic as common_blockdim
336
+ unknown_dims = [d for d in blockdims if np.isnan(sum(d))]
337
+ if unknown_dims:
338
+ all_lengths = {len(d) for d in blockdims}
339
+ if len(all_lengths) > 1:
340
+ raise ValueError(
341
+ "Chunks are unknown or misaligned along dimensions with missing values.\n\n"
342
+ "A possible solution:\n x.compute_chunk_sizes()"
343
+ )
344
+ return toolz.first(unknown_dims)
345
+
346
+ # Filter out singleton dimensions (size 1) - they don't constrain chunking
347
+ non_trivial_dims = {d for d in blockdims if len(d) > 1}
348
+
349
+ if len(non_trivial_dims) == 0:
350
+ # All are singletons, pick any
351
+ return max(blockdims, key=toolz.first)
352
+
353
+ if len(non_trivial_dims) == 1:
354
+ # Only one non-trivial, use it
355
+ return toolz.first(non_trivial_dims)
356
+
357
+ # Multiple non-trivial dimensions - verify they have the same total size
358
+ if len(set(map(sum, non_trivial_dims))) > 1:
359
+ raise ValueError("Chunks do not add up to same value", blockdims)
360
+
361
+ # Find the coarsest chunking (fewest blocks)
362
+ coarsest = min(non_trivial_dims, key=len)
363
+
364
+ # Check if all other chunkings have boundaries that align with the coarsest
365
+ # i.e., the coarsest boundaries are a subset of each other chunking's boundaries
366
+ coarsest_boundaries = set(np.cumsum(coarsest[:-1]))
367
+
368
+ for chunks in non_trivial_dims:
369
+ if chunks == coarsest:
370
+ continue
371
+ other_boundaries = set(np.cumsum(chunks[:-1]))
372
+ if not coarsest_boundaries.issubset(other_boundaries):
373
+ # Boundaries don't align - fall back to common_blockdim
374
+ return common_blockdim(blockdims)
375
+
376
+ # All boundaries align with the coarsest, so use it
377
+ return coarsest
378
+
379
+
380
+ def unify_chunks_expr(*args, warn=True):
381
+ # TODO(expr): This should probably be a dedicated expression
382
+ # This is the implementation that expects the inputs to be expressions, the public facing
383
+ # variant needs to sanitize the inputs
384
+ if not args:
385
+ return {}, [], False
386
+ arginds = list(toolz.partition(2, args))
387
+ arrays, inds = zip(*arginds)
388
+ if all(ind is None for ind in inds):
389
+ return {}, list(arrays), False
390
+ if all(ind == inds[0] for ind in inds) and all(a.chunks == arrays[0].chunks for a in arrays):
391
+ return dict(zip(inds[0], arrays[0].chunks)), arrays, False
392
+
393
+ nameinds = []
394
+ blockdim_dict = dict()
395
+ max_parts = 0
396
+ for a, ind in arginds:
397
+ # Skip scalars (empty tuple index), literals (None), and ArrayBlockwiseDep
398
+ if ind is not None and ind != () and not isinstance(a, ArrayBlockwiseDep):
399
+ nameinds.append((a.name, ind))
400
+ blockdim_dict[a.name] = a.chunks
401
+ max_parts = max(max_parts, math.prod(a.numblocks))
402
+ else:
403
+ nameinds.append((a, ind))
404
+
405
+ chunkss = broadcast_dimensions(nameinds, blockdim_dict, consolidate=coarse_blockdim)
406
+ nparts = math.prod(map(len, chunkss.values())) if chunkss else 0
407
+
408
+ if warn and nparts and nparts >= max_parts * 10:
409
+ warnings.warn(
410
+ f"Increasing number of chunks by factor of {int(nparts / max_parts)}",
411
+ PerformanceWarning,
412
+ stacklevel=3,
413
+ )
414
+
415
+ arrays = []
416
+ changed = False
417
+ for a, i in arginds:
418
+ if i is None or i == () or isinstance(a, ArrayBlockwiseDep):
419
+ pass # Skip scalars, literals, ArrayBlockwiseDep
420
+ else:
421
+ chunks = tuple(
422
+ (chunkss[j] if a.shape[n] > 1 else (a.shape[n],) if not np.isnan(sum(chunkss[j])) else None)
423
+ for n, j in enumerate(i)
424
+ )
425
+ if chunks != a.chunks and all(a.chunks):
426
+ # Skip rechunking known chunks to unknown - can't rechunk to nan sizes
427
+ target_has_nan = any(c is not None and np.isnan(sum(c)) for c in chunks)
428
+ source_is_known = not any(np.isnan(sum(c)) for c in a.chunks)
429
+ if not (target_has_nan and source_is_known):
430
+ a = a.rechunk(chunks)
431
+ changed = True
432
+ arrays.append(a)
433
+ return chunkss, arrays, changed
434
+
435
+
436
+ # Import Stack, Concatenate, and ConcatenateFinalize from their modules
437
+ from dask_array._concatenate import ConcatenateFinalize
438
+
439
+
440
+ def _copy_array(x):
441
+ """Copy an array to prevent mutation of graph-stored data."""
442
+ try:
443
+ return x.copy() # numpy, sparse, scipy.sparse
444
+ except AttributeError:
445
+ return x # Not an Array API object
446
+
447
+
448
+ class CopyArray(ArrayExpr):
449
+ """Copy an array to prevent mutation of the underlying data.
450
+
451
+ When a single-chunk array is computed, the result might be a reference
452
+ to data stored in the task graph. This expression ensures a copy is
453
+ made so modifications don't affect the graph.
454
+ """
455
+
456
+ _parameters = ["array"]
457
+
458
+ @functools.cached_property
459
+ def _name(self):
460
+ return f"copy-{self.deterministic_token}"
461
+
462
+ @functools.cached_property
463
+ def _meta(self):
464
+ return self.array._meta
465
+
466
+ @functools.cached_property
467
+ def chunks(self):
468
+ return self.array.chunks
469
+
470
+ @property
471
+ def dtype(self):
472
+ return self.array.dtype
473
+
474
+ def _layer(self):
475
+ # Generate copy tasks for each block
476
+ dsk = {}
477
+ for block_id in product(*[range(len(c)) for c in self.array.chunks]):
478
+ key = (self._name,) + block_id
479
+ input_key = (self.array._name,) + block_id
480
+ dsk[key] = Task(key, _copy_array, TaskRef(input_key))
481
+ return dsk
482
+
483
+
484
+ class FinalizeComputeArray(FinalizeCompute, ArrayExpr):
485
+ _parameters = ["arr"]
486
+
487
+ @cached_property
488
+ def chunks(self):
489
+ # Each dimension has a single chunk with the full size
490
+ return tuple((s,) for s in self.arr.shape)
491
+
492
+ def _simplify_down(self):
493
+ if all(n == 1 for n in self.arr.numblocks):
494
+ # Single-chunk array: wrap with CopyArray to prevent mutation
495
+ # of graph-stored data from affecting subsequent computes
496
+ return CopyArray(self.arr)
497
+ else:
498
+ # For arrays with unknown chunk sizes, use ConcatenateFinalize
499
+ # instead of rechunking (which requires known shapes)
500
+ if any(np.isnan(s) for s in self.arr.shape):
501
+ return ConcatenateFinalize(self.arr)
502
+ from dask_array._rechunk import Rechunk
503
+
504
+ return Rechunk(
505
+ self.arr,
506
+ tuple(-1 for _ in range(self.arr.ndim)),
507
+ method="tasks",
508
+ )
509
+
510
+
511
+ class ChunksOverride(ArrayExpr):
512
+ """Override chunks metadata for an array expression.
513
+
514
+ This creates an alias layer while preserving the underlying computation.
515
+ Useful when the actual output chunk sizes differ from what the expression
516
+ system infers (e.g., boolean indexing produces unknown chunk sizes).
517
+ """
518
+
519
+ _parameters = ["array", "_chunks"]
520
+
521
+ @functools.cached_property
522
+ def _name(self):
523
+ return f"chunks-override-{self.deterministic_token}"
524
+
525
+ @functools.cached_property
526
+ def _meta(self):
527
+ return self.array._meta
528
+
529
+ @functools.cached_property
530
+ def chunks(self):
531
+ return self._chunks
532
+
533
+ def _layer(self) -> dict:
534
+ from itertools import product
535
+
536
+ from dask._task_spec import Alias
537
+
538
+ dsk = {}
539
+ chunk_ranges = [range(len(c)) for c in self._chunks]
540
+ for idx in product(*chunk_ranges):
541
+ out_key = (self._name,) + idx
542
+ in_key = (self.array._name,) + idx
543
+ dsk[out_key] = Alias(out_key, in_key)
544
+ return dsk