zea 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. zea/__init__.py +1 -1
  2. zea/backend/tensorflow/dataloader.py +0 -4
  3. zea/beamform/pixelgrid.py +1 -1
  4. zea/data/__init__.py +0 -9
  5. zea/data/augmentations.py +221 -28
  6. zea/data/convert/__init__.py +1 -6
  7. zea/data/convert/__main__.py +123 -0
  8. zea/data/convert/camus.py +99 -39
  9. zea/data/convert/echonet.py +183 -82
  10. zea/data/convert/echonetlvh/README.md +2 -3
  11. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +173 -102
  12. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  13. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  14. zea/data/convert/picmus.py +37 -40
  15. zea/data/convert/utils.py +86 -0
  16. zea/data/convert/{matlab.py → verasonics.py} +33 -61
  17. zea/data/data_format.py +124 -4
  18. zea/data/dataloader.py +12 -7
  19. zea/data/datasets.py +109 -70
  20. zea/data/file.py +91 -82
  21. zea/data/file_operations.py +496 -0
  22. zea/data/preset_utils.py +1 -1
  23. zea/display.py +7 -8
  24. zea/internal/checks.py +6 -12
  25. zea/internal/operators.py +4 -0
  26. zea/io_lib.py +108 -160
  27. zea/models/__init__.py +1 -1
  28. zea/models/diffusion.py +62 -11
  29. zea/models/lv_segmentation.py +2 -0
  30. zea/ops.py +398 -158
  31. zea/scan.py +18 -8
  32. zea/tensor_ops.py +82 -62
  33. zea/tools/fit_scan_cone.py +90 -160
  34. zea/tracking/__init__.py +16 -0
  35. zea/tracking/base.py +94 -0
  36. zea/tracking/lucas_kanade.py +474 -0
  37. zea/tracking/segmentation.py +110 -0
  38. zea/utils.py +11 -2
  39. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/METADATA +3 -1
  40. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/RECORD +43 -35
  41. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  42. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  43. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,474 @@
