zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -1
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/config.py +34 -25
  9. zea/data/__init__.py +22 -16
  10. zea/data/convert/camus.py +2 -1
  11. zea/data/convert/echonet.py +4 -4
  12. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  13. zea/data/convert/matlab.py +11 -4
  14. zea/data/data_format.py +31 -30
  15. zea/data/datasets.py +7 -5
  16. zea/data/file.py +104 -2
  17. zea/data/layers.py +3 -3
  18. zea/datapaths.py +16 -4
  19. zea/display.py +7 -5
  20. zea/interface.py +14 -16
  21. zea/internal/_generate_keras_ops.py +6 -7
  22. zea/internal/cache.py +2 -49
  23. zea/internal/config/validation.py +1 -2
  24. zea/internal/core.py +69 -6
  25. zea/internal/device.py +6 -2
  26. zea/internal/dummy_scan.py +330 -0
  27. zea/internal/operators.py +114 -2
  28. zea/internal/parameters.py +101 -70
  29. zea/internal/setup_zea.py +5 -6
  30. zea/internal/utils.py +282 -0
  31. zea/io_lib.py +247 -19
  32. zea/keras_ops.py +74 -4
  33. zea/log.py +9 -7
  34. zea/metrics.py +15 -7
  35. zea/models/__init__.py +30 -20
  36. zea/models/base.py +30 -14
  37. zea/models/carotid_segmenter.py +19 -4
  38. zea/models/diffusion.py +173 -12
  39. zea/models/echonet.py +22 -8
  40. zea/models/echonetlvh.py +31 -7
  41. zea/models/lpips.py +19 -2
  42. zea/models/lv_segmentation.py +28 -11
  43. zea/models/preset_utils.py +5 -5
  44. zea/models/regional_quality.py +30 -10
  45. zea/models/taesd.py +21 -5
  46. zea/models/unet.py +15 -1
  47. zea/ops.py +390 -196
  48. zea/probes.py +6 -6
  49. zea/scan.py +109 -49
  50. zea/simulator.py +24 -21
  51. zea/tensor_ops.py +406 -302
  52. zea/tools/hf.py +1 -1
  53. zea/tools/selection_tool.py +47 -86
  54. zea/utils.py +92 -480
  55. zea/visualize.py +177 -39
  56. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
  57. zea-0.0.7.dist-info/RECORD +114 -0
  58. zea-0.0.6.dist-info/RECORD +0 -112
  59. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
  60. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  61. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/__init__.py CHANGED
@@ -7,36 +7,71 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.6"
10
+ __version__ = "0.0.7"
11
11
 
12
12
 
13
13
  def _bootstrap_backend():
14
14
  """Setup function to initialize the zea package."""
15
15
 
16
16
  def _check_backend_installed():
17
- """Assert that at least one ML backend (torch, tensorflow, jax) is installed.
18
- If not, raise an AssertionError with a helpful install message.
19
- """
20
-
21
- ml_backends = ["torch", "tensorflow", "jax"]
22
- for backend in ml_backends:
23
- if importlib.util.find_spec(backend) is not None:
24
- return
17
+ """Verify that the required ML backend is installed.
25
18
 
26
- backend_env = os.environ.get("KERAS_BACKEND", "numpy")
27
- install_guide_urls = {
19
+ Raises ImportError if:
20
+ 1. No ML backend (torch, tensorflow, jax) is installed
21
+ 2. KERAS_BACKEND points to a backend that is not installed
22
+ """
23
+ ML_BACKENDS = ["torch", "tensorflow", "jax"]
24
+ INSTALL_URLS = {
28
25
  "torch": "https://pytorch.org/get-started/locally/",
29
26
  "tensorflow": "https://www.tensorflow.org/install",
30
27
  "jax": "https://docs.jax.dev/en/latest/installation.html",
31
28
  }
