zea 0.0.7__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. zea/__init__.py +1 -1
  2. zea/backend/tensorflow/dataloader.py +0 -4
  3. zea/beamform/pixelgrid.py +1 -1
  4. zea/data/__init__.py +0 -9
  5. zea/data/augmentations.py +221 -28
  6. zea/data/convert/__init__.py +1 -6
  7. zea/data/convert/__main__.py +123 -0
  8. zea/data/convert/camus.py +99 -39
  9. zea/data/convert/echonet.py +183 -82
  10. zea/data/convert/echonetlvh/README.md +2 -3
  11. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +173 -102
  12. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  13. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  14. zea/data/convert/picmus.py +37 -40
  15. zea/data/convert/utils.py +86 -0
  16. zea/data/convert/{matlab.py → verasonics.py} +33 -61
  17. zea/data/data_format.py +124 -4
  18. zea/data/dataloader.py +12 -7
  19. zea/data/datasets.py +109 -70
  20. zea/data/file.py +91 -82
  21. zea/data/file_operations.py +496 -0
  22. zea/data/preset_utils.py +1 -1
  23. zea/display.py +7 -8
  24. zea/internal/checks.py +6 -12
  25. zea/internal/operators.py +4 -0
  26. zea/io_lib.py +108 -160
  27. zea/models/__init__.py +1 -1
  28. zea/models/diffusion.py +62 -11
  29. zea/models/lv_segmentation.py +2 -0
  30. zea/ops.py +398 -158
  31. zea/scan.py +18 -8
  32. zea/tensor_ops.py +82 -62
  33. zea/tools/fit_scan_cone.py +90 -160
  34. zea/tracking/__init__.py +16 -0
  35. zea/tracking/base.py +94 -0
  36. zea/tracking/lucas_kanade.py +474 -0
  37. zea/tracking/segmentation.py +110 -0
  38. zea/utils.py +11 -2
  39. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/METADATA +3 -1
  40. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/RECORD +43 -35
  41. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  42. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  43. {zea-0.0.7.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/__init__.py CHANGED
@@ -7,7 +7,7 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.7"
10
+ __version__ = "0.0.8"
11
11
 
12
12
 
13
13
  def _bootstrap_backend():
@@ -155,10 +155,6 @@ def make_dataloader(
155
155
  Mimics the native TF function ``tf.keras.utils.image_dataset_from_directory``
156
156
  but for .hdf5 files.
157
157
 
158
- Saves a dataset_info.yaml file in the directory with information about the dataset.
159
- This file is used to load the dataset later on, which speeds up the initial loading
160
- of the dataset for very large datasets.
161
-
162
158
  Does the following in order to load a dataset:
163
159
 
164
160
  - Find all .hdf5 files in the director(ies)
zea/beamform/pixelgrid.py CHANGED
@@ -45,7 +45,7 @@ def cartesian_pixel_grid(xlims, zlims, grid_size_x=None, grid_size_z=None, dx=No
45
45
  ValueError: Either grid_size_x and grid_size_z or dx and dz must be defined.
46
46
 
47
47
  Returns:
48
- grid (np.ndarray): Pixel grid of size (grid_size_z, nx, 3) in
48
+ grid (np.ndarray): Pixel grid of size (grid_size_z, grid_size_x, 3) in
49
49
  Cartesian coordinates (x, y, z)
50
50
  """
51
51
  assert (bool(grid_size_x) and bool(grid_size_z)) ^ (bool(dx) and bool(dz)), (
zea/data/__init__.py CHANGED
@@ -38,15 +38,6 @@ Examples usage
38
38
  ... files.append(file) # process each file as needed
39
39
  >>> dataset.close()
40
40
 
41
- Subpackage layout
42
- -----------------
43
-
44
- - ``file.py``: Implements :class:`zea.File` and related file utilities.
45
- - ``datasets.py``: Implements :class:`zea.Dataset` and folder management.
46
- - ``dataloader.py``: Data loading utilities for batching and shuffling.
47
- - ``data_format.py``: Data validation and example dataset generation.
48
- - ``convert/``: Data conversion tools (e.g., from external formats).
49
-
50
41
  """ # noqa: E501
51
42
 
52
43
  from .convert.camus import sitk_load
zea/data/augmentations.py CHANGED
@@ -30,7 +30,7 @@ class RandomCircleInclusion(layers.Layer):
30
30
 
31
31
  def __init__(
32
32
  self,
33
- radius: int,
33
+ radius: int | tuple[int, int],
34
34
  fill_value: float = 1.0,
35
35
  circle_axes: tuple[int, int] = (1, 2),
36
36
  with_batch_dim=True,
@@ -38,25 +38,70 @@ class RandomCircleInclusion(layers.Layer):
38
38
  recovery_threshold=0.1,
39
39
  randomize_location_across_batch=True,
40
40
  seed=None,
41
+ width_range: tuple[int, int] = None,
42
+ height_range: tuple[int, int] = None,
41
43
  **kwargs,
42
44
  ):
43
45
  """
44
46
  Initialize RandomCircleInclusion.
45
47
 
46
48
  Args:
47
- radius (int): Radius of the circle to include.
49
+ radius (int or tuple[int, int]): Radius of the circle/ellipse to include.
48
50
  fill_value (float): Value to fill inside the circle.
49
- circle_axes (tuple[int, int]): Axes along which to draw the circle (height, width).
51
+ circle_axes (tuple[int, int]): Axes along which to draw the circle
52
+ (height, width).
50
53
  with_batch_dim (bool): Whether input has a batch dimension.
51
54
  return_centers (bool): Whether to return circle centers along with images.
52
55
  recovery_threshold (float): Threshold for considering a pixel as recovered.
53
- randomize_location_across_batch (bool): If True, randomize circle location
54
- per batch element.
56
+ randomize_location_across_batch (bool): If True (and with_batch_dim=True),
57
+ each batch element gets a different random center. If False, all batch
58
+ elements share the same center.
55
59
  seed (Any): Optional random seed for reproducibility.
60
+ width_range (tuple[int, int], optional): Range (min, max) for circle
61
+ center x (width axis).
62
+ height_range (tuple[int, int], optional): Range (min, max) for circle
63
+ center y (height axis).
56
64
  **kwargs: Additional keyword arguments for the parent Layer.
65
+
66
+ Example:
67
+ .. doctest::
68
+
69
+ >>> from zea.data.augmentations import RandomCircleInclusion
70
+ >>> from keras import ops
71
+
72
+ >>> layer = RandomCircleInclusion(
73
+ ... radius=5,
74
+ ... circle_axes=(1, 2),
75
+ ... with_batch_dim=True,
76
+ ... )
77
+ >>> image = ops.zeros((1, 28, 28), dtype="float32")
78
+ >>> out = layer(image) # doctest: +SKIP
79
+
57
80
  """
58
81
  super().__init__(**kwargs)
59
- self.radius = radius
82
+
83
+ # Validate randomize_location_across_batch only makes sense with batch dim
84
+ if not with_batch_dim and not randomize_location_across_batch:
85
+ raise ValueError(
86
+ "randomize_location_across_batch=False is only applicable when "
87
+ "with_batch_dim=True. When with_batch_dim=False, there is no batch "
88
+ "to randomize across."
89
+ )
90
+ # Convert radius to tuple if int, else validate tuple
91
+ if isinstance(radius, int):
92
+ if radius <= 0:
93
+ raise ValueError(f"radius must be a positive integer, got {radius}.")
94
+ self.radius = (radius, radius)
95
+ elif isinstance(radius, tuple) and len(radius) == 2:
96
+ rx, ry = radius
97
+ if not all(isinstance(r, int) for r in (rx, ry)):
98
+ raise TypeError(f"radius tuple must contain two integers, got {radius!r}.")
99
+ if rx <= 0 or ry <= 0:
100
+ raise ValueError(f"radius components must be positive, got {radius!r}.")
101
+ self.radius = (rx, ry)
102
+ else:
103
+ raise TypeError("radius must be an int or a tuple of two ints")
104
+
60
105
  self.fill_value = fill_value
61
106
  self.circle_axes = circle_axes
62
107
  self.with_batch_dim = with_batch_dim
@@ -64,6 +109,8 @@ class RandomCircleInclusion(layers.Layer):
64
109
  self.recovery_threshold = recovery_threshold
65
110
  self.randomize_location_across_batch = randomize_location_across_batch
66
111
  self.seed = seed
112
+ self.width_range = width_range
113
+ self.height_range = height_range
67
114
  self._axis1 = None
68
115
  self._axis2 = None
69
116
  self._perm = None
@@ -116,6 +163,43 @@ class RandomCircleInclusion(layers.Layer):
116
163
  self._static_w = int(permuted_shape[-1])
117
164
  self._static_shape = tuple(permuted_shape)
118
165
 
166
+ # Validate that ellipse can fit within image bounds
167
+ rx, ry = self.radius
168
+ min_required_width = 2 * rx + 1
169
+ min_required_height = 2 * ry + 1
170
+
171
+ if self._static_w < min_required_width:
172
+ raise ValueError(
173
+ f"Image width ({self._static_w}) is too small for radius {rx}. "
174
+ f"Minimum required width: {min_required_width}"
175
+ )
176
+ if self._static_h < min_required_height:
177
+ raise ValueError(
178
+ f"Image height ({self._static_h}) is too small for radius {ry}. "
179
+ f"Minimum required height: {min_required_height}"
180
+ )
181
+
182
+ # Validate width_range and height_range if provided
183
+ if self.width_range is not None:
184
+ min_x, max_x = self.width_range
185
+ if min_x >= max_x:
186
+ raise ValueError(f"width_range must have min < max, got {self.width_range}")
187
+ if min_x < rx or max_x > self._static_w - rx:
188
+ raise ValueError(
189
+ f"width_range {self.width_range} would place circle outside image bounds. "
190
+ f"Valid range: [{rx}, {self._static_w - rx})"
191
+ )
192
+
193
+ if self.height_range is not None:
194
+ min_y, max_y = self.height_range
195
+ if min_y >= max_y:
196
+ raise ValueError(f"height_range must have min < max, got {self.height_range}")
197
+ if min_y < ry or max_y > self._static_h - ry:
198
+ raise ValueError(
199
+ f"height_range {self.height_range} would place circle outside image bounds. "
200
+ f"Valid range: [{ry}, {self._static_h - ry})"
201
+ )
202
+
119
203
  super().build(input_shape)
120
204
 
121
205
  def compute_output_shape(self, input_shape):
@@ -165,7 +249,7 @@ class RandomCircleInclusion(layers.Layer):
165
249
  centers (Tensor): Tensor of shape (batch, 2) with circle centers.
166
250
  h (int): Height of the image.
167
251
  w (int): Width of the image.
168
- radius (int): Radius of the circle.
252
+ radius (tuple[int, int]): Radii of the ellipse (rx, ry).
169
253
  dtype (str or dtype): Data type for the mask.
170
254
 
171
255
  Returns:
@@ -176,12 +260,12 @@ class RandomCircleInclusion(layers.Layer):
176
260
  Y, X = ops.meshgrid(Y, X, indexing="ij")
177
261
  Y = ops.expand_dims(Y, 0) # (1, h, w)
178
262
  X = ops.expand_dims(X, 0) # (1, h, w)
179
- # cx = ops.cast(centers[:, 0], "float32")[:, None, None]
180
- # cy = ops.cast(centers[:, 1], "float32")[:, None, None]
181
263
  cx = centers[:, 0][:, None, None]
182
264
  cy = centers[:, 1][:, None, None]
183
- dist2 = (X - cx) ** 2 + (Y - cy) ** 2
184
- mask = ops.cast(dist2 <= radius**2, dtype)
265
+ rx, ry = radius
266
+ # Ellipse equation: ((X-cx)/rx)^2 + ((Y-cy)/ry)^2 <= 1
267
+ dist = ((X - cx) / rx) ** 2 + ((Y - cy) / ry) ** 2
268
+ mask = ops.cast(dist <= 1, dtype)
185
269
  return mask
186
270
 
187
271
  def call(self, x, seed=None):
@@ -197,9 +281,17 @@ class RandomCircleInclusion(layers.Layer):
197
281
  centers if return_centers is True.
198
282
  """
199
283
  if keras.backend.backend() == "jax" and not is_jax_prng_key(seed):
200
- raise NotImplementedError(
201
- "jax.random.key() is not supported, please use jax.random.PRNGKey()"
202
- )
284
+ if isinstance(seed, keras.random.SeedGenerator):
285
+ raise ValueError(
286
+ "When using JAX backend, please provide a jax.random.PRNGKey as seed, "
287
+ "instead of keras.random.SeedGenerator."
288
+ )
289
+ else:
290
+ raise TypeError(
291
+ f"When using JAX backend, seed must be a JAX PRNG key (created with "
292
+ f"jax.random.PRNGKey()), but got {type(seed)}. Note: jax.random.key() "
293
+ f"keys are not currently supported."
294
+ )
203
295
  seed = seed if seed is not None else self.seed
204
296
 
205
297
  if self.with_batch_dim:
@@ -209,22 +301,33 @@ class RandomCircleInclusion(layers.Layer):
209
301
  imgs, centers = ops.map(lambda arg: self._call(arg, seed), x)
210
302
  else:
211
303
  raise NotImplementedError(
212
- "You cannot fix circle locations across while using"
304
+ "You cannot fix circle locations across batch while using "
213
305
  + "RandomCircleInclusion as a dataset augmentation, "
214
306
  + "since samples in a batch are handled independently."
215
307
  )
216
308
  else:
309
+ batch_size = ops.shape(x)[0]
217
310
  if self.randomize_location_across_batch:
218
- batch_size = ops.shape(x)[0]
219
311
  seeds = split_seed(seed, batch_size)
220
- if all(seed is seeds[0] for seed in seeds):
312
+ if all(s is seeds[0] for s in seeds):
221
313
  imgs, centers = ops.map(lambda arg: self._call(arg, seeds[0]), x)
222
314
  else:
223
315
  imgs, centers = ops.map(
224
316
  lambda args: self._call(args[0], args[1]), (x, seeds)
225
317
  )
226
318
  else:
227
- imgs, centers = ops.map(lambda arg: self._call(arg, seed), x)
319
+ # Generate one random center that will be used for all batch elements
320
+ img0, center0 = self._call(x[0], seed)
321
+
322
+ # Apply the same center to all batch elements
323
+ imgs_list, centers_list = [img0], [center0]
324
+ for i in range(1, batch_size):
325
+ img_aug, center_out = self._call_with_fixed_center(x[i], center0)
326
+ imgs_list.append(img_aug)
327
+ centers_list.append(center_out)
328
+
329
+ imgs = ops.stack(imgs_list, axis=0)
330
+ centers = ops.stack(centers_list, axis=0)
228
331
  else:
229
332
  imgs, centers = self._call(x, seed)
230
333
 
@@ -248,17 +351,28 @@ class RandomCircleInclusion(layers.Layer):
248
351
  flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
249
352
 
250
353
  def _draw_circle_2d(img2d):
354
+ rx, ry = self.radius
355
+ # Determine allowed ranges for center
356
+ if self.width_range is not None:
357
+ min_x, max_x = self.width_range
358
+ else:
359
+ min_x, max_x = rx, w - rx
360
+ if self.height_range is not None:
361
+ min_y, max_y = self.height_range
362
+ else:
363
+ min_y, max_y = ry, h - ry
364
+ # Ensure the ellipse fits within the allowed region
251
365
  cx = ops.cast(
252
- keras.random.uniform((), self.radius, w - self.radius, seed=seed),
366
+ keras.random.uniform((), min_x, max_x, seed=seed),
253
367
  "int32",
254
368
  )
255
369
  new_seed, _ = split_seed(seed, 2) # ensure that cx and cy are independent
256
370
  cy = ops.cast(
257
- keras.random.uniform((), self.radius, h - self.radius, seed=new_seed),
371
+ keras.random.uniform((), min_y, max_y, seed=new_seed),
258
372
  "int32",
259
373
  )
260
374
  mask = self._make_circle_mask(
261
- ops.stack([cx, cy])[None, :], h, w, self.radius, img2d.dtype
375
+ ops.stack([cx, cy])[None, :], h, w, (rx, ry), img2d.dtype
262
376
  )[0]
263
377
  img_aug = img2d * (1 - mask) + self.fill_value * mask
264
378
  center = ops.stack([cx, cy])
@@ -271,6 +385,67 @@ class RandomCircleInclusion(layers.Layer):
271
385
  centers = ops.reshape(centers, centers_shape)
272
386
  return (aug_imgs, centers)
273
387
 
388
+ def _apply_circle_mask(self, flat, center, h, w):
389
+ """Apply circle mask to flattened image data.
390
+
391
+ Args:
392
+ flat (Tensor): Flattened image data of shape (flat_batch, h, w).
393
+ center (Tensor): Center coordinates, either (2,) or (flat_batch, 2).
394
+ h (int): Height of images.
395
+ w (int): Width of images.
396
+
397
+ Returns:
398
+ Tensor: Augmented images with circle applied.
399
+ """
400
+ rx, ry = self.radius
401
+
402
+ # Ensure center has batch dimension for broadcasting
403
+ if len(center.shape) == 1:
404
+ # Single center (2,) -> broadcast to all slices
405
+ center_batched = ops.tile(ops.reshape(center, [1, 2]), [flat.shape[0], 1])
406
+ else:
407
+ # Already batched (flat_batch, 2)
408
+ center_batched = center
409
+
410
+ # Create masks for all slices using vectorized_map or broadcasting
411
+ masks = self._make_circle_mask(center_batched, h, w, (rx, ry), flat.dtype)
412
+
413
+ # Apply masks
414
+ aug_imgs = flat * (1 - masks) + self.fill_value * masks
415
+ return aug_imgs
416
+
417
+ def _call_with_fixed_center(self, x, fixed_center):
418
+ """Apply augmentation using a pre-determined center.
419
+
420
+ Args:
421
+ x (Tensor): Input image tensor.
422
+ fixed_center (Tensor): Pre-determined center coordinates, either (2,)
423
+ for a single center or (flat_batch, 2) for per-slice centers.
424
+
425
+ Returns:
426
+ tuple: (augmented image, center coordinates).
427
+ """
428
+ x = self._permute_axes_to_circle_last(x)
429
+ flat, flat_batch_size, h, w = self._flatten_batch_and_other_dims(x)
430
+
431
+ # Apply circle mask with fixed center (handles both single and batched centers)
432
+ aug_imgs = self._apply_circle_mask(flat, fixed_center, h, w)
433
+ aug_imgs = ops.reshape(aug_imgs, x.shape)
434
+ aug_imgs = ops.transpose(aug_imgs, axes=self._inv_perm)
435
+
436
+ # Return centers matching the expected shape
437
+ if len(fixed_center.shape) == 1:
438
+ # Single center (2,) -> broadcast to match flat_batch_size
439
+ if flat_batch_size == 1:
440
+ centers = fixed_center
441
+ else:
442
+ centers = ops.tile(ops.reshape(fixed_center, [1, 2]), [flat_batch_size, 1])
443
+ else:
444
+ # Already batched centers (flat_batch, 2)
445
+ centers = fixed_center
446
+
447
+ return (aug_imgs, centers)
448
+
274
449
  def get_config(self):
275
450
  """
276
451
  Get layer configuration for serialization.
@@ -285,6 +460,8 @@ class RandomCircleInclusion(layers.Layer):
285
460
  "fill_value": self.fill_value,
286
461
  "circle_axes": self.circle_axes,
287
462
  "return_centers": self.return_centers,
463
+ "width_range": self.width_range,
464
+ "height_range": self.height_range,
288
465
  }
289
466
  )
290
467
  return cfg
@@ -293,7 +470,8 @@ class RandomCircleInclusion(layers.Layer):
293
470
  self, images, centers, recovery_threshold, fill_value=None
294
471
  ):