1
+ """Lucas-Kanade optical flow tracker.
2
+
3
+ .. seealso::
4
+ A tutorial notebook where this model is used:
5
+ :doc:`../notebooks/models/speckle_tracking_example`.
6
+
7
+ """
8
+
9
+ from typing import Tuple
10
+
11
+ from keras import ops
12
+
13
+ from zea.tensor_ops import gaussian_filter, translate
14
+
15
+ from .base import BaseTracker
16
+
17
+
18
+ class LucasKanadeTracker(BaseTracker):
19
+ """Lucas-Kanade optical flow tracker.
20
+
21
+ Implements pyramidal Lucas-Kanade optical flow tracking.
22
+
23
+ Args:
24
+ win_size: Window size (height, width) for 2D or (depth, height, width) for 3D.
25
+ max_level: Number of pyramid levels (0 means no pyramid).
26
+ max_iterations: Maximum iterations per pyramid level.
27
+ epsilon: Convergence threshold for iterative solver.
28
+ **kwargs: Additional parameters.
29
+
30
+ Example:
31
+ .. doctest::
32
+
33
+ >>> from zea.tracking import LucasKanadeTracker
34
+ >>> import numpy as np
35
+
36
+ >>> tracker = LucasKanadeTracker(win_size=(32, 32), max_level=3)
37
+ >>> frame1 = np.random.rand(100, 100).astype("float32")
38
+ >>> frame2 = np.random.rand(100, 100).astype("float32")
39
+ >>> points = np.array([[50.5, 55.2], [60.1, 65.8]], dtype="float32")
40
+ >>> new_points = tracker.track(frame1, frame2, points)
41
+ >>> new_points.shape
42
+ (2, 2)
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ win_size: Tuple[int, ...] = (32, 32),
48
+ max_level: int = 3,
49
+ max_iterations: int = 30,
50
+ epsilon: float = 0.01,
51
+ **kwargs,
52
+ ):
53
+ """Initialize custom Lucas-Kanade tracker."""
54
+ self.ndim = len(win_size)
55
+
56
+ super().__init__(ndim=self.ndim, **kwargs)
57
+
58
+ self.win_size = win_size
59
+ self.max_level = max_level
60
+ self.max_iterations = max_iterations
61
+ self.epsilon = epsilon
62
+
63
+ self.half_win = tuple(w // 2 for w in win_size)
64
+
65
+ def track(
66
+ self,
67
+ prev_frame,
68
+ next_frame,
69
+ points,
70
+ ) -> Tuple:
71
+ """
72
+ Track points using custom pyramidal Lucas-Kanade.
73
+
74
+ Args:
75
+ prev_frame: Previous frame/volume (tensor), shape (H, W) for 2D or (D, H, W) for 3D.
76
+ next_frame: Next frame/volume (tensor), shape (H, W) for 2D or (D, H, W) for 3D.
77
+ points: Points to track (tensor), shape (N, ndim) in (y, x) or (z, y, x) format.
78
+
79
+ Returns:
80
+ new_points: Tracked points as tensor, shape (N, ndim).
81
+ """
82
+ if self.ndim not in [2, 3]:
83
+ raise NotImplementedError(f"Only 2D and 3D tracking supported, got {self.ndim}D")
84
+
85
+ # Normalize frames to [0, 1]
86
+ prev_norm = translate(prev_frame, range_to=(0, 1))
87
+ next_norm = translate(next_frame, range_to=(0, 1))
88
+
89
+ # Build pyramids
90
+ if self.max_level > 0:
91
+ prev_pyr = self._build_pyramid(prev_norm, self.max_level + 1)
92
+ next_pyr = self._build_pyramid(next_norm, self.max_level + 1)
93
+ else:
94
+ prev_pyr = [prev_norm]
95
+ next_pyr = [next_norm]
96
+
97
+ n_levels = len(prev_pyr)
98
+ n_points = int(points.shape[0])
99
+
100
+ # Start at coarsest level
101
+ scale = 2 ** (n_levels - 1)
102
+ curr_points = points / scale
103
+ flows = ops.zeros((n_points, self.ndim), dtype="float32")
104
+
105
+ # Track through pyramid levels
106
+ for level in range(n_levels):
107
+ prev_img = prev_pyr[level]
108
+ next_img = next_pyr[level]
109
+
110
+ # Track each point
111
+ new_flows = []
112
+
113
+ for i in range(n_points):
114
+ pt = curr_points[i]
115
+ flow_guess = flows[i]
116
+
117
+ flow = self._track_point(prev_img, next_img, pt, flow_guess)
118
+ new_flows.append(flow)
119
+
120
+ flows = ops.stack(new_flows)
121
+
122
+ # Scale for next level (if not at finest)
123
+ if level < n_levels - 1:
124
+ flows = flows * 2.0
125
+ curr_points = curr_points * 2.0
126
+
127
+ # Final points at full resolution
128
+ new_points = points + flows
129
+
130
+ return new_points
131
+
132
+ def _build_pyramid(self, image, n_levels: int) -> list:
133
+ """Build Gaussian pyramid."""
134
+ pyramid = [image]
135
+ for _ in range(1, n_levels):
136
+ curr = pyramid[-1]
137
+ shape = ops.shape(curr)
138
+
139
+ # Check minimum size based on dimensionality
140
+ if self.ndim == 2:
141
+ h, w = shape[0], shape[1]
142
+ min_size = ops.minimum(h, w)
143
+ if min_size < 4:
144
+ break
145
+ else: # 3D
146
+ d, h, w = shape[0], shape[1], shape[2]
147
+ min_size = ops.minimum(ops.minimum(d, h), w)
148
+ if min_size < 4:
149
+ break
150
+
151
+ blurred = gaussian_filter(curr, sigma=0.849, mode="reflect")
152
+
153
+ # Downsample by 2x using map_coordinates
154
+ if self.ndim == 2:
155
+ new_h, new_w = h // 2, w // 2
156
+ # Create downsampled coordinate grid
157
+ y_coords = ops.linspace(0, h - 1, new_h)
158
+ x_coords = ops.linspace(0, w - 1, new_w)
159
+ grid_y, grid_x = ops.meshgrid(y_coords, x_coords, indexing="ij")
160
+ coords = ops.stack([grid_y, grid_x], axis=0)
161
+ downsampled = ops.image.map_coordinates(blurred, coords, order=1)
162
+ else: # 3D
163
+ new_d, new_h, new_w = d // 2, h // 2, w // 2
164
+ # Create downsampled coordinate grid
165
+ z_coords = ops.linspace(0, d - 1, new_d)
166
+ y_coords = ops.linspace(0, h - 1, new_h)
167
+ x_coords = ops.linspace(0, w - 1, new_w)
168
+ grid_z, grid_y, grid_x = ops.meshgrid(z_coords, y_coords, x_coords, indexing="ij")
169
+ coords = ops.stack([grid_z, grid_y, grid_x], axis=0)
170
+ downsampled = ops.image.map_coordinates(blurred, coords, order=1)
171
+
172
+ pyramid.append(downsampled)
173
+ return pyramid[::-1]
174
+
175
+ def _track_point(
176
+ self,
177
+ prev_img,
178
+ next_img,
179
+ point,
180
+ flow_guess,
181
+ ):
182
+ """Track a single point using iterative Lucas-Kanade."""
183
+ # Extract template window
184
+ template, valid_template = self._extract_window(prev_img, point)
185
+ if not valid_template:
186
+ return flow_guess
187
+
188
+ # Compute template gradients (Sobel) - returns tensors
189
+ gradients = self._sobel_gradients(template)
190
+
191
+ # Flatten gradients for 2D or 3D
192
+ if self.ndim == 2:
193
+ Iy, Ix = gradients
194
+ Ix_flat = ops.reshape(Ix, [-1])
195
+ Iy_flat = ops.reshape(Iy, [-1])
196
+
197
+ # Structure tensor 2D components
198
+ IxIx = ops.sum(Ix_flat * Ix_flat)
199
+ IxIy = ops.sum(Ix_flat * Iy_flat)
200
+ IyIy = ops.sum(Iy_flat * Iy_flat)
201
+
202
+ else: # 3D
203
+ Iz, Iy, Ix = gradients
204
+ Ix_flat = ops.reshape(Ix, [-1])
205
+ Iy_flat = ops.reshape(Iy, [-1])
206
+ Iz_flat = ops.reshape(Iz, [-1])
207
+
208
+ # Structure tensor 3D components
209
+ IxIx = ops.sum(Ix_flat * Ix_flat)
210
+ IxIy = ops.sum(Ix_flat * Iy_flat)
211
+ IxIz = ops.sum(Ix_flat * Iz_flat)
212
+ IyIy = ops.sum(Iy_flat * Iy_flat)
213
+ IyIz = ops.sum(Iy_flat * Iz_flat)
214
+ IzIz = ops.sum(Iz_flat * Iz_flat)
215
+
216
+ # Iterative refinement (keep as tensors)
217
+ flow = flow_guess
218
+
219
+ for iteration in range(self.max_iterations):
220
+ # Extract warped window from next image
221
+ warped_pt = point + flow
222
+ warped, valid_warped = self._extract_window(next_img, warped_pt)
223
+
224
+ if not valid_warped:
225
+ break
226
+
227
+ # Image difference
228
+ diff = template - warped
229
+ diff_flat = ops.reshape(diff, [-1])
230
+
231
+ # Solve for flow update
232
+ if self.ndim == 2:
233
+ # Build structure tensor matrix (2x2)
234
+ structure = ops.stack(
235
+ [
236
+ ops.stack([IxIx, IxIy]),
237
+ ops.stack([IxIy, IyIy]),
238
+ ],
239
+ axis=0,
240
+ )
241
+ # Add regularization to diagonal
242
+ structure = structure + ops.eye(2, dtype=structure.dtype) * 1e-5
243
+
244
+ # Right-hand side vector
245
+ b_x = ops.sum(Ix_flat * diff_flat)
246
+ b_y = ops.sum(Iy_flat * diff_flat)
247
+ rhs = ops.reshape(ops.stack([b_x, b_y]), (2, 1))
248
+
249
+ # Solve: structure * delta_xy = rhs
250
+ delta_xy = ops.matmul(ops.linalg.inv(structure), rhs)
251
+ delta_xy = ops.reshape(delta_xy, (2,))
252
+
253
+ # Reorder to (y, x)
254
+ delta = ops.stack([delta_xy[1], delta_xy[0]])
255
+
256
+ else: # 3D
257
+ # Build structure tensor matrix (3x3)
258
+ structure = ops.stack(
259
+ [
260
+ ops.stack([IxIx, IxIy, IxIz]),
261
+ ops.stack([IxIy, IyIy, IyIz]),
262
+ ops.stack([IxIz, IyIz, IzIz]),
263
+ ],
264
+ axis=0,
265
+ )
266
+ # Add regularization to diagonal
267
+ structure = structure + ops.eye(3, dtype=structure.dtype) * 1e-5
268
+
269
+ # Right-hand side vector
270
+ b_x = ops.sum(Ix_flat * diff_flat)
271
+ b_y = ops.sum(Iy_flat * diff_flat)
272
+ b_z = ops.sum(Iz_flat * diff_flat)
273
+ rhs = ops.reshape(ops.stack([b_x, b_y, b_z]), (3, 1))
274
+
275
+ # Solve: structure * delta_xyz = rhs
276
+ delta_xyz = ops.matmul(ops.linalg.inv(structure), rhs)
277
+ delta_xyz = ops.reshape(delta_xyz, (3,))
278
+
279
+ # Reorder to (z, y, x)
280
+ delta = ops.stack([delta_xyz[2], delta_xyz[1], delta_xyz[0]])
281
+
282
+ # Update flow
283
+ flow = flow + delta
284
+
285
+ # Check convergence
286
+ delta_norm = ops.sqrt(ops.sum(delta * delta))
287
+ if delta_norm < self.epsilon:
288
+ break
289
+
290
+ return flow
291
+
292
+ def _extract_window(self, image, point):
293
+ """Extract window around point with subpixel interpolation."""
294
+ if self.ndim == 2:
295
+ return self._extract_window_2d(image, point)
296
+ elif self.ndim == 3:
297
+ return self._extract_window_3d(image, point)
298
+ else:
299
+ raise ValueError(f"Unsupported ndim: {self.ndim}")
300
+
301
+ def _extract_window_2d(self, image, point):
302
+ """Extract 2D window with bilinear interpolation using map_coordinates."""
303
+ hy, hx = self.half_win
304
+ h, w = ops.shape(image)[0], ops.shape(image)[1]
305
+
306
+ py, px = point[0], point[1]
307
+
308
+ # Bounds check
309
+ if ops.any(
310
+ ops.stack(
311
+ [
312
+ py < hy + 1,
313
+ py >= ops.cast(h, py.dtype) - hy - 1,
314
+ px < hx + 1,
315
+ px >= ops.cast(w, px.dtype) - hx - 1,
316
+ ]
317
+ )
318
+ ):
319
+ return ops.zeros((2 * hy + 1, 2 * hx + 1), dtype="float32"), False
320
+
321
+ # Create coordinate grid for the window
322
+ # Grid centered at point location
323
+ y_coords = ops.arange(2 * hy + 1, dtype="float32") + py - hy
324
+ x_coords = ops.arange(2 * hx + 1, dtype="float32") + px - hx
325
+ grid_y, grid_x = ops.meshgrid(y_coords, x_coords, indexing="ij")
326
+
327
+ # Stack coordinates for map_coordinates
328
+ coords = ops.stack([grid_y, grid_x], axis=0)
329
+
330
+ # Extract window using bilinear interpolation
331
+ window = ops.image.map_coordinates(image, coords, order=1)
332
+
333
+ return window, True
334
+
335
+ def _extract_window_3d(self, image, point):
336
+ """Extract 3D window with trilinear interpolation using map_coordinates."""
337
+ hz, hy, hx = self.half_win
338
+ d, h, w = ops.shape(image)[0], ops.shape(image)[1], ops.shape(image)[2]
339
+
340
+ pz, py, px = point[0], point[1], point[2]
341
+
342
+ # Bounds check
343
+ if ops.any(
344
+ ops.stack(
345
+ [
346
+ pz < hz + 1,
347
+ pz >= ops.cast(d, pz.dtype) - hz - 1,
348
+ py < hy + 1,
349
+ py >= ops.cast(h, py.dtype) - hy - 1,
350
+ px < hx + 1,
351
+ px >= ops.cast(w, px.dtype) - hx - 1,
352
+ ]
353
+ )
354
+ ):
355
+ return ops.zeros((2 * hz + 1, 2 * hy + 1, 2 * hx + 1), dtype="float32"), False
356
+
357
+ # Create coordinate grid for the window
358
+ # Grid centered at point location
359
+ z_coords = ops.arange(2 * hz + 1, dtype="float32") + pz - hz
360
+ y_coords = ops.arange(2 * hy + 1, dtype="float32") + py - hy
361
+ x_coords = ops.arange(2 * hx + 1, dtype="float32") + px - hx
362
+ grid_z, grid_y, grid_x = ops.meshgrid(z_coords, y_coords, x_coords, indexing="ij")
363
+
364
+ # Stack coordinates for map_coordinates
365
+ coords = ops.stack([grid_z, grid_y, grid_x], axis=0)
366
+
367
+ # Extract window using trilinear interpolation
368
+ window = ops.image.map_coordinates(image, coords, order=1)
369
+
370
+ return window, True
371
+
372
+ def _sobel_gradients(self, image):
373
+ """Compute Sobel gradients for 2D or 3D."""
374
+ if self.ndim == 2:
375
+ return self._sobel_gradients_2d(image)
376
+ elif self.ndim == 3:
377
+ return self._sobel_gradients_3d(image)
378
+ else:
379
+ raise ValueError(f"Unsupported ndim: {self.ndim}")
380
+
381
+ def _sobel_gradients_2d(self, image):
382
+ """Compute 2D Sobel gradients using keras.ops."""
383
+ # Standard Sobel kernels
384
+ sobel_y = ops.convert_to_tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype="float32") / 8.0
385
+ sobel_x = ops.convert_to_tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype="float32") / 8.0
386
+
387
+ h, w = ops.shape(image)[0], ops.shape(image)[1]
388
+
389
+ padded = ops.pad(image, [[1, 1], [1, 1]], mode="reflect")
390
+
391
+ # Reshape for conv: image needs (batch, height, width, channels)
392
+ img_4d = ops.reshape(padded, [1, h + 2, w + 2, 1])
393
+ sobel_y_4d = ops.reshape(sobel_y, [3, 3, 1, 1])
394
+ sobel_x_4d = ops.reshape(sobel_x, [3, 3, 1, 1])
395
+
396
+ Iy_4d = ops.conv(img_4d, sobel_y_4d, padding="valid")
397
+ Ix_4d = ops.conv(img_4d, sobel_x_4d, padding="valid")
398
+
399
+ # Reshape back to 2D
400
+ Iy = ops.reshape(Iy_4d, [h, w])
401
+ Ix = ops.reshape(Ix_4d, [h, w])
402
+
403
+ return Iy, Ix
404
+
405
+ def _sobel_gradients_3d(self, image):
406
+ """Compute 3D Sobel gradients using keras.ops."""
407
+ # 3D Sobel kernels (separable: smooth in 2 dims, gradient in 1 dim)
408
+ # Gradient in z-direction
409
+ sobel_z = (
410
+ ops.convert_to_tensor(
411
+ [
412
+ [[-1, -2, -1], [-2, -4, -2], [-1, -2, -1]],
413
+ [[0, 0, 0], [0, 0, 0], [0, 0, 0]],
414
+ [[1, 2, 1], [2, 4, 2], [1, 2, 1]],
415
+ ],
416
+ dtype="float32",
417
+ )
418
+ / 32.0
419
+ )
420
+
421
+ # Gradient in y-direction
422
+ sobel_y = (
423
+ ops.convert_to_tensor(
424
+ [
425
+ [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
426
+ [[-2, -4, -2], [0, 0, 0], [2, 4, 2]],
427
+ [[-1, -2, -1], [0, 0, 0], [1, 2, 1]],
428
+ ],
429
+ dtype="float32",
430
+ )
431
+ / 32.0
432
+ )
433
+
434
+ # Gradient in x-direction
435
+ sobel_x = (
436
+ ops.convert_to_tensor(
437
+ [
438
+ [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
439
+ [[-2, 0, 2], [-4, 0, 4], [-2, 0, 2]],
440
+ [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]],
441
+ ],
442
+ dtype="float32",
443
+ )
444
+ / 32.0
445
+ )
446
+
447
+ d, h, w = ops.shape(image)[0], ops.shape(image)[1], ops.shape(image)[2]
448
+
449
+ padded = ops.pad(image, [[1, 1], [1, 1], [1, 1]], mode="reflect")
450
+
451
+ # Reshape for conv: image needs (batch, depth, height, width, channels)
452
+ img_5d = ops.reshape(padded, [1, d + 2, h + 2, w + 2, 1])
453
+ sobel_z_5d = ops.reshape(sobel_z, [3, 3, 3, 1, 1])
454
+ sobel_y_5d = ops.reshape(sobel_y, [3, 3, 3, 1, 1])
455
+ sobel_x_5d = ops.reshape(sobel_x, [3, 3, 3, 1, 1])
456
+
457
+ # Apply 3D convolution with 'valid' padding (we pre-padded)
458
+ Iz_5d = ops.conv(img_5d, sobel_z_5d, padding="valid")
459
+ Iy_5d = ops.conv(img_5d, sobel_y_5d, padding="valid")
460
+ Ix_5d = ops.conv(img_5d, sobel_x_5d, padding="valid")
461
+
462
+ # Reshape back to 3D
463
+ Iz = ops.reshape(Iz_5d, [d, h, w])
464
+ Iy = ops.reshape(Iy_5d, [d, h, w])
465
+ Ix = ops.reshape(Ix_5d, [d, h, w])
466
+
467
+ return (Iz, Iy, Ix)
468
+
469
+ def __repr__(self):
470
+ """String representation."""
471
+ return (
472
+ f"LucasKanadeTracker(win_size={self.win_size}, max_level={self.max_level}, "
473
+ f"max_iterations={self.max_iterations}, epsilon={self.epsilon})"
474
+ )
@@ -0,0 +1,110 @@
1
+ """Segmentation-based tracker using contour matching.
2
+
3
+ .. seealso::
4
+ A tutorial notebook where this model is used:
5
+ :doc:`../notebooks/models/speckle_tracking_example`.
6
+
7
+ """
8
+
9
+ from keras import ops
10
+
11
+ from zea.tensor_ops import find_contour
12
+
13
+ from .base import BaseTracker
14
+
15
+
16
+ class SegmentationTracker(BaseTracker):
17
+ """Segmentation-based tracker.
18
+
19
+ This tracker segments each frame independently and finds the closest points
20
+ on the segmented contour to the previous frame's points.
21
+
22
+ Args:
23
+ model: Segmentation model with a `call` method.
24
+ preprocess_fn: Optional preprocessing function to apply to frames before segmentation.
25
+ postprocess_fn: Optional postprocessing function to apply to segmentation output, which
26
+ should return a binary mask of the target structure.
27
+
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ model,
33
+ preprocess_fn: callable = None,
34
+ postprocess_fn: callable = None,
35
+ ):
36
+ """Initialize segmentation-based tracker."""
37
+ super().__init__(ndim=2)
38
+ self.model = model
39
+ self.preprocess_fn = preprocess_fn
40
+ self.postprocess_fn = postprocess_fn
41
+
42
+ if self.preprocess_fn is None:
43
+ self.preprocess_fn = lambda frame: frame
44
+
45
+ if self.postprocess_fn is None:
46
+ raise ValueError("A postprocess_fn must be provided to extract binary masks.")
47
+
48
+ def track(
49
+ self,
50
+ prev_frame, # noqa F821
51
+ next_frame,
52
+ points,
53
+ ):
54
+ """
55
+ Track points by segmenting next_frame and finding closest contour points.
56
+
57
+ Args:
58
+ prev_frame: Previous frame (not used, kept for interface compatibility).
59
+ next_frame: Next frame to segment, shape (H, W).
60
+ points: Points from previous frame, shape (N, 2) in (row, col) format.
61
+
62
+ Returns:
63
+ new_points: Closest points on next frame's contour, shape (N, 2).
64
+ """
65
+ orig_shape = ops.shape(next_frame)
66
+
67
+ frame_input = self.preprocess_fn(next_frame)
68
+
69
+ outputs = self.model.call(frame_input)
70
+
71
+ mask = self.postprocess_fn(outputs, orig_shape)
72
+
73
+ contour_points = find_contour(mask)
74
+
75
+ if ops.shape(contour_points)[0] > 0:
76
+ new_points = self._find_closest_points(points, contour_points)
77
+ else:
78
+ new_points = points
79
+
80
+ return new_points
81
+
82
+ def _find_closest_points(self, query_points, target_points):
83
+ """Find closest target points to each query point.
84
+
85
+ Args:
86
+ query_points: Points to match, shape (N, 2).
87
+ target_points: Points to match to, shape (M, 2).
88
+
89
+ Returns:
90
+ Closest target points, shape (N, 2).
91
+ """
92
+ # Compute pairwise squared distances
93
+ # query_points: (N, 2), target_points: (M, 2)
94
+ # Expand dims: (N, 1, 2) and (1, M, 2)
95
+ query_expanded = ops.expand_dims(query_points, axis=1) # (N, 1, 2)
96
+ target_expanded = ops.expand_dims(target_points, axis=0) # (1, M, 2)
97
+
98
+ # Compute squared distances: (N, M)
99
+ diff = query_expanded - target_expanded
100
+ sq_distances = ops.sum(diff * diff, axis=2)
101
+
102
+ closest_indices = ops.argmin(sq_distances, axis=1)
103
+
104
+ closest_points = ops.take(target_points, closest_indices, axis=0)
105
+
106
+ return closest_points
107
+
108
+ def __repr__(self):
109
+ """String representation."""
110
+ return f"SegmentationTracker(model={self.model.__class__.__name__})"
zea/utils.py CHANGED
@@ -12,14 +12,23 @@ import yaml
12
12
  from zea import log
