zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -5
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/config.py +34 -25
  10. zea/data/__init__.py +22 -25
  11. zea/data/augmentations.py +221 -28
  12. zea/data/convert/__init__.py +1 -6
  13. zea/data/convert/__main__.py +123 -0
  14. zea/data/convert/camus.py +101 -40
  15. zea/data/convert/echonet.py +187 -86
  16. zea/data/convert/echonetlvh/README.md +2 -3
  17. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  18. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  19. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  20. zea/data/convert/picmus.py +37 -40
  21. zea/data/convert/utils.py +86 -0
  22. zea/data/convert/{matlab.py → verasonics.py} +44 -65
  23. zea/data/data_format.py +155 -34
  24. zea/data/dataloader.py +12 -7
  25. zea/data/datasets.py +112 -71
  26. zea/data/file.py +184 -73
  27. zea/data/file_operations.py +496 -0
  28. zea/data/layers.py +3 -3
  29. zea/data/preset_utils.py +1 -1
  30. zea/datapaths.py +16 -4
  31. zea/display.py +14 -13
  32. zea/interface.py +14 -16
  33. zea/internal/_generate_keras_ops.py +6 -7
  34. zea/internal/cache.py +2 -49
  35. zea/internal/checks.py +6 -12
  36. zea/internal/config/validation.py +1 -2
  37. zea/internal/core.py +69 -6
  38. zea/internal/device.py +6 -2
  39. zea/internal/dummy_scan.py +330 -0
  40. zea/internal/operators.py +118 -2
  41. zea/internal/parameters.py +101 -70
  42. zea/internal/setup_zea.py +5 -6
  43. zea/internal/utils.py +282 -0
  44. zea/io_lib.py +322 -146
  45. zea/keras_ops.py +74 -4
  46. zea/log.py +9 -7
  47. zea/metrics.py +15 -7
  48. zea/models/__init__.py +31 -21
  49. zea/models/base.py +30 -14
  50. zea/models/carotid_segmenter.py +19 -4
  51. zea/models/diffusion.py +235 -23
  52. zea/models/echonet.py +22 -8
  53. zea/models/echonetlvh.py +31 -7
  54. zea/models/lpips.py +19 -2
  55. zea/models/lv_segmentation.py +30 -11
  56. zea/models/preset_utils.py +5 -5
  57. zea/models/regional_quality.py +30 -10
  58. zea/models/taesd.py +21 -5
  59. zea/models/unet.py +15 -1
  60. zea/ops.py +770 -336
  61. zea/probes.py +6 -6
  62. zea/scan.py +121 -51
  63. zea/simulator.py +24 -21
  64. zea/tensor_ops.py +477 -353
  65. zea/tools/fit_scan_cone.py +90 -160
  66. zea/tools/hf.py +1 -1
  67. zea/tools/selection_tool.py +47 -86
  68. zea/tracking/__init__.py +16 -0
  69. zea/tracking/base.py +94 -0
  70. zea/tracking/lucas_kanade.py +474 -0
  71. zea/tracking/segmentation.py +110 -0
  72. zea/utils.py +101 -480
  73. zea/visualize.py +177 -39
  74. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
  75. zea-0.0.8.dist-info/RECORD +122 -0
  76. zea-0.0.6.dist-info/RECORD +0 -112
  77. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  78. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  79. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/data/file.py CHANGED
@@ -2,10 +2,11 @@
2
2
 
3
3
  import enum
4
4
  from pathlib import Path
5
- from typing import List
5
+ from typing import List, Tuple, Union
6
6
 
7
7
  import h5py
8
8
  import numpy as np
9
+ from keras.utils import pad_sequences
9
10
 
10
11
  from zea import log
11
12
  from zea.data.preset_utils import HF_PREFIX, _hf_resolve_path
@@ -15,9 +16,10 @@ from zea.internal.checks import (
15
16
  _REQUIRED_SCAN_KEYS,
16
17
  get_check,
17
18
  )
19
+ from zea.internal.core import DataTypes
20
+ from zea.internal.utils import reduce_to_signature
18
21
  from zea.probes import Probe
19
22
  from zea.scan import Scan
20
- from zea.utils import reduce_to_signature
21
23
 
22
24
 
23
25
  def assert_key(file: h5py.File, key: str):
@@ -110,64 +112,10 @@ class File(h5py.File):
110
112
  else:
111
113
  raise NotImplementedError
112
114
 
