dask-array 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dask_array/__init__.py +228 -0
- dask_array/_backends.py +76 -0
- dask_array/_backends_array.py +99 -0
- dask_array/_blockwise.py +1410 -0
- dask_array/_broadcast.py +272 -0
- dask_array/_chunk.py +445 -0
- dask_array/_chunk_types.py +54 -0
- dask_array/_collection.py +1644 -0
- dask_array/_concatenate.py +331 -0
- dask_array/_core_utils.py +1365 -0
- dask_array/_dispatch.py +141 -0
- dask_array/_einsum.py +277 -0
- dask_array/_expr.py +544 -0
- dask_array/_expr_flow.py +586 -0
- dask_array/_gufunc.py +805 -0
- dask_array/_histogram.py +617 -0
- dask_array/_map_blocks.py +652 -0
- dask_array/_new_collection.py +10 -0
- dask_array/_numpy_compat.py +135 -0
- dask_array/_overlap.py +1159 -0
- dask_array/_rechunk.py +1050 -0
- dask_array/_reshape.py +710 -0
- dask_array/_routines.py +102 -0
- dask_array/_shuffle.py +448 -0
- dask_array/_stack.py +264 -0
- dask_array/_svg.py +291 -0
- dask_array/_templates.py +29 -0
- dask_array/_test_utils.py +257 -0
- dask_array/_ufunc.py +385 -0
- dask_array/_utils.py +349 -0
- dask_array/_visualize.py +223 -0
- dask_array/_xarray.py +337 -0
- dask_array/core/__init__.py +34 -0
- dask_array/core/_blockwise_funcs.py +312 -0
- dask_array/core/_conversion.py +422 -0
- dask_array/core/_from_graph.py +97 -0
- dask_array/creation/__init__.py +71 -0
- dask_array/creation/_arange.py +121 -0
- dask_array/creation/_diag.py +116 -0
- dask_array/creation/_diagonal.py +241 -0
- dask_array/creation/_eye.py +103 -0
- dask_array/creation/_linspace.py +102 -0
- dask_array/creation/_mesh.py +134 -0
- dask_array/creation/_ones_zeros.py +454 -0
- dask_array/creation/_pad.py +270 -0
- dask_array/creation/_repeat.py +55 -0
- dask_array/creation/_tile.py +36 -0
- dask_array/creation/_tri.py +28 -0
- dask_array/creation/_utils.py +296 -0
- dask_array/fft.py +320 -0
- dask_array/io/__init__.py +39 -0
- dask_array/io/_base.py +10 -0
- dask_array/io/_from_array.py +257 -0
- dask_array/io/_from_delayed.py +95 -0
- dask_array/io/_from_graph.py +54 -0
- dask_array/io/_from_npy_stack.py +67 -0
- dask_array/io/_store.py +336 -0
- dask_array/io/_tiledb.py +159 -0
- dask_array/io/_to_npy_stack.py +65 -0
- dask_array/io/_zarr.py +449 -0
- dask_array/linalg/__init__.py +39 -0
- dask_array/linalg/_cholesky.py +234 -0
- dask_array/linalg/_lu.py +300 -0
- dask_array/linalg/_norm.py +94 -0
- dask_array/linalg/_qr.py +601 -0
- dask_array/linalg/_solve.py +349 -0
- dask_array/linalg/_svd.py +394 -0
- dask_array/linalg/_tensordot.py +334 -0
- dask_array/linalg/_utils.py +74 -0
- dask_array/manipulation/__init__.py +45 -0
- dask_array/manipulation/_expand.py +321 -0
- dask_array/manipulation/_flip.py +92 -0
- dask_array/manipulation/_roll.py +78 -0
- dask_array/manipulation/_transpose.py +309 -0
- dask_array/random/__init__.py +125 -0
- dask_array/random/_choice.py +181 -0
- dask_array/random/_expr.py +256 -0
- dask_array/random/_generator.py +441 -0
- dask_array/random/_random_state.py +259 -0
- dask_array/random/_utils.py +84 -0
- dask_array/reductions/__init__.py +84 -0
- dask_array/reductions/_arg_reduction.py +130 -0
- dask_array/reductions/_common.py +1082 -0
- dask_array/reductions/_cumulative.py +522 -0
- dask_array/reductions/_percentile.py +261 -0
- dask_array/reductions/_reduction.py +725 -0
- dask_array/reductions/_trace.py +56 -0
- dask_array/routines/__init__.py +133 -0
- dask_array/routines/_apply.py +84 -0
- dask_array/routines/_bincount.py +112 -0
- dask_array/routines/_broadcast.py +111 -0
- dask_array/routines/_coarsen.py +115 -0
- dask_array/routines/_diff.py +79 -0
- dask_array/routines/_gradient.py +158 -0
- dask_array/routines/_indexing.py +65 -0
- dask_array/routines/_insert_delete.py +132 -0
- dask_array/routines/_misc.py +122 -0
- dask_array/routines/_nonzero.py +72 -0
- dask_array/routines/_search.py +123 -0
- dask_array/routines/_select.py +113 -0
- dask_array/routines/_statistics.py +171 -0
- dask_array/routines/_topk.py +82 -0
- dask_array/routines/_triangular.py +74 -0
- dask_array/routines/_unique.py +232 -0
- dask_array/routines/_where.py +62 -0
- dask_array/slicing/__init__.py +67 -0
- dask_array/slicing/_basic.py +550 -0
- dask_array/slicing/_blocks.py +138 -0
- dask_array/slicing/_bool_index.py +145 -0
- dask_array/slicing/_setitem.py +329 -0
- dask_array/slicing/_squeeze.py +101 -0
- dask_array/slicing/_utils.py +1133 -0
- dask_array/slicing/_vindex.py +282 -0
- dask_array/stacking/__init__.py +15 -0
- dask_array/stacking/_block.py +83 -0
- dask_array/stacking/_simple.py +58 -0
- dask_array/templates/array.html.j2 +48 -0
- dask_array/tests/__init__.py +0 -0
- dask_array/tests/conftest.py +22 -0
- dask_array/tests/test_api.py +40 -0
- dask_array/tests/test_binary_op_chunks.py +107 -0
- dask_array/tests/test_coarse_slice_through_blockwise.py +362 -0
- dask_array/tests/test_collection.py +799 -0
- dask_array/tests/test_creation.py +1102 -0
- dask_array/tests/test_expr_flow.py +143 -0
- dask_array/tests/test_linalg.py +1130 -0
- dask_array/tests/test_map_blocks_multi_output.py +104 -0
- dask_array/tests/test_rechunk_pushdown.py +214 -0
- dask_array/tests/test_reductions.py +1091 -0
- dask_array/tests/test_routines.py +2853 -0
- dask_array/tests/test_shuffle_chunks.py +67 -0
- dask_array/tests/test_slice_pushdown.py +968 -0
- dask_array/tests/test_slice_through_blockwise.py +678 -0
- dask_array/tests/test_slice_through_overlap.py +366 -0
- dask_array/tests/test_slice_through_reshape.py +272 -0
- dask_array/tests/test_slicing.py +839 -0
- dask_array/tests/test_transpose_slice_pushdown.py +208 -0
- dask_array/tests/test_visualize.py +94 -0
- dask_array/tests/test_xarray.py +193 -0
- dask_array-0.1.0.dist-info/METADATA +48 -0
- dask_array-0.1.0.dist-info/RECORD +144 -0
- dask_array-0.1.0.dist-info/WHEEL +4 -0
- dask_array-0.1.0.dist-info/entry_points.txt +2 -0
- dask_array-0.1.0.dist-info/licenses/LICENSE +29 -0
dask_array/_overlap.py
ADDED
|
@@ -0,0 +1,1159 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import warnings
|
|
5
|
+
from functools import reduce
|
|
6
|
+
from numbers import Integral, Number
|
|
7
|
+
from operator import mul
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from tlz import concat, get, partial
|
|
11
|
+
from tlz.curried import map
|
|
12
|
+
|
|
13
|
+
from dask_array._new_collection import new_collection
|
|
14
|
+
from dask_array import _chunk as chunk
|
|
15
|
+
from dask_array._collection import Array, concatenate
|
|
16
|
+
from dask_array._expr import ArrayExpr, unify_chunks_expr
|
|
17
|
+
from dask_array._map_blocks import map_blocks
|
|
18
|
+
from dask_array.creation import empty_like, full_like, repeat
|
|
19
|
+
from dask_array._shuffle import _calculate_new_chunksizes
|
|
20
|
+
from dask_array._numpy_compat import normalize_axis_tuple
|
|
21
|
+
from dask_array._utils import compute_meta, meta_from_array
|
|
22
|
+
from dask.layers import ArrayOverlapLayer
|
|
23
|
+
from dask.utils import derived_from, ensure_dict
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _overlap_internal_chunks(original_chunks, axes):
|
|
27
|
+
"""Get new chunks for array with overlap."""
|
|
28
|
+
chunks = []
|
|
29
|
+
for i, bds in enumerate(original_chunks):
|
|
30
|
+
depth = axes.get(i, 0)
|
|
31
|
+
if isinstance(depth, tuple):
|
|
32
|
+
left_depth = depth[0]
|
|
33
|
+
right_depth = depth[1]
|
|
34
|
+
else:
|
|
35
|
+
left_depth = depth
|
|
36
|
+
right_depth = depth
|
|
37
|
+
|
|
38
|
+
if len(bds) == 1:
|
|
39
|
+
chunks.append(bds)
|
|
40
|
+
else:
|
|
41
|
+
left = [bds[0] + right_depth]
|
|
42
|
+
right = [bds[-1] + left_depth]
|
|
43
|
+
mid = []
|
|
44
|
+
for bd in bds[1:-1]:
|
|
45
|
+
mid.append(bd + left_depth + right_depth)
|
|
46
|
+
chunks.append(left + mid + right)
|
|
47
|
+
return chunks
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def overlap_internal(x, axes):
|
|
51
|
+
"""Share boundaries between neighboring blocks
|
|
52
|
+
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
|
|
56
|
+
x: da.Array
|
|
57
|
+
A dask array
|
|
58
|
+
axes: dict
|
|
59
|
+
The size of the shared boundary per axis
|
|
60
|
+
|
|
61
|
+
The axes input informs how many cells to overlap between neighboring blocks
|
|
62
|
+
{0: 2, 2: 5} means share two cells in 0 axis, 5 cells in 2 axis
|
|
63
|
+
"""
|
|
64
|
+
return new_collection(OverlapInternal(x, axes))
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class OverlapInternal(ArrayExpr):
|
|
68
|
+
"""Low-level overlap expression that shares boundaries between blocks.
|
|
69
|
+
|
|
70
|
+
This is the internal implementation detail. For the user-facing
|
|
71
|
+
map_overlap operation, see MapOverlap.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
_parameters = ["array", "axes"]
|
|
75
|
+
|
|
76
|
+
@functools.cached_property
|
|
77
|
+
def _meta(self):
|
|
78
|
+
return meta_from_array(self.array)
|
|
79
|
+
|
|
80
|
+
@functools.cached_property
|
|
81
|
+
def chunks(self):
|
|
82
|
+
return tuple(map(tuple, _overlap_internal_chunks(self.array.chunks, self.axes)))
|
|
83
|
+
|
|
84
|
+
@functools.cached_property
|
|
85
|
+
def _name(self) -> str:
|
|
86
|
+
return f"overlap-{super()._name}"
|
|
87
|
+
|
|
88
|
+
def _layer(self) -> dict:
|
|
89
|
+
x = self.array
|
|
90
|
+
graph = ArrayOverlapLayer(
|
|
91
|
+
name=x.name,
|
|
92
|
+
axes=self.axes,
|
|
93
|
+
chunks=x.chunks,
|
|
94
|
+
numblocks=x.numblocks,
|
|
95
|
+
token="-".join(self._name.split("-")[1:]),
|
|
96
|
+
)
|
|
97
|
+
return ensure_dict(graph)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class MapOverlap(ArrayExpr):
|
|
101
|
+
"""Logical expression for the full map_overlap operation.
|
|
102
|
+
|
|
103
|
+
This captures the user's intent: apply func with overlap depth/boundary,
|
|
104
|
+
optionally trimming the result. Slice pushdown is simple because we
|
|
105
|
+
understand the semantics.
|
|
106
|
+
|
|
107
|
+
Note: new_axis/drop_axis cases are handled by _map_overlap_direct instead.
|
|
108
|
+
|
|
109
|
+
The expression is lowered to the full pipeline during _lower():
|
|
110
|
+
rechunk -> boundaries -> overlap_internal -> map_blocks -> trim
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
_parameters = [
|
|
114
|
+
"arrays", # tuple of input ArrayExpr
|
|
115
|
+
"func", # callable
|
|
116
|
+
"depth", # list of dicts (one per array)
|
|
117
|
+
"boundary", # list of dicts (one per array)
|
|
118
|
+
"trim_output", # bool
|
|
119
|
+
"allow_rechunk", # bool
|
|
120
|
+
"kwargs", # dict for map_blocks kwargs
|
|
121
|
+
]
|
|
122
|
+
_defaults = {
|
|
123
|
+
"trim_output": True,
|
|
124
|
+
"allow_rechunk": True,
|
|
125
|
+
"kwargs": None,
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
@functools.cached_property
|
|
129
|
+
def _meta(self):
|
|
130
|
+
# Check for explicit meta
|
|
131
|
+
meta = self._kwargs.get("meta")
|
|
132
|
+
if meta is not None:
|
|
133
|
+
return meta_from_array(meta)
|
|
134
|
+
|
|
135
|
+
# Check for explicit dtype
|
|
136
|
+
dtype = self._kwargs.get("dtype")
|
|
137
|
+
if dtype is not None:
|
|
138
|
+
return np.empty((0,) * self.ndim, dtype=dtype)
|
|
139
|
+
|
|
140
|
+
# Try to infer dtype by calling the function on array collections
|
|
141
|
+
try:
|
|
142
|
+
arr_collections = [new_collection(a) for a in self.arrays]
|
|
143
|
+
meta = compute_meta(self.func, None, *arr_collections)
|
|
144
|
+
if meta is not None:
|
|
145
|
+
return meta
|
|
146
|
+
except Exception:
|
|
147
|
+
pass
|
|
148
|
+
|
|
149
|
+
# Default to primary (highest-rank) array's meta
|
|
150
|
+
return meta_from_array(self._get_primary_array())
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def _kwargs(self):
|
|
154
|
+
return self.kwargs if self.kwargs is not None else {}
|
|
155
|
+
|
|
156
|
+
def _get_primary_array(self):
|
|
157
|
+
"""Get the primary array (highest rank, first if tied) for shape/chunk info."""
|
|
158
|
+
return max(enumerate(self.arrays), key=lambda x: (x[1].ndim, -x[0]))[1]
|
|
159
|
+
|
|
160
|
+
def _get_primary_index(self):
|
|
161
|
+
"""Get the index of the primary array (highest rank, first if tied)."""
|
|
162
|
+
return max(enumerate(self.arrays), key=lambda x: (x[1].ndim, -x[0]))[0]
|
|
163
|
+
|
|
164
|
+
@functools.cached_property
|
|
165
|
+
def shape(self):
|
|
166
|
+
# Output shape = input shape (no new_axis/drop_axis in this expr)
|
|
167
|
+
return self._get_primary_array().shape
|
|
168
|
+
|
|
169
|
+
@functools.cached_property
|
|
170
|
+
def chunks(self):
|
|
171
|
+
# If allow_rechunk, the input is rechunked to ensure minimum chunk size >= depth
|
|
172
|
+
primary = self._get_primary_array()
|
|
173
|
+
primary_idx = self._get_primary_index()
|
|
174
|
+
if self.allow_rechunk:
|
|
175
|
+
return _get_overlap_rechunked_chunks(new_collection(primary), self.depth[primary_idx])
|
|
176
|
+
return primary.chunks
|
|
177
|
+
|
|
178
|
+
@functools.cached_property
|
|
179
|
+
def _name(self) -> str:
|
|
180
|
+
return f"map-overlap-{super()._name}"
|
|
181
|
+
|
|
182
|
+
def _simplify_up(self, parent, dependents):
|
|
183
|
+
"""Push slice through MapOverlap.
|
|
184
|
+
|
|
185
|
+
For a slice on MapOverlap:
|
|
186
|
+
- Non-overlap axes: push slice directly to inputs
|
|
187
|
+
- Overlap axes: expand slice by depth, push to inputs, leave trim at top
|
|
188
|
+
"""
|
|
189
|
+
from dask_array.slicing import SliceSlicesIntegers
|
|
190
|
+
|
|
191
|
+
if not isinstance(parent, SliceSlicesIntegers):
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
index = parent.index
|
|
195
|
+
ndim = self.arrays[0].ndim
|
|
196
|
+
|
|
197
|
+
# Don't handle None (newaxis) or integers (dimension reduction)
|
|
198
|
+
if any(idx is None for idx in index):
|
|
199
|
+
return None
|
|
200
|
+
if any(isinstance(idx, Integral) for idx in index):
|
|
201
|
+
return None
|
|
202
|
+
|
|
203
|
+
# Pad index to full length
|
|
204
|
+
full_index = list(index) + [slice(None)] * (ndim - len(index))
|
|
205
|
+
|
|
206
|
+
# Build input slices for each input array
|
|
207
|
+
output_trim_index = []
|
|
208
|
+
needs_trim = False
|
|
209
|
+
|
|
210
|
+
# Get depth for first array (all arrays should have same depth structure)
|
|
211
|
+
depth = self.depth[0]
|
|
212
|
+
|
|
213
|
+
for axis in range(ndim):
|
|
214
|
+
idx = full_index[axis]
|
|
215
|
+
d = depth.get(axis, 0)
|
|
216
|
+
|
|
217
|
+
# Get actual depth (handle tuple for asymmetric overlap)
|
|
218
|
+
if isinstance(d, tuple):
|
|
219
|
+
left_depth, right_depth = d
|
|
220
|
+
max_depth = max(left_depth, right_depth)
|
|
221
|
+
else:
|
|
222
|
+
left_depth = right_depth = max_depth = d
|
|
223
|
+
|
|
224
|
+
if not isinstance(idx, slice):
|
|
225
|
+
return None # Unexpected index type
|
|
226
|
+
|
|
227
|
+
if idx == slice(None):
|
|
228
|
+
output_trim_index.append(slice(None))
|
|
229
|
+
continue
|
|
230
|
+
|
|
231
|
+
# Normalize the slice
|
|
232
|
+
dim_size = self.shape[axis]
|
|
233
|
+
start, stop, step = idx.indices(dim_size)
|
|
234
|
+
|
|
235
|
+
if step != 1:
|
|
236
|
+
return None # Don't handle non-unit steps
|
|
237
|
+
|
|
238
|
+
if max_depth == 0:
|
|
239
|
+
# No overlap on this axis - push directly
|
|
240
|
+
output_trim_index.append(slice(None))
|
|
241
|
+
else:
|
|
242
|
+
# Expand slice by overlap depth for input
|
|
243
|
+
# But respect array boundaries
|
|
244
|
+
input_size = self.arrays[0].shape[axis]
|
|
245
|
+
expanded_start = max(0, start - left_depth)
|
|
246
|
+
expanded_stop = min(input_size, stop + right_depth)
|
|
247
|
+
|
|
248
|
+
# Replace this axis in full_index with expanded slice
|
|
249
|
+
full_index[axis] = slice(expanded_start, expanded_stop)
|
|
250
|
+
|
|
251
|
+
# Compute trim slice to get original result
|
|
252
|
+
trim_start = start - expanded_start
|
|
253
|
+
trim_stop = trim_start + (stop - start)
|
|
254
|
+
output_trim_index.append(slice(trim_start, trim_stop))
|
|
255
|
+
needs_trim = True
|
|
256
|
+
|
|
257
|
+
# Slice all input arrays
|
|
258
|
+
new_arrays = []
|
|
259
|
+
for arr in self.arrays:
|
|
260
|
+
sliced = new_collection(arr)[tuple(full_index)]
|
|
261
|
+
new_arrays.append(sliced.expr)
|
|
262
|
+
|
|
263
|
+
# Create new MapOverlap with sliced inputs
|
|
264
|
+
new_expr = MapOverlap(
|
|
265
|
+
arrays=tuple(new_arrays),
|
|
266
|
+
func=self.func,
|
|
267
|
+
depth=self.depth,
|
|
268
|
+
boundary=self.boundary,
|
|
269
|
+
trim_output=self.trim_output,
|
|
270
|
+
allow_rechunk=self.allow_rechunk,
|
|
271
|
+
kwargs=self.kwargs,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
if needs_trim:
|
|
275
|
+
# Apply trim slice to output
|
|
276
|
+
return SliceSlicesIntegers(new_expr, tuple(output_trim_index), parent.allow_getitem_optimization)
|
|
277
|
+
else:
|
|
278
|
+
return new_expr
|
|
279
|
+
|
|
280
|
+
def _lower(self):
|
|
281
|
+
"""Expand to the full overlap pipeline.
|
|
282
|
+
|
|
283
|
+
This expands to: rechunk -> boundaries -> overlap_internal -> map_blocks -> trim
|
|
284
|
+
"""
|
|
285
|
+
# Apply overlap to each input array
|
|
286
|
+
overlapped = []
|
|
287
|
+
for arr, d, b in zip(self.arrays, self.depth, self.boundary):
|
|
288
|
+
arr_coll = new_collection(arr)
|
|
289
|
+
overlapped_arr = overlap(arr_coll, depth=d, boundary=b, allow_rechunk=self.allow_rechunk)
|
|
290
|
+
overlapped.append(overlapped_arr.expr)
|
|
291
|
+
|
|
292
|
+
# Build map_blocks expression
|
|
293
|
+
result = map_blocks(self.func, *[new_collection(a) for a in overlapped], **self._kwargs)
|
|
294
|
+
|
|
295
|
+
if self.trim_output:
|
|
296
|
+
# Find highest-rank array for trim settings
|
|
297
|
+
i = sorted(enumerate(overlapped), key=lambda v: (v[1].ndim, -v[0]))[-1][0]
|
|
298
|
+
trim_depth = dict(self.depth[i])
|
|
299
|
+
trim_boundary = dict(self.boundary[i])
|
|
300
|
+
result = trim_internal(result, trim_depth, trim_boundary)
|
|
301
|
+
|
|
302
|
+
return result.expr
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def trim_overlap(x, depth, boundary=None):
|
|
306
|
+
"""Trim sides from each block.
|
|
307
|
+
|
|
308
|
+
This couples well with the ``map_overlap`` operation which may leave
|
|
309
|
+
excess data on each block.
|
|
310
|
+
|
|
311
|
+
See also
|
|
312
|
+
--------
|
|
313
|
+
dask.array.overlap.map_overlap
|
|
314
|
+
|
|
315
|
+
"""
|
|
316
|
+
|
|
317
|
+
# parameter to be passed to trim_internal
|
|
318
|
+
axes = coerce_depth(x.ndim, depth)
|
|
319
|
+
return trim_internal(x, axes=axes, boundary=boundary)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def trim_internal(x, axes, boundary=None):
|
|
323
|
+
"""Trim sides from each block
|
|
324
|
+
|
|
325
|
+
This couples well with the overlap operation, which may leave excess data on
|
|
326
|
+
each block
|
|
327
|
+
|
|
328
|
+
See also
|
|
329
|
+
--------
|
|
330
|
+
dask.array.chunk.trim
|
|
331
|
+
dask.array.map_blocks
|
|
332
|
+
"""
|
|
333
|
+
boundary = coerce_boundary(x.ndim, boundary)
|
|
334
|
+
|
|
335
|
+
olist = []
|
|
336
|
+
for i, bd in enumerate(x.chunks):
|
|
337
|
+
bdy = boundary.get(i, "none")
|
|
338
|
+
overlap = axes.get(i, 0)
|
|
339
|
+
ilist = []
|
|
340
|
+
for j, d in enumerate(bd):
|
|
341
|
+
if bdy != "none":
|
|
342
|
+
if isinstance(overlap, tuple):
|
|
343
|
+
d = d - sum(overlap)
|
|
344
|
+
else:
|
|
345
|
+
d = d - overlap * 2
|
|
346
|
+
|
|
347
|
+
elif isinstance(overlap, tuple):
|
|
348
|
+
d = d - overlap[0] if j != 0 else d
|
|
349
|
+
d = d - overlap[1] if j != len(bd) - 1 else d
|
|
350
|
+
else:
|
|
351
|
+
d = d - overlap if j != 0 else d
|
|
352
|
+
d = d - overlap if j != len(bd) - 1 else d
|
|
353
|
+
|
|
354
|
+
ilist.append(d)
|
|
355
|
+
olist.append(tuple(ilist))
|
|
356
|
+
chunks = tuple(olist)
|
|
357
|
+
|
|
358
|
+
return map_blocks(
|
|
359
|
+
partial(_trim, axes=axes, boundary=boundary),
|
|
360
|
+
x,
|
|
361
|
+
chunks=chunks,
|
|
362
|
+
dtype=x.dtype,
|
|
363
|
+
meta=x._meta,
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
def _trim(x, axes, boundary, _overlap_trim_info):
|
|
368
|
+
"""Similar to dask.array.chunk.trim but requires one to specify the
|
|
369
|
+
boundary condition.
|
|
370
|
+
|
|
371
|
+
``axes``, and ``boundary`` are assumed to have been coerced.
|
|
372
|
+
|
|
373
|
+
"""
|
|
374
|
+
chunk_location = _overlap_trim_info[0]
|
|
375
|
+
num_chunks = _overlap_trim_info[1]
|
|
376
|
+
axes = [axes.get(i, 0) for i in range(x.ndim)]
|
|
377
|
+
axes_front = (ax[0] if isinstance(ax, tuple) else ax for ax in axes)
|
|
378
|
+
axes_back = (
|
|
379
|
+
(-ax[1] if isinstance(ax, tuple) and ax[1] else -ax if isinstance(ax, Integral) and ax else None) for ax in axes
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
trim_front = (
|
|
383
|
+
0 if (chunk_location == 0 and boundary.get(i, "none") == "none") else ax
|
|
384
|
+
for i, (chunk_location, ax) in enumerate(zip(chunk_location, axes_front))
|
|
385
|
+
)
|
|
386
|
+
trim_back = (
|
|
387
|
+
(None if (chunk_location == chunks - 1 and boundary.get(i, "none") == "none") else ax)
|
|
388
|
+
for i, (chunks, chunk_location, ax) in enumerate(zip(num_chunks, chunk_location, axes_back))
|
|
389
|
+
)
|
|
390
|
+
ind = tuple(slice(front, back) for front, back in zip(trim_front, trim_back))
|
|
391
|
+
return x[ind]
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def periodic(x, axis, depth):
|
|
395
|
+
"""Copy a slice of an array around to its other side
|
|
396
|
+
|
|
397
|
+
Useful to create periodic boundary conditions for overlap
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
left = (slice(None, None, None),) * axis + (slice(0, depth),) + (slice(None, None, None),) * (x.ndim - axis - 1)
|
|
401
|
+
right = (
|
|
402
|
+
(slice(None, None, None),) * axis + (slice(-depth, None),) + (slice(None, None, None),) * (x.ndim - axis - 1)
|
|
403
|
+
)
|
|
404
|
+
l = x[left]
|
|
405
|
+
r = x[right]
|
|
406
|
+
|
|
407
|
+
l, r = _remove_overlap_boundaries(l, r, axis, depth)
|
|
408
|
+
|
|
409
|
+
return concatenate([r, x, l], axis=axis)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def reflect(x, axis, depth):
|
|
413
|
+
"""Reflect boundaries of array on the same side
|
|
414
|
+
|
|
415
|
+
This is the converse of ``periodic``
|
|
416
|
+
"""
|
|
417
|
+
if depth == 1:
|
|
418
|
+
left = (slice(None, None, None),) * axis + (slice(0, 1),) + (slice(None, None, None),) * (x.ndim - axis - 1)
|
|
419
|
+
else:
|
|
420
|
+
left = (
|
|
421
|
+
(slice(None, None, None),) * axis
|
|
422
|
+
+ (slice(depth - 1, None, -1),)
|
|
423
|
+
+ (slice(None, None, None),) * (x.ndim - axis - 1)
|
|
424
|
+
)
|
|
425
|
+
right = (
|
|
426
|
+
(slice(None, None, None),) * axis
|
|
427
|
+
+ (slice(-1, -depth - 1, -1),)
|
|
428
|
+
+ (slice(None, None, None),) * (x.ndim - axis - 1)
|
|
429
|
+
)
|
|
430
|
+
l = x[left]
|
|
431
|
+
r = x[right]
|
|
432
|
+
|
|
433
|
+
l, r = _remove_overlap_boundaries(l, r, axis, depth)
|
|
434
|
+
|
|
435
|
+
return concatenate([l, x, r], axis=axis)
|
|
436
|
+
|
|
437
|
+
|
|
438
|
+
def nearest(x, axis, depth):
|
|
439
|
+
"""Each reflect each boundary value outwards
|
|
440
|
+
|
|
441
|
+
This mimics what the skimage.filters.gaussian_filter(... mode="nearest")
|
|
442
|
+
does.
|
|
443
|
+
"""
|
|
444
|
+
left = (slice(None, None, None),) * axis + (slice(0, 1),) + (slice(None, None, None),) * (x.ndim - axis - 1)
|
|
445
|
+
right = (slice(None, None, None),) * axis + (slice(-1, -2, -1),) + (slice(None, None, None),) * (x.ndim - axis - 1)
|
|
446
|
+
|
|
447
|
+
l = repeat(x[left], depth, axis=axis)
|
|
448
|
+
r = repeat(x[right], depth, axis=axis)
|
|
449
|
+
|
|
450
|
+
l, r = _remove_overlap_boundaries(l, r, axis, depth)
|
|
451
|
+
|
|
452
|
+
return concatenate([l, x, r], axis=axis)
|
|
453
|
+
|
|
454
|
+
|
|
455
|
+
def constant(x, axis, depth, value):
|
|
456
|
+
"""Add constant slice to either side of array"""
|
|
457
|
+
chunks = list(x.chunks)
|
|
458
|
+
chunks[axis] = (depth,)
|
|
459
|
+
|
|
460
|
+
c = full_like(
|
|
461
|
+
x,
|
|
462
|
+
value,
|
|
463
|
+
shape=tuple(map(sum, chunks)),
|
|
464
|
+
chunks=tuple(chunks),
|
|
465
|
+
dtype=x.dtype,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
return concatenate([c, x, c], axis=axis)
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
def _remove_overlap_boundaries(l, r, axis, depth):
|
|
472
|
+
lchunks = list(l.chunks)
|
|
473
|
+
lchunks[axis] = (depth,)
|
|
474
|
+
rchunks = list(r.chunks)
|
|
475
|
+
rchunks[axis] = (depth,)
|
|
476
|
+
|
|
477
|
+
l = l.rechunk(tuple(lchunks))
|
|
478
|
+
r = r.rechunk(tuple(rchunks))
|
|
479
|
+
return l, r
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
def boundaries(x, depth=None, kind=None):
|
|
483
|
+
"""Add boundary conditions to an array before overlapping
|
|
484
|
+
|
|
485
|
+
See Also
|
|
486
|
+
--------
|
|
487
|
+
periodic
|
|
488
|
+
constant
|
|
489
|
+
"""
|
|
490
|
+
if not isinstance(kind, dict):
|
|
491
|
+
kind = dict.fromkeys(range(x.ndim), kind)
|
|
492
|
+
if not isinstance(depth, dict):
|
|
493
|
+
depth = dict.fromkeys(range(x.ndim), depth)
|
|
494
|
+
|
|
495
|
+
for i in range(x.ndim):
|
|
496
|
+
d = depth.get(i, 0)
|
|
497
|
+
if d == 0:
|
|
498
|
+
continue
|
|
499
|
+
|
|
500
|
+
this_kind = kind.get(i, "none")
|
|
501
|
+
if this_kind == "none":
|
|
502
|
+
continue
|
|
503
|
+
elif this_kind == "periodic":
|
|
504
|
+
x = periodic(x, i, d)
|
|
505
|
+
elif this_kind == "reflect":
|
|
506
|
+
x = reflect(x, i, d)
|
|
507
|
+
elif this_kind == "nearest":
|
|
508
|
+
x = nearest(x, i, d)
|
|
509
|
+
elif i in kind:
|
|
510
|
+
x = constant(x, i, d, kind[i])
|
|
511
|
+
|
|
512
|
+
return x
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
def ensure_minimum_chunksize(size, chunks):
|
|
516
|
+
"""Determine new chunks to ensure that every chunk >= size
|
|
517
|
+
|
|
518
|
+
Parameters
|
|
519
|
+
----------
|
|
520
|
+
size: int
|
|
521
|
+
The maximum size of any chunk.
|
|
522
|
+
chunks: tuple
|
|
523
|
+
Chunks along one axis, e.g. ``(3, 3, 2)``
|
|
524
|
+
|
|
525
|
+
Examples
|
|
526
|
+
--------
|
|
527
|
+
>>> ensure_minimum_chunksize(10, (20, 20, 1))
|
|
528
|
+
(20, 11, 10)
|
|
529
|
+
>>> ensure_minimum_chunksize(3, (1, 1, 3))
|
|
530
|
+
(5,)
|
|
531
|
+
|
|
532
|
+
See Also
|
|
533
|
+
--------
|
|
534
|
+
overlap
|
|
535
|
+
"""
|
|
536
|
+
if size <= min(chunks):
|
|
537
|
+
return chunks
|
|
538
|
+
|
|
539
|
+
# add too-small chunks to chunks before them
|
|
540
|
+
output = []
|
|
541
|
+
new = 0
|
|
542
|
+
for c in chunks:
|
|
543
|
+
if c < size:
|
|
544
|
+
if new > size + (size - c):
|
|
545
|
+
output.append(new - (size - c))
|
|
546
|
+
new = size
|
|
547
|
+
else:
|
|
548
|
+
new += c
|
|
549
|
+
if new >= size:
|
|
550
|
+
output.append(new)
|
|
551
|
+
new = 0
|
|
552
|
+
if c >= size:
|
|
553
|
+
new += c
|
|
554
|
+
if new >= size:
|
|
555
|
+
output.append(new)
|
|
556
|
+
elif len(output) >= 1:
|
|
557
|
+
output[-1] += new
|
|
558
|
+
else:
|
|
559
|
+
raise ValueError(f"The overlapping depth {size} is larger than your array {sum(chunks)}.")
|
|
560
|
+
|
|
561
|
+
return tuple(output)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
def _get_overlap_rechunked_chunks(x, depth2):
|
|
565
|
+
depths = [max(d) if isinstance(d, tuple) else d for d in depth2.values()]
|
|
566
|
+
# rechunk if new chunks are needed to fit depth in every chunk
|
|
567
|
+
return tuple(ensure_minimum_chunksize(size, c) for size, c in zip(depths, x.chunks))
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def overlap(x, depth, boundary, *, allow_rechunk=True):
|
|
571
|
+
"""Share boundaries between neighboring blocks
|
|
572
|
+
|
|
573
|
+
Parameters
|
|
574
|
+
----------
|
|
575
|
+
|
|
576
|
+
x: da.Array
|
|
577
|
+
A dask array
|
|
578
|
+
depth: dict
|
|
579
|
+
The size of the shared boundary per axis
|
|
580
|
+
boundary: dict
|
|
581
|
+
The boundary condition on each axis. Options are 'reflect', 'periodic',
|
|
582
|
+
'nearest', 'none', or an array value. Such a value will fill the
|
|
583
|
+
boundary with that value.
|
|
584
|
+
allow_rechunk: bool, keyword only
|
|
585
|
+
Allows rechunking, otherwise chunk sizes need to match and core
|
|
586
|
+
dimensions are to consist only of one chunk.
|
|
587
|
+
|
|
588
|
+
The depth input informs how many cells to overlap between neighboring
|
|
589
|
+
blocks ``{0: 2, 2: 5}`` means share two cells in 0 axis, 5 cells in 2 axis.
|
|
590
|
+
Axes missing from this input will not be overlapped.
|
|
591
|
+
|
|
592
|
+
Any axis containing chunks smaller than depth will be rechunked if
|
|
593
|
+
possible, provided the keyword ``allow_rechunk`` is True (recommended).
|
|
594
|
+
|
|
595
|
+
Examples
|
|
596
|
+
--------
|
|
597
|
+
>>> import numpy as np
|
|
598
|
+
>>> import dask_array as da
|
|
599
|
+
|
|
600
|
+
>>> x = np.arange(64).reshape((8, 8))
|
|
601
|
+
>>> d = da.from_array(x, chunks=(4, 4))
|
|
602
|
+
>>> d.chunks
|
|
603
|
+
((4, 4), (4, 4))
|
|
604
|
+
|
|
605
|
+
>>> g = da.overlap.overlap(d, depth={0: 2, 1: 1},
|
|
606
|
+
... boundary={0: 100, 1: 'reflect'})
|
|
607
|
+
>>> g.chunks
|
|
608
|
+
((8, 8), (6, 6))
|
|
609
|
+
|
|
610
|
+
>>> np.array(g)
|
|
611
|
+
array([[100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
|
|
612
|
+
[100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
|
|
613
|
+
[ 0, 0, 1, 2, 3, 4, 3, 4, 5, 6, 7, 7],
|
|
614
|
+
[ 8, 8, 9, 10, 11, 12, 11, 12, 13, 14, 15, 15],
|
|
615
|
+
[ 16, 16, 17, 18, 19, 20, 19, 20, 21, 22, 23, 23],
|
|
616
|
+
[ 24, 24, 25, 26, 27, 28, 27, 28, 29, 30, 31, 31],
|
|
617
|
+
[ 32, 32, 33, 34, 35, 36, 35, 36, 37, 38, 39, 39],
|
|
618
|
+
[ 40, 40, 41, 42, 43, 44, 43, 44, 45, 46, 47, 47],
|
|
619
|
+
[ 16, 16, 17, 18, 19, 20, 19, 20, 21, 22, 23, 23],
|
|
620
|
+
[ 24, 24, 25, 26, 27, 28, 27, 28, 29, 30, 31, 31],
|
|
621
|
+
[ 32, 32, 33, 34, 35, 36, 35, 36, 37, 38, 39, 39],
|
|
622
|
+
[ 40, 40, 41, 42, 43, 44, 43, 44, 45, 46, 47, 47],
|
|
623
|
+
[ 48, 48, 49, 50, 51, 52, 51, 52, 53, 54, 55, 55],
|
|
624
|
+
[ 56, 56, 57, 58, 59, 60, 59, 60, 61, 62, 63, 63],
|
|
625
|
+
[100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100],
|
|
626
|
+
[100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100, 100]])
|
|
627
|
+
"""
|
|
628
|
+
depth2 = coerce_depth(x.ndim, depth)
|
|
629
|
+
boundary2 = coerce_boundary(x.ndim, boundary)
|
|
630
|
+
|
|
631
|
+
depths = [max(d) if isinstance(d, tuple) else d for d in depth2.values()]
|
|
632
|
+
if allow_rechunk:
|
|
633
|
+
# rechunk if new chunks are needed to fit depth in every chunk
|
|
634
|
+
x1 = x.rechunk(_get_overlap_rechunked_chunks(x, depth2)) # this is a no-op if x.chunks == new_chunks
|
|
635
|
+
|
|
636
|
+
else:
|
|
637
|
+
original_chunks_too_small = any(min(c) < d for d, c in zip(depths, x.chunks))
|
|
638
|
+
if original_chunks_too_small:
|
|
639
|
+
raise ValueError(
|
|
640
|
+
"Overlap depth is larger than smallest chunksize.\n"
|
|
641
|
+
"Please set allow_rechunk=True to rechunk automatically.\n"
|
|
642
|
+
f"Overlap depths required: {depths}\n"
|
|
643
|
+
f"Input chunks: {x.chunks}\n"
|
|
644
|
+
)
|
|
645
|
+
x1 = x
|
|
646
|
+
|
|
647
|
+
x2 = boundaries(x1, depth2, boundary2)
|
|
648
|
+
x3 = overlap_internal(x2, depth2)
|
|
649
|
+
trim = {k: v * 2 if boundary2.get(k, "none") != "none" else 0 for k, v in depth2.items()}
|
|
650
|
+
x4 = chunk.trim(x3, trim)
|
|
651
|
+
return x4
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
def add_dummy_padding(x, depth, boundary):
|
|
655
|
+
"""
|
|
656
|
+
Pads an array which has 'none' as the boundary type.
|
|
657
|
+
Used to simplify trimming arrays which use 'none'.
|
|
658
|
+
|
|
659
|
+
>>> import dask_array as da
|
|
660
|
+
>>> x = da.arange(6, chunks=3)
|
|
661
|
+
>>> add_dummy_padding(x, {0: 1}, {0: 'none'}).compute() # doctest: +NORMALIZE_WHITESPACE
|
|
662
|
+
array([..., 0, 1, 2, 3, 4, 5, ...])
|
|
663
|
+
"""
|
|
664
|
+
for k, v in boundary.items():
|
|
665
|
+
d = depth.get(k, 0)
|
|
666
|
+
if v == "none" and d > 0:
|
|
667
|
+
empty_shape = list(x.shape)
|
|
668
|
+
empty_shape[k] = d
|
|
669
|
+
|
|
670
|
+
empty_chunks = list(x.chunks)
|
|
671
|
+
empty_chunks[k] = (d,)
|
|
672
|
+
|
|
673
|
+
empty = empty_like(
|
|
674
|
+
getattr(x, "_meta", x),
|
|
675
|
+
shape=empty_shape,
|
|
676
|
+
chunks=empty_chunks,
|
|
677
|
+
dtype=x.dtype,
|
|
678
|
+
)
|
|
679
|
+
|
|
680
|
+
out_chunks = list(x.chunks)
|
|
681
|
+
ax_chunks = list(out_chunks[k])
|
|
682
|
+
ax_chunks[0] += d
|
|
683
|
+
ax_chunks[-1] += d
|
|
684
|
+
out_chunks[k] = tuple(ax_chunks)
|
|
685
|
+
|
|
686
|
+
x = concatenate([empty, x, empty], axis=k)
|
|
687
|
+
x = x.rechunk(out_chunks)
|
|
688
|
+
return x
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
def _map_overlap_direct(func, args, depth, boundary, trim, allow_rechunk, kwargs):
|
|
692
|
+
"""Direct implementation of map_overlap without MapOverlap.
|
|
693
|
+
|
|
694
|
+
Used for cases with new_axis/drop_axis where MapOverlap doesn't apply.
|
|
695
|
+
"""
|
|
696
|
+
# Apply overlap to each input array
|
|
697
|
+
overlapped = []
|
|
698
|
+
for x, d, b in zip(args, depth, boundary):
|
|
699
|
+
overlapped.append(overlap(x, depth=d, boundary=b, allow_rechunk=allow_rechunk))
|
|
700
|
+
|
|
701
|
+
# Apply the function via map_blocks
|
|
702
|
+
result = map_blocks(func, *overlapped, **kwargs)
|
|
703
|
+
|
|
704
|
+
if trim:
|
|
705
|
+
# Find highest-rank array for trim settings
|
|
706
|
+
i = sorted(enumerate(overlapped), key=lambda v: (v[1].ndim, -v[0]))[-1][0]
|
|
707
|
+
trim_depth = dict(depth[i])
|
|
708
|
+
trim_boundary = dict(boundary[i])
|
|
709
|
+
|
|
710
|
+
# Handle drop_axis
|
|
711
|
+
drop_axis = kwargs.get("drop_axis")
|
|
712
|
+
if drop_axis is not None:
|
|
713
|
+
if isinstance(drop_axis, Number):
|
|
714
|
+
drop_axis = [drop_axis]
|
|
715
|
+
ndim_out = max(a.ndim for a in overlapped)
|
|
716
|
+
drop_axis = [d % ndim_out for d in drop_axis]
|
|
717
|
+
kept_axes = tuple(ax for ax in range(overlapped[i].ndim) if ax not in drop_axis)
|
|
718
|
+
trim_depth = {n: trim_depth[ax] for n, ax in enumerate(kept_axes)}
|
|
719
|
+
trim_boundary = {n: trim_boundary[ax] for n, ax in enumerate(kept_axes)}
|
|
720
|
+
|
|
721
|
+
# Handle new_axis
|
|
722
|
+
new_axis = kwargs.get("new_axis")
|
|
723
|
+
if new_axis is not None:
|
|
724
|
+
if isinstance(new_axis, Number):
|
|
725
|
+
new_axis = [new_axis]
|
|
726
|
+
ndim_out = max(a.ndim for a in overlapped)
|
|
727
|
+
new_axis = [d % ndim_out for d in new_axis]
|
|
728
|
+
|
|
729
|
+
for axis in new_axis:
|
|
730
|
+
for existing_axis in list(trim_depth.keys()):
|
|
731
|
+
if existing_axis >= axis:
|
|
732
|
+
trim_depth[existing_axis + 1] = trim_depth[existing_axis]
|
|
733
|
+
trim_boundary[existing_axis + 1] = trim_boundary[existing_axis]
|
|
734
|
+
trim_depth[axis] = 0
|
|
735
|
+
trim_boundary[axis] = "none"
|
|
736
|
+
|
|
737
|
+
result = trim_internal(result, trim_depth, trim_boundary)
|
|
738
|
+
|
|
739
|
+
return result
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def map_overlap(
|
|
743
|
+
func,
|
|
744
|
+
*args,
|
|
745
|
+
depth=None,
|
|
746
|
+
boundary=None,
|
|
747
|
+
trim=True,
|
|
748
|
+
align_arrays=True,
|
|
749
|
+
allow_rechunk=True,
|
|
750
|
+
**kwargs,
|
|
751
|
+
):
|
|
752
|
+
"""Map a function over blocks of arrays with some overlap
|
|
753
|
+
|
|
754
|
+
We share neighboring zones between blocks of the array, map a
|
|
755
|
+
function, and then trim away the neighboring strips. If depth is
|
|
756
|
+
larger than any chunk along a particular axis, then the array is
|
|
757
|
+
rechunked.
|
|
758
|
+
|
|
759
|
+
Note that this function will attempt to automatically determine the output
|
|
760
|
+
array type before computing it, please refer to the ``meta`` keyword argument
|
|
761
|
+
in ``map_blocks`` if you expect that the function will not succeed when
|
|
762
|
+
operating on 0-d arrays.
|
|
763
|
+
|
|
764
|
+
Parameters
|
|
765
|
+
----------
|
|
766
|
+
func: function
|
|
767
|
+
The function to apply to each extended block.
|
|
768
|
+
If multiple arrays are provided, then the function should expect to
|
|
769
|
+
receive chunks of each array in the same order.
|
|
770
|
+
args : dask arrays
|
|
771
|
+
depth: int, tuple, dict or list, keyword only
|
|
772
|
+
The number of elements that each block should share with its neighbors
|
|
773
|
+
If a tuple or dict then this can be different per axis.
|
|
774
|
+
If a list then each element of that list must be an int, tuple or dict
|
|
775
|
+
defining depth for the corresponding array in `args`.
|
|
776
|
+
Asymmetric depths may be specified using a dict value of (-/+) tuples.
|
|
777
|
+
Note that asymmetric depths are currently only supported when
|
|
778
|
+
``boundary`` is 'none'.
|
|
779
|
+
The default value is 0.
|
|
780
|
+
boundary: str, tuple, dict or list, keyword only
|
|
781
|
+
How to handle the boundaries.
|
|
782
|
+
Values include 'reflect', 'periodic', 'nearest', 'none',
|
|
783
|
+
or any constant value like 0 or np.nan.
|
|
784
|
+
If a list then each element must be a str, tuple or dict defining the
|
|
785
|
+
boundary for the corresponding array in `args`.
|
|
786
|
+
The default value is 'reflect'.
|
|
787
|
+
trim: bool, keyword only
|
|
788
|
+
Whether or not to trim ``depth`` elements from each block after
|
|
789
|
+
calling the map function.
|
|
790
|
+
Set this to False if your mapping function already does this for you
|
|
791
|
+
align_arrays: bool, keyword only
|
|
792
|
+
Whether or not to align chunks along equally sized dimensions when
|
|
793
|
+
multiple arrays are provided. This allows for larger chunks in some
|
|
794
|
+
arrays to be broken into smaller ones that match chunk sizes in other
|
|
795
|
+
arrays such that they are compatible for block function mapping. If
|
|
796
|
+
this is false, then an error will be thrown if arrays do not already
|
|
797
|
+
have the same number of blocks in each dimension.
|
|
798
|
+
allow_rechunk: bool, keyword only
|
|
799
|
+
Allows rechunking, otherwise chunk sizes need to match and core
|
|
800
|
+
dimensions are to consist only of one chunk.
|
|
801
|
+
**kwargs:
|
|
802
|
+
Other keyword arguments valid in ``map_blocks``
|
|
803
|
+
|
|
804
|
+
Examples
|
|
805
|
+
--------
|
|
806
|
+
>>> import numpy as np
|
|
807
|
+
>>> import dask_array as da
|
|
808
|
+
|
|
809
|
+
>>> x = np.array([1, 1, 2, 3, 3, 3, 2, 1, 1])
|
|
810
|
+
>>> x = da.from_array(x, chunks=5)
|
|
811
|
+
>>> def derivative(x):
|
|
812
|
+
... return x - np.roll(x, 1)
|
|
813
|
+
|
|
814
|
+
>>> y = x.map_overlap(derivative, depth=1, boundary=0)
|
|
815
|
+
>>> y.compute()
|
|
816
|
+
array([ 1, 0, 1, 1, 0, 0, -1, -1, 0])
|
|
817
|
+
|
|
818
|
+
>>> x = np.arange(16).reshape((4, 4))
|
|
819
|
+
>>> d = da.from_array(x, chunks=(2, 2))
|
|
820
|
+
>>> d.map_overlap(lambda x: x + x.size, depth=1, boundary='reflect').compute()
|
|
821
|
+
array([[16, 17, 18, 19],
|
|
822
|
+
[20, 21, 22, 23],
|
|
823
|
+
[24, 25, 26, 27],
|
|
824
|
+
[28, 29, 30, 31]])
|
|
825
|
+
|
|
826
|
+
>>> func = lambda x: x + x.size
|
|
827
|
+
>>> depth = {0: 1, 1: 1}
|
|
828
|
+
>>> boundary = {0: 'reflect', 1: 'none'}
|
|
829
|
+
>>> d.map_overlap(func, depth, boundary).compute() # doctest: +NORMALIZE_WHITESPACE
|
|
830
|
+
array([[12, 13, 14, 15],
|
|
831
|
+
[16, 17, 18, 19],
|
|
832
|
+
[20, 21, 22, 23],
|
|
833
|
+
[24, 25, 26, 27]])
|
|
834
|
+
|
|
835
|
+
The ``da.map_overlap`` function can also accept multiple arrays.
|
|
836
|
+
|
|
837
|
+
>>> func = lambda x, y: x + y
|
|
838
|
+
>>> x = da.arange(8).reshape(2, 4).rechunk((1, 2))
|
|
839
|
+
>>> y = da.arange(4).rechunk(2)
|
|
840
|
+
>>> da.map_overlap(func, x, y, depth=1, boundary='reflect').compute() # doctest: +NORMALIZE_WHITESPACE
|
|
841
|
+
array([[ 0, 2, 4, 6],
|
|
842
|
+
[ 4, 6, 8, 10]])
|
|
843
|
+
|
|
844
|
+
When multiple arrays are given, they do not need to have the
|
|
845
|
+
same number of dimensions but they must broadcast together.
|
|
846
|
+
Arrays are aligned block by block (just as in ``da.map_blocks``)
|
|
847
|
+
so the blocks must have a common chunk size. This common chunking
|
|
848
|
+
is determined automatically as long as ``align_arrays`` is True.
|
|
849
|
+
|
|
850
|
+
>>> x = da.arange(8, chunks=4)
|
|
851
|
+
>>> y = da.arange(8, chunks=2)
|
|
852
|
+
>>> r = da.map_overlap(func, x, y, depth=1, boundary='reflect', align_arrays=True)
|
|
853
|
+
>>> len(r.to_delayed())
|
|
854
|
+
4
|
|
855
|
+
|
|
856
|
+
>>> da.map_overlap(func, x, y, depth=1, boundary='reflect', align_arrays=False).compute()
|
|
857
|
+
Traceback (most recent call last):
|
|
858
|
+
...
|
|
859
|
+
ValueError: Shapes do not align {'.0': {2, 4}}
|
|
860
|
+
|
|
861
|
+
Note also that this function is equivalent to ``map_blocks``
|
|
862
|
+
by default. A non-zero ``depth`` must be defined for any
|
|
863
|
+
overlap to appear in the arrays provided to ``func``.
|
|
864
|
+
|
|
865
|
+
>>> func = lambda x: x.sum()
|
|
866
|
+
>>> x = da.ones(10, dtype='int')
|
|
867
|
+
>>> block_args = dict(chunks=(), drop_axis=0)
|
|
868
|
+
>>> da.map_blocks(func, x, **block_args).compute()
|
|
869
|
+
np.int64(10)
|
|
870
|
+
>>> da.map_overlap(func, x, **block_args, boundary='reflect').compute()
|
|
871
|
+
np.int64(10)
|
|
872
|
+
>>> da.map_overlap(func, x, **block_args, depth=1, boundary='reflect').compute()
|
|
873
|
+
np.int64(12)
|
|
874
|
+
|
|
875
|
+
For functions that may not handle 0-d arrays, it's also possible to specify
|
|
876
|
+
``meta`` with an empty array matching the type of the expected result. In
|
|
877
|
+
the example below, ``func`` will result in an ``IndexError`` when computing
|
|
878
|
+
``meta``:
|
|
879
|
+
|
|
880
|
+
>>> x = np.arange(16).reshape((4, 4))
|
|
881
|
+
>>> d = da.from_array(x, chunks=(2, 2))
|
|
882
|
+
>>> y = d.map_overlap(lambda x: x + x[2], depth=1, boundary='reflect', meta=np.array(()))
|
|
883
|
+
>>> y
|
|
884
|
+
dask.array<_trim, shape=(4, 4), dtype=float64, chunksize=(2, 2), chunktype=numpy.ndarray>
|
|
885
|
+
>>> y.compute()
|
|
886
|
+
array([[ 4, 6, 8, 10],
|
|
887
|
+
[ 8, 10, 12, 14],
|
|
888
|
+
[20, 22, 24, 26],
|
|
889
|
+
[24, 26, 28, 30]])
|
|
890
|
+
|
|
891
|
+
Similarly, it's possible to specify a non-NumPy array to ``meta``:
|
|
892
|
+
|
|
893
|
+
>>> import cupy # doctest: +SKIP
|
|
894
|
+
>>> x = cupy.arange(16).reshape((4, 4)) # doctest: +SKIP
|
|
895
|
+
>>> d = da.from_array(x, chunks=(2, 2)) # doctest: +SKIP
|
|
896
|
+
>>> y = d.map_overlap(lambda x: x + x[2], depth=1, boundary='reflect', meta=cupy.array(())) # doctest: +SKIP
|
|
897
|
+
>>> y # doctest: +SKIP
|
|
898
|
+
dask.array<_trim, shape=(4, 4), dtype=float64, chunksize=(2, 2), chunktype=cupy.ndarray>
|
|
899
|
+
>>> y.compute() # doctest: +SKIP
|
|
900
|
+
array([[ 4, 6, 8, 10],
|
|
901
|
+
[ 8, 10, 12, 14],
|
|
902
|
+
[20, 22, 24, 26],
|
|
903
|
+
[24, 26, 28, 30]])
|
|
904
|
+
"""
|
|
905
|
+
# Look for invocation using deprecated single-array signature
|
|
906
|
+
# map_overlap(x, func, depth, boundary=None, trim=True, **kwargs)
|
|
907
|
+
if isinstance(func, Array) and callable(args[0]):
|
|
908
|
+
warnings.warn(
|
|
909
|
+
"The use of map_overlap(array, func, **kwargs) is deprecated since dask 2.17.0 "
|
|
910
|
+
"and will be an error in a future release. To silence this warning, use the syntax "
|
|
911
|
+
"map_overlap(func, array0,[ array1, ...,] **kwargs) instead.",
|
|
912
|
+
FutureWarning,
|
|
913
|
+
)
|
|
914
|
+
sig = ["func", "depth", "boundary", "trim"]
|
|
915
|
+
depth = get(sig.index("depth"), args, depth)
|
|
916
|
+
boundary = get(sig.index("boundary"), args, boundary)
|
|
917
|
+
trim = get(sig.index("trim"), args, trim)
|
|
918
|
+
func, args = args[0], [func]
|
|
919
|
+
|
|
920
|
+
if not callable(func):
|
|
921
|
+
raise TypeError(
|
|
922
|
+
f"First argument must be callable function, not {type(func).__name__}\n"
|
|
923
|
+
"Usage: da.map_overlap(function, x)\n"
|
|
924
|
+
" or: da.map_overlap(function, x, y, z)"
|
|
925
|
+
)
|
|
926
|
+
if not all(isinstance(x, Array) for x in args):
|
|
927
|
+
raise TypeError(
|
|
928
|
+
f"All variadic arguments must be arrays, not {[type(x).__name__ for x in args]}\n"
|
|
929
|
+
"Usage: da.map_overlap(function, x)\n"
|
|
930
|
+
" or: da.map_overlap(function, x, y, z)"
|
|
931
|
+
)
|
|
932
|
+
|
|
933
|
+
# Coerce depth and boundary arguments to lists of individual
|
|
934
|
+
# specifications for each array argument
|
|
935
|
+
def coerce(xs, arg, fn):
|
|
936
|
+
if not isinstance(arg, list):
|
|
937
|
+
arg = [arg] * len(xs)
|
|
938
|
+
return [fn(x.ndim, a) for x, a in zip(xs, arg)]
|
|
939
|
+
|
|
940
|
+
depth = coerce(args, depth, coerce_depth)
|
|
941
|
+
boundary = coerce(args, boundary, coerce_boundary)
|
|
942
|
+
|
|
943
|
+
# Align chunks in each array to a common size
|
|
944
|
+
if align_arrays:
|
|
945
|
+
# Reverse unification order to allow block broadcasting
|
|
946
|
+
inds = [list(reversed(range(x.ndim))) for x in args]
|
|
947
|
+
args = [a.expr for a in args]
|
|
948
|
+
_, args, _ = unify_chunks_expr(*list(concat(zip(args, inds))))
|
|
949
|
+
args = [new_collection(a) for a in args]
|
|
950
|
+
|
|
951
|
+
# Escape to map_blocks if depth is zero (a more efficient computation)
|
|
952
|
+
if all(all(depth_val == 0 for depth_val in d.values()) for d in depth):
|
|
953
|
+
return map_blocks(func, *args, **kwargs)
|
|
954
|
+
|
|
955
|
+
for i, x in enumerate(args):
|
|
956
|
+
for j in range(x.ndim):
|
|
957
|
+
if isinstance(depth[i][j], tuple) and boundary[i][j] != "none":
|
|
958
|
+
raise NotImplementedError(
|
|
959
|
+
"Asymmetric overlap is currently only implemented "
|
|
960
|
+
"for boundary='none', however boundary for dimension "
|
|
961
|
+
f"{j} in array argument {i} is {boundary[i][j]}"
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
def assert_int_chunksize(xs):
|
|
965
|
+
assert all(type(c) is int for x in xs for cc in x.chunks for c in cc)
|
|
966
|
+
|
|
967
|
+
assert_int_chunksize(args)
|
|
968
|
+
|
|
969
|
+
# Validate chunk sizes if rechunking is not allowed
|
|
970
|
+
if not allow_rechunk:
|
|
971
|
+
for x, d in zip(args, depth):
|
|
972
|
+
depths = [max(dd) if isinstance(dd, tuple) else dd for dd in d.values()]
|
|
973
|
+
original_chunks_too_small = any(min(c) < dd for dd, c in zip(depths, x.chunks))
|
|
974
|
+
if original_chunks_too_small:
|
|
975
|
+
raise ValueError(
|
|
976
|
+
"Overlap depth is larger than smallest chunksize.\n"
|
|
977
|
+
"Please set allow_rechunk=True to rechunk automatically.\n"
|
|
978
|
+
f"Overlap depths required: {depths}\n"
|
|
979
|
+
f"Input chunks: {x.chunks}\n"
|
|
980
|
+
)
|
|
981
|
+
|
|
982
|
+
# Fall back to direct implementation for complex cases:
|
|
983
|
+
# - new_axis/drop_axis: change dimensionality
|
|
984
|
+
# - explicit chunks: change output shape/chunks
|
|
985
|
+
if kwargs.get("new_axis") is not None or kwargs.get("drop_axis") is not None or kwargs.get("chunks") is not None:
|
|
986
|
+
return _map_overlap_direct(func, args, depth, boundary, trim, allow_rechunk, kwargs)
|
|
987
|
+
|
|
988
|
+
# Create the logical MapOverlap
|
|
989
|
+
# It will be lowered to the full pipeline during optimization
|
|
990
|
+
return new_collection(
|
|
991
|
+
MapOverlap(
|
|
992
|
+
arrays=tuple(a.expr for a in args),
|
|
993
|
+
func=func,
|
|
994
|
+
depth=depth,
|
|
995
|
+
boundary=boundary,
|
|
996
|
+
trim_output=trim,
|
|
997
|
+
allow_rechunk=allow_rechunk,
|
|
998
|
+
kwargs=kwargs if kwargs else None,
|
|
999
|
+
)
|
|
1000
|
+
)
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
def coerce_depth(ndim, depth):
|
|
1004
|
+
default = 0
|
|
1005
|
+
if depth is None:
|
|
1006
|
+
depth = default
|
|
1007
|
+
if isinstance(depth, Integral):
|
|
1008
|
+
depth = (depth,) * ndim
|
|
1009
|
+
if isinstance(depth, tuple):
|
|
1010
|
+
depth = dict(zip(range(ndim), depth))
|
|
1011
|
+
if isinstance(depth, dict):
|
|
1012
|
+
depth = {ax: depth.get(ax, default) for ax in range(ndim)}
|
|
1013
|
+
return coerce_depth_type(ndim, depth)
|
|
1014
|
+
|
|
1015
|
+
|
|
1016
|
+
def coerce_depth_type(ndim, depth):
|
|
1017
|
+
for i in range(ndim):
|
|
1018
|
+
if isinstance(depth[i], tuple):
|
|
1019
|
+
depth[i] = tuple(int(d) for d in depth[i])
|
|
1020
|
+
else:
|
|
1021
|
+
depth[i] = int(depth[i])
|
|
1022
|
+
return depth
|
|
1023
|
+
|
|
1024
|
+
|
|
1025
|
+
def coerce_boundary(ndim, boundary):
|
|
1026
|
+
default = "none"
|
|
1027
|
+
if boundary is None:
|
|
1028
|
+
boundary = default
|
|
1029
|
+
if not isinstance(boundary, (tuple, dict)):
|
|
1030
|
+
boundary = (boundary,) * ndim
|
|
1031
|
+
if isinstance(boundary, tuple):
|
|
1032
|
+
boundary = dict(zip(range(ndim), boundary))
|
|
1033
|
+
if isinstance(boundary, dict):
|
|
1034
|
+
boundary = {ax: boundary.get(ax, default) for ax in range(ndim)}
|
|
1035
|
+
return boundary
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
@derived_from(np.lib.stride_tricks)
|
|
1039
|
+
def sliding_window_view(x, window_shape, axis=None, automatic_rechunk=True):
|
|
1040
|
+
window_shape = tuple(window_shape) if np.iterable(window_shape) else (window_shape,)
|
|
1041
|
+
|
|
1042
|
+
window_shape_array = np.array(window_shape)
|
|
1043
|
+
if np.any(window_shape_array <= 0):
|
|
1044
|
+
raise ValueError("`window_shape` must contain values > 0")
|
|
1045
|
+
|
|
1046
|
+
if axis is None:
|
|
1047
|
+
axis = tuple(range(x.ndim))
|
|
1048
|
+
if len(window_shape) != len(axis):
|
|
1049
|
+
raise ValueError(
|
|
1050
|
+
f"Since axis is `None`, must provide "
|
|
1051
|
+
f"window_shape for all dimensions of `x`; "
|
|
1052
|
+
f"got {len(window_shape)} window_shape elements "
|
|
1053
|
+
f"and `x.ndim` is {x.ndim}."
|
|
1054
|
+
)
|
|
1055
|
+
else:
|
|
1056
|
+
axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True)
|
|
1057
|
+
if len(window_shape) != len(axis):
|
|
1058
|
+
raise ValueError(
|
|
1059
|
+
f"Must provide matching length window_shape and "
|
|
1060
|
+
f"axis; got {len(window_shape)} window_shape "
|
|
1061
|
+
f"elements and {len(axis)} axes elements."
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
depths = [0] * x.ndim
|
|
1065
|
+
for ax, window in zip(axis, window_shape):
|
|
1066
|
+
depths[ax] += window - 1
|
|
1067
|
+
|
|
1068
|
+
# Ensure that each chunk is big enough to leave at least a size-1 chunk
|
|
1069
|
+
# after windowing (this is only really necessary for the last chunk).
|
|
1070
|
+
safe_chunks = list(ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks))
|
|
1071
|
+
if automatic_rechunk:
|
|
1072
|
+
safe_chunks = [s if d != 0 else c for d, c, s in zip(depths, x.chunks, safe_chunks)]
|
|
1073
|
+
# safe chunks is our output chunks, so add the new dimensions
|
|
1074
|
+
safe_chunks.extend([(w,) for w in window_shape])
|
|
1075
|
+
max_chunk = reduce(mul, map(max, x.chunks))
|
|
1076
|
+
new_chunks = _calculate_new_chunksizes(
|
|
1077
|
+
x.chunks,
|
|
1078
|
+
safe_chunks.copy(),
|
|
1079
|
+
{i for i, d in enumerate(depths) if d == 0},
|
|
1080
|
+
max_chunk,
|
|
1081
|
+
)
|
|
1082
|
+
x = x.rechunk(tuple(new_chunks))
|
|
1083
|
+
else:
|
|
1084
|
+
x = x.rechunk(tuple(safe_chunks))
|
|
1085
|
+
|
|
1086
|
+
# result.shape = x_shape_trimmed + window_shape,
|
|
1087
|
+
# where x_shape_trimmed is x.shape with every entry
|
|
1088
|
+
# reduced by one less than the corresponding window size.
|
|
1089
|
+
# trim chunks to match x_shape_trimmed
|
|
1090
|
+
newchunks = tuple(c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks)) + tuple(
|
|
1091
|
+
(window,) for window in window_shape
|
|
1092
|
+
)
|
|
1093
|
+
|
|
1094
|
+
return map_overlap(
|
|
1095
|
+
np.lib.stride_tricks.sliding_window_view,
|
|
1096
|
+
x,
|
|
1097
|
+
depth=tuple((0, d) for d in depths), # Overlap on +ve side only
|
|
1098
|
+
boundary="none",
|
|
1099
|
+
meta=x._meta,
|
|
1100
|
+
new_axis=range(x.ndim, x.ndim + len(axis)),
|
|
1101
|
+
chunks=newchunks,
|
|
1102
|
+
trim=False,
|
|
1103
|
+
align_arrays=False,
|
|
1104
|
+
window_shape=window_shape,
|
|
1105
|
+
axis=axis,
|
|
1106
|
+
)
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
def _fill_with_last_one(a, b):
|
|
1110
|
+
"""Fill NaN values in b with values from a."""
|
|
1111
|
+
return np.where(~np.isnan(b), b, a)
|
|
1112
|
+
|
|
1113
|
+
|
|
1114
|
+
def _push(array, n=None, axis=-1):
|
|
1115
|
+
"""Apply bottleneck.push to a single chunk."""
|
|
1116
|
+
import bottleneck as bn
|
|
1117
|
+
|
|
1118
|
+
limit = n if n is not None else array.shape[axis]
|
|
1119
|
+
return bn.push(array, limit, axis)
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def push(array, n, axis):
|
|
1123
|
+
"""
|
|
1124
|
+
Dask-version of bottleneck.push
|
|
1125
|
+
|
|
1126
|
+
Forward fill NaN values along an axis.
|
|
1127
|
+
|
|
1128
|
+
.. note::
|
|
1129
|
+
|
|
1130
|
+
Requires bottleneck to be installed.
|
|
1131
|
+
"""
|
|
1132
|
+
import dask_array as da
|
|
1133
|
+
from dask._compatibility import import_optional_dependency
|
|
1134
|
+
|
|
1135
|
+
import_optional_dependency("bottleneck", min_version="1.3.7")
|
|
1136
|
+
|
|
1137
|
+
if n is not None and 0 < n < array.shape[axis] - 1:
|
|
1138
|
+
arr = da.broadcast_to(
|
|
1139
|
+
da.arange(array.shape[axis], chunks=array.chunks[axis], dtype=array.dtype).reshape(
|
|
1140
|
+
tuple(size if i == axis else 1 for i, size in enumerate(array.shape))
|
|
1141
|
+
),
|
|
1142
|
+
array.shape,
|
|
1143
|
+
array.chunks,
|
|
1144
|
+
)
|
|
1145
|
+
valid_arange = da.where(da.notnull(array), arr, np.nan)
|
|
1146
|
+
valid_limits = (arr - push(valid_arange, None, axis)) <= n
|
|
1147
|
+
# omit the forward fill that violate the limit
|
|
1148
|
+
return da.where(valid_limits, push(array, None, axis), np.nan)
|
|
1149
|
+
|
|
1150
|
+
from dask_array.reductions import cumreduction
|
|
1151
|
+
|
|
1152
|
+
return cumreduction(
|
|
1153
|
+
func=_push,
|
|
1154
|
+
binop=_fill_with_last_one,
|
|
1155
|
+
ident=np.nan,
|
|
1156
|
+
x=array,
|
|
1157
|
+
axis=axis,
|
|
1158
|
+
dtype=array.dtype,
|
|
1159
|
+
)
|