dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_expr.py
ADDED
|
@@ -0,0 +1,544 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import math
|
|
5
|
+
import re
|
|
6
|
+
import warnings
|
|
7
|
+
from functools import cached_property, reduce
|
|
8
|
+
from itertools import product
|
|
9
|
+
from operator import mul
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
import toolz
|
|
13
|
+
|
|
14
|
+
from dask._expr import FinalizeCompute, SingletonExpr
|
|
15
|
+
from dask._task_spec import List, Task, TaskRef
|
|
16
|
+
from dask_array._core_utils import (
|
|
17
|
+
PerformanceWarning,
|
|
18
|
+
T_IntOrNaN,
|
|
19
|
+
common_blockdim,
|
|
20
|
+
unknown_chunk_message,
|
|
21
|
+
)
|
|
22
|
+
from dask.blockwise import broadcast_dimensions
|
|
23
|
+
from dask.layers import ArrayBlockwiseDep
|
|
24
|
+
from dask.utils import cached_cumsum, funcname
|
|
25
|
+
|
|
26
|
+
_OBJECT_AT_PATTERN = re.compile(r"<.+? at 0x[0-9a-fA-F]+>")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _collect_cached_property_names(cls):
|
|
30
|
+
"""Collect all cached_property names from a class and its parents."""
|
|
31
|
+
names = set()
|
|
32
|
+
for parent in cls.__mro__:
|
|
33
|
+
for k, v in parent.__dict__.items():
|
|
34
|
+
if isinstance(v, functools.cached_property):
|
|
35
|
+
names.add(k)
|
|
36
|
+
return frozenset(names)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _simplify_repr(op):
|
|
40
|
+
"""Simplify operand representation for tree_repr display."""
|
|
41
|
+
if isinstance(op, np.ndarray):
|
|
42
|
+
return "<array>"
|
|
43
|
+
if isinstance(op, np.dtype):
|
|
44
|
+
return str(op)
|
|
45
|
+
if callable(op):
|
|
46
|
+
return funcname(op)
|
|
47
|
+
# Simplify objects that show "object at 0x..." in repr
|
|
48
|
+
r = repr(op)
|
|
49
|
+
if " object at 0x" in r:
|
|
50
|
+
return f"<{type(op).__name__}>"
|
|
51
|
+
return op
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _clean_header(header):
|
|
55
|
+
"""Clean up any remaining verbose patterns in the header string."""
|
|
56
|
+
# Replace "<function foo at 0x...>" or "<X object at 0x...>" with "..."
|
|
57
|
+
return _OBJECT_AT_PATTERN.sub("...", header)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ArrayExpr(SingletonExpr):
|
|
61
|
+
# Whether this expression can be fused with other blockwise operations.
|
|
62
|
+
# Override to True in subclasses that support fusion (Blockwise, Random, etc.)
|
|
63
|
+
_is_blockwise_fusable = False
|
|
64
|
+
|
|
65
|
+
def _all_input_block_ids(self, block_id):
|
|
66
|
+
"""Return all input block_ids for dependencies.
|
|
67
|
+
|
|
68
|
+
Returns a dict mapping dep._name to a list of block_ids.
|
|
69
|
+
This handles the case where the same dependency is used multiple
|
|
70
|
+
times with different index mappings (e.g., da.dot(x, x)).
|
|
71
|
+
|
|
72
|
+
Subclasses like Blockwise override this to iterate over all args.
|
|
73
|
+
"""
|
|
74
|
+
result = {}
|
|
75
|
+
for dep in self.dependencies():
|
|
76
|
+
dep_block_id = self._input_block_id(dep, block_id)
|
|
77
|
+
if dep._name not in result:
|
|
78
|
+
result[dep._name] = []
|
|
79
|
+
result[dep._name].append(dep_block_id)
|
|
80
|
+
return result
|
|
81
|
+
|
|
82
|
+
def _input_block_id(self, dep, block_id):
|
|
83
|
+
"""Map output block_id to input block_id for a dependency.
|
|
84
|
+
|
|
85
|
+
Default implementation returns the same block_id.
|
|
86
|
+
Subclasses override for transformations like transpose.
|
|
87
|
+
"""
|
|
88
|
+
return block_id
|
|
89
|
+
|
|
90
|
+
# Pre-computed set of cached_property names for efficient serialization
|
|
91
|
+
_cached_property_names: frozenset[str] = frozenset()
|
|
92
|
+
|
|
93
|
+
def __init_subclass__(cls, **kwargs):
|
|
94
|
+
super().__init_subclass__(**kwargs)
|
|
95
|
+
cls._cached_property_names = _collect_cached_property_names(cls)
|
|
96
|
+
|
|
97
|
+
def __reduce__(self):
|
|
98
|
+
import dask
|
|
99
|
+
from dask._expr import Expr
|
|
100
|
+
|
|
101
|
+
if dask.config.get("dask-expr-no-serialize", False):
|
|
102
|
+
raise RuntimeError(f"Serializing a {type(self)} object")
|
|
103
|
+
cache = {}
|
|
104
|
+
if type(self)._pickle_functools_cache:
|
|
105
|
+
for k in type(self)._cached_property_names:
|
|
106
|
+
if k in self.__dict__:
|
|
107
|
+
cache[k] = self.__dict__[k]
|
|
108
|
+
return Expr._reconstruct, (
|
|
109
|
+
type(self),
|
|
110
|
+
*self.operands,
|
|
111
|
+
self.deterministic_token,
|
|
112
|
+
cache,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _operands_for_repr(self):
|
|
116
|
+
return []
|
|
117
|
+
|
|
118
|
+
def _tree_repr_lines(self, indent=0, recursive=True):
|
|
119
|
+
header = funcname(type(self)) + ":"
|
|
120
|
+
lines = []
|
|
121
|
+
for i, op in enumerate(self.operands):
|
|
122
|
+
if isinstance(op, ArrayExpr):
|
|
123
|
+
if recursive:
|
|
124
|
+
lines.extend(op._tree_repr_lines(2))
|
|
125
|
+
else:
|
|
126
|
+
op = _simplify_repr(op)
|
|
127
|
+
header = self._tree_repr_argument_construction(i, op, header)
|
|
128
|
+
|
|
129
|
+
header = _clean_header(header)
|
|
130
|
+
lines = [header] + lines
|
|
131
|
+
lines = [" " * indent + line for line in lines]
|
|
132
|
+
return lines
|
|
133
|
+
|
|
134
|
+
def _table(self, color=True):
|
|
135
|
+
"""Display expression tree as a formatted table.
|
|
136
|
+
|
|
137
|
+
Requires the `rich` library to be installed.
|
|
138
|
+
"""
|
|
139
|
+
from dask_array._visualize import expr_table
|
|
140
|
+
|
|
141
|
+
return expr_table(self, color=color)
|
|
142
|
+
|
|
143
|
+
def _repr_html_(self):
|
|
144
|
+
"""Jupyter notebook display using rich table."""
|
|
145
|
+
try:
|
|
146
|
+
return self._table()._repr_html_()
|
|
147
|
+
except (ImportError, NotImplementedError):
|
|
148
|
+
return f"<pre>{chr(10).join(self._tree_repr_lines())}</pre>"
|
|
149
|
+
|
|
150
|
+
def __repr__(self):
|
|
151
|
+
"""Return rich table representation if available, else simple repr."""
|
|
152
|
+
try:
|
|
153
|
+
return repr(self._table())
|
|
154
|
+
except (ImportError, NotImplementedError, Exception):
|
|
155
|
+
return str(self)
|
|
156
|
+
|
|
157
|
+
def pprint(self):
|
|
158
|
+
"""Pretty print the expression tree using rich table if available."""
|
|
159
|
+
try:
|
|
160
|
+
self._table().print()
|
|
161
|
+
except (ImportError, NotImplementedError):
|
|
162
|
+
for line in self._tree_repr_lines():
|
|
163
|
+
print(line)
|
|
164
|
+
|
|
165
|
+
@cached_property
|
|
166
|
+
def shape(self) -> tuple[T_IntOrNaN, ...]:
|
|
167
|
+
return tuple(cached_cumsum(c, initial_zero=True)[-1] for c in self.chunks)
|
|
168
|
+
|
|
169
|
+
@cached_property
|
|
170
|
+
def ndim(self):
|
|
171
|
+
return len(self.shape)
|
|
172
|
+
|
|
173
|
+
@cached_property
|
|
174
|
+
def chunksize(self) -> tuple[T_IntOrNaN, ...]:
|
|
175
|
+
return tuple(max(c) for c in self.chunks)
|
|
176
|
+
|
|
177
|
+
@cached_property
|
|
178
|
+
def dtype(self):
|
|
179
|
+
if isinstance(self._meta, tuple):
|
|
180
|
+
dtype = self._meta[0].dtype
|
|
181
|
+
else:
|
|
182
|
+
dtype = self._meta.dtype
|
|
183
|
+
return dtype
|
|
184
|
+
|
|
185
|
+
@cached_property
|
|
186
|
+
def chunks(self):
|
|
187
|
+
if "chunks" in self._parameters:
|
|
188
|
+
return self.operand("chunks")
|
|
189
|
+
raise NotImplementedError("Subclass must implement 'chunks'")
|
|
190
|
+
|
|
191
|
+
@cached_property
|
|
192
|
+
def numblocks(self):
|
|
193
|
+
return tuple(map(len, self.chunks))
|
|
194
|
+
|
|
195
|
+
@cached_property
|
|
196
|
+
def size(self) -> T_IntOrNaN:
|
|
197
|
+
"""Number of elements in array"""
|
|
198
|
+
return reduce(mul, self.shape, 1)
|
|
199
|
+
|
|
200
|
+
@property
|
|
201
|
+
def name(self):
|
|
202
|
+
return self._name
|
|
203
|
+
|
|
204
|
+
def __len__(self):
|
|
205
|
+
if not self.chunks:
|
|
206
|
+
raise TypeError("len() of unsized object")
|
|
207
|
+
if np.isnan(self.chunks[0]).any():
|
|
208
|
+
msg = f"Cannot call len() on object with unknown chunk size.{unknown_chunk_message}"
|
|
209
|
+
raise ValueError(msg)
|
|
210
|
+
return int(sum(self.chunks[0]))
|
|
211
|
+
|
|
212
|
+
@functools.cached_property
|
|
213
|
+
def _cached_keys(self):
|
|
214
|
+
out = self.lower_completely()
|
|
215
|
+
|
|
216
|
+
name, chunks, numblocks = out.name, out.chunks, out.numblocks
|
|
217
|
+
|
|
218
|
+
def keys(*args):
|
|
219
|
+
if not chunks:
|
|
220
|
+
return List(TaskRef((name,)))
|
|
221
|
+
ind = len(args)
|
|
222
|
+
if ind + 1 == len(numblocks):
|
|
223
|
+
result = List(*(TaskRef((name,) + args + (i,)) for i in range(numblocks[ind])))
|
|
224
|
+
else:
|
|
225
|
+
result = List(*(keys(*(args + (i,))) for i in range(numblocks[ind])))
|
|
226
|
+
return result
|
|
227
|
+
|
|
228
|
+
return keys()
|
|
229
|
+
|
|
230
|
+
def __dask_keys__(self):
|
|
231
|
+
key_refs = self._cached_keys
|
|
232
|
+
|
|
233
|
+
def unwrap(task):
|
|
234
|
+
if isinstance(task, List):
|
|
235
|
+
return [unwrap(t) for t in task.args]
|
|
236
|
+
return task.key
|
|
237
|
+
|
|
238
|
+
return unwrap(key_refs)
|
|
239
|
+
|
|
240
|
+
def __hash__(self):
|
|
241
|
+
return hash(self._name)
|
|
242
|
+
|
|
243
|
+
def optimize(self, fuse: bool = True):
|
|
244
|
+
expr = self.simplify().lower_completely()
|
|
245
|
+
if fuse:
|
|
246
|
+
expr = expr.fuse()
|
|
247
|
+
return expr
|
|
248
|
+
|
|
249
|
+
def fuse(self):
|
|
250
|
+
from dask_array._blockwise import optimize_blockwise_fusion_array
|
|
251
|
+
|
|
252
|
+
return optimize_blockwise_fusion_array(self)
|
|
253
|
+
|
|
254
|
+
def rechunk(
|
|
255
|
+
self,
|
|
256
|
+
chunks="auto",
|
|
257
|
+
threshold=None,
|
|
258
|
+
block_size_limit=None,
|
|
259
|
+
balance=False,
|
|
260
|
+
method=None,
|
|
261
|
+
):
|
|
262
|
+
if self.ndim > 0 and all(s == 0 for s in self.shape):
|
|
263
|
+
return self
|
|
264
|
+
|
|
265
|
+
from dask_array._rechunk import Rechunk
|
|
266
|
+
from dask_array._core_utils import normalize_chunks
|
|
267
|
+
from dask_array._utils import validate_axis
|
|
268
|
+
|
|
269
|
+
# Pre-resolve chunks to check for no-op and avoid singleton caching issues
|
|
270
|
+
resolved_chunks = chunks
|
|
271
|
+
if isinstance(chunks, dict):
|
|
272
|
+
normalized_dict = {validate_axis(k, self.ndim): v for k, v in chunks.items()}
|
|
273
|
+
resolved_chunks = tuple(
|
|
274
|
+
(normalized_dict[i] if i in normalized_dict and normalized_dict[i] is not None else self.chunks[i])
|
|
275
|
+
for i in range(self.ndim)
|
|
276
|
+
)
|
|
277
|
+
if isinstance(resolved_chunks, (tuple, list)):
|
|
278
|
+
resolved_chunks = tuple(lc if lc is not None else rc for lc, rc in zip(resolved_chunks, self.chunks))
|
|
279
|
+
resolved_chunks = normalize_chunks(
|
|
280
|
+
resolved_chunks,
|
|
281
|
+
self.shape,
|
|
282
|
+
limit=block_size_limit,
|
|
283
|
+
dtype=self.dtype,
|
|
284
|
+
previous_chunks=self.chunks,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
# No-op rechunk: if chunks already match, return self
|
|
288
|
+
if not balance and resolved_chunks == self.chunks:
|
|
289
|
+
return self
|
|
290
|
+
|
|
291
|
+
result = Rechunk(self, resolved_chunks, threshold, block_size_limit, balance, method)
|
|
292
|
+
# Ensure that chunks are compatible
|
|
293
|
+
result.chunks
|
|
294
|
+
return result
|
|
295
|
+
|
|
296
|
+
def finalize_compute(self):
|
|
297
|
+
return FinalizeComputeArray(self)
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
def coarse_blockdim(blockdims):
|
|
301
|
+
"""Find the coarsest block dimension from a set of block dimensions.
|
|
302
|
+
|
|
303
|
+
Prefers the chunking with the fewest blocks, which results in larger
|
|
304
|
+
chunk sizes and fewer tasks. The finer-grained inputs will be rechunked
|
|
305
|
+
to match.
|
|
306
|
+
|
|
307
|
+
Unlike common_blockdim which finds the finest common divisor, this
|
|
308
|
+
function prefers larger chunks to minimize task overhead. However, if
|
|
309
|
+
the chunk boundaries don't align (one chunking's boundaries aren't a
|
|
310
|
+
subset of another's), falls back to common_blockdim behavior.
|
|
311
|
+
|
|
312
|
+
Parameters
|
|
313
|
+
----------
|
|
314
|
+
blockdims : set of tuples
|
|
315
|
+
Set of chunk tuples for a single dimension
|
|
316
|
+
|
|
317
|
+
Returns
|
|
318
|
+
-------
|
|
319
|
+
tuple
|
|
320
|
+
The preferred chunk tuple (fewest blocks if alignable, otherwise
|
|
321
|
+
finest common divisor)
|
|
322
|
+
|
|
323
|
+
Examples
|
|
324
|
+
--------
|
|
325
|
+
>>> coarse_blockdim({(12, 12, 12, 12), (1, 1, 1, 1, 1)}) # prefer fewer chunks
|
|
326
|
+
(12, 12, 12, 12)
|
|
327
|
+
>>> coarse_blockdim({(10,), (5, 5)}) # single chunk preferred
|
|
328
|
+
(10,)
|
|
329
|
+
>>> coarse_blockdim({(4, 6), (6, 4)}) # incompatible - use common divisor
|
|
330
|
+
(4, 2, 4)
|
|
331
|
+
"""
|
|
332
|
+
if not any(blockdims):
|
|
333
|
+
return ()
|
|
334
|
+
|
|
335
|
+
# Handle unknown chunks - same logic as common_blockdim
|
|
336
|
+
unknown_dims = [d for d in blockdims if np.isnan(sum(d))]
|
|
337
|
+
if unknown_dims:
|
|
338
|
+
all_lengths = {len(d) for d in blockdims}
|
|
339
|
+
if len(all_lengths) > 1:
|
|
340
|
+
raise ValueError(
|
|
341
|
+
"Chunks are unknown or misaligned along dimensions with missing values.\n\n"
|
|
342
|
+
"A possible solution:\n x.compute_chunk_sizes()"
|
|
343
|
+
)
|
|
344
|
+
return toolz.first(unknown_dims)
|
|
345
|
+
|
|
346
|
+
# Filter out singleton dimensions (size 1) - they don't constrain chunking
|
|
347
|
+
non_trivial_dims = {d for d in blockdims if len(d) > 1}
|
|
348
|
+
|
|
349
|
+
if len(non_trivial_dims) == 0:
|
|
350
|
+
# All are singletons, pick any
|
|
351
|
+
return max(blockdims, key=toolz.first)
|
|
352
|
+
|
|
353
|
+
if len(non_trivial_dims) == 1:
|
|
354
|
+
# Only one non-trivial, use it
|
|
355
|
+
return toolz.first(non_trivial_dims)
|
|
356
|
+
|
|
357
|
+
# Multiple non-trivial dimensions - verify they have the same total size
|
|
358
|
+
if len(set(map(sum, non_trivial_dims))) > 1:
|
|
359
|
+
raise ValueError("Chunks do not add up to same value", blockdims)
|
|
360
|
+
|
|
361
|
+
# Find the coarsest chunking (fewest blocks)
|
|
362
|
+
coarsest = min(non_trivial_dims, key=len)
|
|
363
|
+
|
|
364
|
+
# Check if all other chunkings have boundaries that align with the coarsest
|
|
365
|
+
# i.e., the coarsest boundaries are a subset of each other chunking's boundaries
|
|
366
|
+
coarsest_boundaries = set(np.cumsum(coarsest[:-1]))
|
|
367
|
+
|
|
368
|
+
for chunks in non_trivial_dims:
|
|
369
|
+
if chunks == coarsest:
|
|
370
|
+
continue
|
|
371
|
+
other_boundaries = set(np.cumsum(chunks[:-1]))
|
|
372
|
+
if not coarsest_boundaries.issubset(other_boundaries):
|
|
373
|
+
# Boundaries don't align - fall back to common_blockdim
|
|
374
|
+
return common_blockdim(blockdims)
|
|
375
|
+
|
|
376
|
+
# All boundaries align with the coarsest, so use it
|
|
377
|
+
return coarsest
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def unify_chunks_expr(*args, warn=True):
|
|
381
|
+
# TODO(expr): This should probably be a dedicated expression
|
|
382
|
+
# This is the implementation that expects the inputs to be expressions, the public facing
|
|
383
|
+
# variant needs to sanitize the inputs
|
|
384
|
+
if not args:
|
|
385
|
+
return {}, [], False
|
|
386
|
+
arginds = list(toolz.partition(2, args))
|
|
387
|
+
arrays, inds = zip(*arginds)
|
|
388
|
+
if all(ind is None for ind in inds):
|
|
389
|
+
return {}, list(arrays), False
|
|
390
|
+
if all(ind == inds[0] for ind in inds) and all(a.chunks == arrays[0].chunks for a in arrays):
|
|
391
|
+
return dict(zip(inds[0], arrays[0].chunks)), arrays, False
|
|
392
|
+
|
|
393
|
+
nameinds = []
|
|
394
|
+
blockdim_dict = dict()
|
|
395
|
+
max_parts = 0
|
|
396
|
+
for a, ind in arginds:
|
|
397
|
+
# Skip scalars (empty tuple index), literals (None), and ArrayBlockwiseDep
|
|
398
|
+
if ind is not None and ind != () and not isinstance(a, ArrayBlockwiseDep):
|
|
399
|
+
nameinds.append((a.name, ind))
|
|
400
|
+
blockdim_dict[a.name] = a.chunks
|
|
401
|
+
max_parts = max(max_parts, math.prod(a.numblocks))
|
|
402
|
+
else:
|
|
403
|
+
nameinds.append((a, ind))
|
|
404
|
+
|
|
405
|
+
chunkss = broadcast_dimensions(nameinds, blockdim_dict, consolidate=coarse_blockdim)
|
|
406
|
+
nparts = math.prod(map(len, chunkss.values())) if chunkss else 0
|
|
407
|
+
|
|
408
|
+
if warn and nparts and nparts >= max_parts * 10:
|
|
409
|
+
warnings.warn(
|
|
410
|
+
f"Increasing number of chunks by factor of {int(nparts / max_parts)}",
|
|
411
|
+
PerformanceWarning,
|
|
412
|
+
stacklevel=3,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
arrays = []
|
|
416
|
+
changed = False
|
|
417
|
+
for a, i in arginds:
|
|
418
|
+
if i is None or i == () or isinstance(a, ArrayBlockwiseDep):
|
|
419
|
+
pass # Skip scalars, literals, ArrayBlockwiseDep
|
|
420
|
+
else:
|
|
421
|
+
chunks = tuple(
|
|
422
|
+
(chunkss[j] if a.shape[n] > 1 else (a.shape[n],) if not np.isnan(sum(chunkss[j])) else None)
|
|
423
|
+
for n, j in enumerate(i)
|
|
424
|
+
)
|
|
425
|
+
if chunks != a.chunks and all(a.chunks):
|
|
426
|
+
# Skip rechunking known chunks to unknown - can't rechunk to nan sizes
|
|
427
|
+
target_has_nan = any(c is not None and np.isnan(sum(c)) for c in chunks)
|
|
428
|
+
source_is_known = not any(np.isnan(sum(c)) for c in a.chunks)
|
|
429
|
+
if not (target_has_nan and source_is_known):
|
|
430
|
+
a = a.rechunk(chunks)
|
|
431
|
+
changed = True
|
|
432
|
+
arrays.append(a)
|
|
433
|
+
return chunkss, arrays, changed
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
# Import Stack, Concatenate, and ConcatenateFinalize from their modules
|
|
437
|
+
from dask_array._concatenate import ConcatenateFinalize
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
def _copy_array(x):
|
|
441
|
+
"""Copy an array to prevent mutation of graph-stored data."""
|
|
442
|
+
try:
|
|
443
|
+
return x.copy() # numpy, sparse, scipy.sparse
|
|
444
|
+
except AttributeError:
|
|
445
|
+
return x # Not an Array API object
|
|
446
|
+
|
|
447
|
+
|
|
448
|
+
class CopyArray(ArrayExpr):
|
|
449
|
+
"""Copy an array to prevent mutation of the underlying data.
|
|
450
|
+
|
|
451
|
+
When a single-chunk array is computed, the result might be a reference
|
|
452
|
+
to data stored in the task graph. This expression ensures a copy is
|
|
453
|
+
made so modifications don't affect the graph.
|
|
454
|
+
"""
|
|
455
|
+
|
|
456
|
+
_parameters = ["array"]
|
|
457
|
+
|
|
458
|
+
@functools.cached_property
|
|
459
|
+
def _name(self):
|
|
460
|
+
return f"copy-{self.deterministic_token}"
|
|
461
|
+
|
|
462
|
+
@functools.cached_property
|
|
463
|
+
def _meta(self):
|
|
464
|
+
return self.array._meta
|
|
465
|
+
|
|
466
|
+
@functools.cached_property
|
|
467
|
+
def chunks(self):
|
|
468
|
+
return self.array.chunks
|
|
469
|
+
|
|
470
|
+
@property
|
|
471
|
+
def dtype(self):
|
|
472
|
+
return self.array.dtype
|
|
473
|
+
|
|
474
|
+
def _layer(self):
|
|
475
|
+
# Generate copy tasks for each block
|
|
476
|
+
dsk = {}
|
|
477
|
+
for block_id in product(*[range(len(c)) for c in self.array.chunks]):
|
|
478
|
+
key = (self._name,) + block_id
|
|
479
|
+
input_key = (self.array._name,) + block_id
|
|
480
|
+
dsk[key] = Task(key, _copy_array, TaskRef(input_key))
|
|
481
|
+
return dsk
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
class FinalizeComputeArray(FinalizeCompute, ArrayExpr):
|
|
485
|
+
_parameters = ["arr"]
|
|
486
|
+
|
|
487
|
+
@cached_property
|
|
488
|
+
def chunks(self):
|
|
489
|
+
# Each dimension has a single chunk with the full size
|
|
490
|
+
return tuple((s,) for s in self.arr.shape)
|
|
491
|
+
|
|
492
|
+
def _simplify_down(self):
|
|
493
|
+
if all(n == 1 for n in self.arr.numblocks):
|
|
494
|
+
# Single-chunk array: wrap with CopyArray to prevent mutation
|
|
495
|
+
# of graph-stored data from affecting subsequent computes
|
|
496
|
+
return CopyArray(self.arr)
|
|
497
|
+
else:
|
|
498
|
+
# For arrays with unknown chunk sizes, use ConcatenateFinalize
|
|
499
|
+
# instead of rechunking (which requires known shapes)
|
|
500
|
+
if any(np.isnan(s) for s in self.arr.shape):
|
|
501
|
+
return ConcatenateFinalize(self.arr)
|
|
502
|
+
from dask_array._rechunk import Rechunk
|
|
503
|
+
|
|
504
|
+
return Rechunk(
|
|
505
|
+
self.arr,
|
|
506
|
+
tuple(-1 for _ in range(self.arr.ndim)),
|
|
507
|
+
method="tasks",
|
|
508
|
+
)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
class ChunksOverride(ArrayExpr):
|
|
512
|
+
"""Override chunks metadata for an array expression.
|
|
513
|
+
|
|
514
|
+
This creates an alias layer while preserving the underlying computation.
|
|
515
|
+
Useful when the actual output chunk sizes differ from what the expression
|
|
516
|
+
system infers (e.g., boolean indexing produces unknown chunk sizes).
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
_parameters = ["array", "_chunks"]
|
|
520
|
+
|
|
521
|
+
@functools.cached_property
|
|
522
|
+
def _name(self):
|
|
523
|
+
return f"chunks-override-{self.deterministic_token}"
|
|
524
|
+
|
|
525
|
+
@functools.cached_property
|
|
526
|
+
def _meta(self):
|
|
527
|
+
return self.array._meta
|
|
528
|
+
|
|
529
|
+
@functools.cached_property
|
|
530
|
+
def chunks(self):
|
|
531
|
+
return self._chunks
|
|
532
|
+
|
|
533
|
+
def _layer(self) -> dict:
|
|
534
|
+
from itertools import product
|
|
535
|
+
|
|
536
|
+
from dask._task_spec import Alias
|
|
537
|
+
|
|
538
|
+
dsk = {}
|
|
539
|
+
chunk_ranges = [range(len(c)) for c in self._chunks]
|
|
540
|
+
for idx in product(*chunk_ranges):
|
|
541
|
+
out_key = (self._name,) + idx
|
|
542
|
+
in_key = (self.array._name,) + idx
|
|
543
|
+
dsk[out_key] = Alias(out_key, in_key)
|
|
544
|
+
return dsk
|