32
- guide_url = install_guide_urls.get(backend_env, "https://keras.io/getting_started/")
33
- raise ImportError(
34
- "No ML backend (torch, tensorflow, jax) installed in current environment. "
35
- f"Please install at least one ML backend before importing {__package__} or "
36
- f"any other library. Current KERAS_BACKEND is set to '{backend_env}', "
37
- f"please install it first, see: {guide_url}. One simple alternative is to "
38
- f"install with default backend: `pip install {__package__}[jax]`."
39
- )
29
+ KERAS_DEFAULT_BACKEND = "tensorflow"
30
+ DOCS_URL = "https://zea.readthedocs.io/en/latest/installation.html"
31
+
32
+ # Determine which backend Keras will try to use
33
+ backend_env = os.environ.get("KERAS_BACKEND")
34
+ effective_backend = backend_env or KERAS_DEFAULT_BACKEND
35
+
36
+ # Find all installed ML backends
37
+ installed_backends = [
38
+ backend for backend in ML_BACKENDS if importlib.util.find_spec(backend) is not None
39
+ ]
40
+
41
+ # Error if no backends are installed
42
+ if not installed_backends:
43
+ if backend_env:
44
+ backend_status = f"KERAS_BACKEND is set to '{backend_env}'"
45
+ else:
46
+ backend_status = f"KERAS_BACKEND is not set (defaults to '{KERAS_DEFAULT_BACKEND}')"
47
+ install_url = INSTALL_URLS.get(effective_backend, "https://keras.io/getting_started/")
48
+ raise ImportError(
49
+ f"No ML backend (torch, tensorflow, jax) installed in current "
50
+ f"environment. Please install at least one ML backend before importing "
51
+ f"{__package__}. {backend_status}, please install it first, see: "
52
+ f"{install_url}. One simple alternative is to install with default "
53
+ f"backend: `pip install {__package__}[jax]`. For more information, "
54
+ f"see: {DOCS_URL}"
55
+ )
56
+
57
+ # Error if the effective backend is not installed
58
+ # (skip numpy which doesn't need installation)
59
+ if effective_backend not in ["numpy"] and effective_backend not in installed_backends:
60
+ if backend_env:
61
+ backend_status = f"KERAS_BACKEND environment variable is set to '{backend_env}'"
62
+ else:
63
+ backend_status = (
64
+ f"KERAS_BACKEND is not set, which defaults to '{KERAS_DEFAULT_BACKEND}'"
65
+ )
66
+ install_url = INSTALL_URLS.get(effective_backend, "https://keras.io/getting_started/")
67
+ raise ImportError(
68
+ f"{backend_status}, but this backend is not installed. "
69
+ f"Installed backends: {', '.join(installed_backends)}. "
70
+ f"Please either install '{effective_backend}' (see: {install_url}) "
71
+ f"or set KERAS_BACKEND to one of the installed backends "
72
+ f"(e.g., export KERAS_BACKEND={installed_backends[0]}). "
73
+ f"For more information, see: {DOCS_URL}"
74
+ )
40
75
 
41
76
  _check_backend_installed()
42
77
 
zea/agent/__init__.py CHANGED
@@ -7,21 +7,21 @@ For a practical example, see :doc:`../notebooks/agent/agent_example`.
7
7
  Example usage
8
8
  ^^^^^^^^^^^^^
9
9
 
10
- .. code-block:: python
10
+ .. doctest::
11
11
 
12
- import zea
13
- import numpy as np
12
+ >>> import zea
13
+ >>> import numpy as np
14
14
 
15
- agent = zea.agent.selection.GreedyEntropy(
16
- n_actions=7,
17
- n_possible_actions=112,
18
- img_width=112,
19
- img_height=112,
20
- )
15
+ >>> agent = zea.agent.selection.GreedyEntropy(
16
+ ... n_actions=7,
17
+ ... n_possible_actions=112,
18
+ ... img_width=112,
19
+ ... img_height=112,
20
+ ... )
21
21
 