295
472
  """
296
- Evaluate the percentage of the true circle that has been recovered in the images.
473
+ Evaluate the percentage of the true circle that has been recovered in the images,
474
+ and return a mask of the detected part of the circle.
297
475
 
298
476
  Args:
299
477
  images (Tensor): Tensor of images (any shape, with circle axes as specified).
@@ -302,8 +480,12 @@ class RandomCircleInclusion(layers.Layer):
302
480
  fill_value (float, optional): Optionally override fill_value for cases
303
481
  where image range has changed.
304
482
 
305
- Returns:
306
- Tensor: Percentage recovered for each circle (shape: [num_circles]).
483
+ Returns:
484
+ Tuple[Tensor, Tensor]:
485
+ - percent_recovered: [batch] - average recovery percentage per batch element,
486
+ averaged across all non-batch dimensions (e.g., frames, slices)
487
+ - recovered_masks: [batch, flat_batch, h, w] or [batch, h, w] or [flat_batch, h, w]
488
+ depending on input shape - binary mask of detected circle regions
307
489
  """
308
490
  fill_value = fill_value or self.fill_value
309
491
 
@@ -318,12 +500,23 @@ class RandomCircleInclusion(layers.Layer):
318
500
  recovered_sum = ops.sum(recovered, axis=[1, 2])
