zea 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zea/__init__.py CHANGED
@@ -7,7 +7,7 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.2"
10
+ __version__ = "0.0.3"
11
11
 
12
12
 
13
13
  def setup():
zea/agent/selection.py CHANGED
@@ -92,6 +92,7 @@ class GreedyEntropy(LinesActionModel):
92
92
  mean: float = 0,
93
93
  std_dev: float = 1,
94
94
  num_lines_to_update: int = 5,
95
+ entropy_sigma: float = 1.0,
95
96
  ):
96
97
  """Initialize the GreedyEntropy action selection model.
97
98
 
@@ -104,6 +105,8 @@ class GreedyEntropy(LinesActionModel):
104
105
  std_dev (float, optional): The standard deviation of the RBF. Defaults to 1.
105
106
  num_lines_to_update (int, optional): The number of lines around the selected line
106
107
  to update. Must be odd.
108
+ entropy_sigma (float, optional): The standard deviation of the Gaussian
109
+ Mixture components used to approximate the posterior.
107
110
  """
108
111
  super().__init__(n_actions, n_possible_actions, img_width, img_height)
109
112
 
@@ -124,7 +127,7 @@ class GreedyEntropy(LinesActionModel):
124
127
  self.num_lines_to_update,
125
128
  )
126
129
  self.upside_down_gaussian = upside_down_gaussian(points_to_evaluate)
127
- self.entropy_sigma = 1
130
+ self.entropy_sigma = entropy_sigma
128
131
 
129
132
  @staticmethod