22
- # (batch, samples, height, width)
23
- particles = np.random.rand(1, 10, 112, 112)
24
- lines, mask = agent.sample(particles)
22
+ >>> # (batch, samples, height, width)
23
+ >>> particles = np.random.rand(1, 10, 112, 112)
24
+ >>> lines, mask = agent.sample(particles) # doctest: +SKIP
25
25
  """
26
26
 
27
27
  from . import masks, selection
zea/agent/masks.py CHANGED
@@ -91,7 +91,8 @@ def random_uniform_lines(
91
91
 
92
92
  def _assert_equal_spacing(n_actions, n_possible_actions):
93
93
  assert n_possible_actions % n_actions == 0, (
94
- "Number of actions must divide evenly into possible actions to use equispaced sampling."
94
+ "Number of actions must divide evenly into possible actions to use equispaced sampling. "
95
+ "If you do not care about equal spacing, set `assert_equal_spacing=False`."
95
96
  )
96
97
 
97
98
 
@@ -12,7 +12,8 @@ from keras.src.trainers.data_adapters import TFDatasetAdapter
12
12
 
13
13
  from zea.data.dataloader import H5Generator
14
14
  from zea.data.layers import Resizer
15
- from zea.utils import find_methods_with_return_type, translate
15
+ from zea.internal.utils import find_methods_with_return_type
16
+ from zea.tensor_ops import translate
16
17
 
17
18
  METHODS_THAT_RETURN_DATASET = find_methods_with_return_type(tf.data.Dataset, "DatasetV2")
18
19
 
@@ -5,7 +5,7 @@ import numpy as np
5
5
  from keras import ops
6
6
 
7
7
  from zea.beamform.lens_correction import calculate_lens_corrected_delays
8
- from zea.tensor_ops import safe_vectorize
8
+ from zea.tensor_ops import vmap
9
9
 
10
10
 
11
11
  def fnum_window_fn_rect(normalized_angle):
@@ -57,10 +57,11 @@ def tof_correction(
57
57
  initial_times,
58
58
  sampling_frequency,
59
59
  demodulation_frequency,
60
- fnum,
61
- angles,
60
+ f_number,
61
+ polar_angles,
62
62
  focus_distances,
63
- apply_phase_rotation=False,
63
+ t_peak,
64
+ tx_waveform_indices,
64
65
  apply_lens_correction=False,
65
66
  lens_thickness=1e-3,
66
67
  lens_sound_speed=1000,
@@ -73,21 +74,19 @@ def tof_correction(
73
74
  flatgrid (ops.Tensor): Pixel locations x, y, z of shape `(n_pix, 3)`
74
75
  t0_delays (ops.Tensor): Times at which the elements fire shifted such
75
76
  that the first element fires at t=0 of shape `(n_tx, n_el)`
76
- tx_apodizations (ops.Tensor): Transmit apodizations of shape
77
- `(n_tx, n_el)`
77
+ tx_apodizations (ops.Tensor): Transmit apodizations of shape `(n_tx, n_el)`
78
78
  sound_speed (float): Speed-of-sound.
79
- probe_geometry (ops.Tensor): Element positions x, y, z of shape
80
- (num_samples, 3)
81
- initial_times (ops.Tensor): Time-ofsampling_frequencyet per transmission of shape
82
- `(n_tx,)`.
79
+ probe_geometry (ops.Tensor): Element positions x, y, z of shape (n_el, 3)
80
+ initial_times (Tensor): The probe transmit time offsets of shape `(n_tx,)`.
83
81
  sampling_frequency (float): Sampling frequency.
84
82
  demodulation_frequency (float): Demodulation frequency.
85
- fnum (int, optional): Focus number. Defaults to 1.
86
- angles (ops.Tensor): The angles of the plane waves in radians of shape
87
- `(n_tx,)`
83
+ f_number (float): Focus number (ratio of focal depth to aperture size).
84
+ polar_angles (ops.Tensor): The angles of the waves in radians of shape `(n_tx,)`
88
85
  focus_distances (ops.Tensor): The focus distance of shape `(n_tx,)`
89
- apply_phase_rotation (bool, optional): Whether to apply phase rotation to
90
- time-of-flights. Defaults to False.
86
+ t_peak (ops.Tensor): Time of the peak of the pulse in seconds.
87
+ Shape `(n_waveforms,)`.
88
+ tx_waveform_indices (ops.Tensor): The indices of the waveform used for each
89
+ transmit of shape `(n_tx,)`.
91
90
  apply_lens_correction (bool, optional): Whether to apply lens correction to
92
91
  time-of-flights. This makes it slower, but more accurate in the near-field.
93
92
  Defaults to False.
@@ -101,14 +100,12 @@ def tof_correction(
101
100
 
102
101
  Returns:
103
102
  (ops.Tensor): time-of-flight corrected data
104
- with shape: `(n_tx, n_pix, n_el, num_rf_iq_channels)`.
103
+ with shape: `(n_tx, n_pix, n_el, n_ch)`.
105
104
  """
106
105
 
107
106
  assert len(data.shape) == 4, (
108
107
  "The input data should have 4 dimensions, "
109
- f"namely num_transmits, num_elements, num_samples, "
110
- f"num_rf_iq_channels, got {len(data.shape)} dimensions: ."
111
- f"{data.shape}"
108
+ f"namely n_tx, n_ax, n_el, n_ch, got {len(data.shape)} dimensions: {data.shape}"
112
109
  )
