zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -5
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/config.py +34 -25
  10. zea/data/__init__.py +22 -25
  11. zea/data/augmentations.py +221 -28
  12. zea/data/convert/__init__.py +1 -6
  13. zea/data/convert/__main__.py +123 -0
  14. zea/data/convert/camus.py +101 -40
  15. zea/data/convert/echonet.py +187 -86
  16. zea/data/convert/echonetlvh/README.md +2 -3
  17. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  18. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  19. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  20. zea/data/convert/picmus.py +37 -40
  21. zea/data/convert/utils.py +86 -0
  22. zea/data/convert/{matlab.py → verasonics.py} +44 -65
  23. zea/data/data_format.py +155 -34
  24. zea/data/dataloader.py +12 -7
  25. zea/data/datasets.py +112 -71
  26. zea/data/file.py +184 -73
  27. zea/data/file_operations.py +496 -0
  28. zea/data/layers.py +3 -3
  29. zea/data/preset_utils.py +1 -1
  30. zea/datapaths.py +16 -4
  31. zea/display.py +14 -13
  32. zea/interface.py +14 -16
  33. zea/internal/_generate_keras_ops.py +6 -7
  34. zea/internal/cache.py +2 -49
  35. zea/internal/checks.py +6 -12
  36. zea/internal/config/validation.py +1 -2
  37. zea/internal/core.py +69 -6
  38. zea/internal/device.py +6 -2
  39. zea/internal/dummy_scan.py +330 -0
  40. zea/internal/operators.py +118 -2
  41. zea/internal/parameters.py +101 -70
  42. zea/internal/setup_zea.py +5 -6
  43. zea/internal/utils.py +282 -0
  44. zea/io_lib.py +322 -146
  45. zea/keras_ops.py +74 -4
  46. zea/log.py +9 -7
  47. zea/metrics.py +15 -7
  48. zea/models/__init__.py +31 -21
  49. zea/models/base.py +30 -14
  50. zea/models/carotid_segmenter.py +19 -4
  51. zea/models/diffusion.py +235 -23
  52. zea/models/echonet.py +22 -8
  53. zea/models/echonetlvh.py +31 -7
  54. zea/models/lpips.py +19 -2
  55. zea/models/lv_segmentation.py +30 -11
  56. zea/models/preset_utils.py +5 -5
  57. zea/models/regional_quality.py +30 -10
  58. zea/models/taesd.py +21 -5
  59. zea/models/unet.py +15 -1
  60. zea/ops.py +770 -336
  61. zea/probes.py +6 -6
  62. zea/scan.py +121 -51
  63. zea/simulator.py +24 -21
  64. zea/tensor_ops.py +477 -353
  65. zea/tools/fit_scan_cone.py +90 -160
  66. zea/tools/hf.py +1 -1
  67. zea/tools/selection_tool.py +47 -86
  68. zea/tracking/__init__.py +16 -0
  69. zea/tracking/base.py +94 -0
  70. zea/tracking/lucas_kanade.py +474 -0
  71. zea/tracking/segmentation.py +110 -0
  72. zea/utils.py +101 -480
  73. zea/visualize.py +177 -39
  74. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
  75. zea-0.0.8.dist-info/RECORD +122 -0
  76. zea-0.0.6.dist-info/RECORD +0 -112
  77. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  78. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  79. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/probes.py CHANGED
@@ -16,14 +16,14 @@ Example usage
16
16
 
17
17
  We can initialize a generic probe with the following code:
18
18
 
19
- .. code-block:: python
19
+ .. doctest::
20
20
 
21
- import zea
21
+ >>> from zea import Probe
22
22
 
23
- probe = zea.Probe.from_name("generic")
24
- print(probe.get_parameters())
25
-
26
- """
23
+ >>> probe = Probe.from_name("generic")
24
+ >>> print(probe.get_parameters())
25
+ {'probe_geometry': None, 'center_frequency': None, 'sampling_frequency': None, 'xlims': None, 'zlims': None}
26
+ """ # noqa: E501
27
27
 