130
133
  def compute_pairwise_pixel_gaussian_error(
@@ -157,15 +160,18 @@ class GreedyEntropy(LinesActionModel):
157
160
  # This way we can just sum across the height axis and get the entropy
158
161
  # for each pixel in a given line
159
162
  batch_size, n_particles, _, height, _ = gaussian_error_per_pixel_i_j.shape
160
- gaussian_error_per_pixel_stacked = ops.reshape(
161
- gaussian_error_per_pixel_i_j,
162
- [
163
- batch_size,
164
- n_particles,
165
- n_particles,
166
- height * stack_n_cols,
167
- n_possible_actions,
168
- ],
163
+ gaussian_error_per_pixel_stacked = ops.transpose(
164
+ ops.reshape(
165
+ ops.transpose(gaussian_error_per_pixel_i_j, (0, 1, 2, 4, 3)),
166
+ [
167
+ batch_size,
168
+ n_particles,
169
+ n_particles,
170
+ n_possible_actions,
171
+ height * stack_n_cols,
172
+ ],
173
+ ),
174
+ (0, 1, 2, 4, 3),
169
175
  )
170
176
  # [n_particles, n_particles, batch, height, width]
171
177
  return gaussian_error_per_pixel_stacked
@@ -428,10 +434,15 @@ class CovarianceSamplingLines(LinesActionModel):
428
434
  generation. Defaults to None.
429
435
 
430
436
  Returns:
431
- Tensor: The mask of shape (batch_size, img_size, img_size)
437
+ Tuple[Tensor, Tensor]:
438
+ - Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
439
+ - Masks of shape (batch_size, img_height, img_width)
432
440
  """
433
441
  batch_size, n_particles, rows, _ = ops.shape(particles)
434
442
 
443
+ # [batch_size, rows, cols, n_particles]
444
+ particles = ops.transpose(particles, (0, 2, 3, 1))
445
+
435
446
  # [batch_size, rows * stack_n_cols, n_possible_actions, n_particles]
436
447
  shape = [
437
448
  batch_size,
@@ -441,7 +452,7 @@ class CovarianceSamplingLines(LinesActionModel):
441
452
  ]
442
453
  particles = ops.reshape(particles, shape)
443
454
 
444
- # [batch_size, rows, n_possible_actions, n_possible_actions]
455
+ # [batch_size, rows * stack_n_cols, n_possible_actions, n_possible_actions]
445
456
  cov_matrix = tensor_ops.batch_cov(particles)
446
457
 
447
458
  # Sum over the row dimension [batch_size, n_possible_actions, n_possible_actions]
zea/backend/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """Backend subpackage for ``zea``.
1
+ """Backend-specific utilities.
2
2
 
3
3
  This subpackage provides backend-specific utilities for the ``zea`` library. Most backend logic is handled by Keras 3, but a few features require custom wrappers to ensure compatibility and performance across JAX, TensorFlow, and PyTorch.
4
4
 
@@ -9,7 +9,7 @@ Key Features
9
9
  ------------
10
10
 
11
11
  - **JIT Compilation** (:func:`zea.backend.jit`):
12
- Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines.
12
+ Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines. Note that jit compilation is not yet supported when using the `torch` backend.
13
13
 
14
14
  - **Automatic Differentiation** (:class:`zea.backend.AutoGrad`):
15
15
  Offers a backend-agnostic wrapper for automatic differentiation, allowing gradient computation regardless of the underlying ML library.
@@ -108,7 +108,9 @@ def _jit_compile(func, jax=True, tensorflow=True, **kwargs):
108
108
  return func
109
109
  else:
110
110
  log.warning(
111
- f"Unsupported backend: {backend}. Supported backends are 'tensorflow' and 'jax'."
111
+ f"JIT compilation not currently supported for backend {backend}. "
112
+ "Supported backends are 'tensorflow' and 'jax'."
112
113
  )
114
+ log.warning("Initialize zea.Pipeline with jit_options=None to suppress this warning.")
113
115
  log.warning("Falling back to non-compiled mode.")
114
116
  return func
zea/beamform/__init__.py CHANGED
@@ -17,4 +17,4 @@ see the pipeline example notebook: :doc:`../notebooks/pipeline/zea_pipeline_exam
17
17
 
18
18
  """
19
19
 
20
- from . import beamformer, delays, lens_correction, pfield, pixelgrid
20
+ from . import beamformer, delays, lens_correction, pfield, phantoms, pixelgrid
@@ -1,5 +1,6 @@
1
1
  """Main beamforming functions for ultrasound imaging."""
2
2
 
3
+ import keras
3
4
  import numpy as np
4
5
  from keras import ops
5
6
 
@@ -58,7 +59,7 @@ def tof_correction(
58
59
  demodulation_frequency,
59
60
  fnum,
60
61
  angles,
61
- vfocus,
62
+ focus_distances,
62
63
  apply_phase_rotation=False,
63
64
  apply_lens_correction=False,
64
65
  lens_thickness=1e-3,
@@ -84,7 +85,7 @@ def tof_correction(
84
85
  fnum (int, optional): Focus number. Defaults to 1.
85
86
  angles (ops.Tensor): The angles of the plane waves in radians of shape
86
87
  `(n_tx,)`
87
- vfocus (ops.Tensor): The focus distance of shape `(n_tx,)`
88
+ focus_distances (ops.Tensor): The focus distance of shape `(n_tx,)`
88
89
  apply_phase_rotation (bool, optional): Whether to apply phase rotation to
89
90
  time-of-flights. Defaults to False.
90
91
  apply_lens_correction (bool, optional): Whether to apply lens correction to
@@ -133,7 +134,7 @@ def tof_correction(
133
134
  sound_speed,
134
135
  n_tx,
135
136
  n_el,
136
- vfocus,
137
+ focus_distances,
137
138
  angles,
138
139
  lens_thickness=lens_thickness,
139
140
  lens_sound_speed=lens_sound_speed,
@@ -487,7 +488,7 @@ def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
487
488
  # The f-number is fnum = z/aperture = 1/(2 * tan(alpha))
488
489
  # Rearranging gives us alpha = arctan(1/(2 * fnum))
489
490
  # We can use this to compute the maximum angle alpha that is allowed
490
- max_alpha = ops.arctan(1 / (2 * f_number))
491
+ max_alpha = ops.arctan(1 / (2 * f_number + keras.backend.epsilon()))
491
492
 
492
493
  normalized_angle = alpha / max_alpha
493
494
  mask = fnum_window_fn(normalized_angle)
zea/beamform/pfield.py CHANGED
@@ -60,7 +60,8 @@ def compute_pfield(
60
60
  n_el (int): Number of elements in the probe.
61
61
  probe_geometry (array): Geometry of the probe elements.
62
62
  tx_apodizations (array): Transmit apodization values.
63
- grid (array): Grid points where the pressure field is computed of shape (Nz, Nx, 3).
63
+ grid (array): Grid points where the pressure field is computed
64
+ of shape (grid_size_z, grid_size_x, 3).
64
65
  t0_delays (array): Transmit delays for each transmit event.
65
66
  frequency_step (int, optional): Frequency step. Default is 4.
66
67
  Higher is faster but less accurate.
@@ -78,7 +79,8 @@ def compute_pfield(
78
79
  verbose (bool, optional): Whether to print progress.
79
80
 
80
81
  Returns:
81
- ops.array: The (normalized) pressure field (across tx events) of shape (n_tx, Nz, Nx).
82
+ ops.array: The (normalized) pressure field (across tx events)
83
+ of shape (n_tx, grid_size_z, grid_size_x).
82
84
  """
83
85
  # medium params
84
86
  alpha_db = 0 # currently we ignore attenuation in the compounding
@@ -293,7 +295,8 @@ def normalize_pressure_field(pfield, alpha: float = 1.0, percentile: float = 10.
293
295
  Normalize the input array of intensities by zeroing out values below a given percentile.
294
296
 
295
297
  Args:
296
- pfield (array): The unnormalized pressure field array of shape (n_tx, Nz, Nx).
298
+ pfield (array): The unnormalized pressure field array
299
+ of shape (n_tx, grid_size_z, grid_size_x).
297
300
  alpha (float, optional): Exponent to 'sharpen or smooth' the weighting.
298
301
  Higher values result in sharper weighting. Default is 1.0.
299
302
  percentile (int, optional): minimum percentile threshold to keep in the weighting.
zea/beamform/pixelgrid.py CHANGED
@@ -16,52 +16,52 @@ def check_for_aliasing(scan):
16
16
  depth = scan.zlims[1] - scan.zlims[0]
17
17
  wvln = scan.wavelength
18
18
 
19
- if width / scan.Nx > wvln / 2:
19
+ if width / scan.grid_size_x > wvln / 2:
20
20
  log.warning(
21
- f"width/Nx = {width / scan.Nx:.7f} < wavelength/2 = {wvln / 2}. "
22
- f"Consider either increasing scan.Nx to {int(np.ceil(width / (wvln / 2)))} or more, or "
23
- "increasing scan.pixels_per_wavelength to 2 or more."
21
+ f"width/grid_size_x = {width / scan.grid_size_x:.7f} < wavelength/2 = {wvln / 2}. "
22
+ f"Consider either increasing scan.grid_size_x to {int(np.ceil(width / (wvln / 2)))} "
23
+ "or more, or increasing scan.pixels_per_wavelength to 2 or more."
24
24
  )
25
- if depth / scan.Nz > wvln / 2:
25
+ if depth / scan.grid_size_z > wvln / 2:
26
26
  log.warning(
27
- f"depth/Nz = {depth / scan.Nz:.7f} < wavelength/2 = {wvln / 2:.7f}. "
28
- f"Consider either increasing scan.Nz to {int(np.ceil(depth / (wvln / 2)))} or more, or "
29
- "increasing scan.pixels_per_wavelength to 2 or more."
27
+ f"depth/grid_size_z = {depth / scan.grid_size_z:.7f} < wavelength/2 = {wvln / 2:.7f}. "
28
+ f"Consider either increasing scan.grid_size_z to {int(np.ceil(depth / (wvln / 2)))} "
29
+ "or more, or increasing scan.pixels_per_wavelength to 2 or more."
30
30
  )
31
31
 
32
32
 
33
- def cartesian_pixel_grid(xlims, zlims, Nx=None, Nz=None, dx=None, dz=None):
33
+ def cartesian_pixel_grid(xlims, zlims, grid_size_x=None, grid_size_z=None, dx=None, dz=None):
34
34
  """Generate a Cartesian pixel grid based on input parameters.
35
35
 
36
36
  Args:
37
37
  xlims (tuple): Azimuthal limits of pixel grid ([xmin, xmax])
38
38
  zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
39
- Nx (int): Number of azimuthal pixels, overrides dx and dz parameters
40
- Nz (int): Number of depth pixels, overrides dx and dz parameters
39
+ grid_size_x (int): Number of azimuthal pixels, overrides dx and dz parameters
40
+ grid_size_z (int): Number of depth pixels, overrides dx and dz parameters
41
41
  dx (float): Pixel spacing in azimuth
42
42
  dz (float): Pixel spacing in depth
43
43
 
44
44
  Raises:
45
- ValueError: Either Nx and Nz or dx and dz must be defined.
45
+ ValueError: Either grid_size_x and grid_size_z or dx and dz must be defined.
46
46
 
47
47
  Returns:
48
- grid (np.ndarray): Pixel grid of size (nz, nx, 3) in
48
+ grid (np.ndarray): Pixel grid of size (grid_size_z, nx, 3) in
49
49
  Cartesian coordinates (x, y, z)
50
50
  """
51
- assert (bool(Nx) and bool(Nz)) ^ (bool(dx) and bool(dz)), (
52
- "Either Nx and Nz or dx and dz must be defined."
51
+ assert (bool(grid_size_x) and bool(grid_size_z)) ^ (bool(dx) and bool(dz)), (
52
+ "Either grid_size_x and grid_size_z or dx and dz must be defined."
53
53
  )
54
54
 
55
55
  # Determine the grid spacing
56
- if Nx is not None and Nz is not None:
57
- x = np.linspace(xlims[0], xlims[1] + eps, Nx)
58
- z = np.linspace(zlims[0], zlims[1] + eps, Nz)
56
+ if grid_size_x is not None and grid_size_z is not None:
57
+ x = np.linspace(xlims[0], xlims[1] + eps, grid_size_x)
58
+ z = np.linspace(zlims[0], zlims[1] + eps, grid_size_z)
59
59
  elif dx is not None and dz is not None:
60
60
  sign = np.sign(xlims[1] - xlims[0])
61
61
  x = np.arange(xlims[0], xlims[1] + eps, sign * dx)
62
62
  z = np.arange(zlims[0], zlims[1] + eps, sign * dz)
63
63
  else:
64
- raise ValueError("Either Nx and Nz or dx and dz must be defined.")
64
+ raise ValueError("Either grid_size_x and grid_size_z or dx and dz must be defined.")
65
65
 
66
66
  # Create the pixel grid
67
67
  z_grid, x_grid = np.meshgrid(z, x, indexing="ij")
@@ -102,29 +102,30 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
102
102
  return grid
103
103
 
104
104
 
105
- def polar_pixel_grid(polar_limits, zlims, Nz: int, Nr: int):
105
+ def polar_pixel_grid(polar_limits, zlims, num_radial_pixels: int, num_polar_pixels: int):
106
106
  """Generate a polar grid.
107
107
 
108
108
  Uses radial_pixel_grid but based on parameters that are present in the scan class.
109
109
 
110
110
  Args:
111
- polar_limits (tuple): Azimuthal limits of pixel grid ([azimuth_min, azimuth_max])
111
+ polar_limits (tuple): Polar limits of pixel grid ([polar_min, polar_max])
112
112
  zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
113
- Nz (int, optional): Number of depth pixels.
114
- Nr (int, optional): Number of azimuthal pixels.
113
+ num_radial_pixels (int, optional): Number of depth pixels.
114
+ num_polar_pixels (int, optional): Number of polar pixels.
115
115
 
116
116
  Returns:
117
- grid (np.ndarray): Pixel grid of size (Nz, Nr, 3) in Cartesian coordinates (x, y, z)
117
+ grid (np.ndarray): Pixel grid of size (num_radial_pixels, num_polar_pixels, 3)
118
+ in Cartesian coordinates (x, y, z)
118
119
  """
119
120
  assert len(polar_limits) == 2, "polar_limits must be a tuple of length 2."
120
121
  assert len(zlims) == 2, "zlims must be a tuple of length 2."
121
122
 
122
- dr = (zlims[1] - zlims[0]) / Nz
123
+ dr = (zlims[1] - zlims[0]) / num_radial_pixels
123
124
 
124
125
  oris = np.array([0, 0, 0])
125
- oris = np.tile(oris, (Nr, 1))
126
- dirs_az = np.linspace(*polar_limits, Nr)
126
+ oris = np.tile(oris, (num_polar_pixels, 1))
127
+ dirs_az = np.linspace(*polar_limits, num_polar_pixels)
127
128
 
128
- dirs_el = np.zeros(Nr)
129
+ dirs_el = np.zeros(num_polar_pixels)
129
130
  dirs = np.vstack((dirs_az, dirs_el)).T
130
131
  return radial_pixel_grid(zlims, dr, oris, dirs).transpose(1, 0, 2)
zea/config.py CHANGED
@@ -47,8 +47,9 @@ import yaml
47
47
  from huggingface_hub import hf_hub_download
48
48
 
49
49
  from zea import log
50
+ from zea.data.preset_utils import HF_PREFIX, _hf_resolve_path
50
51
  from zea.internal.config.validation import config_schema
51
- from zea.internal.core import object_to_tensor
52
+ from zea.internal.core import dict_to_tensor
52
53
 
53
54
 
54
55
  class Config(dict):
@@ -430,8 +431,22 @@ class Config(dict):
430
431
  v._recursive_setattr(set_key, set_value)
431
432
 
432
433
  @classmethod
433
- def from_yaml(cls, path, **kwargs):
434
- """Load config object from yaml file"""
434
+ def from_path(cls, path, **kwargs):
435
+ """Load config object from a file path.
436
+
437
+ Args:
438
+ path (str or Path): The path to the config file.
439
+ Can be a string or a Path object. Additionally can be a string with
440
+ the prefix 'hf://', in which case it will be resolved to a
441
+ huggingface path.
442
+
443
+ Returns:
444
+ Config: config object.
445
+ """
446
+ if str(path).startswith(HF_PREFIX):
447
+ path = _hf_resolve_path(str(path))
448
+ if isinstance(path, str):
449
+ path = Path(path)
435
450
  return _load_config_from_yaml(path, config_class=cls, **kwargs)
436
451
 
437
452
  @classmethod
@@ -460,9 +475,14 @@ class Config(dict):
460
475
  local_path = hf_hub_download(repo_id, path, **kwargs)
461
476
  return _load_config_from_yaml(local_path, config_class=cls)
462
477
 
463
- def to_tensor(self):
478
+ @classmethod
479
+ def from_yaml(cls, path, **kwargs):
480
+ """Load config object from yaml file."""
481
+ return cls.from_path(path, **kwargs)
482
+
483
+ def to_tensor(self, keep_as_is=None):
464
484
  """Convert the attributes in the object to keras tensors"""
465
- return object_to_tensor(self)
485
+ return dict_to_tensor(self.serialize(), keep_as_is=keep_as_is)
466
486
 
467
487
 
468
488
  def check_config(config: Union[dict, Config], verbose: bool = False):
zea/data/data_format.py CHANGED
@@ -42,8 +42,8 @@ def generate_example_dataset(
42
42
  sound_speed=1540,
43
43
  center_frequency=7e6,
44
44
  sampling_frequency=40e6,
45
- n_z=512,
46
- n_x=512,
45
+ grid_size_z=512,
46
+ grid_size_x=512,
47
47
  ):
48
48
  """Generates an example dataset that contains all the necessary fields.
49
49
  Note: This dataset does not contain actual data, but is filled with random
@@ -60,7 +60,7 @@ def generate_example_dataset(
60
60
  # creating some fake raw and image data
61
61
  raw_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
62
62
  # image data is in dB
63
- image = np.ones((n_frames, n_z, n_x)) * -40
63
+ image = np.ones((n_frames, grid_size_z, grid_size_x)) * -40
64
64
 
65
65
  # creating some fake scan parameters
66
66
  t0_delays = np.zeros((n_tx, n_el), dtype=np.float32)
@@ -82,8 +82,8 @@ def generate_example_dataset(
82
82
 
83
83
  if add_optional_dtypes:
84
84
  aligned_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
85
- envelope_data = np.ones((n_frames, n_z, n_x))
86
- beamformed_data = np.ones((n_frames, n_z, n_x, n_ch))
85
+ envelope_data = np.ones((n_frames, grid_size_z, grid_size_x))
86
+ beamformed_data = np.ones((n_frames, grid_size_z, grid_size_x, n_ch))
87
87
  image_sc = np.ones_like(image)
88
88
  else:
89
89
  aligned_data = None
@@ -123,10 +123,11 @@ def validate_input_data(raw_data, aligned_data, envelope_data, beamformed_data,
123
123
  aligned_data (np.ndarray): The aligned data of the ultrasound measurement of
124
124
  shape (n_frames, n_tx, n_ax, n_el, n_ch).
125
125
  envelope_data (np.ndarray): The envelope data of the ultrasound measurement of
126
- shape (n_frames, n_z, n_x).
126
+ shape (n_frames, grid_size_z, grid_size_x).
127
127
  beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of
128
- shape (n_frames, n_z, n_x).
129
- image (np.ndarray): The ultrasound images to be saved of shape (n_frames, n_z, n_x).
128
+ shape (n_frames, grid_size_z, grid_size_x).
129
+ image (np.ndarray): The ultrasound images to be saved
130
+ of shape (n_frames, grid_size_z, grid_size_x).
130
131
  image_sc (np.ndarray): The scan converted ultrasound images to be saved
131
132
  of shape (n_frames, output_size_z, output_size_x).
132
133
  """
@@ -254,7 +255,7 @@ def _write_datasets(
254
255
  group_name=data_group_name,
255
256
  name="envelope_data",
256
257
  data=_convert_datatype(envelope_data),
257
- description="The envelope_data of shape (n_frames, n_z, n_x).",
258
+ description="The envelope_data of shape (n_frames, grid_size_z, grid_size_x).",
258
259
  unit="unitless",
259
260
  )
260
261
 
@@ -262,7 +263,7 @@ def _write_datasets(
262
263
  group_name=data_group_name,
263
264
  name="beamformed_data",
264
265
  data=_convert_datatype(beamformed_data),
265
- description="The beamformed_data of shape (n_frames, n_z, n_x).",
266
+ description="The beamformed_data of shape (n_frames, grid_size_z, grid_size_x).",
266
267
  unit="unitless",
267
268
  )
268
269
 
@@ -271,7 +272,7 @@ def _write_datasets(
271
272
  name="image",
272
273
  data=_convert_datatype(image),
273
274
  unit="unitless",
274
- description="The images of shape [n_frames, n_z, n_x]",
275
+ description="The images of shape [n_frames, grid_size_z, grid_size_x]",
275
276
  )
276
277
 
277
278
  _add_dataset(
@@ -546,10 +547,11 @@ def generate_zea_dataset(
546
547
  aligned_data (np.ndarray): The aligned data of the ultrasound measurement of
547
548
  shape (n_frames, n_tx, n_ax, n_el, n_ch).
548
549
  envelope_data (np.ndarray): The envelope data of the ultrasound measurement of
549
- shape (n_frames, n_z, n_x).
550
+ shape (n_frames, grid_size_z, grid_size_x).
550
551
  beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of
551
- shape (n_frames, n_z, n_x, n_ch).
552
- image (np.ndarray): The ultrasound images to be saved of shape (n_frames, n_z, n_x).
552
+ shape (n_frames, grid_size_z, grid_size_x, n_ch).
553
+ image (np.ndarray): The ultrasound images to be saved
554
+ of shape (n_frames, grid_size_z, grid_size_x).
553
555
  image_sc (np.ndarray): The scan converted ultrasound images to be saved
554
556
  of shape (n_frames, output_size_z, output_size_x).
555
557
  probe_geometry (np.ndarray): The probe geometry of shape (n_el, 3).
zea/data/file.py CHANGED
@@ -366,9 +366,7 @@ class File(h5py.File):
366
366
  """
367
367
  file_scan_parameters = self.get_parameters(event)
368
368
 
369
- probe_parameters = reduce_to_signature(
370
- Probe.from_name("generic").__init__, file_scan_parameters
371
- )
369
+ probe_parameters = reduce_to_signature(Probe.__init__, file_scan_parameters)
372
370
  return probe_parameters
373
371
 
374
372
  def probe(self, event=None) -> Probe:
@@ -387,21 +385,8 @@ class File(h5py.File):
387
385
  Returns:
388
386
  Probe: The probe object.
389
387
  """
390
- probe_parameters = self.get_probe_parameters(event)
391
- if self.probe_name == "generic":
392
- return Probe.from_name(self.probe_name, **probe_parameters)
393
- else:
394
- probe = Probe.from_name(self.probe_name)
395
-
396
- probe_geometry = probe_parameters.get("probe_geometry", None)
397
- if not np.allclose(probe_geometry, probe.probe_geometry):
398
- probe.probe_geometry = probe_geometry
399
- log.warning(
400
- "The probe geometry in the data file does not "
401
- "match the probe geometry of the probe. The probe "
402
- "geometry has been updated to match the data file."
403
- )
404
- return probe
388
+ probe_parameters_file = self.get_probe_parameters(event)
389
+ return Probe.from_parameters(self.probe_name, probe_parameters_file)
405
390
 
406
391
  def recursively_load_dict_contents_from_group(self, path: str, squeeze: bool = False) -> dict:
407
392
  """Load dict from contents of group
zea/data/preset_utils.py CHANGED
@@ -50,60 +50,74 @@ def _hf_download(repo_id, filename, cache_dir=HF_DATASETS_DIR):
50
50
  )
51
51
 
52
52
 
53
- def _hf_get_snapshot_dir(repo_id, cache_dir=HF_DATASETS_DIR):
54
- repo_cache_dir = Path(cache_dir) / f"datasets--{repo_id.replace('/', '--')}"
55
- snapshots_dir = repo_cache_dir / "snapshots"
56
- if not snapshots_dir.exists() or not any(snapshots_dir.iterdir()):
57
- # Try to trigger a download to populate the cache
58
- files = _hf_list_files(repo_id)
59
- # Pick the first file (prefer .h5/.hdf5 if possible)
60
- h5_files = [f for f in files if f.endswith(".h5") or f.endswith(".hdf5")]
61
- target_file = h5_files[0] if h5_files else files[0]
62
- _hf_download(repo_id, target_file, cache_dir)
63
- # Now try again
64
- if not snapshots_dir.exists() or not any(snapshots_dir.iterdir()):
65
- raise FileNotFoundError(
66
- f"No snapshots found in Hugging Face cache for {repo_id} after download attempt"
67
- )
68
- snapshot_hashes = sorted(snapshots_dir.iterdir(), key=lambda p: p.stat().st_mtime, reverse=True)
69
- if not snapshot_hashes:
70
- raise FileNotFoundError(f"No snapshot found for {repo_id} in cache.")
71
- return snapshot_hashes[0]
53
+ def _get_snapshot_dir_from_downloaded_file(downloaded_file_path: str | Path) -> Path:
54
+ """Extract the snapshot directory from a downloaded file's path.
55
+
56
+ HF Hub downloads to: cache_dir/datasets--org--repo/snapshots/{hash}/path/to/filename
57
+ This navigates up to find the {hash} directory (the snapshot directory).
58
+ """
59
+ file_path = Path(downloaded_file_path)
60
+
61
+ # Navigate up the path until we find the snapshots directory
62
+ current = file_path.parent
63
+ while current.name != "snapshots" and current.parent != current:
64
+ current = current.parent
65
+
66
+ if current.name == "snapshots":
67
+ # Return the snapshot hash directory (first subdirectory of snapshots)
68
+ snapshot_dirs = [d for d in current.iterdir() if d.is_dir()]
69
+ if snapshot_dirs:
70
+ # Return the most recent snapshot directory
71
+ return max(snapshot_dirs, key=lambda p: p.stat().st_mtime)
72
+
73
+ raise FileNotFoundError(f"Could not find snapshot directory for {downloaded_file_path}")
74
+
75
+
76
+ def _download_files_in_path(
77
+ repo_id: str, files: list, path_filter: str = None, cache_dir=HF_DATASETS_DIR
78
+ ) -> list[str]:
79
+ """Download all files matching the path filter."""
80
+ downloaded_files = []
81
+ for f in files:
82
+ if path_filter is None or f.startswith(path_filter):
83
+ downloaded_path = _hf_download(repo_id, f, cache_dir)
84
+ downloaded_files.append(downloaded_path)
85
+
86
+ return downloaded_files
72
87
 
73
88
 
74
89
  def _hf_resolve_path(hf_path: str, cache_dir=HF_DATASETS_DIR):
75
- """Download a file or directory from Hugging Face Hub to a local cache directory.
76
- Returns the local path to the downloaded file or directory.
90
+ """Resolve a Hugging Face path to a local cache directory path.
91
+
92
+ Downloads files from a HuggingFace dataset repository and returns
93
+ the local path where they are cached. Handles:
94
+ - hf://org/repo/subdir/ - Downloads all files in subdirectory
95
+ - hf://org/repo/file.h5 - Downloads specific file
96
+ - hf://org/repo - Downloads all files in repo
77
97
  """
78
98
  repo_id, subpath = _hf_parse_path(hf_path)
79
99
  files = _hf_list_files(repo_id)
80
- snapshot_dir = _hf_get_snapshot_dir(repo_id, cache_dir)
81
-
82
- def is_h5(f):
83
- return f.endswith(".h5") or f.endswith(".hdf5")
84
100
 
85
101
  if subpath:
86
- # Directory
102
+ # Directory case
87
103
  if any(f.startswith(subpath + "/") for f in files):
88
- local_dir = snapshot_dir / subpath
89
- for f in files:
90
- if f.startswith(subpath + "/") and is_h5(f):
91
- _hf_download(repo_id, f, cache_dir)
92
- if not local_dir.exists():
93
- raise FileNotFoundError(f"Directory {local_dir} not found after download.")
94
- return local_dir
95
- # File
96
- elif any(f == subpath for f in files) and is_h5(subpath):
97
- _hf_download(repo_id, subpath, cache_dir)
98
- local_file = snapshot_dir / subpath
99
- if not local_file.exists():
100
- raise FileNotFoundError(f"File {local_file} not found after download.")
101
- return local_file
104
+ downloaded_files = _download_files_in_path(repo_id, files, subpath + "/", cache_dir)
105
+ if not downloaded_files:
106
+ raise FileNotFoundError(f"No files found in directory {subpath}")
107
+
108
+ snapshot_dir = _get_snapshot_dir_from_downloaded_file(downloaded_files[0])
109
+ return snapshot_dir / subpath
110
+
111
+ # File case
112
+ elif subpath in files:
113
+ downloaded_file = _hf_download(repo_id, subpath, cache_dir)
114
+ return Path(downloaded_file)
102
115
  else:
103
116
  raise FileNotFoundError(f"{subpath} not found in {repo_id}")
104
117
  else:
105
- # All .h5/.hdf5 files in repo
106
- for f in files:
107
- if is_h5(f):
108
- _hf_download(repo_id, f, cache_dir)
109
- return snapshot_dir
118
+ # All files in repo
119
+ downloaded_files = _download_files_in_path(repo_id, files, None, cache_dir)
120
+ if not downloaded_files:
121
+ raise FileNotFoundError(f"No files found in repository {repo_id}")
122
+
123
+ return _get_snapshot_dir_from_downloaded_file(downloaded_files[0])
zea/datapaths.py CHANGED
@@ -257,13 +257,13 @@ def set_data_paths(
257
257
 
258
258
  Returns:
259
259
  dict: Absolute paths to location of data. Stores the following parameters:
260
- ``data_root``, ``repo_root``, ``output``, ``system``, ``username``, ``hostname``
260
+ ``data_root``, ``zea_root``, ``output``, ``system``, ``username``, ``hostname``
261
261
 
262
262
  """
263
263
  username = getpass.getuser()
264
264
  system = platform.system().lower()
265
265
  hostname = socket.gethostname()
266
- repo_root = importlib.resources.files("zea") # ultrasound-toolbox/zea
266
+ zea_root = importlib.resources.files("zea")
267
267
 
268
268
  # If user_config is None, use the default users.yaml file
269
269
  if isinstance(user_config, type(None)):
@@ -304,7 +304,7 @@ def set_data_paths(
304
304
 
305
305
  data_path = {
306
306
  "data_root": Path(data_root),
307
- "repo_root": repo_root,
307
+ "zea_root": zea_root,
308
308
  "output": Path(output),
309
309
  "system": system,
310
310
  "username": username,
zea/display.py CHANGED
@@ -130,7 +130,7 @@ def scan_convert_2d(
130
130
 
131
131
  Returns:
132
132
  ndarray: The scan-converted 2D ultrasound image in Cartesian coordinates.
133
- Has dimensions (n_z, n_x). Coordinates outside the input image
133
+ Has dimensions (grid_size_z, grid_size_x). Coordinates outside the input image
134
134
  ranges are filled with NaNs.
135
135
  parameters (dict): A dictionary containing information about the scan conversion.
136
136
  Contains the resolution, x, and z limits, rho and theta ranges.
@@ -269,7 +269,7 @@ def scan_convert_3d(
269
269
 
270
270
  Returns:
271
271
  ndarray: The scan-converted 3D ultrasound image in Cartesian coordinates.
272
- Has dimensions (n_z, n_x, n_y). Coordinates outside the input image
272
+ Has dimensions (grid_size_z, grid_size_x, n_y). Coordinates outside the input image
273
273
  ranges are filled with NaNs.
274
274
  parameters (dict): A dictionary containing information about the scan conversion.
275
275
  Contains the resolution, x, y, and z limits, rho, theta, and phi ranges.