zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -5
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/config.py +34 -25
  10. zea/data/__init__.py +22 -25
  11. zea/data/augmentations.py +221 -28
  12. zea/data/convert/__init__.py +1 -6
  13. zea/data/convert/__main__.py +123 -0
  14. zea/data/convert/camus.py +101 -40
  15. zea/data/convert/echonet.py +187 -86
  16. zea/data/convert/echonetlvh/README.md +2 -3
  17. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  18. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  19. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  20. zea/data/convert/picmus.py +37 -40
  21. zea/data/convert/utils.py +86 -0
  22. zea/data/convert/{matlab.py → verasonics.py} +44 -65
  23. zea/data/data_format.py +155 -34
  24. zea/data/dataloader.py +12 -7
  25. zea/data/datasets.py +112 -71
  26. zea/data/file.py +184 -73
  27. zea/data/file_operations.py +496 -0
  28. zea/data/layers.py +3 -3
  29. zea/data/preset_utils.py +1 -1
  30. zea/datapaths.py +16 -4
  31. zea/display.py +14 -13
  32. zea/interface.py +14 -16
  33. zea/internal/_generate_keras_ops.py +6 -7
  34. zea/internal/cache.py +2 -49
  35. zea/internal/checks.py +6 -12
  36. zea/internal/config/validation.py +1 -2
  37. zea/internal/core.py +69 -6
  38. zea/internal/device.py +6 -2
  39. zea/internal/dummy_scan.py +330 -0
  40. zea/internal/operators.py +118 -2
  41. zea/internal/parameters.py +101 -70
  42. zea/internal/setup_zea.py +5 -6
  43. zea/internal/utils.py +282 -0
  44. zea/io_lib.py +322 -146
  45. zea/keras_ops.py +74 -4
  46. zea/log.py +9 -7
  47. zea/metrics.py +15 -7
  48. zea/models/__init__.py +31 -21
  49. zea/models/base.py +30 -14
  50. zea/models/carotid_segmenter.py +19 -4
  51. zea/models/diffusion.py +235 -23
  52. zea/models/echonet.py +22 -8
  53. zea/models/echonetlvh.py +31 -7
  54. zea/models/lpips.py +19 -2
  55. zea/models/lv_segmentation.py +30 -11
  56. zea/models/preset_utils.py +5 -5
  57. zea/models/regional_quality.py +30 -10
  58. zea/models/taesd.py +21 -5
  59. zea/models/unet.py +15 -1
  60. zea/ops.py +770 -336
  61. zea/probes.py +6 -6
  62. zea/scan.py +121 -51
  63. zea/simulator.py +24 -21
  64. zea/tensor_ops.py +477 -353
  65. zea/tools/fit_scan_cone.py +90 -160
  66. zea/tools/hf.py +1 -1
  67. zea/tools/selection_tool.py +47 -86
  68. zea/tracking/__init__.py +16 -0
  69. zea/tracking/base.py +94 -0
  70. zea/tracking/lucas_kanade.py +474 -0
  71. zea/tracking/segmentation.py +110 -0
  72. zea/utils.py +101 -480
  73. zea/visualize.py +177 -39
  74. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
  75. zea-0.0.8.dist-info/RECORD +122 -0
  76. zea-0.0.6.dist-info/RECORD +0 -112
  77. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  78. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  79. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/data/convert/camus.py CHANGED
@@ -1,14 +1,19 @@
1
1
  """Functionality to convert the camus dataset to the zea format.
2
- Requires SimpleITK to be installed: pip install SimpleITK.
2
+
3
+ .. note::
4
+ Requires SimpleITK to be installed: `pip install SimpleITK`.
5
+
6
+ For more information about the dataset, resort to the following links:
7
+
8
+ - The original dataset can be found at `this link <https://humanheart-project.creatis.insa-lyon.fr/database/#collection/6373703d73e9f0047faa1bc8>`_.
9
+
3
10
  """
4
11
 
5
12
  from __future__ import annotations
6
13
 
7
- import argparse
8
- import importlib.util
9
14
  import logging
10
15
  import os
11
- import sys
16
+ from concurrent.futures import ProcessPoolExecutor
12
17
  from pathlib import Path
13
18
  from typing import Any, Dict, Tuple
14
19
 
@@ -17,10 +22,11 @@ import scipy
17
22
  from skimage.transform import resize
