zea 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +1 -1
- zea/agent/selection.py +23 -12
- zea/backend/__init__.py +5 -3
- zea/beamform/__init__.py +1 -1
- zea/beamform/beamformer.py +5 -4
- zea/beamform/pfield.py +6 -3
- zea/beamform/pixelgrid.py +29 -28
- zea/config.py +25 -5
- zea/data/data_format.py +16 -14
- zea/data/file.py +3 -18
- zea/data/preset_utils.py +59 -45
- zea/datapaths.py +3 -3
- zea/display.py +2 -2
- zea/interface.py +1 -1
- zea/internal/checks.py +17 -8
- zea/internal/config/parameters.py +2 -2
- zea/internal/config/validation.py +2 -2
- zea/internal/core.py +10 -41
- zea/internal/parameters.py +204 -163
- zea/internal/registry.py +16 -8
- zea/io_lib.py +2 -7
- zea/models/base.py +1 -0
- zea/ops.py +159 -113
- zea/probes.py +23 -5
- zea/scan.py +97 -68
- zea/tensor_ops.py +3 -3
- zea/tools/selection_tool.py +1 -1
- zea/visualize.py +5 -3
- {zea-0.0.2.dist-info → zea-0.0.3.dist-info}/METADATA +13 -7
- {zea-0.0.2.dist-info → zea-0.0.3.dist-info}/RECORD +33 -33
- {zea-0.0.2.dist-info → zea-0.0.3.dist-info}/LICENSE +0 -0
- {zea-0.0.2.dist-info → zea-0.0.3.dist-info}/WHEEL +0 -0
- {zea-0.0.2.dist-info → zea-0.0.3.dist-info}/entry_points.txt +0 -0
zea/__init__.py
CHANGED
zea/agent/selection.py
CHANGED
|
@@ -92,6 +92,7 @@ class GreedyEntropy(LinesActionModel):
|
|
|
92
92
|
mean: float = 0,
|
|
93
93
|
std_dev: float = 1,
|
|
94
94
|
num_lines_to_update: int = 5,
|
|
95
|
+
entropy_sigma: float = 1.0,
|
|
95
96
|
):
|
|
96
97
|
"""Initialize the GreedyEntropy action selection model.
|
|
97
98
|
|
|
@@ -104,6 +105,8 @@ class GreedyEntropy(LinesActionModel):
|
|
|
104
105
|
std_dev (float, optional): The standard deviation of the RBF. Defaults to 1.
|
|
105
106
|
num_lines_to_update (int, optional): The number of lines around the selected line
|
|
106
107
|
to update. Must be odd.
|
|
108
|
+
entropy_sigma (float, optional): The standard deviation of the Gaussian
|
|
109
|
+
Mixture components used to approximate the posterior.
|
|
107
110
|
"""
|
|
108
111
|
super().__init__(n_actions, n_possible_actions, img_width, img_height)
|
|
109
112
|
|
|
@@ -124,7 +127,7 @@ class GreedyEntropy(LinesActionModel):
|
|
|
124
127
|
self.num_lines_to_update,
|
|
125
128
|
)
|
|
126
129
|
self.upside_down_gaussian = upside_down_gaussian(points_to_evaluate)
|
|
127
|
-
self.entropy_sigma =
|
|
130
|
+
self.entropy_sigma = entropy_sigma
|
|
128
131
|
|
|
129
132
|
@staticmethod
|
|
130
133
|
def compute_pairwise_pixel_gaussian_error(
|
|
@@ -157,15 +160,18 @@ class GreedyEntropy(LinesActionModel):
|
|
|
157
160
|
# This way we can just sum across the height axis and get the entropy
|
|
158
161
|
# for each pixel in a given line
|
|
159
162
|
batch_size, n_particles, _, height, _ = gaussian_error_per_pixel_i_j.shape
|
|
160
|
-
gaussian_error_per_pixel_stacked = ops.
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
163
|
+
gaussian_error_per_pixel_stacked = ops.transpose(
|
|
164
|
+
ops.reshape(
|
|
165
|
+
ops.transpose(gaussian_error_per_pixel_i_j, (0, 1, 2, 4, 3)),
|
|
166
|
+
[
|
|
167
|
+
batch_size,
|
|
168
|
+
n_particles,
|
|
169
|
+
n_particles,
|
|
170
|
+
n_possible_actions,
|
|
171
|
+
height * stack_n_cols,
|
|
172
|
+
],
|
|
173
|
+
),
|
|
174
|
+
(0, 1, 2, 4, 3),
|
|
169
175
|
)
|
|
170
176
|
# [n_particles, n_particles, batch, height, width]
|
|
171
177
|
return gaussian_error_per_pixel_stacked
|
|
@@ -428,10 +434,15 @@ class CovarianceSamplingLines(LinesActionModel):
|
|
|
428
434
|
generation. Defaults to None.
|
|
429
435
|
|
|
430
436
|
Returns:
|
|
431
|
-
Tensor
|
|
437
|
+
Tuple[Tensor, Tensor]:
|
|
438
|
+
- Newly selected lines as k-hot vectors, shaped (batch_size, n_possible_actions)
|
|
439
|
+
- Masks of shape (batch_size, img_height, img_width)
|
|
432
440
|
"""
|
|
433
441
|
batch_size, n_particles, rows, _ = ops.shape(particles)
|
|
434
442
|
|
|
443
|
+
# [batch_size, rows, cols, n_particles]
|
|
444
|
+
particles = ops.transpose(particles, (0, 2, 3, 1))
|
|
445
|
+
|
|
435
446
|
# [batch_size, rows * stack_n_cols, n_possible_actions, n_particles]
|
|
436
447
|
shape = [
|
|
437
448
|
batch_size,
|
|
@@ -441,7 +452,7 @@ class CovarianceSamplingLines(LinesActionModel):
|
|
|
441
452
|
]
|
|
442
453
|
particles = ops.reshape(particles, shape)
|
|
443
454
|
|
|
444
|
-
# [batch_size, rows, n_possible_actions, n_possible_actions]
|
|
455
|
+
# [batch_size, rows * stack_n_cols, n_possible_actions, n_possible_actions]
|
|
445
456
|
cov_matrix = tensor_ops.batch_cov(particles)
|
|
446
457
|
|
|
447
458
|
# Sum over the row dimension [batch_size, n_possible_actions, n_possible_actions]
|
zea/backend/__init__.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Backend
|
|
1
|
+
"""Backend-specific utilities.
|
|
2
2
|
|
|
3
3
|
This subpackage provides backend-specific utilities for the ``zea`` library. Most backend logic is handled by Keras 3, but a few features require custom wrappers to ensure compatibility and performance across JAX, TensorFlow, and PyTorch.
|
|
4
4
|
|
|
@@ -9,7 +9,7 @@ Key Features
|
|
|
9
9
|
------------
|
|
10
10
|
|
|
11
11
|
- **JIT Compilation** (:func:`zea.backend.jit`):
|
|
12
|
-
Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines.
|
|
12
|
+
Provides a unified interface for just-in-time (JIT) compilation of functions, dispatching to the appropriate backend (JAX or TensorFlow) as needed. This enables accelerated execution of computationally intensive routines. Note that jit compilation is not yet supported when using the `torch` backend.
|
|
13
13
|
|
|
14
14
|
- **Automatic Differentiation** (:class:`zea.backend.AutoGrad`):
|
|
15
15
|
Offers a backend-agnostic wrapper for automatic differentiation, allowing gradient computation regardless of the underlying ML library.
|
|
@@ -108,7 +108,9 @@ def _jit_compile(func, jax=True, tensorflow=True, **kwargs):
|
|
|
108
108
|
return func
|
|
109
109
|
else:
|
|
110
110
|
log.warning(
|
|
111
|
-
f"
|
|
111
|
+
f"JIT compilation not currently supported for backend {backend}. "
|
|
112
|
+
"Supported backends are 'tensorflow' and 'jax'."
|
|
112
113
|
)
|
|
114
|
+
log.warning("Initialize zea.Pipeline with jit_options=None to suppress this warning.")
|
|
113
115
|
log.warning("Falling back to non-compiled mode.")
|
|
114
116
|
return func
|
zea/beamform/__init__.py
CHANGED
zea/beamform/beamformer.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Main beamforming functions for ultrasound imaging."""
|
|
2
2
|
|
|
3
|
+
import keras
|
|
3
4
|
import numpy as np
|
|
4
5
|
from keras import ops
|
|
5
6
|
|
|
@@ -58,7 +59,7 @@ def tof_correction(
|
|
|
58
59
|
demodulation_frequency,
|
|
59
60
|
fnum,
|
|
60
61
|
angles,
|
|
61
|
-
|
|
62
|
+
focus_distances,
|
|
62
63
|
apply_phase_rotation=False,
|
|
63
64
|
apply_lens_correction=False,
|
|
64
65
|
lens_thickness=1e-3,
|
|
@@ -84,7 +85,7 @@ def tof_correction(
|
|
|
84
85
|
fnum (int, optional): Focus number. Defaults to 1.
|
|
85
86
|
angles (ops.Tensor): The angles of the plane waves in radians of shape
|
|
86
87
|
`(n_tx,)`
|
|
87
|
-
|
|
88
|
+
focus_distances (ops.Tensor): The focus distance of shape `(n_tx,)`
|
|
88
89
|
apply_phase_rotation (bool, optional): Whether to apply phase rotation to
|
|
89
90
|
time-of-flights. Defaults to False.
|
|
90
91
|
apply_lens_correction (bool, optional): Whether to apply lens correction to
|
|
@@ -133,7 +134,7 @@ def tof_correction(
|
|
|
133
134
|
sound_speed,
|
|
134
135
|
n_tx,
|
|
135
136
|
n_el,
|
|
136
|
-
|
|
137
|
+
focus_distances,
|
|
137
138
|
angles,
|
|
138
139
|
lens_thickness=lens_thickness,
|
|
139
140
|
lens_sound_speed=lens_sound_speed,
|
|
@@ -487,7 +488,7 @@ def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
|
|
|
487
488
|
# The f-number is fnum = z/aperture = 1/(2 * tan(alpha))
|
|
488
489
|
# Rearranging gives us alpha = arctan(1/(2 * fnum))
|
|
489
490
|
# We can use this to compute the maximum angle alpha that is allowed
|
|
490
|
-
max_alpha = ops.arctan(1 / (2 * f_number))
|
|
491
|
+
max_alpha = ops.arctan(1 / (2 * f_number + keras.backend.epsilon()))
|
|
491
492
|
|
|
492
493
|
normalized_angle = alpha / max_alpha
|
|
493
494
|
mask = fnum_window_fn(normalized_angle)
|
zea/beamform/pfield.py
CHANGED
|
@@ -60,7 +60,8 @@ def compute_pfield(
|
|
|
60
60
|
n_el (int): Number of elements in the probe.
|
|
61
61
|
probe_geometry (array): Geometry of the probe elements.
|
|
62
62
|
tx_apodizations (array): Transmit apodization values.
|
|
63
|
-
grid (array): Grid points where the pressure field is computed
|
|
63
|
+
grid (array): Grid points where the pressure field is computed
|
|
64
|
+
of shape (grid_size_z, grid_size_x, 3).
|
|
64
65
|
t0_delays (array): Transmit delays for each transmit event.
|
|
65
66
|
frequency_step (int, optional): Frequency step. Default is 4.
|
|
66
67
|
Higher is faster but less accurate.
|
|
@@ -78,7 +79,8 @@ def compute_pfield(
|
|
|
78
79
|
verbose (bool, optional): Whether to print progress.
|
|
79
80
|
|
|
80
81
|
Returns:
|
|
81
|
-
ops.array: The (normalized) pressure field (across tx events)
|
|
82
|
+
ops.array: The (normalized) pressure field (across tx events)
|
|
83
|
+
of shape (n_tx, grid_size_z, grid_size_x).
|
|
82
84
|
"""
|
|
83
85
|
# medium params
|
|
84
86
|
alpha_db = 0 # currently we ignore attenuation in the compounding
|
|
@@ -293,7 +295,8 @@ def normalize_pressure_field(pfield, alpha: float = 1.0, percentile: float = 10.
|
|
|
293
295
|
Normalize the input array of intensities by zeroing out values below a given percentile.
|
|
294
296
|
|
|
295
297
|
Args:
|
|
296
|
-
pfield (array): The unnormalized pressure field array
|
|
298
|
+
pfield (array): The unnormalized pressure field array
|
|
299
|
+
of shape (n_tx, grid_size_z, grid_size_x).
|
|
297
300
|
alpha (float, optional): Exponent to 'sharpen or smooth' the weighting.
|
|
298
301
|
Higher values result in sharper weighting. Default is 1.0.
|
|
299
302
|
percentile (int, optional): minimum percentile threshold to keep in the weighting.
|
zea/beamform/pixelgrid.py
CHANGED
|
@@ -16,52 +16,52 @@ def check_for_aliasing(scan):
|
|
|
16
16
|
depth = scan.zlims[1] - scan.zlims[0]
|
|
17
17
|
wvln = scan.wavelength
|
|
18
18
|
|
|
19
|
-
if width / scan.
|
|
19
|
+
if width / scan.grid_size_x > wvln / 2:
|
|
20
20
|
log.warning(
|
|
21
|
-
f"width/
|
|
22
|
-
f"Consider either increasing scan.
|
|
23
|
-
"increasing scan.pixels_per_wavelength to 2 or more."
|
|
21
|
+
f"width/grid_size_x = {width / scan.grid_size_x:.7f} < wavelength/2 = {wvln / 2}. "
|
|
22
|
+
f"Consider either increasing scan.grid_size_x to {int(np.ceil(width / (wvln / 2)))} "
|
|
23
|
+
"or more, or increasing scan.pixels_per_wavelength to 2 or more."
|
|
24
24
|
)
|
|
25
|
-
if depth / scan.
|
|
25
|
+
if depth / scan.grid_size_z > wvln / 2:
|
|
26
26
|
log.warning(
|
|
27
|
-
f"depth/
|
|
28
|
-
f"Consider either increasing scan.
|
|
29
|
-
"increasing scan.pixels_per_wavelength to 2 or more."
|
|
27
|
+
f"depth/grid_size_z = {depth / scan.grid_size_z:.7f} < wavelength/2 = {wvln / 2:.7f}. "
|
|
28
|
+
f"Consider either increasing scan.grid_size_z to {int(np.ceil(depth / (wvln / 2)))} "
|
|
29
|
+
"or more, or increasing scan.pixels_per_wavelength to 2 or more."
|
|
30
30
|
)
|
|
31
31
|
|
|
32
32
|
|
|
33
|
-
def cartesian_pixel_grid(xlims, zlims,
|
|
33
|
+
def cartesian_pixel_grid(xlims, zlims, grid_size_x=None, grid_size_z=None, dx=None, dz=None):
|
|
34
34
|
"""Generate a Cartesian pixel grid based on input parameters.
|
|
35
35
|
|
|
36
36
|
Args:
|
|
37
37
|
xlims (tuple): Azimuthal limits of pixel grid ([xmin, xmax])
|
|
38
38
|
zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
|
|
39
|
-
|
|
40
|
-
|
|
39
|
+
grid_size_x (int): Number of azimuthal pixels, overrides dx and dz parameters
|
|
40
|
+
grid_size_z (int): Number of depth pixels, overrides dx and dz parameters
|
|
41
41
|
dx (float): Pixel spacing in azimuth
|
|
42
42
|
dz (float): Pixel spacing in depth
|
|
43
43
|
|
|
44
44
|
Raises:
|
|
45
|
-
ValueError: Either
|
|
45
|
+
ValueError: Either grid_size_x and grid_size_z or dx and dz must be defined.
|
|
46
46
|
|
|
47
47
|
Returns:
|
|
48
|
-
grid (np.ndarray): Pixel grid of size (
|
|
48
|
+
grid (np.ndarray): Pixel grid of size (grid_size_z, nx, 3) in
|
|
49
49
|
Cartesian coordinates (x, y, z)
|
|
50
50
|
"""
|
|
51
|
-
assert (bool(
|
|
52
|
-
"Either
|
|
51
|
+
assert (bool(grid_size_x) and bool(grid_size_z)) ^ (bool(dx) and bool(dz)), (
|
|
52
|
+
"Either grid_size_x and grid_size_z or dx and dz must be defined."
|
|
53
53
|
)
|
|
54
54
|
|
|
55
55
|
# Determine the grid spacing
|
|
56
|
-
if
|
|
57
|
-
x = np.linspace(xlims[0], xlims[1] + eps,
|
|
58
|
-
z = np.linspace(zlims[0], zlims[1] + eps,
|
|
56
|
+
if grid_size_x is not None and grid_size_z is not None:
|
|
57
|
+
x = np.linspace(xlims[0], xlims[1] + eps, grid_size_x)
|
|
58
|
+
z = np.linspace(zlims[0], zlims[1] + eps, grid_size_z)
|
|
59
59
|
elif dx is not None and dz is not None:
|
|
60
60
|
sign = np.sign(xlims[1] - xlims[0])
|
|
61
61
|
x = np.arange(xlims[0], xlims[1] + eps, sign * dx)
|
|
62
62
|
z = np.arange(zlims[0], zlims[1] + eps, sign * dz)
|
|
63
63
|
else:
|
|
64
|
-
raise ValueError("Either
|
|
64
|
+
raise ValueError("Either grid_size_x and grid_size_z or dx and dz must be defined.")
|
|
65
65
|
|
|
66
66
|
# Create the pixel grid
|
|
67
67
|
z_grid, x_grid = np.meshgrid(z, x, indexing="ij")
|
|
@@ -102,29 +102,30 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
|
|
|
102
102
|
return grid
|
|
103
103
|
|
|
104
104
|
|
|
105
|
-
def polar_pixel_grid(polar_limits, zlims,
|
|
105
|
+
def polar_pixel_grid(polar_limits, zlims, num_radial_pixels: int, num_polar_pixels: int):
|
|
106
106
|
"""Generate a polar grid.
|
|
107
107
|
|
|
108
108
|
Uses radial_pixel_grid but based on parameters that are present in the scan class.
|
|
109
109
|
|
|
110
110
|
Args:
|
|
111
|
-
polar_limits (tuple):
|
|
111
|
+
polar_limits (tuple): Polar limits of pixel grid ([polar_min, polar_max])
|
|
112
112
|
zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
|
|
113
|
-
|
|
114
|
-
|
|
113
|
+
num_radial_pixels (int, optional): Number of depth pixels.
|
|
114
|
+
num_polar_pixels (int, optional): Number of polar pixels.
|
|
115
115
|
|
|
116
116
|
Returns:
|
|
117
|
-
grid (np.ndarray): Pixel grid of size (
|
|
117
|
+
grid (np.ndarray): Pixel grid of size (num_radial_pixels, num_polar_pixels, 3)
|
|
118
|
+
in Cartesian coordinates (x, y, z)
|
|
118
119
|
"""
|
|
119
120
|
assert len(polar_limits) == 2, "polar_limits must be a tuple of length 2."
|
|
120
121
|
assert len(zlims) == 2, "zlims must be a tuple of length 2."
|
|
121
122
|
|
|
122
|
-
dr = (zlims[1] - zlims[0]) /
|
|
123
|
+
dr = (zlims[1] - zlims[0]) / num_radial_pixels
|
|
123
124
|
|
|
124
125
|
oris = np.array([0, 0, 0])
|
|
125
|
-
oris = np.tile(oris, (
|
|
126
|
-
dirs_az = np.linspace(*polar_limits,
|
|
126
|
+
oris = np.tile(oris, (num_polar_pixels, 1))
|
|
127
|
+
dirs_az = np.linspace(*polar_limits, num_polar_pixels)
|
|
127
128
|
|
|
128
|
-
dirs_el = np.zeros(
|
|
129
|
+
dirs_el = np.zeros(num_polar_pixels)
|
|
129
130
|
dirs = np.vstack((dirs_az, dirs_el)).T
|
|
130
131
|
return radial_pixel_grid(zlims, dr, oris, dirs).transpose(1, 0, 2)
|
zea/config.py
CHANGED
|
@@ -47,8 +47,9 @@ import yaml
|
|
|
47
47
|
from huggingface_hub import hf_hub_download
|
|
48
48
|
|
|
49
49
|
from zea import log
|
|
50
|
+
from zea.data.preset_utils import HF_PREFIX, _hf_resolve_path
|
|
50
51
|
from zea.internal.config.validation import config_schema
|
|
51
|
-
from zea.internal.core import
|
|
52
|
+
from zea.internal.core import dict_to_tensor
|
|
52
53
|
|
|
53
54
|
|
|
54
55
|
class Config(dict):
|
|
@@ -430,8 +431,22 @@ class Config(dict):
|
|
|
430
431
|
v._recursive_setattr(set_key, set_value)
|
|
431
432
|
|
|
432
433
|
@classmethod
|
|
433
|
-
def
|
|
434
|
-
"""Load config object from
|
|
434
|
+
def from_path(cls, path, **kwargs):
|
|
435
|
+
"""Load config object from a file path.
|
|
436
|
+
|
|
437
|
+
Args:
|
|
438
|
+
path (str or Path): The path to the config file.
|
|
439
|
+
Can be a string or a Path object. Additionally can be a string with
|
|
440
|
+
the prefix 'hf://', in which case it will be resolved to a
|
|
441
|
+
huggingface path.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Config: config object.
|
|
445
|
+
"""
|
|
446
|
+
if str(path).startswith(HF_PREFIX):
|
|
447
|
+
path = _hf_resolve_path(str(path))
|
|
448
|
+
if isinstance(path, str):
|
|
449
|
+
path = Path(path)
|
|
435
450
|
return _load_config_from_yaml(path, config_class=cls, **kwargs)
|
|
436
451
|
|
|
437
452
|
@classmethod
|
|
@@ -460,9 +475,14 @@ class Config(dict):
|
|
|
460
475
|
local_path = hf_hub_download(repo_id, path, **kwargs)
|
|
461
476
|
return _load_config_from_yaml(local_path, config_class=cls)
|
|
462
477
|
|
|
463
|
-
|
|
478
|
+
@classmethod
|
|
479
|
+
def from_yaml(cls, path, **kwargs):
|
|
480
|
+
"""Load config object from yaml file."""
|
|
481
|
+
return cls.from_path(path, **kwargs)
|
|
482
|
+
|
|
483
|
+
def to_tensor(self, keep_as_is=None):
|
|
464
484
|
"""Convert the attributes in the object to keras tensors"""
|
|
465
|
-
return
|
|
485
|
+
return dict_to_tensor(self.serialize(), keep_as_is=keep_as_is)
|
|
466
486
|
|
|
467
487
|
|
|
468
488
|
def check_config(config: Union[dict, Config], verbose: bool = False):
|
zea/data/data_format.py
CHANGED
|
@@ -42,8 +42,8 @@ def generate_example_dataset(
|
|
|
42
42
|
sound_speed=1540,
|
|
43
43
|
center_frequency=7e6,
|
|
44
44
|
sampling_frequency=40e6,
|
|
45
|
-
|
|
46
|
-
|
|
45
|
+
grid_size_z=512,
|
|
46
|
+
grid_size_x=512,
|
|
47
47
|
):
|
|
48
48
|
"""Generates an example dataset that contains all the necessary fields.
|
|
49
49
|
Note: This dataset does not contain actual data, but is filled with random
|
|
@@ -60,7 +60,7 @@ def generate_example_dataset(
|
|
|
60
60
|
# creating some fake raw and image data
|
|
61
61
|
raw_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
|
|
62
62
|
# image data is in dB
|
|
63
|
-
image = np.ones((n_frames,
|
|
63
|
+
image = np.ones((n_frames, grid_size_z, grid_size_x)) * -40
|
|
64
64
|
|
|
65
65
|
# creating some fake scan parameters
|
|
66
66
|
t0_delays = np.zeros((n_tx, n_el), dtype=np.float32)
|
|
@@ -82,8 +82,8 @@ def generate_example_dataset(
|
|
|
82
82
|
|
|
83
83
|
if add_optional_dtypes:
|
|
84
84
|
aligned_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
|
|
85
|
-
envelope_data = np.ones((n_frames,
|
|
86
|
-
beamformed_data = np.ones((n_frames,
|
|
85
|
+
envelope_data = np.ones((n_frames, grid_size_z, grid_size_x))
|
|
86
|
+
beamformed_data = np.ones((n_frames, grid_size_z, grid_size_x, n_ch))
|
|
87
87
|
image_sc = np.ones_like(image)
|
|
88
88
|
else:
|
|
89
89
|
aligned_data = None
|
|
@@ -123,10 +123,11 @@ def validate_input_data(raw_data, aligned_data, envelope_data, beamformed_data,
|
|
|
123
123
|
aligned_data (np.ndarray): The aligned data of the ultrasound measurement of
|
|
124
124
|
shape (n_frames, n_tx, n_ax, n_el, n_ch).
|
|
125
125
|
envelope_data (np.ndarray): The envelope data of the ultrasound measurement of
|
|
126
|
-
shape (n_frames,
|
|
126
|
+
shape (n_frames, grid_size_z, grid_size_x).
|
|
127
127
|
beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of
|
|
128
|
-
shape (n_frames,
|
|
129
|
-
image (np.ndarray): The ultrasound images to be saved
|
|
128
|
+
shape (n_frames, grid_size_z, grid_size_x).
|
|
129
|
+
image (np.ndarray): The ultrasound images to be saved
|
|
130
|
+
of shape (n_frames, grid_size_z, grid_size_x).
|
|
130
131
|
image_sc (np.ndarray): The scan converted ultrasound images to be saved
|
|
131
132
|
of shape (n_frames, output_size_z, output_size_x).
|
|
132
133
|
"""
|
|
@@ -254,7 +255,7 @@ def _write_datasets(
|
|
|
254
255
|
group_name=data_group_name,
|
|
255
256
|
name="envelope_data",
|
|
256
257
|
data=_convert_datatype(envelope_data),
|
|
257
|
-
description="The envelope_data of shape (n_frames,
|
|
258
|
+
description="The envelope_data of shape (n_frames, grid_size_z, grid_size_x).",
|
|
258
259
|
unit="unitless",
|
|
259
260
|
)
|
|
260
261
|
|
|
@@ -262,7 +263,7 @@ def _write_datasets(
|
|
|
262
263
|
group_name=data_group_name,
|
|
263
264
|
name="beamformed_data",
|
|
264
265
|
data=_convert_datatype(beamformed_data),
|
|
265
|
-
description="The beamformed_data of shape (n_frames,
|
|
266
|
+
description="The beamformed_data of shape (n_frames, grid_size_z, grid_size_x).",
|
|
266
267
|
unit="unitless",
|
|
267
268
|
)
|
|
268
269
|
|
|
@@ -271,7 +272,7 @@ def _write_datasets(
|
|
|
271
272
|
name="image",
|
|
272
273
|
data=_convert_datatype(image),
|
|
273
274
|
unit="unitless",
|
|
274
|
-
description="The images of shape [n_frames,
|
|
275
|
+
description="The images of shape [n_frames, grid_size_z, grid_size_x]",
|
|
275
276
|
)
|
|
276
277
|
|
|
277
278
|
_add_dataset(
|
|
@@ -546,10 +547,11 @@ def generate_zea_dataset(
|
|
|
546
547
|
aligned_data (np.ndarray): The aligned data of the ultrasound measurement of
|
|
547
548
|
shape (n_frames, n_tx, n_ax, n_el, n_ch).
|
|
548
549
|
envelope_data (np.ndarray): The envelope data of the ultrasound measurement of
|
|
549
|
-
shape (n_frames,
|
|
550
|
+
shape (n_frames, grid_size_z, grid_size_x).
|
|
550
551
|
beamformed_data (np.ndarray): The beamformed data of the ultrasound measurement of
|
|
551
|
-
shape (n_frames,
|
|
552
|
-
image (np.ndarray): The ultrasound images to be saved
|
|
552
|
+
shape (n_frames, grid_size_z, grid_size_x, n_ch).
|
|
553
|
+
image (np.ndarray): The ultrasound images to be saved
|
|
554
|
+
of shape (n_frames, grid_size_z, grid_size_x).
|
|
553
555
|
image_sc (np.ndarray): The scan converted ultrasound images to be saved
|
|
554
556
|
of shape (n_frames, output_size_z, output_size_x).
|
|
555
557
|
probe_geometry (np.ndarray): The probe geometry of shape (n_el, 3).
|
zea/data/file.py
CHANGED
|
@@ -366,9 +366,7 @@ class File(h5py.File):
|
|
|
366
366
|
"""
|
|
367
367
|
file_scan_parameters = self.get_parameters(event)
|
|
368
368
|
|
|
369
|
-
probe_parameters = reduce_to_signature(
|
|
370
|
-
Probe.from_name("generic").__init__, file_scan_parameters
|
|
371
|
-
)
|
|
369
|
+
probe_parameters = reduce_to_signature(Probe.__init__, file_scan_parameters)
|
|
372
370
|
return probe_parameters
|
|
373
371
|
|
|
374
372
|
def probe(self, event=None) -> Probe:
|
|
@@ -387,21 +385,8 @@ class File(h5py.File):
|
|
|
387
385
|
Returns:
|
|
388
386
|
Probe: The probe object.
|
|
389
387
|
"""
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
return Probe.from_name(self.probe_name, **probe_parameters)
|
|
393
|
-
else:
|
|
394
|
-
probe = Probe.from_name(self.probe_name)
|
|
395
|
-
|
|
396
|
-
probe_geometry = probe_parameters.get("probe_geometry", None)
|
|
397
|
-
if not np.allclose(probe_geometry, probe.probe_geometry):
|
|
398
|
-
probe.probe_geometry = probe_geometry
|
|
399
|
-
log.warning(
|
|
400
|
-
"The probe geometry in the data file does not "
|
|
401
|
-
"match the probe geometry of the probe. The probe "
|
|
402
|
-
"geometry has been updated to match the data file."
|
|
403
|
-
)
|
|
404
|
-
return probe
|
|
388
|
+
probe_parameters_file = self.get_probe_parameters(event)
|
|
389
|
+
return Probe.from_parameters(self.probe_name, probe_parameters_file)
|
|
405
390
|
|
|
406
391
|
def recursively_load_dict_contents_from_group(self, path: str, squeeze: bool = False) -> dict:
|
|
407
392
|
"""Load dict from contents of group
|
zea/data/preset_utils.py
CHANGED
|
@@ -50,60 +50,74 @@ def _hf_download(repo_id, filename, cache_dir=HF_DATASETS_DIR):
|
|
|
50
50
|
)
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
def
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
53
|
+
def _get_snapshot_dir_from_downloaded_file(downloaded_file_path: str | Path) -> Path:
|
|
54
|
+
"""Extract the snapshot directory from a downloaded file's path.
|
|
55
|
+
|
|
56
|
+
HF Hub downloads to: cache_dir/datasets--org--repo/snapshots/{hash}/path/to/filename
|
|
57
|
+
This navigates up to find the {hash} directory (the snapshot directory).
|
|
58
|
+
"""
|
|
59
|
+
file_path = Path(downloaded_file_path)
|
|
60
|
+
|
|
61
|
+
# Navigate up the path until we find the snapshots directory
|
|
62
|
+
current = file_path.parent
|
|
63
|
+
while current.name != "snapshots" and current.parent != current:
|
|
64
|
+
current = current.parent
|
|
65
|
+
|
|
66
|
+
if current.name == "snapshots":
|
|
67
|
+
# Return the snapshot hash directory (first subdirectory of snapshots)
|
|
68
|
+
snapshot_dirs = [d for d in current.iterdir() if d.is_dir()]
|
|
69
|
+
if snapshot_dirs:
|
|
70
|
+
# Return the most recent snapshot directory
|
|
71
|
+
return max(snapshot_dirs, key=lambda p: p.stat().st_mtime)
|
|
72
|
+
|
|
73
|
+
raise FileNotFoundError(f"Could not find snapshot directory for {downloaded_file_path}")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _download_files_in_path(
|
|
77
|
+
repo_id: str, files: list, path_filter: str = None, cache_dir=HF_DATASETS_DIR
|
|
78
|
+
) -> list[str]:
|
|
79
|
+
"""Download all files matching the path filter."""
|
|
80
|
+
downloaded_files = []
|
|
81
|
+
for f in files:
|
|
82
|
+
if path_filter is None or f.startswith(path_filter):
|
|
83
|
+
downloaded_path = _hf_download(repo_id, f, cache_dir)
|
|
84
|
+
downloaded_files.append(downloaded_path)
|
|
85
|
+
|
|
86
|
+
return downloaded_files
|
|
72
87
|
|
|
73
88
|
|
|
74
89
|
def _hf_resolve_path(hf_path: str, cache_dir=HF_DATASETS_DIR):
|
|
75
|
-
"""
|
|
76
|
-
|
|
90
|
+
"""Resolve a Hugging Face path to a local cache directory path.
|
|
91
|
+
|
|
92
|
+
Downloads files from a HuggingFace dataset repository and returns
|
|
93
|
+
the local path where they are cached. Handles:
|
|
94
|
+
- hf://org/repo/subdir/ - Downloads all files in subdirectory
|
|
95
|
+
- hf://org/repo/file.h5 - Downloads specific file
|
|
96
|
+
- hf://org/repo - Downloads all files in repo
|
|
77
97
|
"""
|
|
78
98
|
repo_id, subpath = _hf_parse_path(hf_path)
|
|
79
99
|
files = _hf_list_files(repo_id)
|
|
80
|
-
snapshot_dir = _hf_get_snapshot_dir(repo_id, cache_dir)
|
|
81
|
-
|
|
82
|
-
def is_h5(f):
|
|
83
|
-
return f.endswith(".h5") or f.endswith(".hdf5")
|
|
84
100
|
|
|
85
101
|
if subpath:
|
|
86
|
-
# Directory
|
|
102
|
+
# Directory case
|
|
87
103
|
if any(f.startswith(subpath + "/") for f in files):
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
# File
|
|
96
|
-
elif
|
|
97
|
-
_hf_download(repo_id, subpath, cache_dir)
|
|
98
|
-
|
|
99
|
-
if not local_file.exists():
|
|
100
|
-
raise FileNotFoundError(f"File {local_file} not found after download.")
|
|
101
|
-
return local_file
|
|
104
|
+
downloaded_files = _download_files_in_path(repo_id, files, subpath + "/", cache_dir)
|
|
105
|
+
if not downloaded_files:
|
|
106
|
+
raise FileNotFoundError(f"No files found in directory {subpath}")
|
|
107
|
+
|
|
108
|
+
snapshot_dir = _get_snapshot_dir_from_downloaded_file(downloaded_files[0])
|
|
109
|
+
return snapshot_dir / subpath
|
|
110
|
+
|
|
111
|
+
# File case
|
|
112
|
+
elif subpath in files:
|
|
113
|
+
downloaded_file = _hf_download(repo_id, subpath, cache_dir)
|
|
114
|
+
return Path(downloaded_file)
|
|
102
115
|
else:
|
|
103
116
|
raise FileNotFoundError(f"{subpath} not found in {repo_id}")
|
|
104
117
|
else:
|
|
105
|
-
# All
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
118
|
+
# All files in repo
|
|
119
|
+
downloaded_files = _download_files_in_path(repo_id, files, None, cache_dir)
|
|
120
|
+
if not downloaded_files:
|
|
121
|
+
raise FileNotFoundError(f"No files found in repository {repo_id}")
|
|
122
|
+
|
|
123
|
+
return _get_snapshot_dir_from_downloaded_file(downloaded_files[0])
|
zea/datapaths.py
CHANGED
|
@@ -257,13 +257,13 @@ def set_data_paths(
|
|
|
257
257
|
|
|
258
258
|
Returns:
|
|
259
259
|
dict: Absolute paths to location of data. Stores the following parameters:
|
|
260
|
-
``data_root``, ``
|
|
260
|
+
``data_root``, ``zea_root``, ``output``, ``system``, ``username``, ``hostname``
|
|
261
261
|
|
|
262
262
|
"""
|
|
263
263
|
username = getpass.getuser()
|
|
264
264
|
system = platform.system().lower()
|
|
265
265
|
hostname = socket.gethostname()
|
|
266
|
-
|
|
266
|
+
zea_root = importlib.resources.files("zea")
|
|
267
267
|
|
|
268
268
|
# If user_config is None, use the default users.yaml file
|
|
269
269
|
if isinstance(user_config, type(None)):
|
|
@@ -304,7 +304,7 @@ def set_data_paths(
|
|
|
304
304
|
|
|
305
305
|
data_path = {
|
|
306
306
|
"data_root": Path(data_root),
|
|
307
|
-
"
|
|
307
|
+
"zea_root": zea_root,
|
|
308
308
|
"output": Path(output),
|
|
309
309
|
"system": system,
|
|
310
310
|
"username": username,
|
zea/display.py
CHANGED
|
@@ -130,7 +130,7 @@ def scan_convert_2d(
|
|
|
130
130
|
|
|
131
131
|
Returns:
|
|
132
132
|
ndarray: The scan-converted 2D ultrasound image in Cartesian coordinates.
|
|
133
|
-
Has dimensions (
|
|
133
|
+
Has dimensions (grid_size_z, grid_size_x). Coordinates outside the input image
|
|
134
134
|
ranges are filled with NaNs.
|
|
135
135
|
parameters (dict): A dictionary containing information about the scan conversion.
|
|
136
136
|
Contains the resolution, x, and z limits, rho and theta ranges.
|
|
@@ -269,7 +269,7 @@ def scan_convert_3d(
|
|
|
269
269
|
|
|
270
270
|
Returns:
|
|
271
271
|
ndarray: The scan-converted 3D ultrasound image in Cartesian coordinates.
|
|
272
|
-
Has dimensions (
|
|
272
|
+
Has dimensions (grid_size_z, grid_size_x, n_y). Coordinates outside the input image
|
|
273
273
|
ranges are filled with NaNs.
|
|
274
274
|
parameters (dict): A dictionary containing information about the scan conversion.
|
|
275
275
|
Contains the resolution, x, y, and z limits, rho, theta, and phi ranges.
|