113
- @staticmethod
114
- def _prepare_indices(indices):
115
- """Prepare the indices for loading data from hdf5 files.
116
- Options:
117
- - str("all")
118
- - int -> single frame
119
- - list of ints -> indexes first axis (frames)
120
- - list of list, ranges or slices -> indexes multiple axes
121
-
122
- Returns:
123
- indices (tuple): A tuple of indices / slices to use for indexing.
124
- """
125
- _value_error_msg = (
126
- f"Invalid value for indices: {indices}. "
127
- "Indices can be a 'all', int or a List[int, tuple, list, slice, range]."
128
- )
129
-
130
- # Check all options that only index the first axis
131
- if isinstance(indices, str):
132
- if indices == "all":
133
- return slice(None)
134
- else:
135
- raise ValueError(_value_error_msg)
136
-
137
- if isinstance(indices, range):
138
- return list(indices)
139
-
140
- if isinstance(indices, (int, slice, np.integer)):
141
- return indices
142
-
143
- # At this point, indices should be a list or tuple
144
- assert isinstance(indices, (list, tuple, np.ndarray)), _value_error_msg
145
-
146
- assert all(
147
- isinstance(idx, (list, tuple, int, slice, range, np.ndarray, np.integer))
148
- for idx in indices
149
- ), _value_error_msg
150
-
151
- # Convert ranges to lists
152
- processed_indices = [list(idx) if isinstance(idx, range) else idx for idx in indices]
153
-
154
- # Check if items are list-like and cast to tuple (needed for hdf5)
155
- if any(isinstance(idx, (list, tuple, slice)) for idx in processed_indices):
156
- processed_indices = tuple(processed_indices)
157
-
158
- return processed_indices
159
-
160
115
  def load_scan(self, event=None):
161
116
  """Alias for get_scan_parameters."""
162
117
  return self.get_scan_parameters(event)
163
118
 
164
- @staticmethod
165
- def check_data(data, key):
166
- """Check the data for a given key. For example, will check if the shape matches
167
- the data type (such as raw_data, ...)"""
168
- if key in _DATA_TYPES:
169
- get_check(key)(data, with_batch_dim=None)
170
-
171
119
  def format_key(self, key):
172
120
  """Format the key to match the data type."""
173
121
  # TODO: support events
@@ -205,7 +153,7 @@ class File(h5py.File):
205
153
  def load_transmits(self, key, selected_transmits):
206
154
  """Load raw_data or aligned_data for a given list of transmits.
207
155
  Args:
208
- data_type (str): The type of data to load. Options are 'raw_data' and 'aligned_data'.
156
+ key (str): The type of data to load. Options are 'raw_data' and 'aligned_data'.
209
157
  selected_transmits (list, np.ndarray): The transmits to load.
210
158
  """
211
159
  key = self.format_key(key)
@@ -213,22 +161,51 @@ class File(h5py.File):
213
161
  assert data_type in ["raw_data", "aligned_data"], (
214
162
  f"Cannot load transmits for {data_type}. Only raw_data and aligned_data are supported."
215
163
  )
216
- indices = [slice(None), np.array(selected_transmits)]
164
+ # First axis: all frames, second axis: selected transmits
165
+ indices = (slice(None), np.array(selected_transmits))
217
166
  return self.load_data(key, indices)
218
167
 
219
- def load_data(self, data_type, indices: str | int | List[int] = "all"):
168
+ def load_data(
169
+ self,
170
+ data_type,
171
+ indices: Tuple[Union[list, slice, int], ...] | List[int] | int | None = None,
172
+ ):
220
173
  """Load data from the file.
221
174
 
175
+ .. include:: ../common/file_indexing.rst
176
+
177
+ .. doctest::
178
+
179
+ >>> from zea import File
180
+
181
+ >>> path_to_file = (
182
+ ... "hf://zeahub/picmus/database/experiments/contrast_speckle/"
183
+ ... "contrast_speckle_expe_dataset_iq/contrast_speckle_expe_dataset_iq.hdf5"
184
+ ... )
185
+
186
+ >>> with File(path_to_file, mode="r") as file:
187
+ ... # data has shape (n_frames, n_tx, n_el, n_ax, n_ch)
188
+ ... data = file.load_data("raw_data")
189
+ ... data.shape
190
+ ... # load first frame only
191
+ ... data = file.load_data("raw_data", indices=0)
192
+ ... data.shape
193
+ ... # load frame 0 and transmits 0, 2 and 4
194
+ ... data = file.load_data("raw_data", indices=(0, [0, 2, 4]))
195
+ ... data.shape
196
+ (1, 75, 832, 128, 2)
197
+ (75, 832, 128, 2)
198
+ (3, 832, 128, 2)
199
+
222
200
  Args:
223
201
  data_type (str): The type of data to load. Options are 'raw_data', 'aligned_data',
224
202
  'beamformed_data', 'envelope_data', 'image' and 'image_sc'.
225
- indices (str, int, list, optional): The indices to load. Defaults to "all" in
226
- which case all frames are loaded. If an int is provided, it will be used
227
- as a single index. If a list is provided, it will be used as a list of
228
- indices.
203
+ indices (optional): The indices to load. Defaults to `None` in
204
+ which case all data is loaded.
229
205
  """