18
23
  from tqdm import tqdm
19
24
 
20
- # from zea.display import transform_sc_image_to_polar
21
25
  from zea import log
26
+ from zea.data.convert.utils import unzip
22
27
  from zea.data.data_format import generate_zea_dataset
23
- from zea.utils import find_first_nonzero_index, translate
28
+ from zea.internal.utils import find_first_nonzero_index
29
+ from zea.tensor_ops import translate
24
30
 
25
31
 
26
32
  def transform_sc_image_to_polar(image_sc, output_size=None, fit_outline=True):
@@ -123,10 +129,12 @@ def sitk_load(filepath: str | Path) -> Tuple[np.ndarray, Dict[str, Any]]:
123
129
  filepath: Path to the image.
124
130
 
125
131
  Returns:
126
- - ([N], H, W), Image array.
132
+ - Image array of shape (num_frames, height, width).
127
133
  - Collection of metadata.
128
134
  """
129
135
  # Load image and save info
136
+ import SimpleITK as sitk
137
+
130
138
  image = sitk.ReadImage(str(filepath))
131
139
 
132
140
  all_metadata = {}
@@ -147,7 +155,7 @@ def sitk_load(filepath: str | Path) -> Tuple[np.ndarray, Dict[str, Any]]:
147
155
  return im_array, metadata
148
156
 
149
157
 
150
- def convert_camus(source_path, output_path, overwrite=False):
158
+ def process_camus(source_path, output_path, overwrite=False):
151
159
  """Converts the camus database to the zea format.
152
160
 
153
161
  Args:
@@ -187,29 +195,22 @@ def convert_camus(source_path, output_path, overwrite=False):
187
195
  )
188
196
 
189
197
 
190
- def get_args():
191
- """Parse command line arguments."""
192
- parser = argparse.ArgumentParser()
193
- parser.add_argument(
194
- "--source",
195
- type=str,
196
- # path to CAMUS_public/database_nifti
197
- required=True,
198
- )
199
- parser.add_argument(
200
- "--output",
201
- type=str,
202
- required=True,
203
- )
204
- args = parser.parse_args()
205
- return args
206
-
207
-
208
198
  splits = {"train": [1, 401], "val": [401, 451], "test": [451, 501]}
209
199
 
210
200
 
211
201
  def get_split(patient_id: int) -> str:
212
- """Determine the dataset split for a given patient ID."""
202
+ """
203
+ Determine which dataset split a patient ID belongs to.
204
+
205
+ Args:
206
+ patient_id: Integer ID of the patient.
207
+
208
+ Returns:
209
+ The split name: "train", "val", or "test".
210
+
211
+ Raises:
212
+ ValueError: If the patient_id does not fall into any defined split range.
213
+ """
213
214
  if splits["train"][0] <= patient_id < splits["train"][1]:
214
215
  return "train"
215
216
  elif splits["val"][0] <= patient_id < splits["val"][1]:
@@ -220,16 +221,58 @@ def get_split(patient_id: int) -> str:
220
221
  raise ValueError(f"Did not find split for patient: {patient_id}")
221
222
 
222
223
 
223
- if __name__ == "__main__":
224
- if importlib.util.find_spec("SimpleITK") is None:
225
- log.error("SimpleITK not installed. Please install SimpleITK: `pip install SimpleITK`")
226
- sys.exit()
227
- import SimpleITK as sitk
224
+ def _process_task(task):
225
+ """
226
+ Unpack a task tuple and invoke process_camus in a worker process.
227
+
228
+ Creates parent directories for the target outputs, calls process_camus
229
+ with the unpacked paths, and logs then re-raises any exception raised by processing.
230
+
231
+ Args:
232
+ task (tuple): (source_file_str, output_file_str)
233
+ - source_file_str: filesystem path to the source CAMUS file as a string.
234
+ - output_file_str: filesystem path for the ZEA output file as a string.
235
+ """
236
+ source_file_str, output_file_str = task
237
+ source_file = Path(source_file_str)
238
+ output_file = Path(output_file_str)
228
239
 