13
13
 
14
14
 
15
- def map_negative_indices(indices: list, length: int):
15
+ def canonicalize_axis(axis, num_dims) -> int:
16
+ """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
17
+ if not -num_dims <= axis < num_dims:
18
+ raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
19
+ if axis < 0:
20
+ axis = axis + num_dims
21
+ return axis
22
+
23
+
24
+ def map_negative_indices(indices: list, num_dims: int):
16
25
  """Maps negative indices for array indexing to positive indices.
17
26
  Example:
18
27
  >>> from zea.utils import map_negative_indices
19
28
  >>> map_negative_indices([-1, -2], 5)
20
29
  [4, 3]
21
30
  """
22
- return [i if i >= 0 else length + i for i in indices]
31
+ return [canonicalize_axis(idx, num_dims) for idx in indices]
23
32
 
24
33
 
25
34
  def print_clear_line():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: zea
3
- Version: 0.0.7
3
+ Version: 0.0.8
4
4
  Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
5
5
  License-File: LICENSE
6
6
  Keywords: ultrasound,machine learning,beamforming
@@ -106,6 +106,8 @@ Description-Content-Type: text/markdown
106
106
  [![License](https://img.shields.io/github/license/tue-bmd/zea)](https://github.com/tue-bmd/zea/blob/main/LICENSE)
107
107
  [![codecov](https://codecov.io/gh/tue-bmd/zea/branch/main/graph/badge.svg)](https://codecov.io/gh/tue-bmd/zea)
108
108
  [![status](https://joss.theoj.org/papers/fa923917ca41761fe0623ca6c350017d/status.svg)](https://joss.theoj.org/papers/fa923917ca41761fe0623ca6c350017d)
109
+ [![arXiv](https://img.shields.io/badge/arXiv-B31B1B?style=flat&logo=arXiv&logoColor=white)](https://arxiv.org/abs/2512.01433)
110
+ [![Hugging Face](https://img.shields.io/badge/Hugging%20Face-FFD21E?logo=huggingface&logoColor=black)](https://huggingface.co/zeahub)
109
111
  [![GitHub stars](https://img.shields.io/github/stars/tue-bmd/zea?style=social)](https://github.com/tue-bmd/zea/stargazers)
110
112
 
111
113
  Welcome to the `zea` package: *A Toolbox for Cognitive Ultrasound Imaging.*