dask-array 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (144) hide show
  1. dask_array/__init__.py +228 -0
  2. dask_array/_backends.py +76 -0
  3. dask_array/_backends_array.py +99 -0
  4. dask_array/_blockwise.py +1410 -0
  5. dask_array/_broadcast.py +272 -0
  6. dask_array/_chunk.py +445 -0
  7. dask_array/_chunk_types.py +54 -0
  8. dask_array/_collection.py +1644 -0
  9. dask_array/_concatenate.py +331 -0
  10. dask_array/_core_utils.py +1365 -0
  11. dask_array/_dispatch.py +141 -0
  12. dask_array/_einsum.py +277 -0
  13. dask_array/_expr.py +544 -0
  14. dask_array/_expr_flow.py +586 -0
  15. dask_array/_gufunc.py +805 -0
  16. dask_array/_histogram.py +617 -0
  17. dask_array/_map_blocks.py +652 -0
  18. dask_array/_new_collection.py +10 -0
  19. dask_array/_numpy_compat.py +135 -0
  20. dask_array/_overlap.py +1159 -0
  21. dask_array/_rechunk.py +1050 -0
  22. dask_array/_reshape.py +710 -0
  23. dask_array/_routines.py +102 -0
  24. dask_array/_shuffle.py +448 -0
  25. dask_array/_stack.py +264 -0
  26. dask_array/_svg.py +291 -0
  27. dask_array/_templates.py +29 -0
  28. dask_array/_test_utils.py +257 -0
  29. dask_array/_ufunc.py +385 -0
  30. dask_array/_utils.py +349 -0
  31. dask_array/_visualize.py +223 -0
  32. dask_array/_xarray.py +337 -0
  33. dask_array/core/__init__.py +34 -0
  34. dask_array/core/_blockwise_funcs.py +312 -0
  35. dask_array/core/_conversion.py +422 -0
  36. dask_array/core/_from_graph.py +97 -0
  37. dask_array/creation/__init__.py +71 -0
  38. dask_array/creation/_arange.py +121 -0
  39. dask_array/creation/_diag.py +116 -0
  40. dask_array/creation/_diagonal.py +241 -0
  41. dask_array/creation/_eye.py +103 -0
  42. dask_array/creation/_linspace.py +102 -0
  43. dask_array/creation/_mesh.py +134 -0
  44. dask_array/creation/_ones_zeros.py +454 -0
  45. dask_array/creation/_pad.py +270 -0
  46. dask_array/creation/_repeat.py +55 -0
  47. dask_array/creation/_tile.py +36 -0
  48. dask_array/creation/_tri.py +28 -0
  49. dask_array/creation/_utils.py +296 -0
  50. dask_array/fft.py +320 -0
  51. dask_array/io/__init__.py +39 -0
  52. dask_array/io/_base.py +10 -0
  53. dask_array/io/_from_array.py +257 -0
  54. dask_array/io/_from_delayed.py +95 -0
  55. dask_array/io/_from_graph.py +54 -0
  56. dask_array/io/_from_npy_stack.py +67 -0
  57. dask_array/io/_store.py +336 -0
  58. dask_array/io/_tiledb.py +159 -0
  59. dask_array/io/_to_npy_stack.py +65 -0
  60. dask_array/io/_zarr.py +449 -0
  61. dask_array/linalg/__init__.py +39 -0
  62. dask_array/linalg/_cholesky.py +234 -0
  63. dask_array/linalg/_lu.py +300 -0
  64. dask_array/linalg/_norm.py +94 -0
  65. dask_array/linalg/_qr.py +601 -0
  66. dask_array/linalg/_solve.py +349 -0
  67. dask_array/linalg/_svd.py +394 -0
  68. dask_array/linalg/_tensordot.py +334 -0
  69. dask_array/linalg/_utils.py +74 -0
  70. dask_array/manipulation/__init__.py +45 -0
  71. dask_array/manipulation/_expand.py +321 -0
  72. dask_array/manipulation/_flip.py +92 -0
  73. dask_array/manipulation/_roll.py +78 -0
  74. dask_array/manipulation/_transpose.py +309 -0
  75. dask_array/random/__init__.py +125 -0
  76. dask_array/random/_choice.py +181 -0
  77. dask_array/random/_expr.py +256 -0
  78. dask_array/random/_generator.py +441 -0
  79. dask_array/random/_random_state.py +259 -0
  80. dask_array/random/_utils.py +84 -0
  81. dask_array/reductions/__init__.py +84 -0
  82. dask_array/reductions/_arg_reduction.py +130 -0
  83. dask_array/reductions/_common.py +1082 -0
  84. dask_array/reductions/_cumulative.py +522 -0
  85. dask_array/reductions/_percentile.py +261 -0
  86. dask_array/reductions/_reduction.py +725 -0
  87. dask_array/reductions/_trace.py +56 -0
  88. dask_array/routines/__init__.py +133 -0
  89. dask_array/routines/_apply.py +84 -0
  90. dask_array/routines/_bincount.py +112 -0
  91. dask_array/routines/_broadcast.py +111 -0
  92. dask_array/routines/_coarsen.py +115 -0
  93. dask_array/routines/_diff.py +79 -0
  94. dask_array/routines/_gradient.py +158 -0
  95. dask_array/routines/_indexing.py +65 -0
  96. dask_array/routines/_insert_delete.py +132 -0
  97. dask_array/routines/_misc.py +122 -0
  98. dask_array/routines/_nonzero.py +72 -0
  99. dask_array/routines/_search.py +123 -0
  100. dask_array/routines/_select.py +113 -0
  101. dask_array/routines/_statistics.py +171 -0
  102. dask_array/routines/_topk.py +82 -0
  103. dask_array/routines/_triangular.py +74 -0
  104. dask_array/routines/_unique.py +232 -0
  105. dask_array/routines/_where.py +62 -0
  106. dask_array/slicing/__init__.py +67 -0
  107. dask_array/slicing/_basic.py +550 -0
  108. dask_array/slicing/_blocks.py +138 -0
  109. dask_array/slicing/_bool_index.py +145 -0
  110. dask_array/slicing/_setitem.py +329 -0
  111. dask_array/slicing/_squeeze.py +101 -0
  112. dask_array/slicing/_utils.py +1133 -0
  113. dask_array/slicing/_vindex.py +282 -0
  114. dask_array/stacking/__init__.py +15 -0
  115. dask_array/stacking/_block.py +83 -0
  116. dask_array/stacking/_simple.py +58 -0
  117. dask_array/templates/array.html.j2 +48 -0
  118. dask_array/tests/__init__.py +0 -0
  119. dask_array/tests/conftest.py +22 -0
  120. dask_array/tests/test_api.py +40 -0
  121. dask_array/tests/test_binary_op_chunks.py +107 -0
  122. dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
  123. dask_array/tests/test_collection.py +799 -0
  124. dask_array/tests/test_creation.py +1102 -0
  125. dask_array/tests/test_expr_flow.py +143 -0
  126. dask_array/tests/test_linalg.py +1130 -0
  127. dask_array/tests/test_map_blocks_multi_output.py +104 -0
  128. dask_array/tests/test_rechunk_pushdown.py +214 -0
  129. dask_array/tests/test_reductions.py +1091 -0
  130. dask_array/tests/test_routines.py +2853 -0
  131. dask_array/tests/test_shuffle_chunks.py +67 -0
  132. dask_array/tests/test_slice_pushdown.py +968 -0
  133. dask_array/tests/test_slice_through_blockwise.py +678 -0
  134. dask_array/tests/test_slice_through_overlap.py +366 -0
  135. dask_array/tests/test_slice_through_reshape.py +272 -0
  136. dask_array/tests/test_slicing.py +839 -0
  137. dask_array/tests/test_transpose_slice_pushdown.py +208 -0
  138. dask_array/tests/test_visualize.py +94 -0
  139. dask_array/tests/test_xarray.py +193 -0
  140. dask_array-0.1.0.dist-info/METADATA +48 -0
  141. dask_array-0.1.0.dist-info/RECORD +144 -0
  142. dask_array-0.1.0.dist-info/WHEEL +4 -0
  143. dask_array-0.1.0.dist-info/entry_points.txt +2 -0
  144. dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_overlap.py ADDED
