zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -5
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/config.py +34 -25
- zea/data/__init__.py +22 -25
- zea/data/augmentations.py +221 -28
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +123 -0
- zea/data/convert/camus.py +101 -40
- zea/data/convert/echonet.py +187 -86
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/{matlab.py → verasonics.py} +44 -65
- zea/data/data_format.py +155 -34
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +112 -71
- zea/data/file.py +184 -73
- zea/data/file_operations.py +496 -0
- zea/data/layers.py +3 -3
- zea/data/preset_utils.py +1 -1
- zea/datapaths.py +16 -4
- zea/display.py +14 -13
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/checks.py +6 -12
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +118 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +322 -146
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +31 -21
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +235 -23
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +30 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +770 -336
- zea/probes.py +6 -6
- zea/scan.py +121 -51
- zea/simulator.py +24 -21
- zea/tensor_ops.py +477 -353
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +101 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
- zea-0.0.8.dist-info/RECORD +122 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
"""Lucas-Kanade optical flow tracker.
|
|
2
|
+
|
|
3
|
+
.. seealso::
|
|
4
|
+
A tutorial notebook where this model is used:
|
|
5
|
+
:doc:`../notebooks/models/speckle_tracking_example`.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
from keras import ops
|
|
12
|
+
|
|
13
|
+
from zea.tensor_ops import gaussian_filter, translate
|
|
14
|
+
|
|
15
|
+
from .base import BaseTracker
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LucasKanadeTracker(BaseTracker):
|
|
19
|
+
"""Lucas-Kanade optical flow tracker.
|
|
20
|
+
|
|
21
|
+
Implements pyramidal Lucas-Kanade optical flow tracking.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
win_size: Window size (height, width) for 2D or (depth, height, width) for 3D.
|
|
25
|
+
max_level: Number of pyramid levels (0 means no pyramid).
|
|
26
|
+
max_iterations: Maximum iterations per pyramid level.
|
|
27
|
+
epsilon: Convergence threshold for iterative solver.
|
|
28
|
+
**kwargs: Additional parameters.
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
.. doctest::
|
|
32
|
+
|
|
33
|
+
>>> from zea.tracking import LucasKanadeTracker
|
|
34
|
+
>>> import numpy as np
|
|
35
|
+
|
|
36
|
+
>>> tracker = LucasKanadeTracker(win_size=(32, 32), max_level=3)
|
|
37
|
+
>>> frame1 = np.random.rand(100, 100).astype("float32")
|
|
38
|
+
>>> frame2 = np.random.rand(100, 100).astype("float32")
|
|
39
|
+
>>> points = np.array([[50.5, 55.2], [60.1, 65.8]], dtype="float32")
|
|
40
|
+
>>> new_points = tracker.track(frame1, frame2, points)
|
|
41
|
+
>>> new_points.shape
|
|
42
|
+
(2, 2)
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
win_size: Tuple[int, ...] = (32, 32),
|
|
48
|
+
max_level: int = 3,
|
|
49
|
+
max_iterations: int = 30,
|
|
50
|
+
epsilon: float = 0.01,
|
|
51
|
+
**kwargs,
|
|
52
|
+
):
|
|
53
|
+
"""Initialize custom Lucas-Kanade tracker."""
|
|
54
|
+
self.ndim = len(win_size)
|
|
55
|
+
|
|
56
|
+
super().__init__(ndim=self.ndim, **kwargs)
|
|
57
|
+
|
|
58
|
+
self.win_size = win_size
|
|
59
|
+
self.max_level = max_level
|
|
60
|
+
self.max_iterations = max_iterations
|
|
61
|
+
self.epsilon = epsilon
|
|
62
|
+
|
|
63
|
+
self.half_win = tuple(w // 2 for w in win_size)
|
|
64
|
+
|
|
65
|
+
def track(
|
|
66
|
+
self,
|
|
67
|
+
prev_frame,
|
|
68
|
+
next_frame,
|
|
69
|
+
points,
|
|
70
|
+
) -> Tuple:
|
|
71
|
+
"""
|
|
72
|
+
Track points using custom pyramidal Lucas-Kanade.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
prev_frame: Previous frame/volume (tensor), shape (H, W) for 2D or (D, H, W) for 3D.
|
|
76
|
+
next_frame: Next frame/volume (tensor), shape (H, W) for 2D or (D, H, W) for 3D.
|
|
77
|
+
points: Points to track (tensor), shape (N, ndim) in (y, x) or (z, y, x) format.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
new_points: Tracked points as tensor, shape (N, ndim).
|
|
81
|
+
"""
|
|
82
|
+
if self.ndim not in [2, 3]:
|
|
83
|
+
raise NotImplementedError(f"Only 2D and 3D tracking supported, got {self.ndim}D")
|
|
84
|
+
|
|
85
|
+
# Normalize frames to [0, 1]
|
|
86
|
+
prev_norm = translate(prev_frame, range_to=(0, 1))
|
|
87
|
+
next_norm = translate(next_frame, range_to=(0, 1))
|
|
88
|
+
|
|
89
|
+
# Build pyramids
|
|
90
|
+
if self.max_level > 0:
|
|
91
|
+
prev_pyr = self._build_pyramid(prev_norm, self.max_level + 1)
|
|
92
|
+
next_pyr = self._build_pyramid(next_norm, self.max_level + 1)
|
|
93
|
+
else:
|
|
94
|
+
prev_pyr = [prev_norm]
|
|
95
|
+
next_pyr = [next_norm]
|
|
96
|
+
|
|
97
|
+
n_levels = len(prev_pyr)
|
|
98
|
+
n_points = int(points.shape[0])
|
|
99
|
+
|
|
100
|
+
# Start at coarsest level
|
|
101
|
+
scale = 2 ** (n_levels - 1)
|
|
102
|
+
curr_points = points / scale
|
|
103
|
+
flows = ops.zeros((n_points, self.ndim), dtype="float32")
|
|
104
|
+
|
|
105
|
+
# Track through pyramid levels
|
|
106
|
+
for level in range(n_levels):
|
|
107
|
+
prev_img = prev_pyr[level]
|
|
108
|
+
next_img = next_pyr[level]
|
|
109
|
+
|
|
110
|
+
# Track each point
|
|
111
|
+
new_flows = []
|
|
112
|
+
|
|
113
|
+
for i in range(n_points):
|
|
114
|
+
pt = curr_points[i]
|
|
115
|
+
flow_guess = flows[i]
|
|
116
|
+
|
|
117
|
+
flow = self._track_point(prev_img, next_img, pt, flow_guess)
|
|
118
|
+
new_flows.append(flow)
|
|
119
|
+
|
|
120
|
+
flows = ops.stack(new_flows)
|
|
121
|
+
|
|
122
|
+
# Scale for next level (if not at finest)
|
|
123
|
+
if level < n_levels - 1:
|
|
124
|
+
flows = flows * 2.0
|
|
125
|
+
curr_points = curr_points * 2.0
|
|
126
|
+
|
|
127
|
+
# Final points at full resolution
|
|
128
|
+
new_points = points + flows
|
|
129
|
+
|
|
130
|
+
return new_points
|
|
131
|
+
|
|
132
|
+
def _build_pyramid(self, image, n_levels: int) -> list:
|
|
133
|
+
"""Build Gaussian pyramid."""
|
|
134
|
+
pyramid = [image]
|
|
135
|
+
for _ in range(1, n_levels):
|
|
136
|
+
curr = pyramid[-1]
|
|
137
|
+
shape = ops.shape(curr)
|
|
138
|
+
|
|
139
|
+
# Check minimum size based on dimensionality
|
|
140
|
+
if self.ndim == 2:
|
|
141
|
+
h, w = shape[0], shape[1]
|
|
142
|
+
min_size = ops.minimum(h, w)
|
|
143
|
+
if min_size < 4:
|
|
144
|
+
break
|
|
145
|
+
else: # 3D
|
|
146
|
+
d, h, w = shape[0], shape[1], shape[2]
|
|
147
|
+
min_size = ops.minimum(ops.minimum(d, h), w)
|
|
148
|
+
if min_size < 4:
|
|
149
|
+
break
|
|
150
|
+
|
|
151
|
+
blurred = gaussian_filter(curr, sigma=0.849, mode="reflect")
|
|
152
|
+
|
|
153
|
+
# Downsample by 2x using map_coordinates
|
|
154
|
+
if self.ndim == 2:
|
|
155
|
+
new_h, new_w = h // 2, w // 2
|
|
156
|
+
# Create downsampled coordinate grid
|
|
157
|
+
y_coords = ops.linspace(0, h - 1, new_h)
|
|
158
|
+
x_coords = ops.linspace(0, w - 1, new_w)
|
|
159
|
+
grid_y, grid_x = ops.meshgrid(y_coords, x_coords, indexing="ij")
|
|
160
|
+
coords = ops.stack([grid_y, grid_x], axis=0)
|
|
161
|
+
downsampled = ops.image.map_coordinates(blurred, coords, order=1)
|
|
162
|
+
else: # 3D
|
|
163
|
+
new_d, new_h, new_w = d // 2, h // 2, w // 2
|
|
164
|
+
# Create downsampled coordinate grid
|
|
165
|
+
z_coords = ops.linspace(0, d - 1, new_d)
|
|
166
|
+
y_coords = ops.linspace(0, h - 1, new_h)
|
|
167
|
+
x_coords = ops.linspace(0, w - 1, new_w)
|
|
168
|
+
grid_z, grid_y, grid_x = ops.meshgrid(z_coords, y_coords, x_coords, indexing="ij")
|
|
169
|
+
coords = ops.stack([grid_z, grid_y, grid_x], axis=0)
|
|
170
|
+
downsampled = ops.image.map_coordinates(blurred, coords, order=1)
|
|
171
|
+
|
|
172
|
+
pyramid.append(downsampled)
|
|
173
|
+
return pyramid[::-1]
|
|
174
|
+
|
|
175
|
+
def _track_point(
|
|
176
|
+
self,
|
|
177
|
+
prev_img,
|
|
178
|
+
next_img,
|
|
179
|
+
point,
|
|
180
|
+
flow_guess,
|
|
181
|
+
):
|
|
182
|
+
"""Track a single point using iterative Lucas-Kanade."""
|
|
183
|
+
# Extract template window
|
|
184
|
+
template, valid_template = self._extract_window(prev_img, point)
|
|
185
|
+
if not valid_template:
|
|
186
|
+
return flow_guess
|
|
187
|
+
|
|
188
|
+
# Compute template gradients (Sobel) - returns tensors
|
|
189
|
+
gradients = self._sobel_gradients(template)
|
|
190
|
+
|
|
191
|
+
# Flatten gradients for 2D or 3D
|
|
192
|
+
if self.ndim == 2:
|
|
193
|
+
Iy, Ix = gradients
|
|
194
|
+
Ix_flat = ops.reshape(Ix, [-1])
|
|
195
|
+
Iy_flat = ops.reshape(Iy, [-1])
|
|
196
|
+
|
|
197
|
+
# Structure tensor 2D components
|
|
198
|
+
IxIx = ops.sum(Ix_flat * Ix_flat)
|
|
199
|
+
IxIy = ops.sum(Ix_flat * Iy_flat)
|
|
200
|
+
IyIy = ops.sum(Iy_flat * Iy_flat)
|
|
201
|
+
|
|
202
|
+
else: # 3D
|
|
203
|
+
Iz, Iy, Ix = gradients
|
|
204
|
+
Ix_flat = ops.reshape(Ix, [-1])
|
|
205
|
+
Iy_flat = ops.reshape(Iy, [-1])
|
|
206
|
+
Iz_flat = ops.reshape(Iz, [-1])
|
|
207
|
+
|
|
208
|
+
# Structure tensor 3D components
|
|
209
|
+
IxIx = ops.sum(Ix_flat * Ix_flat)
|
|
210
|
+
IxIy = ops.sum(Ix_flat * Iy_flat)
|
|
211
|
+
IxIz = ops.sum(Ix_flat * Iz_flat)
|
|
212
|
+
IyIy = ops.sum(Iy_flat * Iy_flat)
|
|
213
|
+
IyIz = ops.sum(Iy_flat * Iz_flat)
|
|
214
|
+
IzIz = ops.sum(Iz_flat * Iz_flat)
|
|
215
|
+
|
|
216
|
+
# Iterative refinement (keep as tensors)
|
|
217
|
+
flow = flow_guess
|
|
218
|
+
|
|
219
|
+
for iteration in range(self.max_iterations):
|
|
220
|
+
# Extract warped window from next image
|
|
221
|
+
warped_pt = point + flow
|
|
222
|
+
warped, valid_warped = self._extract_window(next_img, warped_pt)
|
|
223
|
+
|
|
224
|
+
if not valid_warped:
|
|
225
|
+
break
|
|
226
|
+
|
|
227
|
+
# Image difference
|
|
228
|
+
diff = template - warped
|
|
229
|
+
diff_flat = ops.reshape(diff, [-1])
|
|
230
|
+
|
|
231
|
+
# Solve for flow update
|
|
232
|
+
if self.ndim == 2:
|
|
233
|
+
# Build structure tensor matrix (2x2)
|
|
234
|
+
structure = ops.stack(
|
|
235
|
+
[
|
|
236
|
+
ops.stack([IxIx, IxIy]),
|
|
237
|
+
ops.stack([IxIy, IyIy]),
|
|
238
|
+
],
|
|
239
|
+
axis=0,
|
|
240
|
+
)
|
|
241
|
+
# Add regularization to diagonal
|
|
242
|
+
structure = structure + ops.eye(2, dtype=structure.dtype) * 1e-5
|
|
243
|
+
|
|
244
|
+
# Right-hand side vector
|
|
245
|
+
b_x = ops.sum(Ix_flat * diff_flat)
|
|
246
|
+
b_y = ops.sum(Iy_flat * diff_flat)
|
|
247
|
+
rhs = ops.reshape(ops.stack([b_x, b_y]), (2, 1))
|
|
248
|
+
|
|
249
|
+
# Solve: structure * delta_xy = rhs
|
|
250
|
+
delta_xy = ops.matmul(ops.linalg.inv(structure), rhs)
|
|
251
|
+
delta_xy = ops.reshape(delta_xy, (2,))
|
|
252
|
+
|
|
253
|
+
# Reorder to (y, x)
|
|
254
|
+
delta = ops.stack([delta_xy[1], delta_xy[0]])
|
|
255
|
+
|
|
256
|
+
else: # 3D
|
|
257
|
+
# Build structure tensor matrix (3x3)
|
|
258
|
+
structure = ops.stack(
|
|
259
|
+
[
|
|
260
|
+
ops.stack([IxIx, IxIy, IxIz]),
|
|
261
|
+
ops.stack([IxIy, IyIy, IyIz]),
|
|
262
|
+
ops.stack([IxIz, IyIz, IzIz]),
|
|
263
|
+
],
|
|
264
|
+
axis=0,
|
|
265
|
+
)
|
|
266
|
+
# Add regularization to diagonal
|
|
267
|
+
structure = structure + ops.eye(3, dtype=structure.dtype) * 1e-5
|
|
268
|
+
|
|
269
|
+
# Right-hand side vector
|
|
270
|
+
b_x = ops.sum(Ix_flat * diff_flat)
|
|
271
|
+
b_y = ops.sum(Iy_flat * diff_flat)
|
|
272
|
+
b_z = ops.sum(Iz_flat * diff_flat)
|
|
273
|
+
rhs = ops.reshape(ops.stack([b_x, b_y, b_z]), (3, 1))
|
|
274
|
+
|
|
275
|
+
# Solve: structure * delta_xyz = rhs
|
|
276
|
+
delta_xyz = ops.matmul(ops.linalg.inv(structure), rhs)
|
|
277
|
+
delta_xyz = ops.reshape(delta_xyz, (3,))
|
|
278
|
+
|
|
279
|
+
# Reorder to (z, y, x)
|
|
280
|
+
delta = ops.stack([delta_xyz[2], delta_xyz[1], delta_xyz[0]])
|
|
281
|
+
|
|
282
|
+
# Update flow
|
|
283
|
+
flow = flow + delta
|
|
284
|
+
|
|
285
|
+
# Check convergence
|
|
286
|
+
delta_norm = ops.sqrt(ops.sum(delta * delta))
|
|
287
|
+
if delta_norm < self.epsilon:
|
|
288
|
+
break
|
|
289
|
+
|
|
290
|
+
return flow
|
|
291
|
+
|
|
292
|
+
def _extract_window(self, image, point):
|
|
293
|
+
"""Extract window around point with subpixel interpolation."""
|
|
294
|
+
if self.ndim == 2:
|
|
295
|
+
return self._extract_window_2d(image, point)
|
|
296
|
+
elif self.ndim == 3:
|
|
297
|
+
return self._extract_window_3d(image, point)
|
|
298
|
+
else:
|
|
299
|
+
raise ValueError(f"Unsupported ndim: {self.ndim}")
|
|
300
|
+
|
|
301
|
+
def _extract_window_2d(self, image, point):
|
|
302
|
+
"""Extract 2D window with bilinear interpolation using map_coordinates."""
|
|
303
|
+
hy, hx = self.half_win
|
|
304
|
+
h, w = ops.shape(image)[0], ops.shape(image)[1]
|
|
305
|
+
|
|
306
|
+
py, px = point[0], point[1]
|
|
307
|
+
|
|
308
|
+
# Bounds check
|
|
309
|
+
if ops.any(
|
|
310
|
+
ops.stack(
|
|
311
|
+
[
|
|
312
|
+
py < hy + 1,
|
|
313
|
+
py >= ops.cast(h, py.dtype) - hy - 1,
|
|
314
|
+
px < hx + 1,
|
|
315
|
+
px >= ops.cast(w, px.dtype) - hx - 1,
|
|
316
|
+
]
|
|
317
|
+
)
|
|
318
|
+
):
|
|
319
|
+
return ops.zeros((2 * hy + 1, 2 * hx + 1), dtype="float32"), False
|
|
320
|
+
|
|
321
|
+
# Create coordinate grid for the window
|
|
322
|
+
# Grid centered at point location
|
|
323
|
+
y_coords = ops.arange(2 * hy + 1, dtype="float32") + py - hy
|
|
324
|
+
x_coords = ops.arange(2 * hx + 1, dtype="float32") + px - hx
|
|
325
|
+
grid_y, grid_x = ops.meshgrid(y_coords, x_coords, indexing="ij")
|
|
326
|
+
|
|
327
|
+
# Stack coordinates for map_coordinates
|
|
328
|
+
coords = ops.stack([grid_y, grid_x], axis=0)
|
|
329
|
+
|
|
330
|
+
# Extract window using bilinear interpolation
|
|
331
|
+
window = ops.image.map_coordinates(image, coords, order=1)
|
|
332
|
+
|
|
333
|
+
return window, True
|
|
334
|
+
|
|
335
|
+
def _extract_window_3d(self, image, point):
|
|
336
|
+
"""Extract 3D window with trilinear interpolation using map_coordinates."""
|
|
337
|
+
hz, hy, hx = self.half_win
|
|
338
|
+
d, h, w = ops.shape(image)[0], ops.shape(image)[1], ops.shape(image)[2]
|
|
339
|
+
|
|
340
|
+
pz, py, px = point[0], point[1], point[2]
|
|
341
|
+
|
|
342
|
+
# Bounds check
|
|
343
|
+
if ops.any(
|
|
344
|
+
ops.stack(
|
|
345
|
+
[
|
|
346
|
+
pz < hz + 1,
|
|
347
|
+
pz >= ops.cast(d, pz.dtype) - hz - 1,
|
|
348
|
+
py < hy + 1,
|
|
349
|
+
py >= ops.cast(h, py.dtype) - hy - 1,
|
|
350
|
+
px < hx + 1,
|
|
351
|
+
px >= ops.cast(w, px.dtype) - hx - 1,
|
|
352
|
+
]
|
|
353
|
+
)
|
|
354
|
+
):
|
|
355
|
+
return ops.zeros((2 * hz + 1, 2 * hy + 1, 2 * hx + 1), dtype="float32"), False
|
|
356
|
+
|
|
357
|
+
# Create coordinate grid for the window
|
|
358
|
+
# Grid centered at point location
|
|
359
|
+
z_coords = ops.arange(2 * hz + 1, dtype="float32") + pz - hz
|
|
360
|
+
y_coords = ops.arange(2 * hy + 1, dtype="float32") + py - hy
|
|
361
|
+
x_coords = ops.arange(2 * hx + 1, dtype="float32") + px - hx
|
|
362
|
+
grid_z, grid_y, grid_x = ops.meshgrid(z_coords, y_coords, x_coords, indexing="ij")
|
|
363
|
+
|
|
364
|
+
# Stack coordinates for map_coordinates
|
|
365
|
+
coords = ops.stack([grid_z, grid_y, grid_x], axis=0)
|
|
366
|
+
|
|
367
|
+
# Extract window using trilinear interpolation
|
|
368
|
+
window = ops.image.map_coordinates(image, coords, order=1)
|
|
369
|
+
|
|
370
|
+
return window, True
|
|
371
|
+
|
|
372
|
+
def _sobel_gradients(self, image):
|
|
373
|
+
"""Compute Sobel gradients for 2D or 3D."""
|
|
374
|
+
if self.ndim == 2:
|
|
375
|
+
return self._sobel_gradients_2d(image)
|
|
376
|
+
elif self.ndim == 3:
|
|
377
|
+
return self._sobel_gradients_3d(image)
|
|
378
|
+
else:
|
|
379
|
+
raise ValueError(f"Unsupported ndim: {self.ndim}")
|
|
380
|
+
|
|
381
|
+
def _sobel_gradients_2d(self, image):
|
|
382
|
+
"""Compute 2D Sobel gradients using keras.ops."""
|
|
383
|
+
# Standard Sobel kernels
|
|
384
|
+
sobel_y = ops.convert_to_tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype="float32") / 8.0
|
|
385
|
+
sobel_x = ops.convert_to_tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype="float32") / 8.0
|
|
386
|
+
|
|
387
|
+
h, w = ops.shape(image)[0], ops.shape(image)[1]
|
|
388
|
+
|
|
389
|
+
padded = ops.pad(image, [[1, 1], [1, 1]], mode="reflect")
|
|
390
|
+
|
|
391
|
+
# Reshape for conv: image needs (batch, height, width, channels)
|
|
392
|
+
img_4d = ops.reshape(padded, [1, h + 2, w + 2, 1])
|
|
393
|
+
sobel_y_4d = ops.reshape(sobel_y, [3, 3, 1, 1])
|
|
394
|
+
sobel_x_4d = ops.reshape(sobel_x, [3, 3, 1, 1])
|
|
395
|
+
|
|
396
|
+
Iy_4d = ops.conv(img_4d, sobel_y_4d, padding="valid")
|
|
397
|
+
Ix_4d = ops.conv(img_4d, sobel_x_4d, padding="valid")
|
|
398
|
+
|
|
399
|
+
# Reshape back to 2D
|
|
400
|
+
Iy = ops.reshape(Iy_4d, [h, w])
|
|
401
|
+
Ix = ops.reshape(Ix_4d, [h, w])
|
|
402
|
+
|
|
403
|
+
return Iy, Ix
|
|
404
|
+
|
|
405
|
+
def _sobel_gradients_3d(self, image):
|
|
406
|
+
"""Compute 3D Sobel gradients using keras.ops."""
|
|
407
|
+
# 3D Sobel kernels (separable: smooth in 2 dims, gradient in 1 dim)
|
|
408
|
+
# Gradient in z-direction
|
|
409
|
+
sobel_z = (
|
|
410
|
+
ops.convert_to_tensor(
|
|
411
|
+
[
|
|
412
|
+
[[-1, -2, -1], [-2, -4, -2], [-1, -2, -1]],
|
|
413
|
+
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
|
|
414
|
+
[[1, 2, 1], [2, 4, 2], [1, 2, 1]],
|
|
415
|
+
],
|
|
416
|
+
dtype="float32",
|
|
417
|
+
)
|
|
418
|
+
/ 32.0
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
# Gradient in y-direction
|
|
422
|
+
sobel_y = (
|
|
423
|
+
ops.convert_to_tensor(
|
|
424
|
+
[
|
|
425
|
+
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
|
|
426
|
+
[[-2, -4, -2], [0, 0, 0], [2, 4, 2]],
|
|
427
|
+
[[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
|
|
428
|
+
],
|
|
429
|
+
dtype="float32",
|
|
430
|
+
)
|
|
431
|
+
/ 32.0
|
|
432
|
+
)
|
|
433
|
+
|
|
434
|
+
# Gradient in x-direction
|
|
435
|
+
sobel_x = (
|
|
436
|
+
ops.convert_to_tensor(
|
|
437
|
+
[
|
|
438
|
+
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
|
|
439
|
+
[[-2, 0, 2], [-4, 0, 4], [-2, 0, 2]],
|
|
440
|
+
[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
|
|
441
|
+
],
|
|
442
|
+
dtype="float32",
|
|
443
|
+
)
|
|
444
|
+
/ 32.0
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
d, h, w = ops.shape(image)[0], ops.shape(image)[1], ops.shape(image)[2]
|
|
448
|
+
|
|
449
|
+
padded = ops.pad(image, [[1, 1], [1, 1], [1, 1]], mode="reflect")
|
|
450
|
+
|
|
451
|
+
# Reshape for conv: image needs (batch, depth, height, width, channels)
|
|
452
|
+
img_5d = ops.reshape(padded, [1, d + 2, h + 2, w + 2, 1])
|
|
453
|
+
sobel_z_5d = ops.reshape(sobel_z, [3, 3, 3, 1, 1])
|
|
454
|
+
sobel_y_5d = ops.reshape(sobel_y, [3, 3, 3, 1, 1])
|
|
455
|
+
sobel_x_5d = ops.reshape(sobel_x, [3, 3, 3, 1, 1])
|
|
456
|
+
|
|
457
|
+
# Apply 3D convolution with 'valid' padding (we pre-padded)
|
|
458
|
+
Iz_5d = ops.conv(img_5d, sobel_z_5d, padding="valid")
|
|
459
|
+
Iy_5d = ops.conv(img_5d, sobel_y_5d, padding="valid")
|
|
460
|
+
Ix_5d = ops.conv(img_5d, sobel_x_5d, padding="valid")
|
|
461
|
+
|
|
462
|
+
# Reshape back to 3D
|
|
463
|
+
Iz = ops.reshape(Iz_5d, [d, h, w])
|
|
464
|
+
Iy = ops.reshape(Iy_5d, [d, h, w])
|
|
465
|
+
Ix = ops.reshape(Ix_5d, [d, h, w])
|
|
466
|
+
|
|
467
|
+
return (Iz, Iy, Ix)
|
|
468
|
+
|
|
469
|
+
def __repr__(self):
|
|
470
|
+
"""String representation."""
|
|
471
|
+
return (
|
|
472
|
+
f"LucasKanadeTracker(win_size={self.win_size}, max_level={self.max_level}, "
|
|
473
|
+
f"max_iterations={self.max_iterations}, epsilon={self.epsilon})"
|
|
474
|
+
)
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
"""Segmentation-based tracker using contour matching.
|
|
2
|
+
|
|
3
|
+
.. seealso::
|
|
4
|
+
A tutorial notebook where this model is used:
|
|
5
|
+
:doc:`../notebooks/models/speckle_tracking_example`.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from keras import ops
|
|
10
|
+
|
|
11
|
+
from zea.tensor_ops import find_contour
|
|
12
|
+
|
|
13
|
+
from .base import BaseTracker
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SegmentationTracker(BaseTracker):
|
|
17
|
+
"""Segmentation-based tracker.
|
|
18
|
+
|
|
19
|
+
This tracker segments each frame independently and finds the closest points
|
|
20
|
+
on the segmented contour to the previous frame's points.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
model: Segmentation model with a `call` method.
|
|
24
|
+
preprocess_fn: Optional preprocessing function to apply to frames before segmentation.
|
|
25
|
+
postprocess_fn: Optional postprocessing function to apply to segmentation output, which
|
|
26
|
+
should return a binary mask of the target structure.
|
|
27
|
+
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
model,
|
|
33
|
+
preprocess_fn: callable = None,
|
|
34
|
+
postprocess_fn: callable = None,
|
|
35
|
+
):
|
|
36
|
+
"""Initialize segmentation-based tracker."""
|
|
37
|
+
super().__init__(ndim=2)
|
|
38
|
+
self.model = model
|
|
39
|
+
self.preprocess_fn = preprocess_fn
|
|
40
|
+
self.postprocess_fn = postprocess_fn
|
|
41
|
+
|
|
42
|
+
if self.preprocess_fn is None:
|
|
43
|
+
self.preprocess_fn = lambda frame: frame
|
|
44
|
+
|
|
45
|
+
if self.postprocess_fn is None:
|
|
46
|
+
raise ValueError("A postprocess_fn must be provided to extract binary masks.")
|
|
47
|
+
|
|
48
|
+
def track(
|
|
49
|
+
self,
|
|
50
|
+
prev_frame, # noqa F821
|
|
51
|
+
next_frame,
|
|
52
|
+
points,
|
|
53
|
+
):
|
|
54
|
+
"""
|
|
55
|
+
Track points by segmenting next_frame and finding closest contour points.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
prev_frame: Previous frame (not used, kept for interface compatibility).
|
|
59
|
+
next_frame: Next frame to segment, shape (H, W).
|
|
60
|
+
points: Points from previous frame, shape (N, 2) in (row, col) format.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
new_points: Closest points on next frame's contour, shape (N, 2).
|
|
64
|
+
"""
|
|
65
|
+
orig_shape = ops.shape(next_frame)
|
|
66
|
+
|
|
67
|
+
frame_input = self.preprocess_fn(next_frame)
|
|
68
|
+
|
|
69
|
+
outputs = self.model.call(frame_input)
|
|
70
|
+
|
|
71
|
+
mask = self.postprocess_fn(outputs, orig_shape)
|
|
72
|
+
|
|
73
|
+
contour_points = find_contour(mask)
|
|
74
|
+
|
|
75
|
+
if ops.shape(contour_points)[0] > 0:
|
|
76
|
+
new_points = self._find_closest_points(points, contour_points)
|
|
77
|
+
else:
|
|
78
|
+
new_points = points
|
|
79
|
+
|
|
80
|
+
return new_points
|
|
81
|
+
|
|
82
|
+
def _find_closest_points(self, query_points, target_points):
|
|
83
|
+
"""Find closest target points to each query point.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
query_points: Points to match, shape (N, 2).
|
|
87
|
+
target_points: Points to match to, shape (M, 2).
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
Closest target points, shape (N, 2).
|
|
91
|
+
"""
|
|
92
|
+
# Compute pairwise squared distances
|
|
93
|
+
# query_points: (N, 2), target_points: (M, 2)
|
|
94
|
+
# Expand dims: (N, 1, 2) and (1, M, 2)
|
|
95
|
+
query_expanded = ops.expand_dims(query_points, axis=1) # (N, 1, 2)
|
|
96
|
+
target_expanded = ops.expand_dims(target_points, axis=0) # (1, M, 2)
|
|
97
|
+
|
|
98
|
+
# Compute squared distances: (N, M)
|
|
99
|
+
diff = query_expanded - target_expanded
|
|
100
|
+
sq_distances = ops.sum(diff * diff, axis=2)
|
|
101
|
+
|
|
102
|
+
closest_indices = ops.argmin(sq_distances, axis=1)
|
|
103
|
+
|
|
104
|
+
closest_points = ops.take(target_points, closest_indices, axis=0)
|
|
105
|
+
|
|
106
|
+
return closest_points
|
|
107
|
+
|
|
108
|
+
def __repr__(self):
|
|
109
|
+
"""String representation."""
|
|
110
|
+
return f"SegmentationTracker(model={self.model.__class__.__name__})"
|