229
- args = get_args()
240
+ # Ensure destination directories exist (safe to call from multiple processes)
241
+ output_file.parent.mkdir(parents=True, exist_ok=True)
242
+
243
+ # Call the real processing function (must be importable in the worker)
244
+ # If process_camus lives in another module, import it there instead.
245
+ try:
246
+ process_camus(source_file, output_file, overwrite=False)
247
+ except Exception:
248
+ # Log and re-raise so the main process can handle it
249
+ log.error("Error processing %s", source_file)
250
+ raise
251
+
252
+
253
+ def convert_camus(args):
254
+ """
255
+ Converts the CAMUS dataset into ZEA HDF5 files across dataset splits.
230
256
 
231
- camus_source_folder = Path(args.source)
232
- camus_output_folder = Path(args.output)
257
+ Processes files found under the CAMUS source folder (after unzipping if needed),
258
+ assigns each patient to a train/val/test split, creates matching output paths,
259
+ and executes per-file conversion tasks either serially or in parallel.
260
+ Ensures output directories do not pre-exist, and logs progress and failures.
261
+
262
+ Args:
263
+ args (argparse.Namespace): An object with attributes:
264
+
265
+ - src (str | Path): Path to the CAMUS archive or extracted folder.
266
+ - dst (str | Path): Root destination folder for ZEA HDF5 outputs;
267
+ split subfolders will be created.
268
+ - no_hyperthreading (bool, optional): If True, run tasks serially instead
269
+ of using a process pool.
270
+ """
271
+ camus_source_folder = Path(args.src)
272
+ camus_output_folder = Path(args.dst)
273
+
274
+ # Look for either CAMUS_public.zip or folders database_nifti, database_split
275
+ camus_source_folder = unzip(camus_source_folder, "camus")
233
276
 
234
277
  # check if output folders already exist
235
278
  for split in splits:
@@ -238,9 +281,9 @@ if __name__ == "__main__":
238
281
  )
239
282
 
240
283
  # clone folder structure of source to output using pathlib
241
- # and run convert_camus() for every hdf5 found in there
242
284
  files = list(camus_source_folder.glob("**/*_half_sequence.nii.gz"))
243
- for source_file in tqdm(files):
285
+ tasks = []
286
+ for source_file in files:
244
287
  # check if source file in camus database (ignore other files)
245
288
  if "database_nifti" not in source_file.parts:
246
289
  continue
@@ -250,10 +293,28 @@ if __name__ == "__main__":
250
293
  split = get_split(patient_id)
251
294
 
252
295
  output_file = camus_output_folder / split / source_file.relative_to(camus_source_folder)
253
-
254
296
  # Replace .nii.gz with .hdf5
255
297
  output_file = output_file.with_suffix("").with_suffix(".hdf5")
256
-
257
298
  # make sure folder exists
258
299
  output_file.parent.mkdir(parents=True, exist_ok=True)
259
- convert_camus(source_file, output_file, overwrite=False)
300
+
301
+ tasks.append((str(source_file), str(output_file)))
302
+ if not tasks:
303
+ log.info("No files found to process.")
304
+ return
305
+
306
+ if getattr(args, "no_hyperthreading", False):
307
+ log.info("no_hyperthreading is True — running tasks serially (no ProcessPoolExecutor)")
308
+ for t in tqdm(tasks, desc="Processing files (serial)"):
309
+ try:
310
+ _process_task(t)
311
+ except Exception as e:
312
+ log.error("Task processing failed: %s", e)
313
+ log.info("Processing finished for %d files (serial)", len(tasks))
314
+ return
315
+
316
+ # Submit tasks to the process pool and track progress
317
+ with ProcessPoolExecutor() as exe:
318
+ for _ in tqdm(exe.map(_process_task, tasks), total=len(tasks), desc="Processing files"):
319
+ pass
320
+ log.info("Processing finished for %d files", len(tasks))
@@ -1,49 +1,30 @@
1
1
  """
2
- Script to convert the EchoNet database to .npy and zea formats.
3
- Will segment the images and convert them to polar coordinates.
4
- """
2
+ Script to convert the EchoNet database to zea format.
5
3
 
6
- import os
4
+ .. note::
5
+ Will segment the images and convert them to polar coordinates.
6
+
7
+ For more information about the dataset, resort to the following links:
7
8
 