@@ -0,0 +1,1159 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import warnings
5
+ from functools import reduce
6
+ from numbers import Integral, Number
7
+ from operator import mul
8
+
9
+ import numpy as np
10
+ from tlz import concat, get, partial
11
+ from tlz.curried import map
12
+
13
+ from dask_array._new_collection import new_collection
14
+ from dask_array import _chunk as chunk
15
+ from dask_array._collection import Array, concatenate
16
+ from dask_array._expr import ArrayExpr, unify_chunks_expr
17
+ from dask_array._map_blocks import map_blocks
18
+ from dask_array.creation import empty_like, full_like, repeat
19
+ from dask_array._shuffle import _calculate_new_chunksizes
20
+ from dask_array._numpy_compat import normalize_axis_tuple
21
+ from dask_array._utils import compute_meta, meta_from_array
22
+ from dask.layers import ArrayOverlapLayer
23
+ from dask.utils import derived_from, ensure_dict
24
+
25
+
26
+ def _overlap_internal_chunks(original_chunks, axes):
27
+ """Get new chunks for array with overlap."""
28
+ chunks = []
29
+ for i, bds in enumerate(original_chunks):
30
+ depth = axes.get(i, 0)
31
+ if isinstance(depth, tuple):
32
+ left_depth = depth[0]
33
+ right_depth = depth[1]
34
+ else:
35
+ left_depth = depth
36
+ right_depth = depth
37
+
38
+ if len(bds) == 1:
39
+ chunks.append(bds)
40
+ else:
41
+ left = [bds[0] + right_depth]
42
+ right = [bds[-1] + left_depth]
43
+ mid = []
44
+ for bd in bds[1:-1]:
45
+ mid.append(bd + left_depth + right_depth)
46
+ chunks.append(left + mid + right)
47
+ return chunks
48
+
49
+
50
+ def overlap_internal(x, axes):
51
+ """Share boundaries between neighboring blocks
52
+
53
+ Parameters
54
+ ----------
55
+
56
+ x: da.Array
57
+ A dask array
58
+ axes: dict
59
+ The size of the shared boundary per axis
60
+
61
+ The axes input informs how many cells to overlap between neighboring blocks
62
+ {0: 2, 2: 5} means share two cells in 0 axis, 5 cells in 2 axis
63
+ """
64
+ return new_collection(OverlapInternal(x, axes))
65
+
66
+
67
+ class OverlapInternal(ArrayExpr):
68
+ """Low-level overlap expression that shares boundaries between blocks.
69
+
70
+ This is the internal implementation detail. For the user-facing
71
+ map_overlap operation, see MapOverlap.
72
+ """
73
+
74
+ _parameters = ["array", "axes"]
75
+
76
+ @functools.cached_property
77
+ def _meta(self):
78
+ return meta_from_array(self.array)
79
+
80
+ @functools.cached_property
81
+ def chunks(self):
82
+ return tuple(map(tuple, _overlap_internal_chunks(self.array.chunks, self.axes)))
83
+
84
+ @functools.cached_property
85
+ def _name(self) -> str:
86
+ return f"overlap-{super()._name}"
87
+
88
+ def _layer(self) -> dict:
89
+ x = self.array
90
+ graph = ArrayOverlapLayer(
91
+ name=x.name,
92
+ axes=self.axes,
93
+ chunks=x.chunks,
94
+ numblocks=x.numblocks,
95
+ token="-".join(self._name.split("-")[1:]),
96
+ )
97
+ return ensure_dict(graph)
98
+
99
+
100
+ class MapOverlap(ArrayExpr):
101
+ """Logical expression for the full map_overlap operation.
102
+
103
+ This captures the user's intent: apply func with overlap depth/boundary,
104
+ optionally trimming the result. Slice pushdown is simple because we
105
+ understand the semantics.
106
+
107
+ Note: new_axis/drop_axis cases are handled by _map_overlap_direct instead.
108
+
109
+ The expression is lowered to the full pipeline during _lower():
110
+ rechunk -> boundaries -> overlap_internal -> map_blocks -> trim
111
+ """
112
+
113
+ _parameters = [
114
+ "arrays", # tuple of input ArrayExpr
115
+ "func", # callable
116
+ "depth", # list of dicts (one per array)
117
+ "boundary", # list of dicts (one per array)
118
+ "trim_output", # bool
119
+ "allow_rechunk", # bool
120
+ "kwargs", # dict for map_blocks kwargs
121
+ ]
122
+ _defaults = {
123
+ "trim_output": True,
124
+ "allow_rechunk": True,
125
+ "kwargs": None,
126
+ }
127
+
128
+ @functools.cached_property
129
+ def _meta(self):
130
+ # Check for explicit meta
131
+ meta = self._kwargs.get("meta")
132
+ if meta is not None:
133
+ return meta_from_array(meta)
134
+
135
+ # Check for explicit dtype
136
+ dtype = self._kwargs.get("dtype")
137
+ if dtype is not None:
138
+ return np.empty((0,) * self.ndim, dtype=dtype)
139
+
140
+ # Try to infer dtype by calling the function on array collections
141
+ try:
142
+ arr_collections = [new_collection(a) for a in self.arrays]
143
+ meta = compute_meta(self.func, None, *arr_collections)
144
+ if meta is not None:
145
+ return meta
146
+ except Exception:
147
+ pass
148
+
149
+ # Default to primary (highest-rank) array's meta
150
+ return meta_from_array(self._get_primary_array())
151
+
152
+ @property
153
+ def _kwargs(self):
154
+ return self.kwargs if self.kwargs is not None else {}
155
+
156
+ def _get_primary_array(self):
157
+ """Get the primary array (highest rank, first if tied) for shape/chunk info."""
158
+ return max(enumerate(self.arrays), key=lambda x: (x[1].ndim, -x[0]))[1]
159
+
160
+ def _get_primary_index(self):
161
+ """Get the index of the primary array (highest rank, first if tied)."""
162
+ return max(enumerate(self.arrays), key=lambda x: (x[1].ndim, -x[0]))[0]
163
+
164
+ @functools.cached_property
165
+ def shape(self):
166
+ # Output shape = input shape (no new_axis/drop_axis in this expr)
167
+ return self._get_primary_array().shape
168
+
169
+ @functools.cached_property
170
+ def chunks(self):
171
+ # If allow_rechunk, the input is rechunked to ensure minimum chunk size >= depth
172
+ primary = self._get_primary_array()
173
+ primary_idx = self._get_primary_index()
174
+ if self.allow_rechunk:
175
+ return _get_overlap_rechunked_chunks(new_collection(primary), self.depth[primary_idx])
176
+ return primary.chunks
177
+
178
+ @functools.cached_property
179
+ def _name(self) -> str:
180
+ return f"map-overlap-{super()._name}"
181
+
182
+ def _simplify_up(self, parent, dependents):
183
+ """Push slice through MapOverlap.
184
+
185
+ For a slice on MapOverlap:
186
+ - Non-overlap axes: push slice directly to inputs
187
+ - Overlap axes: expand slice by depth, push to inputs, leave trim at top
188
+ """
189
+ from dask_array.slicing import SliceSlicesIntegers
190
+
191
+ if not isinstance(parent, SliceSlicesIntegers):
192
+ return None
193
+
194
+ index = parent.index
195
+ ndim = self.arrays[0].ndim
196
+
197
+ # Don't handle None (newaxis) or integers (dimension reduction)
198
+ if any(idx is None for idx in index):
199
+ return None
200
+ if any(isinstance(idx, Integral) for idx in index):
201
+ return None
202
+
203
+ # Pad index to full length
204
+ full_index = list(index) + [slice(None)] * (ndim - len(index))
205
+
206
+ # Build input slices for each input array
207
+ output_trim_index = []
208
+ needs_trim = False
209
+
210
+ # Get depth for first array (all arrays should have same depth structure)
211
+ depth = self.depth[0]
212
+
213
+ for axis in range(ndim):
214
+ idx = full_index[axis]
215
+ d = depth.get(axis, 0)
216
+
217
+ # Get actual depth (handle tuple for asymmetric overlap)
218
+ if isinstance(d, tuple):
219
+ left_depth, right_depth = d
220
+ max_depth = max(left_depth, right_depth)
221
+ else:
222
+ left_depth = right_depth = max_depth = d
223
+
224
+ if not isinstance(idx, slice):
225
+ return None # Unexpected index type
226
+
227
+ if idx == slice(None):
228
+ output_trim_index.append(slice(None))
229
+ continue
230
+
231
+ # Normalize the slice
232
+ dim_size = self.shape[axis]
233
+ start, stop, step = idx.indices(dim_size)
234
+
235
+ if step != 1:
236
+ return None # Don't handle non-unit steps
237
+
238
+ if max_depth == 0:
239
+ # No overlap on this axis - push directly
240
+ output_trim_index.append(slice(None))
241
+ else:
242
+ # Expand slice by overlap depth for input
243
+ # But respect array boundaries
244
+ input_size = self.arrays[0].shape[axis]
245
+ expanded_start = max(0, start - left_depth)
246
+ expanded_stop = min(input_size, stop + right_depth)
247
+
248
+ # Replace this axis in full_index with expanded slice
249
+ full_index[axis] = slice(expanded_start, expanded_stop)
250
+
251
+ # Compute trim slice to get original result
252
+ trim_start = start - expanded_start
253
+ trim_stop = trim_start + (stop - start)
254
+ output_trim_index.append(slice(trim_start, trim_stop))
255
+ needs_trim = True
256
+
257
+ # Slice all input arrays
258
+ new_arrays = []
259
+ for arr in self.arrays:
260
+ sliced = new_collection(arr)[tuple(full_index)]
261
+ new_arrays.append(sliced.expr)
262
+
263
+ # Create new MapOverlap with sliced inputs
264
+ new_expr = MapOverlap(
265
+ arrays=tuple(new_arrays),
266
+ func=self.func,
267
+ depth=self.depth,
268
+ boundary=self.boundary,
269
+ trim_output=self.trim_output,
270
+ allow_rechunk=self.allow_rechunk,
271
+ kwargs=self.kwargs,
272
+ )
273
+
274
+ if needs_trim:
275
+ # Apply trim slice to output
276
+ return SliceSlicesIntegers(new_expr, tuple(output_trim_index), parent.allow_getitem_optimization)
277
+ else:
278
+ return new_expr
279
+
280
+ def _lower(self):
281
+ """Expand to the full overlap pipeline.
282
+
283
+ This expands to: rechunk -> boundaries -> overlap_internal -> map_blocks -> trim
284
+ """
285
+ # Apply overlap to each input array
286
+ overlapped = []
287
+ for arr, d, b in zip(self.arrays, self.depth, self.boundary):
288
+ arr_coll = new_collection(arr)
289
+ overlapped_arr = overlap(arr_coll, depth=d, boundary=b, allow_rechunk=self.allow_rechunk)
290
+ overlapped.append(overlapped_arr.expr)
291
+
292
+ # Build map_blocks expression
293
+ result = map_blocks(self.func, *[new_collection(a) for a in overlapped], **self._kwargs)
294
+
295
+ if self.trim_output:
296
+ # Find highest-rank array for trim settings
297
+ i = sorted(enumerate(overlapped), key=lambda v: (v[1].ndim, -v[0]))[-1][0]
298
+ trim_depth = dict(self.depth[i])
299
+ trim_boundary = dict(self.boundary[i])
300
+ result = trim_internal(result, trim_depth, trim_boundary)
301
+
302
+ return result.expr
303
+
304
+
305
+ def trim_overlap(x, depth, boundary=None):
306
+ """Trim sides from each block.
307
+
308
+ This couples well with the ``map_overlap`` operation which may leave
309
+ excess data on each block.
310
+
311
+ See also
312
+ --------
313
+ dask.array.overlap.map_overlap
314
+
315
+ """
316
+
317
+ # parameter to be passed to trim_internal
318
+ axes = coerce_depth(x.ndim, depth)
319
+ return trim_internal(x, axes=axes, boundary=boundary)
320
+
321
+
322
+ def trim_internal(x, axes, boundary=None):
323
+ """Trim sides from each block
324
+
325
+ This couples well with the overlap operation, which may leave excess data on
326
+ each block
327
+
328
+ See also
329
+ --------
330
+ dask.array.chunk.trim
331
+ dask.array.map_blocks
332
+ """
333
+ boundary = coerce_boundary(x.ndim, boundary)
334
+
335
+ olist = []
336
+ for i, bd in enumerate(x.chunks):
337
+ bdy = boundary.get(i, "none")
338
+ overlap = axes.get(i, 0)
339
+ ilist = []
340
+ for j, d in enumerate(bd):
341
+ if bdy != "none":
342
+ if isinstance(overlap, tuple):
343
+ d = d - sum(overlap)
344
+ else:
345
+ d = d - overlap * 2
346
+
347
+ elif isinstance(overlap, tuple):
348
+ d = d - overlap[0] if j != 0 else d
349
+ d = d - overlap[1] if j != len(bd) - 1 else d
350
+ else:
351
+ d = d - overlap if j != 0 else d
352
+ d = d - overlap if j != len(bd) - 1 else d
353
+
354
+ ilist.append(d)
355
+ olist.append(tuple(ilist))
356
+ chunks = tuple(olist)
357
+
358
+ return map_blocks(
359
+ partial(_trim, axes=axes, boundary=boundary),
360
+ x,
361
+ chunks=chunks,
362
+ dtype=x.dtype,
363
+ meta=x._meta,
364
+ )
365
+
366
+
367
+ def _trim(x, axes, boundary, _overlap_trim_info):
368
+ """Similar to dask.array.chunk.trim but requires one to specify the
369
+ boundary condition.
370
+
371
+ ``axes``, and ``boundary`` are assumed to have been coerced.
372
+
373
+ """
374
+ chunk_location = _overlap_trim_info[0]
375
+ num_chunks = _overlap_trim_info[1]
376
+ axes = [axes.get(i, 0) for i in range(x.ndim)]
377
+ axes_front = (ax[0] if isinstance(ax, tuple) else ax for ax in axes)
378
+ axes_back = (
379
+ (-ax[1] if isinstance(ax, tuple) and ax[1] else -ax if isinstance(ax, Integral) and ax else None) for ax in axes
380
+ )
381
+
382
+ trim_front = (
383
+ 0 if (chunk_location == 0 and boundary.get(i, "none") == "none") else ax
384
+ for i, (chunk_location, ax) in enumerate(zip(chunk_location, axes_front))
385
+ )
386
+ trim_back = (
387
+ (None if (chunk_location == chunks - 1 and boundary.get(i, "none") == "none") else ax)
388
+ for i, (chunks, chunk_location, ax) in enumerate(zip(num_chunks, chunk_location, axes_back))
389
+ )
390
+ ind = tuple(slice(front, back) for front, back in zip(trim_front, trim_back))
391
+ return x[ind]
392
+
393
+
394
+ def periodic(x, axis, depth):
395
+ """Copy a slice of an array around to its other side
396
+
397
+ Useful to create periodic boundary conditions for overlap
398
+ """
399
+
400
+ left = (slice(None, None, None),) * axis + (slice(0, depth),) + (slice(None, None, None),) * (x.ndim - axis - 1)
401
+ right = (
402
+ (slice(None, None, None),) * axis + (slice(-depth, None),) + (slice(None, None, None),) * (x.ndim - axis - 1)
403
+ )
404
+ l = x[left]
405
+ r = x[right]
406
+
407
+ l, r = _remove_overlap_boundaries(l, r, axis, depth)
408
+
409
+ return concatenate([r, x, l], axis=axis)
410
+
411
+
412
+ def reflect(x, axis, depth):
413
+ """Reflect boundaries of array on the same side
414
+
415
+ This is the converse of ``periodic``
416
+ """
417
+ if depth == 1:
418
+ left = (slice(None, None, None),) * axis + (slice(0, 1),) + (slice(None, None, None),) * (x.ndim - axis - 1)
419
+ else:
420
+ left = (
421
+ (slice(None, None, None),) * axis
422
+ + (slice(depth - 1, None, -1),)
423
+ + (slice(None, None, None),) * (x.ndim - axis - 1)
424
+ )
425
+ right = (
426
+ (slice(None, None, None),) * axis
427
+ + (slice(-1, -depth - 1, -1),)
428
+ + (slice(None, None, None),) * (x.ndim - axis - 1)
429
+ )
430
+ l = x[left]
431
+ r = x[right]
432
+
433
+ l, r = _remove_overlap_boundaries(l, r, axis, depth)
434
+
435
+ return concatenate([l, x, r], axis=axis)
436
+
437
+
438
+ def nearest(x, axis, depth):
439
+ """Each reflect each boundary value outwards
440
+
441
+ This mimics what the skimage.filters.gaussian_filter(... mode="nearest")
442
+ does.
443
+ """
444
+ left = (slice(None, None, None),) * axis + (slice(0, 1),) + (slice(None, None, None),) * (x.ndim - axis - 1)
445
+ right = (slice(None, None, None),) * axis + (slice(-1, -2, -1),) + (slice(None, None, None),) * (x.ndim - axis - 1)
446
+
447
+ l = repeat(x[left], depth, axis=axis)
448
+ r = repeat(x[right], depth, axis=axis)
449
+
450
+ l, r = _remove_overlap_boundaries(l, r, axis, depth)
451
+
452
+ return concatenate([l, x, r], axis=axis)
453
+
454
+
455
+ def constant(x, axis, depth, value):
456
+ """Add constant slice to either side of array"""
457
+ chunks = list(x.chunks)
458
+ chunks[axis] = (depth,)
459
+
460
+ c = full_like(
461
+ x,
462
+ value,
463
+ shape=tuple(map(sum, chunks)),
464
+ chunks=tuple(chunks),
465
+ dtype=x.dtype,
466
+ )
467
+
468
+ return concatenate([c, x, c], axis=axis)
469
+
470
+
471
+ def _remove_overlap_boundaries(l, r, axis, depth):
472
+ lchunks = list(l.chunks)
473
+ lchunks[axis] = (depth,)
474
+ rchunks = list(r.chunks)
475
+ rchunks[axis] = (depth,)
476
+
477
+ l = l.rechunk(tuple(lchunks))
478
+ r = r.rechunk(tuple(rchunks))
479
+ return l, r
480
+
481
+
482
+ def boundaries(x, depth=None, kind=None):
483
+ """Add boundary conditions to an array before overlapping
484
+
485
+ See Also
486
+ --------
487
+ periodic
488
+ constant
489
+ """
490
+ if not isinstance(kind, dict):
491
+ kind = dict.fromkeys(range(x.ndim), kind)
492
+ if not isinstance(depth, dict):
493
+ depth = dict.fromkeys(range(x.ndim), depth)
494
+
495
+ for i in range(x.ndim):
496
+ d = depth.get(i, 0)
497
+ if d == 0:
498
+ continue
499
+
500
+ this_kind = kind.get(i, "none")
501
+ if this_kind == "none":
502
+ continue
503
+ elif this_kind == "periodic":
504
+ x = periodic(x, i, d)
505
+ elif this_kind == "reflect":
506
+ x = reflect(x, i, d)
507
+ elif this_kind == "nearest":
508
+ x = nearest(x, i, d)
509
+ elif i in kind:
510
+ x = constant(x, i, d, kind[i])
511
+
512
+ return x
513
+
514
+
515
+ def ensure_minimum_chunksize(size, chunks):
516
+ """Determine new chunks to ensure that every chunk >= size
517
+
518
+ Parameters
519
+ ----------
520
+ size: int
521
+ The maximum size of any chunk.
522
+ chunks: tuple
523
+ Chunks along one axis, e.g. ``(3, 3, 2)``
524
+
525
+ Examples
526
+ --------
527
+ >>> ensure_minimum_chunksize(10, (20, 20, 1))
528
+ (20, 11, 10)
529
+ >>> ensure_minimum_chunksize(3, (1, 1, 3))
530
+ (5,)
531
+
532
+ See Also
533
+ --------
534
+ overlap
535
+ """
536
+ if size <= min(chunks):
537
+ return chunks
538
+
539
+ # add too-small chunks to chunks before them
540
+ output = []
541
+ new = 0
542
+ for c in chunks:
543
+ if c < size:
544
+ if new > size + (size - c):
545
+ output.append(new - (size - c))
546
+ new = size
547
+ else:
548
+ new += c
549
+ if new >= size:
550
+ output.append(new)
551
+ new = 0
552
+ if c >= size:
553
+ new += c
554
+ if new >= size:
555
+ output.append(new)
556
+ elif len(output) >= 1:
557
+ output[-1] += new
558
+ else:
559
+ raise ValueError(f"The overlapping depth {size} is larger than your array {sum(chunks)}.")
560
+
561
+ return tuple(output)
562
+
563
+
564
+ def _get_overlap_rechunked_chunks(x, depth2):
565
+ depths = [max(d) if isinstance(d, tuple) else d for d in depth2.values()]
566
+ # rechunk if new chunks are needed to fit depth in every chunk
567
+ return tuple(ensure_minimum_chunksize(size, c) for size, c in zip(depths, x.chunks))
568
+
569
+
570
+ def overlap(x, depth, boundary, *, allow_rechunk=True):
571
+ """Share boundaries between neighboring blocks
572
+
573
+ Parameters
574
+ ----------
575
+
576
+ x: da.Array
577
+ A dask array
578
+ depth: dict
579
+ The size of the shared boundary per axis
580
+ boundary: dict
581
+ The boundary condition on each axis. Options are 'reflect', 'periodic',
582
+ 'nearest', 'none', or an array value. Such a value will fill the
583
+ boundary with that value.
584
+ allow_rechunk: bool, keyword only
585
+ Allows rechunking, otherwise chunk sizes need to match and core
586
+ dimensions are to consist only of one chunk.
587
+
588
+ The depth input informs how many cells to overlap between neighboring
589
+ blocks ``{0: 2, 2: 5}`` means share two cells in 0 axis, 5 cells in 2 axis.
590
+ Axes missing from this input will not be overlapped.
591
+
592
+ Any axis containing chunks smaller than depth will be rechunked if
593
+ possible, provided the keyword ``allow_rechunk`` is True (recommended).
594
+
595
+ Examples
596
+ --------
597
+ >>> import numpy as np
598
+ >>> import dask_array as da
599
+
600
+ >>> x = np.arange(64).reshape((8, 8))
601
+ >>> d = da.from_array(x, chunks=(4, 4))
602
+ >>> d.chunks
603
+ ((4, 4), (4, 4))
604
+
605
+ >>> g = da.overlap.overlap(d, depth={0: 2, 1: 1},
606
+ ... boundary={0: 100, 1: 'reflect'})
607
+ >>> g.chunks
608
+ ((8, 8), (6, 6))
609
+
610
+ >>> np.array(g)
611
+ array([[100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
612
+ [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
613
+ [ 0, 0, 1, 2, 3, 4, 3, 4, 5, 6, 7, 7],
614
+ [ 8, 8, 9, 10, 11, 12, 11, 12, 13, 14, 15, 15],
615
+ [ 16, 16, 17, 18, 19, 20, 19, 20, 21, 22, 23, 23],
616
+ [ 24, 24, 25, 26, 27, 28, 27, 28, 29, 30, 31, 31],
617
+ [ 32, 32, 33, 34, 35, 36, 35, 36, 37, 38, 39, 39],
618
+ [ 40, 40, 41, 42, 43, 44, 43, 44, 45, 46, 47, 47],
619
+ [ 16, 16, 17, 18, 19, 20, 19, 20, 21, 22, 23, 23],
620
+ [ 24, 24, 25, 26, 27, 28, 27, 28, 29, 30, 31, 31],
621
+ [ 32, 32, 33, 34, 35, 36, 35, 36, 37, 38, 39, 39],
622
+ [ 40, 40, 41, 42, 43, 44, 43, 44, 45, 46, 47, 47],
623
+ [ 48, 48, 49, 50, 51, 52, 51, 52, 53, 54, 55, 55],
624
+ [ 56, 56, 57, 58, 59, 60, 59, 60, 61, 62, 63, 63],
625
+ [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
626
+ [100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100]])
627
+ """
628
+ depth2 = coerce_depth(x.ndim, depth)
629
+ boundary2 = coerce_boundary(x.ndim, boundary)
630
+
631
+ depths = [max(d) if isinstance(d, tuple) else d for d in depth2.values()]
632
+ if allow_rechunk:
633
+ # rechunk if new chunks are needed to fit depth in every chunk
634
+ x1 = x.rechunk(_get_overlap_rechunked_chunks(x, depth2)) # this is a no-op if x.chunks == new_chunks
635
+
636
+ else:
637
+ original_chunks_too_small = any(min(c) < d for d, c in zip(depths, x.chunks))
638
+ if original_chunks_too_small:
639
+ raise ValueError(
640
+ "Overlap depth is larger than smallest chunksize.\n"
641
+ "Please set allow_rechunk=True to rechunk automatically.\n"
642
+ f"Overlap depths required: {depths}\n"
643
+ f"Input chunks: {x.chunks}\n"
644
+ )
645
+ x1 = x
646
+
647
+ x2 = boundaries(x1, depth2, boundary2)
648
+ x3 = overlap_internal(x2, depth2)
649
+ trim = {k: v * 2 if boundary2.get(k, "none") != "none" else 0 for k, v in depth2.items()}
650
+ x4 = chunk.trim(x3, trim)
651
+ return x4
652
+
653
+
654
+ def add_dummy_padding(x, depth, boundary):
655
+ """
656
+ Pads an array which has 'none' as the boundary type.
657
+ Used to simplify trimming arrays which use 'none'.
658
+
659
+ >>> import dask_array as da
660
+ >>> x = da.arange(6, chunks=3)
661
+ >>> add_dummy_padding(x, {0: 1}, {0: 'none'}).compute() # doctest: +NORMALIZE_WHITESPACE
662
+ array([..., 0, 1, 2, 3, 4, 5, ...])
663
+ """
664
+ for k, v in boundary.items():
665
+ d = depth.get(k, 0)
666
+ if v == "none" and d > 0:
667
+ empty_shape = list(x.shape)
668
+ empty_shape[k] = d
669
+
670
+ empty_chunks = list(x.chunks)
671
+ empty_chunks[k] = (d,)
672
+
673
+ empty = empty_like(
674
+ getattr(x, "_meta", x),
675
+ shape=empty_shape,
676
+ chunks=empty_chunks,
677
+ dtype=x.dtype,
678
+ )
679
+
680
+ out_chunks = list(x.chunks)
681
+ ax_chunks = list(out_chunks[k])
682
+ ax_chunks[0] += d
683
+ ax_chunks[-1] += d
684
+ out_chunks[k] = tuple(ax_chunks)
685
+
686
+ x = concatenate([empty, x, empty], axis=k)
687
+ x = x.rechunk(out_chunks)
688
+ return x
689
+
690
+
691
+ def _map_overlap_direct(func, args, depth, boundary, trim, allow_rechunk, kwargs):
692
+ """Direct implementation of map_overlap without MapOverlap.
693
+
694
+ Used for cases with new_axis/drop_axis where MapOverlap doesn't apply.
695
+ """
696
+ # Apply overlap to each input array
697
+ overlapped = []
698
+ for x, d, b in zip(args, depth, boundary):
699
+ overlapped.append(overlap(x, depth=d, boundary=b, allow_rechunk=allow_rechunk))
700
+
701
+ # Apply the function via map_blocks
702
+ result = map_blocks(func, *overlapped, **kwargs)
703
+
704
+ if trim:
705
+ # Find highest-rank array for trim settings
706
+ i = sorted(enumerate(overlapped), key=lambda v: (v[1].ndim, -v[0]))[-1][0]
707
+ trim_depth = dict(depth[i])
708
+ trim_boundary = dict(boundary[i])
709
+
710
+ # Handle drop_axis
711
+ drop_axis = kwargs.get("drop_axis")
712
+ if drop_axis is not None:
713
+ if isinstance(drop_axis, Number):
714
+ drop_axis = [drop_axis]
715
+ ndim_out = max(a.ndim for a in overlapped)
716
+ drop_axis = [d % ndim_out for d in drop_axis]
717
+ kept_axes = tuple(ax for ax in range(overlapped[i].ndim) if ax not in drop_axis)
718
+ trim_depth = {n: trim_depth[ax] for n, ax in enumerate(kept_axes)}
719
+ trim_boundary = {n: trim_boundary[ax] for n, ax in enumerate(kept_axes)}
720
+
721
+ # Handle new_axis
722
+ new_axis = kwargs.get("new_axis")
723
+ if new_axis is not None:
724
+ if isinstance(new_axis, Number):
725
+ new_axis = [new_axis]
726
+ ndim_out = max(a.ndim for a in overlapped)
727
+ new_axis = [d % ndim_out for d in new_axis]
728
+
729
+ for axis in new_axis:
730
+ for existing_axis in list(trim_depth.keys()):
731
+ if existing_axis >= axis:
732
+ trim_depth[existing_axis + 1] = trim_depth[existing_axis]
733
+ trim_boundary[existing_axis + 1] = trim_boundary[existing_axis]
734
+ trim_depth[axis] = 0
735
+ trim_boundary[axis] = "none"
736
+
737
+ result = trim_internal(result, trim_depth, trim_boundary)
738
+
739
+ return result
740
+
741
+
742
+ def map_overlap(
743
+ func,
744
+ *args,
745
+ depth=None,
746
+ boundary=None,
747
+ trim=True,
748
+ align_arrays=True,
749
+ allow_rechunk=True,
750
+ **kwargs,
751
+ ):
752
+ """Map a function over blocks of arrays with some overlap
753
+
754
+ We share neighboring zones between blocks of the array, map a
755
+ function, and then trim away the neighboring strips. If depth is
756
+ larger than any chunk along a particular axis, then the array is
757
+ rechunked.
758
+
759
+ Note that this function will attempt to automatically determine the output
760
+ array type before computing it, please refer to the ``meta`` keyword argument
761
+ in ``map_blocks`` if you expect that the function will not succeed when
762
+ operating on 0-d arrays.
763
+
764
+ Parameters
765
+ ----------
766
+ func: function
767
+ The function to apply to each extended block.
768
+ If multiple arrays are provided, then the function should expect to
769
+ receive chunks of each array in the same order.
770
+ args : dask arrays
771
+ depth: int, tuple, dict or list, keyword only
772
+ The number of elements that each block should share with its neighbors
773
+ If a tuple or dict then this can be different per axis.
774
+ If a list then each element of that list must be an int, tuple or dict
775
+ defining depth for the corresponding array in `args`.
776
+ Asymmetric depths may be specified using a dict value of (-/+) tuples.
777
+ Note that asymmetric depths are currently only supported when
778
+ ``boundary`` is 'none'.
779
+ The default value is 0.
780
+ boundary: str, tuple, dict or list, keyword only
781
+ How to handle the boundaries.
782
+ Values include 'reflect', 'periodic', 'nearest', 'none',
783
+ or any constant value like 0 or np.nan.
784
+ If a list then each element must be a str, tuple or dict defining the
785
+ boundary for the corresponding array in `args`.
786
+ The default value is 'reflect'.
787
+ trim: bool, keyword only
788
+ Whether or not to trim ``depth`` elements from each block after
789
+ calling the map function.
790
+ Set this to False if your mapping function already does this for you
791
+ align_arrays: bool, keyword only
792
+ Whether or not to align chunks along equally sized dimensions when
793
+ multiple arrays are provided. This allows for larger chunks in some
794
+ arrays to be broken into smaller ones that match chunk sizes in other
795
+ arrays such that they are compatible for block function mapping. If
796
+ this is false, then an error will be thrown if arrays do not already
797
+ have the same number of blocks in each dimension.
798
+ allow_rechunk: bool, keyword only
799
+ Allows rechunking, otherwise chunk sizes need to match and core
800
+ dimensions are to consist only of one chunk.
801
+ **kwargs:
802
+ Other keyword arguments valid in ``map_blocks``
803
+
804
+ Examples
805
+ --------
806
+ >>> import numpy as np
807
+ >>> import dask_array as da
808
+
809
+ >>> x = np.array([1, 1, 2, 3, 3, 3, 2, 1, 1])
810
+ >>> x = da.from_array(x, chunks=5)
811
+ >>> def derivative(x):
812
+ ... return x - np.roll(x, 1)
813
+
814
+ >>> y = x.map_overlap(derivative, depth=1, boundary=0)
815
+ >>> y.compute()
816
+ array([ 1, 0, 1, 1, 0, 0, -1, -1, 0])
817
+
818
+ >>> x = np.arange(16).reshape((4, 4))
819
+ >>> d = da.from_array(x, chunks=(2, 2))
820
+ >>> d.map_overlap(lambda x: x + x.size, depth=1, boundary='reflect').compute()
821
+ array([[16, 17, 18, 19],
822
+ [20, 21, 22, 23],
823
+ [24, 25, 26, 27],
824
+ [28, 29, 30, 31]])
825
+
826
+ >>> func = lambda x: x + x.size
827
+ >>> depth = {0: 1, 1: 1}
828
+ >>> boundary = {0: 'reflect', 1: 'none'}
829
+ >>> d.map_overlap(func, depth, boundary).compute() # doctest: +NORMALIZE_WHITESPACE
830
+ array([[12, 13, 14, 15],
831
+ [16, 17, 18, 19],
832
+ [20, 21, 22, 23],
833
+ [24, 25, 26, 27]])
834
+
835
+ The ``da.map_overlap`` function can also accept multiple arrays.
836
+
837
+ >>> func = lambda x, y: x + y
838
+ >>> x = da.arange(8).reshape(2, 4).rechunk((1, 2))
839
+ >>> y = da.arange(4).rechunk(2)
840
+ >>> da.map_overlap(func, x, y, depth=1, boundary='reflect').compute() # doctest: +NORMALIZE_WHITESPACE
841
+ array([[ 0, 2, 4, 6],
842
+ [ 4, 6, 8, 10]])
843
+
844
+ When multiple arrays are given, they do not need to have the
845
+ same number of dimensions but they must broadcast together.
846
+ Arrays are aligned block by block (just as in ``da.map_blocks``)
847
+ so the blocks must have a common chunk size. This common chunking
848
+ is determined automatically as long as ``align_arrays`` is True.
849
+
850
+ >>> x = da.arange(8, chunks=4)
851
+ >>> y = da.arange(8, chunks=2)
852
+ >>> r = da.map_overlap(func, x, y, depth=1, boundary='reflect', align_arrays=True)
853
+ >>> len(r.to_delayed())
854
+ 4
855
+
856
+ >>> da.map_overlap(func, x, y, depth=1, boundary='reflect', align_arrays=False).compute()
857
+ Traceback (most recent call last):
858
+ ...
859
+ ValueError: Shapes do not align {'.0': {2, 4}}
860
+
861
+ Note also that this function is equivalent to ``map_blocks``
862
+ by default. A non-zero ``depth`` must be defined for any
863
+ overlap to appear in the arrays provided to ``func``.
864
+
865
+ >>> func = lambda x: x.sum()
866
+ >>> x = da.ones(10, dtype='int')
867
+ >>> block_args = dict(chunks=(), drop_axis=0)
868
+ >>> da.map_blocks(func, x, **block_args).compute()
869
+ np.int64(10)
870
+ >>> da.map_overlap(func, x, **block_args, boundary='reflect').compute()
871
+ np.int64(10)
872
+ >>> da.map_overlap(func, x, **block_args, depth=1, boundary='reflect').compute()
873
+ np.int64(12)
874
+
875
+ For functions that may not handle 0-d arrays, it's also possible to specify
876
+ ``meta`` with an empty array matching the type of the expected result. In
877
+ the example below, ``func`` will result in an ``IndexError`` when computing
878
+ ``meta``:
879
+
880
+ >>> x = np.arange(16).reshape((4, 4))
881
+ >>> d = da.from_array(x, chunks=(2, 2))
882
+ >>> y = d.map_overlap(lambda x: x + x[2], depth=1, boundary='reflect', meta=np.array(()))
883
+ >>> y
884
+ dask.array<_trim, shape=(4, 4), dtype=float64, chunksize=(2, 2), chunktype=numpy.ndarray>
885
+ >>> y.compute()
886
+ array([[ 4, 6, 8, 10],
887
+ [ 8, 10, 12, 14],
888
+ [20, 22, 24, 26],
889
+ [24, 26, 28, 30]])
890
+
891
+ Similarly, it's possible to specify a non-NumPy array to ``meta``:
892
+
893
+ >>> import cupy # doctest: +SKIP
894
+ >>> x = cupy.arange(16).reshape((4, 4)) # doctest: +SKIP
895
+ >>> d = da.from_array(x, chunks=(2, 2)) # doctest: +SKIP
896
+ >>> y = d.map_overlap(lambda x: x + x[2], depth=1, boundary='reflect', meta=cupy.array(())) # doctest: +SKIP
897
+ >>> y # doctest: +SKIP
898
+ dask.array<_trim, shape=(4, 4), dtype=float64, chunksize=(2, 2), chunktype=cupy.ndarray>
899
+ >>> y.compute() # doctest: +SKIP
900
+ array([[ 4, 6, 8, 10],
901
+ [ 8, 10, 12, 14],
902
+ [20, 22, 24, 26],
903
+ [24, 26, 28, 30]])
904
+ """
905
+ # Look for invocation using deprecated single-array signature
906
+ # map_overlap(x, func, depth, boundary=None, trim=True, **kwargs)
907
+ if isinstance(func, Array) and callable(args[0]):
908
+ warnings.warn(
909
+ "The use of map_overlap(array, func, **kwargs) is deprecated since dask 2.17.0 "
910
+ "and will be an error in a future release. To silence this warning, use the syntax "
911
+ "map_overlap(func, array0,[ array1, ...,] **kwargs) instead.",
912
+ FutureWarning,
913
+ )
914
+ sig = ["func", "depth", "boundary", "trim"]
915
+ depth = get(sig.index("depth"), args, depth)
916
+ boundary = get(sig.index("boundary"), args, boundary)
917
+ trim = get(sig.index("trim"), args, trim)
918
+ func, args = args[0], [func]
919
+
920
+ if not callable(func):
921
+ raise TypeError(
922
+ f"First argument must be callable function, not {type(func).__name__}\n"
923
+ "Usage: da.map_overlap(function, x)\n"
924
+ " or: da.map_overlap(function, x, y, z)"
925
+ )
926
+ if not all(isinstance(x, Array) for x in args):
927
+ raise TypeError(
928
+ f"All variadic arguments must be arrays, not {[type(x).__name__ for x in args]}\n"
929
+ "Usage: da.map_overlap(function, x)\n"
930
+ " or: da.map_overlap(function, x, y, z)"
931
+ )
932
+
933
+ # Coerce depth and boundary arguments to lists of individual
934
+ # specifications for each array argument
935
+ def coerce(xs, arg, fn):
936
+ if not isinstance(arg, list):
937
+ arg = [arg] * len(xs)
938
+ return [fn(x.ndim, a) for x, a in zip(xs, arg)]
939
+
940
+ depth = coerce(args, depth, coerce_depth)
941
+ boundary = coerce(args, boundary, coerce_boundary)
942
+
943
+ # Align chunks in each array to a common size
944
+ if align_arrays:
945
+ # Reverse unification order to allow block broadcasting
946
+ inds = [list(reversed(range(x.ndim))) for x in args]
947
+ args = [a.expr for a in args]
948
+ _, args, _ = unify_chunks_expr(*list(concat(zip(args, inds))))
949
+ args = [new_collection(a) for a in args]
950
+
951
+ # Escape to map_blocks if depth is zero (a more efficient computation)
952
+ if all(all(depth_val == 0 for depth_val in d.values()) for d in depth):
953
+ return map_blocks(func, *args, **kwargs)
954
+
955
+ for i, x in enumerate(args):
956
+ for j in range(x.ndim):
957
+ if isinstance(depth[i][j], tuple) and boundary[i][j] != "none":
958
+ raise NotImplementedError(
959
+ "Asymmetric overlap is currently only implemented "
960
+ "for boundary='none', however boundary for dimension "
961
+ f"{j} in array argument {i} is {boundary[i][j]}"
962
+ )
963
+
964
+ def assert_int_chunksize(xs):
965
+ assert all(type(c) is int for x in xs for cc in x.chunks for c in cc)
966
+
967
+ assert_int_chunksize(args)
968
+
969
+ # Validate chunk sizes if rechunking is not allowed
970
+ if not allow_rechunk:
971
+ for x, d in zip(args, depth):
972
+ depths = [max(dd) if isinstance(dd, tuple) else dd for dd in d.values()]
973
+ original_chunks_too_small = any(min(c) < dd for dd, c in zip(depths, x.chunks))
974
+ if original_chunks_too_small:
975
+ raise ValueError(
976
+ "Overlap depth is larger than smallest chunksize.\n"
977
+ "Please set allow_rechunk=True to rechunk automatically.\n"
978
+ f"Overlap depths required: {depths}\n"
979
+ f"Input chunks: {x.chunks}\n"
980
+ )
981
+
982
+ # Fall back to direct implementation for complex cases:
983
+ # - new_axis/drop_axis: change dimensionality
984
+ # - explicit chunks: change output shape/chunks
985
+ if kwargs.get("new_axis") is not None or kwargs.get("drop_axis") is not None or kwargs.get("chunks") is not None:
986
+ return _map_overlap_direct(func, args, depth, boundary, trim, allow_rechunk, kwargs)
987
+
988
+ # Create the logical MapOverlap
989
+ # It will be lowered to the full pipeline during optimization
990
+ return new_collection(
991
+ MapOverlap(
992
+ arrays=tuple(a.expr for a in args),
993
+ func=func,
994
+ depth=depth,
995
+ boundary=boundary,
996
+ trim_output=trim,
997
+ allow_rechunk=allow_rechunk,
998
+ kwargs=kwargs if kwargs else None,
999
+ )
1000
+ )
1001
+
1002
+
1003
+ def coerce_depth(ndim, depth):
1004
+ default = 0
1005
+ if depth is None:
1006
+ depth = default
1007
+ if isinstance(depth, Integral):
1008
+ depth = (depth,) * ndim
1009
+ if isinstance(depth, tuple):
1010
+ depth = dict(zip(range(ndim), depth))
1011
+ if isinstance(depth, dict):
1012
+ depth = {ax: depth.get(ax, default) for ax in range(ndim)}
1013
+ return coerce_depth_type(ndim, depth)
1014
+
1015
+
1016
+ def coerce_depth_type(ndim, depth):
1017
+ for i in range(ndim):
1018
+ if isinstance(depth[i], tuple):
1019
+ depth[i] = tuple(int(d) for d in depth[i])
1020
+ else:
1021
+ depth[i] = int(depth[i])
1022
+ return depth
1023
+
1024
+
1025
+ def coerce_boundary(ndim, boundary):
1026
+ default = "none"
1027
+ if boundary is None:
1028
+ boundary = default
1029
+ if not isinstance(boundary, (tuple, dict)):
1030
+ boundary = (boundary,) * ndim
1031
+ if isinstance(boundary, tuple):
1032
+ boundary = dict(zip(range(ndim), boundary))
1033
+ if isinstance(boundary, dict):
1034
+ boundary = {ax: boundary.get(ax, default) for ax in range(ndim)}
1035
+ return boundary
1036
+
1037
+
1038
+ @derived_from(np.lib.stride_tricks)
1039
+ def sliding_window_view(x, window_shape, axis=None, automatic_rechunk=True):
1040
+ window_shape = tuple(window_shape) if np.iterable(window_shape) else (window_shape,)
1041
+
1042
+ window_shape_array = np.array(window_shape)
1043
+ if np.any(window_shape_array <= 0):
1044
+ raise ValueError("`window_shape` must contain values > 0")
1045
+
1046
+ if axis is None:
1047
+ axis = tuple(range(x.ndim))
1048
+ if len(window_shape) != len(axis):
1049
+ raise ValueError(
1050
+ f"Since axis is `None`, must provide "
1051
+ f"window_shape for all dimensions of `x`; "
1052
+ f"got {len(window_shape)} window_shape elements "
1053
+ f"and `x.ndim` is {x.ndim}."
1054
+ )
1055
+ else:
1056
+ axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
1057
+ if len(window_shape) != len(axis):
1058
+ raise ValueError(
1059
+ f"Must provide matching length window_shape and "
1060
+ f"axis; got {len(window_shape)} window_shape "
1061
+ f"elements and {len(axis)} axes elements."
1062
+ )
1063
+
1064
+ depths = [0] * x.ndim
1065
+ for ax, window in zip(axis, window_shape):
1066
+ depths[ax] += window - 1
1067
+
1068
+ # Ensure that each chunk is big enough to leave at least a size-1 chunk
1069
+ # after windowing (this is only really necessary for the last chunk).
1070
+ safe_chunks = list(ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks))
1071
+ if automatic_rechunk:
1072
+ safe_chunks = [s if d != 0 else c for d, c, s in zip(depths, x.chunks, safe_chunks)]
1073
+ # safe chunks is our output chunks, so add the new dimensions
1074
+ safe_chunks.extend([(w,) for w in window_shape])
1075
+ max_chunk = reduce(mul, map(max, x.chunks))
1076
+ new_chunks = _calculate_new_chunksizes(
1077
+ x.chunks,
1078
+ safe_chunks.copy(),
1079
+ {i for i, d in enumerate(depths) if d == 0},
1080
+ max_chunk,
1081
+ )
1082
+ x = x.rechunk(tuple(new_chunks))
1083
+ else:
1084
+ x = x.rechunk(tuple(safe_chunks))
1085
+
1086
+ # result.shape = x_shape_trimmed + window_shape,
1087
+ # where x_shape_trimmed is x.shape with every entry
1088
+ # reduced by one less than the corresponding window size.
1089
+ # trim chunks to match x_shape_trimmed
1090
+ newchunks = tuple(c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks)) + tuple(
1091
+ (window,) for window in window_shape
1092
+ )
1093
+
1094
+ return map_overlap(
1095
+ np.lib.stride_tricks.sliding_window_view,
1096
+ x,
1097
+ depth=tuple((0, d) for d in depths), # Overlap on +ve side only
1098
+ boundary="none",
1099
+ meta=x._meta,
1100
+ new_axis=range(x.ndim, x.ndim + len(axis)),
1101
+ chunks=newchunks,
1102
+ trim=False,
1103
+ align_arrays=False,
1104
+ window_shape=window_shape,
1105
+ axis=axis,
1106
+ )
1107
+
1108
+
1109
+ def _fill_with_last_one(a, b):
1110
+ """Fill NaN values in b with values from a."""
1111
+ return np.where(~np.isnan(b), b, a)
1112
+
1113
+
1114
+ def _push(array, n=None, axis=-1):
1115
+ """Apply bottleneck.push to a single chunk."""
1116
+ import bottleneck as bn
1117
+
1118
+ limit = n if n is not None else array.shape[axis]
1119
+ return bn.push(array, limit, axis)
1120
+
1121
+
1122
+ def push(array, n, axis):
1123
+ """
1124
+ Dask-version of bottleneck.push
1125
+
1126
+ Forward fill NaN values along an axis.
1127
+
1128
+ .. note::
1129
+
1130
+ Requires bottleneck to be installed.
1131
+ """
1132
+ import dask_array as da
1133
+ from dask._compatibility import import_optional_dependency
1134
+
1135
+ import_optional_dependency("bottleneck", min_version="1.3.7")
1136
+
1137
+ if n is not None and 0 < n < array.shape[axis] - 1:
1138
+ arr = da.broadcast_to(
1139
+ da.arange(array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype).reshape(
1140
+ tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
1141
+ ),
1142
+ array.shape,
1143
+ array.chunks,
1144
+ )
1145
+ valid_arange = da.where(da.notnull(array), arr, np.nan)
1146
+ valid_limits = (arr - push(valid_arange, None, axis)) <= n
1147
+ # omit the forward fill that violate the limit
1148
+ return da.where(valid_limits, push(array, None, axis), np.nan)
1149
+
1150
+ from dask_array.reductions import cumreduction
1151
+
1152
+ return cumreduction(
1153
+ func=_push,
1154
+ binop=_fill_with_last_one,
1155
+ ident=np.nan,
1156
+ x=array,
1157
+ axis=axis,
1158
+ dtype=array.dtype,
1159
+ )