28
28
  import numpy as np
29
29
 
zea/scan.py CHANGED
@@ -43,38 +43,38 @@ Comparison to ``zea.Config`` and ``zea.Probe``
43
43
  Example Usage
44
44
  ^^^^^^^^^^^^^
45
45
 
46
- .. code-block:: python
47
-
48
- from zea import Config, Probe, Scan
49
-
50
- # Initialize Scan from a Probe's parameters
51
- probe = Probe.from_name("verasonics_l11_4v")
52
- scan = Scan(**probe.get_parameters(), grid_size_z=256)
53
-
54
- # Or initialize from a Config object
55
- config = Config.from_hf("zeahub/configs", "config_picmus_rf.yaml", repo_type="dataset")
56
- scan = Scan(**config.scan, n_tx=11)
57
-
58
- # Or manually specify parameters
59
- scan = Scan(
60
- grid_size_x=128,
61
- grid_size_z=256,
62
- xlims=(-0.02, 0.02),
63
- zlims=(0.0, 0.06),
64
- center_frequency=6.25e6,
65
- sound_speed=1540.0,
66
- sampling_frequency=25e6,
67
- n_el=128,
68
- n_tx=11,
69
- )
70
-
71
- # Access a derived property (computed lazily)
72
- grid = scan.grid # shape: (grid_size_z, grid_size_x, 3)
73
-
74
- # Select a subset of transmit events
75
- scan.set_transmits(3) # Use 3 evenly spaced transmits
76
- scan.set_transmits([0, 2, 4]) # Use specific transmit indices
77
- scan.set_transmits("all") # Use all transmits
46
+ .. doctest::
47
+
48
+ >>> from zea import Config, Probe, Scan
49
+
50
+ >>> # Initialize Scan from a Probe's parameters
51
+ >>> probe = Probe.from_name("verasonics_l11_4v")
52
+ >>> scan = Scan(**probe.get_parameters(), grid_size_z=256, n_tx=11)
53
+
54
+ >>> # Or initialize from a Config object
55
+ >>> config = Config.from_hf("zeahub/configs", "config_picmus_rf.yaml", repo_type="dataset")
56
+ >>> scan = Scan(n_tx=11, **config.scan)
57
+
58
+ >>> # Or manually specify parameters
59
+ >>> scan = Scan(
60
+ ... grid_size_x=128,
61
+ ... grid_size_z=256,
62
+ ... xlims=(-0.02, 0.02),
63
+ ... zlims=(0.0, 0.06),
64
+ ... center_frequency=6.25e6,
65
+ ... sound_speed=1540.0,
66
+ ... sampling_frequency=25e6,
67
+ ... n_el=128,
68
+ ... n_tx=11,
69
+ ... )
70
+
71
+ >>> # Access a derived property (computed lazily)
72
+ >>> grid = scan.grid # shape: (grid_size_z, grid_size_x, 3)
73
+
74
+ >>> # Select a subset of transmit events
75
+ >>> _ = scan.set_transmits(3) # Use 3 evenly spaced transmits
76
+ >>> _ = scan.set_transmits([0, 2, 4]) # Use specific transmit indices
77
+ >>> _ = scan.set_transmits("all") # Use all transmits
78
78
 