8
- os.environ["KERAS_BACKEND"] = "numpy"
9
+ - The original dataset can be found at `this link <https://stanfordaimi.azurewebsites.net/datasets/834e1cd1-92f7-4268-9daa-d359198b310a>`_.
10
+ - The project page is available `here <https://echonet.github.io/>`_.
11
+
12
+ """
9
13
 
10
- import argparse
14
+ import os
11
15
  from concurrent.futures import ProcessPoolExecutor, as_completed
16
+ from multiprocessing import Value
12
17
  from pathlib import Path
13
18
 
14
19
  import numpy as np
20
+ import yaml
15
21
  from scipy.interpolate import griddata
16
22
  from tqdm import tqdm
17
23
 
18
- from zea.config import Config
24
+ from zea import log
19
25
  from zea.data import generate_zea_dataset
20
- from zea.io_lib import load_video
21
- from zea.utils import translate
22
-
23
-
24
- def get_args():
25
- """Parse command line arguments."""
26
- parser = argparse.ArgumentParser(description="Convert EchoNet to zea format")
27
- parser.add_argument(
28
- "--source",
29
- type=str,
30
- # path to EchoNet-Dynamic/Videos
31
- required=True,
32
- )
33
- parser.add_argument(
34
- "--output",
35
- type=str,
36
- required=True,
37
- )
38
- parser.add_argument(
39
- "--splits",
40
- type=str,
41
- default=None,
42
- )
43
- parser.add_argument("--output_numpy", type=str, default=None)
44
- parser.add_argument("--no_hyperthreading", action="store_true")
45
- args = parser.parse_args()
46
- return args
26
+ from zea.data.convert.utils import load_avi, unzip
27
+ from zea.tensor_ops import translate
47
28
 
48
29
 
49
30
  def segment(tensor, number_erasing=0, min_clip=0):
@@ -52,6 +33,8 @@ def segment(tensor, number_erasing=0, min_clip=0):
52
33
  Args:
53
34
  tensor (ndarray): Input image (sc) with 3 dimensions. (N, 112, 112)
54
35
  number_erasing (float, optional): number to fill the background with.
36
+ min_clip (float, optional): If > 0, values on the computed cone edge will be clipped
37
+ to be at least this value. Defaults to 0.
55
38
  Returns:
56
39
  tensor (ndarray): Segmented matrix of same dimensions as input
57
40
 
@@ -264,13 +247,37 @@ def cartesian_to_polar_matrix(
264
247
 
265
248
 
266
249
  def find_split_for_file(file_dict, target_file):
267
- """Function that finds the split for a given file in a dictionary."""
250
+ """
251
+ Locate which split contains a given filename.
252
+
253
+ Parameters:
254
+ file_dict (dict): Mapping from split name (e.g., "train", "val", "test", "rejected")
255
+ to an iterable of filenames.
256
+ target_file (str): Filename to search for within the split lists.
257
+
258
+ Returns:
259
+ str: The split name that contains `target_file`, or `"rejected"` if the file is not found.
260
+ """
268
261
  for split, files in file_dict.items():
269
262
  if target_file in files:
270
263
  return split
264
+ log.warning(f"File {target_file} not found in any split, defaulting to rejected.")
271
265
  return "rejected"
272
266
 
273
267
 
268
+ def count_init(shared_counter):
269
+ """
270
+ Initialize the module-level shared counter used by worker processes.
271
+
272
+ Parameters:
273
+ shared_counter (multiprocessing.Value): A process-shared integer Value that
274
+ will be assigned to the module-global COUNTER for coordinated counting
275
+ across processes.
276
+ """
277
+ global COUNTER
278
+ COUNTER = shared_counter
279
+
280
+
274
281
  class H5Processor:
275
282
  """
276
283
  Stores a few variables and paths to allow for hyperthreading.
@@ -279,7 +286,6 @@ class H5Processor:
279
286
  def __init__(
280
287
  self,
281
288
  path_out_h5,
282
- path_out=None,
283
289
  num_val=500,
284
290
  num_test=500,
285
291
  range_from=(0, 255),
@@ -287,7 +293,6 @@ class H5Processor:
287
293
  splits=None,
288
294
  ):
289
295
  self.path_out_h5 = Path(path_out_h5)
290
- self.path_out = Path(path_out) if path_out else None
291
296
  self.num_val = num_val
292
297
  self.num_test = num_test
293
298
  self.range_from = range_from
@@ -297,20 +302,33 @@ class H5Processor:
297
302
 
298
303
  # Ensure train, val, test, rejected paths exist
299
304
  for folder in ["train", "val", "test", "rejected"]:
300
- if self._to_numpy:
301
- (self.path_out / folder).mkdir(parents=True, exist_ok=True)
302
305
  (self.path_out_h5 / folder).mkdir(parents=True, exist_ok=True)
303
306
 
304
- @property
305
- def _to_numpy(self):
306
- return self.path_out is not None
307
-
308
- def translate(self, data):
307
+ def _translate(self, data):
309
308
  """Translate the data from the processing range to final range."""
310
309
  return translate(data, self._process_range, self.range_to)
311
310
 
312
311
  def get_split(self, hdf5_file: str, sequence):
313
- """Determine the split for a given file."""
312
+ """
313
+ Determine the dataset split label for a given file and its image sequence.
314
+
315
+ This method checks acceptance based on the first frame of `sequence`.
316
+ If explicit splits were provided to the processor, it returns the split
317
+ found for `hdf5_file` (and asserts that the acceptance result matches the split).
318
+ If no explicit splits are provided, rejected sequences are labeled `"rejected"`.
319
+ Accepted sequences increment a shared counter and are assigned
320
+ `"val"`, `"test"`, or `"train"` according to the processor's
321
+ `num_val` and `num_test` quotas.
322
+
323
+ Args:
324
+ hdf5_file (str): Filename or identifier used to look up an existing split
325
+ when splits are provided.
326
+ sequence (array-like): Time-ordered sequence of images; the first frame is
327
+ used for acceptance checking.
328
+
329
+ Returns:
330
+ str: One of `"train"`, `"val"`, `"test"`, or `"rejected"` indicating the assigned split.
331
+ """
314
332
  # Always check acceptance
315
333
  accepted = accept_shape(sequence[0])
316
334
 
@@ -324,25 +342,73 @@ class H5Processor:
324
342
  if not accepted:
325
343
  return "rejected"
326
344
 
327
- # This inefficient counter works with hyperthreading
328
- # TODO: but it is not reproducible!
329
- val_counter = len(list((self.path_out_h5 / "val").iterdir()))
330
- test_counter = len(list((self.path_out_h5 / "test").iterdir()))
345
+ # Increment the hyperthreading counter
346
+ # Note that some threads will start on subsequent splits
347
+ # while others are still processing
348
+ with COUNTER.get_lock():
349
+ COUNTER.value += 1
350
+ n = COUNTER.value
331
351
 
332
352
  # Determine the split
333
- if val_counter < self.num_val:
353
+ if n <= self.num_val:
334
354
  return "val"
335
- elif test_counter < self.num_test:
355
+ elif n <= self.num_val + self.num_test:
336
356
  return "test"
337
357
  else:
338
358
  return "train"
339
359
 
360
+ def validate_split_copy(self, split_file):
361
+ """
362
+ Validate that a generated split YAML matches the original splits provided to the processor.
363
+
364
+ Reads the YAML at `split_file` and compares its `train`, `val`, `test`, and `rejected` lists
365
+ (or other split keys present in `self.splits`) against `self.splits`; logs confirmation
366
+ when a split matches and logs which entries are missing or extra when they differ. If the
367
+ processor was not initialized with `splits`, validation is skipped and a message is logged.
368
+
369
+ Args:
370
+ split_file (str or os.PathLike): Path to the YAML file containing the
371
+ generated dataset splits.
372
+ """
373
+ if self.splits is not None:
374
+ # Read the split_file and ensure contents of the train, val and split match
375
+ with open(split_file, "r") as f:
376
+ new_splits = yaml.safe_load(f)
377
+ for split in self.splits.keys():
378
+ if set(new_splits[split]) == set(self.splits[split]):
379
+ log.info(f"Split {split} copied correctly.")
380
+ else:
381
+ # Log which entry is missing or extra in the split_file
382
+ missing = set(self.splits[split]) - set(new_splits[split])
383
+ extra = set(new_splits[split]) - set(self.splits[split])
384
+ if missing:
385
+ log.warning(f"New dataset split {split} is missing entries: {missing}")
386
+ if extra:
387
+ log.warning(f"New dataset split {split} has extra entries: {extra}")
388
+ else:
389
+ log.info(
390
+ "Processor not initialized with a split, not validating if the split was copied."
391
+ )
392
+
340
393
  def __call__(self, avi_file):