113
110
 
114
111
  n_tx, n_ax, n_el, _ = ops.shape(data)
@@ -135,19 +132,33 @@ def tof_correction(
135
132
  n_tx,
136
133
  n_el,
137
134
  focus_distances,
138
- angles,
135
+ polar_angles,
136
+ t_peak=t_peak,
137
+ tx_waveform_indices=tx_waveform_indices,
139
138
  lens_thickness=lens_thickness,
140
139
  lens_sound_speed=lens_sound_speed,
141
140
  )
142
141
 
143
142
  n_pix = ops.shape(flatgrid)[0]
144
143
  mask = ops.cond(
145
- fnum == 0,
144
+ f_number == 0,
146
145
  lambda: ops.ones((n_pix, n_el, 1)),
147
- lambda: fnumber_mask(flatgrid, probe_geometry, fnum, fnum_window_fn=fnum_window_fn),
146
+ lambda: fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn=fnum_window_fn),
148
147
  )
149
148
 
150
149
  def _apply_delays(data_tx, txdel):
150
+ """Applies the delays to TOF correct a single transmit.
151
+
152
+ Args:
153
+ data_tx (ops.Tensor): The RF/IQ data for a single transmit of shape
154
+ `(n_ax, n_el, n_ch)`.
155
+ txdel (ops.Tensor): The transmit delays for a single transmit in samples
156
+ (not in seconds) of shape `(n_pix, 1)`.
157
+
158
+ Returns:
159
+ ops.Tensor: The time-of-flight corrected data of shape
160
+ `(n_pix, n_el, n_ch)`.
161
+ """
151
162
  # data_tx is of shape (num_elements, num_samples, 1 or 2)
152
163
 
153
164
  # Take receive delays and add the transmit delays for this transmit