230
206
  key = self.format_key(data_type)
231
- indices = self._prepare_indices(indices)
207
+ if indices is None or (isinstance(indices, str) and indices == "all"):
208
+ indices = slice(None)
232
209
 
233
210
  if self._simple_index(key):
234
211
  data = self[key]
@@ -238,7 +215,6 @@ class File(h5py.File):
238
215
  raise ValueError(
239
216
  f"Invalid indices {indices} for key {key}. {key} has shape {data.shape}."
240
217
  ) from exc
241
- self.check_data(data, key)
242
218
  elif self.events_have_same_shape(key):
243
219
  raise NotImplementedError
244
220
  else:
@@ -337,7 +313,7 @@ class File(h5py.File):
337
313
  Returns:
338
314
  Scan: The scan object.
339
315
  """
340
- return Scan.merge(self.get_scan_parameters(event), kwargs, safe=safe)
316
+ return Scan.merge(_reformat_waveforms(self.get_scan_parameters(event)), kwargs, safe=safe)
341
317
 
342
318
  def get_probe_parameters(self, event=None) -> dict:
343
319
  """Returns a dictionary of probe parameters to initialize a probe
@@ -389,6 +365,21 @@ class File(h5py.File):
389
365
  ans[key] = self.recursively_load_dict_contents_from_group(path + "/" + key + "/")
390
366
  return ans
391
367
 
368
+ def has_key(self, key: str) -> bool:
369
+ """Check if the file has a specific key.
370
+
371
+ Args:
372
+ key (str): The key to check.
373
+
374
+ Returns:
375
+ bool: True if the key exists, False otherwise.
376
+ """
377
+ try:
378
+ key = self.format_key(key)
379
+ except AssertionError:
380
+ return False
381
+ return True
382
+
392
383
  @classmethod
393
384
  def get_shape(cls, path: str, key: str) -> tuple:
394
385
  """Get the shape of a key in a file.
@@ -454,10 +445,67 @@ class File(h5py.File):
454
445
  _print_hdf5_attrs(self)
455
446
 
456
447
 
448
+ def load_file_all_data_types(
449
+ path,
450
+ indices: Tuple[Union[list, slice, int], ...] | List[int] | int | None = None,
451
+ scan_kwargs: dict = None,
452
+ ):
453
+ """Loads a zea data files (h5py file).
454
+
455
+ Returns all data types together with a scan object containing the parameters
456
+ of the acquisition and a probe object containing the parameters of the probe.
457
+
458
+ Additionally, it can load a specific subset of frames / transmits.
459
+
460
+ .. include:: ../common/file_indexing.rst
461
+
462
+ Args:
463
+ path (str, pathlike): The path to the hdf5 file.
464
+ indices (optional): The indices to load. Defaults to None in
465
+ which case all frames are loaded.
466
+ scan_kwargs (Config, dict, optional): Additional keyword arguments
467
+ to pass to the Scan object. These will override the parameters from the file
468
+ if they are present in the file. Defaults to None.
469
+
470
+ Returns:
471
+ (dict): A dictionary with all data types as keys and the corresponding data as values.
472
+ (Scan): A scan object containing the parameters of the acquisition.
473
+ (Probe): A probe object containing the parameters of the probe.
474
+ """
475
+ # Define the additional keyword parameters from the scan object
476
+ if scan_kwargs is None:
477
+ scan_kwargs = {}
478
+
479
+ data_dict = {}
480
+
481
+ with File(path, mode="r") as file:
482
+ # Load the probe object from the file
483
+ probe = file.probe()
484
+
485
+ for data_type in DataTypes:
486
+ if not file.has_key(data_type.value):
487
+ data_dict[data_type.value] = None
488
+ continue
489
+
490
+ # Load the desired frames from the file
491
+ data_dict[data_type.value] = file.load_data(data_type.value, indices=indices)
492
+
493
+ # extract transmits from indices
494
+ # we only have to do this when the data has a n_tx dimension
495
+ # in that case we also have update scan parameters to match
496
+ # the number of selected transmits
497
+ if isinstance(indices, tuple) and len(indices) > 1:
498
+ scan_kwargs["selected_transmits"] = indices[1]
499
+
500
+ scan = file.scan(**scan_kwargs)
501
+
502
+ return data_dict, scan, probe
503
+
504
+
457
505
  def load_file(
458
506
  path,
459
507
  data_type="raw_data",
460
- indices: str | int | List[int] = "all",
508
+ indices: Tuple[Union[list, slice, int], ...] | List[int] | int | None = None,
461
509
  scan_kwargs: dict = None,
462
510
  ):
463
511
  """Loads a zea data files (h5py file).