341
394
  """
342
- Processes a single h5 file using the class variables and the filename given.
395
+ Convert a single AVI file into a zea dataset entry.
396
+ Loads the AVI, validates and rescales pixel ranges, applies segmentation,
397
+ assigns a data split (train/val/test/rejected), converts accepted frames
398
+ to polar coordinates.
399
+ Constructs and returns the zea dataset descriptor used by
400
+ generate_zea_dataset; the descriptor always includes `path`, `image_sc`,
401
+ `probe_name`, and `description`, and includes `image` when the file is accepted.
402
+
403
+ Args:
404
+ avi_file (pathlib.Path): Path to the source .avi file to process.
405
+
406
+ Returns:
407
+ dict: The value returned by generate_zea_dataset containing the dataset
408
+ entry for the processed file.
343
409
  """
344
410
  hdf5_file = avi_file.stem + ".hdf5"
345
- sequence = load_video(avi_file)
411
+ sequence = load_avi(avi_file)
346
412
 
347
413
  assert sequence.min() >= self.range_from[0], f"{sequence.min()} < {self.range_from[0]}"
348
414
  assert sequence.max() <= self.range_from[1], f"{sequence.max()} > {self.range_from[1]}"
@@ -356,25 +422,14 @@ class H5Processor:
356
422
  accepted = split != "rejected"
357
423
 
358
424
  out_h5 = self.path_out_h5 / split / hdf5_file
359
- if self._to_numpy:
360
- out_dir = self.path_out / split / avi_file.stem
361
- out_dir.mkdir(parents=True, exist_ok=True)
362
425
 
363
426
  polar_im_set = []
364
- for i, im in enumerate(sequence):
365
- if self._to_numpy:
366
- np.save(out_dir / f"sc{str(i).zfill(3)}.npy", im)
367
-
427
+ for _, im in enumerate(sequence):
368
428
  if not accepted:
369
429
  continue
370
430
 
371
431
  polar_im = cartesian_to_polar_matrix(im, interpolation="cubic")
372
432
  polar_im = np.clip(polar_im, *self._process_range)
373
- if self._to_numpy:
374
- np.save(
375
- out_dir / f"polar{str(i).zfill(3)}.npy",
376
- polar_im,
377
- )
378
433
  polar_im_set.append(polar_im)
379
434
 
380
435
  if accepted:
@@ -386,53 +441,99 @@ class H5Processor:
386
441
 
387
442
  zea_dataset = {
388
443
  "path": out_h5,
389
- "image_sc": self.translate(sequence),
444
+ "image_sc": self._translate(sequence),
390
445
  "probe_name": "generic",
391
446
  "description": "EchoNet dataset converted to zea format",
392
447
  }
393
448
  if accepted:
394
- zea_dataset["image"] = self.translate(polar_im_set)
449
+ zea_dataset["image"] = self._translate(polar_im_set)
395
450
  return generate_zea_dataset(**zea_dataset)
396
451
 
397
452
 
398
- if __name__ == "__main__":
399
- args = get_args()
453
+ def convert_echonet(args):
454
+ """
455
+ Convert an EchoNet dataset into zea files, organizing results
456
+ into train/val/test/rejected splits.
457
+
458
+ Args:
459
+ args (argparse.Namespace): An object with the following attributes.
460
+
461
+ - src (str|Path): Path to the source archive or directory containing .avi files.
462
+ Will be unzipped if needed.
463
+ - dst (str|Path): Destination directory for generated zea files
464
+ per-split subdirectories (train, val, test, rejected) and a split.yaml
465
+ are created or updated.
466
+ - split_path (str|Path|None): If provided, must contain a split.yaml to reproduce
467
+ an existing split; function asserts the file exists.
468
+ - no_hyperthreading (bool): When false, processing uses a ProcessPoolExecutor
469
+ with a shared counter; when true, processing runs sequentially.
470
+
471
+ Note:
472
+ - May unzip the source into a working directory.
473
+ - Writes zea files into dst.
474
+ - Writes a split.yaml into dst summarizing produced files per split.
475
+ - Logs progress and validation results.
476
+ - Asserts that split.yaml exists at split_path when split reproduction is requested.
477
+ """
478
+ # Check if unzip is needed
479
+ src = unzip(args.src, "echonet")
400
480
 
401
- if args.splits is not None:
481
+ if args.split_path is not None:
402
482
  # Reproduce a previous split...