79
79
  """
80
80
 
@@ -83,7 +83,11 @@ from keras import ops
83
83
 
84
84
  from zea import log
85
85
  from zea.beamform.pfield import compute_pfield
86
- from zea.beamform.pixelgrid import cartesian_pixel_grid, check_for_aliasing, polar_pixel_grid
86
+ from zea.beamform.pixelgrid import (
87
+ cartesian_pixel_grid,
88
+ check_for_aliasing,
89
+ polar_pixel_grid,
90
+ )
87
91
  from zea.display import (
88
92
  compute_scan_convert_2d_coordinates,
89
93
  compute_scan_convert_3d_coordinates,
@@ -130,6 +134,15 @@ class Scan(Parameters):
130
134
  demodulation_frequency (float, optional): Demodulation frequency in Hz.
131
135
  time_to_next_transmit (np.ndarray): The time between subsequent
132
136
  transmit events of shape (n_frames, n_tx).
137
+ tgc_gain_curve (np.ndarray): Time gain compensation (TGC) curve of shape (n_ax,).
138
+ waveforms_one_way (np.ndarray): The one-way transmit waveforms of shape
139
+ (n_waveforms, n_samples).
140
+ waveforms_two_way (np.ndarray): The two-way transmit waveforms of shape
141
+ (n_waveforms, n_samples).
142
+ tx_waveform_indices (np.ndarray): Indices of the waveform used for each
143
+ transmit event of shape (n_tx,).
144
+ t_peak (np.ndarray, optional): The time of the peak of the pulse of every transmit waveform
145
+ of shape (n_waveforms,).
133
146
  pixels_per_wavelength (int, optional): Number of pixels per wavelength.
134
147
  Defaults to 4.
135
148
  element_width (float, optional): Width of each transducer element in meters.
@@ -140,7 +153,6 @@ class Scan(Parameters):
140
153
  apply_lens_correction (bool, optional): Whether to apply lens correction to
141
154
  delays. Defaults to False.
142
155
  lens_thickness (float, optional): Thickness of the lens in meters.
143
- Defaults to None.
144
156
  f_number (float, optional): F-number of the transducer. Defaults to 1.0.
145
157
  theta_range (tuple, optional): Range of theta angles for 3D imaging.
146
158
  phi_range (tuple, optional): Range of phi angles for 3D imaging.
@@ -160,6 +172,7 @@ class Scan(Parameters):
160
172
  Can be "cartesian" or "polar". Defaults to "cartesian".
161
173
  dynamic_range (tuple, optional): Dynamic range for image display.
162
174
  Defined in dB as (min_dB, max_dB). Defaults to (-60, 0).
175
+
163
176
  """
164
177
 
165
178
  VALID_PARAMS = {
@@ -200,6 +213,11 @@ class Scan(Parameters):
200
213
  "focus_distances": {"type": np.ndarray},
201
214
  "initial_times": {"type": np.ndarray},
202
215
  "time_to_next_transmit": {"type": np.ndarray},
216
+ "tgc_gain_curve": {"type": np.ndarray},
217
+ "waveforms_one_way": {"type": np.ndarray, "default": None},
218
+ "waveforms_two_way": {"type": np.ndarray, "default": None},
219
+ "tx_waveform_indices": {"type": np.ndarray},
220
+ "t_peak": {"type": np.ndarray},
203
221
  # scan conversion parameters
204
222
  "theta_range": {"type": (tuple, list)},
205
223
  "phi_range": {"type": (tuple, list)},
@@ -209,8 +227,9 @@ class Scan(Parameters):
209
227
  }
210
228
 
211
229
  def __init__(self, **kwargs):
212
- # Store the current selection state before initialization
213
- selected_transmits_input = kwargs.pop("selected_transmits", None)
230
+ # Ensure that selected_transmits is present and set to None by default
231
+ selected_transmits_input = kwargs.get("selected_transmits", None)
232
+ kwargs["selected_transmits"] = None
214
233
 
215
234
  # Initialize parent class
216
235
  super().__init__(**kwargs)
@@ -240,7 +259,10 @@ class Scan(Parameters):
240
259
  )
241
260
  elif self.grid_type == "cartesian":
242
261
  return cartesian_pixel_grid(
243
- self.xlims, self.zlims, grid_size_z=self.grid_size_z, grid_size_x=self.grid_size_x
262
+ self.xlims,
263
+ self.zlims,
264
+ grid_size_z=self.grid_size_z,
265
+ grid_size_x=self.grid_size_x,
244
266
  )
