zea 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zea/__init__.py CHANGED
@@ -7,7 +7,7 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.1"
10
+ __version__ = "0.0.3"
11
11
 
12
12
 
13
13
  def setup():
zea/agent/selection.py CHANGED
@@ -92,6 +92,7 @@ class GreedyEntropy(LinesActionModel):
92
92
  mean: float = 0,
93
93
  std_dev: float = 1,
94
94
  num_lines_to_update: int = 5,
95
+ entropy_sigma: float = 1.0,
95
96
  ):
96
97
  """Initialize the GreedyEntropy action selection model.
97
98
 
@@ -104,6 +105,8 @@ class GreedyEntropy(LinesActionModel):
104
105
  std_dev (float, optional): The standard deviation of the RBF. Defaults to 1.
105
106
  num_lines_to_update (int, optional): The number of lines around the selected line
106
107
  to update. Must be odd.
108
+ entropy_sigma (float, optional): The standard deviation of the Gaussian
109
+ Mixture components used to approximate the posterior.
107
110
  """
108
111
  super().__init__(n_actions, n_possible_actions, img_width, img_height)
109
112
 
@@ -124,7 +127,7 @@ class GreedyEntropy(LinesActionModel):
124
127
  self.num_lines_to_update,
125
128
  )
126
129
  self.upside_down_gaussian = upside_down_gaussian(points_to_evaluate)
127
- self.entropy_sigma = 1
130
+ self.entropy_sigma = entropy_sigma
128
131
 
129
132
  @staticmethod
