Simple-Track 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simple_track-2.0.0.dist-info/METADATA +218 -0
- simple_track-2.0.0.dist-info/RECORD +17 -0
- simple_track-2.0.0.dist-info/WHEEL +5 -0
- simple_track-2.0.0.dist-info/entry_points.txt +2 -0
- simple_track-2.0.0.dist-info/licenses/LICENSE +373 -0
- simple_track-2.0.0.dist-info/top_level.txt +1 -0
- simpletrack/__init__.py +1 -0
- simpletrack/exceptions.py +51 -0
- simpletrack/feature.py +322 -0
- simpletrack/flow_solver.py +589 -0
- simpletrack/frame.py +521 -0
- simpletrack/frame_output.py +295 -0
- simpletrack/frame_tracker.py +962 -0
- simpletrack/load.py +170 -0
- simpletrack/run_simple_track.py +12 -0
- simpletrack/track.py +281 -0
- simpletrack/utils.py +145 -0
|
@@ -0,0 +1,589 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import warnings
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import scipy.ndimage as ndimage
|
|
8
|
+
from numpy.typing import NDArray
|
|
9
|
+
from scipy.interpolate import LinearNDInterpolator, RectBivariateSpline
|
|
10
|
+
from scipy.signal.windows import tukey
|
|
11
|
+
from skimage.registration import phase_cross_correlation
|
|
12
|
+
|
|
13
|
+
from simpletrack.exceptions import ArrayError
|
|
14
|
+
from simpletrack.frame import Frame
|
|
15
|
+
from simpletrack.utils import check_arrays
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class FlowSolver:
|
|
19
|
+
"""
|
|
20
|
+
Class containing functionality for deriving flow field that maps features from an
|
|
21
|
+
older Frame to a more recent Frame. This is done by subdividing the input fields
|
|
22
|
+
into overlapping subdomains, deriving the flow in each subdomain using phase
|
|
23
|
+
cross-correlation, and then stitching together the subdomain flows using
|
|
24
|
+
interpolation.
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
subdomain_size: int = None,
|
|
30
|
+
min_fractional_coverage: float = 0.01,
|
|
31
|
+
subdomain_tolerance: int = 3,
|
|
32
|
+
overlap_threshold: float = 0.6,
|
|
33
|
+
apply_tukey_filtering: bool = True,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Class containing funcitonality for deriving flow between two input frames
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
subdomain_size (int, optional):
|
|
40
|
+
Size in pixels of individual squares to run fft for (dy, dx)
|
|
41
|
+
displacement. Must divide (y,x) lengths of the array.
|
|
42
|
+
Defaults to domain size / 5
|
|
43
|
+
min_fractional_coverage (float, optional):
|
|
44
|
+
Minimum fractional cover of objects required for fft to obtain (dy, dx)
|
|
45
|
+
displacement
|
|
46
|
+
Defaults to 0.01.
|
|
47
|
+
subdomain_tolerance (int, optional):
|
|
48
|
+
Maximum difference in displacement values between adjacent squares
|
|
49
|
+
(to remove spurious values)
|
|
50
|
+
Defaults to 3.
|
|
51
|
+
overlap_threshold (float, optional):
|
|
52
|
+
Minimum fraction of overlap between features for use in flow_solver
|
|
53
|
+
Defaults to 0.6.
|
|
54
|
+
apply_tukey_filtering (bool, optional):
|
|
55
|
+
Apply a 2D Tukey window to each subdomain before phase cross-correlation
|
|
56
|
+
Defaults to True.
|
|
57
|
+
"""
|
|
58
|
+
if isinstance(subdomain_size, int):
|
|
59
|
+
self.subdomain_shape = np.array([subdomain_size, subdomain_size], dtype=int)
|
|
60
|
+
elif isinstance(subdomain_size, float):
|
|
61
|
+
raise TypeError("Expected int or array-like, got float")
|
|
62
|
+
elif subdomain_size is None:
|
|
63
|
+
self.subdomain_shape = None
|
|
64
|
+
else:
|
|
65
|
+
self.subdomain_shape = check_arrays(subdomain_size, ndim=2).astype(int)
|
|
66
|
+
self.min_fractional_coverage = min_fractional_coverage
|
|
67
|
+
self.subdomain_tolerance = subdomain_tolerance
|
|
68
|
+
self.overlap_threshold = overlap_threshold
|
|
69
|
+
self.apply_tukey_filtering = apply_tukey_filtering
|
|
70
|
+
|
|
71
|
+
def analyse_flow(
|
|
72
|
+
self, prev_field: Union[Frame, NDArray], current_field: Union[Frame, NDArray]
|
|
73
|
+
) -> list[NDArray, NDArray]:
|
|
74
|
+
"""
|
|
75
|
+
Analyses previous field and current field to identify flow field. Uses phase
|
|
76
|
+
cross correlation over a series of overlapping subdomains to stitch together
|
|
77
|
+
a full field, where the flow is constant within each subdomain. Subdomain size
|
|
78
|
+
is controlled by OpticalFlowSolver init, but is estimated from inputs if not
|
|
79
|
+
provided.
|
|
80
|
+
|
|
81
|
+
Input feature fields must be of the same size and contain sufficient feature
|
|
82
|
+
coverage, as determined by min_fractional_coverage in init. Otherwise, solver
|
|
83
|
+
will return None for the flow fields.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
prev_field (Union[Frame, NDArray]):
|
|
87
|
+
Feature field from previous timestep
|
|
88
|
+
current_field (Union[Frame, NDArray]):
|
|
89
|
+
Feature field from current timestep
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
list[NDArray, NDArray]: y_flow, x_flow
|
|
93
|
+
"""
|
|
94
|
+
if isinstance(prev_field, Frame) and isinstance(current_field, Frame):
|
|
95
|
+
prev_features = prev_field.feature_field
|
|
96
|
+
current_features = current_field.feature_field
|
|
97
|
+
elif isinstance(prev_field, np.ndarray) and isinstance(
|
|
98
|
+
current_field, np.ndarray
|
|
99
|
+
):
|
|
100
|
+
prev_features = prev_field
|
|
101
|
+
current_features = current_field
|
|
102
|
+
else:
|
|
103
|
+
raise TypeError(
|
|
104
|
+
"prev_field and current_field must both be of type Frame or NDArray"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Check input fields are same shape
|
|
108
|
+
prev_features, current_features = check_arrays(
|
|
109
|
+
prev_features, current_features, equal_shape=True, ndim=2, dtype=int
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# Determine a subdomain size if not provided
|
|
113
|
+
if self.subdomain_shape is None:
|
|
114
|
+
self.subdomain_shape = self.get_subdomain_shape(prev_features.shape)
|
|
115
|
+
|
|
116
|
+
# Check inputs, don't proceed if not validated
|
|
117
|
+
prev_features, current_features = self._check_inputs(
|
|
118
|
+
prev_features, current_features
|
|
119
|
+
)
|
|
120
|
+
if prev_features is None:
|
|
121
|
+
return None, None
|
|
122
|
+
|
|
123
|
+
# Initialise containing arrays for holding subdomain dy, dx
|
|
124
|
+
subdomain_dy, subdomain_dx = self.get_subdomain_containment_arrays(
|
|
125
|
+
prev_features.shape, self.subdomain_shape
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Get tuple of indices of subdomains to iterate over
|
|
129
|
+
# This will also check that the subdomain shape exactly fits the domain
|
|
130
|
+
y_subdomain_bounds, x_subdomain_bounds = self.get_overlapping_subdomain_idxs(
|
|
131
|
+
prev_features.shape, self.subdomain_shape
|
|
132
|
+
)
|
|
133
|
+
# Get the iterable of subdomain bounds
|
|
134
|
+
subdomain_bounds = self.subdomain_iter(y_subdomain_bounds, x_subdomain_bounds)
|
|
135
|
+
|
|
136
|
+
for y_bounds, x_bounds in subdomain_bounds:
|
|
137
|
+
# Construct subdomain mask from bounds
|
|
138
|
+
y_slice = slice(y_bounds[0], y_bounds[1])
|
|
139
|
+
x_slice = slice(x_bounds[0], x_bounds[1])
|
|
140
|
+
subdomain_mask = (y_slice, x_slice)
|
|
141
|
+
|
|
142
|
+
dy, dx = self.derive_subdomain_flow(
|
|
143
|
+
field1=prev_features[subdomain_mask],
|
|
144
|
+
field2=current_features[subdomain_mask],
|
|
145
|
+
tukey_filtering=self.apply_tukey_filtering,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
# Use first bounds to get idx for dy and dx subdomain
|
|
149
|
+
subdomain_step = self.subdomain_shape / 2
|
|
150
|
+
dy_idx = int(y_bounds[0] // subdomain_step[0])
|
|
151
|
+
dx_idx = int(x_bounds[0] // subdomain_step[1])
|
|
152
|
+
subdomain_dy[dy_idx, dx_idx] = dy
|
|
153
|
+
subdomain_dx[dy_idx, dx_idx] = dx
|
|
154
|
+
|
|
155
|
+
# Check neighbouring subdomain values vary within acceptable tolerance
|
|
156
|
+
subdomain_dy = self.check_subdomain_variability(subdomain_dy)
|
|
157
|
+
subdomain_dx = self.check_subdomain_variability(subdomain_dx)
|
|
158
|
+
|
|
159
|
+
# Finally, interpolate values between subdomains.
|
|
160
|
+
# For this function, only need the interior subdomain bounds (not edge indices)
|
|
161
|
+
interior_y_subdom_bounds = y_subdomain_bounds[1:-1]
|
|
162
|
+
interior_x_subdom_bounds = x_subdomain_bounds[1:-1]
|
|
163
|
+
y_flow = self.interpolate_subdomain_flows(
|
|
164
|
+
interior_y_subdom_bounds,
|
|
165
|
+
interior_x_subdom_bounds,
|
|
166
|
+
subdomain_dy,
|
|
167
|
+
prev_features.shape,
|
|
168
|
+
)
|
|
169
|
+
x_flow = self.interpolate_subdomain_flows(
|
|
170
|
+
interior_y_subdom_bounds,
|
|
171
|
+
interior_x_subdom_bounds,
|
|
172
|
+
subdomain_dx,
|
|
173
|
+
prev_features.shape,
|
|
174
|
+
)
|
|
175
|
+
return y_flow, x_flow
|
|
176
|
+
|
|
177
|
+
def get_subdomain_containment_arrays(
|
|
178
|
+
self, full_domain_shape: NDArray, subdomain_shape: NDArray
|
|
179
|
+
) -> list[NDArray, NDArray]:
|
|
180
|
+
"""
|
|
181
|
+
Return array with correct shape for containing subdomain flow values
|
|
182
|
+
Shape is number of subdomains in each direction * 2 for overlaps
|
|
183
|
+
E.g., for a 100x200 domain and a 20x20 subdomain size, shape of
|
|
184
|
+
containing arrays will be 10x20, but -1 from each dimension
|
|
185
|
+
due to the stride of the values
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
full_domain_shape (NDArray): Shape of the full domain
|
|
189
|
+
subdomain_shape (NDArray): Shape of requested subdomain
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
list[NDArray, NDArray]: containing arrays for subdomain dy and dx
|
|
193
|
+
"""
|
|
194
|
+
full_domain_shape, subdomain_shape = check_arrays(
|
|
195
|
+
full_domain_shape, subdomain_shape, shape=(2,), dtype=int, non_negative=True
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
containing_shape = (full_domain_shape // subdomain_shape) * 2
|
|
199
|
+
containing_shape = [dim - 1 for dim in containing_shape]
|
|
200
|
+
subdomain_dy = np.full(shape=containing_shape, fill_value=np.nan)
|
|
201
|
+
subdomain_dx = np.full(shape=containing_shape, fill_value=np.nan)
|
|
202
|
+
return subdomain_dy, subdomain_dx
|
|
203
|
+
|
|
204
|
+
def subdomain_iter(
|
|
205
|
+
self, y_subdomain_bounds: tuple, x_subdomain_bounds: tuple
|
|
206
|
+
) -> Iterable:
|
|
207
|
+
"""
|
|
208
|
+
Produces iterable of subdomain bounds with stride of 2 indices
|
|
209
|
+
between inputs. Each returned set is an iterable of tuples defining
|
|
210
|
+
start and end bounds in y and x direction
|
|
211
|
+
|
|
212
|
+
Args:
|
|
213
|
+
y_subdomain_bounds (tuple):
|
|
214
|
+
x_subdomain_bounds (tuple):
|
|
215
|
+
|
|
216
|
+
Returns:
|
|
217
|
+
Iterable: ((y_start, y_stop), (x_start, x_stop))
|
|
218
|
+
"""
|
|
219
|
+
y_subdomain_bounds, x_subdomain_bounds = check_arrays(
|
|
220
|
+
y_subdomain_bounds, x_subdomain_bounds, dtype=int, ndim=1, non_negative=True
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
# Combine these idxs pairwise with stride 2 to define the subdomain bounds
|
|
224
|
+
# E.g., for subdomain size of 20 in y with bounds [0, 10, 20, 30...]
|
|
225
|
+
# and for subdomain size of 30 in x with bounds [0, 15, 30, 45...]
|
|
226
|
+
# this operation then produces ((0, 20), (10, 30), ...) for y
|
|
227
|
+
# and ((0, 30), (15, 45)...) for x
|
|
228
|
+
y_subdomain_bounds_tuple = pairwise_with_stride(y_subdomain_bounds, 2)
|
|
229
|
+
x_subdomain_bounds_tuple = pairwise_with_stride(x_subdomain_bounds, 2)
|
|
230
|
+
|
|
231
|
+
# Finally, get permutations of all xy subdomain bounds
|
|
232
|
+
# E.g., for example above, produces ( ((0, 20), (0, 30)), ((0, 20), (15, 45))..)
|
|
233
|
+
subdomain_bounds = itertools.product(
|
|
234
|
+
y_subdomain_bounds_tuple, x_subdomain_bounds_tuple
|
|
235
|
+
)
|
|
236
|
+
return subdomain_bounds
|
|
237
|
+
|
|
238
|
+
def check_subdomain_variability(self, subdomain_vals: NDArray) -> NDArray:
|
|
239
|
+
"""
|
|
240
|
+
Check variability in neighbouring subdomains to ensure no large
|
|
241
|
+
discrepancies. If there is a neighbourhood mean that departs from
|
|
242
|
+
the local value by more than self.subdomain_tolerance, set the
|
|
243
|
+
local value to np.nan
|
|
244
|
+
|
|
245
|
+
Args:
|
|
246
|
+
subdomain_vals (NDArray):
|
|
247
|
+
Flow values derived in subdomains
|
|
248
|
+
|
|
249
|
+
Returns:
|
|
250
|
+
NDArray: Input values with outliers replaced by nan
|
|
251
|
+
"""
|
|
252
|
+
# Check input array is 2D
|
|
253
|
+
subdomain_vals = check_arrays(subdomain_vals, ndim=2, dtype=float)
|
|
254
|
+
|
|
255
|
+
# Setup footprint for performing filter check on neighbouring points
|
|
256
|
+
# footprint excludes current index
|
|
257
|
+
footprint = np.ones((3, 3))
|
|
258
|
+
footprint[1, 1] = 0
|
|
259
|
+
|
|
260
|
+
# Catch OOB warnings
|
|
261
|
+
with warnings.catch_warnings():
|
|
262
|
+
warnings.simplefilter("ignore", category=RuntimeWarning)
|
|
263
|
+
nbhood_mean = ndimage.generic_filter(
|
|
264
|
+
input=subdomain_vals,
|
|
265
|
+
function=np.nanmean, # Apply nanmean to each nbhood
|
|
266
|
+
footprint=footprint, # Deterrmines how to sample points in the nbhood
|
|
267
|
+
mode="constant",
|
|
268
|
+
# Determines how to handle boundaries. "constant" = fill with cval
|
|
269
|
+
cval=np.nan,
|
|
270
|
+
# Fill boundary values with nan so they don't contribute to nanmean
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
# Check for any values where the nanmean exceeds threshold set in init
|
|
274
|
+
invalid_tolerance = (
|
|
275
|
+
np.abs(nbhood_mean - subdomain_vals) > self.subdomain_tolerance
|
|
276
|
+
)
|
|
277
|
+
subdomain_vals[invalid_tolerance] = np.nan
|
|
278
|
+
return subdomain_vals
|
|
279
|
+
|
|
280
|
+
def get_subdomain_shape(self, feature_field_shape):
|
|
281
|
+
# TODO: figure out some logic here for getting a good sd size
|
|
282
|
+
# if none is provided.
|
|
283
|
+
# Use this for now, but it won't work in all cases!
|
|
284
|
+
# TODO: this is entirely arbitrary. Check if this is sensible. It probably isnt
|
|
285
|
+
# TODO: what if domain is an odd shape?? What then??
|
|
286
|
+
sd_shape = np.array(feature_field_shape) // 5
|
|
287
|
+
if not self.check_subdomain_size_fits_in_full_domain(
|
|
288
|
+
feature_field_shape, sd_shape
|
|
289
|
+
):
|
|
290
|
+
# TODO: do something more intelligent here rather than just raise an error
|
|
291
|
+
# Try to find another subdomain shape that could fit
|
|
292
|
+
raise Exception(
|
|
293
|
+
f"Subdomain shape ({sd_shape}) cannot fit ({feature_field_shape})"
|
|
294
|
+
)
|
|
295
|
+
return sd_shape
|
|
296
|
+
|
|
297
|
+
def check_subdomain_size_fits_in_full_domain(
|
|
298
|
+
self, feature_field_shape: NDArray, subdomain_shape: NDArray
|
|
299
|
+
) -> bool:
|
|
300
|
+
"""
|
|
301
|
+
Determines whether the subdomain shape is suitable for the feature field
|
|
302
|
+
shape. Subdomain is suitable only if it fits an equal number of times
|
|
303
|
+
into the feature field in each dimension
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
feature_field_shape (NDArray):
|
|
307
|
+
1D array describing shape of feature field
|
|
308
|
+
subdomain_shape (NDArray):
|
|
309
|
+
1d array describing shape of requested subdomain
|
|
310
|
+
|
|
311
|
+
Returns:
|
|
312
|
+
bool: True if subdomain shape fits exactly in feature field shape,
|
|
313
|
+
False otherwise
|
|
314
|
+
"""
|
|
315
|
+
# Check inputs, only except errors related to contents of inputs
|
|
316
|
+
try:
|
|
317
|
+
feature_field_shape, subdomain_shape = check_arrays(
|
|
318
|
+
feature_field_shape,
|
|
319
|
+
subdomain_shape,
|
|
320
|
+
dtype=int,
|
|
321
|
+
shape=(2,),
|
|
322
|
+
non_negative=True,
|
|
323
|
+
)
|
|
324
|
+
except ArrayError:
|
|
325
|
+
return False
|
|
326
|
+
|
|
327
|
+
# First, check if subdomain shape/2 is an integer
|
|
328
|
+
if not np.all(subdomain_shape % 2 == 0):
|
|
329
|
+
return False
|
|
330
|
+
|
|
331
|
+
subdomain_check = [
|
|
332
|
+
dim % sd_shape / 2
|
|
333
|
+
for dim, sd_shape in zip(feature_field_shape, subdomain_shape)
|
|
334
|
+
]
|
|
335
|
+
return not any([remainder != 0 for remainder in subdomain_check])
|
|
336
|
+
|
|
337
|
+
def get_overlapping_subdomain_idxs(
|
|
338
|
+
self, feature_field_shape: NDArray, subdomain_shape: NDArray
|
|
339
|
+
) -> tuple[tuple]:
|
|
340
|
+
"""
|
|
341
|
+
Get indices of subdomain bounds of the requested shape that will fit into
|
|
342
|
+
the requested feature field. These subdomains overlap halfway, meaning that
|
|
343
|
+
the requirement for an exact fit is that HALF the subdomain shape must fit.
|
|
344
|
+
Returns tuples giving bounds as ((y0, y1), (x0, x1))
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
feature_field_shape (NDArray):
|
|
348
|
+
Shape of the feature field to subdivide
|
|
349
|
+
subdomain_shape (NDArray):
|
|
350
|
+
Requested shape of the subdomain
|
|
351
|
+
|
|
352
|
+
Raises:
|
|
353
|
+
ValueError: If requested subdomain cannot fit exactly into input field
|
|
354
|
+
|
|
355
|
+
Returns:
|
|
356
|
+
tuple(tuple): overlapping subdomain bounds in form ((y0, y1), (x0, x1))
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
feature_field_shape, subdomain_shape = check_arrays(
|
|
360
|
+
feature_field_shape,
|
|
361
|
+
subdomain_shape,
|
|
362
|
+
dtype=int,
|
|
363
|
+
shape=(2,),
|
|
364
|
+
non_negative=True,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
if not self.check_subdomain_size_fits_in_full_domain(
|
|
368
|
+
feature_field_shape, subdomain_shape
|
|
369
|
+
):
|
|
370
|
+
print(f"Input feature field dim size: {feature_field_shape}")
|
|
371
|
+
print(f"Requested subdomain shape: {subdomain_shape}")
|
|
372
|
+
msg = "Could not fit exact number of subdomains in feauture_field"
|
|
373
|
+
raise ValueError(msg)
|
|
374
|
+
|
|
375
|
+
# Now, get idxs of subdomain bounds to iterate over (with overlap)
|
|
376
|
+
step = (subdomain_shape / 2).astype(int)
|
|
377
|
+
y_subdomain_idxs = np.arange(
|
|
378
|
+
start=0, stop=feature_field_shape[0] + step[0], step=step[0]
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
x_subdomain_idxs = np.arange(
|
|
382
|
+
start=0, stop=feature_field_shape[1] + step[1], step=step[1]
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
return y_subdomain_idxs, x_subdomain_idxs
|
|
386
|
+
|
|
387
|
+
def derive_subdomain_flow(
|
|
388
|
+
self, field1: NDArray, field2: NDArray, tukey_filtering: bool = True
|
|
389
|
+
) -> list[int, int]:
|
|
390
|
+
"""
|
|
391
|
+
Uses FFT to identify most likely dy, dx motion vectors that translate field1
|
|
392
|
+
to field 2. This is largely handled by
|
|
393
|
+
skimage.registration.phase_cross_correlation but with additional
|
|
394
|
+
pre-processing to avoid spurious correlations. E.g., if tukey_smoothing flag is
|
|
395
|
+
enabled, applies a filter to the edges of each field that tapers to zero,
|
|
396
|
+
which avoids spectral leakage.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
field1 (NDArray):
|
|
400
|
+
Previous timestep binary field
|
|
401
|
+
field2 (NDArray):
|
|
402
|
+
Current timestep binary field
|
|
403
|
+
tukey_filtering (bool, optional):
|
|
404
|
+
Whether to apply tukey filter to input fields to prevent wrap-around
|
|
405
|
+
disparities occuring during FFT transformations.
|
|
406
|
+
Defaults to True.
|
|
407
|
+
|
|
408
|
+
Returns:
|
|
409
|
+
list[int, int]: [dy, dx] motion vectors for subdomain flow
|
|
410
|
+
"""
|
|
411
|
+
# Check inputs are equally shaped 2D arrays containing ints
|
|
412
|
+
field1, field2 = check_arrays(
|
|
413
|
+
field1, field2, ndim=2, equal_shape=True, dtype=int, non_negative=True
|
|
414
|
+
)
|
|
415
|
+
if not isinstance(tukey_filtering, bool):
|
|
416
|
+
raise TypeError(
|
|
417
|
+
f"Expected tukey_filtering type bool, got {type(tukey_filtering)}"
|
|
418
|
+
)
|
|
419
|
+
|
|
420
|
+
# Filter inputs if flagged
|
|
421
|
+
if tukey_filtering:
|
|
422
|
+
domain_filter = self.get_2d_tukey_window(field1.shape)
|
|
423
|
+
field1 = field1 * domain_filter
|
|
424
|
+
field2 = field2 * domain_filter
|
|
425
|
+
|
|
426
|
+
# Subtracting the mean from binary fields before cross correlation centres each
|
|
427
|
+
# field around zero and improves accuracy of correlation peak
|
|
428
|
+
m1 = field1 - np.mean(field1)
|
|
429
|
+
m2 = field2 - np.mean(field2)
|
|
430
|
+
|
|
431
|
+
# Since image registration finds the vector that translates the second arg to
|
|
432
|
+
# the first, need to reverse input order
|
|
433
|
+
with warnings.catch_warnings():
|
|
434
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
435
|
+
cross_corr = phase_cross_correlation(
|
|
436
|
+
m2,
|
|
437
|
+
m1,
|
|
438
|
+
space="real",
|
|
439
|
+
overlap_ratio=self.overlap_threshold,
|
|
440
|
+
normalization=None,
|
|
441
|
+
upsample_factor=1,
|
|
442
|
+
disambiguate=False,
|
|
443
|
+
)
|
|
444
|
+
|
|
445
|
+
dy, dx = cross_corr[0]
|
|
446
|
+
# error = cross_corr[1]
|
|
447
|
+
return dy, dx
|
|
448
|
+
|
|
449
|
+
def get_2d_tukey_window(self, arr_shape: NDArray) -> NDArray:
|
|
450
|
+
"""
|
|
451
|
+
Creates a 2D tukey filter for an array of the requested input shape.
|
|
452
|
+
Uses scipy.signal.windows.tukey to create a 1D filter for each dimension,
|
|
453
|
+
then constructs the 2D filter using outer product of each 1D filter.
|
|
454
|
+
|
|
455
|
+
A Tukey window is a tapered cosine combination of a rectangular and
|
|
456
|
+
Hanning window which provides a good balance between preventing spectral
|
|
457
|
+
leakage and maintaining good frequency resolution.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
arr_shape (NDArray):
|
|
461
|
+
Shape with which to create the Tukey filter
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
NDArray: 2D filter array of requested shape
|
|
466
|
+
"""
|
|
467
|
+
# Checks that the arr_shape input describes dimensions of a 2D array
|
|
468
|
+
arr_shape = check_arrays(arr_shape, shape=(2,), dtype=int, non_negative=True)
|
|
469
|
+
|
|
470
|
+
# Create a 1D Tukey filter for each dimension. Alpha sets the degree to which
|
|
471
|
+
# the filter resembles either a rectangular window (alpha=0) or a Hanning window
|
|
472
|
+
# (alpha=1). We want to retain a lot of points if the subdomain is small, but
|
|
473
|
+
# for larger subdomains we can apply more smoothing.
|
|
474
|
+
filters = [
|
|
475
|
+
tukey(dim_size, alpha=max(0.1, 10.0 / dim_size)) for dim_size in arr_shape
|
|
476
|
+
]
|
|
477
|
+
|
|
478
|
+
return np.outer(*filters)
|
|
479
|
+
|
|
480
|
+
def interpolate_subdomain_flows(
|
|
481
|
+
self,
|
|
482
|
+
y_subdomain_bounds: tuple,
|
|
483
|
+
x_subdomain_bounds: tuple,
|
|
484
|
+
subdomain_flows: NDArray,
|
|
485
|
+
full_domain_shape: NDArray,
|
|
486
|
+
) -> NDArray:
|
|
487
|
+
"""
|
|
488
|
+
Takes the subdomain flows found from track_subdomain_flow and stitches them
|
|
489
|
+
together using 2d interpolation via RectBivariateSpline
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
y_subdomain_bounds (tuple):
|
|
493
|
+
bounds of y subdomains
|
|
494
|
+
x_subdomain_bounds (tuple):
|
|
495
|
+
bounds of x subdomains
|
|
496
|
+
subdomain_flows (NDArray):
|
|
497
|
+
subdomain flows
|
|
498
|
+
full_domain_shape (NDArray):
|
|
499
|
+
Shape of full domain to create flow field for
|
|
500
|
+
|
|
501
|
+
Returns:
|
|
502
|
+
NDArray: Interpolated flow field for the full domain
|
|
503
|
+
"""
|
|
504
|
+
# Check inputs
|
|
505
|
+
y_subdomain_bounds, x_subdomain_bounds = check_arrays(
|
|
506
|
+
y_subdomain_bounds, x_subdomain_bounds, ndim=1, non_negative=True
|
|
507
|
+
)
|
|
508
|
+
full_domain_shape = check_arrays(full_domain_shape, shape=(2,), dtype=int)
|
|
509
|
+
subdomain_flows = check_arrays(subdomain_flows, ndim=2)
|
|
510
|
+
subdomain_flows = self._fill_nans(subdomain_flows)
|
|
511
|
+
|
|
512
|
+
# If there are fewer than 4 nonzero subdomain flow elements, return 0 flow field
|
|
513
|
+
if np.count_nonzero(subdomain_flows) < 4:
|
|
514
|
+
return np.zeros(full_domain_shape, dtype=int)
|
|
515
|
+
|
|
516
|
+
# interp2d deprecated in newer version of scipy.
|
|
517
|
+
# For functionally identical replacement, use RectBivariateSpline
|
|
518
|
+
# with kx=3, ky=3 for cubic spline interpolation
|
|
519
|
+
# RectBivariateSpline expected data on (x,y) grid, not expected (y,x),
|
|
520
|
+
# hence transposed input and output.
|
|
521
|
+
# https://scipy.github.io/devdocs/tutorial/interpolate/interp_transition_guide.html
|
|
522
|
+
# fu = interpolate.interp2d(xint[0, :], yint[:, 0], filled, kind="cubic")
|
|
523
|
+
fu = RectBivariateSpline(
|
|
524
|
+
x_subdomain_bounds, y_subdomain_bounds, subdomain_flows.T, kx=3, ky=3
|
|
525
|
+
)
|
|
526
|
+
y_range = range(full_domain_shape[0])
|
|
527
|
+
x_range = range(full_domain_shape[1])
|
|
528
|
+
newumat = fu(x_range, y_range).T
|
|
529
|
+
return newumat
|
|
530
|
+
|
|
531
|
+
def _check_inputs(self, arr1: NDArray, arr2: NDArray) -> bool:
|
|
532
|
+
# Check both fields have features
|
|
533
|
+
if not np.count_nonzero(arr1) and not np.count_nonzero(arr2):
|
|
534
|
+
print("No features detected in both fields. Skipping optical flow.")
|
|
535
|
+
return None, None
|
|
536
|
+
|
|
537
|
+
# If there are too few features, don't proceed with optical flow
|
|
538
|
+
# TODO: what is actually the check here??
|
|
539
|
+
subdomain_count = np.prod(self.subdomain_shape)
|
|
540
|
+
min_feature_coverage = subdomain_count * self.min_fractional_coverage
|
|
541
|
+
if np.sum(arr1) < min_feature_coverage or np.sum(arr2) < min_feature_coverage:
|
|
542
|
+
print(f"Threshold for running optical flow: {self.min_fractional_coverage}")
|
|
543
|
+
print(f"Number of pixels above threshold in arr1: {np.sum(arr1)}")
|
|
544
|
+
print(f"Number of pixels above threshold in arr2: {np.sum(arr2)}")
|
|
545
|
+
print("Number of features in arr1 and/or arr2 less than threshold. ")
|
|
546
|
+
print("Skipping optical flow")
|
|
547
|
+
return None, None
|
|
548
|
+
|
|
549
|
+
return arr1, arr2
|
|
550
|
+
|
|
551
|
+
def _fill_nans(self, arr: NDArray) -> NDArray:
|
|
552
|
+
"""
|
|
553
|
+
Replace NaNs in the input field with values interpolated from neighbouring
|
|
554
|
+
grid points, or 0 if this is not possible.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
arr (NDArray): Input array, potentially containing NaNs
|
|
558
|
+
Returns:
|
|
559
|
+
NDArray: Ouput array with NaNs filled
|
|
560
|
+
"""
|
|
561
|
+
# Create NaN mask
|
|
562
|
+
valid_mask = ~np.isnan(arr)
|
|
563
|
+
coords = np.nonzero(valid_mask)
|
|
564
|
+
non_nan_values = arr[valid_mask]
|
|
565
|
+
it = LinearNDInterpolator(coords, non_nan_values, fill_value=0)
|
|
566
|
+
filled = it(list(np.ndindex(arr.shape))).reshape(arr.shape)
|
|
567
|
+
return filled
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def pairwise_with_stride(input_iter: Iterable, stride: int) -> Iterable:
|
|
571
|
+
"""
|
|
572
|
+
Similar to itertools.pairwise but with step between elements
|
|
573
|
+
|
|
574
|
+
pairwise_with_stride('ABCDEFG', 1) → AB BC CD DE EF FG
|
|
575
|
+
pairwise_with_stride('ABCDEFG', 2) → AC BD CE DF EG
|
|
576
|
+
"""
|
|
577
|
+
|
|
578
|
+
if not isinstance(stride, int):
|
|
579
|
+
raise TypeError(f"Expected int, got f{type(stride)}")
|
|
580
|
+
|
|
581
|
+
pairwise_list = []
|
|
582
|
+
for idx, element in enumerate(input_iter):
|
|
583
|
+
try:
|
|
584
|
+
next_element = input_iter[idx + stride]
|
|
585
|
+
pairwise_list.append((element, next_element))
|
|
586
|
+
except IndexError:
|
|
587
|
+
pass
|
|
588
|
+
|
|
589
|
+
return iter(pairwise_list)
|