245
267
  else:
246
268
  raise ValueError(
@@ -301,7 +323,10 @@ class Scan(Parameters):
301
323
  radius * np.cos(-np.pi / 2 + self.polar_limits[1]),
302
324
  )
303
325
  xlims_plane = (self.probe_geometry[0, 0], self.probe_geometry[-1, 0])
304
- xlims = min(xlims_polar[0], xlims_plane[0]), max(xlims_polar[1], xlims_plane[1])
326
+ xlims = (
327
+ min(xlims_polar[0], xlims_plane[0]),
328
+ max(xlims_polar[1], xlims_plane[1]),
329
+ )
305
330
  return xlims
306
331
 
307
332
  @cache_with_dependencies("sound_speed", "sampling_frequency", "n_ax")
@@ -344,6 +369,11 @@ class Scan(Parameters):
344
369
  """The total number of transmits in the full dataset."""
345
370
  return self._params["n_tx"]
346
371
 
372
+ @property
373
+ def n_tx_selected(self):
374
+ """The number of currently selected transmits."""
375
+ return len(self.selected_transmits)
376
+
347
377
  @cache_with_dependencies("selected_transmits")
348
378
  def n_tx(self):
349
379
  """The number of currently selected transmits."""
@@ -387,14 +417,12 @@ class Scan(Parameters):
387
417
  if selection is None or selection == "all":
388
418
  self._selected_transmits = None
389
419
  self._invalidate("selected_transmits")
390
- self._invalidate_dependents("selected_transmits")
391
420
  return self
392
421
 
393
422
  # Handle "center" - use center transmit
394
423
  if selection == "center":
395
424
  self._selected_transmits = [n_tx_total // 2]
396
425
  self._invalidate("selected_transmits")
397
- self._invalidate_dependents("selected_transmits")
398
426
  return self
399
427
 
400
428
  # Handle integer - select evenly spaced transmits
@@ -416,7 +444,6 @@ class Scan(Parameters):
416
444
  self._selected_transmits = list(np.rint(tx_indices).astype(int))
417
445
 
418
446
  self._invalidate("selected_transmits")
419
- self._invalidate_dependents("selected_transmits")
420
447
  return self
421
448
 
422
449
  # Handle slice - convert to list of indices
@@ -436,7 +463,6 @@ class Scan(Parameters):
436
463
  int(i) for i in selection
437
464
  ] # Convert numpy integers to Python ints
438
465
  self._invalidate("selected_transmits")
439
- self._invalidate_dependents("selected_transmits")
440
466
  return self
441
467
 
442
468
  # Aliasing check
@@ -481,7 +507,7 @@ class Scan(Parameters):
481
507
  value = self._params.get("azimuth_angles")
482
508
  if value is None:
483
509
  log.warning("No azimuth angles provided, using zeros")
484
- value = np.zeros(self.n_tx_total)
510
+ return np.zeros(self.n_tx_selected)
485
511
 
486
512
  return value[self.selected_transmits]
487
513
 
@@ -492,7 +518,7 @@ class Scan(Parameters):
492
518
  value = self._params.get("t0_delays")
493
519
  if value is None:
494
520
  log.warning("No transmit delays provided, using zeros")
495
- return np.zeros((self.n_tx_total, self.n_el))
521
+ return np.zeros((self.n_tx_selected, self.n_el))
496
522
 
497
523
  return value[self.selected_transmits]
498
524
 
@@ -502,7 +528,7 @@ class Scan(Parameters):
502
528
  value = self._params.get("tx_apodizations")
503
529
  if value is None:
504
530
  log.warning("No transmit apodizations provided, using ones")
505
- value = np.ones((self.n_tx_total, self.n_el))
531
+ return np.ones((self.n_tx_selected, self.n_el))
506
532
 
507
533
  return value[self.selected_transmits]
