zea 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +1 -1
- zea/backend/tensorflow/dataloader.py +0 -4
- zea/beamform/pixelgrid.py +1 -1
- zea/data/__init__.py +0 -9
- zea/data/augmentations.py +221 -28
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +123 -0
- zea/data/convert/camus.py +99 -39
- zea/data/convert/echonet.py +183 -82
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +173 -102
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/{matlab.py → verasonics.py} +33 -61
- zea/data/data_format.py +124 -4
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +109 -70
- zea/data/file.py +91 -82
- zea/data/file_operations.py +496 -0
- zea/data/preset_utils.py +1 -1
- zea/display.py +7 -8
- zea/internal/checks.py +6 -12
- zea/internal/operators.py +4 -0
- zea/io_lib.py +108 -160
- zea/models/__init__.py +1 -1
- zea/models/diffusion.py +62 -11
- zea/models/lv_segmentation.py +2 -0
- zea/ops.py +398 -158
- zea/scan.py +18 -8
- zea/tensor_ops.py +82 -62
- zea/tools/fit_scan_cone.py +90 -160
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +11 -2
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/METADATA +3 -1
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/RECORD +43 -35
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
- {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/__init__.py
CHANGED
|
@@ -155,10 +155,6 @@ def make_dataloader(
|
|
|
155
155
|
Mimics the native TF function ``tf.keras.utils.image_dataset_from_directory``
|
|
156
156
|
but for .hdf5 files.
|
|
157
157
|
|
|
158
|
-
Saves a dataset_info.yaml file in the directory with information about the dataset.
|
|
159
|
-
This file is used to load the dataset later on, which speeds up the initial loading
|
|
160
|
-
of the dataset for very large datasets.
|
|
161
|
-
|
|
162
158
|
Does the following in order to load a dataset:
|
|
163
159
|
|
|
164
160
|
- Find all .hdf5 files in the director(ies)
|
zea/beamform/pixelgrid.py
CHANGED
|
@@ -45,7 +45,7 @@ def cartesian_pixel_grid(xlims, zlims, grid_size_x=None, grid_size_z=None, dx=No
|
|
|
45
45
|
ValueError: Either grid_size_x and grid_size_z or dx and dz must be defined.
|
|
46
46
|
|
|
47
47
|
Returns:
|
|
48
|
-
grid (np.ndarray): Pixel grid of size (grid_size_z,
|
|
48
|
+
grid (np.ndarray): Pixel grid of size (grid_size_z, grid_size_x, 3) in
|
|
49
49
|
Cartesian coordinates (x, y, z)
|
|
50
50
|
"""
|
|
51
51
|
assert (bool(grid_size_x) and bool(grid_size_z)) ^ (bool(dx) and bool(dz)), (
|
zea/data/__init__.py
CHANGED
|
@@ -38,15 +38,6 @@ Examples usage
|
|
|
38
38
|
... files.append(file) # process each file as needed
|
|
39
39
|
>>> dataset.close()
|
|
40
40
|
|
|
41
|
-
Subpackage layout
|
|
42
|
-
-----------------
|
|
43
|
-
|
|
44
|
-
- ``file.py``: Implements :class:`zea.File` and related file utilities.
|
|
45
|
-
- ``datasets.py``: Implements :class:`zea.Dataset` and folder management.
|
|
46
|
-
- ``dataloader.py``: Data loading utilities for batching and shuffling.
|
|
47
|
-
- ``data_format.py``: Data validation and example dataset generation.
|
|
48
|
-
- ``convert/``: Data conversion tools (e.g., from external formats).
|
|
49
|
-
|
|
50
41
|
""" # noqa: E501
|
|
51
42
|
|
|
52
43
|
from .convert.camus import sitk_load
|
zea/data/augmentations.py
CHANGED
|
@@ -30,7 +30,7 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
30
30
|
|
|
31
31
|
def __init__(
|
|
32
32
|
self,
|
|
33
|
-
radius: int,
|
|
33
|
+
radius: int | tuple[int, int],
|
|
34
34
|
fill_value: float = 1.0,
|
|
35
35
|
circle_axes: tuple[int, int] = (1, 2),
|
|
36
36
|
with_batch_dim=True,
|
|
@@ -38,25 +38,70 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
38
38
|
recovery_threshold=0.1,
|
|
39
39
|
randomize_location_across_batch=True,
|
|
40
40
|
seed=None,
|
|
41
|
+
width_range: tuple[int, int] = None,
|
|
42
|
+
height_range: tuple[int, int] = None,
|
|
41
43
|
**kwargs,
|
|
42
44
|
):
|
|
43
45
|
"""
|
|
44
46
|
Initialize RandomCircleInclusion.
|
|
45
47
|
|
|
46
48
|
Args:
|
|
47
|
-
radius (int): Radius of the circle to include.
|
|
49
|
+
radius (int or tuple[int, int]): Radius of the circle/ellipse to include.
|
|
48
50
|
fill_value (float): Value to fill inside the circle.
|
|
49
|
-
circle_axes (tuple[int, int]): Axes along which to draw the circle
|
|
51
|
+
circle_axes (tuple[int, int]): Axes along which to draw the circle
|
|
52
|
+
(height, width).
|
|
50
53
|
with_batch_dim (bool): Whether input has a batch dimension.
|
|
51
54
|
return_centers (bool): Whether to return circle centers along with images.
|
|
52
55
|
recovery_threshold (float): Threshold for considering a pixel as recovered.
|
|
53
|
-
randomize_location_across_batch (bool): If True
|
|
54
|
-
|
|
56
|
+
randomize_location_across_batch (bool): If True (and with_batch_dim=True),
|
|
57
|
+
each batch element gets a different random center. If False, all batch
|
|
58
|
+
elements share the same center.
|
|
55
59
|
seed (Any): Optional random seed for reproducibility.
|
|
60
|
+
width_range (tuple[int, int], optional): Range (min, max) for circle
|
|
61
|
+
center x (width axis).
|
|
62
|
+
height_range (tuple[int, int], optional): Range (min, max) for circle
|
|
63
|
+
center y (height axis).
|
|
56
64
|
**kwargs: Additional keyword arguments for the parent Layer.
|
|
65
|
+
|
|
66
|
+
Example:
|
|
67
|
+
.. doctest::
|
|
68
|
+
|
|
69
|
+
>>> from zea.data.augmentations import RandomCircleInclusion
|
|
70
|
+
>>> from keras import ops
|
|
71
|
+
|
|
72
|
+
>>> layer = RandomCircleInclusion(
|
|
73
|
+
... radius=5,
|
|
74
|
+
... circle_axes=(1, 2),
|
|
75
|
+
... with_batch_dim=True,
|
|
76
|
+
... )
|
|
77
|
+
>>> image = ops.zeros((1, 28, 28), dtype="float32")
|
|
78
|
+
>>> out = layer(image) # doctest: +SKIP
|
|
79
|
+
|
|
57
80
|
"""
|
|
58
81
|
super().__init__(**kwargs)
|
|
59
|
-
|
|
82
|
+
|
|
83
|
+
# Validate randomize_location_across_batch only makes sense with batch dim
|
|
84
|
+
if not with_batch_dim and not randomize_location_across_batch:
|
|
85
|
+
raise ValueError(
|
|
86
|
+
"randomize_location_across_batch=False is only applicable when "
|
|
87
|
+
"with_batch_dim=True. When with_batch_dim=False, there is no batch "
|
|
88
|
+
"to randomize across."
|
|
89
|
+
)
|
|
90
|
+
# Convert radius to tuple if int, else validate tuple
|
|
91
|
+
if isinstance(radius, int):
|
|
92
|
+
if radius <= 0:
|
|
93
|
+
raise ValueError(f"radius must be a positive integer, got {radius}.")
|
|
94
|
+
self.radius = (radius, radius)
|
|
95
|
+
elif isinstance(radius, tuple) and len(radius) == 2:
|
|
96
|
+
rx, ry = radius
|
|
97
|
+
if not all(isinstance(r, int) for r in (rx, ry)):
|
|
98
|
+
raise TypeError(f"radius tuple must contain two integers, got {radius!r}.")
|
|
99
|
+
if rx <= 0 or ry <= 0:
|
|
100
|
+
raise ValueError(f"radius components must be positive, got {radius!r}.")
|
|
101
|
+
self.radius = (rx, ry)
|
|
102
|
+
else:
|
|
103
|
+
raise TypeError("radius must be an int or a tuple of two ints")
|
|
104
|
+
|
|
60
105
|
self.fill_value = fill_value
|
|
61
106
|
self.circle_axes = circle_axes
|
|
62
107
|
self.with_batch_dim = with_batch_dim
|
|
@@ -64,6 +109,8 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
64
109
|
self.recovery_threshold = recovery_threshold
|
|
65
110
|
self.randomize_location_across_batch = randomize_location_across_batch
|
|
66
111
|
self.seed = seed
|
|
112
|
+
self.width_range = width_range
|
|
113
|
+
self.height_range = height_range
|
|
67
114
|
self._axis1 = None
|
|
68
115
|
self._axis2 = None
|
|
69
116
|
self._perm = None
|
|
@@ -116,6 +163,43 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
116
163
|
self._static_w = int(permuted_shape[-1])
|
|
117
164
|
self._static_shape = tuple(permuted_shape)
|
|
118
165
|
|
|
166
|
+
# Validate that ellipse can fit within image bounds
|
|
167
|
+
rx, ry = self.radius
|
|
168
|
+
min_required_width = 2 * rx + 1
|
|
169
|
+
min_required_height = 2 * ry + 1
|
|
170
|
+
|
|
171
|
+
if self._static_w < min_required_width:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
f"Image width ({self._static_w}) is too small for radius {rx}. "
|
|
174
|
+
f"Minimum required width: {min_required_width}"
|
|
175
|
+
)
|
|
176
|
+
if self._static_h < min_required_height:
|
|
177
|
+
raise ValueError(
|
|
178
|
+
f"Image height ({self._static_h}) is too small for radius {ry}. "
|
|
179
|
+
f"Minimum required height: {min_required_height}"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
# Validate width_range and height_range if provided
|
|
183
|
+
if self.width_range is not None:
|
|
184
|
+
min_x, max_x = self.width_range
|
|
185
|
+
if min_x >= max_x:
|
|
186
|
+
raise ValueError(f"width_range must have min < max, got {self.width_range}")
|
|
187
|
+
if min_x < rx or max_x > self._static_w - rx:
|
|
188
|
+
raise ValueError(
|
|
189
|
+
f"width_range {self.width_range} would place circle outside image bounds. "
|
|
190
|
+
f"Valid range: [{rx}, {self._static_w - rx})"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if self.height_range is not None:
|
|
194
|
+
min_y, max_y = self.height_range
|
|
195
|
+
if min_y >= max_y:
|
|
196
|
+
raise ValueError(f"height_range must have min < max, got {self.height_range}")
|
|
197
|
+
if min_y < ry or max_y > self._static_h - ry:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
f"height_range {self.height_range} would place circle outside image bounds. "
|
|
200
|
+
f"Valid range: [{ry}, {self._static_h - ry})"
|
|
201
|
+
)
|
|
202
|
+
|
|
119
203
|
super().build(input_shape)
|
|
120
204
|
|
|
121
205
|
def compute_output_shape(self, input_shape):
|
|
@@ -165,7 +249,7 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
165
249
|
centers (Tensor): Tensor of shape (batch, 2) with circle centers.
|
|
166
250
|
h (int): Height of the image.
|
|
167
251
|
w (int): Width of the image.
|
|
168
|
-
radius (int):
|
|
252
|
+
radius (tuple[int, int]): Radii of the ellipse (rx, ry).
|
|
169
253
|
dtype (str or dtype): Data type for the mask.
|
|
170
254
|
|
|
171
255
|
Returns:
|
|
@@ -176,12 +260,12 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
176
260
|
Y, X = ops.meshgrid(Y, X, indexing="ij")
|
|
177
261
|
Y = ops.expand_dims(Y, 0) # (1, h, w)
|
|
178
262
|
X = ops.expand_dims(X, 0) # (1, h, w)
|
|
179
|
-
# cx = ops.cast(centers[:, 0], "float32")[:, None, None]
|
|
180
|
-
# cy = ops.cast(centers[:, 1], "float32")[:, None, None]
|
|
181
263
|
cx = centers[:, 0][:, None, None]
|
|
182
264
|
cy = centers[:, 1][:, None, None]
|
|
183
|
-
|
|
184
|
-
|
|
265
|
+
rx, ry = radius
|
|
266
|
+
# Ellipse equation: ((X-cx)/rx)^2 + ((Y-cy)/ry)^2 <= 1
|
|
267
|
+
dist = ((X - cx) / rx) ** 2 + ((Y - cy) / ry) ** 2
|
|
268
|
+
mask = ops.cast(dist <= 1, dtype)
|
|
185
269
|
return mask
|
|
186
270
|
|
|
187
271
|
def call(self, x, seed=None):
|
|
@@ -197,9 +281,17 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
197
281
|
centers if return_centers is True.
|
|
198
282
|
"""
|
|
199
283
|
if keras.backend.backend() == "jax" and not is_jax_prng_key(seed):
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
284
|
+
if isinstance(seed, keras.random.SeedGenerator):
|
|
285
|
+
raise ValueError(
|
|
286
|
+
"When using JAX backend, please provide a jax.random.PRNGKey as seed, "
|
|
287
|
+
"instead of keras.random.SeedGenerator."
|
|
288
|
+
)
|
|
289
|
+
else:
|
|
290
|
+
raise TypeError(
|
|
291
|
+
f"When using JAX backend, seed must be a JAX PRNG key (created with "
|
|
292
|
+
f"jax.random.PRNGKey()), but got {type(seed)}. Note: jax.random.key() "
|
|
293
|
+
f"keys are not currently supported."
|
|
294
|
+
)
|
|
203
295
|
seed = seed if seed is not None else self.seed
|
|
204
296
|
|
|
205
297
|
if self.with_batch_dim:
|
|
@@ -209,22 +301,33 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
209
301
|
imgs, centers = ops.map(lambda arg: self._call(arg, seed), x)
|
|
210
302
|
else:
|
|
211
303
|
raise NotImplementedError(
|
|
212
|
-
"You cannot fix circle locations across while using"
|
|
304
|
+
"You cannot fix circle locations across batch while using "
|
|
213
305
|
+ "RandomCircleInclusion as a dataset augmentation, "
|
|
214
306
|
+ "since samples in a batch are handled independently."
|
|
215
307
|
)
|
|
216
308
|
else:
|
|
309
|
+
batch_size = ops.shape(x)[0]
|
|
217
310
|
if self.randomize_location_across_batch:
|
|
218
|
-
batch_size = ops.shape(x)[0]
|
|
219
311
|
seeds = split_seed(seed, batch_size)
|
|
220
|
-
if all(
|
|
312
|
+
if all(s is seeds[0] for s in seeds):
|
|
221
313
|
imgs, centers = ops.map(lambda arg: self._call(arg, seeds[0]), x)
|
|
222
314
|
else:
|
|
223
315
|
imgs, centers = ops.map(
|
|
224
316
|
lambda args: self._call(args[0], args[1]), (x, seeds)
|
|
225
317
|
)
|
|
226
318
|
else:
|
|
227
|
-
|
|
319
|
+
# Generate one random center that will be used for all batch elements
|
|
320
|
+
img0, center0 = self._call(x[0], seed)
|
|
321
|
+
|
|
322
|
+
# Apply the same center to all batch elements
|
|
323
|
+
imgs_list, centers_list = [img0], [center0]
|
|
324
|
+
for i in range(1, batch_size):
|
|
325
|
+
img_aug, center_out = self._call_with_fixed_center(x[i], center0)
|
|
326
|
+
imgs_list.append(img_aug)
|
|
327
|
+
centers_list.append(center_out)
|
|
328
|
+
|
|
329
|
+
imgs = ops.stack(imgs_list, axis=0)
|
|
330
|
+
centers = ops.stack(centers_list, axis=0)
|
|
228
331
|
else:
|
|
229
332
|
imgs, centers = self._call(x, seed)
|
|
230
333
|
|
|
@@ -248,17 +351,28 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
248
351
|
flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
|
|
249
352
|
|
|
250
353
|
def _draw_circle_2d(img2d):
|
|
354
|
+
rx, ry = self.radius
|
|
355
|
+
# Determine allowed ranges for center
|
|
356
|
+
if self.width_range is not None:
|
|
357
|
+
min_x, max_x = self.width_range
|
|
358
|
+
else:
|
|
359
|
+
min_x, max_x = rx, w - rx
|
|
360
|
+
if self.height_range is not None:
|
|
361
|
+
min_y, max_y = self.height_range
|
|
362
|
+
else:
|
|
363
|
+
min_y, max_y = ry, h - ry
|
|
364
|
+
# Ensure the ellipse fits within the allowed region
|
|
251
365
|
cx = ops.cast(
|
|
252
|
-
keras.random.uniform((),
|
|
366
|
+
keras.random.uniform((), min_x, max_x, seed=seed),
|
|
253
367
|
"int32",
|
|
254
368
|
)
|
|
255
369
|
new_seed, _ = split_seed(seed, 2) # ensure that cx and cy are independent
|
|
256
370
|
cy = ops.cast(
|
|
257
|
-
keras.random.uniform((),
|
|
371
|
+
keras.random.uniform((), min_y, max_y, seed=new_seed),
|
|
258
372
|
"int32",
|
|
259
373
|
)
|
|
260
374
|
mask = self._make_circle_mask(
|
|
261
|
-
ops.stack([cx, cy])[None, :], h, w,
|
|
375
|
+
ops.stack([cx, cy])[None, :], h, w, (rx, ry), img2d.dtype
|
|
262
376
|
)[0]
|
|
263
377
|
img_aug = img2d * (1 - mask) + self.fill_value * mask
|
|
264
378
|
center = ops.stack([cx, cy])
|
|
@@ -271,6 +385,67 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
271
385
|
centers = ops.reshape(centers, centers_shape)
|
|
272
386
|
return (aug_imgs, centers)
|
|
273
387
|
|
|
388
|
+
def _apply_circle_mask(self, flat, center, h, w):
|
|
389
|
+
"""Apply circle mask to flattened image data.
|
|
390
|
+
|
|
391
|
+
Args:
|
|
392
|
+
flat (Tensor): Flattened image data of shape (flat_batch, h, w).
|
|
393
|
+
center (Tensor): Center coordinates, either (2,) or (flat_batch, 2).
|
|
394
|
+
h (int): Height of images.
|
|
395
|
+
w (int): Width of images.
|
|
396
|
+
|
|
397
|
+
Returns:
|
|
398
|
+
Tensor: Augmented images with circle applied.
|
|
399
|
+
"""
|
|
400
|
+
rx, ry = self.radius
|
|
401
|
+
|
|
402
|
+
# Ensure center has batch dimension for broadcasting
|
|
403
|
+
if len(center.shape) == 1:
|
|
404
|
+
# Single center (2,) -> broadcast to all slices
|
|
405
|
+
center_batched = ops.tile(ops.reshape(center, [1, 2]), [flat.shape[0], 1])
|
|
406
|
+
else:
|
|
407
|
+
# Already batched (flat_batch, 2)
|
|
408
|
+
center_batched = center
|
|
409
|
+
|
|
410
|
+
# Create masks for all slices using vectorized_map or broadcasting
|
|
411
|
+
masks = self._make_circle_mask(center_batched, h, w, (rx, ry), flat.dtype)
|
|
412
|
+
|
|
413
|
+
# Apply masks
|
|
414
|
+
aug_imgs = flat * (1 - masks) + self.fill_value * masks
|
|
415
|
+
return aug_imgs
|
|
416
|
+
|
|
417
|
+
def _call_with_fixed_center(self, x, fixed_center):
|
|
418
|
+
"""Apply augmentation using a pre-determined center.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
x (Tensor): Input image tensor.
|
|
422
|
+
fixed_center (Tensor): Pre-determined center coordinates, either (2,)
|
|
423
|
+
for a single center or (flat_batch, 2) for per-slice centers.
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
tuple: (augmented image, center coordinates).
|
|
427
|
+
"""
|
|
428
|
+
x = self._permute_axes_to_circle_last(x)
|
|
429
|
+
flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
|
|
430
|
+
|
|
431
|
+
# Apply circle mask with fixed center (handles both single and batched centers)
|
|
432
|
+
aug_imgs = self._apply_circle_mask(flat, fixed_center, h, w)
|
|
433
|
+
aug_imgs = ops.reshape(aug_imgs, x.shape)
|
|
434
|
+
aug_imgs = ops.transpose(aug_imgs, axes=self._inv_perm)
|
|
435
|
+
|
|
436
|
+
# Return centers matching the expected shape
|
|
437
|
+
if len(fixed_center.shape) == 1:
|
|
438
|
+
# Single center (2,) -> broadcast to match flat_batch_size
|
|
439
|
+
if flat_batch_size == 1:
|
|
440
|
+
centers = fixed_center
|
|
441
|
+
else:
|
|
442
|
+
centers = ops.tile(ops.reshape(fixed_center, [1, 2]), [flat_batch_size, 1])
|
|
443
|
+
else:
|
|
444
|
+
# Already batched centers (flat_batch, 2)
|
|
445
|
+
centers = fixed_center
|
|
446
|
+
|
|
447
|
+
return (aug_imgs, centers)
|
|
448
|
+
|
|
274
449
|
def get_config(self):
|
|
275
450
|
"""
|
|
276
451
|
Get layer configuration for serialization.
|
|
@@ -285,6 +460,8 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
285
460
|
"fill_value": self.fill_value,
|
|
286
461
|
"circle_axes": self.circle_axes,
|
|
287
462
|
"return_centers": self.return_centers,
|
|
463
|
+
"width_range": self.width_range,
|
|
464
|
+
"height_range": self.height_range,
|
|
288
465
|
}
|
|
289
466
|
)
|
|
290
467
|
return cfg
|
|
@@ -293,7 +470,8 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
293
470
|
self, images, centers, recovery_threshold, fill_value=None
|
|
294
471
|
):
|
|
295
472
|
"""
|
|
296
|
-
Evaluate the percentage of the true circle that has been recovered in the images
|
|
473
|
+
Evaluate the percentage of the true circle that has been recovered in the images,
|
|
474
|
+
and return a mask of the detected part of the circle.
|
|
297
475
|
|
|
298
476
|
Args:
|
|
299
477
|
images (Tensor): Tensor of images (any shape, with circle axes as specified).
|
|
@@ -302,8 +480,12 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
302
480
|
fill_value (float, optional): Optionally override fill_value for cases
|
|
303
481
|
where image range has changed.
|
|
304
482
|
|
|
305
|
-
|
|
306
|
-
Tensor
|
|
483
|
+
Returns:
|
|
484
|
+
Tuple[Tensor, Tensor]:
|
|
485
|
+
- percent_recovered: [batch] - average recovery percentage per batch element,
|
|
486
|
+
averaged across all non-batch dimensions (e.g., frames, slices)
|
|
487
|
+
- recovered_masks: [batch, flat_batch, h, w] or [batch, h, w] or [flat_batch, h, w]
|
|
488
|
+
depending on input shape - binary mask of detected circle regions
|
|
307
489
|
"""
|
|
308
490
|
fill_value = fill_value or self.fill_value
|
|
309
491
|
|
|
@@ -318,12 +500,23 @@ class RandomCircleInclusion(layers.Layer):
|
|
|
318
500
|
recovered_sum = ops.sum(recovered, axis=[1, 2])
|
|
319
501
|
mask_sum = ops.sum(mask, axis=[1, 2])
|
|
320
502
|
percent_recovered = recovered_sum / (mask_sum + 1e-8)
|
|
321
|
-
|
|
503
|
+
# recovered_mask: binary mask of detected part of the circle
|
|
504
|
+
recovered_mask = ops.cast(recovered > 0, flat_image.dtype)
|
|
505
|
+
return percent_recovered, recovered_mask
|
|
322
506
|
|
|
323
507
|
if self.with_batch_dim:
|
|
324
|
-
|
|
508
|
+
results = ops.vectorized_map(
|
|
325
509
|
lambda args: _evaluate_recovered_circle_accuracy(args[0], args[1]),
|
|
326
510
|
(images, centers),
|
|
327
|
-
)
|
|
511
|
+
)
|
|
512
|
+
percent_recovered, recovered_masks = results
|
|
513
|
+
# If there are multiple circles per batch element (e.g., multiple frames/slices),
|
|
514
|
+
# take the mean across all non-batch dimensions to get one value per batch element
|
|
515
|
+
if len(percent_recovered.shape) > 1:
|
|
516
|
+
# Average over all axes except the batch dimension (axis 0)
|
|
517
|
+
axes_to_reduce = tuple(range(1, len(percent_recovered.shape)))
|
|
518
|
+
percent_recovered = ops.mean(percent_recovered, axis=axes_to_reduce)
|
|
519
|
+
return percent_recovered, recovered_masks
|
|
328
520
|
else:
|
|
329
|
-
|
|
521
|
+
percent_recovered, recovered_mask = _evaluate_recovered_circle_accuracy(images, centers)
|
|
522
|
+
return percent_recovered, recovered_mask
|
zea/data/convert/__init__.py
CHANGED
|
@@ -1,6 +1 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
from .camus import convert_camus
|
|
4
|
-
from .images import convert_image_dataset
|
|
5
|
-
from .matlab import zea_from_matlab_raw
|
|
6
|
-
from .picmus import convert_picmus
|
|
1
|
+
"""Data conversion of datasets to the ``zea`` data format."""
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_parser():
|
|
5
|
+
"""
|
|
6
|
+
Build and parse command-line arguments for converting raw datasets to a zea dataset.
|
|
7
|
+
|
|
8
|
+
Returns:
|
|
9
|
+
argparse.Namespace: Parsed arguments with the following attributes:
|
|
10
|
+
dataset (str): One of "echonet", "echonetlvh", "camus", "picmus", "verasonics".
|
|
11
|
+
src (str): Source folder path.
|
|
12
|
+
dst (str): Destination folder path.
|
|
13
|
+
split_path (str|None): Optional path to a split.yaml to copy dataset splits.
|
|
14
|
+
no_hyperthreading (bool): Disable hyperthreading for multiprocessing.
|
|
15
|
+
frames (list[str]): MATLAB frames spec (e.g., ["all"], integers, or ranges like "4-8").
|
|
16
|
+
no_rejection (bool): EchonetLVH flag to skip manual_rejections.txt filtering.
|
|
17
|
+
batch (str|None): EchonetLVH Batch directory to process (e.g., "Batch2").
|
|
18
|
+
convert_measurements (bool): EchonetLVH flag to convert only measurements CSV.
|
|
19
|
+
convert_images (bool): EchonetLVH flag to convert only image files.
|
|
20
|
+
max_files (int|None): EchonetLVH maximum number of files to process.
|
|
21
|
+
force (bool): EchonetLVH flag to force recomputation even if parameters exist.
|
|
22
|
+
"""
|
|
23
|
+
parser = argparse.ArgumentParser(description="Convert raw data to a zea dataset.")
|
|
24
|
+
parser.add_argument(
|
|
25
|
+
"dataset",
|
|
26
|
+
choices=["echonet", "echonetlvh", "camus", "picmus", "verasonics"],
|
|
27
|
+
help="Raw dataset to convert",
|
|
28
|
+
)
|
|
29
|
+
parser.add_argument("src", type=str, help="Source folder path")
|
|
30
|
+
parser.add_argument("dst", type=str, help="Destination folder path")
|
|
31
|
+
parser.add_argument(
|
|
32
|
+
"--split_path",
|
|
33
|
+
type=str,
|
|
34
|
+
help="Path to the split.yaml file containing the dataset split if a split should be copied",
|
|
35
|
+
)
|
|
36
|
+
parser.add_argument(
|
|
37
|
+
"--no_hyperthreading",
|
|
38
|
+
action="store_true",
|
|
39
|
+
help="Disable hyperthreading for multiprocessing",
|
|
40
|
+
)
|
|
41
|
+
# Dataset specific arguments:
|
|
42
|
+
|
|
43
|
+
# verasonics
|
|
44
|
+
parser.add_argument(
|
|
45
|
+
"--frames",
|
|
46
|
+
default=["all"],
|
|
47
|
+
type=str,
|
|
48
|
+
nargs="+",
|
|
49
|
+
help="verasonics: The frames to add to the file. This can be a list of integers, a range "
|
|
50
|
+
"of integers (e.g. 4-8), or 'all'.",
|
|
51
|
+
)
|
|
52
|
+
# ECHONET_LVH
|
|
53
|
+
parser.add_argument(
|
|
54
|
+
"--no_rejection",
|
|
55
|
+
action="store_true",
|
|
56
|
+
help="EchonetLVH: Do not reject sequences in manual_rejections.txt",
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
parser.add_argument(
|
|
60
|
+
"--batch",
|
|
61
|
+
type=str,
|
|
62
|
+
default=None,
|
|
63
|
+
help="EchonetLVH: Specify which BatchX directory to process, e.g. --batch=Batch2",
|
|
64
|
+
)
|
|
65
|
+
parser.add_argument(
|
|
66
|
+
"--convert_measurements",
|
|
67
|
+
action="store_true",
|
|
68
|
+
help="EchonetLVH: Only convert measurements CSV file",
|
|
69
|
+
)
|
|
70
|
+
parser.add_argument(
|
|
71
|
+
"--convert_images", action="store_true", help="EchonetLVH: Only convert image files"
|
|
72
|
+
)
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"--max_files",
|
|
75
|
+
type=int,
|
|
76
|
+
default=None,
|
|
77
|
+
help="EchonetLVH: Maximum number of files to process (for testing)",
|
|
78
|
+
)
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--force",
|
|
81
|
+
action="store_true",
|
|
82
|
+
help="EchonetLVH: Force recomputation even if parameters already exist",
|
|
83
|
+
)
|
|
84
|
+
return parser
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def main():
|
|
88
|
+
"""
|
|
89
|
+
Parse command-line arguments and dispatch to the selected dataset conversion routine.
|
|
90
|
+
|
|
91
|
+
This function obtains CLI arguments via get_args() and calls the corresponding converter
|
|
92
|
+
(convert_echonet, convert_echonetlvh, convert_camus, convert_picmus, or convert_verasonics)
|
|
93
|
+
based on args.dataset.
|
|
94
|
+
Raises a ValueError if args.dataset is not one of the supported choices.
|
|
95
|
+
"""
|
|
96
|
+
parser = get_parser()
|
|
97
|
+
args = parser.parse_args()
|
|
98
|
+
if args.dataset == "echonet":
|
|
99
|
+
from zea.data.convert.echonet import convert_echonet
|
|
100
|
+
|
|
101
|
+
convert_echonet(args)
|
|
102
|
+
elif args.dataset == "echonetlvh":
|
|
103
|
+
from zea.data.convert.echonetlvh import convert_echonetlvh
|
|
104
|
+
|
|
105
|
+
convert_echonetlvh(args)
|
|
106
|
+
elif args.dataset == "camus":
|
|
107
|
+
from zea.data.convert.camus import convert_camus
|
|
108
|
+
|
|
109
|
+
convert_camus(args)
|
|
110
|
+
elif args.dataset == "picmus":
|
|
111
|
+
from zea.data.convert.picmus import convert_picmus
|
|
112
|
+
|
|
113
|
+
convert_picmus(args)
|
|
114
|
+
elif args.dataset == "verasonics":
|
|
115
|
+
from zea.data.convert.verasonics import convert_verasonics
|
|
116
|
+
|
|
117
|
+
convert_verasonics(args)
|
|
118
|
+
else:
|
|
119
|
+
raise ValueError(f"Unknown dataset: {args.dataset}")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
if __name__ == "__main__":
|
|
123
|
+
main()
|