dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,1410 @@
1
+ from __future__ import annotations
2
+
3
+ import numbers
4
+ from collections.abc import Iterable
5
+ from itertools import product
6
+
7
+ import numpy as np
8
+ import tlz as toolz
9
+
10
+ from dask import is_dask_collection
11
+ from dask._task_spec import Task, TaskRef
12
+ from dask_array._expr import ArrayExpr, unify_chunks_expr
13
+ from dask_array._utils import compute_meta
14
+ from dask_array._core_utils import (
15
+ _elemwise_handle_where,
16
+ _enforce_dtype,
17
+ apply_infer_dtype,
18
+ broadcast_shapes,
19
+ is_scalar_for_elemwise,
20
+ normalize_arg,
21
+ )
22
+ from dask_array._utils import meta_from_array
23
+ from dask.blockwise import blockwise as core_blockwise
24
+ from dask.delayed import unpack_collections
25
+ from dask.layers import ArrayBlockwiseDep
26
+ from dask.tokenize import _tokenize_deterministic
27
+ from dask.utils import SerializableLock, cached_property, funcname
28
+
29
+
30
+ class Blockwise(ArrayExpr):
31
+ _parameters = [
32
+ "func",
33
+ "out_ind",
34
+ "name",
35
+ "token",
36
+ "dtype",
37
+ "adjust_chunks",
38
+ "new_axes",
39
+ "align_arrays",
40
+ "concatenate",
41
+ "_meta_provided",
42
+ "kwargs",
43
+ ]
44
+ _defaults = {
45
+ "name": None,
46
+ "token": None,
47
+ "dtype": None,
48
+ "adjust_chunks": None,
49
+ "new_axes": None,
50
+ "align_arrays": True,
51
+ "concatenate": None,
52
+ "_meta_provided": None,
53
+ "kwargs": None,
54
+ }
55
+
56
+ @cached_property
57
+ def args(self):
58
+ return self.operands[len(self._parameters) :]
59
+
60
+ @cached_property
61
+ def _meta_provided(self):
62
+ # We catch recursion errors if key starts with _meta, so define
63
+ # explicitly here
64
+ return self.operand("_meta_provided")
65
+
66
+ @cached_property
67
+ def _meta(self):
68
+ if self._meta_provided is not None:
69
+ # Handle tuple metas for multi-output functions (e.g., from apply_gufunc)
70
+ if isinstance(self._meta_provided, (tuple, list)):
71
+ return tuple(
72
+ meta_from_array(
73
+ m,
74
+ ndim=m.ndim,
75
+ dtype=getattr(m, "dtype", None),
76
+ )
77
+ for m in self._meta_provided
78
+ )
79
+ # Use getattr for dtype since some metas (e.g., DataFrame) don't have .dtype
80
+ return meta_from_array(
81
+ self._meta_provided,
82
+ ndim=self.ndim,
83
+ dtype=getattr(self._meta_provided, "dtype", None),
84
+ )
85
+ else:
86
+ meta = compute_meta(self.func, self.operand("dtype"), *self.args[::2], **self.kwargs)
87
+ if meta is None:
88
+ # compute_meta failed (e.g., function has assertions on shapes)
89
+ # Fall back to a default meta based on the explicitly provided dtype
90
+ # (use operand to avoid recursion since dtype property may depend on _meta)
91
+ meta = meta_from_array(None, ndim=self.ndim, dtype=self.operand("dtype"))
92
+ return meta
93
+
94
+ @cached_property
95
+ def chunks(self):
96
+ if self.align_arrays:
97
+ chunkss, arrays, _ = unify_chunks_expr(*self.args)
98
+ else:
99
+ arginds = [(a, i) for (a, i) in toolz.partition(2, self.args) if i is not None]
100
+ chunkss = {}
101
+ # For each dimension, use the input chunking that has the most blocks;
102
+ # this will ensure that broadcasting works as expected, and in
103
+ # particular the number of blocks should be correct if the inputs are
104
+ # consistent.
105
+ for arg, ind in arginds:
106
+ for c, i in zip(arg.chunks, ind):
107
+ if i not in chunkss or len(c) > len(chunkss[i]):
108
+ chunkss[i] = c
109
+
110
+ for k, v in self.new_axes.items():
111
+ if not isinstance(v, tuple):
112
+ v = (v,)
113
+ chunkss[k] = v
114
+
115
+ chunks = [chunkss[i] for i in self.out_ind]
116
+ if self.adjust_chunks:
117
+ for i, ind in enumerate(self.out_ind):
118
+ if ind in self.adjust_chunks:
119
+ if callable(self.adjust_chunks[ind]):
120
+ chunks[i] = tuple(map(self.adjust_chunks[ind], chunks[i]))
121
+ elif isinstance(self.adjust_chunks[ind], numbers.Integral):
122
+ chunks[i] = tuple(self.adjust_chunks[ind] for _ in chunks[i])
123
+ elif isinstance(self.adjust_chunks[ind], (tuple, list)):
124
+ if len(self.adjust_chunks[ind]) != len(chunks[i]):
125
+ raise ValueError(
126
+ f"Dimension {i} has {len(chunks[i])} blocks, adjust_chunks "
127
+ f"specified with {len(self.adjust_chunks[ind])} blocks"
128
+ )
129
+ chunks[i] = tuple(self.adjust_chunks[ind])
130
+ else:
131
+ raise NotImplementedError("adjust_chunks values must be callable, int, or tuple")
132
+ chunks = tuple(chunks)
133
+ return tuple(map(tuple, chunks))
134
+
135
+ @cached_property
136
+ def dtype(self):
137
+ return super().dtype
138
+
139
+ @property
140
+ def _is_blockwise_fusable(self):
141
+ # Blockwise with concatenate requires special handling not yet implemented
142
+ if self.concatenate:
143
+ return False
144
+ # Blockwise with Delayed operands can't be fused because FusedBlockwise
145
+ # doesn't properly track them as external dependencies
146
+ from dask.delayed import Delayed
147
+
148
+ if any(isinstance(op, Delayed) for op in self.operands):
149
+ return False
150
+
151
+ # Check for contracted dimensions with multiple blocks
152
+ # These are dimensions in input but not in output - we can only fuse
153
+ # if they have a single block
154
+ out_idx_set = set(self.out_ind)
155
+ if self.new_axes:
156
+ out_idx_set |= set(self.new_axes.keys())
157
+ for arr, ind in toolz.partition(2, self.args):
158
+ if ind is not None and hasattr(arr, "numblocks"):
159
+ for dim, i in enumerate(ind):
160
+ if i not in out_idx_set and arr.numblocks[dim] > 1:
161
+ # Contracted dimension with multiple blocks can't be fused
162
+ return False
163
+ return True
164
+
165
+ def _idx_to_block(self, block_id: tuple[int, ...]) -> dict:
166
+ """Map symbolic indices to output block coordinates."""
167
+ idx_to_block = {idx: block_id[dim] for dim, idx in enumerate(self.out_ind)}
168
+ for idx in self.new_axes:
169
+ idx_to_block[idx] = 0
170
+ return idx_to_block
171
+
172
+ def _dep_block_id(self, arr, ind, idx_to_block: dict) -> tuple[int, ...]:
173
+ """Compute block_id for a dependency, applying modulo for broadcasting."""
174
+ return _compute_block_id(ind, idx_to_block, arr.numblocks)
175
+
176
+ def _task(self, key, block_id: tuple[int, ...]):
177
+ """Generate task for a specific output block."""
178
+ from dask._task_spec import Task, TaskRef
179
+ from dask.layers import ArrayBlockwiseDep
180
+
181
+ if self.concatenate:
182
+ raise NotImplementedError("Blockwise with concatenate not supported for fusion")
183
+
184
+ idx_to_block = self._idx_to_block(block_id)
185
+
186
+ args = []
187
+ for arr, ind in toolz.partition(2, self.args):
188
+ if ind is None:
189
+ args.append(arr)
190
+ elif isinstance(arr, ArrayBlockwiseDep):
191
+ numblocks = tuple(len(c) for c in arr.chunks)
192
+ input_block_id = _compute_block_id(ind, idx_to_block, numblocks)
193
+ args.append(arr[input_block_id])
194
+ else:
195
+ input_block_id = self._dep_block_id(arr, ind, idx_to_block)
196
+ args.append(TaskRef((arr._name, *input_block_id)))
197
+
198
+ return Task(key, self.func, *args, **self.kwargs)
199
+
200
+ def _input_block_id(self, dep, block_id: tuple[int, ...]) -> tuple[int, ...]:
201
+ """Map output block_id to input block_id for a dependency."""
202
+ idx_to_block = self._idx_to_block(block_id)
203
+ for arr, ind in toolz.partition(2, self.args):
204
+ if ind is not None and hasattr(arr, "_name") and arr._name == dep._name:
205
+ return self._dep_block_id(arr, ind, idx_to_block)
206
+ return block_id
207
+
208
+ def _all_input_block_ids(self, block_id: tuple[int, ...]) -> dict:
209
+ """Return all input block_ids for dependencies.
210
+
211
+ Handles case where same dependency appears multiple times with
212
+ different index mappings (e.g., da.dot(x, x)).
213
+ """
214
+ idx_to_block = self._idx_to_block(block_id)
215
+ result: dict = {}
216
+ for arr, ind in toolz.partition(2, self.args):
217
+ if ind is not None and hasattr(arr, "_name"):
218
+ dep_block_id = self._dep_block_id(arr, ind, idx_to_block)
219
+ if arr._name not in result:
220
+ result[arr._name] = []
221
+ result[arr._name].append(dep_block_id)
222
+ return result
223
+
224
+ def __dask_tokenize__(self):
225
+ if not self._determ_token:
226
+ # Handle non-serializable locks in kwargs by using their id()
227
+ kwargs_token = {}
228
+ for k, v in self.kwargs.items():
229
+ if k == "lock" and v and not isinstance(v, (bool, SerializableLock)):
230
+ kwargs_token[k] = ("lock-id", id(v))
231
+ else:
232
+ kwargs_token[k] = v
233
+
234
+ self._determ_token = _tokenize_deterministic(
235
+ self.func,
236
+ self.out_ind,
237
+ self.dtype,
238
+ self.adjust_chunks,
239
+ self.new_axes,
240
+ self.align_arrays,
241
+ self.concatenate,
242
+ *self.args,
243
+ **kwargs_token,
244
+ )
245
+ return self._determ_token
246
+
247
+ @cached_property
248
+ def _name(self):
249
+ # Always include deterministic_token suffix to ensure:
250
+ # 1. Different expressions with same user-provided name are distinguishable
251
+ # 2. lower_completely can detect when operands change (via name change)
252
+ prefix = (
253
+ self.operand("name")
254
+ if "name" in self._parameters and self.operand("name")
255
+ else (self.token or funcname(self.func).strip("_"))
256
+ )
257
+ return f"{prefix}-{self.deterministic_token}"
258
+
259
+ def _layer(self):
260
+ arginds = [(a, i) for (a, i) in toolz.partition(2, self.args)]
261
+
262
+ numblocks = {}
263
+ dependencies = []
264
+ arrays = []
265
+
266
+ # Normalize arguments
267
+ argindsstr = []
268
+
269
+ for arg, ind in arginds:
270
+ if ind is None:
271
+ # Literal argument (not an array) - normalize it
272
+ arg = normalize_arg(arg)
273
+ arg, collections = unpack_collections(arg)
274
+ dependencies.extend(collections)
275
+ else:
276
+ if hasattr(arg, "ndim") and hasattr(ind, "__len__") and arg.ndim != len(ind):
277
+ raise ValueError(f"Index string {ind} does not match array dimension {arg.ndim}")
278
+ # TODO(expr): this class is a confusing crutch to pass arguments to the
279
+ # graph, we should write them directly into the graph
280
+ if not isinstance(arg, ArrayBlockwiseDep):
281
+ numblocks[arg.name] = arg.numblocks
282
+ arrays.append(arg)
283
+ arg = arg.name
284
+ argindsstr.extend((arg, ind))
285
+
286
+ # Normalize keyword arguments
287
+ kwargs2 = {}
288
+ for k, v in self.kwargs.items():
289
+ v = normalize_arg(v)
290
+ v, collections = unpack_collections(v)
291
+ dependencies.extend(collections)
292
+ kwargs2[k] = v
293
+
294
+ # TODO(expr): Highlevelgraph :(
295
+ graph = core_blockwise(
296
+ self.func,
297
+ self._name,
298
+ self.out_ind,
299
+ *argindsstr,
300
+ numblocks=numblocks,
301
+ dependencies=dependencies,
302
+ new_axes=self.new_axes,
303
+ concatenate=self.concatenate,
304
+ **kwargs2,
305
+ )
306
+ result = dict(graph)
307
+ # Merge in dependency graphs (from delayed objects, etc.)
308
+ for dep in dependencies:
309
+ if is_dask_collection(dep):
310
+ result.update(dep.__dask_graph__())
311
+ return result
312
+
313
+ def _lower(self):
314
+ if self.align_arrays:
315
+ _, arrays, changed = unify_chunks_expr(*self.args)
316
+ if changed:
317
+ args = []
318
+ for idx, arr in zip(self.args[1::2], arrays):
319
+ args.extend([arr, idx])
320
+ return type(self)(*self.operands[: len(self._parameters)], *args)
321
+
322
+ def _simplify_up(self, parent, dependents):
323
+ """Allow slice and shuffle operations to push through Blockwise."""
324
+ from dask_array._shuffle import Shuffle
325
+ from dask_array.slicing import SliceSlicesIntegers
326
+
327
+ if isinstance(parent, SliceSlicesIntegers):
328
+ return self._accept_slice(parent)
329
+ if isinstance(parent, Shuffle):
330
+ return self._accept_shuffle(parent)
331
+ return None
332
+
333
+ def _accept_shuffle(self, shuffle_expr):
334
+ """Accept a shuffle being pushed through Blockwise.
335
+
336
+ Push shuffle through when shuffle axis is not modified by blockwise.
337
+ """
338
+ import toolz
339
+
340
+ from dask_array._shuffle import Shuffle
341
+
342
+ axis = shuffle_expr.axis
343
+ out_ind = self.out_ind
344
+
345
+ # Get the index label for the shuffle axis
346
+ shuffle_ind = out_ind[axis]
347
+
348
+ # Can't push through if shuffle axis is a new axis or has adjusted chunks
349
+ new_axes = getattr(self, "new_axes", None)
350
+ if new_axes and shuffle_ind in new_axes:
351
+ return None
352
+ adjust_chunks = getattr(self, "adjust_chunks", None)
353
+ if adjust_chunks and shuffle_ind in adjust_chunks:
354
+ return None
355
+
356
+ # Shuffle each array input on the corresponding axis
357
+ new_args = []
358
+ for arr, ind in toolz.partition(2, self.args):
359
+ if ind is None:
360
+ # Literal argument
361
+ new_args.extend([arr, ind])
362
+ elif shuffle_ind in ind:
363
+ # Find the axis in this input that corresponds to shuffle_ind
364
+ input_axis = ind.index(shuffle_ind)
365
+ shuffled = Shuffle(arr, shuffle_expr.indexer, input_axis, shuffle_expr.operand("name"))
366
+ new_args.extend([shuffled, ind])
367
+ else:
368
+ # This input doesn't have the shuffle dimension
369
+ new_args.extend([arr, ind])
370
+
371
+ return Blockwise(
372
+ self.func,
373
+ self.out_ind,
374
+ self.operand("name"),
375
+ self.operand("token"),
376
+ self.operand("dtype"),
377
+ self.operand("adjust_chunks"),
378
+ self.operand("new_axes"),
379
+ self.operand("align_arrays"),
380
+ self.operand("concatenate"),
381
+ self.operand("_meta_provided"),
382
+ self.operand("kwargs"),
383
+ *new_args,
384
+ )
385
+
386
+ def _accept_slice(self, slice_expr):
387
+ """Accept a slice being pushed through this Blockwise.
388
+
389
+ This optimization is safe when:
390
+ - The blockwise doesn't adjust chunk sizes on sliced dimensions
391
+ - The blockwise doesn't add new axes on sliced dimensions
392
+ - The slice uses only slices or integers (no newaxis)
393
+
394
+ When adjust_chunks is present, we do "coarse" optimization:
395
+ - Calculate which output blocks the slice needs
396
+ - Select only the corresponding input blocks
397
+ - Wrap with adjusted output slice if needed
398
+ """
399
+ from numbers import Integral
400
+
401
+ from dask_array._new_collection import new_collection
402
+
403
+ out_ind = self.out_ind
404
+ index = slice_expr.index
405
+
406
+ # Don't handle None/newaxis
407
+ if any(idx is None for idx in index):
408
+ return None
409
+
410
+ # Pad index to full output length
411
+ full_index = index + (slice(None),) * (len(out_ind) - len(index))
412
+
413
+ # Find which output axes have non-trivial slices
414
+ sliced_axes = {i for i, idx in enumerate(full_index) if isinstance(idx, Integral) or idx != slice(None)}
415
+ sliced_indices = {out_ind[axis] for axis in sliced_axes if axis < len(out_ind)}
416
+
417
+ # Use getattr since subclasses may define as class attribute or property
418
+ adjust_chunks = getattr(self, "adjust_chunks", None)
419
+ needs_coarse = False
420
+ if adjust_chunks:
421
+ # Check if we're slicing an adjusted dimension
422
+ adjusted_indices = set(adjust_chunks.keys())
423
+ if sliced_indices & adjusted_indices:
424
+ # Use coarse slice optimization
425
+ needs_coarse = True
426
+
427
+ # Don't handle if blockwise adds new axes and we're slicing those axes
428
+ new_axes = getattr(self, "new_axes", None)
429
+ if new_axes:
430
+ new_axis_indices = set(new_axes.keys())
431
+ if sliced_indices & new_axis_indices:
432
+ return None
433
+
434
+ # For coarse optimization, calculate block-aligned slices
435
+ if needs_coarse:
436
+ return self._accept_slice_coarse(slice_expr, full_index, adjust_chunks)
437
+
438
+ # Convert integers to size-1 slices for pushdown
439
+ slice_index = tuple(slice(idx, idx + 1) if isinstance(idx, Integral) else idx for idx in full_index)
440
+ has_integers = any(isinstance(idx, Integral) for idx in full_index)
441
+
442
+ # For subclasses with a single "array" parameter, use substitute_parameters
443
+ if "array" in type(self)._parameters:
444
+ # Map output slice indices to input dimensions
445
+ arg_ind = tuple(range(self.array.ndim)) # Input indices
446
+ arg_slices = []
447
+ for dim_idx in arg_ind:
448
+ try:
449
+ out_pos = out_ind.index(dim_idx)
450
+ arg_slices.append(slice_index[out_pos])
451
+ except ValueError:
452
+ arg_slices.append(slice(None))
453
+
454
+ sliced_input = new_collection(self.array)[tuple(arg_slices)]
455
+ result = self.substitute_parameters({"array": sliced_input.expr})
456
+ else:
457
+ # For base Blockwise with multiple inputs in args
458
+ args = self.args
459
+ new_args = []
460
+ for i in range(0, len(args), 2):
461
+ arg = args[i]
462
+ arg_ind = args[i + 1]
463
+
464
+ if arg_ind is None:
465
+ new_args.extend([arg, arg_ind])
466
+ else:
467
+ arg_slices = []
468
+ for dim_idx in arg_ind:
469
+ try:
470
+ out_pos = out_ind.index(dim_idx)
471
+ arg_slices.append(slice_index[out_pos])
472
+ except ValueError:
473
+ arg_slices.append(slice(None))
474
+
475
+ sliced_arg = new_collection(arg)[tuple(arg_slices)]
476
+ new_args.extend([sliced_arg.expr, arg_ind])
477
+
478
+ result = Blockwise(
479
+ self.func,
480
+ self.out_ind,
481
+ self.operand("name"),
482
+ self.operand("token"),
483
+ self.operand("dtype"),
484
+ self.operand("adjust_chunks"),
485
+ self.operand("new_axes"),
486
+ self.operand("align_arrays"),
487
+ self.operand("concatenate"),
488
+ self.operand("_meta_provided"),
489
+ self.operand("kwargs"),
490
+ *new_args,
491
+ )
492
+
493
+ # If we converted integers to slices, extract with [0] to restore dimensions
494
+ if has_integers:
495
+ from dask_array.slicing import SliceSlicesIntegers
496
+
497
+ extract_index = tuple(0 if isinstance(idx, Integral) else slice(None) for idx in full_index)
498
+ return SliceSlicesIntegers(result, extract_index, slice_expr.allow_getitem_optimization)
499
+
500
+ return result
501
+
502
+ def _accept_slice_coarse(self, slice_expr, full_index, adjust_chunks):
503
+ """Coarse slice optimization for blockwise with adjust_chunks.
504
+
505
+ When chunk sizes change between input and output, we can't push the
506
+ exact slice through. But we CAN select only the input blocks that
507
+ contribute to the needed output blocks.
508
+
509
+ Algorithm:
510
+ 1. For each adjusted axis, find which OUTPUT blocks the slice needs
511
+ 2. Map to corresponding INPUT blocks (same block indices for blockwise)
512
+ 3. Create block-aligned input slices
513
+ 4. Wrap output with adjusted slice if original doesn't align to blocks
514
+ """
515
+ from numbers import Integral
516
+
517
+ from dask_array._new_collection import new_collection
518
+ from dask.utils import cached_cumsum
519
+
520
+ def find_block_range(cumsum, start, stop):
521
+ """Find (first_block, last_block) indices for range [start, stop)."""
522
+ # First block containing element at 'start'
523
+ first_block = np.searchsorted(cumsum[1:], start, side="right")
524
+ # Last block containing element before 'stop'
525
+ last_block = np.searchsorted(cumsum[1:], stop - 1, side="right") if stop > start else first_block - 1
526
+ if first_block >= len(cumsum) - 1:
527
+ return None, None # Out of bounds
528
+ return int(first_block), int(last_block)
529
+
530
+ out_ind = self.out_ind
531
+ out_chunks = self.chunks
532
+
533
+ # For each output axis, compute block range and output adjustment
534
+ block_ranges = [] # (first_block, last_block) for each axis
535
+ output_adjustments = [] # Adjusted slices to apply to output
536
+
537
+ for axis, idx in enumerate(full_index):
538
+ chunks = out_chunks[axis]
539
+ dim_size = sum(chunks)
540
+ cumsum = np.array(list(cached_cumsum(chunks, initial_zero=True)))
541
+
542
+ if idx == slice(None):
543
+ block_ranges.append(None) # All blocks
544
+ output_adjustments.append(slice(None))
545
+ elif isinstance(idx, Integral):
546
+ pos_idx = idx if idx >= 0 else idx + dim_size
547
+ first, last = find_block_range(cumsum, pos_idx, pos_idx + 1)
548
+ if first is None:
549
+ return None # Out of bounds
550
+ block_ranges.append((first, last))
551
+ output_adjustments.append(pos_idx - cumsum[first])
552
+ elif isinstance(idx, slice):
553
+ start, stop, step = idx.indices(dim_size)
554
+ if step != 1:
555
+ return None # Non-unit step not supported
556
+
557
+ first, last = find_block_range(cumsum, start, stop)
558
+ if first is None:
559
+ block_ranges.append((0, -1)) # Empty
560
+ output_adjustments.append(slice(0, 0))
561
+ else:
562
+ block_ranges.append((first, last))
563
+ coarse_start = cumsum[first]
564
+ coarse_end = cumsum[last + 1]
565
+ adj_start = start - coarse_start
566
+ adj_stop = stop - coarse_start
567
+ if adj_start == 0 and adj_stop == coarse_end - coarse_start:
568
+ output_adjustments.append(slice(None))
569
+ else:
570
+ output_adjustments.append(slice(adj_start, adj_stop))
571
+ else:
572
+ return None
573
+
574
+ # Map output block ranges to input slices
575
+ args = self.args
576
+ new_args = []
577
+
578
+ for i in range(0, len(args), 2):
579
+ arg = args[i]
580
+ arg_ind = args[i + 1]
581
+
582
+ if arg_ind is None:
583
+ new_args.extend([arg, arg_ind])
584
+ elif not hasattr(arg, "_meta"):
585
+ # Non-array args (e.g., ArrayValuesDep for block_info) can't be sliced
586
+ return None
587
+ else:
588
+ arg_slices = []
589
+ for dim_idx, in_ind in enumerate(arg_ind):
590
+ try:
591
+ out_pos = out_ind.index(in_ind)
592
+ br = block_ranges[out_pos]
593
+
594
+ if br is None:
595
+ arg_slices.append(slice(None))
596
+ else:
597
+ first, last = br
598
+ if last < first: # Empty
599
+ arg_slices.append(slice(0, 0))
600
+ else:
601
+ in_cumsum = list(cached_cumsum(arg.chunks[dim_idx], initial_zero=True))
602
+ arg_slices.append(slice(in_cumsum[first], in_cumsum[last + 1]))
603
+ except ValueError:
604
+ arg_slices.append(slice(None)) # Contracted dimension
605
+
606
+ sliced_arg = new_collection(arg)[tuple(arg_slices)]
607
+ new_args.extend([sliced_arg.expr, arg_ind])
608
+
609
+ # Slice adjust_chunks tuples/lists to match the new block ranges
610
+ new_adjust_chunks = self.operand("adjust_chunks")
611
+ if new_adjust_chunks:
612
+ new_adjust_chunks = dict(new_adjust_chunks) # Copy
613
+ for axis, br in enumerate(block_ranges):
614
+ if br is None:
615
+ continue
616
+ first, last = br
617
+ if last < first:
618
+ continue
619
+ ind = out_ind[axis]
620
+ if ind in new_adjust_chunks:
621
+ val = new_adjust_chunks[ind]
622
+ if isinstance(val, (tuple, list)):
623
+ # Slice the tuple to match the selected blocks
624
+ new_adjust_chunks[ind] = val[first : last + 1]
625
+
626
+ # Build the new Blockwise with coarse-sliced inputs
627
+ result = Blockwise(
628
+ self.func,
629
+ self.out_ind,
630
+ self.operand("name"),
631
+ self.operand("token"),
632
+ self.operand("dtype"),
633
+ new_adjust_chunks,
634
+ self.operand("new_axes"),
635
+ self.operand("align_arrays"),
636
+ self.operand("concatenate"),
637
+ self.operand("_meta_provided"),
638
+ self.operand("kwargs"),
639
+ *new_args,
640
+ )
641
+
642
+ # Check if we need output adjustment
643
+ needs_output_slice = any(adj != slice(None) for adj in output_adjustments)
644
+
645
+ if needs_output_slice:
646
+ from dask_array.slicing import SliceSlicesIntegers
647
+
648
+ # Build the output adjustment index
649
+ adj_index = tuple(output_adjustments)
650
+ return SliceSlicesIntegers(result, adj_index, slice_expr.allow_getitem_optimization)
651
+
652
+ return result
653
+
654
+
655
+ class Elemwise(Blockwise):
656
+ _parameters = ["op", "dtype", "name", "where", "out", "_user_kwargs"]
657
+ _defaults = {
658
+ "dtype": None,
659
+ "name": None,
660
+ "where": True,
661
+ "out": None,
662
+ "_user_kwargs": None,
663
+ }
664
+ align_arrays = True
665
+ new_axes: dict = {}
666
+ adjust_chunks = None
667
+ concatenate = None
668
+
669
+ @property
670
+ def user_kwargs(self):
671
+ return self.operand("_user_kwargs") or {}
672
+
673
+ @cached_property
674
+ def _meta(self):
675
+ # When where is not True, _info[0] is _elemwise_handle_where which
676
+ # expects args to end with (where, out)
677
+ args = list(self.elemwise_args)
678
+ if self.where is not True:
679
+ args.extend([self.where, self.out])
680
+ return compute_meta(self._info[0], self.dtype, *args, **self.kwargs)
681
+
682
+ @property
683
+ def elemwise_args(self):
684
+ return self.operands[len(self._parameters) :]
685
+
686
+ def dependencies(self):
687
+ """Return expression dependencies.
688
+
689
+ When where is True (the default), 'out' is not actually used in
690
+ the computation - it's just a placeholder for _handle_out to
691
+ replace the expression. Exclude it from dependencies to avoid
692
+ fusion issues, UNLESS out is also an input (e.g., np.sin(x, out=x)).
693
+ """
694
+ deps = super().dependencies()
695
+ if self.where is True and self.out is not None:
696
+ out_name = getattr(self.out, "_name", None)
697
+ # Only exclude if out is not also an input argument
698
+ input_names = {getattr(a, "_name", None) for a in self.elemwise_args if hasattr(a, "_name")}
699
+ if out_name and out_name not in input_names:
700
+ deps = [d for d in deps if d._name != out_name]
701
+ return deps
702
+
703
+ @property
704
+ def out_ind(self):
705
+ shapes = []
706
+ for arg in self.elemwise_args:
707
+ shape = getattr(arg, "shape", ())
708
+ if any(is_dask_collection(x) for x in shape):
709
+ # Want to exclude Delayed shapes and dd.Scalar
710
+ shape = ()
711
+ shapes.append(shape)
712
+ if isinstance(self.where, ArrayExpr):
713
+ shapes.append(self.where.shape)
714
+ if isinstance(self.out, ArrayExpr):
715
+ shapes.append(self.out.shape)
716
+
717
+ shapes = [s if isinstance(s, Iterable) else () for s in shapes]
718
+ out_ndim = len(broadcast_shapes(*shapes)) # Raises ValueError if dimensions mismatch
719
+ return tuple(range(out_ndim))[::-1]
720
+
721
+ @cached_property
722
+ def _info(self):
723
+ if self.operand("dtype") is not None:
724
+ need_enforce_dtype = True
725
+ dtype = np.dtype(self.operand("dtype"))
726
+ else:
727
+ # We follow NumPy's rules for dtype promotion, which special cases
728
+ # scalars and 0d ndarrays (which it considers equivalent) by using
729
+ # their values to compute the result dtype:
730
+ # https://github.com/numpy/numpy/issues/6240
731
+ # We don't inspect the values of 0d dask arrays, because these could
732
+ # hold potentially very expensive calculations. Instead, we treat
733
+ # them just like other arrays, and if necessary cast the result of op
734
+ # to match.
735
+ vals = [
736
+ (np.empty((1,) * max(1, a.ndim), dtype=a.dtype) if not is_scalar_for_elemwise(a) else a)
737
+ for a in self.elemwise_args
738
+ ]
739
+ try:
740
+ dtype = apply_infer_dtype(self.op, vals, self.user_kwargs, "elemwise", suggest_dtype=False)
741
+ except Exception:
742
+ raise NotImplementedError
743
+ need_enforce_dtype = any(not is_scalar_for_elemwise(a) and a.ndim == 0 for a in self.elemwise_args)
744
+
745
+ blockwise_kwargs = {}
746
+ op = self.op
747
+ if self.where is not True:
748
+ blockwise_kwargs["elemwise_where_function"] = op
749
+ op = _elemwise_handle_where
750
+
751
+ if need_enforce_dtype:
752
+ blockwise_kwargs.update(
753
+ {
754
+ "enforce_dtype": dtype,
755
+ "enforce_dtype_function": op,
756
+ }
757
+ )
758
+ op = _enforce_dtype
759
+
760
+ return op, dtype, blockwise_kwargs
761
+
762
+ @property
763
+ def func(self):
764
+ return self._info[0]
765
+
766
+ @property
767
+ def dtype(self):
768
+ return self._info[1]
769
+
770
+ @property
771
+ def kwargs(self):
772
+ # Merge user kwargs with internal kwargs (dtype enforcement, where handling)
773
+ return {**self.user_kwargs, **self._info[2]}
774
+
775
+ @property
776
+ def token(self):
777
+ return funcname(self.op).strip("_")
778
+
779
+ @property
780
+ def args(self):
781
+ # for Blockwise rather than Elemwise
782
+ # When where is an array, append [where, out] for _elemwise_handle_where
783
+ extra_args = []
784
+ if self.where is not True:
785
+ extra_args.append(self.where)
786
+ extra_args.append(self.out)
787
+ return tuple(
788
+ toolz.concat(
789
+ (
790
+ a,
791
+ (tuple(range(a.ndim)[::-1]) if not is_scalar_for_elemwise(a) else None),
792
+ )
793
+ for a in self.elemwise_args + extra_args
794
+ )
795
+ )
796
+
797
+ def _lower(self):
798
+ # Override Blockwise._lower to handle Elemwise's different operand structure.
799
+ # Elemwise stores just arrays in operands, but args generates (array, indices) pairs.
800
+ # After unifying chunks, we only pass the unified arrays (not indices) to the constructor.
801
+ if self.align_arrays:
802
+ _, arrays, changed = unify_chunks_expr(*self.args)
803
+ if changed:
804
+ # Only pass the unified arrays, not the indices
805
+ # When where is an array, the last two arrays are where and out
806
+ if self.where is not True:
807
+ new_elemwise_args = arrays[:-2]
808
+ new_where = arrays[-2]
809
+ new_out = arrays[-1]
810
+ else:
811
+ new_elemwise_args = arrays
812
+ new_where = True
813
+ new_out = None
814
+ return Elemwise(
815
+ self.op,
816
+ self.operand("dtype"),
817
+ self.operand("name"),
818
+ new_where,
819
+ new_out,
820
+ self.operand("_user_kwargs"),
821
+ *new_elemwise_args,
822
+ )
823
+
824
+ def _task(self, key, block_id: tuple[int, ...]) -> Task:
825
+ """Generate task for a specific output block.
826
+
827
+ Parameters
828
+ ----------
829
+ key : tuple
830
+ The output key for this task (e.g., ('add-abc123', 0, 1))
831
+ block_id : tuple[int, ...]
832
+ The block coordinates (e.g., (0, 1) for block at row 0, col 1)
833
+
834
+ Returns
835
+ -------
836
+ Task
837
+ A Task object that computes this block
838
+ """
839
+ args = []
840
+
841
+ # Process elemwise_args
842
+ for arg in self.elemwise_args:
843
+ if is_scalar_for_elemwise(arg):
844
+ args.append(arg)
845
+ else:
846
+ # Array argument - compute block_id adjusted for broadcasting
847
+ # For broadcasting: use 0 for dimensions where array has 1 block
848
+ arg_block_id = self._broadcast_block_id(arg, block_id)
849
+ args.append(TaskRef((arg.name, *arg_block_id)))
850
+
851
+ # Handle where/out arrays if present
852
+ if self.where is not True:
853
+ if is_scalar_for_elemwise(self.where):
854
+ args.append(self.where)
855
+ else:
856
+ where_block_id = self._broadcast_block_id(self.where, block_id)
857
+ args.append(TaskRef((self.where.name, *where_block_id)))
858
+
859
+ if self.out is None or is_scalar_for_elemwise(self.out):
860
+ args.append(self.out)
861
+ else:
862
+ out_block_id = self._broadcast_block_id(self.out, block_id)
863
+ args.append(TaskRef((self.out.name, *out_block_id)))
864
+
865
+ if self.kwargs:
866
+ return Task(key, self.func, *args, **self.kwargs)
867
+ else:
868
+ return Task(key, self.func, *args)
869
+
870
+ def _broadcast_block_id(self, arr, block_id: tuple[int, ...]) -> tuple[int, ...]:
871
+ """Adjust block_id for broadcasting."""
872
+ return _broadcast_block_id(arr.numblocks, block_id)
873
+
874
+ def _input_block_id(self, dep, block_id: tuple[int, ...]) -> tuple[int, ...]:
875
+ """Map output block_id to input block_id for a dependency.
876
+
877
+ For Elemwise, this handles broadcasting - same block_id adjusted
878
+ for arrays with fewer dimensions or single-block dimensions.
879
+ """
880
+ return self._broadcast_block_id(dep, block_id)
881
+
882
+ def _accept_slice(self, slice_expr):
883
+ """Accept a slice being pushed through this Elemwise.
884
+
885
+ Returns a new Elemwise with the slice pushed to each input,
886
+ handling broadcasting appropriately.
887
+ """
888
+ from numbers import Integral
889
+
890
+ from dask_array._new_collection import new_collection
891
+
892
+ out_ind = self.out_ind
893
+ index = slice_expr.index
894
+
895
+ # Pad index to full length
896
+ full_index = index + (slice(None),) * (len(out_ind) - len(index))
897
+
898
+ # Build sliced inputs
899
+ new_args = []
900
+ for arg in self.elemwise_args:
901
+ if is_scalar_for_elemwise(arg):
902
+ new_args.append(arg)
903
+ else:
904
+ # Map output slice to this input's dimensions
905
+ # arg has indices tuple(range(arg.ndim)[::-1])
906
+ arg_ind = tuple(range(arg.ndim)[::-1])
907
+ arg_shape = arg.shape
908
+
909
+ # For each dimension of arg, find where its index appears in out_ind
910
+ # and get the corresponding slice
911
+ arg_slices = []
912
+ for i, dim_idx in enumerate(arg_ind):
913
+ # Find position of this index in out_ind
914
+ try:
915
+ out_pos = out_ind.index(dim_idx)
916
+ out_slice = full_index[out_pos]
917
+ # Handle size-1 (broadcast) dimensions specially:
918
+ # - For slices: use slice(None) to preserve broadcast semantics,
919
+ # EXCEPT for empty output slices (like [:0]) which must be preserved
920
+ # - For integers: use 0 instead of the original index (which may be
921
+ # out of bounds for the size-1 input)
922
+ if arg_shape[i] == 1:
923
+ if isinstance(out_slice, slice):
924
+ out_dim_size = self.shape[out_pos]
925
+ start, stop, step = out_slice.indices(out_dim_size)
926
+ if len(range(start, stop, step)) == 0:
927
+ # Empty output slice - preserve it
928
+ arg_slices.append(out_slice)
929
+ else:
930
+ arg_slices.append(slice(None))
931
+ elif isinstance(out_slice, Integral):
932
+ # Integer index on broadcast dim - use 0
933
+ arg_slices.append(0)
934
+ else:
935
+ arg_slices.append(out_slice)
936
+ else:
937
+ arg_slices.append(out_slice)
938
+ except ValueError:
939
+ # Index not in output (shouldn't happen for elemwise)
940
+ arg_slices.append(slice(None))
941
+
942
+ sliced_arg = new_collection(arg)[tuple(arg_slices)]
943
+ new_args.append(sliced_arg.expr)
944
+
945
+ return Elemwise(
946
+ self.op,
947
+ self.operand("dtype"),
948
+ self.operand("name"),
949
+ self.where,
950
+ self.out,
951
+ self.operand("_user_kwargs"),
952
+ *new_args,
953
+ )
954
+
955
+ def _accept_shuffle(self, shuffle_expr):
956
+ """Accept a shuffle being pushed through this Elemwise.
957
+
958
+ Push shuffle through by shuffling each input array on the corresponding
959
+ axis, accounting for broadcasting. Inputs that broadcast on the shuffle
960
+ axis (size-1 or fewer dimensions) are not shuffled.
961
+ """
962
+ from dask_array._shuffle import Shuffle
963
+
964
+ axis = shuffle_expr.axis
965
+ indexer = shuffle_expr.indexer
966
+ name = shuffle_expr.operand("name")
967
+ output_ndim = len(self.shape)
968
+
969
+ def get_input_axis(arg):
970
+ """Get the corresponding axis in input for the output shuffle axis.
971
+
972
+ Returns the input axis, or None if the input broadcasts on this axis.
973
+ For broadcasting, input axes are aligned to the right of output axes.
974
+ """
975
+ if is_scalar_for_elemwise(arg):
976
+ return None
977
+ # Input axis = output axis - (dimensions added by broadcasting)
978
+ input_axis = axis - (output_ndim - arg.ndim)
979
+ if input_axis < 0:
980
+ # This input doesn't have the shuffle axis (broadcasts on it)
981
+ return None
982
+ if arg.shape[input_axis] == 1:
983
+ # Size-1 dimensions broadcast, don't shuffle
984
+ return None
985
+ return input_axis
986
+
987
+ # Shuffle each array input on its corresponding axis
988
+ new_args = []
989
+ for arg in self.elemwise_args:
990
+ input_axis = get_input_axis(arg)
991
+ if input_axis is not None:
992
+ new_args.append(Shuffle(arg, indexer, input_axis, name))
993
+ else:
994
+ new_args.append(arg)
995
+
996
+ # Shuffle where/out if they are arrays
997
+ new_where = self.where
998
+ input_axis = get_input_axis(new_where) if hasattr(new_where, "ndim") else None
999
+ if input_axis is not None:
1000
+ new_where = Shuffle(new_where, indexer, input_axis, name)
1001
+
1002
+ new_out = self.out
1003
+ input_axis = get_input_axis(new_out) if hasattr(new_out, "ndim") else None
1004
+ if input_axis is not None:
1005
+ new_out = Shuffle(new_out, indexer, input_axis, name)
1006
+
1007
+ return Elemwise(
1008
+ self.op,
1009
+ self.operand("dtype"),
1010
+ self.operand("name"),
1011
+ new_where,
1012
+ new_out,
1013
+ self.operand("_user_kwargs"),
1014
+ *new_args,
1015
+ )
1016
+
1017
+
1018
+ def _broadcast_block_id(numblocks: tuple[int, ...], block_id: tuple[int, ...]) -> tuple[int, ...]:
1019
+ """Adjust block_id for broadcasting.
1020
+
1021
+ When an array has fewer dimensions or single-block dimensions,
1022
+ we need to adjust the block indices accordingly.
1023
+ """
1024
+ out_ndim = len(block_id)
1025
+ arr_ndim = len(numblocks)
1026
+
1027
+ # Handle dimension mismatch (broadcasting adds leading dims)
1028
+ offset = out_ndim - arr_ndim
1029
+
1030
+ result = []
1031
+ for i, nb in enumerate(numblocks):
1032
+ out_idx = offset + i
1033
+ if nb == 1:
1034
+ # Single block in this dimension - always use 0
1035
+ result.append(0)
1036
+ else:
1037
+ result.append(block_id[out_idx])
1038
+ return tuple(result)
1039
+
1040
+
1041
+ def _compute_block_id(ind: tuple, idx_to_block: dict, numblocks: tuple[int, ...]) -> tuple[int, ...]:
1042
+ """Compute block_id for a dependency given symbolic indices.
1043
+
1044
+ Maps symbolic indices to block coordinates using idx_to_block mapping.
1045
+ Handles contracted dimensions (indices in input but not output) by using
1046
+ block 0 when the dimension has only 1 block.
1047
+ """
1048
+ result = []
1049
+ for dim, i in enumerate(ind):
1050
+ if i in idx_to_block:
1051
+ result.append(idx_to_block[i] % numblocks[dim])
1052
+ elif numblocks[dim] == 1:
1053
+ # Contracted dimension with single block - use block 0
1054
+ result.append(0)
1055
+ else:
1056
+ raise ValueError(
1057
+ f"Cannot determine block for index {i}: not in output indices "
1058
+ f"and input has {numblocks[dim]} blocks in dimension {dim}"
1059
+ )
1060
+ return tuple(result)
1061
+
1062
+
1063
+ def is_fusable_blockwise(expr):
1064
+ """Check if an expression is a fusable Blockwise operation.
1065
+
1066
+ Returns True if the expression has _is_blockwise_fusable = True.
1067
+ This includes Blockwise (without concatenate), BroadcastTrick, and Random.
1068
+ """
1069
+ return getattr(expr, "_is_blockwise_fusable", False)
1070
+
1071
+
1072
+ # Alias for internal use
1073
+ is_fusable_elemwise = is_fusable_blockwise
1074
+
1075
+
1076
+ def _symbolic_mapping(expr, parent_mapping):
1077
+ """Compute symbolic block mapping from root dimensions to dependency dimensions.
1078
+
1079
+ A symbolic mapping is a tuple where each element indicates which root output
1080
+ dimension maps to that position. For example:
1081
+ - (0, 1) means block = (root_dim_0, root_dim_1)
1082
+ - (2, 1) means block = (root_dim_2, root_dim_1)
1083
+
1084
+ This allows detecting conflicts symbolically without sampling.
1085
+ """
1086
+ from dask_array.manipulation._transpose import Transpose
1087
+
1088
+ result = {}
1089
+
1090
+ if isinstance(expr, Transpose):
1091
+ # Transpose permutes dimensions: output[i] comes from input[axes[i]]
1092
+ # So if parent has mapping M, our input has mapping M permuted by inverse_axes
1093
+ inv = expr._inverse_axes
1094
+ dep_mapping = tuple(parent_mapping[inv[i]] for i in range(len(inv)))
1095
+ dep = expr.array
1096
+ if hasattr(dep, "_name"):
1097
+ result[dep._name] = [dep_mapping]
1098
+ elif hasattr(expr, "out_ind") and hasattr(expr, "args"):
1099
+ # Blockwise: each arg has indices that select from out_ind
1100
+ idx_to_parent = {}
1101
+ for dim, idx in enumerate(expr.out_ind):
1102
+ idx_to_parent[idx] = parent_mapping[dim] if dim < len(parent_mapping) else dim
1103
+
1104
+ for arr, ind in toolz.partition(2, expr.args):
1105
+ if ind is not None and hasattr(arr, "_name"):
1106
+ # Map each position in ind to root dimension
1107
+ dep_mapping = tuple(idx_to_parent.get(i, i) for i in ind)
1108
+ if arr._name not in result:
1109
+ result[arr._name] = []
1110
+ result[arr._name].append(dep_mapping)
1111
+ else:
1112
+ # For other expression types (e.g., Random), use identity mapping
1113
+ # through dependencies - each dep gets the same mapping as parent
1114
+ for dep in expr.dependencies():
1115
+ if hasattr(dep, "_name") and dep.ndim == len(parent_mapping):
1116
+ result[dep._name] = [parent_mapping]
1117
+
1118
+ return result
1119
+
1120
+
1121
+ def _remove_conflicting_exprs(group):
1122
+ """Remove expressions accessed with conflicting block patterns.
1123
+
1124
+ When the same expression is accessed via multiple paths with different
1125
+ index transformations (e.g., a + a.T), we can't fuse it - each output
1126
+ block would need different source blocks from the same expression.
1127
+
1128
+ Uses symbolic analysis: traces how root output dimensions map to each
1129
+ expression's block dimensions through the expression tree. If the same
1130
+ expression is reached via paths with different symbolic mappings, it's
1131
+ a conflict.
1132
+
1133
+ Also removes expressions that become unreachable after conflict removal.
1134
+ """
1135
+ if len(group) <= 1:
1136
+ return group
1137
+
1138
+ expr_names = {e._name for e in group}
1139
+ expr_map = {e._name: e for e in group}
1140
+ root = group[0]
1141
+
1142
+ # Symbolic mapping: tuple of root dimension indices for each expression
1143
+ # (0, 1) means "root dim 0 for position 0, root dim 1 for position 1"
1144
+ symbolic_mappings = {root._name: tuple(range(root.ndim))}
1145
+ conflicts = set()
1146
+
1147
+ for expr in group:
1148
+ if expr._name not in symbolic_mappings:
1149
+ continue
1150
+ my_mapping = symbolic_mappings[expr._name]
1151
+
1152
+ # Get symbolic mappings for all dependencies
1153
+ dep_mappings = _symbolic_mapping(expr, my_mapping)
1154
+
1155
+ for dep_name, mappings_list in dep_mappings.items():
1156
+ if dep_name not in expr_names:
1157
+ continue
1158
+
1159
+ for dep_mapping in mappings_list:
1160
+ if dep_name in symbolic_mappings:
1161
+ if symbolic_mappings[dep_name] != dep_mapping:
1162
+ conflicts.add(dep_name)
1163
+ else:
1164
+ symbolic_mappings[dep_name] = dep_mapping
1165
+
1166
+ if not conflicts:
1167
+ return group
1168
+
1169
+ # Remove conflicts and find reachable expressions
1170
+ remaining = {e._name for e in group if e._name not in conflicts}
1171
+ reachable = {root._name}
1172
+ stack = [root]
1173
+
1174
+ while stack:
1175
+ expr = stack.pop()
1176
+ for dep in expr.dependencies():
1177
+ if dep._name in remaining and dep._name not in reachable:
1178
+ reachable.add(dep._name)
1179
+ stack.append(expr_map[dep._name])
1180
+
1181
+ return [e for e in group if e._name in reachable]
1182
+
1183
+
1184
+ def optimize_blockwise_fusion_array(expr):
1185
+ """Traverse the expression graph and apply fusion.
1186
+
1187
+ Finds groups of consecutive fusable Blockwise operations and fuses them
1188
+ into single FusedBlockwise expressions.
1189
+ """
1190
+ from collections import defaultdict
1191
+
1192
+ def _fusion_pass(expr):
1193
+ # Build dependency graph of fusable operations
1194
+ seen = set()
1195
+ stack = [expr]
1196
+ dependents = defaultdict(set) # name -> set of dependent names
1197
+ dependencies = {} # name -> set of dependency names
1198
+ expr_mapping = {} # name -> expr
1199
+
1200
+ while stack:
1201
+ node = stack.pop()
1202
+
1203
+ if node._name in seen:
1204
+ continue
1205
+ seen.add(node._name)
1206
+
1207
+ if is_fusable_elemwise(node):
1208
+ dependencies[node._name] = set()
1209
+ if node._name not in dependents:
1210
+ dependents[node._name] = set()
1211
+ expr_mapping[node._name] = node
1212
+
1213
+ for operand in node.dependencies():
1214
+ stack.append(operand)
1215
+ if is_fusable_elemwise(operand):
1216
+ if node._name in dependencies:
1217
+ dependencies[node._name].add(operand._name)
1218
+ dependents[operand._name].add(node._name)
1219
+ expr_mapping[operand._name] = operand
1220
+ expr_mapping[node._name] = node
1221
+
1222
+ # Find roots - Elemwise nodes with no Elemwise dependents
1223
+ roots = [
1224
+ expr_mapping[k]
1225
+ for k, v in dependents.items()
1226
+ if v == set() or all(not is_fusable_elemwise(expr_mapping.get(_name)) for _name in v)
1227
+ ]
1228
+
1229
+ while roots:
1230
+ root = roots.pop()
1231
+ seen_in_group = set()
1232
+ stack = [root]
1233
+ group = []
1234
+
1235
+ while stack:
1236
+ node = stack.pop()
1237
+
1238
+ if node._name in seen_in_group:
1239
+ continue
1240
+ seen_in_group.add(node._name)
1241
+
1242
+ group.append(node)
1243
+ for dep_name in dependencies.get(node._name, set()):
1244
+ dep = expr_mapping[dep_name]
1245
+
1246
+ stack_names = {s._name for s in stack}
1247
+ group_names = {g._name for g in group}
1248
+
1249
+ # Check if all dependents of dep are in our group or stack
1250
+ dep_dependents = dependents.get(dep_name, set())
1251
+ if dep_dependents <= (stack_names | group_names | {node._name}):
1252
+ # dep can be fused into this group
1253
+ stack.append(dep)
1254
+ elif dependencies.get(dep._name) and dep._name not in [r._name for r in roots]:
1255
+ # Can't fuse dep, but may be able to use as new root
1256
+ roots.append(dep)
1257
+
1258
+ # Replace fusable sub-group
1259
+ if len(group) > 1:
1260
+ # Check for conflicting block patterns before fusing
1261
+ group = _remove_conflicting_exprs(group)
1262
+ if len(group) > 1:
1263
+ fused = FusedBlockwise(tuple(group))
1264
+ new_expr = expr.substitute(group[0], fused)
1265
+ return new_expr, not roots
1266
+
1267
+ # No fusable groups found
1268
+ return expr, True
1269
+
1270
+ # Iterate until no more fusion is possible
1271
+ while True:
1272
+ original_name = expr._name
1273
+ expr, done = _fusion_pass(expr)
1274
+ if done or expr._name == original_name:
1275
+ break
1276
+
1277
+ return expr
1278
+
1279
+
1280
+ class FusedBlockwise(ArrayExpr):
1281
+ """Fused blockwise operations for arrays.
1282
+
1283
+ A FusedBlockwise corresponds to the fusion of multiple Blockwise/Elemwise
1284
+ expressions into a single Expr object. At graph-materialization time,
1285
+ the behavior produces fused tasks that execute all operations together.
1286
+
1287
+ Parameters
1288
+ ----------
1289
+ exprs : tuple[Expr, ...]
1290
+ Group of original Expr objects being fused together. The first
1291
+ expression is the "root" (final output).
1292
+ *dependencies :
1293
+ External Expr dependencies - any Expr operand not included in exprs.
1294
+ These are passed as additional operands after exprs.
1295
+ """
1296
+
1297
+ _parameters = ["exprs"]
1298
+
1299
+ @property
1300
+ def _meta(self):
1301
+ return self.exprs[0]._meta
1302
+
1303
+ @property
1304
+ def chunks(self):
1305
+ return self.exprs[0].chunks
1306
+
1307
+ @property
1308
+ def dtype(self):
1309
+ return self.exprs[0].dtype
1310
+
1311
+ def dependencies(self):
1312
+ """Return external dependencies not included in the fused group."""
1313
+ fused_names = {e._name for e in self.exprs}
1314
+ external_deps = []
1315
+ seen = set()
1316
+ for expr in self.exprs:
1317
+ for dep in expr.dependencies():
1318
+ if dep._name not in fused_names and dep._name not in seen:
1319
+ external_deps.append(dep)
1320
+ seen.add(dep._name)
1321
+ return external_deps
1322
+
1323
+ def _layer(self):
1324
+ result = {}
1325
+ for block_id in product(*[range(n) for n in self.numblocks]):
1326
+ key = (self._name, *block_id)
1327
+ result[key] = self._task(key, block_id)
1328
+ return result
1329
+
1330
+ def _task(self, key, block_id: tuple[int, ...]) -> Task:
1331
+ """Generate a fused task for a specific output block."""
1332
+ # Compute block_id for each expression by tracing through dependencies
1333
+ # Each expression type (Elemwise, Transpose) has its own block mapping
1334
+ expr_block_ids = self._compute_block_ids(block_id)
1335
+
1336
+ # Generate tasks in dependency order (leaves first for Task.fuse)
1337
+ internal_tasks = []
1338
+ for expr in reversed(self.exprs):
1339
+ expr_block_id = expr_block_ids[expr._name]
1340
+ subname = (expr._name, *expr_block_id)
1341
+ t = expr._task(subname, expr_block_id)
1342
+ internal_tasks.append(t)
1343
+ return Task.fuse(*internal_tasks, key=key) # type: ignore[return-value]
1344
+
1345
+ def _compute_block_ids(self, output_block_id: tuple[int, ...]) -> dict:
1346
+ """Compute block_id for each expression given the output block_id.
1347
+
1348
+ Traces through the expression chain, using each expression's
1349
+ _input_block_id method to map output to input block coordinates.
1350
+ """
1351
+ expr_names = {e._name for e in self.exprs}
1352
+ expr_block_ids = {self.exprs[0]._name: output_block_id}
1353
+
1354
+ for expr in self.exprs:
1355
+ my_block_id = expr_block_ids[expr._name]
1356
+ for dep in expr.dependencies():
1357
+ if dep._name in expr_names and dep._name not in expr_block_ids:
1358
+ dep_block_id = expr._input_block_id(dep, my_block_id)
1359
+ expr_block_ids[dep._name] = dep_block_id
1360
+
1361
+ return expr_block_ids
1362
+
1363
+ def __str__(self):
1364
+ names = [expr._name.split("-")[0] for expr in self.exprs]
1365
+ if len(names) > 4:
1366
+ return f"{names[0]}-fused-{names[-1]}"
1367
+ return "-".join(names)
1368
+
1369
+ @cached_property
1370
+ def _name(self):
1371
+ return f"{self}-{self.deterministic_token}"
1372
+
1373
+
1374
+ def outer(a, b):
1375
+ """
1376
+ Compute the outer product of two vectors.
1377
+
1378
+ This docstring was copied from numpy.outer.
1379
+
1380
+ Some inconsistencies with the Dask version may exist.
1381
+
1382
+ Given two vectors, ``a = [a0, a1, ..., aM]`` and
1383
+ ``b = [b0, b1, ..., bN]``,
1384
+ the outer product is::
1385
+
1386
+ [[a0*b0 a0*b1 ... a0*bN ]
1387
+ [a1*b0 .
1388
+ [ ... .
1389
+ [aM*b0 aM*bN ]]
1390
+
1391
+ Parameters
1392
+ ----------
1393
+ a : (M,) array_like
1394
+ First input vector. Input is flattened if not already 1-dimensional.
1395
+ b : (N,) array_like
1396
+ Second input vector. Input is flattened if not already 1-dimensional.
1397
+
1398
+ Returns
1399
+ -------
1400
+ out : (M, N) ndarray
1401
+ ``out[i, j] = a[i] * b[j]``
1402
+ """
1403
+ from dask_array._collection import asarray, blockwise
1404
+
1405
+ a = asarray(a).flatten()
1406
+ b = asarray(b).flatten()
1407
+
1408
+ dtype = np.outer(a.dtype.type(), b.dtype.type()).dtype
1409
+
1410
+ return blockwise(np.outer, "ij", a, "i", b, "j", dtype=dtype)