508
534
 
@@ -512,7 +538,7 @@ class Scan(Parameters):
512
538
  value = self._params.get("focus_distances")
513
539
  if value is None:
514
540
  log.warning("No focus distances provided, using zeros")
515
- value = np.zeros(self.n_tx_total)
541
+ return np.zeros(self.n_tx_selected)
516
542
 
517
543
  return value[self.selected_transmits]
518
544
 
@@ -522,10 +548,19 @@ class Scan(Parameters):
522
548
  value = self._params.get("initial_times")
523
549
  if value is None:
524
550
  log.warning("No initial times provided, using zeros")
525
- value = np.zeros(self.n_tx_total)
551
+ return np.zeros(self.n_tx_selected)
526
552
 
527
553
  return value[self.selected_transmits]
528
554
 
555
+ @property
556
+ def t_peak(self):
557
+ """The time of the peak of the pulse in seconds of shape (n_waveforms,)."""
558
+ t_peak = self._params.get("t_peak")
559
+ if t_peak is None:
560
+ t_peak = np.array([1 / self.center_frequency])
561
+
562
+ return t_peak
563
+
529
564
  @cache_with_dependencies("selected_transmits")
530
565
  def time_to_next_transmit(self):
531
566
  """The time between subsequent transmit events of shape (n_frames, n_tx)."""
@@ -536,6 +571,23 @@ class Scan(Parameters):
536
571
  selected = self.selected_transmits
537
572
  return value[:, selected]
538
573
 
574
+ @cache_with_dependencies("n_ax")
575
+ def tgc_gain_curve(self):
576
+ """Time gain compensation (TGC) curve of shape (n_ax,)."""
577
+ value = self._params.get("tgc_gain_curve")
578
+ if value is None:
579
+ return np.ones(self.n_ax)
580
+ return value[: self.n_ax]
581
+
582
+ @cache_with_dependencies("selected_transmits")
583
+ def tx_waveform_indices(self):
584
+ """Indices of the waveform used for each transmit event of shape (n_tx,)."""
585
+ value = self._params.get("tx_waveform_indices")
586
+ if value is None:
587
+ return np.zeros(self.n_tx_selected, dtype=int)
588
+
589
+ return value[self.selected_transmits]
590
+
539
591
  @cache_with_dependencies(
540
592
  "sound_speed",
541
593
  "center_frequency",
@@ -545,8 +597,9 @@ class Scan(Parameters):
545
597
  "tx_apodizations",
546
598
  "grid",
547
599
  "t0_delays",
600
+ "pfield_kwargs",
548
601
  )
549
- def pfield(self):
602
+ def pfield(self) -> np.ndarray:
550
603
  """Compute or return the pressure field (pfield) for weighting."""
