dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_blockwise.py
ADDED
|
@@ -0,0 +1,1410 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import numbers
|
|
4
|
+
from collections.abc import Iterable
|
|
5
|
+
from itertools import product
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import tlz as toolz
|
|
9
|
+
|
|
10
|
+
from dask import is_dask_collection
|
|
11
|
+
from dask._task_spec import Task, TaskRef
|
|
12
|
+
from dask_array._expr import ArrayExpr, unify_chunks_expr
|
|
13
|
+
from dask_array._utils import compute_meta
|
|
14
|
+
from dask_array._core_utils import (
|
|
15
|
+
_elemwise_handle_where,
|
|
16
|
+
_enforce_dtype,
|
|
17
|
+
apply_infer_dtype,
|
|
18
|
+
broadcast_shapes,
|
|
19
|
+
is_scalar_for_elemwise,
|
|
20
|
+
normalize_arg,
|
|
21
|
+
)
|
|
22
|
+
from dask_array._utils import meta_from_array
|
|
23
|
+
from dask.blockwise import blockwise as core_blockwise
|
|
24
|
+
from dask.delayed import unpack_collections
|
|
25
|
+
from dask.layers import ArrayBlockwiseDep
|
|
26
|
+
from dask.tokenize import _tokenize_deterministic
|
|
27
|
+
from dask.utils import SerializableLock, cached_property, funcname
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Blockwise(ArrayExpr):
|
|
31
|
+
_parameters = [
|
|
32
|
+
"func",
|
|
33
|
+
"out_ind",
|
|
34
|
+
"name",
|
|
35
|
+
"token",
|
|
36
|
+
"dtype",
|
|
37
|
+
"adjust_chunks",
|
|
38
|
+
"new_axes",
|
|
39
|
+
"align_arrays",
|
|
40
|
+
"concatenate",
|
|
41
|
+
"_meta_provided",
|
|
42
|
+
"kwargs",
|
|
43
|
+
]
|
|
44
|
+
_defaults = {
|
|
45
|
+
"name": None,
|
|
46
|
+
"token": None,
|
|
47
|
+
"dtype": None,
|
|
48
|
+
"adjust_chunks": None,
|
|
49
|
+
"new_axes": None,
|
|
50
|
+
"align_arrays": True,
|
|
51
|
+
"concatenate": None,
|
|
52
|
+
"_meta_provided": None,
|
|
53
|
+
"kwargs": None,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
@cached_property
|
|
57
|
+
def args(self):
|
|
58
|
+
return self.operands[len(self._parameters) :]
|
|
59
|
+
|
|
60
|
+
@cached_property
|
|
61
|
+
def _meta_provided(self):
|
|
62
|
+
# We catch recursion errors if key starts with _meta, so define
|
|
63
|
+
# explicitly here
|
|
64
|
+
return self.operand("_meta_provided")
|
|
65
|
+
|
|
66
|
+
@cached_property
|
|
67
|
+
def _meta(self):
|
|
68
|
+
if self._meta_provided is not None:
|
|
69
|
+
# Handle tuple metas for multi-output functions (e.g., from apply_gufunc)
|
|
70
|
+
if isinstance(self._meta_provided, (tuple, list)):
|
|
71
|
+
return tuple(
|
|
72
|
+
meta_from_array(
|
|
73
|
+
m,
|
|
74
|
+
ndim=m.ndim,
|
|
75
|
+
dtype=getattr(m, "dtype", None),
|
|
76
|
+
)
|
|
77
|
+
for m in self._meta_provided
|
|
78
|
+
)
|
|
79
|
+
# Use getattr for dtype since some metas (e.g., DataFrame) don't have .dtype
|
|
80
|
+
return meta_from_array(
|
|
81
|
+
self._meta_provided,
|
|
82
|
+
ndim=self.ndim,
|
|
83
|
+
dtype=getattr(self._meta_provided, "dtype", None),
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
meta = compute_meta(self.func, self.operand("dtype"), *self.args[::2], **self.kwargs)
|
|
87
|
+
if meta is None:
|
|
88
|
+
# compute_meta failed (e.g., function has assertions on shapes)
|
|
89
|
+
# Fall back to a default meta based on the explicitly provided dtype
|
|
90
|
+
# (use operand to avoid recursion since dtype property may depend on _meta)
|
|
91
|
+
meta = meta_from_array(None, ndim=self.ndim, dtype=self.operand("dtype"))
|
|
92
|
+
return meta
|
|
93
|
+
|
|
94
|
+
@cached_property
|
|
95
|
+
def chunks(self):
|
|
96
|
+
if self.align_arrays:
|
|
97
|
+
chunkss, arrays, _ = unify_chunks_expr(*self.args)
|
|
98
|
+
else:
|
|
99
|
+
arginds = [(a, i) for (a, i) in toolz.partition(2, self.args) if i is not None]
|
|
100
|
+
chunkss = {}
|
|
101
|
+
# For each dimension, use the input chunking that has the most blocks;
|
|
102
|
+
# this will ensure that broadcasting works as expected, and in
|
|
103
|
+
# particular the number of blocks should be correct if the inputs are
|
|
104
|
+
# consistent.
|
|
105
|
+
for arg, ind in arginds:
|
|
106
|
+
for c, i in zip(arg.chunks, ind):
|
|
107
|
+
if i not in chunkss or len(c) > len(chunkss[i]):
|
|
108
|
+
chunkss[i] = c
|
|
109
|
+
|
|
110
|
+
for k, v in self.new_axes.items():
|
|
111
|
+
if not isinstance(v, tuple):
|
|
112
|
+
v = (v,)
|
|
113
|
+
chunkss[k] = v
|
|
114
|
+
|
|
115
|
+
chunks = [chunkss[i] for i in self.out_ind]
|
|
116
|
+
if self.adjust_chunks:
|
|
117
|
+
for i, ind in enumerate(self.out_ind):
|
|
118
|
+
if ind in self.adjust_chunks:
|
|
119
|
+
if callable(self.adjust_chunks[ind]):
|
|
120
|
+
chunks[i] = tuple(map(self.adjust_chunks[ind], chunks[i]))
|
|
121
|
+
elif isinstance(self.adjust_chunks[ind], numbers.Integral):
|
|
122
|
+
chunks[i] = tuple(self.adjust_chunks[ind] for _ in chunks[i])
|
|
123
|
+
elif isinstance(self.adjust_chunks[ind], (tuple, list)):
|
|
124
|
+
if len(self.adjust_chunks[ind]) != len(chunks[i]):
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Dimension {i} has {len(chunks[i])} blocks, adjust_chunks "
|
|
127
|
+
f"specified with {len(self.adjust_chunks[ind])} blocks"
|
|
128
|
+
)
|
|
129
|
+
chunks[i] = tuple(self.adjust_chunks[ind])
|
|
130
|
+
else:
|
|
131
|
+
raise NotImplementedError("adjust_chunks values must be callable, int, or tuple")
|
|
132
|
+
chunks = tuple(chunks)
|
|
133
|
+
return tuple(map(tuple, chunks))
|
|
134
|
+
|
|
135
|
+
@cached_property
|
|
136
|
+
def dtype(self):
|
|
137
|
+
return super().dtype
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
def _is_blockwise_fusable(self):
|
|
141
|
+
# Blockwise with concatenate requires special handling not yet implemented
|
|
142
|
+
if self.concatenate:
|
|
143
|
+
return False
|
|
144
|
+
# Blockwise with Delayed operands can't be fused because FusedBlockwise
|
|
145
|
+
# doesn't properly track them as external dependencies
|
|
146
|
+
from dask.delayed import Delayed
|
|
147
|
+
|
|
148
|
+
if any(isinstance(op, Delayed) for op in self.operands):
|
|
149
|
+
return False
|
|
150
|
+
|
|
151
|
+
# Check for contracted dimensions with multiple blocks
|
|
152
|
+
# These are dimensions in input but not in output - we can only fuse
|
|
153
|
+
# if they have a single block
|
|
154
|
+
out_idx_set = set(self.out_ind)
|
|
155
|
+
if self.new_axes:
|
|
156
|
+
out_idx_set |= set(self.new_axes.keys())
|
|
157
|
+
for arr, ind in toolz.partition(2, self.args):
|
|
158
|
+
if ind is not None and hasattr(arr, "numblocks"):
|
|
159
|
+
for dim, i in enumerate(ind):
|
|
160
|
+
if i not in out_idx_set and arr.numblocks[dim] > 1:
|
|
161
|
+
# Contracted dimension with multiple blocks can't be fused
|
|
162
|
+
return False
|
|
163
|
+
return True
|
|
164
|
+
|
|
165
|
+
def _idx_to_block(self, block_id: tuple[int, ...]) -> dict:
|
|
166
|
+
"""Map symbolic indices to output block coordinates."""
|
|
167
|
+
idx_to_block = {idx: block_id[dim] for dim, idx in enumerate(self.out_ind)}
|
|
168
|
+
for idx in self.new_axes:
|
|
169
|
+
idx_to_block[idx] = 0
|
|
170
|
+
return idx_to_block
|
|
171
|
+
|
|
172
|
+
def _dep_block_id(self, arr, ind, idx_to_block: dict) -> tuple[int, ...]:
|
|
173
|
+
"""Compute block_id for a dependency, applying modulo for broadcasting."""
|
|
174
|
+
return _compute_block_id(ind, idx_to_block, arr.numblocks)
|
|
175
|
+
|
|
176
|
+
def _task(self, key, block_id: tuple[int, ...]):
|
|
177
|
+
"""Generate task for a specific output block."""
|
|
178
|
+
from dask._task_spec import Task, TaskRef
|
|
179
|
+
from dask.layers import ArrayBlockwiseDep
|
|
180
|
+
|
|
181
|
+
if self.concatenate:
|
|
182
|
+
raise NotImplementedError("Blockwise with concatenate not supported for fusion")
|
|
183
|
+
|
|
184
|
+
idx_to_block = self._idx_to_block(block_id)
|
|
185
|
+
|
|
186
|
+
args = []
|
|
187
|
+
for arr, ind in toolz.partition(2, self.args):
|
|
188
|
+
if ind is None:
|
|
189
|
+
args.append(arr)
|
|
190
|
+
elif isinstance(arr, ArrayBlockwiseDep):
|
|
191
|
+
numblocks = tuple(len(c) for c in arr.chunks)
|
|
192
|
+
input_block_id = _compute_block_id(ind, idx_to_block, numblocks)
|
|
193
|
+
args.append(arr[input_block_id])
|
|
194
|
+
else:
|
|
195
|
+
input_block_id = self._dep_block_id(arr, ind, idx_to_block)
|
|
196
|
+
args.append(TaskRef((arr._name, *input_block_id)))
|
|
197
|
+
|
|
198
|
+
return Task(key, self.func, *args, **self.kwargs)
|
|
199
|
+
|
|
200
|
+
def _input_block_id(self, dep, block_id: tuple[int, ...]) -> tuple[int, ...]:
|
|
201
|
+
"""Map output block_id to input block_id for a dependency."""
|
|
202
|
+
idx_to_block = self._idx_to_block(block_id)
|
|
203
|
+
for arr, ind in toolz.partition(2, self.args):
|
|
204
|
+
if ind is not None and hasattr(arr, "_name") and arr._name == dep._name:
|
|
205
|
+
return self._dep_block_id(arr, ind, idx_to_block)
|
|
206
|
+
return block_id
|
|
207
|
+
|
|
208
|
+
def _all_input_block_ids(self, block_id: tuple[int, ...]) -> dict:
|
|
209
|
+
"""Return all input block_ids for dependencies.
|
|
210
|
+
|
|
211
|
+
Handles case where same dependency appears multiple times with
|
|
212
|
+
different index mappings (e.g., da.dot(x, x)).
|
|
213
|
+
"""
|
|
214
|
+
idx_to_block = self._idx_to_block(block_id)
|
|
215
|
+
result: dict = {}
|
|
216
|
+
for arr, ind in toolz.partition(2, self.args):
|
|
217
|
+
if ind is not None and hasattr(arr, "_name"):
|
|
218
|
+
dep_block_id = self._dep_block_id(arr, ind, idx_to_block)
|
|
219
|
+
if arr._name not in result:
|
|
220
|
+
result[arr._name] = []
|
|
221
|
+
result[arr._name].append(dep_block_id)
|
|
222
|
+
return result
|
|
223
|
+
|
|
224
|
+
def __dask_tokenize__(self):
|
|
225
|
+
if not self._determ_token:
|
|
226
|
+
# Handle non-serializable locks in kwargs by using their id()
|
|
227
|
+
kwargs_token = {}
|
|
228
|
+
for k, v in self.kwargs.items():
|
|
229
|
+
if k == "lock" and v and not isinstance(v, (bool, SerializableLock)):
|
|
230
|
+
kwargs_token[k] = ("lock-id", id(v))
|
|
231
|
+
else:
|
|
232
|
+
kwargs_token[k] = v
|
|
233
|
+
|
|
234
|
+
self._determ_token = _tokenize_deterministic(
|
|
235
|
+
self.func,
|
|
236
|
+
self.out_ind,
|
|
237
|
+
self.dtype,
|
|
238
|
+
self.adjust_chunks,
|
|
239
|
+
self.new_axes,
|
|
240
|
+
self.align_arrays,
|
|
241
|
+
self.concatenate,
|
|
242
|
+
*self.args,
|
|
243
|
+
**kwargs_token,
|
|
244
|
+
)
|
|
245
|
+
return self._determ_token
|
|
246
|
+
|
|
247
|
+
@cached_property
|
|
248
|
+
def _name(self):
|
|
249
|
+
# Always include deterministic_token suffix to ensure:
|
|
250
|
+
# 1. Different expressions with same user-provided name are distinguishable
|
|
251
|
+
# 2. lower_completely can detect when operands change (via name change)
|
|
252
|
+
prefix = (
|
|
253
|
+
self.operand("name")
|
|
254
|
+
if "name" in self._parameters and self.operand("name")
|
|
255
|
+
else (self.token or funcname(self.func).strip("_"))
|
|
256
|
+
)
|
|
257
|
+
return f"{prefix}-{self.deterministic_token}"
|
|
258
|
+
|
|
259
|
+
def _layer(self):
|
|
260
|
+
arginds = [(a, i) for (a, i) in toolz.partition(2, self.args)]
|
|
261
|
+
|
|
262
|
+
numblocks = {}
|
|
263
|
+
dependencies = []
|
|
264
|
+
arrays = []
|
|
265
|
+
|
|
266
|
+
# Normalize arguments
|
|
267
|
+
argindsstr = []
|
|
268
|
+
|
|
269
|
+
for arg, ind in arginds:
|
|
270
|
+
if ind is None:
|
|
271
|
+
# Literal argument (not an array) - normalize it
|
|
272
|
+
arg = normalize_arg(arg)
|
|
273
|
+
arg, collections = unpack_collections(arg)
|
|
274
|
+
dependencies.extend(collections)
|
|
275
|
+
else:
|
|
276
|
+
if hasattr(arg, "ndim") and hasattr(ind, "__len__") and arg.ndim != len(ind):
|
|
277
|
+
raise ValueError(f"Index string {ind} does not match array dimension {arg.ndim}")
|
|
278
|
+
# TODO(expr): this class is a confusing crutch to pass arguments to the
|
|
279
|
+
# graph, we should write them directly into the graph
|
|
280
|
+
if not isinstance(arg, ArrayBlockwiseDep):
|
|
281
|
+
numblocks[arg.name] = arg.numblocks
|
|
282
|
+
arrays.append(arg)
|
|
283
|
+
arg = arg.name
|
|
284
|
+
argindsstr.extend((arg, ind))
|
|
285
|
+
|
|
286
|
+
# Normalize keyword arguments
|
|
287
|
+
kwargs2 = {}
|
|
288
|
+
for k, v in self.kwargs.items():
|
|
289
|
+
v = normalize_arg(v)
|
|
290
|
+
v, collections = unpack_collections(v)
|
|
291
|
+
dependencies.extend(collections)
|
|
292
|
+
kwargs2[k] = v
|
|
293
|
+
|
|
294
|
+
# TODO(expr): Highlevelgraph :(
|
|
295
|
+
graph = core_blockwise(
|
|
296
|
+
self.func,
|
|
297
|
+
self._name,
|
|
298
|
+
self.out_ind,
|
|
299
|
+
*argindsstr,
|
|
300
|
+
numblocks=numblocks,
|
|
301
|
+
dependencies=dependencies,
|
|
302
|
+
new_axes=self.new_axes,
|
|
303
|
+
concatenate=self.concatenate,
|
|
304
|
+
**kwargs2,
|
|
305
|
+
)
|
|
306
|
+
result = dict(graph)
|
|
307
|
+
# Merge in dependency graphs (from delayed objects, etc.)
|
|
308
|
+
for dep in dependencies:
|
|
309
|
+
if is_dask_collection(dep):
|
|
310
|
+
result.update(dep.__dask_graph__())
|
|
311
|
+
return result
|
|
312
|
+
|
|
313
|
+
def _lower(self):
|
|
314
|
+
if self.align_arrays:
|
|
315
|
+
_, arrays, changed = unify_chunks_expr(*self.args)
|
|
316
|
+
if changed:
|
|
317
|
+
args = []
|
|
318
|
+
for idx, arr in zip(self.args[1::2], arrays):
|
|
319
|
+
args.extend([arr, idx])
|
|
320
|
+
return type(self)(*self.operands[: len(self._parameters)], *args)
|
|
321
|
+
|
|
322
|
+
def _simplify_up(self, parent, dependents):
|
|
323
|
+
"""Allow slice and shuffle operations to push through Blockwise."""
|
|
324
|
+
from dask_array._shuffle import Shuffle
|
|
325
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
326
|
+
|
|
327
|
+
if isinstance(parent, SliceSlicesIntegers):
|
|
328
|
+
return self._accept_slice(parent)
|
|
329
|
+
if isinstance(parent, Shuffle):
|
|
330
|
+
return self._accept_shuffle(parent)
|
|
331
|
+
return None
|
|
332
|
+
|
|
333
|
+
def _accept_shuffle(self, shuffle_expr):
|
|
334
|
+
"""Accept a shuffle being pushed through Blockwise.
|
|
335
|
+
|
|
336
|
+
Push shuffle through when shuffle axis is not modified by blockwise.
|
|
337
|
+
"""
|
|
338
|
+
import toolz
|
|
339
|
+
|
|
340
|
+
from dask_array._shuffle import Shuffle
|
|
341
|
+
|
|
342
|
+
axis = shuffle_expr.axis
|
|
343
|
+
out_ind = self.out_ind
|
|
344
|
+
|
|
345
|
+
# Get the index label for the shuffle axis
|
|
346
|
+
shuffle_ind = out_ind[axis]
|
|
347
|
+
|
|
348
|
+
# Can't push through if shuffle axis is a new axis or has adjusted chunks
|
|
349
|
+
new_axes = getattr(self, "new_axes", None)
|
|
350
|
+
if new_axes and shuffle_ind in new_axes:
|
|
351
|
+
return None
|
|
352
|
+
adjust_chunks = getattr(self, "adjust_chunks", None)
|
|
353
|
+
if adjust_chunks and shuffle_ind in adjust_chunks:
|
|
354
|
+
return None
|
|
355
|
+
|
|
356
|
+
# Shuffle each array input on the corresponding axis
|
|
357
|
+
new_args = []
|
|
358
|
+
for arr, ind in toolz.partition(2, self.args):
|
|
359
|
+
if ind is None:
|
|
360
|
+
# Literal argument
|
|
361
|
+
new_args.extend([arr, ind])
|
|
362
|
+
elif shuffle_ind in ind:
|
|
363
|
+
# Find the axis in this input that corresponds to shuffle_ind
|
|
364
|
+
input_axis = ind.index(shuffle_ind)
|
|
365
|
+
shuffled = Shuffle(arr, shuffle_expr.indexer, input_axis, shuffle_expr.operand("name"))
|
|
366
|
+
new_args.extend([shuffled, ind])
|
|
367
|
+
else:
|
|
368
|
+
# This input doesn't have the shuffle dimension
|
|
369
|
+
new_args.extend([arr, ind])
|
|
370
|
+
|
|
371
|
+
return Blockwise(
|
|
372
|
+
self.func,
|
|
373
|
+
self.out_ind,
|
|
374
|
+
self.operand("name"),
|
|
375
|
+
self.operand("token"),
|
|
376
|
+
self.operand("dtype"),
|
|
377
|
+
self.operand("adjust_chunks"),
|
|
378
|
+
self.operand("new_axes"),
|
|
379
|
+
self.operand("align_arrays"),
|
|
380
|
+
self.operand("concatenate"),
|
|
381
|
+
self.operand("_meta_provided"),
|
|
382
|
+
self.operand("kwargs"),
|
|
383
|
+
*new_args,
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
def _accept_slice(self, slice_expr):
|
|
387
|
+
"""Accept a slice being pushed through this Blockwise.
|
|
388
|
+
|
|
389
|
+
This optimization is safe when:
|
|
390
|
+
- The blockwise doesn't adjust chunk sizes on sliced dimensions
|
|
391
|
+
- The blockwise doesn't add new axes on sliced dimensions
|
|
392
|
+
- The slice uses only slices or integers (no newaxis)
|
|
393
|
+
|
|
394
|
+
When adjust_chunks is present, we do "coarse" optimization:
|
|
395
|
+
- Calculate which output blocks the slice needs
|
|
396
|
+
- Select only the corresponding input blocks
|
|
397
|
+
- Wrap with adjusted output slice if needed
|
|
398
|
+
"""
|
|
399
|
+
from numbers import Integral
|
|
400
|
+
|
|
401
|
+
from dask_array._new_collection import new_collection
|
|
402
|
+
|
|
403
|
+
out_ind = self.out_ind
|
|
404
|
+
index = slice_expr.index
|
|
405
|
+
|
|
406
|
+
# Don't handle None/newaxis
|
|
407
|
+
if any(idx is None for idx in index):
|
|
408
|
+
return None
|
|
409
|
+
|
|
410
|
+
# Pad index to full output length
|
|
411
|
+
full_index = index + (slice(None),) * (len(out_ind) - len(index))
|
|
412
|
+
|
|
413
|
+
# Find which output axes have non-trivial slices
|
|
414
|
+
sliced_axes = {i for i, idx in enumerate(full_index) if isinstance(idx, Integral) or idx != slice(None)}
|
|
415
|
+
sliced_indices = {out_ind[axis] for axis in sliced_axes if axis < len(out_ind)}
|
|
416
|
+
|
|
417
|
+
# Use getattr since subclasses may define as class attribute or property
|
|
418
|
+
adjust_chunks = getattr(self, "adjust_chunks", None)
|
|
419
|
+
needs_coarse = False
|
|
420
|
+
if adjust_chunks:
|
|
421
|
+
# Check if we're slicing an adjusted dimension
|
|
422
|
+
adjusted_indices = set(adjust_chunks.keys())
|
|
423
|
+
if sliced_indices & adjusted_indices:
|
|
424
|
+
# Use coarse slice optimization
|
|
425
|
+
needs_coarse = True
|
|
426
|
+
|
|
427
|
+
# Don't handle if blockwise adds new axes and we're slicing those axes
|
|
428
|
+
new_axes = getattr(self, "new_axes", None)
|
|
429
|
+
if new_axes:
|
|
430
|
+
new_axis_indices = set(new_axes.keys())
|
|
431
|
+
if sliced_indices & new_axis_indices:
|
|
432
|
+
return None
|
|
433
|
+
|
|
434
|
+
# For coarse optimization, calculate block-aligned slices
|
|
435
|
+
if needs_coarse:
|
|
436
|
+
return self._accept_slice_coarse(slice_expr, full_index, adjust_chunks)
|
|
437
|
+
|
|
438
|
+
# Convert integers to size-1 slices for pushdown
|
|
439
|
+
slice_index = tuple(slice(idx, idx + 1) if isinstance(idx, Integral) else idx for idx in full_index)
|
|
440
|
+
has_integers = any(isinstance(idx, Integral) for idx in full_index)
|
|
441
|
+
|
|
442
|
+
# For subclasses with a single "array" parameter, use substitute_parameters
|
|
443
|
+
if "array" in type(self)._parameters:
|
|
444
|
+
# Map output slice indices to input dimensions
|
|
445
|
+
arg_ind = tuple(range(self.array.ndim)) # Input indices
|
|
446
|
+
arg_slices = []
|
|
447
|
+
for dim_idx in arg_ind:
|
|
448
|
+
try:
|
|
449
|
+
out_pos = out_ind.index(dim_idx)
|
|
450
|
+
arg_slices.append(slice_index[out_pos])
|
|
451
|
+
except ValueError:
|
|
452
|
+
arg_slices.append(slice(None))
|
|
453
|
+
|
|
454
|
+
sliced_input = new_collection(self.array)[tuple(arg_slices)]
|
|
455
|
+
result = self.substitute_parameters({"array": sliced_input.expr})
|
|
456
|
+
else:
|
|
457
|
+
# For base Blockwise with multiple inputs in args
|
|
458
|
+
args = self.args
|
|
459
|
+
new_args = []
|
|
460
|
+
for i in range(0, len(args), 2):
|
|
461
|
+
arg = args[i]
|
|
462
|
+
arg_ind = args[i + 1]
|
|
463
|
+
|
|
464
|
+
if arg_ind is None:
|
|
465
|
+
new_args.extend([arg, arg_ind])
|
|
466
|
+
else:
|
|
467
|
+
arg_slices = []
|
|
468
|
+
for dim_idx in arg_ind:
|
|
469
|
+
try:
|
|
470
|
+
out_pos = out_ind.index(dim_idx)
|
|
471
|
+
arg_slices.append(slice_index[out_pos])
|
|
472
|
+
except ValueError:
|
|
473
|
+
arg_slices.append(slice(None))
|
|
474
|
+
|
|
475
|
+
sliced_arg = new_collection(arg)[tuple(arg_slices)]
|
|
476
|
+
new_args.extend([sliced_arg.expr, arg_ind])
|
|
477
|
+
|
|
478
|
+
result = Blockwise(
|
|
479
|
+
self.func,
|
|
480
|
+
self.out_ind,
|
|
481
|
+
self.operand("name"),
|
|
482
|
+
self.operand("token"),
|
|
483
|
+
self.operand("dtype"),
|
|
484
|
+
self.operand("adjust_chunks"),
|
|
485
|
+
self.operand("new_axes"),
|
|
486
|
+
self.operand("align_arrays"),
|
|
487
|
+
self.operand("concatenate"),
|
|
488
|
+
self.operand("_meta_provided"),
|
|
489
|
+
self.operand("kwargs"),
|
|
490
|
+
*new_args,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
# If we converted integers to slices, extract with [0] to restore dimensions
|
|
494
|
+
if has_integers:
|
|
495
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
496
|
+
|
|
497
|
+
extract_index = tuple(0 if isinstance(idx, Integral) else slice(None) for idx in full_index)
|
|
498
|
+
return SliceSlicesIntegers(result, extract_index, slice_expr.allow_getitem_optimization)
|
|
499
|
+
|
|
500
|
+
return result
|
|
501
|
+
|
|
502
|
+
def _accept_slice_coarse(self, slice_expr, full_index, adjust_chunks):
|
|
503
|
+
"""Coarse slice optimization for blockwise with adjust_chunks.
|
|
504
|
+
|
|
505
|
+
When chunk sizes change between input and output, we can't push the
|
|
506
|
+
exact slice through. But we CAN select only the input blocks that
|
|
507
|
+
contribute to the needed output blocks.
|
|
508
|
+
|
|
509
|
+
Algorithm:
|
|
510
|
+
1. For each adjusted axis, find which OUTPUT blocks the slice needs
|
|
511
|
+
2. Map to corresponding INPUT blocks (same block indices for blockwise)
|
|
512
|
+
3. Create block-aligned input slices
|
|
513
|
+
4. Wrap output with adjusted slice if original doesn't align to blocks
|
|
514
|
+
"""
|
|
515
|
+
from numbers import Integral
|
|
516
|
+
|
|
517
|
+
from dask_array._new_collection import new_collection
|
|
518
|
+
from dask.utils import cached_cumsum
|
|
519
|
+
|
|
520
|
+
def find_block_range(cumsum, start, stop):
|
|
521
|
+
"""Find (first_block, last_block) indices for range [start, stop)."""
|
|
522
|
+
# First block containing element at 'start'
|
|
523
|
+
first_block = np.searchsorted(cumsum[1:], start, side="right")
|
|
524
|
+
# Last block containing element before 'stop'
|
|
525
|
+
last_block = np.searchsorted(cumsum[1:], stop - 1, side="right") if stop > start else first_block - 1
|
|
526
|
+
if first_block >= len(cumsum) - 1:
|
|
527
|
+
return None, None # Out of bounds
|
|
528
|
+
return int(first_block), int(last_block)
|
|
529
|
+
|
|
530
|
+
out_ind = self.out_ind
|
|
531
|
+
out_chunks = self.chunks
|
|
532
|
+
|
|
533
|
+
# For each output axis, compute block range and output adjustment
|
|
534
|
+
block_ranges = [] # (first_block, last_block) for each axis
|
|
535
|
+
output_adjustments = [] # Adjusted slices to apply to output
|
|
536
|
+
|
|
537
|
+
for axis, idx in enumerate(full_index):
|
|
538
|
+
chunks = out_chunks[axis]
|
|
539
|
+
dim_size = sum(chunks)
|
|
540
|
+
cumsum = np.array(list(cached_cumsum(chunks, initial_zero=True)))
|
|
541
|
+
|
|
542
|
+
if idx == slice(None):
|
|
543
|
+
block_ranges.append(None) # All blocks
|
|
544
|
+
output_adjustments.append(slice(None))
|
|
545
|
+
elif isinstance(idx, Integral):
|
|
546
|
+
pos_idx = idx if idx >= 0 else idx + dim_size
|
|
547
|
+
first, last = find_block_range(cumsum, pos_idx, pos_idx + 1)
|
|
548
|
+
if first is None:
|
|
549
|
+
return None # Out of bounds
|
|
550
|
+
block_ranges.append((first, last))
|
|
551
|
+
output_adjustments.append(pos_idx - cumsum[first])
|
|
552
|
+
elif isinstance(idx, slice):
|
|
553
|
+
start, stop, step = idx.indices(dim_size)
|
|
554
|
+
if step != 1:
|
|
555
|
+
return None # Non-unit step not supported
|
|
556
|
+
|
|
557
|
+
first, last = find_block_range(cumsum, start, stop)
|
|
558
|
+
if first is None:
|
|
559
|
+
block_ranges.append((0, -1)) # Empty
|
|
560
|
+
output_adjustments.append(slice(0, 0))
|
|
561
|
+
else:
|
|
562
|
+
block_ranges.append((first, last))
|
|
563
|
+
coarse_start = cumsum[first]
|
|
564
|
+
coarse_end = cumsum[last + 1]
|
|
565
|
+
adj_start = start - coarse_start
|
|
566
|
+
adj_stop = stop - coarse_start
|
|
567
|
+
if adj_start == 0 and adj_stop == coarse_end - coarse_start:
|
|
568
|
+
output_adjustments.append(slice(None))
|
|
569
|
+
else:
|
|
570
|
+
output_adjustments.append(slice(adj_start, adj_stop))
|
|
571
|
+
else:
|
|
572
|
+
return None
|
|
573
|
+
|
|
574
|
+
# Map output block ranges to input slices
|
|
575
|
+
args = self.args
|
|
576
|
+
new_args = []
|
|
577
|
+
|
|
578
|
+
for i in range(0, len(args), 2):
|
|
579
|
+
arg = args[i]
|
|
580
|
+
arg_ind = args[i + 1]
|
|
581
|
+
|
|
582
|
+
if arg_ind is None:
|
|
583
|
+
new_args.extend([arg, arg_ind])
|
|
584
|
+
elif not hasattr(arg, "_meta"):
|
|
585
|
+
# Non-array args (e.g., ArrayValuesDep for block_info) can't be sliced
|
|
586
|
+
return None
|
|
587
|
+
else:
|
|
588
|
+
arg_slices = []
|
|
589
|
+
for dim_idx, in_ind in enumerate(arg_ind):
|
|
590
|
+
try:
|
|
591
|
+
out_pos = out_ind.index(in_ind)
|
|
592
|
+
br = block_ranges[out_pos]
|
|
593
|
+
|
|
594
|
+
if br is None:
|
|
595
|
+
arg_slices.append(slice(None))
|
|
596
|
+
else:
|
|
597
|
+
first, last = br
|
|
598
|
+
if last < first: # Empty
|
|
599
|
+
arg_slices.append(slice(0, 0))
|
|
600
|
+
else:
|
|
601
|
+
in_cumsum = list(cached_cumsum(arg.chunks[dim_idx], initial_zero=True))
|
|
602
|
+
arg_slices.append(slice(in_cumsum[first], in_cumsum[last + 1]))
|
|
603
|
+
except ValueError:
|
|
604
|
+
arg_slices.append(slice(None)) # Contracted dimension
|
|
605
|
+
|
|
606
|
+
sliced_arg = new_collection(arg)[tuple(arg_slices)]
|
|
607
|
+
new_args.extend([sliced_arg.expr, arg_ind])
|
|
608
|
+
|
|
609
|
+
# Slice adjust_chunks tuples/lists to match the new block ranges
|
|
610
|
+
new_adjust_chunks = self.operand("adjust_chunks")
|
|
611
|
+
if new_adjust_chunks:
|
|
612
|
+
new_adjust_chunks = dict(new_adjust_chunks) # Copy
|
|
613
|
+
for axis, br in enumerate(block_ranges):
|
|
614
|
+
if br is None:
|
|
615
|
+
continue
|
|
616
|
+
first, last = br
|
|
617
|
+
if last < first:
|
|
618
|
+
continue
|
|
619
|
+
ind = out_ind[axis]
|
|
620
|
+
if ind in new_adjust_chunks:
|
|
621
|
+
val = new_adjust_chunks[ind]
|
|
622
|
+
if isinstance(val, (tuple, list)):
|
|
623
|
+
# Slice the tuple to match the selected blocks
|
|
624
|
+
new_adjust_chunks[ind] = val[first : last + 1]
|
|
625
|
+
|
|
626
|
+
# Build the new Blockwise with coarse-sliced inputs
|
|
627
|
+
result = Blockwise(
|
|
628
|
+
self.func,
|
|
629
|
+
self.out_ind,
|
|
630
|
+
self.operand("name"),
|
|
631
|
+
self.operand("token"),
|
|
632
|
+
self.operand("dtype"),
|
|
633
|
+
new_adjust_chunks,
|
|
634
|
+
self.operand("new_axes"),
|
|
635
|
+
self.operand("align_arrays"),
|
|
636
|
+
self.operand("concatenate"),
|
|
637
|
+
self.operand("_meta_provided"),
|
|
638
|
+
self.operand("kwargs"),
|
|
639
|
+
*new_args,
|
|
640
|
+
)
|
|
641
|
+
|
|
642
|
+
# Check if we need output adjustment
|
|
643
|
+
needs_output_slice = any(adj != slice(None) for adj in output_adjustments)
|
|
644
|
+
|
|
645
|
+
if needs_output_slice:
|
|
646
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
647
|
+
|
|
648
|
+
# Build the output adjustment index
|
|
649
|
+
adj_index = tuple(output_adjustments)
|
|
650
|
+
return SliceSlicesIntegers(result, adj_index, slice_expr.allow_getitem_optimization)
|
|
651
|
+
|
|
652
|
+
return result
|
|
653
|
+
|
|
654
|
+
|
|
655
|
+
class Elemwise(Blockwise):
|
|
656
|
+
_parameters = ["op", "dtype", "name", "where", "out", "_user_kwargs"]
|
|
657
|
+
_defaults = {
|
|
658
|
+
"dtype": None,
|
|
659
|
+
"name": None,
|
|
660
|
+
"where": True,
|
|
661
|
+
"out": None,
|
|
662
|
+
"_user_kwargs": None,
|
|
663
|
+
}
|
|
664
|
+
align_arrays = True
|
|
665
|
+
new_axes: dict = {}
|
|
666
|
+
adjust_chunks = None
|
|
667
|
+
concatenate = None
|
|
668
|
+
|
|
669
|
+
@property
|
|
670
|
+
def user_kwargs(self):
|
|
671
|
+
return self.operand("_user_kwargs") or {}
|
|
672
|
+
|
|
673
|
+
@cached_property
|
|
674
|
+
def _meta(self):
|
|
675
|
+
# When where is not True, _info[0] is _elemwise_handle_where which
|
|
676
|
+
# expects args to end with (where, out)
|
|
677
|
+
args = list(self.elemwise_args)
|
|
678
|
+
if self.where is not True:
|
|
679
|
+
args.extend([self.where, self.out])
|
|
680
|
+
return compute_meta(self._info[0], self.dtype, *args, **self.kwargs)
|
|
681
|
+
|
|
682
|
+
@property
|
|
683
|
+
def elemwise_args(self):
|
|
684
|
+
return self.operands[len(self._parameters) :]
|
|
685
|
+
|
|
686
|
+
def dependencies(self):
|
|
687
|
+
"""Return expression dependencies.
|
|
688
|
+
|
|
689
|
+
When where is True (the default), 'out' is not actually used in
|
|
690
|
+
the computation - it's just a placeholder for _handle_out to
|
|
691
|
+
replace the expression. Exclude it from dependencies to avoid
|
|
692
|
+
fusion issues, UNLESS out is also an input (e.g., np.sin(x, out=x)).
|
|
693
|
+
"""
|
|
694
|
+
deps = super().dependencies()
|
|
695
|
+
if self.where is True and self.out is not None:
|
|
696
|
+
out_name = getattr(self.out, "_name", None)
|
|
697
|
+
# Only exclude if out is not also an input argument
|
|
698
|
+
input_names = {getattr(a, "_name", None) for a in self.elemwise_args if hasattr(a, "_name")}
|
|
699
|
+
if out_name and out_name not in input_names:
|
|
700
|
+
deps = [d for d in deps if d._name != out_name]
|
|
701
|
+
return deps
|
|
702
|
+
|
|
703
|
+
@property
|
|
704
|
+
def out_ind(self):
|
|
705
|
+
shapes = []
|
|
706
|
+
for arg in self.elemwise_args:
|
|
707
|
+
shape = getattr(arg, "shape", ())
|
|
708
|
+
if any(is_dask_collection(x) for x in shape):
|
|
709
|
+
# Want to exclude Delayed shapes and dd.Scalar
|
|
710
|
+
shape = ()
|
|
711
|
+
shapes.append(shape)
|
|
712
|
+
if isinstance(self.where, ArrayExpr):
|
|
713
|
+
shapes.append(self.where.shape)
|
|
714
|
+
if isinstance(self.out, ArrayExpr):
|
|
715
|
+
shapes.append(self.out.shape)
|
|
716
|
+
|
|
717
|
+
shapes = [s if isinstance(s, Iterable) else () for s in shapes]
|
|
718
|
+
out_ndim = len(broadcast_shapes(*shapes)) # Raises ValueError if dimensions mismatch
|
|
719
|
+
return tuple(range(out_ndim))[::-1]
|
|
720
|
+
|
|
721
|
+
@cached_property
|
|
722
|
+
def _info(self):
|
|
723
|
+
if self.operand("dtype") is not None:
|
|
724
|
+
need_enforce_dtype = True
|
|
725
|
+
dtype = np.dtype(self.operand("dtype"))
|
|
726
|
+
else:
|
|
727
|
+
# We follow NumPy's rules for dtype promotion, which special cases
|
|
728
|
+
# scalars and 0d ndarrays (which it considers equivalent) by using
|
|
729
|
+
# their values to compute the result dtype:
|
|
730
|
+
# https://github.com/numpy/numpy/issues/6240
|
|
731
|
+
# We don't inspect the values of 0d dask arrays, because these could
|
|
732
|
+
# hold potentially very expensive calculations. Instead, we treat
|
|
733
|
+
# them just like other arrays, and if necessary cast the result of op
|
|
734
|
+
# to match.
|
|
735
|
+
vals = [
|
|
736
|
+
(np.empty((1,) * max(1, a.ndim), dtype=a.dtype) if not is_scalar_for_elemwise(a) else a)
|
|
737
|
+
for a in self.elemwise_args
|
|
738
|
+
]
|
|
739
|
+
try:
|
|
740
|
+
dtype = apply_infer_dtype(self.op, vals, self.user_kwargs, "elemwise", suggest_dtype=False)
|
|
741
|
+
except Exception:
|
|
742
|
+
raise NotImplementedError
|
|
743
|
+
need_enforce_dtype = any(not is_scalar_for_elemwise(a) and a.ndim == 0 for a in self.elemwise_args)
|
|
744
|
+
|
|
745
|
+
blockwise_kwargs = {}
|
|
746
|
+
op = self.op
|
|
747
|
+
if self.where is not True:
|
|
748
|
+
blockwise_kwargs["elemwise_where_function"] = op
|
|
749
|
+
op = _elemwise_handle_where
|
|
750
|
+
|
|
751
|
+
if need_enforce_dtype:
|
|
752
|
+
blockwise_kwargs.update(
|
|
753
|
+
{
|
|
754
|
+
"enforce_dtype": dtype,
|
|
755
|
+
"enforce_dtype_function": op,
|
|
756
|
+
}
|
|
757
|
+
)
|
|
758
|
+
op = _enforce_dtype
|
|
759
|
+
|
|
760
|
+
return op, dtype, blockwise_kwargs
|
|
761
|
+
|
|
762
|
+
@property
|
|
763
|
+
def func(self):
|
|
764
|
+
return self._info[0]
|
|
765
|
+
|
|
766
|
+
@property
|
|
767
|
+
def dtype(self):
|
|
768
|
+
return self._info[1]
|
|
769
|
+
|
|
770
|
+
@property
|
|
771
|
+
def kwargs(self):
|
|
772
|
+
# Merge user kwargs with internal kwargs (dtype enforcement, where handling)
|
|
773
|
+
return {**self.user_kwargs, **self._info[2]}
|
|
774
|
+
|
|
775
|
+
@property
|
|
776
|
+
def token(self):
|
|
777
|
+
return funcname(self.op).strip("_")
|
|
778
|
+
|
|
779
|
+
@property
|
|
780
|
+
def args(self):
|
|
781
|
+
# for Blockwise rather than Elemwise
|
|
782
|
+
# When where is an array, append [where, out] for _elemwise_handle_where
|
|
783
|
+
extra_args = []
|
|
784
|
+
if self.where is not True:
|
|
785
|
+
extra_args.append(self.where)
|
|
786
|
+
extra_args.append(self.out)
|
|
787
|
+
return tuple(
|
|
788
|
+
toolz.concat(
|
|
789
|
+
(
|
|
790
|
+
a,
|
|
791
|
+
(tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None),
|
|
792
|
+
)
|
|
793
|
+
for a in self.elemwise_args + extra_args
|
|
794
|
+
)
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
def _lower(self):
|
|
798
|
+
# Override Blockwise._lower to handle Elemwise's different operand structure.
|
|
799
|
+
# Elemwise stores just arrays in operands, but args generates (array, indices) pairs.
|
|
800
|
+
# After unifying chunks, we only pass the unified arrays (not indices) to the constructor.
|
|
801
|
+
if self.align_arrays:
|
|
802
|
+
_, arrays, changed = unify_chunks_expr(*self.args)
|
|
803
|
+
if changed:
|
|
804
|
+
# Only pass the unified arrays, not the indices
|
|
805
|
+
# When where is an array, the last two arrays are where and out
|
|
806
|
+
if self.where is not True:
|
|
807
|
+
new_elemwise_args = arrays[:-2]
|
|
808
|
+
new_where = arrays[-2]
|
|
809
|
+
new_out = arrays[-1]
|
|
810
|
+
else:
|
|
811
|
+
new_elemwise_args = arrays
|
|
812
|
+
new_where = True
|
|
813
|
+
new_out = None
|
|
814
|
+
return Elemwise(
|
|
815
|
+
self.op,
|
|
816
|
+
self.operand("dtype"),
|
|
817
|
+
self.operand("name"),
|
|
818
|
+
new_where,
|
|
819
|
+
new_out,
|
|
820
|
+
self.operand("_user_kwargs"),
|
|
821
|
+
*new_elemwise_args,
|
|
822
|
+
)
|
|
823
|
+
|
|
824
|
+
def _task(self, key, block_id: tuple[int, ...]) -> Task:
|
|
825
|
+
"""Generate task for a specific output block.
|
|
826
|
+
|
|
827
|
+
Parameters
|
|
828
|
+
----------
|
|
829
|
+
key : tuple
|
|
830
|
+
The output key for this task (e.g., ('add-abc123', 0, 1))
|
|
831
|
+
block_id : tuple[int, ...]
|
|
832
|
+
The block coordinates (e.g., (0, 1) for block at row 0, col 1)
|
|
833
|
+
|
|
834
|
+
Returns
|
|
835
|
+
-------
|
|
836
|
+
Task
|
|
837
|
+
A Task object that computes this block
|
|
838
|
+
"""
|
|
839
|
+
args = []
|
|
840
|
+
|
|
841
|
+
# Process elemwise_args
|
|
842
|
+
for arg in self.elemwise_args:
|
|
843
|
+
if is_scalar_for_elemwise(arg):
|
|
844
|
+
args.append(arg)
|
|
845
|
+
else:
|
|
846
|
+
# Array argument - compute block_id adjusted for broadcasting
|
|
847
|
+
# For broadcasting: use 0 for dimensions where array has 1 block
|
|
848
|
+
arg_block_id = self._broadcast_block_id(arg, block_id)
|
|
849
|
+
args.append(TaskRef((arg.name, *arg_block_id)))
|
|
850
|
+
|
|
851
|
+
# Handle where/out arrays if present
|
|
852
|
+
if self.where is not True:
|
|
853
|
+
if is_scalar_for_elemwise(self.where):
|
|
854
|
+
args.append(self.where)
|
|
855
|
+
else:
|
|
856
|
+
where_block_id = self._broadcast_block_id(self.where, block_id)
|
|
857
|
+
args.append(TaskRef((self.where.name, *where_block_id)))
|
|
858
|
+
|
|
859
|
+
if self.out is None or is_scalar_for_elemwise(self.out):
|
|
860
|
+
args.append(self.out)
|
|
861
|
+
else:
|
|
862
|
+
out_block_id = self._broadcast_block_id(self.out, block_id)
|
|
863
|
+
args.append(TaskRef((self.out.name, *out_block_id)))
|
|
864
|
+
|
|
865
|
+
if self.kwargs:
|
|
866
|
+
return Task(key, self.func, *args, **self.kwargs)
|
|
867
|
+
else:
|
|
868
|
+
return Task(key, self.func, *args)
|
|
869
|
+
|
|
870
|
+
def _broadcast_block_id(self, arr, block_id: tuple[int, ...]) -> tuple[int, ...]:
|
|
871
|
+
"""Adjust block_id for broadcasting."""
|
|
872
|
+
return _broadcast_block_id(arr.numblocks, block_id)
|
|
873
|
+
|
|
874
|
+
def _input_block_id(self, dep, block_id: tuple[int, ...]) -> tuple[int, ...]:
|
|
875
|
+
"""Map output block_id to input block_id for a dependency.
|
|
876
|
+
|
|
877
|
+
For Elemwise, this handles broadcasting - same block_id adjusted
|
|
878
|
+
for arrays with fewer dimensions or single-block dimensions.
|
|
879
|
+
"""
|
|
880
|
+
return self._broadcast_block_id(dep, block_id)
|
|
881
|
+
|
|
882
|
+
def _accept_slice(self, slice_expr):
|
|
883
|
+
"""Accept a slice being pushed through this Elemwise.
|
|
884
|
+
|
|
885
|
+
Returns a new Elemwise with the slice pushed to each input,
|
|
886
|
+
handling broadcasting appropriately.
|
|
887
|
+
"""
|
|
888
|
+
from numbers import Integral
|
|
889
|
+
|
|
890
|
+
from dask_array._new_collection import new_collection
|
|
891
|
+
|
|
892
|
+
out_ind = self.out_ind
|
|
893
|
+
index = slice_expr.index
|
|
894
|
+
|
|
895
|
+
# Pad index to full length
|
|
896
|
+
full_index = index + (slice(None),) * (len(out_ind) - len(index))
|
|
897
|
+
|
|
898
|
+
# Build sliced inputs
|
|
899
|
+
new_args = []
|
|
900
|
+
for arg in self.elemwise_args:
|
|
901
|
+
if is_scalar_for_elemwise(arg):
|
|
902
|
+
new_args.append(arg)
|
|
903
|
+
else:
|
|
904
|
+
# Map output slice to this input's dimensions
|
|
905
|
+
# arg has indices tuple(range(arg.ndim)[::-1])
|
|
906
|
+
arg_ind = tuple(range(arg.ndim)[::-1])
|
|
907
|
+
arg_shape = arg.shape
|
|
908
|
+
|
|
909
|
+
# For each dimension of arg, find where its index appears in out_ind
|
|
910
|
+
# and get the corresponding slice
|
|
911
|
+
arg_slices = []
|
|
912
|
+
for i, dim_idx in enumerate(arg_ind):
|
|
913
|
+
# Find position of this index in out_ind
|
|
914
|
+
try:
|
|
915
|
+
out_pos = out_ind.index(dim_idx)
|
|
916
|
+
out_slice = full_index[out_pos]
|
|
917
|
+
# Handle size-1 (broadcast) dimensions specially:
|
|
918
|
+
# - For slices: use slice(None) to preserve broadcast semantics,
|
|
919
|
+
# EXCEPT for empty output slices (like [:0]) which must be preserved
|
|
920
|
+
# - For integers: use 0 instead of the original index (which may be
|
|
921
|
+
# out of bounds for the size-1 input)
|
|
922
|
+
if arg_shape[i] == 1:
|
|
923
|
+
if isinstance(out_slice, slice):
|
|
924
|
+
out_dim_size = self.shape[out_pos]
|
|
925
|
+
start, stop, step = out_slice.indices(out_dim_size)
|
|
926
|
+
if len(range(start, stop, step)) == 0:
|
|
927
|
+
# Empty output slice - preserve it
|
|
928
|
+
arg_slices.append(out_slice)
|
|
929
|
+
else:
|
|
930
|
+
arg_slices.append(slice(None))
|
|
931
|
+
elif isinstance(out_slice, Integral):
|
|
932
|
+
# Integer index on broadcast dim - use 0
|
|
933
|
+
arg_slices.append(0)
|
|
934
|
+
else:
|
|
935
|
+
arg_slices.append(out_slice)
|
|
936
|
+
else:
|
|
937
|
+
arg_slices.append(out_slice)
|
|
938
|
+
except ValueError:
|
|
939
|
+
# Index not in output (shouldn't happen for elemwise)
|
|
940
|
+
arg_slices.append(slice(None))
|
|
941
|
+
|
|
942
|
+
sliced_arg = new_collection(arg)[tuple(arg_slices)]
|
|
943
|
+
new_args.append(sliced_arg.expr)
|
|
944
|
+
|
|
945
|
+
return Elemwise(
|
|
946
|
+
self.op,
|
|
947
|
+
self.operand("dtype"),
|
|
948
|
+
self.operand("name"),
|
|
949
|
+
self.where,
|
|
950
|
+
self.out,
|
|
951
|
+
self.operand("_user_kwargs"),
|
|
952
|
+
*new_args,
|
|
953
|
+
)
|
|
954
|
+
|
|
955
|
+
def _accept_shuffle(self, shuffle_expr):
|
|
956
|
+
"""Accept a shuffle being pushed through this Elemwise.
|
|
957
|
+
|
|
958
|
+
Push shuffle through by shuffling each input array on the corresponding
|
|
959
|
+
axis, accounting for broadcasting. Inputs that broadcast on the shuffle
|
|
960
|
+
axis (size-1 or fewer dimensions) are not shuffled.
|
|
961
|
+
"""
|
|
962
|
+
from dask_array._shuffle import Shuffle
|
|
963
|
+
|
|
964
|
+
axis = shuffle_expr.axis
|
|
965
|
+
indexer = shuffle_expr.indexer
|
|
966
|
+
name = shuffle_expr.operand("name")
|
|
967
|
+
output_ndim = len(self.shape)
|
|
968
|
+
|
|
969
|
+
def get_input_axis(arg):
|
|
970
|
+
"""Get the corresponding axis in input for the output shuffle axis.
|
|
971
|
+
|
|
972
|
+
Returns the input axis, or None if the input broadcasts on this axis.
|
|
973
|
+
For broadcasting, input axes are aligned to the right of output axes.
|
|
974
|
+
"""
|
|
975
|
+
if is_scalar_for_elemwise(arg):
|
|
976
|
+
return None
|
|
977
|
+
# Input axis = output axis - (dimensions added by broadcasting)
|
|
978
|
+
input_axis = axis - (output_ndim - arg.ndim)
|
|
979
|
+
if input_axis < 0:
|
|
980
|
+
# This input doesn't have the shuffle axis (broadcasts on it)
|
|
981
|
+
return None
|
|
982
|
+
if arg.shape[input_axis] == 1:
|
|
983
|
+
# Size-1 dimensions broadcast, don't shuffle
|
|
984
|
+
return None
|
|
985
|
+
return input_axis
|
|
986
|
+
|
|
987
|
+
# Shuffle each array input on its corresponding axis
|
|
988
|
+
new_args = []
|
|
989
|
+
for arg in self.elemwise_args:
|
|
990
|
+
input_axis = get_input_axis(arg)
|
|
991
|
+
if input_axis is not None:
|
|
992
|
+
new_args.append(Shuffle(arg, indexer, input_axis, name))
|
|
993
|
+
else:
|
|
994
|
+
new_args.append(arg)
|
|
995
|
+
|
|
996
|
+
# Shuffle where/out if they are arrays
|
|
997
|
+
new_where = self.where
|
|
998
|
+
input_axis = get_input_axis(new_where) if hasattr(new_where, "ndim") else None
|
|
999
|
+
if input_axis is not None:
|
|
1000
|
+
new_where = Shuffle(new_where, indexer, input_axis, name)
|
|
1001
|
+
|
|
1002
|
+
new_out = self.out
|
|
1003
|
+
input_axis = get_input_axis(new_out) if hasattr(new_out, "ndim") else None
|
|
1004
|
+
if input_axis is not None:
|
|
1005
|
+
new_out = Shuffle(new_out, indexer, input_axis, name)
|
|
1006
|
+
|
|
1007
|
+
return Elemwise(
|
|
1008
|
+
self.op,
|
|
1009
|
+
self.operand("dtype"),
|
|
1010
|
+
self.operand("name"),
|
|
1011
|
+
new_where,
|
|
1012
|
+
new_out,
|
|
1013
|
+
self.operand("_user_kwargs"),
|
|
1014
|
+
*new_args,
|
|
1015
|
+
)
|
|
1016
|
+
|
|
1017
|
+
|
|
1018
|
+
def _broadcast_block_id(numblocks: tuple[int, ...], block_id: tuple[int, ...]) -> tuple[int, ...]:
|
|
1019
|
+
"""Adjust block_id for broadcasting.
|
|
1020
|
+
|
|
1021
|
+
When an array has fewer dimensions or single-block dimensions,
|
|
1022
|
+
we need to adjust the block indices accordingly.
|
|
1023
|
+
"""
|
|
1024
|
+
out_ndim = len(block_id)
|
|
1025
|
+
arr_ndim = len(numblocks)
|
|
1026
|
+
|
|
1027
|
+
# Handle dimension mismatch (broadcasting adds leading dims)
|
|
1028
|
+
offset = out_ndim - arr_ndim
|
|
1029
|
+
|
|
1030
|
+
result = []
|
|
1031
|
+
for i, nb in enumerate(numblocks):
|
|
1032
|
+
out_idx = offset + i
|
|
1033
|
+
if nb == 1:
|
|
1034
|
+
# Single block in this dimension - always use 0
|
|
1035
|
+
result.append(0)
|
|
1036
|
+
else:
|
|
1037
|
+
result.append(block_id[out_idx])
|
|
1038
|
+
return tuple(result)
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
def _compute_block_id(ind: tuple, idx_to_block: dict, numblocks: tuple[int, ...]) -> tuple[int, ...]:
|
|
1042
|
+
"""Compute block_id for a dependency given symbolic indices.
|
|
1043
|
+
|
|
1044
|
+
Maps symbolic indices to block coordinates using idx_to_block mapping.
|
|
1045
|
+
Handles contracted dimensions (indices in input but not output) by using
|
|
1046
|
+
block 0 when the dimension has only 1 block.
|
|
1047
|
+
"""
|
|
1048
|
+
result = []
|
|
1049
|
+
for dim, i in enumerate(ind):
|
|
1050
|
+
if i in idx_to_block:
|
|
1051
|
+
result.append(idx_to_block[i] % numblocks[dim])
|
|
1052
|
+
elif numblocks[dim] == 1:
|
|
1053
|
+
# Contracted dimension with single block - use block 0
|
|
1054
|
+
result.append(0)
|
|
1055
|
+
else:
|
|
1056
|
+
raise ValueError(
|
|
1057
|
+
f"Cannot determine block for index {i}: not in output indices "
|
|
1058
|
+
f"and input has {numblocks[dim]} blocks in dimension {dim}"
|
|
1059
|
+
)
|
|
1060
|
+
return tuple(result)
|
|
1061
|
+
|
|
1062
|
+
|
|
1063
|
+
def is_fusable_blockwise(expr):
|
|
1064
|
+
"""Check if an expression is a fusable Blockwise operation.
|
|
1065
|
+
|
|
1066
|
+
Returns True if the expression has _is_blockwise_fusable = True.
|
|
1067
|
+
This includes Blockwise (without concatenate), BroadcastTrick, and Random.
|
|
1068
|
+
"""
|
|
1069
|
+
return getattr(expr, "_is_blockwise_fusable", False)
|
|
1070
|
+
|
|
1071
|
+
|
|
1072
|
+
# Alias for internal use
|
|
1073
|
+
is_fusable_elemwise = is_fusable_blockwise
|
|
1074
|
+
|
|
1075
|
+
|
|
1076
|
+
def _symbolic_mapping(expr, parent_mapping):
|
|
1077
|
+
"""Compute symbolic block mapping from root dimensions to dependency dimensions.
|
|
1078
|
+
|
|
1079
|
+
A symbolic mapping is a tuple where each element indicates which root output
|
|
1080
|
+
dimension maps to that position. For example:
|
|
1081
|
+
- (0, 1) means block = (root_dim_0, root_dim_1)
|
|
1082
|
+
- (2, 1) means block = (root_dim_2, root_dim_1)
|
|
1083
|
+
|
|
1084
|
+
This allows detecting conflicts symbolically without sampling.
|
|
1085
|
+
"""
|
|
1086
|
+
from dask_array.manipulation._transpose import Transpose
|
|
1087
|
+
|
|
1088
|
+
result = {}
|
|
1089
|
+
|
|
1090
|
+
if isinstance(expr, Transpose):
|
|
1091
|
+
# Transpose permutes dimensions: output[i] comes from input[axes[i]]
|
|
1092
|
+
# So if parent has mapping M, our input has mapping M permuted by inverse_axes
|
|
1093
|
+
inv = expr._inverse_axes
|
|
1094
|
+
dep_mapping = tuple(parent_mapping[inv[i]] for i in range(len(inv)))
|
|
1095
|
+
dep = expr.array
|
|
1096
|
+
if hasattr(dep, "_name"):
|
|
1097
|
+
result[dep._name] = [dep_mapping]
|
|
1098
|
+
elif hasattr(expr, "out_ind") and hasattr(expr, "args"):
|
|
1099
|
+
# Blockwise: each arg has indices that select from out_ind
|
|
1100
|
+
idx_to_parent = {}
|
|
1101
|
+
for dim, idx in enumerate(expr.out_ind):
|
|
1102
|
+
idx_to_parent[idx] = parent_mapping[dim] if dim < len(parent_mapping) else dim
|
|
1103
|
+
|
|
1104
|
+
for arr, ind in toolz.partition(2, expr.args):
|
|
1105
|
+
if ind is not None and hasattr(arr, "_name"):
|
|
1106
|
+
# Map each position in ind to root dimension
|
|
1107
|
+
dep_mapping = tuple(idx_to_parent.get(i, i) for i in ind)
|
|
1108
|
+
if arr._name not in result:
|
|
1109
|
+
result[arr._name] = []
|
|
1110
|
+
result[arr._name].append(dep_mapping)
|
|
1111
|
+
else:
|
|
1112
|
+
# For other expression types (e.g., Random), use identity mapping
|
|
1113
|
+
# through dependencies - each dep gets the same mapping as parent
|
|
1114
|
+
for dep in expr.dependencies():
|
|
1115
|
+
if hasattr(dep, "_name") and dep.ndim == len(parent_mapping):
|
|
1116
|
+
result[dep._name] = [parent_mapping]
|
|
1117
|
+
|
|
1118
|
+
return result
|
|
1119
|
+
|
|
1120
|
+
|
|
1121
|
+
def _remove_conflicting_exprs(group):
|
|
1122
|
+
"""Remove expressions accessed with conflicting block patterns.
|
|
1123
|
+
|
|
1124
|
+
When the same expression is accessed via multiple paths with different
|
|
1125
|
+
index transformations (e.g., a + a.T), we can't fuse it - each output
|
|
1126
|
+
block would need different source blocks from the same expression.
|
|
1127
|
+
|
|
1128
|
+
Uses symbolic analysis: traces how root output dimensions map to each
|
|
1129
|
+
expression's block dimensions through the expression tree. If the same
|
|
1130
|
+
expression is reached via paths with different symbolic mappings, it's
|
|
1131
|
+
a conflict.
|
|
1132
|
+
|
|
1133
|
+
Also removes expressions that become unreachable after conflict removal.
|
|
1134
|
+
"""
|
|
1135
|
+
if len(group) <= 1:
|
|
1136
|
+
return group
|
|
1137
|
+
|
|
1138
|
+
expr_names = {e._name for e in group}
|
|
1139
|
+
expr_map = {e._name: e for e in group}
|
|
1140
|
+
root = group[0]
|
|
1141
|
+
|
|
1142
|
+
# Symbolic mapping: tuple of root dimension indices for each expression
|
|
1143
|
+
# (0, 1) means "root dim 0 for position 0, root dim 1 for position 1"
|
|
1144
|
+
symbolic_mappings = {root._name: tuple(range(root.ndim))}
|
|
1145
|
+
conflicts = set()
|
|
1146
|
+
|
|
1147
|
+
for expr in group:
|
|
1148
|
+
if expr._name not in symbolic_mappings:
|
|
1149
|
+
continue
|
|
1150
|
+
my_mapping = symbolic_mappings[expr._name]
|
|
1151
|
+
|
|
1152
|
+
# Get symbolic mappings for all dependencies
|
|
1153
|
+
dep_mappings = _symbolic_mapping(expr, my_mapping)
|
|
1154
|
+
|
|
1155
|
+
for dep_name, mappings_list in dep_mappings.items():
|
|
1156
|
+
if dep_name not in expr_names:
|
|
1157
|
+
continue
|
|
1158
|
+
|
|
1159
|
+
for dep_mapping in mappings_list:
|
|
1160
|
+
if dep_name in symbolic_mappings:
|
|
1161
|
+
if symbolic_mappings[dep_name] != dep_mapping:
|
|
1162
|
+
conflicts.add(dep_name)
|
|
1163
|
+
else:
|
|
1164
|
+
symbolic_mappings[dep_name] = dep_mapping
|
|
1165
|
+
|
|
1166
|
+
if not conflicts:
|
|
1167
|
+
return group
|
|
1168
|
+
|
|
1169
|
+
# Remove conflicts and find reachable expressions
|
|
1170
|
+
remaining = {e._name for e in group if e._name not in conflicts}
|
|
1171
|
+
reachable = {root._name}
|
|
1172
|
+
stack = [root]
|
|
1173
|
+
|
|
1174
|
+
while stack:
|
|
1175
|
+
expr = stack.pop()
|
|
1176
|
+
for dep in expr.dependencies():
|
|
1177
|
+
if dep._name in remaining and dep._name not in reachable:
|
|
1178
|
+
reachable.add(dep._name)
|
|
1179
|
+
stack.append(expr_map[dep._name])
|
|
1180
|
+
|
|
1181
|
+
return [e for e in group if e._name in reachable]
|
|
1182
|
+
|
|
1183
|
+
|
|
1184
|
+
def optimize_blockwise_fusion_array(expr):
|
|
1185
|
+
"""Traverse the expression graph and apply fusion.
|
|
1186
|
+
|
|
1187
|
+
Finds groups of consecutive fusable Blockwise operations and fuses them
|
|
1188
|
+
into single FusedBlockwise expressions.
|
|
1189
|
+
"""
|
|
1190
|
+
from collections import defaultdict
|
|
1191
|
+
|
|
1192
|
+
def _fusion_pass(expr):
|
|
1193
|
+
# Build dependency graph of fusable operations
|
|
1194
|
+
seen = set()
|
|
1195
|
+
stack = [expr]
|
|
1196
|
+
dependents = defaultdict(set) # name -> set of dependent names
|
|
1197
|
+
dependencies = {} # name -> set of dependency names
|
|
1198
|
+
expr_mapping = {} # name -> expr
|
|
1199
|
+
|
|
1200
|
+
while stack:
|
|
1201
|
+
node = stack.pop()
|
|
1202
|
+
|
|
1203
|
+
if node._name in seen:
|
|
1204
|
+
continue
|
|
1205
|
+
seen.add(node._name)
|
|
1206
|
+
|
|
1207
|
+
if is_fusable_elemwise(node):
|
|
1208
|
+
dependencies[node._name] = set()
|
|
1209
|
+
if node._name not in dependents:
|
|
1210
|
+
dependents[node._name] = set()
|
|
1211
|
+
expr_mapping[node._name] = node
|
|
1212
|
+
|
|
1213
|
+
for operand in node.dependencies():
|
|
1214
|
+
stack.append(operand)
|
|
1215
|
+
if is_fusable_elemwise(operand):
|
|
1216
|
+
if node._name in dependencies:
|
|
1217
|
+
dependencies[node._name].add(operand._name)
|
|
1218
|
+
dependents[operand._name].add(node._name)
|
|
1219
|
+
expr_mapping[operand._name] = operand
|
|
1220
|
+
expr_mapping[node._name] = node
|
|
1221
|
+
|
|
1222
|
+
# Find roots - Elemwise nodes with no Elemwise dependents
|
|
1223
|
+
roots = [
|
|
1224
|
+
expr_mapping[k]
|
|
1225
|
+
for k, v in dependents.items()
|
|
1226
|
+
if v == set() or all(not is_fusable_elemwise(expr_mapping.get(_name)) for _name in v)
|
|
1227
|
+
]
|
|
1228
|
+
|
|
1229
|
+
while roots:
|
|
1230
|
+
root = roots.pop()
|
|
1231
|
+
seen_in_group = set()
|
|
1232
|
+
stack = [root]
|
|
1233
|
+
group = []
|
|
1234
|
+
|
|
1235
|
+
while stack:
|
|
1236
|
+
node = stack.pop()
|
|
1237
|
+
|
|
1238
|
+
if node._name in seen_in_group:
|
|
1239
|
+
continue
|
|
1240
|
+
seen_in_group.add(node._name)
|
|
1241
|
+
|
|
1242
|
+
group.append(node)
|
|
1243
|
+
for dep_name in dependencies.get(node._name, set()):
|
|
1244
|
+
dep = expr_mapping[dep_name]
|
|
1245
|
+
|
|
1246
|
+
stack_names = {s._name for s in stack}
|
|
1247
|
+
group_names = {g._name for g in group}
|
|
1248
|
+
|
|
1249
|
+
# Check if all dependents of dep are in our group or stack
|
|
1250
|
+
dep_dependents = dependents.get(dep_name, set())
|
|
1251
|
+
if dep_dependents <= (stack_names | group_names | {node._name}):
|
|
1252
|
+
# dep can be fused into this group
|
|
1253
|
+
stack.append(dep)
|
|
1254
|
+
elif dependencies.get(dep._name) and dep._name not in [r._name for r in roots]:
|
|
1255
|
+
# Can't fuse dep, but may be able to use as new root
|
|
1256
|
+
roots.append(dep)
|
|
1257
|
+
|
|
1258
|
+
# Replace fusable sub-group
|
|
1259
|
+
if len(group) > 1:
|
|
1260
|
+
# Check for conflicting block patterns before fusing
|
|
1261
|
+
group = _remove_conflicting_exprs(group)
|
|
1262
|
+
if len(group) > 1:
|
|
1263
|
+
fused = FusedBlockwise(tuple(group))
|
|
1264
|
+
new_expr = expr.substitute(group[0], fused)
|
|
1265
|
+
return new_expr, not roots
|
|
1266
|
+
|
|
1267
|
+
# No fusable groups found
|
|
1268
|
+
return expr, True
|
|
1269
|
+
|
|
1270
|
+
# Iterate until no more fusion is possible
|
|
1271
|
+
while True:
|
|
1272
|
+
original_name = expr._name
|
|
1273
|
+
expr, done = _fusion_pass(expr)
|
|
1274
|
+
if done or expr._name == original_name:
|
|
1275
|
+
break
|
|
1276
|
+
|
|
1277
|
+
return expr
|
|
1278
|
+
|
|
1279
|
+
|
|
1280
|
+
class FusedBlockwise(ArrayExpr):
|
|
1281
|
+
"""Fused blockwise operations for arrays.
|
|
1282
|
+
|
|
1283
|
+
A FusedBlockwise corresponds to the fusion of multiple Blockwise/Elemwise
|
|
1284
|
+
expressions into a single Expr object. At graph-materialization time,
|
|
1285
|
+
the behavior produces fused tasks that execute all operations together.
|
|
1286
|
+
|
|
1287
|
+
Parameters
|
|
1288
|
+
----------
|
|
1289
|
+
exprs : tuple[Expr, ...]
|
|
1290
|
+
Group of original Expr objects being fused together. The first
|
|
1291
|
+
expression is the "root" (final output).
|
|
1292
|
+
*dependencies :
|
|
1293
|
+
External Expr dependencies - any Expr operand not included in exprs.
|
|
1294
|
+
These are passed as additional operands after exprs.
|
|
1295
|
+
"""
|
|
1296
|
+
|
|
1297
|
+
_parameters = ["exprs"]
|
|
1298
|
+
|
|
1299
|
+
@property
|
|
1300
|
+
def _meta(self):
|
|
1301
|
+
return self.exprs[0]._meta
|
|
1302
|
+
|
|
1303
|
+
@property
|
|
1304
|
+
def chunks(self):
|
|
1305
|
+
return self.exprs[0].chunks
|
|
1306
|
+
|
|
1307
|
+
@property
|
|
1308
|
+
def dtype(self):
|
|
1309
|
+
return self.exprs[0].dtype
|
|
1310
|
+
|
|
1311
|
+
def dependencies(self):
|
|
1312
|
+
"""Return external dependencies not included in the fused group."""
|
|
1313
|
+
fused_names = {e._name for e in self.exprs}
|
|
1314
|
+
external_deps = []
|
|
1315
|
+
seen = set()
|
|
1316
|
+
for expr in self.exprs:
|
|
1317
|
+
for dep in expr.dependencies():
|
|
1318
|
+
if dep._name not in fused_names and dep._name not in seen:
|
|
1319
|
+
external_deps.append(dep)
|
|
1320
|
+
seen.add(dep._name)
|
|
1321
|
+
return external_deps
|
|
1322
|
+
|
|
1323
|
+
def _layer(self):
|
|
1324
|
+
result = {}
|
|
1325
|
+
for block_id in product(*[range(n) for n in self.numblocks]):
|
|
1326
|
+
key = (self._name, *block_id)
|
|
1327
|
+
result[key] = self._task(key, block_id)
|
|
1328
|
+
return result
|
|
1329
|
+
|
|
1330
|
+
def _task(self, key, block_id: tuple[int, ...]) -> Task:
|
|
1331
|
+
"""Generate a fused task for a specific output block."""
|
|
1332
|
+
# Compute block_id for each expression by tracing through dependencies
|
|
1333
|
+
# Each expression type (Elemwise, Transpose) has its own block mapping
|
|
1334
|
+
expr_block_ids = self._compute_block_ids(block_id)
|
|
1335
|
+
|
|
1336
|
+
# Generate tasks in dependency order (leaves first for Task.fuse)
|
|
1337
|
+
internal_tasks = []
|
|
1338
|
+
for expr in reversed(self.exprs):
|
|
1339
|
+
expr_block_id = expr_block_ids[expr._name]
|
|
1340
|
+
subname = (expr._name, *expr_block_id)
|
|
1341
|
+
t = expr._task(subname, expr_block_id)
|
|
1342
|
+
internal_tasks.append(t)
|
|
1343
|
+
return Task.fuse(*internal_tasks, key=key) # type: ignore[return-value]
|
|
1344
|
+
|
|
1345
|
+
def _compute_block_ids(self, output_block_id: tuple[int, ...]) -> dict:
|
|
1346
|
+
"""Compute block_id for each expression given the output block_id.
|
|
1347
|
+
|
|
1348
|
+
Traces through the expression chain, using each expression's
|
|
1349
|
+
_input_block_id method to map output to input block coordinates.
|
|
1350
|
+
"""
|
|
1351
|
+
expr_names = {e._name for e in self.exprs}
|
|
1352
|
+
expr_block_ids = {self.exprs[0]._name: output_block_id}
|
|
1353
|
+
|
|
1354
|
+
for expr in self.exprs:
|
|
1355
|
+
my_block_id = expr_block_ids[expr._name]
|
|
1356
|
+
for dep in expr.dependencies():
|
|
1357
|
+
if dep._name in expr_names and dep._name not in expr_block_ids:
|
|
1358
|
+
dep_block_id = expr._input_block_id(dep, my_block_id)
|
|
1359
|
+
expr_block_ids[dep._name] = dep_block_id
|
|
1360
|
+
|
|
1361
|
+
return expr_block_ids
|
|
1362
|
+
|
|
1363
|
+
def __str__(self):
|
|
1364
|
+
names = [expr._name.split("-")[0] for expr in self.exprs]
|
|
1365
|
+
if len(names) > 4:
|
|
1366
|
+
return f"{names[0]}-fused-{names[-1]}"
|
|
1367
|
+
return "-".join(names)
|
|
1368
|
+
|
|
1369
|
+
@cached_property
|
|
1370
|
+
def _name(self):
|
|
1371
|
+
return f"{self}-{self.deterministic_token}"
|
|
1372
|
+
|
|
1373
|
+
|
|
1374
|
+
def outer(a, b):
|
|
1375
|
+
"""
|
|
1376
|
+
Compute the outer product of two vectors.
|
|
1377
|
+
|
|
1378
|
+
This docstring was copied from numpy.outer.
|
|
1379
|
+
|
|
1380
|
+
Some inconsistencies with the Dask version may exist.
|
|
1381
|
+
|
|
1382
|
+
Given two vectors, ``a = [a0, a1, ..., aM]`` and
|
|
1383
|
+
``b = [b0, b1, ..., bN]``,
|
|
1384
|
+
the outer product is::
|
|
1385
|
+
|
|
1386
|
+
[[a0*b0 a0*b1 ... a0*bN ]
|
|
1387
|
+
[a1*b0 .
|
|
1388
|
+
[ ... .
|
|
1389
|
+
[aM*b0 aM*bN ]]
|
|
1390
|
+
|
|
1391
|
+
Parameters
|
|
1392
|
+
----------
|
|
1393
|
+
a : (M,) array_like
|
|
1394
|
+
First input vector. Input is flattened if not already 1-dimensional.
|
|
1395
|
+
b : (N,) array_like
|
|
1396
|
+
Second input vector. Input is flattened if not already 1-dimensional.
|
|
1397
|
+
|
|
1398
|
+
Returns
|
|
1399
|
+
-------
|
|
1400
|
+
out : (M, N) ndarray
|
|
1401
|
+
``out[i, j] = a[i] * b[j]``
|
|
1402
|
+
"""
|
|
1403
|
+
from dask_array._collection import asarray, blockwise
|
|
1404
|
+
|
|
1405
|
+
a = asarray(a).flatten()
|
|
1406
|
+
b = asarray(b).flatten()
|
|
1407
|
+
|
|
1408
|
+
dtype = np.outer(a.dtype.type(), b.dtype.type()).dtype
|
|
1409
|
+
|
|
1410
|
+
return blockwise(np.outer, "ij", a, "i", b, "j", dtype=dtype)
|