zea 0.0.2__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zea/__init__.py CHANGED
@@ -7,7 +7,7 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.2"
10
+ __version__ = "0.0.4"
11
11
 
12
12
 
13
13
  def setup():
zea/agent/selection.py CHANGED
@@ -92,6 +92,7 @@ class GreedyEntropy(LinesActionModel):
92
92
  mean: float = 0,
93
93
  std_dev: float = 1,
94
94
  num_lines_to_update: int = 5,
95
+ entropy_sigma: float = 1.0,
95
96
  ):
96
97
  """Initialize the GreedyEntropy action selection model.
97
98
 
@@ -104,6 +105,8 @@ class GreedyEntropy(LinesActionModel):
104
105
  std_dev (float, optional): The standard deviation of the RBF. Defaults to 1.
105
106
  num_lines_to_update (int, optional): The number of lines around the selected line
106
107
  to update. Must be odd.
108
+ entropy_sigma (float, optional): The standard deviation of the Gaussian
109
+ Mixture components used to approximate the posterior.
107
110
  """
108
111
  super().__init__(n_actions, n_possible_actions, img_width, img_height)
109
112
 
@@ -124,7 +127,7 @@ class GreedyEntropy(LinesActionModel):
124
127
  self.num_lines_to_update,
125
128
  )
126
129
  self.upside_down_gaussian = upside_down_gaussian(points_to_evaluate)
127
- self.entropy_sigma = 1
130
+ self.entropy_sigma = entropy_sigma
128
131
 
129
132
  @staticmethod
