Simple-Track 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,589 @@
1
+ import itertools
2
+ import warnings
3
+ from collections.abc import Iterable
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ import scipy.ndimage as ndimage
8
+ from numpy.typing import NDArray
9
+ from scipy.interpolate import LinearNDInterpolator, RectBivariateSpline
10
+ from scipy.signal.windows import tukey
11
+ from skimage.registration import phase_cross_correlation
12
+
13
+ from simpletrack.exceptions import ArrayError
14
+ from simpletrack.frame import Frame
15
+ from simpletrack.utils import check_arrays
16
+
17
+
18
+ class FlowSolver:
19
+ """
20
+ Class containing functionality for deriving flow field that maps features from an
21
+ older Frame to a more recent Frame. This is done by subdividing the input fields
22
+ into overlapping subdomains, deriving the flow in each subdomain using phase
23
+ cross-correlation, and then stitching together the subdomain flows using
24
+ interpolation.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ subdomain_size: int = None,
30
+ min_fractional_coverage: float = 0.01,
31
+ subdomain_tolerance: int = 3,
32
+ overlap_threshold: float = 0.6,
33
+ apply_tukey_filtering: bool = True,
34
+ ) -> None:
35
+ """
36
+ Class containing funcitonality for deriving flow between two input frames
37
+
38
+ Args:
39
+ subdomain_size (int, optional):
40
+ Size in pixels of individual squares to run fft for (dy, dx)
41
+ displacement. Must divide (y,x) lengths of the array.
42
+ Defaults to domain size / 5
43
+ min_fractional_coverage (float, optional):
44
+ Minimum fractional cover of objects required for fft to obtain (dy, dx)
45
+ displacement
46
+ Defaults to 0.01.
47
+ subdomain_tolerance (int, optional):
48
+ Maximum difference in displacement values between adjacent squares
49
+ (to remove spurious values)
50
+ Defaults to 3.
51
+ overlap_threshold (float, optional):
52
+ Minimum fraction of overlap between features for use in flow_solver
53
+ Defaults to 0.6.
54
+ apply_tukey_filtering (bool, optional):
55
+ Apply a 2D Tukey window to each subdomain before phase cross-correlation
56
+ Defaults to True.
57
+ """
58
+ if isinstance(subdomain_size, int):
59
+ self.subdomain_shape = np.array([subdomain_size, subdomain_size], dtype=int)
60
+ elif isinstance(subdomain_size, float):
61
+ raise TypeError("Expected int or array-like, got float")
62
+ elif subdomain_size is None:
63
+ self.subdomain_shape = None
64
+ else:
65
+ self.subdomain_shape = check_arrays(subdomain_size, ndim=2).astype(int)
66
+ self.min_fractional_coverage = min_fractional_coverage
67
+ self.subdomain_tolerance = subdomain_tolerance
68
+ self.overlap_threshold = overlap_threshold
69
+ self.apply_tukey_filtering = apply_tukey_filtering
70
+
71
+ def analyse_flow(
72
+ self, prev_field: Union[Frame, NDArray], current_field: Union[Frame, NDArray]
73
+ ) -> list[NDArray, NDArray]:
74
+ """
75
+ Analyses previous field and current field to identify flow field. Uses phase
76
+ cross correlation over a series of overlapping subdomains to stitch together
77
+ a full field, where the flow is constant within each subdomain. Subdomain size
78
+ is controlled by OpticalFlowSolver init, but is estimated from inputs if not
79
+ provided.
80
+
81
+ Input feature fields must be of the same size and contain sufficient feature
82
+ coverage, as determined by min_fractional_coverage in init. Otherwise, solver
83
+ will return None for the flow fields.
84
+
85
+ Args:
86
+ prev_field (Union[Frame, NDArray]):
87
+ Feature field from previous timestep
88
+ current_field (Union[Frame, NDArray]):
89
+ Feature field from current timestep
90
+
91
+ Returns:
92
+ list[NDArray, NDArray]: y_flow, x_flow
93
+ """
94
+ if isinstance(prev_field, Frame) and isinstance(current_field, Frame):
95
+ prev_features = prev_field.feature_field
96
+ current_features = current_field.feature_field
97
+ elif isinstance(prev_field, np.ndarray) and isinstance(
98
+ current_field, np.ndarray
99
+ ):
100
+ prev_features = prev_field
101
+ current_features = current_field
102
+ else:
103
+ raise TypeError(
104
+ "prev_field and current_field must both be of type Frame or NDArray"
105
+ )
106
+
107
+ # Check input fields are same shape
108
+ prev_features, current_features = check_arrays(
109
+ prev_features, current_features, equal_shape=True, ndim=2, dtype=int
110
+ )
111
+
112
+ # Determine a subdomain size if not provided
113
+ if self.subdomain_shape is None:
114
+ self.subdomain_shape = self.get_subdomain_shape(prev_features.shape)
115
+
116
+ # Check inputs, don't proceed if not validated
117
+ prev_features, current_features = self._check_inputs(
118
+ prev_features, current_features
119
+ )
120
+ if prev_features is None:
121
+ return None, None
122
+
123
+ # Initialise containing arrays for holding subdomain dy, dx
124
+ subdomain_dy, subdomain_dx = self.get_subdomain_containment_arrays(
125
+ prev_features.shape, self.subdomain_shape
126
+ )
127
+
128
+ # Get tuple of indices of subdomains to iterate over
129
+ # This will also check that the subdomain shape exactly fits the domain
130
+ y_subdomain_bounds, x_subdomain_bounds = self.get_overlapping_subdomain_idxs(
131
+ prev_features.shape, self.subdomain_shape
132
+ )
133
+ # Get the iterable of subdomain bounds
134
+ subdomain_bounds = self.subdomain_iter(y_subdomain_bounds, x_subdomain_bounds)
135
+
136
+ for y_bounds, x_bounds in subdomain_bounds:
137
+ # Construct subdomain mask from bounds
138
+ y_slice = slice(y_bounds[0], y_bounds[1])
139
+ x_slice = slice(x_bounds[0], x_bounds[1])
140
+ subdomain_mask = (y_slice, x_slice)
141
+
142
+ dy, dx = self.derive_subdomain_flow(
143
+ field1=prev_features[subdomain_mask],
144
+ field2=current_features[subdomain_mask],
145
+ tukey_filtering=self.apply_tukey_filtering,
146
+ )
147
+
148
+ # Use first bounds to get idx for dy and dx subdomain
149
+ subdomain_step = self.subdomain_shape / 2
150
+ dy_idx = int(y_bounds[0] // subdomain_step[0])
151
+ dx_idx = int(x_bounds[0] // subdomain_step[1])
152
+ subdomain_dy[dy_idx, dx_idx] = dy
153
+ subdomain_dx[dy_idx, dx_idx] = dx
154
+
155
+ # Check neighbouring subdomain values vary within acceptable tolerance
156
+ subdomain_dy = self.check_subdomain_variability(subdomain_dy)
157
+ subdomain_dx = self.check_subdomain_variability(subdomain_dx)
158
+
159
+ # Finally, interpolate values between subdomains.
160
+ # For this function, only need the interior subdomain bounds (not edge indices)
161
+ interior_y_subdom_bounds = y_subdomain_bounds[1:-1]
162
+ interior_x_subdom_bounds = x_subdomain_bounds[1:-1]
163
+ y_flow = self.interpolate_subdomain_flows(
164
+ interior_y_subdom_bounds,
165
+ interior_x_subdom_bounds,
166
+ subdomain_dy,
167
+ prev_features.shape,
168
+ )
169
+ x_flow = self.interpolate_subdomain_flows(
170
+ interior_y_subdom_bounds,
171
+ interior_x_subdom_bounds,
172
+ subdomain_dx,
173
+ prev_features.shape,
174
+ )
175
+ return y_flow, x_flow
176
+
177
+ def get_subdomain_containment_arrays(
178
+ self, full_domain_shape: NDArray, subdomain_shape: NDArray
179
+ ) -> list[NDArray, NDArray]:
180
+ """
181
+ Return array with correct shape for containing subdomain flow values
182
+ Shape is number of subdomains in each direction * 2 for overlaps
183
+ E.g., for a 100x200 domain and a 20x20 subdomain size, shape of
184
+ containing arrays will be 10x20, but -1 from each dimension
185
+ due to the stride of the values
186
+
187
+ Args:
188
+ full_domain_shape (NDArray): Shape of the full domain
189
+ subdomain_shape (NDArray): Shape of requested subdomain
190
+
191
+ Returns:
192
+ list[NDArray, NDArray]: containing arrays for subdomain dy and dx
193
+ """
194
+ full_domain_shape, subdomain_shape = check_arrays(
195
+ full_domain_shape, subdomain_shape, shape=(2,), dtype=int, non_negative=True
196
+ )
197
+
198
+ containing_shape = (full_domain_shape // subdomain_shape) * 2
199
+ containing_shape = [dim - 1 for dim in containing_shape]
200
+ subdomain_dy = np.full(shape=containing_shape, fill_value=np.nan)
201
+ subdomain_dx = np.full(shape=containing_shape, fill_value=np.nan)
202
+ return subdomain_dy, subdomain_dx
203
+
204
+ def subdomain_iter(
205
+ self, y_subdomain_bounds: tuple, x_subdomain_bounds: tuple
206
+ ) -> Iterable:
207
+ """
208
+ Produces iterable of subdomain bounds with stride of 2 indices
209
+ between inputs. Each returned set is an iterable of tuples defining
210
+ start and end bounds in y and x direction
211
+
212
+ Args:
213
+ y_subdomain_bounds (tuple):
214
+ x_subdomain_bounds (tuple):
215
+
216
+ Returns:
217
+ Iterable: ((y_start, y_stop), (x_start, x_stop))
218
+ """
219
+ y_subdomain_bounds, x_subdomain_bounds = check_arrays(
220
+ y_subdomain_bounds, x_subdomain_bounds, dtype=int, ndim=1, non_negative=True
221
+ )
222
+
223
+ # Combine these idxs pairwise with stride 2 to define the subdomain bounds
224
+ # E.g., for subdomain size of 20 in y with bounds [0, 10, 20, 30...]
225
+ # and for subdomain size of 30 in x with bounds [0, 15, 30, 45...]
226
+ # this operation then produces ((0, 20), (10, 30), ...) for y
227
+ # and ((0, 30), (15, 45)...) for x
228
+ y_subdomain_bounds_tuple = pairwise_with_stride(y_subdomain_bounds, 2)
229
+ x_subdomain_bounds_tuple = pairwise_with_stride(x_subdomain_bounds, 2)
230
+
231
+ # Finally, get permutations of all xy subdomain bounds
232
+ # E.g., for example above, produces ( ((0, 20), (0, 30)), ((0, 20), (15, 45))..)
233
+ subdomain_bounds = itertools.product(
234
+ y_subdomain_bounds_tuple, x_subdomain_bounds_tuple
235
+ )
236
+ return subdomain_bounds
237
+
238
+ def check_subdomain_variability(self, subdomain_vals: NDArray) -> NDArray:
239
+ """
240
+ Check variability in neighbouring subdomains to ensure no large
241
+ discrepancies. If there is a neighbourhood mean that departs from
242
+ the local value by more than self.subdomain_tolerance, set the
243
+ local value to np.nan
244
+
245
+ Args:
246
+ subdomain_vals (NDArray):
247
+ Flow values derived in subdomains
248
+
249
+ Returns:
250
+ NDArray: Input values with outliers replaced by nan
251
+ """
252
+ # Check input array is 2D
253
+ subdomain_vals = check_arrays(subdomain_vals, ndim=2, dtype=float)
254
+
255
+ # Setup footprint for performing filter check on neighbouring points
256
+ # footprint excludes current index
257
+ footprint = np.ones((3, 3))
258
+ footprint[1, 1] = 0
259
+
260
+ # Catch OOB warnings
261
+ with warnings.catch_warnings():
262
+ warnings.simplefilter("ignore", category=RuntimeWarning)
263
+ nbhood_mean = ndimage.generic_filter(
264
+ input=subdomain_vals,
265
+ function=np.nanmean, # Apply nanmean to each nbhood
266
+ footprint=footprint, # Deterrmines how to sample points in the nbhood
267
+ mode="constant",
268
+ # Determines how to handle boundaries. "constant" = fill with cval
269
+ cval=np.nan,
270
+ # Fill boundary values with nan so they don't contribute to nanmean
271
+ )
272
+
273
+ # Check for any values where the nanmean exceeds threshold set in init
274
+ invalid_tolerance = (
275
+ np.abs(nbhood_mean - subdomain_vals) > self.subdomain_tolerance
276
+ )
277
+ subdomain_vals[invalid_tolerance] = np.nan
278
+ return subdomain_vals
279
+
280
+ def get_subdomain_shape(self, feature_field_shape):
281
+ # TODO: figure out some logic here for getting a good sd size
282
+ # if none is provided.
283
+ # Use this for now, but it won't work in all cases!
284
+ # TODO: this is entirely arbitrary. Check if this is sensible. It probably isnt
285
+ # TODO: what if domain is an odd shape?? What then??
286
+ sd_shape = np.array(feature_field_shape) // 5
287
+ if not self.check_subdomain_size_fits_in_full_domain(
288
+ feature_field_shape, sd_shape
289
+ ):
290
+ # TODO: do something more intelligent here rather than just raise an error
291
+ # Try to find another subdomain shape that could fit
292
+ raise Exception(
293
+ f"Subdomain shape ({sd_shape}) cannot fit ({feature_field_shape})"
294
+ )
295
+ return sd_shape
296
+
297
+ def check_subdomain_size_fits_in_full_domain(
298
+ self, feature_field_shape: NDArray, subdomain_shape: NDArray
299
+ ) -> bool:
300
+ """
301
+ Determines whether the subdomain shape is suitable for the feature field
302
+ shape. Subdomain is suitable only if it fits an equal number of times
303
+ into the feature field in each dimension
304
+
305
+ Args:
306
+ feature_field_shape (NDArray):
307
+ 1D array describing shape of feature field
308
+ subdomain_shape (NDArray):
309
+ 1d array describing shape of requested subdomain
310
+
311
+ Returns:
312
+ bool: True if subdomain shape fits exactly in feature field shape,
313
+ False otherwise
314
+ """
315
+ # Check inputs, only except errors related to contents of inputs
316
+ try:
317
+ feature_field_shape, subdomain_shape = check_arrays(
318
+ feature_field_shape,
319
+ subdomain_shape,
320
+ dtype=int,
321
+ shape=(2,),
322
+ non_negative=True,
323
+ )
324
+ except ArrayError:
325
+ return False
326
+
327
+ # First, check if subdomain shape/2 is an integer
328
+ if not np.all(subdomain_shape % 2 == 0):
329
+ return False
330
+
331
+ subdomain_check = [
332
+ dim % sd_shape / 2
333
+ for dim, sd_shape in zip(feature_field_shape, subdomain_shape)
334
+ ]
335
+ return not any([remainder != 0 for remainder in subdomain_check])
336
+
337
+ def get_overlapping_subdomain_idxs(
338
+ self, feature_field_shape: NDArray, subdomain_shape: NDArray
339
+ ) -> tuple[tuple]:
340
+ """
341
+ Get indices of subdomain bounds of the requested shape that will fit into
342
+ the requested feature field. These subdomains overlap halfway, meaning that
343
+ the requirement for an exact fit is that HALF the subdomain shape must fit.
344
+ Returns tuples giving bounds as ((y0, y1), (x0, x1))
345
+
346
+ Args:
347
+ feature_field_shape (NDArray):
348
+ Shape of the feature field to subdivide
349
+ subdomain_shape (NDArray):
350
+ Requested shape of the subdomain
351
+
352
+ Raises:
353
+ ValueError: If requested subdomain cannot fit exactly into input field
354
+
355
+ Returns:
356
+ tuple(tuple): overlapping subdomain bounds in form ((y0, y1), (x0, x1))
357
+ """
358
+
359
+ feature_field_shape, subdomain_shape = check_arrays(
360
+ feature_field_shape,
361
+ subdomain_shape,
362
+ dtype=int,
363
+ shape=(2,),
364
+ non_negative=True,
365
+ )
366
+
367
+ if not self.check_subdomain_size_fits_in_full_domain(
368
+ feature_field_shape, subdomain_shape
369
+ ):
370
+ print(f"Input feature field dim size: {feature_field_shape}")
371
+ print(f"Requested subdomain shape: {subdomain_shape}")
372
+ msg = "Could not fit exact number of subdomains in feauture_field"
373
+ raise ValueError(msg)
374
+
375
+ # Now, get idxs of subdomain bounds to iterate over (with overlap)
376
+ step = (subdomain_shape / 2).astype(int)
377
+ y_subdomain_idxs = np.arange(
378
+ start=0, stop=feature_field_shape[0] + step[0], step=step[0]
379
+ )
380
+
381
+ x_subdomain_idxs = np.arange(
382
+ start=0, stop=feature_field_shape[1] + step[1], step=step[1]
383
+ )
384
+
385
+ return y_subdomain_idxs, x_subdomain_idxs
386
+
387
+ def derive_subdomain_flow(
388
+ self, field1: NDArray, field2: NDArray, tukey_filtering: bool = True
389
+ ) -> list[int, int]:
390
+ """
391
+ Uses FFT to identify most likely dy, dx motion vectors that translate field1
392
+ to field 2. This is largely handled by
393
+ skimage.registration.phase_cross_correlation but with additional
394
+ pre-processing to avoid spurious correlations. E.g., if tukey_smoothing flag is
395
+ enabled, applies a filter to the edges of each field that tapers to zero,
396
+ which avoids spectral leakage.
397
+
398
+ Args:
399
+ field1 (NDArray):
400
+ Previous timestep binary field
401
+ field2 (NDArray):
402
+ Current timestep binary field
403
+ tukey_filtering (bool, optional):
404
+ Whether to apply tukey filter to input fields to prevent wrap-around
405
+ disparities occuring during FFT transformations.
406
+ Defaults to True.
407
+
408
+ Returns:
409
+ list[int, int]: [dy, dx] motion vectors for subdomain flow
410
+ """
411
+ # Check inputs are equally shaped 2D arrays containing ints
412
+ field1, field2 = check_arrays(
413
+ field1, field2, ndim=2, equal_shape=True, dtype=int, non_negative=True
414
+ )
415
+ if not isinstance(tukey_filtering, bool):
416
+ raise TypeError(
417
+ f"Expected tukey_filtering type bool, got {type(tukey_filtering)}"
418
+ )
419
+
420
+ # Filter inputs if flagged
421
+ if tukey_filtering:
422
+ domain_filter = self.get_2d_tukey_window(field1.shape)
423
+ field1 = field1 * domain_filter
424
+ field2 = field2 * domain_filter
425
+
426
+ # Subtracting the mean from binary fields before cross correlation centres each
427
+ # field around zero and improves accuracy of correlation peak
428
+ m1 = field1 - np.mean(field1)
429
+ m2 = field2 - np.mean(field2)
430
+
431
+ # Since image registration finds the vector that translates the second arg to
432
+ # the first, need to reverse input order
433
+ with warnings.catch_warnings():
434
+ warnings.simplefilter("ignore", category=UserWarning)
435
+ cross_corr = phase_cross_correlation(
436
+ m2,
437
+ m1,
438
+ space="real",
439
+ overlap_ratio=self.overlap_threshold,
440
+ normalization=None,
441
+ upsample_factor=1,
442
+ disambiguate=False,
443
+ )
444
+
445
+ dy, dx = cross_corr[0]
446
+ # error = cross_corr[1]
447
+ return dy, dx
448
+
449
+ def get_2d_tukey_window(self, arr_shape: NDArray) -> NDArray:
450
+ """
451
+ Creates a 2D tukey filter for an array of the requested input shape.
452
+ Uses scipy.signal.windows.tukey to create a 1D filter for each dimension,
453
+ then constructs the 2D filter using outer product of each 1D filter.
454
+
455
+ A Tukey window is a tapered cosine combination of a rectangular and
456
+ Hanning window which provides a good balance between preventing spectral
457
+ leakage and maintaining good frequency resolution.
458
+
459
+ Args:
460
+ arr_shape (NDArray):
461
+ Shape with which to create the Tukey filter
462
+
463
+
464
+ Returns:
465
+ NDArray: 2D filter array of requested shape
466
+ """
467
+ # Checks that the arr_shape input describes dimensions of a 2D array
468
+ arr_shape = check_arrays(arr_shape, shape=(2,), dtype=int, non_negative=True)
469
+
470
+ # Create a 1D Tukey filter for each dimension. Alpha sets the degree to which
471
+ # the filter resembles either a rectangular window (alpha=0) or a Hanning window
472
+ # (alpha=1). We want to retain a lot of points if the subdomain is small, but
473
+ # for larger subdomains we can apply more smoothing.
474
+ filters = [
475
+ tukey(dim_size, alpha=max(0.1, 10.0 / dim_size)) for dim_size in arr_shape
476
+ ]
477
+
478
+ return np.outer(*filters)
479
+
480
+ def interpolate_subdomain_flows(
481
+ self,
482
+ y_subdomain_bounds: tuple,
483
+ x_subdomain_bounds: tuple,
484
+ subdomain_flows: NDArray,
485
+ full_domain_shape: NDArray,
486
+ ) -> NDArray:
487
+ """
488
+ Takes the subdomain flows found from track_subdomain_flow and stitches them
489
+ together using 2d interpolation via RectBivariateSpline
490
+
491
+ Args:
492
+ y_subdomain_bounds (tuple):
493
+ bounds of y subdomains
494
+ x_subdomain_bounds (tuple):
495
+ bounds of x subdomains
496
+ subdomain_flows (NDArray):
497
+ subdomain flows
498
+ full_domain_shape (NDArray):
499
+ Shape of full domain to create flow field for
500
+
501
+ Returns:
502
+ NDArray: Interpolated flow field for the full domain
503
+ """
504
+ # Check inputs
505
+ y_subdomain_bounds, x_subdomain_bounds = check_arrays(
506
+ y_subdomain_bounds, x_subdomain_bounds, ndim=1, non_negative=True
507
+ )
508
+ full_domain_shape = check_arrays(full_domain_shape, shape=(2,), dtype=int)
509
+ subdomain_flows = check_arrays(subdomain_flows, ndim=2)
510
+ subdomain_flows = self._fill_nans(subdomain_flows)
511
+
512
+ # If there are fewer than 4 nonzero subdomain flow elements, return 0 flow field
513
+ if np.count_nonzero(subdomain_flows) < 4:
514
+ return np.zeros(full_domain_shape, dtype=int)
515
+
516
+ # interp2d deprecated in newer version of scipy.
517
+ # For functionally identical replacement, use RectBivariateSpline
518
+ # with kx=3, ky=3 for cubic spline interpolation
519
+ # RectBivariateSpline expected data on (x,y) grid, not expected (y,x),
520
+ # hence transposed input and output.
521
+ # https://scipy.github.io/devdocs/tutorial/interpolate/interp_transition_guide.html
522
+ # fu = interpolate.interp2d(xint[0, :], yint[:, 0], filled, kind="cubic")
523
+ fu = RectBivariateSpline(
524
+ x_subdomain_bounds, y_subdomain_bounds, subdomain_flows.T, kx=3, ky=3
525
+ )
526
+ y_range = range(full_domain_shape[0])
527
+ x_range = range(full_domain_shape[1])
528
+ newumat = fu(x_range, y_range).T
529
+ return newumat
530
+
531
+ def _check_inputs(self, arr1: NDArray, arr2: NDArray) -> bool:
532
+ # Check both fields have features
533
+ if not np.count_nonzero(arr1) and not np.count_nonzero(arr2):
534
+ print("No features detected in both fields. Skipping optical flow.")
535
+ return None, None
536
+
537
+ # If there are too few features, don't proceed with optical flow
538
+ # TODO: what is actually the check here??
539
+ subdomain_count = np.prod(self.subdomain_shape)
540
+ min_feature_coverage = subdomain_count * self.min_fractional_coverage
541
+ if np.sum(arr1) < min_feature_coverage or np.sum(arr2) < min_feature_coverage:
542
+ print(f"Threshold for running optical flow: {self.min_fractional_coverage}")
543
+ print(f"Number of pixels above threshold in arr1: {np.sum(arr1)}")
544
+ print(f"Number of pixels above threshold in arr2: {np.sum(arr2)}")
545
+ print("Number of features in arr1 and/or arr2 less than threshold. ")
546
+ print("Skipping optical flow")
547
+ return None, None
548
+
549
+ return arr1, arr2
550
+
551
+ def _fill_nans(self, arr: NDArray) -> NDArray:
552
+ """
553
+ Replace NaNs in the input field with values interpolated from neighbouring
554
+ grid points, or 0 if this is not possible.
555
+
556
+ Args:
557
+ arr (NDArray): Input array, potentially containing NaNs
558
+ Returns:
559
+ NDArray: Ouput array with NaNs filled
560
+ """
561
+ # Create NaN mask
562
+ valid_mask = ~np.isnan(arr)
563
+ coords = np.nonzero(valid_mask)
564
+ non_nan_values = arr[valid_mask]
565
+ it = LinearNDInterpolator(coords, non_nan_values, fill_value=0)
566
+ filled = it(list(np.ndindex(arr.shape))).reshape(arr.shape)
567
+ return filled
568
+
569
+
570
+ def pairwise_with_stride(input_iter: Iterable, stride: int) -> Iterable:
571
+ """
572
+ Similar to itertools.pairwise but with step between elements
573
+
574
+ pairwise_with_stride('ABCDEFG', 1) → AB BC CD DE EF FG
575
+ pairwise_with_stride('ABCDEFG', 2) → AC BD CE DF EG
576
+ """
577
+
578
+ if not isinstance(stride, int):
579
+ raise TypeError(f"Expected int, got f{type(stride)}")
580
+
581
+ pairwise_list = []
582
+ for idx, element in enumerate(input_iter):
583
+ try:
584
+ next_element = input_iter[idx + stride]
585
+ pairwise_list.append((element, next_element))
586
+ except IndexError:
587
+ pass
588
+
589
+ return iter(pairwise_list)