@@ -467,17 +515,15 @@ def load_file(
467
515
 
468
516
  Additionally, it can load a specific subset of frames / transmits.
469
517
 
470
- # TODO: add support for event
518
+ .. include:: ../common/file_indexing.rst
471
519
 
472
520
  Args:
473
521
  path (str, pathlike): The path to the hdf5 file.
474
522
  data_type (str, optional): The type of data to load. Defaults to
475
523
  'raw_data'. Other options are 'aligned_data', 'beamformed_data',
476
524
  'envelope_data', 'image' and 'image_sc'.
477
- indices (str, int, list, optional): The indices to load. Defaults to "all" in
478
- which case all frames are loaded. If an int is provided, it will be used
479
- as a single index. If a list is provided, it will be used as a list of
480
- indices.
525
+ indices (optional): The indices to load. Defaults to None in
526
+ which case all frames are loaded.
481
527
  scan_kwargs (Config, dict, optional): Additional keyword arguments
482
528
  to pass to the Scan object. These will override the parameters from the file
483
529
  if they are present in the file. Defaults to None.
@@ -503,7 +549,6 @@ def load_file(
503
549
  # in that case we also have update scan parameters to match
504
550
  # the number of selected transmits
505
551
  if data_type in ["raw_data", "aligned_data"]:
506
- indices = File._prepare_indices(indices)
507
552
  if isinstance(indices, tuple) and len(indices) > 1:
508
553
  scan_kwargs["selected_transmits"] = indices[1]
509
554
 
@@ -772,3 +817,69 @@ def _assert_unit_and_description_present(hdf5_file, _prefix=""):
772
817
  assert "description" in hdf5_file[key].attrs.keys(), (
773
818
  f"The file {_prefix}/{key} does not have a description attribute."
774
819
  )
820
+
821
+
822
+ def _reformat_waveforms(scan_kwargs: dict) -> dict:
823
+ """Reformat waveforms from dict to array if needed. This is for backwards compatibility and will
824
+ be removed in a future version of zea.
825
+
826
+ Args:
827
+ scan_kwargs (dict): The scan parameters.
828
+
829
+ Returns:
830
+ scan_kwargs (dict): The scan parameters with the keys waveforms_one_way and
831
+ waveforms_two_way reformatted to arrays if they were stored as dicts.
832
+ """
833
+
834
+ # TODO: remove this in a future version of zea
835
+ if "waveforms_one_way" in scan_kwargs and isinstance(scan_kwargs["waveforms_one_way"], dict):
836
+ log.warning(
837
+ "The waveforms_one_way parameter is stored as a dictionary in the file. "
838
+ "Converting to array. This will be deprecated in future versions of zea. "
839
+ "Please update your files to store waveforms as arrays of shape `(n_tx, n_samples)`."
840
+ )
841
+ scan_kwargs["waveforms_one_way"] = _waveforms_dict_to_array(
842
+ scan_kwargs["waveforms_one_way"]
843
+ )
844
+
845
+ if "waveforms_two_way" in scan_kwargs and isinstance(scan_kwargs["waveforms_two_way"], dict):
846
+ log.warning(
847
+ "The waveforms_two_way parameter is stored as a dictionary in the file. "
848
+ "Converting to array. This will be deprecated in future versions of zea. "
849
+ "Please update your files to store waveforms as arrays of shape `(n_tx, n_samples)`."
850
+ )
851
+ scan_kwargs["waveforms_two_way"] = _waveforms_dict_to_array(
852
+ scan_kwargs["waveforms_two_way"]
853
+ )
854
+ return scan_kwargs
855
+
856
+
857
+ def _waveforms_dict_to_array(waveforms_dict: dict):
858
+ """Convert waveforms stored as a dictionary to a padded numpy array."""
859
+ waveforms = dict_to_sorted_list(waveforms_dict)
860
+ return pad_sequences(waveforms, dtype=np.float32, padding="post")
861
+
862
+
863
+ def dict_to_sorted_list(dictionary: dict):
864
+ """Convert a dictionary with sortable keys to a sorted list of values.
865
+
866
+ .. note::
867
+
868
+ This function operates on the top level of the dictionary only.
869
+ If the dictionary contains nested dictionaries, those will not be sorted.
870
+
871
+ Example:
872
+ .. doctest::
873
+
874
+ >>> from zea.data.file import dict_to_sorted_list
875
+ >>> input_dict = {"number_000": 5, "number_001": 1, "number_002": 23}
876
+ >>> dict_to_sorted_list(input_dict)
877
+ [5, 1, 23]
878
+
879
+ Args:
880
+ dictionary (dict): The dictionary to convert. The keys must be sortable.
881
+
882
+ Returns:
883
+ list: The sorted list of values.
884
+ """
885
+ return [value for _, value in sorted(dictionary.items())]