@@ -164,22 +175,22 @@ def tof_correction(
164
175
  # Apply the mask
165
176
  tof_tx = tof_tx * mask
166
177
 
167
- # Phase correction
178
+ # Apply phase rotation if using IQ data
179
+ # This is needed because interpolating the IQ data without phase rotation
180
+ # is not equivalent to interpolating the RF data and then IQ demodulating
181
+ # See the docstring from complex_rotate for more details
182
+ apply_phase_rotation = data_tx.shape[-1] == 2
168
183
  if apply_phase_rotation:
169
- tshift = delays[:, :] / sampling_frequency
170
- tdemod = flatgrid[:, None, 2] * 2 / sound_speed
171
- theta = 2 * np.pi * demodulation_frequency * (tshift - tdemod)
172
- tof_tx = _complex_rotate(tof_tx, theta)
184
+ total_delay_seconds = delays[:, :] / sampling_frequency
185
+ theta = 2 * np.pi * demodulation_frequency * total_delay_seconds
186
+ tof_tx = complex_rotate(tof_tx, theta)
173
187
  return tof_tx
174
188
 
175
189
  # Reshape to (n_tx, n_pix, 1)
176
190
  txdel = ops.moveaxis(txdel, 1, 0)
177
191
  txdel = txdel[..., None]
178
192
 
179
- return safe_vectorize(
180
- _apply_delays,
181
- signature="(n_samples,n_el,n_ch),(n_pix,1)->(n_pix,n_el,n_ch)",
182
- )(data, txdel)
193
+ return vmap(_apply_delays)(data, txdel)
183
194
 
184
195
 
185
196
  def calculate_delays(
@@ -194,6 +205,8 @@ def calculate_delays(
194
205
  n_el,
195
206
  focus_distances,
196
207
  polar_angles,
208
+ t_peak,
209
+ tx_waveform_indices,
197
210
  **kwargs,
198
211
  ):
199
212
  """Calculates the delays in samples to every pixel in the grid.
@@ -225,12 +238,18 @@ def calculate_delays(
225
238
  assume plane wave transmission.
226
239
  polar_angles (Tensor): The polar angles of the plane waves in radians
227
240
  of shape `(n_tx,)`.
241
+ t_peak (Tensor): Time of the peak of the pulse in seconds of shape
242
+ `(n_waveforms,)`.
243
+ tx_waveform_indices (Tensor): The indices of the waveform used for each
244
+ transmit of shape `(n_tx,)`.
245
+
228
246
 
229
247
  Returns:
230
- transmit_delays (Tensor): The tensor of transmit delays to every pixel,
231
- shape `(n_pix, n_tx)`.
248
+ transmit_delays (Tensor): The tensor of transmit delays to every pixel
249
+ in samples (not in seconds), of shape `(n_pix, n_tx)`.
232
250
  receive_delays (Tensor): The tensor of receive delays from every pixel
233
- back to the transducer element, shape `(n_pix, n_el)`.
251
+ back to the transducer element in samples (not in seconds), of shape
252
+ `(n_pix, n_el)`.
234
253
  """
235
254
 
236
255
  def _tx_distances(polar_angles, t0_delays, tx_apodizations, focus_distances):
@@ -244,10 +263,7 @@ def calculate_delays(
244
263
  sound_speed,
245
264
  )
246
265
 
247
- tx_distances = safe_vectorize(
248
- _tx_distances,
249
- signature="(),(n_el),(n_el),()->(n_pix)",
250
- )(polar_angles, t0_delays, tx_apodizations, focus_distances)
266
+ tx_distances = vmap(_tx_distances)(polar_angles, t0_delays, tx_apodizations, focus_distances)
251
267
  tx_distances = ops.transpose(tx_distances, (1, 0))
252
268
  # tx_distances shape is now (n_pix, n_tx)
253
269
 
@@ -255,14 +271,18 @@ def calculate_delays(
255
271
  def _rx_distances(probe_geometry):
256
272
  return distance_Rx(grid, probe_geometry)
257
273
 
258
- rx_distances = safe_vectorize(_rx_distances, signature="(3)->(n_pix)")(probe_geometry)
274
+ rx_distances = vmap(_rx_distances)(probe_geometry)
259
275
  rx_distances = ops.transpose(rx_distances, (1, 0))
260
276
  # rx_distances shape is now (n_pix, n_el)
261
277
 
262
278
  # Compute the delays [in samples] from the distances
263
279
  # The units here are ([m]/[m/s]-[s])*[1/s] resulting in a unitless quantity
264
280
  # TODO: Add pulse width to transmit delays
265
- tx_delays = (tx_distances / sound_speed - initial_times[None]) * sampling_frequency
281
+ tx_delays = (
282
+ tx_distances / sound_speed
283
+ - initial_times[None]
284
+ + ops.take(t_peak, tx_waveform_indices)[None]
285
+ ) * sampling_frequency
266
286
  rx_delays = (rx_distances / sound_speed) * sampling_frequency
267
287
 
268
288
  return tx_delays, rx_delays
@@ -288,11 +308,11 @@ def apply_delays(data, delays, clip_min: int = -1, clip_max: int = -1):
288
308
 
289
309
  Returns:
290
310
  ops.Tensor: The samples received by each transducer element corresponding to the
291
- reflections of each pixel in the image of shape `(n_el, n_pix, n_ch)`.
311
+ reflections of each pixel in the image of shape `(n_pix, n_el, n_ch)`.
292
312
  """
293
313
 
294
314
  # Add a dummy channel dimension to the delays tensor to ensure it has the
295
- # same number of dimensions as the data. The new shape is (1, n_el, n_pix)
315
+ # same number of dimensions as the data. The new shape is (n_pix, n_el, 1)
296
316
  delays = delays[..., None]
297
317
 
298
318
  # Get the integer values above and below the exact delay values
@@ -318,7 +338,7 @@ def apply_delays(data, delays, clip_min: int = -1, clip_max: int = -1):
318
338
 
319
339
  # Gather pixel values
320
340
  # Here we extract for each transducer element the sample containing the
321
- # reflection from each pixel. These are of shape `(n_el, n_pix, n_ch)`.
341
+ # reflection from each pixel. These are of shape `(n_pix, n_el, n_ch)`.
322
342
  data0 = ops.take_along_axis(data, d0, 0)
323
343
  data1 = ops.take_along_axis(data, d1, 0)
324
344
 
@@ -332,7 +352,7 @@ def apply_delays(data, delays, clip_min: int = -1, clip_max: int = -1):
332
352
  return reflection_samples
333
353
 
334
354
 
335
- def _complex_rotate(iq, theta):
355
+ def complex_rotate(iq, theta):
336
356
  """Performs a simple phase rotation of I and Q component.
337
357
 
338
358
  Args:
@@ -341,11 +361,41 @@ def _complex_rotate(iq, theta):
341
361
 
342
362
  Returns:
343
363
  Tensor: The rotated tensor of shape `(..., 2)`.
364
+
365
+ .. dropdown:: Explanation
366
+
367
+ The IQ data is related to the RF data as follows:
368
+
369
+ .. math::
370
+
371
+ x(t) &= I(t)\\cos(\\omega_c t) + Q(t)\\cos(\\omega_c t + \\pi/2)\\\\
372
+ &= I(t)\\cos(\\omega_c t) - Q(t)\\sin(\\omega_c t)
373
+
374
+
375
+ If we want to delay the RF data `x(t)` by `Δt` we can substitute in
376
+ :math:`t=t+\\Delta t`. We also define :math:`I'(t) = I(t + \\Delta t)`,
377
+ :math:`Q'(t) = Q(t + \\Delta t)`, and :math:`\\theta=\\omega_c\\Delta t`.
378
+ This gives us:
379
+
380
+ .. math::
381
+
382
+ x(t + \\Delta t) &= I'(t) \\cos(\\omega_c (t + \\Delta t))
383
+ - Q'(t) \\sin(\\omega_c (t + \\Delta t))\\\\
384
+ &= \\overbrace{(I'(t)\\cos(\\theta)
385
+ - Q'(t)\\sin(\\theta) )}^{I_\\Delta(t)} \\cos(\\omega_c t)\\\\
386
+ &- \\overbrace{(Q'(t)\\cos(\\theta)
387
+ + I'(t)\\sin(\\theta))}^{Q_\\Delta(t)} \\sin(\\omega_c t)
388
+
389
+ This means that to correctly interpolate the IQ data to the new components
390
+ :math:`I_\\Delta(t)` and :math:`Q_\\Delta(t)`, it is not sufficient to just
391
+ interpolate the I- and Q-channels independently. We also need to rotate the
392
+ I- and Q-channels by the angle :math:`\\theta`. This function performs this
393
+ rotation.
344
394
  """
345
- # assert iq.shape[-1] == 2, (
346
- # "The last dimension of the input tensor should be 2, "
347
- # f"got {iq.shape[-1]} dimensions and shape {iq.shape}."
348
- # )
395
+ assert iq.shape[-1] == 2, (
396
+ "The last dimension of the input tensor should be 2, "
397
+ f"got {iq.shape[-1]} dimensions and shape {iq.shape}."
398
+ )
349
399
  # Select i and q channels
350
400
  i = iq[..., 0]
351
401
  q = iq[..., 1]
@@ -485,8 +535,8 @@ def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
485
535
 
486
536
  alpha = ops.arccos(grid_relative_to_probe_z)
487
537
 
488
- # The f-number is fnum = z/aperture = 1/(2 * tan(alpha))
489
- # Rearranging gives us alpha = arctan(1/(2 * fnum))
538
+ # The f-number is f_number = z/aperture = 1/(2 * tan(alpha))
539
+ # Rearranging gives us alpha = arctan(1/(2 * f_number))
490
540
  # We can use this to compute the maximum angle alpha that is allowed
491
541
  max_alpha = ops.arctan(1 / (2 * f_number + keras.backend.epsilon()))
492
542
 
@@ -15,6 +15,8 @@ def calculate_lens_corrected_delays(
15
15
  n_el,
16
16
  focus_distances,
17
17
  polar_angles,
18
+ t_peak,
19
+ tx_waveform_indices,
18
20
  lens_sound_speed=1000,
19
21
  lens_thickness=1e-3,
20
22
  n_iter=2,
@@ -33,6 +35,9 @@ def calculate_lens_corrected_delays(
33
35
  n_el (int): The number of elements.
34
36
  focus_distances (ndarray): The focus distances of shape (n_tx,).
35
37
  polar_angles (ndarray): The polar angles of shape (n_tx,).
38
+ t_peak (ndarray): The time of the peak of the pulse in seconds of shape (n_waveforms,).
39
+ tx_waveform_indices (ndarray): The indices of the waveforms used for each transmit
40
+ of shape (n_tx,).
36
41
  lens_sound_speed (float): The speed of sound in the lens in m/s.
37
42
  lens_thickness (float): The thickness of the lens in meters.
38
43
  n_iter (int): The number of iterations to run the Newton-Raphson method.
@@ -59,9 +64,11 @@ def calculate_lens_corrected_delays(
59
64
  # Add a large offset to elements that are not used in the transmit to
60
65
  # diqualify them from being the closest element
61
66
  apod_offset = ops.where(tx_apodizations[tx] == 0, 10.0, 0)
62
- tx_min = ops.min(rx_delays + t0_delays[tx] + apod_offset, axis=-1) + initial_times[tx]
67
+ tx_min = ops.min(rx_delays + t0_delays[tx] + apod_offset, axis=-1) - initial_times[tx]
63
68
  tx_delays.append(tx_min)
64
- tx_delays = ops.stack(tx_delays, axis=-1)
69
+ tx_delays = ops.stack(tx_delays, axis=-1) + ops.take(t_peak, tx_waveform_indices)[None]
70
+ tx_delays = ops.nan_to_num(tx_delays, nan=0.0, posinf=0.0, neginf=0.0)
71
+ rx_delays = ops.nan_to_num(rx_delays, nan=0.0, posinf=0.0, neginf=0.0)
65
72
 
66
73
  tx_delays *= sampling_frequency
67
74
  rx_delays *= sampling_frequency
zea/beamform/pfield.py CHANGED
@@ -43,7 +43,7 @@ def compute_pfield(
43
43
  grid,
44
44
  t0_delays,
45
45
  frequency_step=4,
46
- db_thresh=-1,
46
+ db_thresh=-1.0,
47
47
  downsample=10,
48
48
  downmix=4,
49
49
  alpha=1,
@@ -65,7 +65,7 @@ def compute_pfield(
65
65
  t0_delays (array): Transmit delays for each transmit event.
66
66
  frequency_step (int, optional): Frequency step. Default is 4.
67
67
  Higher is faster but less accurate.
68
- db_thresh (int, optional): dB threshold. Default is -1.
68
+ db_thresh (float, optional): dB threshold. Default is -1.0
69
69
  Higher is faster but less accurate.
70
70
  downsample (int, optional): Downsample the grid for faster computation.
71
71
  Default is 10. Higher is faster but less accurate.
@@ -85,6 +85,13 @@ def compute_pfield(
85
85
  # medium params
86
86
  alpha_db = 0 # currently we ignore attenuation in the compounding
87
87
 
88
+ # cast to float32
89
+ sound_speed = ops.cast(sound_speed, "float32")
90
+ center_frequency = ops.cast(center_frequency, "float32")
91
+ bandwidth_percent = ops.cast(bandwidth_percent, "float32")
92
+ alpha_db = ops.cast(alpha_db, "float32")
93
+ db_thresh = ops.cast(db_thresh, "float32")
94
+
88
95
  # probe params
89
96
  center_frequency = center_frequency / downmix # downmixing the frequency
90
97
 
zea/config.py CHANGED
@@ -16,23 +16,30 @@ Features
16
16
  Example Usage
17
17
  ^^^^^^^^^^^^^
18
18
 
19
- .. code-block:: python
19
+ .. doctest::
20
20
 
21
- from zea import Config
21
+ >>> from zea import Config
22
22
 
23
- # Load from YAML
24
- config = Config.from_yaml("config.yaml")
25
- # Load from HuggingFace Hub
26
- config = Config.from_hf("zea/diffusion-echonet-dynamic", "train_config.yaml")
23
+ >>> # Load from YAML
24
+ >>> config = Config.from_yaml("../configs/config_echonet.yaml")
25
+ >>> # Load from HuggingFace Hub
26
+ >>> config = Config.from_hf("zeahub/configs", "config_picmus_rf.yaml", repo_type="dataset")
27
27
 
28
- # Access attributes with dot notation
29
- print(config.model.name)
28
+ >>> # Access attributes with dot notation
29
+ >>> print(config.data.dtype)
30
+ raw_data
30
31
 
31
- # Update recursively
32
- config.update_recursive({"model": {"name": "new_model"}})
32
+ >>> # Update recursively
33
+ >>> config.update_recursive({"data": {"dtype": "raw_data"}})
33
34
 
34
- # Save to YAML
35
- config.save_to_yaml("new_config.yaml")
35
+ >>> # Save to YAML
36
+ >>> config.save_to_yaml("new_config.yaml")
37
+
38
+ .. testcleanup::
39
+
40
+ import os
41
+
42
+ os.remove("new_config.yaml")
36
43
 
37
44
  """
38
45
 
@@ -166,14 +173,14 @@ class Config(dict):
166
173
  each element is updated recursively if it is a Config, otherwise replaced.
167
174
 
168
175
  Example:
176
+ .. doctest::
169
177
 
170
- .. code-block:: python
171
-
172
- config = Config({"a": 1, "b": {"c": 2, "d": 3}})
173
- config.update_recursive({"a": 4, "b": {"c": 5}})
174
- print(config)
175
- # <Config {'a': 4, 'b': {'c': 5, 'd': 3}}>
176
- # Notice how "d" is kept and only "c" is updated.
178
+ >>> from zea import Config
179
+ >>> config = Config({"a": 1, "b": {"c": 2, "d": 3}})
180
+ >>> config.update_recursive({"a": 4, "b": {"c": 5}})
181
+ >>> # Notice how "d" is kept and only "c" is updated.
182
+ >>> print(config)
183
+ <Config {'a': 4, 'b': {'c': 5, 'd': 3}}>
177
184
 
178
185
  Args:
179
186
  dictionary (dict, optional): Dictionary to update from.
@@ -453,12 +460,6 @@ class Config(dict):
453
460
  def from_hf(cls, repo_id, path, **kwargs):
454
461
  """Load config object from huggingface hub.
455
462
 
456
- Example:
457
-
458
- .. code-block:: python
459
-
460
- config = Config.from_hf("zeahub/configs", "config_camus.yaml", repo_type="dataset")
461
-
462
463
  Args:
463
464
  repo_id (str): huggingface hub repo id.
464
465
  For example: "zeahub/configs"
@@ -471,6 +472,14 @@ class Config(dict):
471
472
 
472
473
  Returns:
473
474
  Config: config object.
475
+
476
+ Example:
477
+ .. doctest::
478
+
479
+ >>> from zea import Config
480
+ >>> config = Config.from_hf(
481
+ ... "zeahub/configs", "config_camus.yaml", repo_type="dataset"
482
+ ... )
474
483
  """
475
484
  local_path = hf_hub_download(repo_id, path, **kwargs)
476
485
  return _load_config_from_yaml(local_path, config_class=cls)
zea/data/__init__.py CHANGED
@@ -15,22 +15,28 @@ See the data notebook for a more detailed example: :doc:`../notebooks/data/zea_d
15
15
  Examples usage
16
16
  ^^^^^^^^^^^^^^
17
17
 
18
- .. code-block:: python
19
-
20
- from zea import File, Dataset
21
-
22
- # Open a single zea data file
23
- with File("path/to/file.hdf5", mode="r") as file:
24
- file.summary()
25
- data = file.load_data("raw_data", indices=[0])
26
- scan = file.scan()
27
- probe = file.probe()
28
-
29
- # Work with a dataset (folder or list of files)
30
- dataset = Dataset("path/to/folder", key="raw_data")
31
- for file in dataset:
32
- print(file)
33
- dataset.close()
18
+ .. doctest::
19
+
20
+ >>> from zea import File, Dataset
21
+
22
+ >>> # Work with a single file
23
+ >>> path_to_file = (
24
+ ... "hf://zeahub/picmus/database/experiments/contrast_speckle/"
25
+ ... "contrast_speckle_expe_dataset_iq/contrast_speckle_expe_dataset_iq.hdf5"
26
+ ... )
27
+
28
+ >>> with File(path_to_file, mode="r") as file:
29
+ ... # file.summary()
30
+ ... data = file.load_data("raw_data", indices=[0])
31
+ ... scan = file.scan()
32
+ ... probe = file.probe()
33
+
34
+ >>> # Work with a dataset (folder or list of files)
35
+ >>> dataset = Dataset("hf://zeahub/picmus", key="raw_data")
36
+ >>> files = []
37
+ >>> for file in dataset:
38
+ ... files.append(file) # process each file as needed
39
+ >>> dataset.close()
34
40
 
35
41
  Subpackage layout
36
42
  -----------------
zea/data/convert/camus.py CHANGED
@@ -20,7 +20,8 @@ from tqdm import tqdm
20
20
  # from zea.display import transform_sc_image_to_polar
21
21
  from zea import log
22
22
  from zea.data.data_format import generate_zea_dataset
23
- from zea.utils import find_first_nonzero_index, translate
23
+ from zea.internal.utils import find_first_nonzero_index
24
+ from zea.tensor_ops import translate
24
25
 
25
26
 
26
27
  def transform_sc_image_to_polar(image_sc, output_size=None, fit_outline=True):