zea 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +8 -7
- zea/__main__.py +8 -26
- zea/agent/selection.py +166 -0
- zea/backend/__init__.py +89 -0
- zea/backend/jax/__init__.py +14 -51
- zea/backend/tensorflow/__init__.py +0 -49
- zea/backend/torch/__init__.py +27 -62
- zea/data/__main__.py +6 -3
- zea/data/file.py +19 -74
- zea/data/layers.py +2 -3
- zea/display.py +1 -5
- zea/doppler.py +75 -0
- zea/internal/_generate_keras_ops.py +125 -0
- zea/internal/core.py +10 -3
- zea/internal/device.py +33 -16
- zea/internal/notebooks.py +39 -0
- zea/internal/operators.py +10 -0
- zea/internal/parameters.py +75 -19
- zea/internal/registry.py +1 -1
- zea/internal/viewer.py +24 -24
- zea/io_lib.py +60 -62
- zea/keras_ops.py +1989 -0
- zea/metrics.py +357 -65
- zea/models/__init__.py +6 -3
- zea/models/deeplabv3.py +131 -0
- zea/models/diffusion.py +18 -18
- zea/models/echonetlvh.py +279 -0
- zea/models/lv_segmentation.py +79 -0
- zea/models/presets.py +50 -0
- zea/models/regional_quality.py +122 -0
- zea/ops.py +52 -56
- zea/scan.py +10 -3
- zea/tensor_ops.py +251 -0
- zea/tools/fit_scan_cone.py +2 -2
- zea/tools/selection_tool.py +28 -9
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/METADATA +10 -3
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/RECORD +40 -33
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/WHEEL +1 -1
- zea/internal/convert.py +0 -150
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/entry_points.txt +0 -0
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info/licenses}/LICENSE +0 -0
zea/models/diffusion.py
CHANGED
|
@@ -728,15 +728,15 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
728
728
|
)
|
|
729
729
|
return next_noisy_images
|
|
730
730
|
|
|
731
|
-
def start_track_progress(self, diffusion_steps):
|
|
731
|
+
def start_track_progress(self, diffusion_steps, initial_step=0):
|
|
732
732
|
"""Initialize the progress tracking for the diffusion process.
|
|
733
|
-
|
|
734
733
|
For diffusion animation we keep track of the diffusion progress.
|
|
735
734
|
For large number of steps, we do not store all the images due to memory constraints.
|
|
736
735
|
"""
|
|
737
736
|
self.track_progress = []
|
|
738
|
-
|
|
739
|
-
|
|
737
|
+
remaining = max(1, diffusion_steps - int(initial_step))
|
|
738
|
+
if remaining > 50:
|
|
739
|
+
self.track_progress_interval = remaining // 50
|
|
740
740
|
else:
|
|
741
741
|
self.track_progress_interval = 1
|
|
742
742
|
|
|
@@ -823,7 +823,8 @@ class DPS(DiffusionGuidance):
|
|
|
823
823
|
omega,
|
|
824
824
|
**kwargs,
|
|
825
825
|
):
|
|
826
|
-
"""
|
|
826
|
+
"""
|
|
827
|
+
Compute measurement error for diffusion posterior sampling.
|
|
827
828
|
|
|
828
829
|
Args:
|
|
829
830
|
noisy_images: Noisy images.
|
|
@@ -849,20 +850,19 @@ class DPS(DiffusionGuidance):
|
|
|
849
850
|
return measurement_error, (pred_noises, pred_images)
|
|
850
851
|
|
|
851
852
|
def __call__(self, noisy_images, **kwargs):
|
|
852
|
-
"""
|
|
853
|
-
|
|
854
|
-
Returns a function with the following signature:
|
|
855
|
-
(
|
|
856
|
-
noisy_images,
|
|
857
|
-
measurement,
|
|
858
|
-
operator,
|
|
859
|
-
noise_rates,
|
|
860
|
-
signal_rates,
|
|
861
|
-
omega,
|
|
862
|
-
**operator_kwargs,
|
|
863
|
-
) -> gradients, (error, (pred_noises, pred_images))
|
|
853
|
+
"""
|
|
854
|
+
Call the gradient function.
|
|
864
855
|
|
|
865
|
-
|
|
856
|
+
Args:
|
|
857
|
+
noisy_images: Noisy images.
|
|
858
|
+
measurement: Target measurement.
|
|
859
|
+
operator: Forward operator.
|
|
860
|
+
noise_rates: Current noise rates.
|
|
861
|
+
signal_rates: Current signal rates.
|
|
862
|
+
omega: Weight for the measurement error.
|
|
863
|
+
**kwargs: Additional arguments for the operator.
|
|
866
864
|
|
|
865
|
+
Returns:
|
|
866
|
+
Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
|
|
867
867
|
"""
|
|
868
868
|
return self.gradient_fn(noisy_images, **kwargs)
|
zea/models/echonetlvh.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""EchoNetLVH model for segmentation of PLAX view cardiac ultrasound. For more details see https://echonet.github.io/lvh/index.html."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from zea.internal.registry import model_registry
|
|
7
|
+
from zea.models.base import BaseModel
|
|
8
|
+
from zea.models.deeplabv3 import DeeplabV3Plus
|
|
9
|
+
from zea.models.preset_utils import register_presets
|
|
10
|
+
from zea.models.presets import echonet_lvh_presets
|
|
11
|
+
from zea.utils import translate
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@model_registry(name="echonetlvh")
|
|
15
|
+
class EchoNetLVH(BaseModel):
|
|
16
|
+
"""
|
|
17
|
+
EchoNet Left Ventricular Hypertrophy (LVH) model for echocardiogram analysis.
|
|
18
|
+
|
|
19
|
+
This model performs semantic segmentation on echocardiogram images to identify
|
|
20
|
+
key anatomical landmarks for measuring left ventricular wall thickness:
|
|
21
|
+
- LVPWd_1: Left Ventricular Posterior Wall point 1
|
|
22
|
+
- LVPWd_2: Left Ventricular Posterior Wall point 2
|
|
23
|
+
- IVSd_1: Interventricular Septum point 1
|
|
24
|
+
- IVSd_2: Interventricular Septum point 2
|
|
25
|
+
|
|
26
|
+
The model outputs 4-channel logits corresponding to heatmaps for each landmark.
|
|
27
|
+
|
|
28
|
+
For more information, see the original project page at https://echonet.github.io/lvh/index.html
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, **kwargs):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the EchoNetLVH model.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
**kwargs: Additional keyword arguments passed to BaseModel
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(**kwargs)
|
|
39
|
+
|
|
40
|
+
# Pre-computed coordinate grid for efficient processing
|
|
41
|
+
self.coordinate_grid = ops.stack(
|
|
42
|
+
ops.cast(ops.convert_to_tensor(np.indices((224, 224))), "float32"), axis=-1
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Initialize the underlying segmentation network
|
|
46
|
+
self.network = DeeplabV3Plus(image_shape=(224, 224, 3), num_classes=4)
|
|
47
|
+
|
|
48
|
+
def call(self, inputs):
|
|
49
|
+
"""
|
|
50
|
+
Forward pass of the model.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
inputs (Tensor): Input images of shape [B, H, W, C]. They should
|
|
54
|
+
be scan converted, with pixel values in range [0, 255].
|
|
55
|
+
|
|
56
|
+
Returns:
|
|
57
|
+
Tensor: Logits of shape [B, H, W, 4] with 4 channels for each landmark
|
|
58
|
+
"""
|
|
59
|
+
assert len(ops.shape(inputs)) == 4
|
|
60
|
+
|
|
61
|
+
# Store original dimensions for output resizing
|
|
62
|
+
original_size = ops.shape(inputs)[1:3]
|
|
63
|
+
|
|
64
|
+
# Resize to network input size
|
|
65
|
+
inputs_resized = ops.image.resize(inputs, size=(224, 224))
|
|
66
|
+
|
|
67
|
+
# Get network predictions
|
|
68
|
+
logits = self.network(inputs_resized)
|
|
69
|
+
|
|
70
|
+
# Resize logits back to original input dimensions
|
|
71
|
+
logits_output = ops.image.resize(logits, original_size)
|
|
72
|
+
return logits_output
|
|
73
|
+
|
|
74
|
+
def extract_key_points_as_indices(self, logits):
|
|
75
|
+
"""
|
|
76
|
+
Extract key point coordinates from logits using center-of-mass calculation.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
logits (Tensor): Model output logits of shape [B, H, W, 4]
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Tensor: Key point coordinates of shape [B, 4, 2] where each point
|
|
83
|
+
is in (x, y) format
|
|
84
|
+
"""
|
|
85
|
+
# Create coordinate grid for the current logit dimensions
|
|
86
|
+
input_shape = ops.shape(logits)[1:3]
|
|
87
|
+
input_space_coordinate_grid = ops.stack(
|
|
88
|
+
ops.cast(ops.convert_to_tensor(np.indices(input_shape)), "float32"), axis=-1
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Transpose logits to [B, 4, H, W] for vectorized processing
|
|
92
|
+
logits_batchified = ops.transpose(logits, (0, 3, 1, 2))
|
|
93
|
+
|
|
94
|
+
# Extract expected coordinates for each channel
|
|
95
|
+
return ops.flip(
|
|
96
|
+
ops.vectorized_map(
|
|
97
|
+
lambda logit: self.expected_coordinate(logit, input_space_coordinate_grid),
|
|
98
|
+
logits_batchified,
|
|
99
|
+
),
|
|
100
|
+
axis=-1, # Flip to convert from (y, x) to (x, y)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def expected_coordinate(self, mask, coordinate_grid=None):
|
|
104
|
+
"""
|
|
105
|
+
Compute the expected coordinate (center-of-mass) of a heatmap.
|
|
106
|
+
|
|
107
|
+
This implements a differentiable version of taking the max of a heatmap
|
|
108
|
+
by computing the weighted average of coordinates.
|
|
109
|
+
|
|
110
|
+
Reference: https://arxiv.org/pdf/1711.08229
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
mask (Tensor): Heatmap of shape [B, H, W]
|
|
114
|
+
coordinate_grid (Tensor, optional): Grid of coordinates. If None,
|
|
115
|
+
uses self.coordinate_grid
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
Tensor: Expected coordinates of shape [B, 2] in (x, y) format
|
|
119
|
+
"""
|
|
120
|
+
if coordinate_grid is None:
|
|
121
|
+
coordinate_grid = self.coordinate_grid
|
|
122
|
+
|
|
123
|
+
# Ensure mask values are non-negative and normalized
|
|
124
|
+
mask_clipped = ops.clip(mask, 0, None)
|
|
125
|
+
mask_normed = mask_clipped / ops.max(mask_clipped)
|
|
126
|
+
|
|
127
|
+
def safe_normalize(m):
|
|
128
|
+
mask_sum = ops.sum(m)
|
|
129
|
+
return ops.where(mask_sum > 0, m / mask_sum, m)
|
|
130
|
+
|
|
131
|
+
coordinate_probabilities = ops.map(safe_normalize, mask_normed)
|
|
132
|
+
|
|
133
|
+
# Add dimension for broadcasting with coordinate grid
|
|
134
|
+
coordinate_probabilities = ops.expand_dims(coordinate_probabilities, axis=-1)
|
|
135
|
+
|
|
136
|
+
# Compute weighted average of coordinates
|
|
137
|
+
expected_coordinate = ops.sum(
|
|
138
|
+
ops.expand_dims(coordinate_grid, axis=0) * coordinate_probabilities,
|
|
139
|
+
axis=(1, 2),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
# Flip to convert from (y, x) to (x, y) format for euclidean distance calculation
|
|
143
|
+
return ops.flip(expected_coordinate, axis=-1)
|
|
144
|
+
|
|
145
|
+
def overlay_labels_on_image(self, image, label, alpha=0.5):
|
|
146
|
+
"""
|
|
147
|
+
Overlay predicted heatmaps and connecting lines on the input image.
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
image (Tensor): Input image of shape [H, W] or [H, W, C]
|
|
151
|
+
label (Tensor): Predicted logits of shape [H, W, 4]
|
|
152
|
+
alpha (float): Blending factor for overlay (0=transparent, 1=opaque)
|
|
153
|
+
|
|
154
|
+
Returns:
|
|
155
|
+
ndarray: Image with overlaid heatmaps and measurements of shape [H, W, 3]
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
import cv2
|
|
159
|
+
|
|
160
|
+
except ImportError as exc:
|
|
161
|
+
raise ImportError(
|
|
162
|
+
"OpenCV is required for `EchoNetLVH.overlay_labels_on_image`. "
|
|
163
|
+
"Please install it with 'pip install opencv-python' or "
|
|
164
|
+
"'pip install opencv-python-headless'."
|
|
165
|
+
) from exc
|
|
166
|
+
|
|
167
|
+
# Color scheme for each landmark
|
|
168
|
+
overlay_colors = np.array(
|
|
169
|
+
[
|
|
170
|
+
[1, 1, 0], # Yellow (LVPWd_X1)
|
|
171
|
+
[1, 0, 1], # Magenta (LVPWd_X2)
|
|
172
|
+
[0, 1, 1], # Cyan (IVSd_X1)
|
|
173
|
+
[0, 1, 0], # Green (IVSd_X2)
|
|
174
|
+
],
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
# Convert to numpy and ensure RGB format
|
|
178
|
+
image = ops.convert_to_numpy(image)
|
|
179
|
+
label = ops.convert_to_numpy(label)
|
|
180
|
+
|
|
181
|
+
if image.ndim == 2:
|
|
182
|
+
image = np.stack([image] * 3, axis=-1)
|
|
183
|
+
elif image.shape[-1] == 1:
|
|
184
|
+
image = np.repeat(image, 3, axis=-1)
|
|
185
|
+
else:
|
|
186
|
+
image = image.copy()
|
|
187
|
+
|
|
188
|
+
# Normalize each channel to [0, 1] for proper visualization
|
|
189
|
+
label = np.clip(label, 0, None)
|
|
190
|
+
for ch in range(label.shape[-1]):
|
|
191
|
+
max_val = np.max(label[..., ch])
|
|
192
|
+
if max_val > 0:
|
|
193
|
+
label[..., ch] = label[..., ch] / max_val
|
|
194
|
+
|
|
195
|
+
# Initialize overlay and tracking variables
|
|
196
|
+
overlay = np.zeros_like(image, dtype=np.float32)
|
|
197
|
+
centers = []
|
|
198
|
+
|
|
199
|
+
# Process each landmark channel
|
|
200
|
+
for ch in range(4):
|
|
201
|
+
# Square the mask to enhance peak responses
|
|
202
|
+
mask = label[..., ch] ** 2
|
|
203
|
+
color = overlay_colors[ch]
|
|
204
|
+
|
|
205
|
+
# Find center of mass for this channel
|
|
206
|
+
center_coords = self.expected_coordinate(ops.expand_dims(mask, axis=0))
|
|
207
|
+
center_x = ops.convert_to_numpy(center_coords[0, 0])
|
|
208
|
+
center_y = ops.convert_to_numpy(center_coords[0, 1])
|
|
209
|
+
|
|
210
|
+
# Bounds check before conversion to int
|
|
211
|
+
if 0 <= center_x < image.shape[1] and 0 <= center_y < image.shape[0]:
|
|
212
|
+
center = (int(center_x), int(center_y))
|
|
213
|
+
else:
|
|
214
|
+
center = None
|
|
215
|
+
|
|
216
|
+
if center is not None:
|
|
217
|
+
# Blend heatmap with overlay
|
|
218
|
+
mask_alpha = mask * alpha
|
|
219
|
+
for c in range(3):
|
|
220
|
+
overlay[..., c] += mask_alpha * color[c]
|
|
221
|
+
centers.append(center)
|
|
222
|
+
|
|
223
|
+
# Draw connecting lines between consecutive landmarks
|
|
224
|
+
for i in range(3):
|
|
225
|
+
pt1, pt2 = centers[i], centers[i + 1]
|
|
226
|
+
if pt1 is not None and pt2 is not None:
|
|
227
|
+
color = tuple(int(x) for x in overlay_colors[i])
|
|
228
|
+
|
|
229
|
+
# Create line mask
|
|
230
|
+
line_mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
|
231
|
+
cv2.line(line_mask, pt1, pt2, color=1, thickness=2)
|
|
232
|
+
|
|
233
|
+
# Apply line to overlay
|
|
234
|
+
for c in range(3):
|
|
235
|
+
overlay[..., c][line_mask.astype(bool)] = color[c] * alpha
|
|
236
|
+
|
|
237
|
+
# Blend overlay with original image
|
|
238
|
+
overlay = np.clip(overlay, 0, 1)
|
|
239
|
+
out = image.astype(np.float32)
|
|
240
|
+
blend_mask = np.any(overlay > 0.02, axis=-1)
|
|
241
|
+
out[blend_mask] = (1 - alpha) * out[blend_mask] + overlay[blend_mask]
|
|
242
|
+
|
|
243
|
+
return np.clip(out, 0, 1)
|
|
244
|
+
|
|
245
|
+
def visualize_logits(self, images, logits):
|
|
246
|
+
"""
|
|
247
|
+
Create visualization of model predictions overlaid on input images.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
images (Tensor): Input images of shape [B, H, W, C]
|
|
251
|
+
logits (Tensor): Model predictions of shape [B, H, W, 4]
|
|
252
|
+
|
|
253
|
+
Returns:
|
|
254
|
+
Tensor: Images with overlaid predictions of shape [B, H, W, 3]
|
|
255
|
+
"""
|
|
256
|
+
# Store original dimensions for final output
|
|
257
|
+
original_size = ops.shape(images)[1:3]
|
|
258
|
+
|
|
259
|
+
# Resize to standard processing size
|
|
260
|
+
images_resized = ops.image.resize(images, size=(224, 224), interpolation="nearest")
|
|
261
|
+
logits_resized = ops.image.resize(logits, size=(224, 224), interpolation="nearest")
|
|
262
|
+
|
|
263
|
+
# Normalize images to [0, 1] range
|
|
264
|
+
images_clipped = ops.clip(images_resized, 0, 255)
|
|
265
|
+
images = translate(images_clipped, range_from=(0, 255), range_to=(0, 1))
|
|
266
|
+
|
|
267
|
+
# Generate overlays for each image in the batch
|
|
268
|
+
images_with_overlay = []
|
|
269
|
+
for img, logit_heatmap in zip(images, logits_resized):
|
|
270
|
+
overlay = self.overlay_labels_on_image(img, logit_heatmap)
|
|
271
|
+
images_with_overlay.append(overlay)
|
|
272
|
+
|
|
273
|
+
# Stack results and resize back to original dimensions
|
|
274
|
+
images_with_overlay = np.stack(images_with_overlay, axis=0)
|
|
275
|
+
return ops.image.resize(images_with_overlay, original_size)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
# Register model presets
|
|
279
|
+
register_presets(echonet_lvh_presets, EchoNetLVH)
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The model is the nnU-Net model trained on the augmented CAMUS dataset from the following publication:
|
|
3
|
+
Van De Vyver, Gilles, et al.
|
|
4
|
+
"Generative augmentations for improved cardiac ultrasound segmentation using diffusion models."
|
|
5
|
+
arXiv preprint arXiv:2502.20100 (2025).
|
|
6
|
+
|
|
7
|
+
GitHub original repo: https://github.com/GillesVanDeVyver/EchoGAINS
|
|
8
|
+
|
|
9
|
+
At the time of writing (17 September 2025) and to the best of our knowledge,
|
|
10
|
+
it is the state-of-the-art model for left ventricle segmentation on the CAMUS dataset.
|
|
11
|
+
|
|
12
|
+
The model is originally a PyTorch model converted to ONNX. The model segments the left ventricle and myocardium.
|
|
13
|
+
|
|
14
|
+
Note:
|
|
15
|
+
-----
|
|
16
|
+
To use this model, you must install the `onnxruntime` Python package:
|
|
17
|
+
|
|
18
|
+
pip install onnxruntime
|
|
19
|
+
|
|
20
|
+
This is required for ONNX model inference.
|
|
21
|
+
""" # noqa: E501
|
|
22
|
+
|
|
23
|
+
from keras import ops
|
|
24
|
+
|
|
25
|
+
from zea.internal.registry import model_registry
|
|
26
|
+
from zea.models.base import BaseModel
|
|
27
|
+
from zea.models.preset_utils import get_preset_loader, register_presets
|
|
28
|
+
from zea.models.presets import augmented_camus_seg_presets
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@model_registry(name="augmented_camus_seg")
|
|
32
|
+
class AugmentedCamusSeg(BaseModel):
|
|
33
|
+
"""
|
|
34
|
+
nnU-Net based left ventricle and myocardium segmentation model.
|
|
35
|
+
|
|
36
|
+
- Trained on the augmented CAMUS dataset.
|
|
37
|
+
- This class loads an ONNX model and provides inference for cardiac ultrasound segmentation tasks.
|
|
38
|
+
|
|
39
|
+
""" # noqa: E501
|
|
40
|
+
|
|
41
|
+
def call(self, inputs):
|
|
42
|
+
"""
|
|
43
|
+
Run inference on the input data using the loaded ONNX model.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
inputs (np.ndarray): Input image or batch of images for segmentation.
|
|
47
|
+
Shape: [batch, 1, 256, 256]
|
|
48
|
+
Range: Any numeric range; normalized internally.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
np.ndarray: Segmentation mask(s) for left ventricle and myocardium.
|
|
52
|
+
Shape: [batch, 3, 256, 256] (logits for background, LV, myocardium)
|
|
53
|
+
|
|
54
|
+
Raises:
|
|
55
|
+
ValueError: If model weights are not loaded.
|
|
56
|
+
"""
|
|
57
|
+
if not hasattr(self, "onnx_sess"):
|
|
58
|
+
raise ValueError("Model weights not loaded. Please call custom_load_weights() first.")
|
|
59
|
+
input_name = self.onnx_sess.get_inputs()[0].name
|
|
60
|
+
output_name = self.onnx_sess.get_outputs()[0].name
|
|
61
|
+
inputs = ops.convert_to_numpy(inputs).astype("float32")
|
|
62
|
+
output = self.onnx_sess.run([output_name], {input_name: inputs})[0]
|
|
63
|
+
return output
|
|
64
|
+
|
|
65
|
+
def custom_load_weights(self, preset, **kwargs):
|
|
66
|
+
"""Load the ONNX weights for the segmentation model."""
|
|
67
|
+
try:
|
|
68
|
+
import onnxruntime
|
|
69
|
+
except ImportError:
|
|
70
|
+
raise ImportError(
|
|
71
|
+
"onnxruntime is not installed. Please run "
|
|
72
|
+
"`pip install onnxruntime` to use this model."
|
|
73
|
+
)
|
|
74
|
+
loader = get_preset_loader(preset)
|
|
75
|
+
filename = loader.get_file("model.onnx")
|
|
76
|
+
self.onnx_sess = onnxruntime.InferenceSession(filename)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
register_presets(augmented_camus_seg_presets, AugmentedCamusSeg)
|
zea/models/presets.py
CHANGED
|
@@ -47,6 +47,48 @@ echonet_dynamic_presets = {
|
|
|
47
47
|
},
|
|
48
48
|
}
|
|
49
49
|
|
|
50
|
+
augmented_camus_seg_presets = {
|
|
51
|
+
"augmented_camus_seg": {
|
|
52
|
+
"metadata": {
|
|
53
|
+
"description": (
|
|
54
|
+
"Augmented CAMUS segmentation model for cardiac ultrasound segmentation. "
|
|
55
|
+
"Original paper and code: https://arxiv.org/abs/2502.20100"
|
|
56
|
+
),
|
|
57
|
+
"params": 33468899,
|
|
58
|
+
"path": "lv_segmentation",
|
|
59
|
+
},
|
|
60
|
+
"hf_handle": "hf://zeahub/augmented-camus-segmentation",
|
|
61
|
+
},
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
regional_quality_presets = {
|
|
65
|
+
"mobilenetv2_regional_quality": {
|
|
66
|
+
"metadata": {
|
|
67
|
+
"description": (
|
|
68
|
+
"MobileNetV2-based regional myocardial image quality scoring model. "
|
|
69
|
+
"Original GitHub repository and code: https://github.com/GillesVanDeVyver/arqee"
|
|
70
|
+
),
|
|
71
|
+
"params": 2217064,
|
|
72
|
+
"path": "regional_quality",
|
|
73
|
+
},
|
|
74
|
+
"hf_handle": "hf://zeahub/mobilenetv2-regional-quality",
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
echonet_lvh_presets = {
|
|
79
|
+
"echonetlvh": {
|
|
80
|
+
"metadata": {
|
|
81
|
+
"description": (
|
|
82
|
+
"EchoNetLVH segmentation model for PLAX-view cardiac ultrasound segmentation. "
|
|
83
|
+
"Trained on images of size (224, 224)."
|
|
84
|
+
),
|
|
85
|
+
"params": 0,
|
|
86
|
+
"path": "echonetlvh",
|
|
87
|
+
},
|
|
88
|
+
"hf_handle": "hf://zeahub/echonetlvh",
|
|
89
|
+
},
|
|
90
|
+
}
|
|
91
|
+
|
|
50
92
|
lpips_presets = {
|
|
51
93
|
"lpips": {
|
|
52
94
|
"metadata": {
|
|
@@ -83,6 +125,14 @@ diffusion_model_presets = {
|
|
|
83
125
|
},
|
|
84
126
|
"hf_handle": "hf://zeahub/diffusion-echonet-dynamic",
|
|
85
127
|
},
|
|
128
|
+
"diffusion-echonetlvh-3-frame": {
|
|
129
|
+
"metadata": {
|
|
130
|
+
"description": ("3-frame diffusion model trained on EchoNetLVH dataset."),
|
|
131
|
+
"params": 0,
|
|
132
|
+
"path": "diffusion",
|
|
133
|
+
},
|
|
134
|
+
"hf_handle": "hf://zeahub/diffusion-echonetlvh",
|
|
135
|
+
},
|
|
86
136
|
}
|
|
87
137
|
|
|
88
138
|
carotid_segmenter_presets = {
|
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""
|
|
2
|
+
The model is the movilenetV2 based image quality model from:
|
|
3
|
+
Van De Vyver, et al. "Regional Image Quality Scoring for 2-D Echocardiography Using Deep Learning."
|
|
4
|
+
Ultrasound in Medicine & Biology 51.4 (2025): 638-649.
|
|
5
|
+
|
|
6
|
+
GitHub original repo: https://github.com/GillesVanDeVyver/arqee
|
|
7
|
+
|
|
8
|
+
The model is originally a PyTorch model converted to ONNX. The model predicts the regional image quality of
|
|
9
|
+
the myocardial regions in apical views. It can also be used to get the overall image quality by averaging the
|
|
10
|
+
regional scores.
|
|
11
|
+
|
|
12
|
+
Note:
|
|
13
|
+
-----
|
|
14
|
+
To use this model, you must install the `onnxruntime` Python package:
|
|
15
|
+
|
|
16
|
+
pip install onnxruntime
|
|
17
|
+
|
|
18
|
+
This is required for ONNX model inference.
|
|
19
|
+
""" # noqa: E501
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
from keras import ops
|
|
23
|
+
|
|
24
|
+
from zea.internal.registry import model_registry
|
|
25
|
+
from zea.models.base import BaseModel
|
|
26
|
+
from zea.models.preset_utils import get_preset_loader, register_presets
|
|
27
|
+
from zea.models.presets import regional_quality_presets
|
|
28
|
+
|
|
29
|
+
# Visualization colors and helper for regional quality (arqee-inspired)
|
|
30
|
+
QUALITY_COLORS = np.array(
|
|
31
|
+
[
|
|
32
|
+
[0.929, 0.106, 0.141], # not visible, red
|
|
33
|
+
[0.957, 0.396, 0.137], # poor, orange
|
|
34
|
+
[1, 0.984, 0.090], # ok, yellow
|
|
35
|
+
[0.553, 0.776, 0.098], # good, light green
|
|
36
|
+
[0.09, 0.407, 0.216], # excellent, dark green
|
|
37
|
+
]
|
|
38
|
+
)
|
|
39
|
+
REGION_LABELS = [
|
|
40
|
+
"basal_left",
|
|
41
|
+
"mid_left",
|
|
42
|
+
"apical_left",
|
|
43
|
+
"apical_right",
|
|
44
|
+
"mid_right",
|
|
45
|
+
"basal_right",
|
|
46
|
+
"annulus_left",
|
|
47
|
+
"annulus_right",
|
|
48
|
+
]
|
|
49
|
+
QUALITY_CLASSES = ["not visible", "poor", "ok", "good", "excellent"]
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@model_registry(name="mobilenetv2_regional_quality")
|
|
53
|
+
class MobileNetv2RegionalQuality(BaseModel):
|
|
54
|
+
"""
|
|
55
|
+
MobileNetV2 based regional image quality scoring model for myocardial regions in apical views.
|
|
56
|
+
|
|
57
|
+
This class loads an ONNX model and provides inference for regional image quality scoring tasks.
|
|
58
|
+
"""
|
|
59
|
+
|
|
60
|
+
def preprocess_input(self, inputs):
|
|
61
|
+
"""
|
|
62
|
+
Normalize input image(s) to [0, 255] range.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
inputs (np.ndarray): Input image(s), any numeric range.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
np.ndarray: Normalized image(s) in [0, 255] range.
|
|
69
|
+
"""
|
|
70
|
+
inputs = ops.convert_to_numpy(inputs).astype("float32")
|
|
71
|
+
max_val = np.max(inputs)
|
|
72
|
+
min_val = np.min(inputs)
|
|
73
|
+
denom = max_val - min_val
|
|
74
|
+
if denom > 0.0:
|
|
75
|
+
inputs = (inputs - min_val) / denom * 255.0
|
|
76
|
+
else:
|
|
77
|
+
inputs = np.zeros_like(inputs, dtype=np.float32)
|
|
78
|
+
return inputs
|
|
79
|
+
|
|
80
|
+
def call(self, inputs):
|
|
81
|
+
"""
|
|
82
|
+
Predict regional image quality scores for input image(s).
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
inputs (np.ndarray): Input image or batch of images.
|
|
86
|
+
Shape: [batch, 1, 256, 256]
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
np.ndarray: Regional quality scores.
|
|
90
|
+
Shape is [batch, 8] with regions in order:
|
|
91
|
+
basal_left, mid_left, apical_left, apical_right,
|
|
92
|
+
mid_right, basal_right, annulus_left, annulus_right
|
|
93
|
+
"""
|
|
94
|
+
if not hasattr(self, "onnx_sess"):
|
|
95
|
+
raise ValueError("Model weights not loaded. Please call custom_load_weights() first.")
|
|
96
|
+
input_name = self.onnx_sess.get_inputs()[0].name
|
|
97
|
+
output_name = self.onnx_sess.get_outputs()[0].name
|
|
98
|
+
inputs = self.preprocess_input(inputs)
|
|
99
|
+
|
|
100
|
+
output = self.onnx_sess.run([output_name], {input_name: inputs})[0]
|
|
101
|
+
slope = self.slope_intercept[0]
|
|
102
|
+
intercept = self.slope_intercept[1]
|
|
103
|
+
output_debiased = (output - intercept) / slope
|
|
104
|
+
return output_debiased
|
|
105
|
+
|
|
106
|
+
def custom_load_weights(self, preset, **kwargs):
|
|
107
|
+
"""Load ONNX model weights and bias correction for regional image quality scoring."""
|
|
108
|
+
try:
|
|
109
|
+
import onnxruntime
|
|
110
|
+
except ImportError:
|
|
111
|
+
raise ImportError(
|
|
112
|
+
"onnxruntime is not installed. Please run "
|
|
113
|
+
"`pip install onnxruntime` to use this model."
|
|
114
|
+
)
|
|
115
|
+
loader = get_preset_loader(preset)
|
|
116
|
+
filename = loader.get_file("model.onnx")
|
|
117
|
+
self.onnx_sess = onnxruntime.InferenceSession(filename)
|
|
118
|
+
filename = loader.get_file("slope_intercept_bias_correction.npy")
|
|
119
|
+
self.slope_intercept = np.load(filename)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
register_presets(regional_quality_presets, MobileNetv2RegionalQuality)
|