551
604
  pfield = compute_pfield(
552
605
  sound_speed=self.sound_speed,
@@ -563,7 +616,7 @@ class Scan(Parameters):
563
616
 
564
617
  @cache_with_dependencies("pfield")
565
618
  def flat_pfield(self):
566
- """Flattened pfield for weighting."""
619
+ """Flattened pfield for weighting of shape (n_pix, n_tx)."""
567
620
  return self.pfield.reshape(self.n_tx, -1).swapaxes(0, 1)
568
621
 
569
622
  @cache_with_dependencies("zlims")
@@ -596,7 +649,12 @@ class Scan(Parameters):
596
649
  return coords
597
650
 
598
651
  @cache_with_dependencies(
599
- "rho_range", "theta_range", "phi_range", "resolution", "grid_size_z", "grid_size_x"
652
+ "rho_range",
653
+ "theta_range",
654
+ "phi_range",
655
+ "resolution",
656
+ "grid_size_z",
657
+ "grid_size_x",
600
658
  )
601
659
  def coordinates_3d(self):
602
660
  """The coordinates for scan conversion."""
@@ -615,6 +673,17 @@ class Scan(Parameters):
615
673
  otherwise 2D."""
616
674
  return self.coordinates_3d if getattr(self, "phi_range", None) else self.coordinates_2d
617
675
 
676
+ @property
677
+ def pulse_repetition_frequency(self):
678
+ """The pulse repetition frequency (PRF) [Hz]. Assumes a constant PRF."""
679
+ if self.time_to_next_transmit is None:
680
+ log.warning("Time to next transmit is not set, cannot compute PRF")
681
+ return None
682
+
683
+ pulse_repetition_interval = np.mean(self.time_to_next_transmit)
684
+
685
+ return 1 / pulse_repetition_interval
686
+
618
687
  @cache_with_dependencies("time_to_next_transmit")
619
688
  def frames_per_second(self):
620
689
  """The number of frames per second [Hz]. Assumes a constant frame rate.
@@ -644,6 +713,7 @@ class Scan(Parameters):
644
713
  """The width of each transducer element in meters."""
645
714
  value = self._params.get("element_width")
646
715
  if value is None:
716
+ # assume uniform spacing
647
717
  return np.linalg.norm(self.probe_geometry[1] - self.probe_geometry[0])
648
718
  return value
649
719
 
@@ -651,5 +721,5 @@ class Scan(Parameters):
651
721
  if key == "selected_transmits":
652
722
  # If setting selected_transmits, call set_transmits to handle logic
653
723
  self.set_transmits(value)
654
- return
724
+ return super().__setattr__(key, self.selected_transmits)
655
725
  return super().__setattr__(key, value)
zea/simulator.py CHANGED
@@ -14,27 +14,30 @@ Example usage
14
14
  A simple example of simulating RF data with a single scatterer at the center of the probe. For a
15
15
  more in depth example see the notebook: :doc:`../notebooks/data/zea_simulation_example`.
16
16
 
17
- .. code-block:: python
18
-
19
- raw_data = simulate_rf(
20
- scatterer_positions=np.array([[0, 0, 20e-3]]),
21
- scatterer_magnitudes=np.array([1.0]),
22
- probe_geometry=np.stack(
23
- [np.linspace(-20e-3, 20e-3, 64), np.zeros(64), np.zeros(64)], axis=-1
24
- ),
25
- apply_lens_correction=True,
26
- lens_thickness=1e-3,
27
- lens_sound_speed=1000,
28
- sound_speed=1540,
29
- n_ax=1024,
30
- center_frequency=5e6,
31
- sampling_frequency=20e6,
32
- t0_delays=np.zeros((1, 64)),
33
- initial_times=np.zeros(1),
34
- element_width=0.2e-3,
35
- attenuation_coef=0.5,
36
- tx_apodizations=np.ones((1, 64)),
37
- )
17
+ .. doctest::
18
+
19
+ >>> from zea.simulator import simulate_rf
20
+ >>> import numpy as np
21
+
22
+ >>> raw_data = simulate_rf(
23
+ ... scatterer_positions=np.array([[0, 0, 20e-3]]),
24
+ ... scatterer_magnitudes=np.array([1.0]),
25
+ ... probe_geometry=np.stack(
26
+ ... [np.linspace(-20e-3, 20e-3, 64), np.zeros(64), np.zeros(64)], axis=-1
27
+ ... ),
28
+ ... apply_lens_correction=True,
29
+ ... lens_thickness=1e-3,
30
+ ... lens_sound_speed=1000,
31
+ ... sound_speed=1540,
32
+ ... n_ax=1024,
33
+ ... center_frequency=5e6,
34
+ ... sampling_frequency=20e6,
35
+ ... t0_delays=np.zeros((1, 64)),
36
+ ... initial_times=np.zeros(1),
37
+ ... element_width=0.2e-3,
38
+ ... attenuation_coef=0.5,
39
+ ... tx_apodizations=np.ones((1, 64)),
40
+ ... )
38
41
 
39
42
  """
40
43