130
133
  def compute_pairwise_pixel_gaussian_error(
@@ -157,15 +160,18 @@ class GreedyEntropy(LinesActionModel):
157
160
  # This way we can just sum across the height axis and get the entropy
158
161
  # for each pixel in a given line
159
162
  batch_size, n_particles, _, height, _ = gaussian_error_per_pixel_i_j.shape
160
- gaussian_error_per_pixel_stacked = ops.reshape(
161
- gaussian_error_per_pixel_i_j,
162
- [
163
- batch_size,
164
- n_particles,
165
- n_particles,
166
- height * stack_n_cols,
167
- n_possible_actions,
168
- ],
163
+ gaussian_error_per_pixel_stacked = ops.transpose(
164
+ ops.reshape(
165
+ ops.transpose(gaussian_error_per_pixel_i_j, (0, 1, 2, 4, 3)),
166
+ [
167
+ batch_size,
168
+ n_particles,
169
+ n_particles,
170
+ n_possible_actions,
171
+ height * stack_n_cols,
172
+ ],
173
+ ),
174
+ (0, 1, 2, 4, 3),
169
175
  )
170
176
  # [n_particles, n_particles, batch, height, width]
171
177
  return gaussian_error_per_pixel_stacked
@@ -428,10 +434,15 @@ class CovarianceSamplingLines(LinesActionModel):
428
434
  generation. Defaults to None.
429
435
 
430
436
  Returns:
431
- Tensor: The mask of shape (batch_size, img_size, img_size)
437
+ Tuple[Tensor, Tensor]:
438
+ - Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
439
+ - Masks of shape (batch_size, img_height, img_width)
432
440
  """
433
441
  batch_size, n_particles, rows, _ = ops.shape(particles)
434
442
 
443
+ # [batch_size, rows, cols, n_particles]
444
+ particles = ops.transpose(particles, (0, 2, 3, 1))
445
+
435
446
  # [batch_size, rows * stack_n_cols, n_possible_actions, n_particles]
436
447
  shape = [
437
448
  batch_size,
@@ -441,7 +452,7 @@ class CovarianceSamplingLines(LinesActionModel):
441
452
  ]
442
453
  particles = ops.reshape(particles, shape)
443
454
 
444
- # [batch_size, rows, n_possible_actions, n_possible_actions]
455
+ # [batch_size, rows * stack_n_cols, n_possible_actions, n_possible_actions]
445
456
  cov_matrix = tensor_ops.batch_cov(particles)
446
457
 
447
458
  # Sum over the row dimension [batch_size, n_possible_actions, n_possible_actions]
zea/backend/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- """Backend subpackage for ``zea``.
1
+ """Backend-specific utilities.
2
2
 
3
3
  This subpackage provides backend-specific utilities for the ``zea`` library. Most backend logic is handled by Keras 3, but a few features require custom wrappers to ensure compatibility and performance across JAX, TensorFlow, and PyTorch.
4
4
 
@@ -9,7 +9,7 @@ Key Features
9
9
  ------------
10
10
 
11
11
  - **JIT Compilation** (:func:`zea.backend.jit`):
12
- Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines.
12
+ Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines. Note that jit compilation is not yet supported when using the `torch` backend.
13
13
 
14
14
  - **Automatic Differentiation** (:class:`zea.backend.AutoGrad`):
15
15
  Offers a backend-agnostic wrapper for automatic differentiation, allowing gradient computation regardless of the underlying ML library.
@@ -108,7 +108,9 @@ def _jit_compile(func, jax=True, tensorflow=True, **kwargs):
108
108
  return func
109
109
  else:
110
110
  log.warning(
111
- f"Unsupported backend: {backend}. Supported backends are 'tensorflow' and 'jax'."
111
+ f"JIT compilation not currently supported for backend {backend}. "
112
+ "Supported backends are 'tensorflow' and 'jax'."
112
113
  )
114
+ log.warning("Initialize zea.Pipeline with jit_options=None to suppress this warning.")
113
115
  log.warning("Falling back to non-compiled mode.")
114
116
  return func
zea/beamform/__init__.py CHANGED
@@ -17,4 +17,4 @@ see the pipeline example notebook: :doc:`../notebooks/pipeline/zea_pipeline_exam
17
17
 
18
18
  """
19
19
 
20
- from . import beamformer, delays, lens_correction, pfield, pixelgrid
20
+ from . import beamformer, delays, lens_correction, pfield, phantoms, pixelgrid
@@ -1,5 +1,6 @@
1
1
  """Main beamforming functions for ultrasound imaging."""
2
2
 
3
+ import keras
3
4
  import numpy as np
4
5
  from keras import ops
5
6
 
@@ -7,6 +8,45 @@ from zea.beamform.lens_correction import calculate_lens_corrected_delays
7
8
  from zea.tensor_ops import safe_vectorize
8
9
 
9
10
 
11
+ def fnum_window_fn_rect(normalized_angle):
12
+ """Rectangular window function for f-number masking."""
13
+ return ops.where(normalized_angle <= 1.0, 1.0, 0.0)
14
+
15
+
16
+ def fnum_window_fn_hann(normalized_angle):
17
+ """Hann window function for f-number masking."""
18
+ # Use a Hann window function to smoothly transition the mask
19
+ return ops.where(
20
+ normalized_angle <= 1.0,
21
+ 0.5 * (1 + ops.cos(np.pi * normalized_angle)),
22
+ 0.0,
23
+ )
24
+
25
+
26
+ def fnum_window_fn_tukey(normalized_angle, alpha=0.5):
27
+ """Tukey window function for f-number masking.
28
+
29
+ Args:
30
+ normalized_angle (ops.Tensor): Normalized angle values in the range [0, 1].
31
+ alpha (float, optional): The alpha parameter for the Tukey window. 0.0 corresponds to a
32
+ rectangular window, 1.0 corresponds to a Hann window. Defaults to 0.5.
33
+ """
34
+ # Use a Tukey window function to smoothly transition the mask
35
+ normalized_angle = ops.clip(ops.abs(normalized_angle), 0.0, 1.0)
36
+
37
+ beta = 1.0 - alpha
38
+
39
+ return ops.where(
40
+ normalized_angle < beta,
41
+ 1.0,
42
+ ops.where(
43
+ normalized_angle < 1.0,
44
+ 0.5 * (1 + ops.cos(np.pi * (normalized_angle - beta) / (ops.abs(alpha) + 1e-6))),
45
+ 0.0,
46
+ ),
47
+ )
48
+
49
+
10
50
  def tof_correction(
11
51
  data,
12
52
  flatgrid,
@@ -19,11 +59,12 @@ def tof_correction(
19
59
  demodulation_frequency,
20
60
  fnum,
21
61
  angles,
22
- vfocus,
62
+ focus_distances,
23
63
  apply_phase_rotation=False,
24
64
  apply_lens_correction=False,
25
65
  lens_thickness=1e-3,
26
66
  lens_sound_speed=1000,
67
+ fnum_window_fn=fnum_window_fn_rect,
27
68
  ):
28
69
  """Time-of-flight correction for a flat grid.
29
70
 
@@ -44,7 +85,7 @@ def tof_correction(
44
85
  fnum (int, optional): Focus number. Defaults to 1.
45
86
  angles (ops.Tensor): The angles of the plane waves in radians of shape
46
87
  `(n_tx,)`
47
- vfocus (ops.Tensor): The focus distance of shape `(n_tx,)`
88
+ focus_distances (ops.Tensor): The focus distance of shape `(n_tx,)`
48
89
  apply_phase_rotation (bool, optional): Whether to apply phase rotation to
49
90
  time-of-flights. Defaults to False.
50
91
  apply_lens_correction (bool, optional): Whether to apply lens correction to
@@ -54,6 +95,9 @@ def tof_correction(
54
95
  lens correction. Defaults to 1e-3.
55
96
  lens_sound_speed (float, optional): Speed of sound in the lens in m/s. Used
56
97
  for lens correction Defaults to 1000.
98
+ fnum_window_fn (callable, optional): F-number function to define the transition from
99
+ straight in front of the element (fn(0.0)) to the largest angle within the f-number cone
100
+ (fn(1.0)). The function should be zero for fn(x>1.0).
57
101
 
58
102
  Returns:
59
103
  (ops.Tensor): time-of-flight corrected data
@@ -90,7 +134,7 @@ def tof_correction(
90
134
  sound_speed,
91
135
  n_tx,
92
136
  n_el,
93
- vfocus,
137
+ focus_distances,
94
138
  angles,
95
139
  lens_thickness=lens_thickness,
96
140
  lens_sound_speed=lens_sound_speed,
@@ -100,7 +144,7 @@ def tof_correction(
100
144
  mask = ops.cond(
101
145
  fnum == 0,
102
146
  lambda: ops.ones((n_pix, n_el, 1)),
103
- lambda: apod_mask(flatgrid, probe_geometry, fnum),
147
+ lambda: fnumber_mask(flatgrid, probe_geometry, fnum, fnum_window_fn=fnum_window_fn),
104
148
  )
105
149
 
106
150
  def _apply_delays(data_tx, txdel):
@@ -408,64 +452,48 @@ def distance_Tx_generic(
408
452
  return dist
409
453
 
410
454
 
411
- def apod_mask(grid, probe_geometry, f_number):
455
+ def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
412
456
  """Apodization mask for the receive beamformer.
413
457
 
414
- Computes a binary mask to disregard pixels outside of the vision cone of a
458
+ Computes a mask to disregard pixels outside of the vision cone of a
415
459
  transducer element. Transducer elements can only accurately measure
416
460
  signals within some range of incidence angles. Waves coming in from the
417
461
  side do not register correctly leading to a worse image.
418
462
 
419
463
  Args:
420
- grid (ops.Tensor): The flattened image grid `(n_pix, 3)`.
464
+ flatgrid (ops.Tensor): The flattened image grid `(n_pix, 3)`.
421
465
  probe_geometry (ops.Tensor): The transducer element positions of shape
422
466
  `(n_el, 3)`.
423
467
  f_number (int): The receive f-number. Set to zero to not use masking and
424
468
  return 1. (The f-number is the ratio between distance from the transducer
425
469
  and the size of the aperture below which transducer elements contribute to
426
470
  the signal for a pixel.).
471
+ fnum_window_fn (callable): F-number function to define the transition from
472
+ straight in front of the element (fn(0.0)) to the largest angle within the f-number cone
473
+ (fn(1.0)). The function should be zero for fn(x>1.0).
474
+
427
475
 
428
476
  Returns:
429
477
  Tensor: Mask of shape `(n_pix, n_el, 1)`
430
478
  """
431
- n_pix = ops.shape(grid)[0]
432
- n_el = ops.shape(probe_geometry)[0]
433
-
434
- # Get the depth of every pixel
435
- z_pixel = grid[:, 2]
436
- # Get the lateral location of each pixel
437
- x_pixel = grid[:, 0]
438
- # Get the lateral location of each element
439
- x_element = ops.cast(probe_geometry[:, 0], dtype="float32")
440
-
441
- # Compute the aperture size for every pixel
442
- # The f-number is by definition f=z/aperture
443
- aperture = z_pixel / f_number
444
-
445
- # Use matrix multiplication to expand aperture tensor, x_pixel tensor, and
446
- # x_element tensor to shape (n_pix, n_el)
447
- ones_aperture = ops.ones(
448
- (1, n_el), dtype=ops.dtype(aperture)
449
- ) # getting error here? pip install -U keras ;)
450
- ones_x_pixel = ops.ones((1, n_el), dtype=ops.dtype(x_pixel))
451
- ones_x_element = ops.ones((n_pix, 1), dtype=ops.dtype(x_element))
452
-
453
- aperture = ops.matmul(aperture[..., None], ones_aperture)
454
- expanded_x_pixel = ops.matmul(x_pixel[..., None], ones_x_pixel)
455
- expanded_x_element = ops.matmul(ones_x_element, x_element[None])
456
-
457
- # Compute the lateral distance between elements and pixels
458
- # Of shape (n_pix, n_el)
459
- distance = ops.abs(expanded_x_pixel - expanded_x_element)
460
-
461
- # Compute binary mask for which the lateral pixel distance is less than
462
- # half
463
- # the aperture i.e. where the pixel lies within the vision cone of the
464
- # element
465
- mask = distance <= aperture / 2
466
- mask = ops.cast(mask, "float32")
467
-
468
- # Add dummy dimension for RF/IQ channel channel
479
+
480
+ grid_relative_to_probe = flatgrid[:, None] - probe_geometry[None]
481
+
482
+ grid_relative_to_probe_norm = ops.linalg.norm(grid_relative_to_probe, axis=-1)
483
+
484
+ grid_relative_to_probe_z = grid_relative_to_probe[..., 2] / (grid_relative_to_probe_norm + 1e-6)
485
+
486
+ alpha = ops.arccos(grid_relative_to_probe_z)
487
+
488
+ # The f-number is fnum = z/aperture = 1/(2 * tan(alpha))
489
+ # Rearranging gives us alpha = arctan(1/(2 * fnum))
490
+ # We can use this to compute the maximum angle alpha that is allowed
491
+ max_alpha = ops.arctan(1 / (2 * f_number + keras.backend.epsilon()))
492
+
493
+ normalized_angle = alpha / max_alpha
494
+ mask = fnum_window_fn(normalized_angle)
495
+
496
+ # Add dummy channel dimension
469
497
  mask = mask[..., None]
470
498
 
471
499
  return mask
zea/beamform/pfield.py CHANGED
@@ -60,7 +60,8 @@ def compute_pfield(
60
60
  n_el (int): Number of elements in the probe.
61
61
  probe_geometry (array): Geometry of the probe elements.
62
62
  tx_apodizations (array): Transmit apodization values.
63
- grid (array): Grid points where the pressure field is computed of shape (Nz, Nx, 3).
63
+ grid (array): Grid points where the pressure field is computed
64
+ of shape (grid_size_z, grid_size_x, 3).
64
65
  t0_delays (array): Transmit delays for each transmit event.
65
66
  frequency_step (int, optional): Frequency step. Default is 4.
66
67
  Higher is faster but less accurate.
@@ -78,7 +79,8 @@ def compute_pfield(
78
79
  verbose (bool, optional): Whether to print progress.
79
80
 
80
81
  Returns:
81
- ops.array: The (normalized) pressure field (across tx events) of shape (n_tx, Nz, Nx).
82
+ ops.array: The (normalized) pressure field (across tx events)
83
+ of shape (n_tx, grid_size_z, grid_size_x).
82
84
  """
83
85
  # medium params
84
86
  alpha_db = 0 # currently we ignore attenuation in the compounding
@@ -293,7 +295,8 @@ def normalize_pressure_field(pfield, alpha: float = 1.0, percentile: float = 10.
293
295
  Normalize the input array of intensities by zeroing out values below a given percentile.
294
296
 
295
297
  Args:
296
- pfield (array): The unnormalized pressure field array of shape (n_tx, Nz, Nx).
298
+ pfield (array): The unnormalized pressure field array
299
+ of shape (n_tx, grid_size_z, grid_size_x).
297
300
  alpha (float, optional): Exponent to 'sharpen or smooth' the weighting.
298
301
  Higher values result in sharper weighting. Default is 1.0.
299
302
  percentile (int, optional): minimum percentile threshold to keep in the weighting.
@@ -0,0 +1,43 @@
1
+ import numpy as np
2
+
3
+
4
+ def fish():
5
+ """Returns a scatterer phantom for ultrasound simulation tests.
6
+
7
+ Returns:
8
+ ndarray: The scatterer positions of shape (n_scat, 3).
9
+ """
10
+ # The size is the height of the fish
11
+ size = 11e-3
12
+ z_offset = 2.0 * size
13
+
14
+ # See https://en.wikipedia.org/wiki/Fish_curve
15
+ def fish_curve(t, size=1):
16
+ x = size * (np.cos(t) - np.sin(t) ** 2 / np.sqrt(2))
17
+ y = size * np.cos(t) * np.sin(t)
18
+ return x, y
19
+
20
+ scat_x, scat_z = fish_curve(np.linspace(0, 2 * np.pi, 100), size=size)
21
+
22
+ scat_x = np.concatenate(
23
+ [
24
+ scat_x,
25
+ np.array([size * 0.7]),
26
+ np.array([size * 1.1]),
27
+ np.array([size * 1.4]),
28
+ np.array([size * 1.2]),
29
+ ]
30
+ )
31
+ scat_y = np.zeros_like(scat_x)
32
+ scat_z = np.concatenate(
33
+ [
34
+ scat_z,
35
+ np.array([-size * 0.1]),
36
+ np.array([-size * 0.25]),
37
+ np.array([-size * 0.6]),
38
+ np.array([-size * 1.0]),
39
+ ]
40
+ )
41
+
42
+ scat_z += z_offset
43
+ return np.stack([scat_x, scat_y, scat_z], axis=1)
zea/beamform/pixelgrid.py CHANGED
@@ -7,101 +7,61 @@ from zea import log
7
7
  eps = 1e-10
8
8
 
9
9
 
10
- def get_grid(
11
- xlims,
12
- zlims,
13
- Nx: int,
14
- Nz: int,
15
- sound_speed,
16
- center_frequency,
17
- pixels_per_wavelength,
18
- verbose=False,
19
- ):
20
- """Creates a pixelgrid based on scan class parameters."""
21
-
22
- if Nx and Nz:
23
- grid = cartesian_pixel_grid(xlims, zlims, Nx=int(Nx), Nz=int(Nz))
24
- else:
25
- wvln = sound_speed / center_frequency
26
- dx = wvln / pixels_per_wavelength
27
- dz = dx
28
- grid = cartesian_pixel_grid(xlims, zlims, dx=dx, dz=dz)
29
- if verbose:
30
- print(
31
- f"Pixelgrid was set automatically to Nx: {grid.shape[1]}, Nz: {grid.shape[0]}, "
32
- f"using {pixels_per_wavelength} pixels per wavelength."
33
- )
34
- return grid
35
-
36
-
37
10
  def check_for_aliasing(scan):
38
11
  """Checks if the scan class parameters will cause spatial aliasing due to a too low pixel
39
12
  density. If so, a warning is printed with a suggestion to increase the pixel density by either
40
13
  increasing the number of pixels, or decreasing the pixel spacing, depending on which parameter
41
14
  was set by the user."""
42
- wvln = scan.sound_speed / scan.center_frequency
43
- dx = wvln / scan.pixels_per_wavelength
44
- dz = dx
45
-
46
15
  width = scan.xlims[1] - scan.xlims[0]
47
16
  depth = scan.zlims[1] - scan.zlims[0]
48
-
49
- if scan.Nx and scan.Nz:
50
- if width / scan.Nx > wvln / 2:
51
- log.warning(
52
- f"width/Nx = {width / scan.Nx:.7f} < wvln/2 = {wvln / 2}. "
53
- f"Consider increasing scan.Nx to {int(np.ceil(width / (wvln / 2)))} or more."
54
- )
55
- if depth / scan.Nz > wvln / 2:
56
- log.warning(
57
- f"depth/Nz = {depth / scan.Nz:.7f} < wvln/2 = {wvln / 2:.7f}. "
58
- f"Consider increasing scan.Nz to {int(np.ceil(depth / (wvln / 2)))} or more."
59
- )
60
- else:
61
- if dx > wvln / 2:
62
- log.warning(
63
- f"dx = {dx:.7f} > wvln/2 = {wvln / 2:.7f}. "
64
- f"Consider increasing scan.pixels_per_wavelength to 2 or more"
65
- )
66
- if dz > wvln / 2:
67
- log.warning(
68
- f"dz = {dz:.7f} > wvln/2 = {wvln / 2:.7f}. "
69
- f"Consider increasing scan.pixels_per_wavelength to 2 or more"
70
- )
71
-
72
-
73
- def cartesian_pixel_grid(xlims, zlims, Nx=None, Nz=None, dx=None, dz=None):
17
+ wvln = scan.wavelength
18
+
19
+ if width / scan.grid_size_x > wvln / 2:
20
+ log.warning(
21
+ f"width/grid_size_x = {width / scan.grid_size_x:.7f} < wavelength/2 = {wvln / 2}. "
22
+ f"Consider either increasing scan.grid_size_x to {int(np.ceil(width / (wvln / 2)))} "
23
+ "or more, or increasing scan.pixels_per_wavelength to 2 or more."
24
+ )
25
+ if depth / scan.grid_size_z > wvln / 2:
26
+ log.warning(
27
+ f"depth/grid_size_z = {depth / scan.grid_size_z:.7f} < wavelength/2 = {wvln / 2:.7f}. "
28
+ f"Consider either increasing scan.grid_size_z to {int(np.ceil(depth / (wvln / 2)))} "
29
+ "or more, or increasing scan.pixels_per_wavelength to 2 or more."
30
+ )
31
+
32
+
33
+ def cartesian_pixel_grid(xlims, zlims, grid_size_x=None, grid_size_z=None, dx=None, dz=None):
74
34
  """Generate a Cartesian pixel grid based on input parameters.
75
35
 
76
36
  Args:
77
37
  xlims (tuple): Azimuthal limits of pixel grid ([xmin, xmax])
78
38
  zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
79
- Nx (int): Number of azimuthal pixels, overrides dx and dz parameters
80
- Nz (int): Number of depth pixels, overrides dx and dz parameters
39
+ grid_size_x (int): Number of azimuthal pixels, overrides dx and dz parameters
40
+ grid_size_z (int): Number of depth pixels, overrides dx and dz parameters
81
41
  dx (float): Pixel spacing in azimuth
82
42
  dz (float): Pixel spacing in depth
83
43
 
84
44
  Raises:
85
- ValueError: Either Nx and Nz or dx and dz must be defined.
45
+ ValueError: Either grid_size_x and grid_size_z or dx and dz must be defined.
86
46
 
87
47
  Returns:
88
- grid (np.ndarray): Pixel grid of size (nz, nx, 3) in
48
+ grid (np.ndarray): Pixel grid of size (grid_size_z, nx, 3) in
89
49
  Cartesian coordinates (x, y, z)
90
50
  """
91
- assert (bool(Nx) and bool(Nz)) ^ (bool(dx) and bool(dz)), (
92
- "Either Nx and Nz or dx and dz must be defined."
51
+ assert (bool(grid_size_x) and bool(grid_size_z)) ^ (bool(dx) and bool(dz)), (
52
+ "Either grid_size_x and grid_size_z or dx and dz must be defined."
93
53
  )
94
54
 
95
55
  # Determine the grid spacing
96
- if Nx is not None and Nz is not None:
97
- x = np.linspace(xlims[0], xlims[1] + eps, Nx)
98
- z = np.linspace(zlims[0], zlims[1] + eps, Nz)
56
+ if grid_size_x is not None and grid_size_z is not None:
57
+ x = np.linspace(xlims[0], xlims[1] + eps, grid_size_x)
58
+ z = np.linspace(zlims[0], zlims[1] + eps, grid_size_z)
99
59
  elif dx is not None and dz is not None:
100
60
  sign = np.sign(xlims[1] - xlims[0])
101
61
  x = np.arange(xlims[0], xlims[1] + eps, sign * dx)
102
62
  z = np.arange(zlims[0], zlims[1] + eps, sign * dz)
103
63
  else:
104
- raise ValueError("Either Nx and Nz or dx and dz must be defined.")
64
+ raise ValueError("Either grid_size_x and grid_size_z or dx and dz must be defined.")
105
65
 
106
66
  # Create the pixel grid
107
67
  z_grid, x_grid = np.meshgrid(z, x, indexing="ij")
@@ -130,7 +90,7 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
130
90
  Cartesian coordinates (x, y, z), with nr being the number of radial pixels.
131
91
  """
132
92
  # Get focusing positions in rho-theta coordinates
133
- r = np.arange(rlims[0], rlims[1] + eps, dr) # Depth rho
93
+ r = np.arange(rlims[0], rlims[1], dr) # Depth rho
134
94
  t = dirs[:, 0] # Use azimuthal angle theta (ignore elevation angle)
135
95
  tt, rr = np.meshgrid(t, r, indexing="ij")
136
96
 
@@ -140,3 +100,32 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
140
100
  yy = 0 * xx
141
101
  grid = np.stack((xx, yy, zz), axis=-1)
142
102
  return grid
103
+
104
+
105
+ def polar_pixel_grid(polar_limits, zlims, num_radial_pixels: int, num_polar_pixels: int):
106
+ """Generate a polar grid.
107
+
108
+ Uses radial_pixel_grid but based on parameters that are present in the scan class.
109
+
110
+ Args:
111
+ polar_limits (tuple): Polar limits of pixel grid ([polar_min, polar_max])
112
+ zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
113
+ num_radial_pixels (int, optional): Number of depth pixels.
114
+ num_polar_pixels (int, optional): Number of polar pixels.
115
+
116
+ Returns:
117
+ grid (np.ndarray): Pixel grid of size (num_radial_pixels, num_polar_pixels, 3)
118
+ in Cartesian coordinates (x, y, z)
119
+ """
120
+ assert len(polar_limits) == 2, "polar_limits must be a tuple of length 2."
121
+ assert len(zlims) == 2, "zlims must be a tuple of length 2."
122
+
123
+ dr = (zlims[1] - zlims[0]) / num_radial_pixels
124
+
125
+ oris = np.array([0, 0, 0])
126
+ oris = np.tile(oris, (num_polar_pixels, 1))
127
+ dirs_az = np.linspace(*polar_limits, num_polar_pixels)
128
+
129
+ dirs_el = np.zeros(num_polar_pixels)
130
+ dirs = np.vstack((dirs_az, dirs_el)).T
131
+ return radial_pixel_grid(zlims, dr, oris, dirs).transpose(1, 0, 2)
zea/config.py CHANGED
@@ -47,8 +47,9 @@ import yaml
47
47
  from huggingface_hub import hf_hub_download
48
48
 
49
49
  from zea import log
50
+ from zea.data.preset_utils import HF_PREFIX, _hf_resolve_path
50
51
  from zea.internal.config.validation import config_schema
51
- from zea.internal.core import object_to_tensor
52
+ from zea.internal.core import dict_to_tensor
52
53
 
53
54
 
54
55
  class Config(dict):
@@ -430,8 +431,22 @@ class Config(dict):
430
431
  v._recursive_setattr(set_key, set_value)
431
432
 
432
433
  @classmethod
433
- def from_yaml(cls, path, **kwargs):
434
- """Load config object from yaml file"""
434
+ def from_path(cls, path, **kwargs):
435
+ """Load config object from a file path.
436
+
437
+ Args:
438
+ path (str or Path): The path to the config file.
439
+ Can be a string or a Path object. Additionally can be a string with
440
+ the prefix 'hf://', in which case it will be resolved to a
441
+ huggingface path.
442
+
443
+ Returns:
444
+ Config: config object.
445
+ """
446
+ if str(path).startswith(HF_PREFIX):
447
+ path = _hf_resolve_path(str(path))
448
+ if isinstance(path, str):
449
+ path = Path(path)
435
450
  return _load_config_from_yaml(path, config_class=cls, **kwargs)
436
451
 
437
452
  @classmethod
@@ -460,9 +475,14 @@ class Config(dict):
460
475
  local_path = hf_hub_download(repo_id, path, **kwargs)
461
476
  return _load_config_from_yaml(local_path, config_class=cls)
462
477
 
463
- def to_tensor(self):
478
+ @classmethod
479
+ def from_yaml(cls, path, **kwargs):
480
+ """Load config object from yaml file."""
481
+ return cls.from_path(path, **kwargs)
482
+
483
+ def to_tensor(self, keep_as_is=None):
464
484
  """Convert the attributes in the object to keras tensors"""
465
- return object_to_tensor(self)
485
+ return dict_to_tensor(self.serialize(), keep_as_is=keep_as_is)
466
486
 
467
487
 
468
488
  def check_config(config: Union[dict, Config], verbose: bool = False):
zea/data/__main__.py ADDED
@@ -0,0 +1,31 @@
1
+ """Command-line interface for copying a zea.Folder to a new location.
2
+
3
+ Usage:
4
+ python -m zea.data <source_folder> <destination_folder> <key>
5
+ """
6
+
7
+ import argparse
8
+
9
+ from zea import Folder
10
+
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser(description="Copy a zea.Folder to a new location.")
14
+ parser.add_argument("src", help="Source folder path")
15
+ parser.add_argument("dst", help="Destination folder path")
16
+ parser.add_argument("key", help="Key to access in the hdf5 files")
17
+ parser.add_argument(
18
+ "--mode",
19
+ default="a",
20
+ choices=["a", "w", "r+", "x"],
21
+ help="Mode in which to open the destination files (default: 'a')",
22
+ )
23
+
24
+ args = parser.parse_args()
25
+
26
+ src_folder = Folder(args.src, args.key, validate=False)
27
+ src_folder.copy(args.dst, args.key, mode=args.mode)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main()