403
- split_yaml_dir = Path(args.splits)
404
- splits = {"train": None, "val": None, "test": None}
405
- for split in splits:
406
- yaml_file = split_yaml_dir / (split + ".yaml")
407
- assert yaml_file.exists(), f"File {yaml_file} does not exist."
408
- splits[split] = Config.from_yaml(yaml_file)["file_paths"]
483
+ yaml_file = Path(args.split_path) / "split.yaml"
484
+ assert yaml_file.exists(), f"File {yaml_file} does not exist."
485
+ splits = {"train": None, "val": None, "test": None, "rejected": None}
486
+ with open(yaml_file, "r") as f:
487
+ splits = yaml.safe_load(f)
488
+ log.info(f"Processor initialized with train-val-test split from {yaml_file}.")
409
489
  else:
410
490
  splits = None
411
491
 
412
492
  # List the files that have an entry in path_out_h5 already
413
493
  files_done = []
414
- for _, _, filenames in os.walk(args.output):
494
+ for _, _, filenames in os.walk(args.dst):
415
495
  for filename in filenames:
416
496
  files_done.append(filename.replace(".hdf5", ""))
417
497
 
418
498
  # List all files of echonet and exclude those already processed
419
- path_in = Path(args.source)
499
+ path_in = Path(src)
420
500
  h5_files = path_in.glob("*.avi")
421
501
  h5_files = [file for file in h5_files if file.stem not in files_done]
422
- print(f"Files left to process: {len(h5_files)}")
502
+ log.info(f"Files left to process: {len(h5_files)}")
423
503
 
424
504
  # Run the processor
425
- processor = H5Processor(path_out_h5=args.output, path_out=args.output_numpy, splits=splits)
505
+ processor = H5Processor(path_out_h5=args.dst, splits=splits)
426
506
 
427
- print("Starting the conversion process.")
507
+ log.info("Starting the conversion process.")
428
508
 
429
509
  if not args.no_hyperthreading:
430
- with ProcessPoolExecutor() as executor:
431
- futures = {executor.submit(processor, file): file for file in h5_files}
432
- for future in tqdm(as_completed(futures), total=len(h5_files)):
433
- future.result()
510
+ shared_counter = Value("i", 0)
511
+ with ProcessPoolExecutor(initializer=count_init, initargs=(shared_counter,)) as executor:
512
+ futures = [executor.submit(processor, file) for file in h5_files]
513
+ for future in tqdm(as_completed(futures), total=len(futures)):
514
+ try:
515
+ future.result()
516
+ except Exception:
517
+ log.warning("Task raised an exception")
434
518
  else:
519
+ # Initialize global variable for counting
520
+ count_init(Value("i", 0))
435
521
  for file in tqdm(h5_files):
436
522
  processor(file)
437
523
 
438
- print("All tasks are completed.")
524
+ log.info("All tasks are completed.")
525
+
526
+ # Write to yaml split files
527
+ full_list = {}
528
+ for split in ["train", "val", "test", "rejected"]:
529
+ split_dir = Path(args.dst) / split
530
+
531
+ # Get only files (skip directories)
532
+ file_list = [f.name for f in split_dir.iterdir() if f.is_file()]
533
+ full_list[split] = file_list
534
+
535
+ with open(Path(args.dst) / "split.yaml", "w") as f:
536
+ yaml.dump(full_list, f)
537
+
538
+ # Validate that the split was copied correctly
539
+ processor.validate_split_copy(Path(args.dst) / "split.yaml")
@@ -3,7 +3,6 @@
3
3
  vary from file to file. This is done by the `precompute_crop.py` script, which will
4
4
  produce as output a .csv file `cone_parameters.csv`. The cone parameters will indicate
5
5
  how the video should be cropped in order to bound the cone and remove margins.
6
- - Next, `convert_raw_to_zea.py` can be run to convert the dataset to zea format,
7
- with cropping and scan conversion. The measurement locations stored in `MeasurementsList.csv`
6
+ - Next, `__init__.py` converts the dataset to zea format,
7
+ with cropping and scan conversion. The original measurement locations stored in `MeasurementsList.csv`
8
8
  are also updated to match the new cropping / padding coordinates.
9
- - You can save the video and measurement plots for a converted video using `examples/echonetlvh/plot_sample.py`.