dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_xarray.py ADDED
@@ -0,0 +1,337 @@
1
+ """
2
+ xarray ChunkManager integration for dask-array expressions.
3
+
4
+ This module registers a ChunkManagerEntrypoint under the entry point name
5
+ "dask" — the same name used by xarray's built-in DaskManager. We *must*
6
+ replace the built-in rather than coexist alongside it because:
7
+
8
+ 1. ``dask_array.Array`` is a dask collection (implements ``__dask_graph__``)
9
+ and a duck array, so xarray's built-in ``DaskManager.is_chunked_array``
10
+ recognises it via ``is_duck_dask_array``.
11
+ 2. If two managers both claim the same array type, xarray's
12
+ ``get_chunked_array_type`` raises
13
+ ``"Multiple ChunkManagers recognise type ..."``.
14
+ 3. Therefore only one "dask"-flavoured manager can be active at a time.
15
+
16
+ Because both xarray and dask-array register an entry point named "dask",
17
+ the winner of ``importlib.metadata.entry_points()`` iteration is
18
+ non-deterministic (it depends on filesystem enumeration order). To make
19
+ the result reproducible, ``_ensure_registered`` mutates the cached dict
20
+ returned by ``list_chunkmanagers()`` at import time so that our manager
21
+ is always the one stored under the "dask" key.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ from collections.abc import Callable, Iterable, Mapping, Sequence
27
+ from typing import TYPE_CHECKING, Any
28
+
29
+ import numpy as np
30
+ from xarray.namedarray.parallelcompat import ChunkManagerEntrypoint
31
+
32
+ if TYPE_CHECKING:
33
+ from dask_array._collection import Array
34
+
35
+
36
+ class DaskArrayExprManager(ChunkManagerEntrypoint["Array"]):
37
+ """
38
+ ChunkManager for dask-array expressions.
39
+
40
+ This integrates dask_array.Array with xarray's chunked array interface,
41
+ enabling expression-based optimizations for xarray operations.
42
+ """
43
+
44
+ array_cls: type[Array]
45
+ available: bool = True
46
+
47
+ def __init__(self) -> None:
48
+ from dask_array._collection import Array
49
+
50
+ self.array_cls = Array
51
+
52
+ def is_chunked_array(self, data: Any) -> bool:
53
+ return isinstance(data, self.array_cls)
54
+
55
+ def chunks(self, data: Array) -> tuple[tuple[int, ...], ...]:
56
+ return data.chunks
57
+
58
+ def normalize_chunks(
59
+ self,
60
+ chunks: Any,
61
+ shape: tuple[int, ...] | None = None,
62
+ limit: int | None = None,
63
+ dtype: np.dtype[Any] | None = None,
64
+ previous_chunks: tuple[tuple[int, ...], ...] | None = None,
65
+ ) -> tuple[tuple[int, ...], ...]:
66
+ from dask_array._core_utils import normalize_chunks
67
+
68
+ return normalize_chunks(
69
+ chunks,
70
+ shape=shape,
71
+ limit=limit,
72
+ dtype=dtype,
73
+ previous_chunks=previous_chunks,
74
+ )
75
+
76
+ def from_array(
77
+ self,
78
+ data: Any,
79
+ chunks: Any,
80
+ **kwargs: Any,
81
+ ) -> Array:
82
+ import dask_array as da
83
+ from xarray.core.indexing import ImplicitToExplicitIndexingAdapter
84
+
85
+ if isinstance(data, ImplicitToExplicitIndexingAdapter):
86
+ # Lazily loaded backend arrays should use NumPy arrays for meta.
87
+ kwargs["meta"] = np.ndarray
88
+
89
+ return da.from_array(data, chunks, **kwargs)
90
+
91
+ def rechunk(
92
+ self,
93
+ data: Array,
94
+ chunks: Any,
95
+ **kwargs: Any,
96
+ ) -> Array:
97
+ return data.rechunk(chunks, **kwargs)
98
+
99
+ def compute(
100
+ self,
101
+ *data: Array | Any,
102
+ **kwargs: Any,
103
+ ) -> tuple[np.ndarray[Any, Any], ...]:
104
+ from dask import compute
105
+
106
+ return compute(*data, **kwargs)
107
+
108
+ def persist(
109
+ self,
110
+ *data: Array | Any,
111
+ **kwargs: Any,
112
+ ) -> tuple[Array | Any, ...]:
113
+ from dask import persist
114
+
115
+ return persist(*data, **kwargs)
116
+
117
+ @property
118
+ def array_api(self) -> Any:
119
+ import dask_array as da
120
+
121
+ return da
122
+
123
+ def reduction(
124
+ self,
125
+ arr: Array,
126
+ func: Callable[..., Any],
127
+ combine_func: Callable[..., Any] | None = None,
128
+ aggregate_func: Callable[..., Any] | None = None,
129
+ axis: int | Sequence[int] | None = None,
130
+ dtype: np.dtype[Any] | None = None,
131
+ keepdims: bool = False,
132
+ ) -> Array:
133
+ from dask_array import reduction
134
+
135
+ return reduction(
136
+ arr,
137
+ chunk=func,
138
+ combine=combine_func,
139
+ aggregate=aggregate_func,
140
+ axis=axis,
141
+ dtype=dtype,
142
+ keepdims=keepdims,
143
+ )
144
+
145
+ def scan(
146
+ self,
147
+ func: Callable[..., Any],
148
+ binop: Callable[..., Any],
149
+ ident: float,
150
+ arr: Array,
151
+ axis: int | None = None,
152
+ dtype: np.dtype[Any] | None = None,
153
+ **kwargs: Any,
154
+ ) -> Array:
155
+ from dask_array import cumreduction
156
+
157
+ return cumreduction(
158
+ func,
159
+ binop,
160
+ ident,
161
+ arr,
162
+ axis=axis,
163
+ dtype=dtype,
164
+ **kwargs,
165
+ )
166
+
167
+ def apply_gufunc(
168
+ self,
169
+ func: Callable[..., Any],
170
+ signature: str,
171
+ *args: Any,
172
+ axes: Sequence[tuple[int, ...]] | None = None,
173
+ axis: int | None = None,
174
+ keepdims: bool = False,
175
+ output_dtypes: Sequence[np.dtype[Any]] | None = None,
176
+ output_sizes: dict[str, int] | None = None,
177
+ vectorize: bool | None = None,
178
+ allow_rechunk: bool = False,
179
+ meta: tuple[np.ndarray[Any, Any], ...] | None = None,
180
+ **kwargs: Any,
181
+ ) -> Any:
182
+ from dask_array import apply_gufunc
183
+
184
+ return apply_gufunc(
185
+ func,
186
+ signature,
187
+ *args,
188
+ axes=axes,
189
+ axis=axis,
190
+ keepdims=keepdims,
191
+ output_dtypes=output_dtypes,
192
+ output_sizes=output_sizes,
193
+ vectorize=vectorize,
194
+ allow_rechunk=allow_rechunk,
195
+ meta=meta,
196
+ **kwargs,
197
+ )
198
+
199
+ def map_blocks(
200
+ self,
201
+ func: Callable[..., Any],
202
+ *args: Any,
203
+ dtype: np.dtype[Any] | None = None,
204
+ chunks: tuple[int, ...] | None = None,
205
+ drop_axis: int | Sequence[int] | None = None,
206
+ new_axis: int | Sequence[int] | None = None,
207
+ **kwargs: Any,
208
+ ) -> Any:
209
+ from dask_array import map_blocks
210
+
211
+ return map_blocks(
212
+ func,
213
+ *args,
214
+ dtype=dtype,
215
+ chunks=chunks,
216
+ drop_axis=drop_axis,
217
+ new_axis=new_axis,
218
+ **kwargs,
219
+ )
220
+
221
+ def map_blocks_multi_output(
222
+ self,
223
+ func: Callable[..., Any],
224
+ input_exprs: Sequence[Any],
225
+ input_indices: Sequence[Iterable[Any]],
226
+ shared_indices: Iterable[Any],
227
+ block_specs: Mapping[tuple[int, ...], Any],
228
+ outputs: Sequence[Mapping[str, Any]],
229
+ *,
230
+ token: str,
231
+ ) -> list[Array]:
232
+ from dask_array._map_blocks import map_blocks_multi_output
233
+
234
+ return map_blocks_multi_output(
235
+ func,
236
+ input_exprs,
237
+ input_indices,
238
+ shared_indices,
239
+ block_specs,
240
+ outputs,
241
+ token=token,
242
+ )
243
+
244
+ def blockwise(
245
+ self,
246
+ func: Callable[..., Any],
247
+ out_ind: Iterable[Any],
248
+ *args: Any,
249
+ name: str | None = None,
250
+ token: Any | None = None,
251
+ dtype: np.dtype[Any] | None = None,
252
+ adjust_chunks: dict[Any, Callable[..., Any]] | None = None,
253
+ new_axes: dict[Any, int] | None = None,
254
+ align_arrays: bool = True,
255
+ concatenate: bool | None = None,
256
+ meta: tuple[np.ndarray[Any, Any], ...] | None = None,
257
+ **kwargs: Any,
258
+ ) -> Array:
259
+ from dask_array import blockwise
260
+
261
+ return blockwise(
262
+ func,
263
+ out_ind,
264
+ *args,
265
+ name=name,
266
+ token=token,
267
+ dtype=dtype,
268
+ adjust_chunks=adjust_chunks,
269
+ new_axes=new_axes,
270
+ align_arrays=align_arrays,
271
+ concatenate=concatenate,
272
+ meta=meta,
273
+ **kwargs,
274
+ )
275
+
276
+ def unify_chunks(
277
+ self,
278
+ *args: Any,
279
+ **kwargs: Any,
280
+ ) -> tuple[dict[str, tuple[tuple[int, ...], ...]], list[Array]]:
281
+ from dask_array import unify_chunks
282
+
283
+ return unify_chunks(*args, **kwargs)
284
+
285
+ def store(
286
+ self,
287
+ sources: Array | Sequence[Array],
288
+ targets: Any,
289
+ **kwargs: Any,
290
+ ) -> Any:
291
+ from dask_array import store
292
+
293
+ return store(
294
+ sources=sources,
295
+ targets=targets,
296
+ **kwargs,
297
+ )
298
+
299
+ def shuffle(
300
+ self,
301
+ x: Array,
302
+ indexer: list[list[int]],
303
+ axis: int,
304
+ chunks: Any,
305
+ ) -> Array:
306
+ from dask_array import shuffle
307
+
308
+ if chunks is None:
309
+ chunks = "auto"
310
+ if chunks != "auto":
311
+ raise NotImplementedError("Only chunks='auto' is supported at present.")
312
+ return shuffle(x, indexer, axis, chunks="auto")
313
+
314
+ def get_auto_chunk_size(self) -> int:
315
+ from dask import config as dask_config
316
+ from dask.utils import parse_bytes
317
+
318
+ return parse_bytes(dask_config.get("array.chunk-size"))
319
+
320
+
321
+ def _ensure_registered() -> None:
322
+ """Ensure DaskArrayExprManager is the "dask" chunk manager in xarray.
323
+
324
+ Both xarray and this package register an entry point named "dask" under
325
+ the ``xarray.chunkmanagers`` group. ``list_chunkmanagers`` builds a dict
326
+ from those entry points, so the *last* one enumerated wins. Because
327
+ ``importlib.metadata.entry_points`` iteration order is non-deterministic,
328
+ we fix the race here by replacing the cached value after the fact.
329
+ """
330
+ try:
331
+ from xarray.namedarray.parallelcompat import list_chunkmanagers
332
+ except ImportError:
333
+ return
334
+
335
+ managers = list_chunkmanagers()
336
+ if not isinstance(managers.get("dask"), DaskArrayExprManager):
337
+ managers["dask"] = DaskArrayExprManager()
@@ -0,0 +1,34 @@
1
+ """Core array types and wrapping functions.
2
+
3
+ This module re-exports the core Array class and conversion functions.
4
+ """
5
+
6
+ from dask_array.core._blockwise_funcs import blockwise, elemwise
7
+ from dask_array.core._conversion import (
8
+ array,
9
+ asanyarray,
10
+ asarray,
11
+ from_array,
12
+ )
13
+ from dask_array.core._from_graph import from_graph
14
+
15
+
16
+ def __getattr__(name):
17
+ """Lazy import of Array to avoid circular imports."""
18
+ if name == "Array":
19
+ from dask_array._collection import Array
20
+
21
+ return Array
22
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
23
+
24
+
25
+ __all__ = [
26
+ "Array",
27
+ "from_array",
28
+ "from_graph",
29
+ "asarray",
30
+ "asanyarray",
31
+ "array",
32
+ "blockwise",
33
+ "elemwise",
34
+ ]
@@ -0,0 +1,312 @@
1
+ """Blockwise and elemwise function wrappers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import numpy as np
6
+ import toolz
7
+
8
+ from dask_array._new_collection import new_collection
9
+ from dask_array._blockwise import Blockwise, Elemwise
10
+ from dask_array._core_utils import is_scalar_for_elemwise
11
+
12
+
13
+ def blockwise(
14
+ func,
15
+ out_ind,
16
+ *args,
17
+ name=None,
18
+ token=None,
19
+ dtype=None,
20
+ adjust_chunks=None,
21
+ new_axes=None,
22
+ align_arrays=True,
23
+ concatenate=None,
24
+ meta=None,
25
+ **kwargs,
26
+ ):
27
+ """Tensor operation: Generalized inner and outer products
28
+
29
+ A broad class of blocked algorithms and patterns can be specified with a
30
+ concise multi-index notation. The ``blockwise`` function applies an in-memory
31
+ function across multiple blocks of multiple inputs in a variety of ways.
32
+ Many dask.array operations are special cases of blockwise including
33
+ elementwise, broadcasting, reductions, tensordot, and transpose.
34
+
35
+ Parameters
36
+ ----------
37
+ func : callable
38
+ Function to apply to individual tuples of blocks
39
+ out_ind : iterable
40
+ Block pattern of the output, something like 'ijk' or (1, 2, 3)
41
+ *args : sequence of Array, index pairs
42
+ You may also pass literal arguments, accompanied by None index
43
+ e.g. (x, 'ij', y, 'jk', z, 'i', some_literal, None)
44
+ **kwargs : dict
45
+ Extra keyword arguments to pass to function
46
+ dtype : np.dtype
47
+ Datatype of resulting array.
48
+ concatenate : bool, keyword only
49
+ If true concatenate arrays along dummy indices, else provide lists
50
+ adjust_chunks : dict
51
+ Dictionary mapping index to function to be applied to chunk sizes
52
+ new_axes : dict, keyword only
53
+ New indexes and their dimension lengths
54
+ align_arrays: bool
55
+ Whether or not to align chunks along equally sized dimensions when
56
+ multiple arrays are provided. This allows for larger chunks in some
57
+ arrays to be broken into smaller ones that match chunk sizes in other
58
+ arrays such that they are compatible for block function mapping. If
59
+ this is false, then an error will be thrown if arrays do not already
60
+ have the same number of blocks in each dimension.
61
+
62
+ Examples
63
+ --------
64
+ 2D embarrassingly parallel operation from two arrays, x, and y.
65
+
66
+ >>> import operator, numpy as np, dask.array as da
67
+ >>> x = da.from_array([[1, 2],
68
+ ... [3, 4]], chunks=(1, 2))
69
+ >>> y = da.from_array([[10, 20],
70
+ ... [0, 0]])
71
+ >>> z = blockwise(operator.add, 'ij', x, 'ij', y, 'ij', dtype='f8')
72
+ >>> z.compute()
73
+ array([[11, 22],
74
+ [ 3, 4]])
75
+
76
+ Outer product multiplying a by b, two 1-d vectors
77
+
78
+ >>> a = da.from_array([0, 1, 2], chunks=1)
79
+ >>> b = da.from_array([10, 50, 100], chunks=1)
80
+ >>> z = blockwise(np.outer, 'ij', a, 'i', b, 'j', dtype='f8')
81
+ >>> z.compute()
82
+ array([[ 0, 0, 0],
83
+ [ 10, 50, 100],
84
+ [ 20, 100, 200]])
85
+
86
+ z = x.T
87
+
88
+ >>> z = blockwise(np.transpose, 'ji', x, 'ij', dtype=x.dtype)
89
+ >>> z.compute()
90
+ array([[1, 3],
91
+ [2, 4]])
92
+
93
+ The transpose case above is illustrative because it does transposition
94
+ both on each in-memory block by calling ``np.transpose`` and on the order
95
+ of the blocks themselves, by switching the order of the index ``ij -> ji``.
96
+
97
+ We can compose these same patterns with more variables and more complex
98
+ in-memory functions
99
+
100
+ z = X + Y.T
101
+
102
+ >>> z = blockwise(lambda x, y: x + y.T, 'ij', x, 'ij', y, 'ji', dtype='f8')
103
+ >>> z.compute()
104
+ array([[11, 2],
105
+ [23, 4]])
106
+
107
+ Any index, like ``i`` missing from the output index is interpreted as a
108
+ contraction (note that this differs from Einstein convention; repeated
109
+ indices do not imply contraction.) In the case of a contraction the passed
110
+ function should expect an iterable of blocks on any array that holds that
111
+ index. To receive arrays concatenated along contracted dimensions instead
112
+ pass ``concatenate=True``.
113
+
114
+ Inner product multiplying a by b, two 1-d vectors
115
+
116
+ >>> def sequence_dot(a_blocks, b_blocks):
117
+ ... result = 0
118
+ ... for a, b in zip(a_blocks, b_blocks):
119
+ ... result += a.dot(b)
120
+ ... return result
121
+
122
+ >>> z = blockwise(sequence_dot, '', a, 'i', b, 'i', dtype='f8')
123
+ >>> z.compute()
124
+ 250
125
+
126
+ Add new single-chunk dimensions with the ``new_axes=`` keyword, including
127
+ the length of the new dimension. New dimensions will always be in a single
128
+ chunk.
129
+
130
+ >>> def f(a):
131
+ ... return a[:, None] * np.ones((1, 5))
132
+
133
+ >>> z = blockwise(f, 'az', a, 'a', new_axes={'z': 5}, dtype=a.dtype)
134
+
135
+ New dimensions can also be multi-chunk by specifying a tuple of chunk
136
+ sizes. This has limited utility as is (because the chunks are all the
137
+ same), but the resulting graph can be modified to achieve more useful
138
+ results (see ``da.map_blocks``).
139
+
140
+ >>> z = blockwise(f, 'az', a, 'a', new_axes={'z': (5, 5)}, dtype=x.dtype)
141
+ >>> z.chunks
142
+ ((1, 1, 1), (5, 5))
143
+
144
+ If the applied function changes the size of each chunk you can specify this
145
+ with a ``adjust_chunks={...}`` dictionary holding a function for each index
146
+ that modifies the dimension size in that index.
147
+
148
+ >>> def double(x):
149
+ ... return np.concatenate([x, x])
150
+
151
+ >>> y = blockwise(double, 'ij', x, 'ij',
152
+ ... adjust_chunks={'i': lambda n: 2 * n}, dtype=x.dtype)
153
+ >>> y.chunks
154
+ ((2, 2), (2,))
155
+
156
+ Include literals by indexing with None
157
+
158
+ >>> z = blockwise(operator.add, 'ij', x, 'ij', 1234, None, dtype=x.dtype)
159
+ >>> z.compute()
160
+ array([[1235, 1236],
161
+ [1237, 1238]])
162
+ """
163
+ from dask_array.core import asanyarray
164
+
165
+ new_axes = new_axes or {}
166
+
167
+ # Normalize dtype to numpy dtype (handles cases like dtype=float)
168
+ if dtype is not None:
169
+ dtype = np.dtype(dtype)
170
+
171
+ # Input Validation
172
+ if len(set(out_ind)) != len(out_ind):
173
+ raise ValueError(
174
+ "Repeated elements not allowed in output index",
175
+ [k for k, v in toolz.frequencies(out_ind).items() if v > 1],
176
+ )
177
+ new = set(out_ind) - {a for arg in args[1::2] if arg is not None for a in arg} - set(new_axes or ())
178
+ if new:
179
+ raise ValueError("Unknown dimension", new)
180
+
181
+ # Convert scalars with empty tuple index to 0-d dask arrays
182
+ # This mirrors what traditional unify_chunks does
183
+ normalized_args = []
184
+ for arg, ind in toolz.partition(2, args):
185
+ if ind == () and not hasattr(arg, "chunks"):
186
+ arg = asanyarray(arg)
187
+ normalized_args.extend([arg, ind])
188
+
189
+ return new_collection(
190
+ Blockwise(
191
+ func,
192
+ out_ind,
193
+ name,
194
+ token,
195
+ dtype,
196
+ adjust_chunks,
197
+ new_axes,
198
+ align_arrays,
199
+ concatenate,
200
+ meta,
201
+ kwargs,
202
+ *normalized_args,
203
+ )
204
+ )
205
+
206
+
207
+ def elemwise(op, *args, out=None, where=True, dtype=None, name=None, **kwargs):
208
+ """Apply an elementwise ufunc-like function blockwise across arguments.
209
+
210
+ Like numpy ufuncs, broadcasting rules are respected.
211
+
212
+ Parameters
213
+ ----------
214
+ op : callable
215
+ The function to apply. Should be numpy ufunc-like in the parameters
216
+ that it accepts.
217
+ *args : Any
218
+ Arguments to pass to `op`. Non-dask array-like objects are first
219
+ converted to dask arrays, then all arrays are broadcast together before
220
+ applying the function blockwise across all arguments. Any scalar
221
+ arguments are passed as-is following normal numpy ufunc behavior.
222
+ out : dask array, optional
223
+ If out is a dask.array then this overwrites the contents of that array
224
+ with the result.
225
+ where : array_like, optional
226
+ An optional boolean mask marking locations where the ufunc should be
227
+ applied. Can be a scalar, dask array, or any other array-like object.
228
+ Mirrors the ``where`` argument to numpy ufuncs, see e.g. ``numpy.add``
229
+ for more information.
230
+ dtype : dtype, optional
231
+ If provided, overrides the output array dtype.
232
+ name : str, optional
233
+ A unique key name to use when building the backing dask graph. If not
234
+ provided, one will be automatically generated based on the input
235
+ arguments.
236
+ **kwargs : dict
237
+ Additional keyword arguments to pass to `op`.
238
+
239
+ Examples
240
+ --------
241
+ >>> elemwise(add, x, y) # doctest: +SKIP
242
+ >>> elemwise(sin, x) # doctest: +SKIP
243
+ >>> elemwise(sin, x, out=dask_array) # doctest: +SKIP
244
+
245
+ See Also
246
+ --------
247
+ blockwise
248
+ """
249
+ # Lazy import to avoid circular dependency
250
+ from dask_array.core import asanyarray
251
+
252
+ # Normalize where parameter
253
+ if where is True:
254
+ pass # keep as True
255
+ elif where is False or where is None:
256
+ where = False
257
+ else:
258
+ # Convert to dask array
259
+ where = asanyarray(where)
260
+
261
+ # Normalize out parameter
262
+ out = _normalize_out(out)
263
+
264
+ args = [np.asarray(a) if isinstance(a, (list, tuple)) else a for a in args]
265
+
266
+ # Only convert non-scalar arguments to dask arrays
267
+ # Scalars are kept as-is to preserve proper dtype behavior (e.g., 2.0 * float32_array = float32)
268
+ args = [asanyarray(a) if not is_scalar_for_elemwise(a) else a for a in args]
269
+
270
+ user_kwargs = dict(kwargs) if kwargs else None
271
+
272
+ result = new_collection(Elemwise(op, dtype, name, where, out, user_kwargs, *args))
273
+
274
+ return _handle_out(out, result)
275
+
276
+
277
+ def _normalize_out(out):
278
+ """Normalize out parameter for elemwise operations."""
279
+ from dask_array._collection import Array
280
+
281
+ if isinstance(out, tuple):
282
+ if len(out) == 1:
283
+ out = out[0]
284
+ elif len(out) > 1:
285
+ raise NotImplementedError("The out parameter is not fully supported")
286
+ else:
287
+ out = None
288
+ if not (out is None or isinstance(out, Array)):
289
+ raise NotImplementedError(
290
+ f"The out parameter is not fully supported. Received type {type(out).__name__}, expected Dask Array"
291
+ )
292
+ return out
293
+
294
+
295
+ def _handle_out(out, result):
296
+ """Handle out parameters for array-expr.
297
+
298
+ If out is a dask Array then this overwrites the contents of that array with
299
+ the result by replacing its internal expression.
300
+ """
301
+ from dask_array._collection import Array
302
+
303
+ if isinstance(out, Array):
304
+ if out.shape != result.shape:
305
+ raise ValueError(
306
+ f"Mismatched shapes between result and out parameter. out={out.shape}, result={result.shape}"
307
+ )
308
+ # Modify the out array in-place by replacing its expression
309
+ out._expr = result._expr
310
+ return out
311
+ else:
312
+ return result