130
133
  def compute_pairwise_pixel_gaussian_error(
@@ -152,51 +155,58 @@ class GreedyEntropy(LinesActionModel):
152
155
  # TODO: I think we only need to compute the lower triangular
153
156
  # of this matrix, since it's symmetric
154
157
  squared_l2_error_matrices = (particles[:, :, None, ...] - particles[:, None, :, ...]) ** 2
155
- gaussian_error_per_pixel_i_j = ops.exp((squared_l2_error_matrices) / (2 * entropy_sigma**2))
158
+ gaussian_error_per_pixel_i_j = ops.exp(
159
+ -(squared_l2_error_matrices) / (2 * entropy_sigma**2)
160
+ )
156
161
  # Vertically stack all columns corresponding with the same line
157
162
  # This way we can just sum across the height axis and get the entropy
158
163
  # for each pixel in a given line
159
164
  batch_size, n_particles, _, height, _ = gaussian_error_per_pixel_i_j.shape
160
- gaussian_error_per_pixel_stacked = ops.reshape(
161
- gaussian_error_per_pixel_i_j,
162
- [
163
- batch_size,
164
- n_particles,
165
- n_particles,
166
- height * stack_n_cols,
167
- n_possible_actions,
168
- ],
165
+ gaussian_error_per_pixel_stacked = ops.transpose(
166
+ ops.reshape(
167
+ ops.transpose(gaussian_error_per_pixel_i_j, (0, 1, 2, 4, 3)),
168
+ [
169
+ batch_size,
170
+ n_particles,
171
+ n_particles,
172
+ n_possible_actions,
173
+ height * stack_n_cols,
174
+ ],
175
+ ),
176
+ (0, 1, 2, 4, 3),
169
177
  )
170
178
  # [n_particles, n_particles, batch, height, width]
171
179
  return gaussian_error_per_pixel_stacked
172
180
 
173
- def compute_gmm_entropy_per_line(self, particles):
174
- """Compute the entropy for each line using a Gaussian Mixture Model.
175
-
181
+ def compute_pixelwise_entropy(self, particles):
182
+ """
176
183
  This function computes the entropy for each line using a Gaussian Mixture Model
177
184
  approximation of the posterior distribution.
178
- For more details see Section 4 here: https://arxiv.org/abs/2406.14388
185
+ For more details see Section VI. B here: https://arxiv.org/pdf/2410.13310
179
186
 
180
187
  Args:
181
188
  particles (Tensor): Particles of shape (batch_size, n_particles, height, width)
182
189
 
183
190
  Returns:
184
- Tensor: batch of entropies per line, of shape (batch, n_possible_actions)
191
+ Tensor: batch of entropies per pixel, of shape (batch, height, width)
185
192
  """
186
- gaussian_error_per_pixel_stacked = GreedyEntropy.compute_pairwise_pixel_gaussian_error(
193
+ n_particles = ops.shape(particles)[1]
194
+ gaussian_error_per_pixel_stacked = self.compute_pairwise_pixel_gaussian_error(
187
195
  particles,
188
196
  self.stack_n_cols,
189
197
  self.n_possible_actions,
190
198
  self.entropy_sigma,
191
199
  )
192
- gaussian_error_per_line = ops.sum(gaussian_error_per_pixel_stacked, axis=3)
193
200
  # sum out first dimension of (n_particles x n_particles) error matrix
194
- # [n_particles, batch, n_possible_actions]
195
- entropy_per_line_i = ops.sum(gaussian_error_per_line, axis=1)
201
+ # [n_particles, batch, height, width]
202
+ pixelwise_entropy_sum_j = ops.sum(
203
+ (1 / n_particles) * gaussian_error_per_pixel_stacked, axis=1
204
+ )
205
+ log_pixelwise_entropy_sum_j = ops.log(pixelwise_entropy_sum_j)
196
206
  # sum out second dimension of (n_particles x n_particles) error matrix
197
- # [batch, n_possible_actions]
198
- entropy_per_line = ops.sum(entropy_per_line_i, axis=1)
199
- return entropy_per_line
207
+ # [batch, height, width]
208
+ pixelwise_entropy = -ops.sum((1 / n_particles) * log_pixelwise_entropy_sum_j, axis=1)
209
+ return pixelwise_entropy
200
210
 
201
211
  def select_line_and_reweight_entropy(self, entropy_per_line):
202
212
  """Select the line with maximum entropy and reweight the entropies.
@@ -254,17 +264,19 @@ class GreedyEntropy(LinesActionModel):
254
264
  particles (Tensor): Particles of shape (batch_size, n_particles, height, width)
255
265
 
256
266
  Returns:
257
- Tuple[Tensor, Tensor]:
267
+ Tuple[Tensor, Tensor]:
258
268
  - Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
259
- - Masks of shape (batch_size, img_height, img_width)
269
+ - Masks of shape (batch_size, img_height, img_width)
260
270
  """
261
- entropy_per_line = self.compute_gmm_entropy_per_line(particles)
271
+
272
+ pixelwise_entropy = self.compute_pixelwise_entropy(particles)
273
+ linewise_entropy = ops.sum(pixelwise_entropy, axis=1)
262
274
 
263
275
  # Greedily select best line, reweight entropies, and repeat
264
276
  all_selected_lines = []
265
277
  for _ in range(self.n_actions):
266
- max_entropy_line, entropy_per_line = ops.vectorized_map(
267
- self.select_line_and_reweight_entropy, entropy_per_line
278
+ max_entropy_line, linewise_entropy = ops.vectorized_map(
279
+ self.select_line_and_reweight_entropy, linewise_entropy
268
280
  )
269
281
  all_selected_lines.append(max_entropy_line)
270
282
 
@@ -428,10 +440,15 @@ class CovarianceSamplingLines(LinesActionModel):
428
440
  generation. Defaults to None.
429
441
 
430
442
  Returns:
431
- Tensor: The mask of shape (batch_size, img_size, img_size)
443
+ Tuple[Tensor, Tensor]:
444
+ - Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
445
+ - Masks of shape (batch_size, img_height, img_width)
432
446
  """
433
447
  batch_size, n_particles, rows, _ = ops.shape(particles)
434
448
 
449
+ # [batch_size, rows, cols, n_particles]
450
+ particles = ops.transpose(particles, (0, 2, 3, 1))
451
+
435
452
  # [batch_size, rows * stack_n_cols, n_possible_actions, n_particles]
436
453
  shape = [
437
454
  batch_size,
@@ -441,7 +458,7 @@ class CovarianceSamplingLines(LinesActionModel):
441
458
  ]
442
459
  particles = ops.reshape(particles, shape)
443
460
 
444
- # [batch_size, rows, n_possible_actions, n_possible_actions]
461
+ # [batch_size, rows * stack_n_cols, n_possible_actions, n_possible_actions]
445
462
  cov_matrix = tensor_ops.batch_cov(particles)
446
463
 
447
464
  # Sum over the row dimension [batch_size, n_possible_actions, n_possible_actions]
zea/backend/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """Backend subpackage for ``zea``.
1
+ """Backend-specific utilities.
2
2
 
3
3
  This subpackage provides backend-specific utilities for the ``zea`` library. Most backend logic is handled by Keras 3, but a few features require custom wrappers to ensure compatibility and performance across JAX, TensorFlow, and PyTorch.
4
4
 
@@ -9,7 +9,7 @@ Key Features
9
9
  ------------
10
10
 
11
11
  - **JIT Compilation** (:func:`zea.backend.jit`):
12
- Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines.
12
+ Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines. Note that jit compilation is not yet supported when using the `torch` backend.
13
13
 
14
14
  - **Automatic Differentiation** (:class:`zea.backend.AutoGrad`):
15
15
  Offers a backend-agnostic wrapper for automatic differentiation, allowing gradient computation regardless of the underlying ML library.
@@ -108,7 +108,9 @@ def _jit_compile(func, jax=True, tensorflow=True, **kwargs):
108
108
  return func
109
109
  else:
110
110
  log.warning(
111
- f"Unsupported backend: {backend}. Supported backends are 'tensorflow' and 'jax'."
111
+ f"JIT compilation not currently supported for backend {backend}. "
112
+ "Supported backends are 'tensorflow' and 'jax'."
112
113
  )
114
+ log.warning("Initialize zea.Pipeline with jit_options=None to suppress this warning.")
113
115
  log.warning("Falling back to non-compiled mode.")
114
116
  return func
zea/beamform/__init__.py CHANGED
@@ -17,4 +17,4 @@ see the pipeline example notebook: :doc:`../notebooks/pipeline/zea_pipeline_exam
17
17
 
18
18
  """
19
19
 
20
- from . import beamformer, delays, lens_correction, pfield, pixelgrid
20
+ from . import beamformer, delays, lens_correction, pfield, phantoms, pixelgrid
@@ -1,5 +1,6 @@
1
1
  """Main beamforming functions for ultrasound imaging."""
2
2
 
3
+ import keras
3
4
  import numpy as np
4
5
  from keras import ops
5
6
 
@@ -58,7 +59,7 @@ def tof_correction(
58
59
  demodulation_frequency,
59
60
  fnum,
60
61
  angles,
61
- vfocus,
62
+ focus_distances,
62
63
  apply_phase_rotation=False,
63
64
  apply_lens_correction=False,
64
65
  lens_thickness=1e-3,
@@ -84,7 +85,7 @@ def tof_correction(
84
85
  fnum (int, optional): Focus number. Defaults to 1.
85
86
  angles (ops.Tensor): The angles of the plane waves in radians of shape
86
87
  `(n_tx,)`
87
- vfocus (ops.Tensor): The focus distance of shape `(n_tx,)`
88
+ focus_distances (ops.Tensor): The focus distance of shape `(n_tx,)`
88
89
  apply_phase_rotation (bool, optional): Whether to apply phase rotation to
89
90
  time-of-flights. Defaults to False.
90
91
  apply_lens_correction (bool, optional): Whether to apply lens correction to
@@ -133,7 +134,7 @@ def tof_correction(
133
134
  sound_speed,
134
135
  n_tx,
135
136
  n_el,
136
- vfocus,
137
+ focus_distances,
137
138
  angles,
138
139
  lens_thickness=lens_thickness,
139
140
  lens_sound_speed=lens_sound_speed,
@@ -487,7 +488,7 @@ def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
487
488
  # The f-number is fnum = z/aperture = 1/(2 * tan(alpha))
488
489
  # Rearranging gives us alpha = arctan(1/(2 * fnum))
489
490
  # We can use this to compute the maximum angle alpha that is allowed
490
- max_alpha = ops.arctan(1 / (2 * f_number))
491
+ max_alpha = ops.arctan(1 / (2 * f_number + keras.backend.epsilon()))
491
492
 
492
493
  normalized_angle = alpha / max_alpha
493
494
  mask = fnum_window_fn(normalized_angle)
zea/beamform/pfield.py CHANGED
@@ -60,7 +60,8 @@ def compute_pfield(
60
60
  n_el (int): Number of elements in the probe.
61
61
  probe_geometry (array): Geometry of the probe elements.
62
62
  tx_apodizations (array): Transmit apodization values.
63
- grid (array): Grid points where the pressure field is computed of shape (Nz, Nx, 3).
63
+ grid (array): Grid points where the pressure field is computed
64
+ of shape (grid_size_z, grid_size_x, 3).
64
65
  t0_delays (array): Transmit delays for each transmit event.
65
66
  frequency_step (int, optional): Frequency step. Default is 4.
66
67
  Higher is faster but less accurate.
@@ -78,7 +79,8 @@ def compute_pfield(
78
79
  verbose (bool, optional): Whether to print progress.
79
80
 
80
81
  Returns:
81
- ops.array: The (normalized) pressure field (across tx events) of shape (n_tx, Nz, Nx).
82
+ ops.array: The (normalized) pressure field (across tx events)
83
+ of shape (n_tx, grid_size_z, grid_size_x).
82
84
  """
83
85
  # medium params
84
86
  alpha_db = 0 # currently we ignore attenuation in the compounding
@@ -293,7 +295,8 @@ def normalize_pressure_field(pfield, alpha: float = 1.0, percentile: float = 10.
293
295
  Normalize the input array of intensities by zeroing out values below a given percentile.
294
296
 
295
297
  Args:
296
- pfield (array): The unnormalized pressure field array of shape (n_tx, Nz, Nx).
298
+ pfield (array): The unnormalized pressure field array
299
+ of shape (n_tx, grid_size_z, grid_size_x).
297
300
  alpha (float, optional): Exponent to 'sharpen or smooth' the weighting.
298
301
  Higher values result in sharper weighting. Default is 1.0.
299
302
  percentile (int, optional): minimum percentile threshold to keep in the weighting.
zea/beamform/pixelgrid.py CHANGED
@@ -16,52 +16,52 @@ def check_for_aliasing(scan):
16
16
  depth = scan.zlims[1] - scan.zlims[0]
17
17
  wvln = scan.wavelength
18
18
 
19
- if width / scan.Nx > wvln / 2:
19
+ if width / scan.grid_size_x > wvln / 2:
20
20
  log.warning(
21
- f"width/Nx = {width / scan.Nx:.7f} < wavelength/2 = {wvln / 2}. "
22
- f"Consider either increasing scan.Nx to {int(np.ceil(width / (wvln / 2)))} or more, or "
23
- "increasing scan.pixels_per_wavelength to 2 or more."
21
+ f"width/grid_size_x = {width / scan.grid_size_x:.7f} < wavelength/2 = {wvln / 2}. "
22
+ f"Consider either increasing scan.grid_size_x to {int(np.ceil(width / (wvln / 2)))} "
23
+ "or more, or increasing scan.pixels_per_wavelength to 2 or more."
24
24
  )
25
- if depth / scan.Nz > wvln / 2:
25
+ if depth / scan.grid_size_z > wvln / 2:
26
26
  log.warning(
27
- f"depth/Nz = {depth / scan.Nz:.7f} < wavelength/2 = {wvln / 2:.7f}. "
28
- f"Consider either increasing scan.Nz to {int(np.ceil(depth / (wvln / 2)))} or more, or "
29
- "increasing scan.pixels_per_wavelength to 2 or more."
27
+ f"depth/grid_size_z = {depth / scan.grid_size_z:.7f} < wavelength/2 = {wvln / 2:.7f}. "
28
+ f"Consider either increasing scan.grid_size_z to {int(np.ceil(depth / (wvln / 2)))} "
29
+ "or more, or increasing scan.pixels_per_wavelength to 2 or more."
30
30
  )
31
31
 
32
32
 
33
- def cartesian_pixel_grid(xlims, zlims, Nx=None, Nz=None, dx=None, dz=None):
33
+ def cartesian_pixel_grid(xlims, zlims, grid_size_x=None, grid_size_z=None, dx=None, dz=None):
34
34
  """Generate a Cartesian pixel grid based on input parameters.
35
35
 
36
36
  Args:
37
37
  xlims (tuple): Azimuthal limits of pixel grid ([xmin, xmax])
38
38
  zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
39
- Nx (int): Number of azimuthal pixels, overrides dx and dz parameters
40
- Nz (int): Number of depth pixels, overrides dx and dz parameters
39
+ grid_size_x (int): Number of azimuthal pixels, overrides dx and dz parameters
40
+ grid_size_z (int): Number of depth pixels, overrides dx and dz parameters
41
41
  dx (float): Pixel spacing in azimuth
42
42
  dz (float): Pixel spacing in depth
43
43
 
44
44
  Raises:
45
- ValueError: Either Nx and Nz or dx and dz must be defined.
45
+ ValueError: Either grid_size_x and grid_size_z or dx and dz must be defined.
46
46
 
47
47
  Returns:
48
- grid (np.ndarray): Pixel grid of size (nz, nx, 3) in
48
+ grid (np.ndarray): Pixel grid of size (grid_size_z, nx, 3) in
49
49
  Cartesian coordinates (x, y, z)
50
50
  """
51
- assert (bool(Nx) and bool(Nz)) ^ (bool(dx) and bool(dz)), (
52
- "Either Nx and Nz or dx and dz must be defined."
51
+ assert (bool(grid_size_x) and bool(grid_size_z)) ^ (bool(dx) and bool(dz)), (
52
+ "Either grid_size_x and grid_size_z or dx and dz must be defined."
53
53
  )
54
54
 
55
55
  # Determine the grid spacing
56
- if Nx is not None and Nz is not None:
57
- x = np.linspace(xlims[0], xlims[1] + eps, Nx)
58
- z = np.linspace(zlims[0], zlims[1] + eps, Nz)
56
+ if grid_size_x is not None and grid_size_z is not None:
57
+ x = np.linspace(xlims[0], xlims[1] + eps, grid_size_x)
58
+ z = np.linspace(zlims[0], zlims[1] + eps, grid_size_z)
59
59
  elif dx is not None and dz is not None:
60
60
  sign = np.sign(xlims[1] - xlims[0])
61
61
  x = np.arange(xlims[0], xlims[1] + eps, sign * dx)
62
62
  z = np.arange(zlims[0], zlims[1] + eps, sign * dz)
63
63
  else:
64
- raise ValueError("Either Nx and Nz or dx and dz must be defined.")
64
+ raise ValueError("Either grid_size_x and grid_size_z or dx and dz must be defined.")
65
65
 
66
66
  # Create the pixel grid
67
67
  z_grid, x_grid = np.meshgrid(z, x, indexing="ij")
@@ -102,29 +102,30 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
102
102
  return grid
103
103
 
104
104
 
105
- def polar_pixel_grid(polar_limits, zlims, Nz: int, Nr: int):
105
+ def polar_pixel_grid(polar_limits, zlims, num_radial_pixels: int, num_polar_pixels: int):
106
106
  """Generate a polar grid.
107
107
 
108
108
  Uses radial_pixel_grid but based on parameters that are present in the scan class.
109
109
 
110
110
  Args:
111
- polar_limits (tuple): Azimuthal limits of pixel grid ([azimuth_min, azimuth_max])
111
+ polar_limits (tuple): Polar limits of pixel grid ([polar_min, polar_max])
112
112
  zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
113
- Nz (int, optional): Number of depth pixels.
114
- Nr (int, optional): Number of azimuthal pixels.
113
+ num_radial_pixels (int, optional): Number of depth pixels.
114
+ num_polar_pixels (int, optional): Number of polar pixels.
115
115
 
116
116
  Returns:
117
- grid (np.ndarray): Pixel grid of size (Nz, Nr, 3) in Cartesian coordinates (x, y, z)
117
+ grid (np.ndarray): Pixel grid of size (num_radial_pixels, num_polar_pixels, 3)
118
+ in Cartesian coordinates (x, y, z)
118
119
  """
119
120
  assert len(polar_limits) == 2, "polar_limits must be a tuple of length 2."
120
121
  assert len(zlims) == 2, "zlims must be a tuple of length 2."
121
122
 
122
- dr = (zlims[1] - zlims[0]) / Nz
123
+ dr = (zlims[1] - zlims[0]) / num_radial_pixels
123
124
 
124
125
  oris = np.array([0, 0, 0])
125
- oris = np.tile(oris, (Nr, 1))
126
- dirs_az = np.linspace(*polar_limits, Nr)
126
+ oris = np.tile(oris, (num_polar_pixels, 1))
127
+ dirs_az = np.linspace(*polar_limits, num_polar_pixels)
127
128
 
128
- dirs_el = np.zeros(Nr)
129
+ dirs_el = np.zeros(num_polar_pixels)
129
130
  dirs = np.vstack((dirs_az, dirs_el)).T
130
131
  return radial_pixel_grid(zlims, dr, oris, dirs).transpose(1, 0, 2)
zea/config.py CHANGED
@@ -47,8 +47,9 @@ import yaml
47
47
  from huggingface_hub import hf_hub_download
48
48
 
49
49
  from zea import log
50
+ from zea.data.preset_utils import HF_PREFIX, _hf_resolve_path
50
51
  from zea.internal.config.validation import config_schema
51
- from zea.internal.core import object_to_tensor
52
+ from zea.internal.core import dict_to_tensor
52
53
 
53
54
 
54
55
  class Config(dict):
@@ -430,8 +431,22 @@ class Config(dict):
430
431
  v._recursive_setattr(set_key, set_value)
431
432
 
432
433
  @classmethod
433
- def from_yaml(cls, path, **kwargs):
434
- """Load config object from yaml file"""
434
+ def from_path(cls, path, **kwargs):
435
+ """Load config object from a file path.
436
+
437
+ Args:
438
+ path (str or Path): The path to the config file.
439
+ Can be a string or a Path object. Additionally can be a string with
440
+ the prefix 'hf://', in which case it will be resolved to a
441
+ huggingface path.
442
+
443
+ Returns:
444
+ Config: config object.
445
+ """
446
+ if str(path).startswith(HF_PREFIX):
447
+ path = _hf_resolve_path(str(path))
448
+ if isinstance(path, str):
449
+ path = Path(path)
435
450
  return _load_config_from_yaml(path, config_class=cls, **kwargs)
436
451
 
437
452
  @classmethod
@@ -460,9 +475,14 @@ class Config(dict):
460
475
  local_path = hf_hub_download(repo_id, path, **kwargs)
461
476
  return _load_config_from_yaml(local_path, config_class=cls)
462
477
 
463
- def to_tensor(self):
478
+ @classmethod
479
+ def from_yaml(cls, path, **kwargs):
480
+ """Load config object from yaml file."""
481
+ return cls.from_path(path, **kwargs)
482
+
483
+ def to_tensor(self, keep_as_is=None):
464
484
  """Convert the attributes in the object to keras tensors"""
465
- return object_to_tensor(self)
485
+ return dict_to_tensor(self.serialize(), keep_as_is=keep_as_is)
466
486
 
467
487
 
468
488
  def check_config(config: Union[dict, Config], verbose: bool = False):
zea/data/data_format.py CHANGED
@@ -42,8 +42,8 @@ def generate_example_dataset(
42
42
  sound_speed=1540,
43
43
  center_frequency=7e6,
44
44
  sampling_frequency=40e6,
45
- n_z=512,
46
- n_x=512,
45
+ grid_size_z=512,
46
+ grid_size_x=512,
47
47
  ):
48
48
  """Generates an example dataset that contains all the necessary fields.
49
49
  Note: This dataset does not contain actual data, but is filled with random
@@ -60,7 +60,7 @@ def generate_example_dataset(
60
60
  # creating some fake raw and image data
61
61
  raw_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
62
62
  # image data is in dB
63
- image = np.ones((n_frames, n_z, n_x)) * -40
63
+ image = np.ones((n_frames, grid_size_z, grid_size_x)) * -40
64
64
 
65
65
  # creating some fake scan parameters
66
66
  t0_delays = np.zeros((n_tx, n_el), dtype=np.float32)
@@ -82,8 +82,8 @@ def generate_example_dataset(
82
82
 
83
83
  if add_optional_dtypes:
84
84
  aligned_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
85
- envelope_data = np.ones((n_frames, n_z, n_x))
86
- beamformed_data = np.ones((n_frames, n_z, n_x, n_ch))
85
+ envelope_data = np.ones((n_frames, grid_size_z, grid_size_x))
86
+ beamformed_data = np.ones((n_frames, grid_size_z, grid_size_x, n_ch))
87
87
  image_sc = np.ones_like(image)
88
88
  else:
89
89
  aligned_data = None
@@ -123,10 +123,11 @@ def validate_input_data(raw_data, aligned_data, envelope_data, beamformed_data,
123
123
  aligned_data (np.ndarray): The aligned data of the ultrasound measurement of
124
124
  shape (n_frames, n_tx, n_ax, n_el, n_ch).
125
125
  envelope_data (np.ndarray): The envelope data of the ultrasound measurement of
126
- shape (n_frames, n_z, n_x).
126
+ shape (n_frames, grid_size_z, grid_size_x).
127
127
  beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of
128
- shape (n_frames, n_z, n_x).
129
- image (np.ndarray): The ultrasound images to be saved of shape (n_frames, n_z, n_x).
128
+ shape (n_frames, grid_size_z, grid_size_x).
129
+ image (np.ndarray): The ultrasound images to be saved
130
+ of shape (n_frames, grid_size_z, grid_size_x).
130
131
  image_sc (np.ndarray): The scan converted ultrasound images to be saved
131
132
  of shape (n_frames, output_size_z, output_size_x).
132
133
  """
@@ -254,7 +255,7 @@ def _write_datasets(
254
255
  group_name=data_group_name,
255
256
  name="envelope_data",
256
257
  data=_convert_datatype(envelope_data),
257
- description="The envelope_data of shape (n_frames, n_z, n_x).",
258
+ description="The envelope_data of shape (n_frames, grid_size_z, grid_size_x).",
258
259
  unit="unitless",
259
260
  )
260
261
 
@@ -262,7 +263,7 @@ def _write_datasets(
262
263
  group_name=data_group_name,
263
264
  name="beamformed_data",
264
265
  data=_convert_datatype(beamformed_data),
265
- description="The beamformed_data of shape (n_frames, n_z, n_x).",
266
+ description="The beamformed_data of shape (n_frames, grid_size_z, grid_size_x).",
266
267
  unit="unitless",
267
268
  )
268
269
 
@@ -271,7 +272,7 @@ def _write_datasets(
271
272
  name="image",
272
273
  data=_convert_datatype(image),
273
274
  unit="unitless",
274
- description="The images of shape [n_frames, n_z, n_x]",
275
+ description="The images of shape [n_frames, grid_size_z, grid_size_x]",
275
276
  )
276
277
 
277
278
  _add_dataset(
@@ -467,32 +468,34 @@ def _write_datasets(
467
468
  ),
468
469
  unit="-",
469
470
  )
470
- n_waveforms = len(waveforms_one_way)
471
- for n, waveform_1way, waveform_2way in zip(
472
- range(n_waveforms), waveforms_one_way, waveforms_two_way
473
- ):
474
- _add_dataset(
475
- group_name=scan_group_name + "/waveforms_one_way",
476
- name=f"waveform_{str(n).zfill(3)}",
477
- data=waveform_1way,
478
- description=(
479
- "One-way waveform as simulated by the Verasonics system, "
480
- "sampled at 250MHz. This is the waveform after being filtered "
481
- "by the tranducer bandwidth once."
482
- ),
483
- unit="V",
484
- )
485
- _add_dataset(
486
- group_name=scan_group_name + "/waveforms_two_way",
487
- name=f"waveform_{str(n).zfill(3)}",
488
- data=waveform_2way,
489
- description=(
490
- "Two-way waveform as simulated by the Verasonics system, "
491
- "sampled at 250MHz. This is the waveform after being filtered "
492
- "by the tranducer bandwidth twice."
493
- ),
494
- unit="V",
495
- )
471
+
472
+ if waveforms_one_way is not None:
473
+ for n in range(len(waveforms_one_way)):
474
+ _add_dataset(
475
+ group_name=scan_group_name + "/waveforms_one_way",
476
+ name=f"waveform_{str(n).zfill(3)}",
477
+ data=waveforms_one_way[n],
478
+ description=(
479
+ "One-way waveform as simulated by the Verasonics system, "
480
+ "sampled at 250MHz. This is the waveform after being filtered "
481
+ "by the tranducer bandwidth once."
482
+ ),
483
+ unit="V",
484
+ )
485
+
486
+ if waveforms_two_way is not None:
487
+ for n in range(len(waveforms_two_way)):
488
+ _add_dataset(
489
+ group_name=scan_group_name + "/waveforms_two_way",
490
+ name=f"waveform_{str(n).zfill(3)}",
491
+ data=waveforms_two_way[n],
492
+ description=(
493
+ "Two-way waveform as simulated by the Verasonics system, "
494
+ "sampled at 250MHz. This is the waveform after being filtered "
495
+ "by the tranducer bandwidth twice."
496
+ ),
497
+ unit="V",
498
+ )
496
499
 
497
500
  # Add additional elements
498
501
  if additional_elements is not None:
@@ -546,10 +549,11 @@ def generate_zea_dataset(
546
549
  aligned_data (np.ndarray): The aligned data of the ultrasound measurement of
547
550
  shape (n_frames, n_tx, n_ax, n_el, n_ch).
548
551
  envelope_data (np.ndarray): The envelope data of the ultrasound measurement of
549
- shape (n_frames, n_z, n_x).
552
+ shape (n_frames, grid_size_z, grid_size_x).
550
553
  beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of
551
- shape (n_frames, n_z, n_x, n_ch).
552
- image (np.ndarray): The ultrasound images to be saved of shape (n_frames, n_z, n_x).
554
+ shape (n_frames, grid_size_z, grid_size_x, n_ch).
555
+ image (np.ndarray): The ultrasound images to be saved
556
+ of shape (n_frames, grid_size_z, grid_size_x).
553
557
  image_sc (np.ndarray): The scan converted ultrasound images to be saved
554
558
  of shape (n_frames, output_size_z, output_size_x).
555
559
  probe_geometry (np.ndarray): The probe geometry of shape (n_el, 3).
zea/data/file.py CHANGED
@@ -366,9 +366,7 @@ class File(h5py.File):
366
366
  """
367
367
  file_scan_parameters = self.get_parameters(event)
368
368
 
369
- probe_parameters = reduce_to_signature(
370
- Probe.from_name("generic").__init__, file_scan_parameters
371
- )
369
+ probe_parameters = reduce_to_signature(Probe.__init__, file_scan_parameters)
372
370
  return probe_parameters
373
371
 
374
372
  def probe(self, event=None) -> Probe:
@@ -387,21 +385,8 @@ class File(h5py.File):
387
385
  Returns:
388
386
  Probe: The probe object.
389
387
  """
390
- probe_parameters = self.get_probe_parameters(event)
391
- if self.probe_name == "generic":
392
- return Probe.from_name(self.probe_name, **probe_parameters)
393
- else:
394
- probe = Probe.from_name(self.probe_name)
395
-
396
- probe_geometry = probe_parameters.get("probe_geometry", None)
397
- if not np.allclose(probe_geometry, probe.probe_geometry):
398
- probe.probe_geometry = probe_geometry
399
- log.warning(
400
- "The probe geometry in the data file does not "
401
- "match the probe geometry of the probe. The probe "
402
- "geometry has been updated to match the data file."
403
- )
404
- return probe
388
+ probe_parameters_file = self.get_probe_parameters(event)
389
+ return Probe.from_parameters(self.probe_name, probe_parameters_file)
405
390
 
406
391
  def recursively_load_dict_contents_from_group(self, path: str, squeeze: bool = False) -> dict:
407
392
  """Load dict from contents of group