zea 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +8 -7
- zea/__main__.py +8 -26
- zea/data/__main__.py +6 -3
- zea/data/file.py +19 -74
- zea/display.py +1 -5
- zea/doppler.py +75 -0
- zea/internal/_generate_keras_ops.py +125 -0
- zea/internal/core.py +10 -3
- zea/internal/device.py +33 -16
- zea/internal/notebooks.py +39 -0
- zea/internal/operators.py +10 -0
- zea/internal/parameters.py +75 -19
- zea/internal/viewer.py +24 -24
- zea/io_lib.py +60 -62
- zea/keras_ops.py +1989 -0
- zea/models/__init__.py +6 -3
- zea/models/deeplabv3.py +131 -0
- zea/models/diffusion.py +4 -4
- zea/models/echonetlvh.py +290 -0
- zea/models/presets.py +14 -0
- zea/ops.py +28 -45
- zea/scan.py +10 -3
- zea/tensor_ops.py +150 -0
- zea/tools/fit_scan_cone.py +2 -2
- zea/tools/selection_tool.py +28 -9
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/METADATA +5 -2
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/RECORD +30 -25
- zea/internal/convert.py +0 -150
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/LICENSE +0 -0
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/WHEEL +0 -0
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/entry_points.txt +0 -0
zea/models/__init__.py
CHANGED
|
@@ -4,8 +4,9 @@
|
|
|
4
4
|
|
|
5
5
|
Currently, the following models are available (all inherited from :class:`zea.models.BaseModel`):
|
|
6
6
|
|
|
7
|
-
- :class:`zea.models.echonet.EchoNetDynamic`: A model for
|
|
7
|
+
- :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
|
|
8
8
|
- :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
|
|
9
|
+
- :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
|
|
9
10
|
- :class:`zea.models.unet.UNet`: A simple U-Net implementation.
|
|
10
11
|
- :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
|
|
11
12
|
- :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
|
|
@@ -16,7 +17,7 @@ To use these models, you can import them directly from the :mod:`zea.models` mod
|
|
|
16
17
|
|
|
17
18
|
.. code-block:: python
|
|
18
19
|
|
|
19
|
-
from zea.models import UNet
|
|
20
|
+
from zea.models.unet import UNet
|
|
20
21
|
|
|
21
22
|
model = UNet.from_preset("unet-echonet-inpainter")
|
|
22
23
|
|
|
@@ -48,7 +49,7 @@ An example of how to use the :class:`zea.models.diffusion.DiffusionModel` is sho
|
|
|
48
49
|
|
|
49
50
|
.. code-block:: python
|
|
50
51
|
|
|
51
|
-
from zea.models import DiffusionModel
|
|
52
|
+
from zea.models.diffusion import DiffusionModel
|
|
52
53
|
|
|
53
54
|
model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
|
|
54
55
|
samples = model.sample(n_samples=4)
|
|
@@ -74,9 +75,11 @@ The following steps are recommended when adding a new model:
|
|
|
74
75
|
|
|
75
76
|
from . import (
|
|
76
77
|
carotid_segmenter,
|
|
78
|
+
deeplabv3,
|
|
77
79
|
dense,
|
|
78
80
|
diffusion,
|
|
79
81
|
echonet,
|
|
82
|
+
echonetlvh,
|
|
80
83
|
generative,
|
|
81
84
|
gmm,
|
|
82
85
|
layers,
|
zea/models/deeplabv3.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""DeepLabV3+ architecture for multi-class segmentation. For more details see https://arxiv.org/abs/1802.02611."""
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
from keras import layers, ops
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def convolution_block(
|
|
8
|
+
block_input,
|
|
9
|
+
num_filters=256,
|
|
10
|
+
kernel_size=3,
|
|
11
|
+
dilation_rate=1,
|
|
12
|
+
use_bias=False,
|
|
13
|
+
):
|
|
14
|
+
"""
|
|
15
|
+
Create a convolution block with batch normalization and ReLU activation.
|
|
16
|
+
|
|
17
|
+
This is a standard building block used throughout the DeepLabV3+ architecture,
|
|
18
|
+
consisting of Conv2D -> BatchNormalization -> ReLU.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
block_input (Tensor): Input tensor to the convolution block
|
|
22
|
+
num_filters (int): Number of output filters/channels. Defaults to 256.
|
|
23
|
+
kernel_size (int): Size of the convolution kernel. Defaults to 3.
|
|
24
|
+
dilation_rate (int): Dilation rate for dilated convolution. Defaults to 1.
|
|
25
|
+
use_bias (bool): Whether to use bias in the convolution layer. Defaults to False.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
Tensor: Output tensor after convolution, batch normalization, and ReLU
|
|
29
|
+
"""
|
|
30
|
+
x = layers.Conv2D(
|
|
31
|
+
num_filters,
|
|
32
|
+
kernel_size=kernel_size,
|
|
33
|
+
dilation_rate=dilation_rate,
|
|
34
|
+
padding="same",
|
|
35
|
+
use_bias=use_bias,
|
|
36
|
+
kernel_initializer=keras.initializers.HeNormal(),
|
|
37
|
+
)(block_input)
|
|
38
|
+
x = layers.BatchNormalization()(x)
|
|
39
|
+
return ops.nn.relu(x)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def DilatedSpatialPyramidPooling(dspp_input):
|
|
43
|
+
"""
|
|
44
|
+
Implement Atrous Spatial Pyramid Pooling (ASPP) module.
|
|
45
|
+
|
|
46
|
+
ASPP captures multi-scale context by applying parallel atrous convolutions
|
|
47
|
+
with different dilation rates. This helps the model understand objects
|
|
48
|
+
at multiple scales.
|
|
49
|
+
|
|
50
|
+
The module consists of:
|
|
51
|
+
- Global average pooling branch
|
|
52
|
+
- 1x1 convolution branch
|
|
53
|
+
- 3x3 convolutions with dilation rates 6, 12, and 18
|
|
54
|
+
|
|
55
|
+
Reference: https://arxiv.org/abs/1706.05587
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
dspp_input (Tensor): Input feature tensor from encoder
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Tensor: Multi-scale feature representation
|
|
62
|
+
"""
|
|
63
|
+
dims = dspp_input.shape
|
|
64
|
+
x = layers.AveragePooling2D(pool_size=(dims[-3], dims[-2]))(dspp_input)
|
|
65
|
+
x = convolution_block(x, kernel_size=1, use_bias=True)
|
|
66
|
+
out_pool = layers.UpSampling2D(
|
|
67
|
+
size=(dims[-3] // x.shape[1], dims[-2] // x.shape[2]),
|
|
68
|
+
interpolation="bilinear",
|
|
69
|
+
)(x)
|
|
70
|
+
|
|
71
|
+
out_1 = convolution_block(dspp_input, kernel_size=1, dilation_rate=1)
|
|
72
|
+
out_6 = convolution_block(dspp_input, kernel_size=3, dilation_rate=6)
|
|
73
|
+
out_12 = convolution_block(dspp_input, kernel_size=3, dilation_rate=12)
|
|
74
|
+
out_18 = convolution_block(dspp_input, kernel_size=3, dilation_rate=18)
|
|
75
|
+
|
|
76
|
+
x = layers.Concatenate(axis=-1)([out_pool, out_1, out_6, out_12, out_18])
|
|
77
|
+
output = convolution_block(x, kernel_size=1)
|
|
78
|
+
return output
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def DeeplabV3Plus(image_shape, num_classes, pretrained_weights=None):
|
|
82
|
+
"""
|
|
83
|
+
Build DeepLabV3+ model for semantic segmentation.
|
|
84
|
+
|
|
85
|
+
DeepLabV3+ combines the benefits of spatial pyramid pooling and encoder-decoder
|
|
86
|
+
architecture. It uses a ResNet50 backbone as encoder, ASPP for multi-scale
|
|
87
|
+
feature extraction, and a simple decoder for recovering spatial details.
|
|
88
|
+
|
|
89
|
+
Architecture:
|
|
90
|
+
1. Encoder: ResNet50 backbone with atrous convolutions
|
|
91
|
+
2. ASPP: Multi-scale feature extraction
|
|
92
|
+
3. Decoder: Simple decoder with skip connections
|
|
93
|
+
4. Output: Final segmentation prediction
|
|
94
|
+
|
|
95
|
+
Reference: https://arxiv.org/abs/1802.02611
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
image_shape (tuple): Input image shape as (height, width, channels)
|
|
99
|
+
num_classes (int): Number of output classes for segmentation
|
|
100
|
+
pretrained_weights (str, optional): Pretrained weights for ResNet50 backbone.
|
|
101
|
+
Defaults to None.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
keras.Model: Complete DeepLabV3+ model
|
|
105
|
+
"""
|
|
106
|
+
model_input = keras.Input(shape=image_shape)
|
|
107
|
+
# 3-channel grayscale as repeated single channel for ResNet50
|
|
108
|
+
model_input_3_channel = ops.concatenate([model_input, model_input, model_input], axis=-1)
|
|
109
|
+
preprocessed = keras.applications.resnet50.preprocess_input(model_input_3_channel)
|
|
110
|
+
resnet50 = keras.applications.ResNet50(
|
|
111
|
+
weights=pretrained_weights, include_top=False, input_tensor=preprocessed
|
|
112
|
+
)
|
|
113
|
+
x = resnet50.get_layer("conv4_block6_2_relu").output
|
|
114
|
+
x = DilatedSpatialPyramidPooling(x)
|
|
115
|
+
|
|
116
|
+
input_a = layers.UpSampling2D(
|
|
117
|
+
size=(image_shape[0] // 4 // x.shape[1], image_shape[1] // 4 // x.shape[2]),
|
|
118
|
+
interpolation="bilinear",
|
|
119
|
+
)(x)
|
|
120
|
+
input_b = resnet50.get_layer("conv2_block3_2_relu").output
|
|
121
|
+
input_b = convolution_block(input_b, num_filters=48, kernel_size=1)
|
|
122
|
+
|
|
123
|
+
x = layers.Concatenate(axis=-1)([input_a, input_b])
|
|
124
|
+
x = convolution_block(x)
|
|
125
|
+
x = convolution_block(x)
|
|
126
|
+
x = layers.UpSampling2D(
|
|
127
|
+
size=(image_shape[0] // x.shape[1], image_shape[1] // x.shape[2]),
|
|
128
|
+
interpolation="bilinear",
|
|
129
|
+
)(x)
|
|
130
|
+
model_output = layers.Conv2D(num_classes, kernel_size=(1, 1), padding="same")(x)
|
|
131
|
+
return keras.Model(inputs=model_input, outputs=model_output)
|
zea/models/diffusion.py
CHANGED
|
@@ -728,15 +728,15 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
728
728
|
)
|
|
729
729
|
return next_noisy_images
|
|
730
730
|
|
|
731
|
-
def start_track_progress(self, diffusion_steps):
|
|
731
|
+
def start_track_progress(self, diffusion_steps, initial_step=0):
|
|
732
732
|
"""Initialize the progress tracking for the diffusion process.
|
|
733
|
-
|
|
734
733
|
For diffusion animation we keep track of the diffusion progress.
|
|
735
734
|
For large number of steps, we do not store all the images due to memory constraints.
|
|
736
735
|
"""
|
|
737
736
|
self.track_progress = []
|
|
738
|
-
|
|
739
|
-
|
|
737
|
+
remaining = max(1, diffusion_steps - int(initial_step))
|
|
738
|
+
if remaining > 50:
|
|
739
|
+
self.track_progress_interval = remaining // 50
|
|
740
740
|
else:
|
|
741
741
|
self.track_progress_interval = 1
|
|
742
742
|
|
zea/models/echonetlvh.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
"""EchoNetLVH model for segmentation of PLAX view cardiac ultrasound. For more details see https://echonet.github.io/lvh/index.html."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from zea.internal.registry import model_registry
|
|
7
|
+
from zea.models.base import BaseModel
|
|
8
|
+
from zea.models.deeplabv3 import DeeplabV3Plus
|
|
9
|
+
from zea.models.preset_utils import register_presets
|
|
10
|
+
from zea.models.presets import echonet_lvh_presets
|
|
11
|
+
from zea.utils import translate
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@model_registry(name="echonetlvh")
|
|
15
|
+
class EchoNetLVH(BaseModel):
|
|
16
|
+
"""
|
|
17
|
+
EchoNet Left Ventricular Hypertrophy (LVH) model for echocardiogram analysis.
|
|
18
|
+
|
|
19
|
+
This model performs semantic segmentation on echocardiogram images to identify
|
|
20
|
+
key anatomical landmarks for measuring left ventricular wall thickness:
|
|
21
|
+
- LVPWd_1: Left Ventricular Posterior Wall point 1
|
|
22
|
+
- LVPWd_2: Left Ventricular Posterior Wall point 2
|
|
23
|
+
- IVSd_1: Interventricular Septum point 1
|
|
24
|
+
- IVSd_2: Interventricular Septum point 2
|
|
25
|
+
|
|
26
|
+
The model outputs 4-channel logits corresponding to heatmaps for each landmark.
|
|
27
|
+
|
|
28
|
+
For more information, see the original project page at https://echonet.github.io/lvh/index.html
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, **kwargs):
|
|
32
|
+
"""
|
|
33
|
+
Initialize the EchoNetLVH model.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
**kwargs: Additional keyword arguments passed to BaseModel
|
|
37
|
+
"""
|
|
38
|
+
super().__init__(**kwargs)
|
|
39
|
+
|
|
40
|
+
# Scan conversion constants for echonet processing
|
|
41
|
+
self.rho_range = (0, 224) # Radial distance range in pixels
|
|
42
|
+
self.theta_range = (np.deg2rad(-45), np.deg2rad(45)) # Angular range in radians
|
|
43
|
+
self.fill_value = -1.0 # Fill value for scan conversion
|
|
44
|
+
self.resolution = 1.0 # mm per pixel resolution
|
|
45
|
+
|
|
46
|
+
# Network input/output dimensions
|
|
47
|
+
self.n_rho = 224
|
|
48
|
+
self.n_theta = 224
|
|
49
|
+
self.output_shape = (224, 224, 4)
|
|
50
|
+
|
|
51
|
+
# Pre-computed coordinate grid for efficient processing
|
|
52
|
+
self.coordinate_grid = ops.stack(
|
|
53
|
+
ops.cast(ops.convert_to_tensor(np.indices((224, 224))), "float32"), axis=-1
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Initialize the underlying segmentation network
|
|
57
|
+
self.network = DeeplabV3Plus(image_shape=(224, 224, 3), num_classes=4)
|
|
58
|
+
|
|
59
|
+
def call(self, inputs):
|
|
60
|
+
"""
|
|
61
|
+
Forward pass of the model.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
inputs (Tensor): Input images of shape [B, H, W, C]. They should
|
|
65
|
+
be scan converted, with pixel values in range [0, 255].
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
Tensor: Logits of shape [B, H, W, 4] with 4 channels for each landmark
|
|
69
|
+
"""
|
|
70
|
+
assert len(ops.shape(inputs)) == 4
|
|
71
|
+
|
|
72
|
+
# Store original dimensions for output resizing
|
|
73
|
+
original_size = ops.shape(inputs)[1:3]
|
|
74
|
+
|
|
75
|
+
# Resize to network input size
|
|
76
|
+
inputs_resized = ops.image.resize(inputs, size=(224, 224))
|
|
77
|
+
|
|
78
|
+
# Get network predictions
|
|
79
|
+
logits = self.network(inputs_resized)
|
|
80
|
+
|
|
81
|
+
# Resize logits back to original input dimensions
|
|
82
|
+
logits_output = ops.image.resize(logits, original_size)
|
|
83
|
+
return logits_output
|
|
84
|
+
|
|
85
|
+
def extract_key_points_as_indices(self, logits):
|
|
86
|
+
"""
|
|
87
|
+
Extract key point coordinates from logits using center-of-mass calculation.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
logits (Tensor): Model output logits of shape [B, H, W, 4]
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Tensor: Key point coordinates of shape [B, 4, 2] where each point
|
|
94
|
+
is in (x, y) format
|
|
95
|
+
"""
|
|
96
|
+
# Create coordinate grid for the current logit dimensions
|
|
97
|
+
input_shape = ops.shape(logits)[1:3]
|
|
98
|
+
input_space_coordinate_grid = ops.stack(
|
|
99
|
+
ops.cast(ops.convert_to_tensor(np.indices(input_shape)), "float32"), axis=-1
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Transpose logits to [B, 4, H, W] for vectorized processing
|
|
103
|
+
logits_batchified = ops.transpose(logits, (0, 3, 1, 2))
|
|
104
|
+
|
|
105
|
+
# Extract expected coordinates for each channel
|
|
106
|
+
return ops.flip(
|
|
107
|
+
ops.vectorized_map(
|
|
108
|
+
lambda logit: self.expected_coordinate(logit, input_space_coordinate_grid),
|
|
109
|
+
logits_batchified,
|
|
110
|
+
),
|
|
111
|
+
axis=-1, # Flip to convert from (y, x) to (x, y)
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def expected_coordinate(self, mask, coordinate_grid=None):
|
|
115
|
+
"""
|
|
116
|
+
Compute the expected coordinate (center-of-mass) of a heatmap.
|
|
117
|
+
|
|
118
|
+
This implements a differentiable version of taking the max of a heatmap
|
|
119
|
+
by computing the weighted average of coordinates.
|
|
120
|
+
|
|
121
|
+
Reference: https://arxiv.org/pdf/1711.08229
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
mask (Tensor): Heatmap of shape [B, H, W]
|
|
125
|
+
coordinate_grid (Tensor, optional): Grid of coordinates. If None,
|
|
126
|
+
uses self.coordinate_grid
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
Tensor: Expected coordinates of shape [B, 2] in (x, y) format
|
|
130
|
+
"""
|
|
131
|
+
if coordinate_grid is None:
|
|
132
|
+
coordinate_grid = self.coordinate_grid
|
|
133
|
+
|
|
134
|
+
# Ensure mask values are non-negative and normalized
|
|
135
|
+
mask_clipped = ops.clip(mask, 0, None)
|
|
136
|
+
mask_normed = mask_clipped / ops.max(mask_clipped)
|
|
137
|
+
|
|
138
|
+
def safe_normalize(m):
|
|
139
|
+
mask_sum = ops.sum(m)
|
|
140
|
+
return ops.where(mask_sum > 0, m / mask_sum, m)
|
|
141
|
+
|
|
142
|
+
coordinate_probabilities = ops.map(safe_normalize, mask_normed)
|
|
143
|
+
|
|
144
|
+
# Add dimension for broadcasting with coordinate grid
|
|
145
|
+
coordinate_probabilities = ops.expand_dims(coordinate_probabilities, axis=-1)
|
|
146
|
+
|
|
147
|
+
# Compute weighted average of coordinates
|
|
148
|
+
expected_coordinate = ops.sum(
|
|
149
|
+
ops.expand_dims(coordinate_grid, axis=0) * coordinate_probabilities,
|
|
150
|
+
axis=(1, 2),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Flip to convert from (y, x) to (x, y) format for euclidean distance calculation
|
|
154
|
+
return ops.flip(expected_coordinate, axis=-1)
|
|
155
|
+
|
|
156
|
+
def overlay_labels_on_image(self, image, label, alpha=0.5):
|
|
157
|
+
"""
|
|
158
|
+
Overlay predicted heatmaps and connecting lines on the input image.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
image (Tensor): Input image of shape [H, W] or [H, W, C]
|
|
162
|
+
label (Tensor): Predicted logits of shape [H, W, 4]
|
|
163
|
+
alpha (float): Blending factor for overlay (0=transparent, 1=opaque)
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
ndarray: Image with overlaid heatmaps and measurements of shape [H, W, 3]
|
|
167
|
+
"""
|
|
168
|
+
try:
|
|
169
|
+
import cv2
|
|
170
|
+
|
|
171
|
+
except ImportError as exc:
|
|
172
|
+
raise ImportError(
|
|
173
|
+
"OpenCV is required for `EchoNetLVH.overlay_labels_on_image`. "
|
|
174
|
+
"Please install it with 'pip install opencv-python' or "
|
|
175
|
+
"'pip install opencv-python-headless'."
|
|
176
|
+
) from exc
|
|
177
|
+
|
|
178
|
+
# Color scheme for each landmark
|
|
179
|
+
overlay_colors = np.array(
|
|
180
|
+
[
|
|
181
|
+
[1, 1, 0], # Yellow (LVPWd_X1)
|
|
182
|
+
[1, 0, 1], # Magenta (LVPWd_X2)
|
|
183
|
+
[0, 1, 1], # Cyan (IVSd_X1)
|
|
184
|
+
[0, 1, 0], # Green (IVSd_X2)
|
|
185
|
+
],
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# Convert to numpy and ensure RGB format
|
|
189
|
+
image = ops.convert_to_numpy(image)
|
|
190
|
+
label = ops.convert_to_numpy(label)
|
|
191
|
+
|
|
192
|
+
if image.ndim == 2:
|
|
193
|
+
image = np.stack([image] * 3, axis=-1)
|
|
194
|
+
elif image.shape[-1] == 1:
|
|
195
|
+
image = np.repeat(image, 3, axis=-1)
|
|
196
|
+
else:
|
|
197
|
+
image = image.copy()
|
|
198
|
+
|
|
199
|
+
# Normalize each channel to [0, 1] for proper visualization
|
|
200
|
+
label = np.clip(label, 0, None)
|
|
201
|
+
for ch in range(label.shape[-1]):
|
|
202
|
+
max_val = np.max(label[..., ch])
|
|
203
|
+
if max_val > 0:
|
|
204
|
+
label[..., ch] = label[..., ch] / max_val
|
|
205
|
+
|
|
206
|
+
# Initialize overlay and tracking variables
|
|
207
|
+
overlay = np.zeros_like(image, dtype=np.float32)
|
|
208
|
+
centers = []
|
|
209
|
+
|
|
210
|
+
# Process each landmark channel
|
|
211
|
+
for ch in range(4):
|
|
212
|
+
# Square the mask to enhance peak responses
|
|
213
|
+
mask = label[..., ch] ** 2
|
|
214
|
+
color = overlay_colors[ch]
|
|
215
|
+
|
|
216
|
+
# Find center of mass for this channel
|
|
217
|
+
center_coords = self.expected_coordinate(ops.expand_dims(mask, axis=0))
|
|
218
|
+
center_x = ops.convert_to_numpy(center_coords[0, 0])
|
|
219
|
+
center_y = ops.convert_to_numpy(center_coords[0, 1])
|
|
220
|
+
|
|
221
|
+
# Bounds check before conversion to int
|
|
222
|
+
if 0 <= center_x < image.shape[1] and 0 <= center_y < image.shape[0]:
|
|
223
|
+
center = (int(center_x), int(center_y))
|
|
224
|
+
else:
|
|
225
|
+
center = None
|
|
226
|
+
|
|
227
|
+
if center is not None:
|
|
228
|
+
# Blend heatmap with overlay
|
|
229
|
+
mask_alpha = mask * alpha
|
|
230
|
+
for c in range(3):
|
|
231
|
+
overlay[..., c] += mask_alpha * color[c]
|
|
232
|
+
centers.append(center)
|
|
233
|
+
|
|
234
|
+
# Draw connecting lines between consecutive landmarks
|
|
235
|
+
for i in range(3):
|
|
236
|
+
pt1, pt2 = centers[i], centers[i + 1]
|
|
237
|
+
if pt1 is not None and pt2 is not None:
|
|
238
|
+
color = tuple(int(x) for x in overlay_colors[i])
|
|
239
|
+
|
|
240
|
+
# Create line mask
|
|
241
|
+
line_mask = np.zeros(image.shape[:2], dtype=np.uint8)
|
|
242
|
+
cv2.line(line_mask, pt1, pt2, color=1, thickness=2)
|
|
243
|
+
|
|
244
|
+
# Apply line to overlay
|
|
245
|
+
for c in range(3):
|
|
246
|
+
overlay[..., c][line_mask.astype(bool)] = color[c] * alpha
|
|
247
|
+
|
|
248
|
+
# Blend overlay with original image
|
|
249
|
+
overlay = np.clip(overlay, 0, 1)
|
|
250
|
+
out = image.astype(np.float32)
|
|
251
|
+
blend_mask = np.any(overlay > 0.02, axis=-1)
|
|
252
|
+
out[blend_mask] = (1 - alpha) * out[blend_mask] + overlay[blend_mask]
|
|
253
|
+
|
|
254
|
+
return np.clip(out, 0, 1)
|
|
255
|
+
|
|
256
|
+
def visualize_logits(self, images, logits):
|
|
257
|
+
"""
|
|
258
|
+
Create visualization of model predictions overlaid on input images.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
images (Tensor): Input images of shape [B, H, W, C]
|
|
262
|
+
logits (Tensor): Model predictions of shape [B, H, W, 4]
|
|
263
|
+
|
|
264
|
+
Returns:
|
|
265
|
+
Tensor: Images with overlaid predictions of shape [B, H, W, 3]
|
|
266
|
+
"""
|
|
267
|
+
# Store original dimensions for final output
|
|
268
|
+
original_size = ops.shape(images)[1:3]
|
|
269
|
+
|
|
270
|
+
# Resize to standard processing size
|
|
271
|
+
images_resized = ops.image.resize(images, size=(224, 224), interpolation="nearest")
|
|
272
|
+
logits_resized = ops.image.resize(logits, size=(224, 224), interpolation="nearest")
|
|
273
|
+
|
|
274
|
+
# Normalize images to [0, 1] range
|
|
275
|
+
images_clipped = ops.clip(images_resized, 0, 255)
|
|
276
|
+
images = translate(images_clipped, range_from=(0, 255), range_to=(0, 1))
|
|
277
|
+
|
|
278
|
+
# Generate overlays for each image in the batch
|
|
279
|
+
images_with_overlay = []
|
|
280
|
+
for img, logit_heatmap in zip(images, logits_resized):
|
|
281
|
+
overlay = self.overlay_labels_on_image(img, logit_heatmap)
|
|
282
|
+
images_with_overlay.append(overlay)
|
|
283
|
+
|
|
284
|
+
# Stack results and resize back to original dimensions
|
|
285
|
+
images_with_overlay = np.stack(images_with_overlay, axis=0)
|
|
286
|
+
return ops.image.resize(images_with_overlay, original_size)
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
# Register model presets
|
|
290
|
+
register_presets(echonet_lvh_presets, EchoNetLVH)
|
zea/models/presets.py
CHANGED
|
@@ -47,6 +47,20 @@ echonet_dynamic_presets = {
|
|
|
47
47
|
},
|
|
48
48
|
}
|
|
49
49
|
|
|
50
|
+
echonet_lvh_presets = {
|
|
51
|
+
"echonetlvh": {
|
|
52
|
+
"metadata": {
|
|
53
|
+
"description": (
|
|
54
|
+
"EchoNetLVH segmentation model for PLAX-view cardiac ultrasound segmentation. "
|
|
55
|
+
"Trained on images of size (224, 224)."
|
|
56
|
+
),
|
|
57
|
+
"params": 0,
|
|
58
|
+
"path": "echonetlvh",
|
|
59
|
+
},
|
|
60
|
+
"hf_handle": "hf://zeahub/echonetlvh",
|
|
61
|
+
},
|
|
62
|
+
}
|
|
63
|
+
|
|
50
64
|
lpips_presets = {
|
|
51
65
|
"lpips": {
|
|
52
66
|
"metadata": {
|
zea/ops.py
CHANGED
|
@@ -1376,20 +1376,6 @@ class Mean(Operation):
|
|
|
1376
1376
|
return kwargs
|
|
1377
1377
|
|
|
1378
1378
|
|
|
1379
|
-
@ops_registry("transpose")
|
|
1380
|
-
class Transpose(Operation):
|
|
1381
|
-
"""Transpose the input data along the specified axes."""
|
|
1382
|
-
|
|
1383
|
-
def __init__(self, axes, **kwargs):
|
|
1384
|
-
super().__init__(**kwargs)
|
|
1385
|
-
self.axes = axes
|
|
1386
|
-
|
|
1387
|
-
def call(self, **kwargs):
|
|
1388
|
-
data = kwargs[self.key]
|
|
1389
|
-
transposed_data = ops.transpose(data, axes=self.axes)
|
|
1390
|
-
return {self.output_key: transposed_data}
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
1379
|
@ops_registry("simulate_rf")
|
|
1394
1380
|
class Simulate(Operation):
|
|
1395
1381
|
"""Simulate RF data."""
|
|
@@ -1578,19 +1564,6 @@ class PfieldWeighting(Operation):
|
|
|
1578
1564
|
return {self.output_key: weighted_data}
|
|
1579
1565
|
|
|
1580
1566
|
|
|
1581
|
-
@ops_registry("sum")
|
|
1582
|
-
class Sum(Operation):
|
|
1583
|
-
"""Sum data along a specific axis."""
|
|
1584
|
-
|
|
1585
|
-
def __init__(self, axis, **kwargs):
|
|
1586
|
-
super().__init__(**kwargs)
|
|
1587
|
-
self.axis = axis
|
|
1588
|
-
|
|
1589
|
-
def call(self, **kwargs):
|
|
1590
|
-
data = kwargs[self.key]
|
|
1591
|
-
return {self.output_key: ops.sum(data, axis=self.axis)}
|
|
1592
|
-
|
|
1593
|
-
|
|
1594
1567
|
@ops_registry("delay_and_sum")
|
|
1595
1568
|
class DelayAndSum(Operation):
|
|
1596
1569
|
"""Sums time-delayed signals along channels and transmits."""
|
|
@@ -2124,29 +2097,37 @@ class Demodulate(Operation):
|
|
|
2124
2097
|
class Lambda(Operation):
|
|
2125
2098
|
"""Use any function as an operation."""
|
|
2126
2099
|
|
|
2127
|
-
def __init__(self, func,
|
|
2128
|
-
|
|
2129
|
-
|
|
2130
|
-
|
|
2131
|
-
|
|
2132
|
-
def call(self, **kwargs):
|
|
2133
|
-
data = kwargs[self.key]
|
|
2134
|
-
data = self.func(data)
|
|
2135
|
-
return {self.output_key: data}
|
|
2136
|
-
|
|
2100
|
+
def __init__(self, func, **kwargs):
|
|
2101
|
+
# Split kwargs into kwargs for partial and __init__
|
|
2102
|
+
op_kwargs = {k: v for k, v in kwargs.items() if k not in func.__code__.co_varnames}
|
|
2103
|
+
func_kwargs = {k: v for k, v in kwargs.items() if k in func.__code__.co_varnames}
|
|
2104
|
+
Lambda._check_if_unary(func, **func_kwargs)
|
|
2137
2105
|
|
|
2138
|
-
|
|
2139
|
-
|
|
2140
|
-
"""Clip the input data to a given range."""
|
|
2106
|
+
super().__init__(**op_kwargs)
|
|
2107
|
+
self.func = partial(func, **func_kwargs)
|
|
2141
2108
|
|
|
2142
|
-
|
|
2143
|
-
|
|
2144
|
-
|
|
2145
|
-
|
|
2109
|
+
@staticmethod
|
|
2110
|
+
def _check_if_unary(func, **kwargs):
|
|
2111
|
+
"""Checks if the kwargs are sufficient to call the function as a unary operation."""
|
|
2112
|
+
sig = inspect.signature(func)
|
|
2113
|
+
# Remove arguments that are already provided in func_kwargs
|
|
2114
|
+
params = list(sig.parameters.values())
|
|
2115
|
+
remaining = [p for p in params if p.name not in kwargs]
|
|
2116
|
+
# Count required positional arguments (excluding self/cls)
|
|
2117
|
+
required_positional = [
|
|
2118
|
+
p
|
|
2119
|
+
for p in remaining
|
|
2120
|
+
if p.default is p.empty and p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
|
|
2121
|
+
]
|
|
2122
|
+
if len(required_positional) != 1:
|
|
2123
|
+
raise ValueError(
|
|
2124
|
+
f"Partial of {func.__name__} must be callable with exactly one required "
|
|
2125
|
+
f"positional argument, we still need: {required_positional}."
|
|
2126
|
+
)
|
|
2146
2127
|
|
|
2147
2128
|
def call(self, **kwargs):
|
|
2148
2129
|
data = kwargs[self.key]
|
|
2149
|
-
data =
|
|
2130
|
+
data = self.func(data)
|
|
2150
2131
|
return {self.output_key: data}
|
|
2151
2132
|
|
|
2152
2133
|
|
|
@@ -2685,6 +2666,7 @@ class AnisotropicDiffusion(Operation):
|
|
|
2685
2666
|
return result
|
|
2686
2667
|
|
|
2687
2668
|
|
|
2669
|
+
@ops_registry("channels_to_complex")
|
|
2688
2670
|
class ChannelsToComplex(Operation):
|
|
2689
2671
|
def call(self, **kwargs):
|
|
2690
2672
|
data = kwargs[self.key]
|
|
@@ -2692,6 +2674,7 @@ class ChannelsToComplex(Operation):
|
|
|
2692
2674
|
return {self.output_key: output}
|
|
2693
2675
|
|
|
2694
2676
|
|
|
2677
|
+
@ops_registry("complex_to_channels")
|
|
2695
2678
|
class ComplexToChannels(Operation):
|
|
2696
2679
|
def __init__(self, axis=-1, **kwargs):
|
|
2697
2680
|
super().__init__(**kwargs)
|
zea/scan.py
CHANGED
|
@@ -149,12 +149,13 @@ class Scan(Parameters):
|
|
|
149
149
|
Defaults to 0.0.
|
|
150
150
|
attenuation_coef (float, optional): Attenuation coefficient in dB/(MHz*cm).
|
|
151
151
|
Defaults to 0.0.
|
|
152
|
-
selected_transmits (None, str, int, list, or np.ndarray, optional):
|
|
152
|
+
selected_transmits (None, str, int, list, slice, or np.ndarray, optional):
|
|
153
153
|
Specifies which transmit events to select.
|
|
154
154
|
- None or "all": Use all transmits.
|
|
155
155
|
- "center": Use only the center transmit.
|
|
156
156
|
- int: Select this many evenly spaced transmits.
|
|
157
157
|
- list/array: Use these specific transmit indices.
|
|
158
|
+
- slice: Use transmits specified by the slice (e.g., slice(0, 10, 2)).
|
|
158
159
|
grid_type (str, optional): Type of grid to use for beamforming.
|
|
159
160
|
Can be "cartesian" or "polar". Defaults to "cartesian".
|
|
160
161
|
dynamic_range (tuple, optional): Dynamic range for image display.
|
|
@@ -171,13 +172,14 @@ class Scan(Parameters):
|
|
|
171
172
|
"pixels_per_wavelength": {"type": int, "default": 4},
|
|
172
173
|
"pfield_kwargs": {"type": dict, "default": {}},
|
|
173
174
|
"apply_lens_correction": {"type": bool, "default": False},
|
|
174
|
-
"lens_sound_speed": {"type":
|
|
175
|
+
"lens_sound_speed": {"type": float},
|
|
175
176
|
"lens_thickness": {"type": float},
|
|
176
177
|
"grid_type": {"type": str, "default": "cartesian"},
|
|
177
178
|
"polar_limits": {"type": (tuple, list)},
|
|
178
179
|
"dynamic_range": {"type": (tuple, list), "default": DEFAULT_DYNAMIC_RANGE},
|
|
180
|
+
"selected_transmits": {"type": (type(None), str, int, list, slice, np.ndarray)},
|
|
179
181
|
# acquisition parameters
|
|
180
|
-
"sound_speed": {"type":
|
|
182
|
+
"sound_speed": {"type": float, "default": 1540.0},
|
|
181
183
|
"sampling_frequency": {"type": float},
|
|
182
184
|
"center_frequency": {"type": float},
|
|
183
185
|
"n_el": {"type": int},
|
|
@@ -359,6 +361,7 @@ class Scan(Parameters):
|
|
|
359
361
|
- "center": Use only the center transmit
|
|
360
362
|
- int: Select this many evenly spaced transmits
|
|
361
363
|
- list/array: Use these specific transmit indices
|
|
364
|
+
- slice: Use transmits specified by the slice (e.g., slice(0, 10, 2))
|
|
362
365
|
|
|
363
366
|
Returns:
|
|
364
367
|
The current instance for method chaining.
|
|
@@ -416,6 +419,10 @@ class Scan(Parameters):
|
|
|
416
419
|
self._invalidate_dependents("selected_transmits")
|
|
417
420
|
return self
|
|
418
421
|
|
|
422
|
+
# Handle slice - convert to list of indices
|
|
423
|
+
if isinstance(selection, slice):
|
|
424
|
+
selection = list(range(n_tx_total))[selection]
|
|
425
|
+
|
|
419
426
|
# Handle list of indices
|
|
420
427
|
if isinstance(selection, list):
|
|
421
428
|
# Validate indices
|