319
501
  mask_sum = ops.sum(mask, axis=[1, 2])
320
502
  percent_recovered = recovered_sum / (mask_sum + 1e-8)
321
- return percent_recovered
503
+ # recovered_mask: binary mask of detected part of the circle
504
+ recovered_mask = ops.cast(recovered > 0, flat_image.dtype)
505
+ return percent_recovered, recovered_mask
322
506
 
323
507
  if self.with_batch_dim:
324
- return ops.vectorized_map(
508
+ results = ops.vectorized_map(
325
509
  lambda args: _evaluate_recovered_circle_accuracy(args[0], args[1]),
326
510
  (images, centers),
327
- )[..., 0]
511
+ )
512
+ percent_recovered, recovered_masks = results
513
+ # If there are multiple circles per batch element (e.g., multiple frames/slices),
514
+ # take the mean across all non-batch dimensions to get one value per batch element
515
+ if len(percent_recovered.shape) > 1:
516
+ # Average over all axes except the batch dimension (axis 0)
517
+ axes_to_reduce = tuple(range(1, len(percent_recovered.shape)))
518
+ percent_recovered = ops.mean(percent_recovered, axis=axes_to_reduce)
519
+ return percent_recovered, recovered_masks
328
520
  else:
329
- return _evaluate_recovered_circle_accuracy(images, centers)
521
+ percent_recovered, recovered_mask = _evaluate_recovered_circle_accuracy(images, centers)
522
+ return percent_recovered, recovered_mask
@@ -1,6 +1 @@
1
- """All functions in this module are used to convert different datasets to the zea format."""
2
-
3
- from .camus import convert_camus
4
- from .images import convert_image_dataset
5
- from .matlab import zea_from_matlab_raw
6
- from .picmus import convert_picmus
1
+ """Data conversion of datasets to the ``zea`` data format."""
@@ -0,0 +1,123 @@
1
+ import argparse
2
+
3
+
4
+ def get_parser():
5
+ """
6
+ Build and parse command-line arguments for converting raw datasets to a zea dataset.
7
+
8
+ Returns:
9
+ argparse.Namespace: Parsed arguments with the following attributes:
10
+ dataset (str): One of "echonet", "echonetlvh", "camus", "picmus", "verasonics".
11
+ src (str): Source folder path.
12
+ dst (str): Destination folder path.
13
+ split_path (str|None): Optional path to a split.yaml to copy dataset splits.
14
+ no_hyperthreading (bool): Disable hyperthreading for multiprocessing.
15
+ frames (list[str]): MATLAB frames spec (e.g., ["all"], integers, or ranges like "4-8").
16
+ no_rejection (bool): EchonetLVH flag to skip manual_rejections.txt filtering.
17
+ batch (str|None): EchonetLVH Batch directory to process (e.g., "Batch2").
18
+ convert_measurements (bool): EchonetLVH flag to convert only measurements CSV.
19
+ convert_images (bool): EchonetLVH flag to convert only image files.
20
+ max_files (int|None): EchonetLVH maximum number of files to process.
21
+ force (bool): EchonetLVH flag to force recomputation even if parameters exist.
22
+ """
23
+ parser = argparse.ArgumentParser(description="Convert raw data to a zea dataset.")
24
+ parser.add_argument(
25
+ "dataset",
26
+ choices=["echonet", "echonetlvh", "camus", "picmus", "verasonics"],
27
+ help="Raw dataset to convert",
28
+ )
29
+ parser.add_argument("src", type=str, help="Source folder path")
30
+ parser.add_argument("dst", type=str, help="Destination folder path")
31
+ parser.add_argument(
32
+ "--split_path",
33
+ type=str,
34
+ help="Path to the split.yaml file containing the dataset split if a split should be copied",
35
+ )
36
+ parser.add_argument(
37
+ "--no_hyperthreading",
38
+ action="store_true",
39
+ help="Disable hyperthreading for multiprocessing",
40
+ )
41
+ # Dataset specific arguments:
42
+
43
+ # verasonics
44
+ parser.add_argument(
45
+ "--frames",
46
+ default=["all"],
47
+ type=str,
48
+ nargs="+",
49
+ help="verasonics: The frames to add to the file. This can be a list of integers, a range "
50
+ "of integers (e.g. 4-8), or 'all'.",
51
+ )
52
+ # ECHONET_LVH
53
+ parser.add_argument(
54
+ "--no_rejection",
55
+ action="store_true",
56
+ help="EchonetLVH: Do not reject sequences in manual_rejections.txt",
57
+ )
58
+
59
+ parser.add_argument(
60
+ "--batch",
61
+ type=str,
62
+ default=None,
63
+ help="EchonetLVH: Specify which BatchX directory to process, e.g. --batch=Batch2",
64
+ )
65
+ parser.add_argument(
66
+ "--convert_measurements",
67
+ action="store_true",
68
+ help="EchonetLVH: Only convert measurements CSV file",
69
+ )
70
+ parser.add_argument(
71
+ "--convert_images", action="store_true", help="EchonetLVH: Only convert image files"
72
+ )
73
+ parser.add_argument(
74
+ "--max_files",
75
+ type=int,
76
+ default=None,
77
+ help="EchonetLVH: Maximum number of files to process (for testing)",
78
+ )
79
+ parser.add_argument(
80
+ "--force",
81
+ action="store_true",
82
+ help="EchonetLVH: Force recomputation even if parameters already exist",
83
+ )
84
+ return parser
85
+
86
+
87
+ def main():
88
+ """
89
+ Parse command-line arguments and dispatch to the selected dataset conversion routine.
90
+
91
+ This function obtains CLI arguments via get_args() and calls the corresponding converter
92
+ (convert_echonet, convert_echonetlvh, convert_camus, convert_picmus, or convert_verasonics)
93
+ based on args.dataset.
94
+ Raises a ValueError if args.dataset is not one of the supported choices.
95
+ """
96
+ parser = get_parser()
97
+ args = parser.parse_args()
98
+ if args.dataset == "echonet":
99
+ from zea.data.convert.echonet import convert_echonet
100
+
101
+ convert_echonet(args)
102
+ elif args.dataset == "echonetlvh":
103
+ from zea.data.convert.echonetlvh import convert_echonetlvh
104
+
105
+ convert_echonetlvh(args)
106
+ elif args.dataset == "camus":
107
+ from zea.data.convert.camus import convert_camus
108
+
109
+ convert_camus(args)
110
+ elif args.dataset == "picmus":
111
+ from zea.data.convert.picmus import convert_picmus
112
+
113
+ convert_picmus(args)
114
+ elif args.dataset == "verasonics":
115
+ from zea.data.convert.verasonics import convert_verasonics
116
+
117
+ convert_verasonics(args)
118
+ else:
119
+ raise ValueError(f"Unknown dataset: {args.dataset}")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()