dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
@@ -0,0 +1,652 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import cached_property
4
+ from itertools import product
5
+ from numbers import Number
6
+ import operator
7
+
8
+ import numpy as np
9
+ from toolz import concat
10
+
11
+ from dask_array._collection import Array, blockwise
12
+ from dask_array._core_utils import _pass_extra_kwargs, apply_and_enforce, apply_infer_dtype
13
+ from dask_array._utils import compute_meta
14
+ from dask_array._expr import ArrayExpr
15
+ from dask_array._new_collection import new_collection
16
+ from dask._task_spec import Task, TaskRef
17
+ from dask.layers import ArrayBlockIdDep, ArrayBlockwiseDep, ArrayValuesDep
18
+ from dask.utils import cached_cumsum, funcname, has_keyword
19
+
20
+
21
+ def map_blocks(
22
+ func,
23
+ *args,
24
+ name=None,
25
+ token=None,
26
+ dtype=None,
27
+ chunks=None,
28
+ drop_axis=None,
29
+ new_axis=None,
30
+ enforce_ndim=False,
31
+ meta=None,
32
+ **kwargs,
33
+ ):
34
+ """Map a function across all blocks of a dask array.
35
+
36
+ Note that ``map_blocks`` will attempt to automatically determine the output
37
+ array type by calling ``func`` on 0-d versions of the inputs. Please refer to
38
+ the ``meta`` keyword argument below if you expect that the function will not
39
+ succeed when operating on 0-d arrays.
40
+
41
+ Parameters
42
+ ----------
43
+ func : callable
44
+ Function to apply to every block in the array.
45
+ If ``func`` accepts ``block_info=`` or ``block_id=``
46
+ as keyword arguments, these will be passed dictionaries
47
+ containing information about input and output chunks/arrays
48
+ during computation. See examples for details.
49
+ args : dask arrays or other objects
50
+ dtype : np.dtype, optional
51
+ The ``dtype`` of the output array. It is recommended to provide this.
52
+ If not provided, will be inferred by applying the function to a small
53
+ set of fake data.
54
+ chunks : tuple, optional
55
+ Chunk shape of resulting blocks if the function does not preserve
56
+ shape. If not provided, the resulting array is assumed to have the same
57
+ block structure as the first input array.
58
+ drop_axis : number or iterable, optional
59
+ Dimensions lost by the function.
60
+ new_axis : number or iterable, optional
61
+ New dimensions created by the function. Note that these are applied
62
+ after ``drop_axis`` (if present). The size of each chunk along this
63
+ dimension will be set to 1. Please specify ``chunks`` if the individual
64
+ chunks have a different size.
65
+ enforce_ndim : bool, default False
66
+ Whether to enforce at runtime that the dimensionality of the array
67
+ produced by ``func`` actually matches that of the array returned by
68
+ ``map_blocks``.
69
+ If True, this will raise an error when there is a mismatch.
70
+ token : string, optional
71
+ The key prefix to use for the output array. If not provided, will be
72
+ determined from the function name.
73
+ name : string, optional
74
+ The key name to use for the output array. Note that this fully
75
+ specifies the output key name, and must be unique. If not provided,
76
+ will be determined by a hash of the arguments.
77
+ meta : array-like, optional
78
+ The ``meta`` of the output array, when specified is expected to be an
79
+ array of the same type and dtype of that returned when calling ``.compute()``
80
+ on the array returned by this function. When not provided, ``meta`` will be
81
+ inferred by applying the function to a small set of fake data, usually a
82
+ 0-d array. It's important to ensure that ``func`` can successfully complete
83
+ computation without raising exceptions when 0-d is passed to it, providing
84
+ ``meta`` will be required otherwise. If the output type is known beforehand
85
+ (e.g., ``np.ndarray``, ``cupy.ndarray``), an empty array of such type dtype
86
+ can be passed, for example: ``meta=np.array((), dtype=np.int32)``.
87
+ **kwargs :
88
+ Other keyword arguments to pass to function. Values must be constants
89
+ (not dask.arrays)
90
+
91
+ See Also
92
+ --------
93
+ dask.array.map_overlap : Generalized operation with overlap between neighbors.
94
+ dask.array.blockwise : Generalized operation with control over block alignment.
95
+
96
+ Examples
97
+ --------
98
+ >>> import dask_array as da
99
+ >>> x = da.arange(6, chunks=3)
100
+
101
+ >>> x.map_blocks(lambda x: x * 2).compute()
102
+ array([ 0, 2, 4, 6, 8, 10])
103
+
104
+ The ``da.map_blocks`` function can also accept multiple arrays.
105
+
106
+ >>> d = da.arange(5, chunks=2)
107
+ >>> e = da.arange(5, chunks=2)
108
+
109
+ >>> f = da.map_blocks(lambda a, b: a + b**2, d, e)
110
+ >>> f.compute()
111
+ array([ 0, 2, 6, 12, 20])
112
+
113
+ If the function changes shape of the blocks then you must provide chunks
114
+ explicitly.
115
+
116
+ >>> y = x.map_blocks(lambda x: x[::2], chunks=((2, 2),))
117
+
118
+ You have a bit of freedom in specifying chunks. If all of the output chunk
119
+ sizes are the same, you can provide just that chunk size as a single tuple.
120
+
121
+ >>> a = da.arange(18, chunks=(6,))
122
+ >>> b = a.map_blocks(lambda x: x[:3], chunks=(3,))
123
+
124
+ If the function changes the dimension of the blocks you must specify the
125
+ created or destroyed dimensions.
126
+
127
+ >>> b = a.map_blocks(lambda x: x[None, :, None], chunks=(1, 6, 1),
128
+ ... new_axis=[0, 2])
129
+
130
+ If ``chunks`` is specified but ``new_axis`` is not, then it is inferred to
131
+ add the necessary number of axes on the left.
132
+
133
+ Note that ``map_blocks()`` will concatenate chunks along axes specified by
134
+ the keyword parameter ``drop_axis`` prior to applying the function.
135
+ This is illustrated in the figure below:
136
+
137
+ .. image:: /images/map_blocks_drop_axis.png
138
+
139
+ Due to memory-size-constraints, it is often not advisable to use ``drop_axis``
140
+ on an axis that is chunked. In that case, it is better not to use
141
+ ``map_blocks`` but rather
142
+ ``dask.array.reduction(..., axis=dropped_axes, concatenate=False)`` which
143
+ maintains a leaner memory footprint while it drops any axis.
144
+
145
+ Map_blocks aligns blocks by block positions without regard to shape. In the
146
+ following example we have two arrays with the same number of blocks but
147
+ with different shape and chunk sizes.
148
+
149
+ >>> x = da.arange(1000, chunks=(100,))
150
+ >>> y = da.arange(100, chunks=(10,))
151
+
152
+ The relevant attribute to match is numblocks.
153
+
154
+ >>> x.numblocks
155
+ (10,)
156
+ >>> y.numblocks
157
+ (10,)
158
+
159
+ If these match (up to broadcasting rules) then we can map arbitrary
160
+ functions across blocks
161
+
162
+ >>> def func(a, b):
163
+ ... return np.array([a.max(), b.max()])
164
+
165
+ >>> da.map_blocks(func, x, y, chunks=(2,), dtype='i8')
166
+ dask.array<func, shape=(20,), dtype=int64, chunksize=(2,), chunktype=numpy.ndarray>
167
+
168
+ >>> _.compute()
169
+ array([ 99, 9, 199, 19, 299, 29, 399, 39, 499, 49, 599, 59, 699,
170
+ 69, 799, 79, 899, 89, 999, 99])
171
+
172
+ Your block function can get information about where it is in the array by
173
+ accepting a special ``block_info`` or ``block_id`` keyword argument.
174
+ During computation, they will contain information about each of the input
175
+ and output chunks (and dask arrays) relevant to each call of ``func``.
176
+
177
+ >>> def func(block_info=None):
178
+ ... pass
179
+
180
+ This will receive the following information:
181
+
182
+ >>> block_info # doctest: +SKIP
183
+ {0: {'shape': (1000,),
184
+ 'num-chunks': (10,),
185
+ 'chunk-location': (4,),
186
+ 'array-location': [(400, 500)]},
187
+ None: {'shape': (1000,),
188
+ 'num-chunks': (10,),
189
+ 'chunk-location': (4,),
190
+ 'array-location': [(400, 500)],
191
+ 'chunk-shape': (100,),
192
+ 'dtype': dtype('float64')}}
193
+
194
+ The keys to the ``block_info`` dictionary indicate which is the input and
195
+ output Dask array:
196
+
197
+ - **Input Dask array(s):** ``block_info[0]`` refers to the first input Dask array.
198
+ The dictionary key is ``0`` because that is the argument index corresponding
199
+ to the first input Dask array.
200
+ In cases where multiple Dask arrays have been passed as input to the function,
201
+ you can access them with the number corresponding to the input argument,
202
+ eg: ``block_info[1]``, ``block_info[2]``, etc.
203
+ (Note that if you pass multiple Dask arrays as input to map_blocks,
204
+ the arrays must match each other by having matching numbers of chunks,
205
+ along corresponding dimensions up to broadcasting rules.)
206
+ - **Output Dask array:** ``block_info[None]`` refers to the output Dask array,
207
+ and contains information about the output chunks.
208
+ The output chunk shape and dtype may may be different than the input chunks.
209
+
210
+ For each dask array, ``block_info`` describes:
211
+
212
+ - ``shape``: the shape of the full Dask array,
213
+ - ``num-chunks``: the number of chunks of the full array in each dimension,
214
+ - ``chunk-location``: the chunk location (for example the fourth chunk over
215
+ in the first dimension), and
216
+ - ``array-location``: the array location within the full Dask array
217
+ (for example the slice corresponding to ``40:50``).
218
+
219
+ In addition to these, there are two extra parameters described by
220
+ ``block_info`` for the output array (in ``block_info[None]``):
221
+
222
+ - ``chunk-shape``: the output chunk shape, and
223
+ - ``dtype``: the output dtype.
224
+
225
+ These features can be combined to synthesize an array from scratch, for
226
+ example:
227
+
228
+ >>> def func(block_info=None):
229
+ ... loc = block_info[None]['array-location'][0]
230
+ ... return np.arange(loc[0], loc[1])
231
+
232
+ >>> da.map_blocks(func, chunks=((4, 4),), dtype=np.float64)
233
+ dask.array<func, shape=(8,), dtype=float64, chunksize=(4,), chunktype=numpy.ndarray>
234
+
235
+ >>> _.compute()
236
+ array([0, 1, 2, 3, 4, 5, 6, 7])
237
+
238
+ ``block_id`` is similar to ``block_info`` but contains only the ``chunk_location``:
239
+
240
+ >>> def func(block_id=None):
241
+ ... pass
242
+
243
+ This will receive the following information:
244
+
245
+ >>> block_id # doctest: +SKIP
246
+ (4, 3)
247
+
248
+ You may specify the key name prefix of the resulting task in the graph with
249
+ the optional ``token`` keyword argument.
250
+
251
+ >>> x.map_blocks(lambda x: x + 1, name='increment')
252
+ dask.array<increment, shape=(1000,), dtype=int64, chunksize=(100,), chunktype=numpy.ndarray>
253
+
254
+ For functions that may not handle 0-d arrays, it's also possible to specify
255
+ ``meta`` with an empty array matching the type of the expected result. In
256
+ the example below, ``func`` will result in an ``IndexError`` when computing
257
+ ``meta``:
258
+
259
+ >>> rng = da.random.default_rng()
260
+ >>> da.map_blocks(lambda x: x[2], rng.random(5), meta=np.array(()))
261
+ dask.array<lambda, shape=(5,), dtype=float64, chunksize=(5,), chunktype=numpy.ndarray>
262
+
263
+ Similarly, it's possible to specify a non-NumPy array to ``meta``, and provide
264
+ a ``dtype``:
265
+
266
+ >>> import cupy # doctest: +SKIP
267
+ >>> rng = da.random.default_rng(cupy.random.default_rng()) # doctest: +SKIP
268
+ >>> dt = np.float32
269
+ >>> da.map_blocks(lambda x: x[2], rng.random(5, dtype=dt), meta=cupy.array((), dtype=dt)) # doctest: +SKIP
270
+ dask.array<lambda, shape=(5,), dtype=float32, chunksize=(5,), chunktype=cupy.ndarray>
271
+ """
272
+ if drop_axis is None:
273
+ drop_axis = []
274
+
275
+ if not callable(func):
276
+ msg = (
277
+ "First argument must be callable function, not %s\n"
278
+ "Usage: da.map_blocks(function, x)\n"
279
+ " or: da.map_blocks(function, x, y, z)"
280
+ )
281
+ raise TypeError(msg % type(func).__name__)
282
+ if token:
283
+ name = token
284
+
285
+ # Track if user provided explicit name (should be used exactly)
286
+ # vs auto-generated token prefix
287
+ user_provided_name = name is not None
288
+ token_prefix = f"{name or funcname(func)}"
289
+ new_axes = {}
290
+
291
+ if isinstance(drop_axis, Number):
292
+ drop_axis = [drop_axis]
293
+ if isinstance(new_axis, Number):
294
+ new_axis = [new_axis] # TODO: handle new_axis
295
+
296
+ arrs = [a for a in args if isinstance(a, Array)]
297
+
298
+ def get_argpair(a):
299
+ if isinstance(a, Array):
300
+ return (a, tuple(range(a.ndim))[::-1])
301
+ elif isinstance(a, ArrayBlockwiseDep):
302
+ # ArrayBlockwiseDep needs to be indexed like an array
303
+ return (a, tuple(range(len(a.numblocks)))[::-1])
304
+ else:
305
+ return (a, None)
306
+
307
+ argpairs = [get_argpair(a) for a in args]
308
+ if arrs:
309
+ out_ind = tuple(range(max(a.ndim for a in arrs)))[::-1]
310
+ else:
311
+ out_ind = ()
312
+
313
+ original_kwargs = kwargs
314
+
315
+ if dtype is None and meta is None:
316
+ try:
317
+ meta = compute_meta(func, dtype, *args, **kwargs)
318
+ except Exception:
319
+ pass
320
+
321
+ dtype = apply_infer_dtype(func, args, original_kwargs, "map_blocks")
322
+
323
+ # Create synthetic meta if compute_meta failed but we have dtype
324
+ if meta is None and dtype is not None:
325
+ meta = np.empty((0,) * len(out_ind), dtype=dtype)
326
+
327
+ if drop_axis:
328
+ ndim_out = len(out_ind)
329
+ if any(i < -ndim_out or i >= ndim_out for i in drop_axis):
330
+ raise ValueError(f"drop_axis out of range (drop_axis={drop_axis}, but output is {ndim_out}d).")
331
+ drop_axis = [i % ndim_out for i in drop_axis]
332
+ out_ind = tuple(x for i, x in enumerate(out_ind) if i not in drop_axis)
333
+ if new_axis is None and chunks is not None and len(out_ind) < len(chunks):
334
+ new_axis = range(len(chunks) - len(out_ind))
335
+ if new_axis:
336
+ # new_axis = [x + len(drop_axis) for x in new_axis]
337
+ out_ind = list(out_ind)
338
+ for ax in sorted(new_axis):
339
+ n = len(out_ind) + len(drop_axis)
340
+ out_ind.insert(ax, n)
341
+ if chunks is not None:
342
+ new_axes[n] = chunks[ax]
343
+ else:
344
+ new_axes[n] = 1
345
+ out_ind = tuple(out_ind)
346
+ if max(new_axis) > max(out_ind):
347
+ raise ValueError("New_axis values do not fill in all dimensions")
348
+
349
+ if chunks is not None:
350
+ if len(chunks) != len(out_ind):
351
+ raise ValueError(f"Provided chunks have {len(chunks)} dims; expected {len(out_ind)} dims")
352
+ adjust_chunks = dict(zip(out_ind, chunks))
353
+ else:
354
+ adjust_chunks = None
355
+
356
+ # Determine if we actually need concatenation. Concatenation is only needed
357
+ # when some input indices are not in the output (contracted dimensions).
358
+ # When there's no actual contraction, we can set concatenate=False to enable fusion.
359
+ out_ind_set = set(out_ind)
360
+ needs_concatenate = any(i not in out_ind_set for _, ind in argpairs if ind is not None for i in ind)
361
+
362
+ if enforce_ndim:
363
+ out = blockwise(
364
+ apply_and_enforce,
365
+ out_ind,
366
+ *concat(argpairs),
367
+ expected_ndim=len(out_ind),
368
+ _func=func,
369
+ name=name if user_provided_name else None,
370
+ token=token_prefix,
371
+ new_axes=new_axes,
372
+ dtype=dtype,
373
+ concatenate=needs_concatenate,
374
+ align_arrays=False,
375
+ adjust_chunks=adjust_chunks,
376
+ meta=meta,
377
+ **kwargs,
378
+ )
379
+ else:
380
+ out = blockwise(
381
+ func,
382
+ out_ind,
383
+ *concat(argpairs),
384
+ name=name if user_provided_name else None,
385
+ token=token_prefix,
386
+ new_axes=new_axes,
387
+ dtype=dtype,
388
+ concatenate=needs_concatenate,
389
+ align_arrays=False,
390
+ adjust_chunks=adjust_chunks,
391
+ meta=meta,
392
+ **kwargs,
393
+ )
394
+
395
+ extra_argpairs = []
396
+ extra_names = []
397
+
398
+ # If func has block_id as an argument, construct an object to inject it.
399
+ if has_keyword(func, "block_id"):
400
+ extra_argpairs.append((ArrayBlockIdDep(out.chunks), out_ind))
401
+ extra_names.append("block_id")
402
+
403
+ if has_keyword(func, "_overlap_trim_info"):
404
+ # Internal for map overlap to reduce size of graph
405
+ num_chunks = out.numblocks
406
+ block_id_dict = {block_id: (block_id, num_chunks) for block_id in product(*(range(len(c)) for c in out.chunks))}
407
+ extra_argpairs.append((ArrayValuesDep(out.chunks, block_id_dict), out_ind))
408
+ extra_names.append("_overlap_trim_info")
409
+
410
+ # If func has block_info as an argument, construct a dict of block info
411
+ # objects and prepare to inject it.
412
+ if has_keyword(func, "block_info"):
413
+ starts = {}
414
+ num_chunks = {}
415
+ shapes = {}
416
+
417
+ for i, (arg, in_ind) in enumerate(argpairs):
418
+ if in_ind is not None:
419
+ shapes[i] = arg.shape
420
+ if drop_axis:
421
+ # We concatenate along dropped axes, so we need to treat them
422
+ # as if there is only a single chunk.
423
+ starts[i] = [
424
+ (cached_cumsum(arg.chunks[j], initial_zero=True) if ind in out_ind else [0, arg.shape[j]])
425
+ for j, ind in enumerate(in_ind)
426
+ ]
427
+ num_chunks[i] = tuple(len(s) - 1 for s in starts[i])
428
+ else:
429
+ starts[i] = [cached_cumsum(c, initial_zero=True) for c in arg.chunks]
430
+ num_chunks[i] = arg.numblocks
431
+ out_starts = [cached_cumsum(c, initial_zero=True) for c in out.chunks]
432
+
433
+ block_info_dict = {}
434
+ for block_id in product(*(range(len(c)) for c in out.chunks)):
435
+ # Get position of chunk, indexed by axis labels
436
+ location = {out_ind[i]: loc for i, loc in enumerate(block_id)}
437
+ info = {}
438
+ for i, shape in shapes.items():
439
+ # Compute chunk key in the array, taking broadcasting into
440
+ # account. We don't directly know which dimensions are
441
+ # broadcast, but any dimension with only one chunk can be
442
+ # treated as broadcast.
443
+ arr_k = tuple(
444
+ location.get(ind, 0) if num_chunks[i][j] > 1 else 0 for j, ind in enumerate(argpairs[i][1])
445
+ )
446
+ info[i] = {
447
+ "shape": shape,
448
+ "num-chunks": num_chunks[i],
449
+ "array-location": [(starts[i][ij][j], starts[i][ij][j + 1]) for ij, j in enumerate(arr_k)],
450
+ "chunk-location": arr_k,
451
+ }
452
+
453
+ info[None] = {
454
+ "shape": out.shape,
455
+ "num-chunks": out.numblocks,
456
+ "array-location": [(out_starts[ij][j], out_starts[ij][j + 1]) for ij, j in enumerate(block_id)],
457
+ "chunk-location": block_id,
458
+ "chunk-shape": tuple(out.chunks[ij][j] for ij, j in enumerate(block_id)),
459
+ "dtype": dtype,
460
+ }
461
+ block_info_dict[block_id] = info
462
+
463
+ extra_argpairs.append((ArrayValuesDep(out.chunks, block_info_dict), out_ind))
464
+ extra_names.append("block_info")
465
+
466
+ if extra_argpairs:
467
+ # Rewrite the Blockwise layer to inject block_info/block_id.
468
+ # Use token=out.name to preserve name prefix from first blockwise.
469
+ out = blockwise(
470
+ _pass_extra_kwargs,
471
+ out_ind,
472
+ func,
473
+ None,
474
+ tuple(extra_names),
475
+ None,
476
+ *concat(extra_argpairs),
477
+ *concat(argpairs),
478
+ token=out.name,
479
+ dtype=out.dtype,
480
+ concatenate=needs_concatenate,
481
+ align_arrays=False,
482
+ adjust_chunks=dict(zip(out_ind, out.chunks)),
483
+ meta=meta,
484
+ **kwargs,
485
+ )
486
+
487
+ # If output is DataFrame-like, create a DataFrame expression directly
488
+ # instead of returning an Array with DataFrame blocks
489
+ from dask.utils import is_dataframe_like, is_index_like, is_series_like
490
+
491
+ if meta is not None and (is_dataframe_like(meta) or is_series_like(meta) or is_index_like(meta)):
492
+ try:
493
+ from dask.dataframe.dask_expr._array import MapBlocksToDataFrame
494
+ from dask.dataframe.dask_expr._collection import new_collection
495
+
496
+ # Helper to convert Array to expr
497
+ def to_expr(arr):
498
+ return arr.expr if isinstance(arr, Array) else arr
499
+
500
+ # Build args list with expressions
501
+ if extra_argpairs:
502
+ # Function wrapped with block_info/block_id injection
503
+ actual_func = _pass_extra_kwargs
504
+ expr_args = [func, None, tuple(extra_names), None]
505
+ for arr, ind in extra_argpairs:
506
+ expr_args.extend([arr, ind])
507
+ for arr, ind in argpairs:
508
+ expr_args.extend([to_expr(arr), ind])
509
+ else:
510
+ actual_func = func
511
+ expr_args = []
512
+ for arr, ind in argpairs:
513
+ expr_args.extend([to_expr(arr), ind])
514
+
515
+ return new_collection(
516
+ MapBlocksToDataFrame(
517
+ actual_func,
518
+ meta,
519
+ token_prefix,
520
+ out_ind,
521
+ tuple(expr_args),
522
+ kwargs or None,
523
+ )
524
+ )
525
+ except ImportError:
526
+ pass # dask.dataframe not available
527
+
528
+ return out
529
+
530
+
531
+ class MapBlocksOutput(ArrayExpr):
532
+ """One projected output from a shared multi-output block mapping."""
533
+
534
+ _parameters = [
535
+ "func",
536
+ "output_key",
537
+ "output_indices",
538
+ "chunks",
539
+ "dtype",
540
+ "_meta_provided",
541
+ "name",
542
+ "shared_name",
543
+ "shared_indices",
544
+ "input_indices",
545
+ "block_specs",
546
+ ]
547
+
548
+ @property
549
+ def _is_blockwise_fusable(self):
550
+ return False
551
+
552
+ @cached_property
553
+ def input_exprs(self):
554
+ return self.operands[len(self._parameters) :]
555
+
556
+ @cached_property
557
+ def chunks(self):
558
+ return self.operand("chunks")
559
+
560
+ @cached_property
561
+ def _meta(self):
562
+ meta = self.operand("_meta_provided")
563
+ if meta is not None:
564
+ return meta
565
+ return np.empty((0,) * len(self.chunks), dtype=self.dtype)
566
+
567
+ @cached_property
568
+ def dtype(self):
569
+ return np.dtype(self.operand("dtype"))
570
+
571
+ @cached_property
572
+ def _name(self):
573
+ return self.operand("name")
574
+
575
+ def _shared_block_id(self, output_block_id):
576
+ output_location = dict(zip(self.operand("output_indices"), output_block_id))
577
+ return tuple(output_location.get(dim, 0) for dim in self.operand("shared_indices"))
578
+
579
+ def _input_block_id(self, input_indices, shared_block_id):
580
+ shared_location = dict(zip(self.operand("shared_indices"), shared_block_id))
581
+ return tuple(shared_location[dim] for dim in input_indices)
582
+
583
+ def _layer(self):
584
+ dsk = {}
585
+ needed_shared_ids = set()
586
+ output_ranges = [range(len(c)) for c in self.chunks]
587
+
588
+ for output_block_id in product(*output_ranges):
589
+ shared_block_id = self._shared_block_id(output_block_id)
590
+ needed_shared_ids.add(shared_block_id)
591
+ out_key = (self._name, *output_block_id)
592
+ shared_key = (self.operand("shared_name"), *shared_block_id)
593
+ dsk[out_key] = Task(
594
+ out_key,
595
+ operator.getitem,
596
+ TaskRef(shared_key),
597
+ self.operand("output_key"),
598
+ )
599
+
600
+ block_specs = self.operand("block_specs")
601
+ input_indices = self.operand("input_indices")
602
+ for shared_block_id in needed_shared_ids:
603
+ shared_key = (self.operand("shared_name"), *shared_block_id)
604
+ args = [block_specs[shared_block_id]]
605
+ for expr, indices in zip(self.input_exprs, input_indices, strict=True):
606
+ input_block_id = self._input_block_id(indices, shared_block_id)
607
+ args.append(TaskRef((expr._name, *input_block_id)))
608
+ dsk[shared_key] = Task(shared_key, self.operand("func"), *args)
609
+
610
+ return dsk
611
+
612
+
613
+ def map_blocks_multi_output(
614
+ func,
615
+ input_exprs,
616
+ input_indices,
617
+ shared_indices,
618
+ block_specs,
619
+ outputs,
620
+ *,
621
+ token,
622
+ ):
623
+ """Create arrays projected from one shared block function.
624
+
625
+ ``func`` receives ``(block_spec, *input_blocks)`` and returns a mapping.
626
+ Each entry in ``outputs`` must provide ``key``, ``indices``, ``chunks``,
627
+ ``dtype``, and optionally ``meta`` and ``name``.
628
+ """
629
+ shared_name = f"{token}-shared"
630
+ input_exprs = [expr.lower_completely() for expr in input_exprs]
631
+ arrays = []
632
+ for output in outputs:
633
+ name = output.get("name") or f"{output['key']}-{token}"
634
+ arrays.append(
635
+ new_collection(
636
+ MapBlocksOutput(
637
+ func,
638
+ output["key"],
639
+ tuple(output["indices"]),
640
+ tuple(map(tuple, output["chunks"])),
641
+ np.dtype(output["dtype"]),
642
+ output.get("meta"),
643
+ name,
644
+ shared_name,
645
+ tuple(shared_indices),
646
+ tuple(tuple(indices) for indices in input_indices),
647
+ dict(block_specs),
648
+ *input_exprs,
649
+ )
650
+ )
651
+ )
652
+ return arrays
@@ -0,0 +1,10 @@
1
+ """Minimal module for new_collection to avoid circular imports."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ def new_collection(expr):
7
+ """Create new Array collection from an expression."""
8
+ from dask_array._collection import Array
9
+
10
+ return Array(expr)