zea 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/agent/selection.py +166 -0
- zea/backend/__init__.py +89 -0
- zea/backend/jax/__init__.py +14 -51
- zea/backend/tensorflow/__init__.py +0 -49
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/backend/torch/__init__.py +27 -62
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +5 -6
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/registry.py +1 -1
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +365 -65
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +187 -26
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -18
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +96 -0
- zea/models/preset_utils.py +5 -5
- zea/models/presets.py +36 -0
- zea/models/regional_quality.py +142 -0
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +414 -207
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +411 -206
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/METADATA +9 -3
- zea-0.0.7.dist-info/RECORD +114 -0
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/WHEEL +1 -1
- zea-0.0.5.dist-info/RECORD +0 -110
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info/licenses}/LICENSE +0 -0
zea/models/carotid_segmenter.py
CHANGED
|
@@ -1,9 +1,24 @@
|
|
|
1
1
|
"""Carotid segmentation model.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
3
|
+
To try this model, simply load one of the available presets:
|
|
4
|
+
|
|
5
|
+
.. doctest::
|
|
6
|
+
|
|
7
|
+
>>> from zea.models.carotid_segmenter import CarotidSegmenter
|
|
8
|
+
|
|
9
|
+
>>> model = CarotidSegmenter.from_preset("carotid-segmenter")
|
|
10
|
+
|
|
11
|
+
.. important::
|
|
12
|
+
This is a ``zea`` implementation of the model.
|
|
13
|
+
For the original paper see:
|
|
14
|
+
|
|
15
|
+
van Knippenberg, Luuk, et al.
|
|
16
|
+
"Unsupervised domain adaptation method for segmenting cross-sectional CCA images."
|
|
17
|
+
*https://doi.org/10.1016/j.cmpb.2022.107037*
|
|
18
|
+
|
|
19
|
+
.. seealso::
|
|
20
|
+
A tutorial notebook where this model is used:
|
|
21
|
+
:doc:`../notebooks/models/carotid_segmentation_example`.
|
|
7
22
|
"""
|
|
8
23
|
|
|
9
24
|
import keras
|
zea/models/diffusion.py
CHANGED
|
@@ -1,4 +1,15 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Diffusion models for ultrasound image generation and posterior sampling.
|
|
3
|
+
|
|
4
|
+
To try this model, simply load one of the available presets:
|
|
5
|
+
|
|
6
|
+
.. doctest::
|
|
7
|
+
|
|
8
|
+
>>> from zea.models.diffusion import DiffusionModel
|
|
9
|
+
|
|
10
|
+
>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
|
|
11
|
+
|
|
12
|
+
"""
|
|
2
13
|
|
|
3
14
|
import abc
|
|
4
15
|
from typing import Literal
|
|
@@ -6,11 +17,12 @@ from typing import Literal
|
|
|
6
17
|
import keras
|
|
7
18
|
from keras import ops
|
|
8
19
|
|
|
9
|
-
from zea.backend import _import_tf
|
|
20
|
+
from zea.backend import _import_tf, jit
|
|
10
21
|
from zea.backend.autograd import AutoGrad
|
|
11
22
|
from zea.internal.core import Object
|
|
12
23
|
from zea.internal.operators import Operator
|
|
13
24
|
from zea.internal.registry import diffusion_guidance_registry, model_registry, operator_registry
|
|
25
|
+
from zea.internal.utils import fn_requires_argument
|
|
14
26
|
from zea.models.dense import get_time_conditional_dense_network
|
|
15
27
|
from zea.models.generative import DeepGenerativeModel
|
|
16
28
|
from zea.models.preset_utils import register_presets
|
|
@@ -18,7 +30,6 @@ from zea.models.presets import diffusion_model_presets
|
|
|
18
30
|
from zea.models.unet import get_time_conditional_unetwork
|
|
19
31
|
from zea.models.utils import LossTrackerWrapper
|
|
20
32
|
from zea.tensor_ops import L2, fori_loop, split_seed
|
|
21
|
-
from zea.utils import fn_requires_argument
|
|
22
33
|
|
|
23
34
|
tf = _import_tf()
|
|
24
35
|
|
|
@@ -182,7 +193,7 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
182
193
|
**kwargs: Additional arguments.
|
|
183
194
|
|
|
184
195
|
Returns:
|
|
185
|
-
Generated samples
|
|
196
|
+
Generated samples of shape `(n_samples, *input_shape)`.
|
|
186
197
|
"""
|
|
187
198
|
seed, seed1 = split_seed(seed, 2)
|
|
188
199
|
|
|
@@ -233,7 +244,8 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
233
244
|
`(batch_size, n_samples, *input_shape)`.
|
|
234
245
|
|
|
235
246
|
"""
|
|
236
|
-
|
|
247
|
+
batch_size = ops.shape(measurements)[0]
|
|
248
|
+
shape = (batch_size, n_samples, *self.input_shape)
|
|
237
249
|
|
|
238
250
|
def _tile_with_sample_dim(tensor):
|
|
239
251
|
"""Tile the tensor with an additional sample dimension."""
|
|
@@ -250,7 +262,7 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
250
262
|
seed1, seed2 = split_seed(seed, 2)
|
|
251
263
|
|
|
252
264
|
initial_noise = keras.random.normal(
|
|
253
|
-
shape=
|
|
265
|
+
shape=(batch_size * n_samples, *self.input_shape),
|
|
254
266
|
seed=seed1,
|
|
255
267
|
)
|
|
256
268
|
|
|
@@ -262,9 +274,9 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
262
274
|
initial_step=initial_step,
|
|
263
275
|
seed=seed2,
|
|
264
276
|
**kwargs,
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
return ops.reshape(out, (
|
|
277
|
+
) # ( batch_size * n_samples, *self.input_shape)
|
|
278
|
+
|
|
279
|
+
return ops.reshape(out, shape) # (batch_size, n_samples, *input_shape)
|
|
268
280
|
|
|
269
281
|
def log_likelihood(self, data, **kwargs):
|
|
270
282
|
"""Approximate log-likelihood of the data under the model.
|
|
@@ -776,7 +788,12 @@ register_presets(diffusion_model_presets, DiffusionModel)
|
|
|
776
788
|
class DiffusionGuidance(abc.ABC, Object):
|
|
777
789
|
"""Base class for diffusion guidance methods."""
|
|
778
790
|
|
|
779
|
-
def __init__(
|
|
791
|
+
def __init__(
|
|
792
|
+
self,
|
|
793
|
+
diffusion_model: DiffusionModel,
|
|
794
|
+
operator: Operator,
|
|
795
|
+
disable_jit: bool = False,
|
|
796
|
+
):
|
|
780
797
|
"""Initialize the diffusion guidance.
|
|
781
798
|
|
|
782
799
|
Args:
|
|
@@ -823,12 +840,12 @@ class DPS(DiffusionGuidance):
|
|
|
823
840
|
omega,
|
|
824
841
|
**kwargs,
|
|
825
842
|
):
|
|
826
|
-
"""
|
|
843
|
+
"""
|
|
844
|
+
Compute measurement error for diffusion posterior sampling.
|
|
827
845
|
|
|
828
846
|
Args:
|
|
829
847
|
noisy_images: Noisy images.
|
|
830
|
-
|
|
831
|
-
operator: Forward operator.
|
|
848
|
+
measurements: Target measurement.
|
|
832
849
|
noise_rates: Current noise rates.
|
|
833
850
|
signal_rates: Current signal rates.
|
|
834
851
|
omega: Weight for the measurement error.
|
|
@@ -849,20 +866,164 @@ class DPS(DiffusionGuidance):
|
|
|
849
866
|
return measurement_error, (pred_noises, pred_images)
|
|
850
867
|
|
|
851
868
|
def __call__(self, noisy_images, **kwargs):
|
|
852
|
-
"""
|
|
853
|
-
|
|
854
|
-
Returns a function with the following signature:
|
|
855
|
-
(
|
|
856
|
-
noisy_images,
|
|
857
|
-
measurement,
|
|
858
|
-
operator,
|
|
859
|
-
noise_rates,
|
|
860
|
-
signal_rates,
|
|
861
|
-
omega,
|
|
862
|
-
**operator_kwargs,
|
|
863
|
-
) -> gradients, (error, (pred_noises, pred_images))
|
|
869
|
+
"""
|
|
870
|
+
Call the gradient function.
|
|
864
871
|
|
|
865
|
-
|
|
872
|
+
Args:
|
|
873
|
+
noisy_images: Noisy images.
|
|
874
|
+
measurement: Target measurement.
|
|
875
|
+
operator: Forward operator.
|
|
876
|
+
noise_rates: Current noise rates.
|
|
877
|
+
signal_rates: Current signal rates.
|
|
878
|
+
omega: Weight for the measurement error.
|
|
879
|
+
**kwargs: Additional arguments for the operator.
|
|
866
880
|
|
|
881
|
+
Returns:
|
|
882
|
+
Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
|
|
867
883
|
"""
|
|
868
884
|
return self.gradient_fn(noisy_images, **kwargs)
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
@diffusion_guidance_registry(name="dds")
|
|
888
|
+
class DDS(DiffusionGuidance):
|
|
889
|
+
"""
|
|
890
|
+
Decomposed Diffusion Sampling guidance.
|
|
891
|
+
|
|
892
|
+
Reference paper: https://arxiv.org/pdf/2303.05754
|
|
893
|
+
"""
|
|
894
|
+
|
|
895
|
+
def setup(self):
|
|
896
|
+
"""Setup DDS guidance function."""
|
|
897
|
+
if not self.disable_jit:
|
|
898
|
+
self.call = jit(self.call)
|
|
899
|
+
|
|
900
|
+
def Acg(self, x, **op_kwargs):
|
|
901
|
+
# we transform the operator from A(x) to A.T(A(x)) to get the normal equations,
|
|
902
|
+
# so that it is suitable for conjugate gradient. (symmetric, positive definite)
|
|
903
|
+
# Normal equations: A^T y = A^T A x
|
|
904
|
+
return self.operator.transpose(self.operator.forward(x, **op_kwargs), **op_kwargs)
|
|
905
|
+
|
|
906
|
+
def conjugate_gradient_inner_loop(self, i, loop_state, eps=1e-5):
|
|
907
|
+
"""
|
|
908
|
+
A single iteration of the conjugate gradient method.
|
|
909
|
+
This involves minimizing the error of x along the current search
|
|
910
|
+
vector p, and then choosing the next search vector.
|
|
911
|
+
|
|
912
|
+
Reference code from: https://github.com/svi-diffusion/
|
|
913
|
+
"""
|
|
914
|
+
p, rs_old, r, x, eps, op_kwargs = loop_state
|
|
915
|
+
|
|
916
|
+
# compute alpha
|
|
917
|
+
Ap = self.Acg(p, **op_kwargs) # transform search vector p by A
|
|
918
|
+
a = rs_old / ops.sum(p * Ap) # minimize f along the line p
|
|
919
|
+
|
|
920
|
+
x_new = x + a * p # set new x at the minimum of f along line p
|
|
921
|
+
r_new = r - a * Ap # shortcut to compute next residual
|
|
922
|
+
|
|
923
|
+
# compute Gram-Schmidt coefficient beta to choose next search vector
|
|
924
|
+
# so that p_new is A-orthogonal to p_current.
|
|
925
|
+
rs_new = ops.sum(r_new * r_new)
|
|
926
|
+
p_new = r_new + (rs_new / rs_old) * p
|
|
927
|
+
|
|
928
|
+
# this is like a jittable 'break' -- if the residual
|
|
929
|
+
# is less than eps, then we just return the old
|
|
930
|
+
# loop state rather than the updated one.
|
|
931
|
+
next_loop_state = ops.cond(
|
|
932
|
+
ops.abs(ops.sqrt(rs_old)) < eps,
|
|
933
|
+
lambda: (p, rs_old, r, x, eps, op_kwargs),
|
|
934
|
+
lambda: (p_new, rs_new, r_new, x_new, eps, op_kwargs),
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
return next_loop_state
|
|
938
|
+
|
|
939
|
+
def call(
|
|
940
|
+
self,
|
|
941
|
+
noisy_images,
|
|
942
|
+
measurements,
|
|
943
|
+
noise_rates,
|
|
944
|
+
signal_rates,
|
|
945
|
+
n_inner,
|
|
946
|
+
eps,
|
|
947
|
+
verbose,
|
|
948
|
+
**op_kwargs,
|
|
949
|
+
):
|
|
950
|
+
"""
|
|
951
|
+
Call the DDS guidance function
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
noisy_images: Noisy images.
|
|
955
|
+
measurement: Target measurement.
|
|
956
|
+
noise_rates: Current noise rates.
|
|
957
|
+
signal_rates: Current signal rates.
|
|
958
|
+
n_inner: Number of conjugate gradient steps.
|
|
959
|
+
verbose: Whether to calculate error.
|
|
960
|
+
|
|
961
|
+
Returns:
|
|
962
|
+
Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
|
|
963
|
+
"""
|
|
964
|
+
pred_noises, pred_images = self.diffusion_model.denoise(
|
|
965
|
+
noisy_images,
|
|
966
|
+
noise_rates,
|
|
967
|
+
signal_rates,
|
|
968
|
+
training=False,
|
|
969
|
+
)
|
|
970
|
+
measurements_cg = self.operator.transpose(measurements, **op_kwargs)
|
|
971
|
+
r = measurements_cg - self.Acg(pred_images, **op_kwargs) # residual
|
|
972
|
+
p = ops.copy(r) # initial search vector = residual
|
|
973
|
+
rs_old = ops.sum(r * r) # residual dot product
|
|
974
|
+
_, _, _, pred_images_updated_cg, _, _ = fori_loop(
|
|
975
|
+
0,
|
|
976
|
+
n_inner,
|
|
977
|
+
self.conjugate_gradient_inner_loop,
|
|
978
|
+
(p, rs_old, r, pred_images, eps, op_kwargs),
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
# Not strictly necessary, just for debugging
|
|
982
|
+
error = ops.cond(
|
|
983
|
+
verbose,
|
|
984
|
+
lambda: L2(measurements - self.operator.forward(pred_images_updated_cg, **op_kwargs)),
|
|
985
|
+
lambda: 0.0,
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
pred_images = pred_images_updated_cg
|
|
989
|
+
# we have already performed the guidance steps in self.conjugate_gradient_method, so
|
|
990
|
+
# we can set these gradients to zero.
|
|
991
|
+
gradients = ops.zeros_like(pred_images)
|
|
992
|
+
return gradients, (error, (pred_noises, pred_images))
|
|
993
|
+
|
|
994
|
+
def __call__(
|
|
995
|
+
self,
|
|
996
|
+
noisy_images,
|
|
997
|
+
measurements,
|
|
998
|
+
noise_rates,
|
|
999
|
+
signal_rates,
|
|
1000
|
+
n_inner=5,
|
|
1001
|
+
eps=1e-5,
|
|
1002
|
+
verbose=False,
|
|
1003
|
+
**op_kwargs,
|
|
1004
|
+
):
|
|
1005
|
+
"""
|
|
1006
|
+
Call the DDS guidance function
|
|
1007
|
+
|
|
1008
|
+
Args:
|
|
1009
|
+
noisy_images: Noisy images.
|
|
1010
|
+
measurement: Target measurement.
|
|
1011
|
+
noise_rates: Current noise rates.
|
|
1012
|
+
signal_rates: Current signal rates.
|
|
1013
|
+
n_inner: Number of conjugate gradient steps.
|
|
1014
|
+
verbose: Whether to calculate error.
|
|
1015
|
+
**kwargs: Additional arguments for the operator.
|
|
1016
|
+
|
|
1017
|
+
Returns:
|
|
1018
|
+
Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
|
|
1019
|
+
"""
|
|
1020
|
+
return self.call(
|
|
1021
|
+
noisy_images,
|
|
1022
|
+
measurements,
|
|
1023
|
+
noise_rates,
|
|
1024
|
+
signal_rates,
|
|
1025
|
+
n_inner,
|
|
1026
|
+
eps,
|
|
1027
|
+
verbose,
|
|
1028
|
+
**op_kwargs,
|
|
1029
|
+
)
|
zea/models/echonet.py
CHANGED
|
@@ -1,6 +1,25 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
"""
|
|
2
|
+
Echonet-Dynamic segmentation model for cardiac ultrasound segmentation.
|
|
3
|
+
|
|
4
|
+
To try this model, simply load one of the available presets:
|
|
5
|
+
|
|
6
|
+
.. doctest::
|
|
7
|
+
|
|
8
|
+
>>> from zea.models.echonet import EchoNetDynamic
|
|
9
|
+
|
|
10
|
+
>>> model = EchoNetDynamic.from_preset("echonet-dynamic") # doctest: +SKIP
|
|
11
|
+
|
|
12
|
+
.. important::
|
|
13
|
+
This is a ``zea`` implementation of the model.
|
|
14
|
+
For the original paper and code, see `here <https://echonet.github.io/dynamic/>`_.
|
|
15
|
+
|
|
16
|
+
Ouyang, David, et al. "Video-based AI for beat-to-beat assessment of cardiac function."
|
|
17
|
+
*Nature 580.7802 (2020): 252-256*
|
|
18
|
+
|
|
19
|
+
.. seealso::
|
|
20
|
+
A tutorial notebook where this model is used:
|
|
21
|
+
:doc:`../notebooks/models/left_ventricle_segmentation_example`.
|
|
22
|
+
|
|
4
23
|
"""
|
|
5
24
|
|
|
6
25
|
from pathlib import Path
|
|
@@ -33,11 +52,6 @@ tf = _import_tf()
|
|
|
33
52
|
class EchoNetDynamic(BaseModel):
|
|
34
53
|
"""EchoNet-Dynamic segmentation model for cardiac ultrasound segmentation.
|
|
35
54
|
|
|
36
|
-
Original paper and code: https://echonet.github.io/dynamic/
|
|
37
|
-
|
|
38
|
-
This class extracts useful parts of the original code and wraps it in a
|
|
39
|
-
easy to use class.
|
|
40
|
-
|
|
41
55
|
Preprocessing should normalize the input images with mean and standard deviation.
|
|
42
56
|
|
|
43
57
|
"""
|
zea/models/echonetlvh.py
CHANGED
|
@@ -1,4 +1,26 @@
|
|
|
1
|
-
"""EchoNetLVH model for segmentation of PLAX view cardiac ultrasound.
|
|
1
|
+
"""EchoNetLVH model for segmentation of PLAX view cardiac ultrasound.
|
|
2
|
+
|
|
3
|
+
To try this model, simply load one of the available presets:
|
|
4
|
+
|
|
5
|
+
.. doctest::
|
|
6
|
+
|
|
7
|
+
>>> from zea.models.echonetlvh import EchoNetLVH
|
|
8
|
+
|
|
9
|
+
>>> model = EchoNetLVH.from_preset("echonetlvh")
|
|
10
|
+
|
|
11
|
+
.. important::
|
|
12
|
+
This is a ``zea`` implementation of the model.
|
|
13
|
+
For the original paper and code, see `here <https://echonet.github.io/lvh/>`_.
|
|
14
|
+
|
|
15
|
+
Duffy, Grant, et al.
|
|
16
|
+
"High-throughput precision phenotyping of left ventricular hypertrophy with cardiovascular deep learning."
|
|
17
|
+
*JAMA cardiology 7.4 (2022): 386-395*
|
|
18
|
+
|
|
19
|
+
.. seealso::
|
|
20
|
+
A tutorial notebook where this model is used:
|
|
21
|
+
:doc:`../notebooks/agent/task_based_perception_action_loop`.
|
|
22
|
+
|
|
23
|
+
""" # noqa: E501
|
|
2
24
|
|
|
3
25
|
import numpy as np
|
|
4
26
|
from keras import ops
|
|
@@ -8,7 +30,7 @@ from zea.models.base import BaseModel
|
|
|
8
30
|
from zea.models.deeplabv3 import DeeplabV3Plus
|
|
9
31
|
from zea.models.preset_utils import register_presets
|
|
10
32
|
from zea.models.presets import echonet_lvh_presets
|
|
11
|
-
from zea.
|
|
33
|
+
from zea.tensor_ops import translate
|
|
12
34
|
|
|
13
35
|
|
|
14
36
|
@model_registry(name="echonetlvh")
|
|
@@ -18,14 +40,16 @@ class EchoNetLVH(BaseModel):
|
|
|
18
40
|
|
|
19
41
|
This model performs semantic segmentation on echocardiogram images to identify
|
|
20
42
|
key anatomical landmarks for measuring left ventricular wall thickness:
|
|
21
|
-
|
|
22
|
-
-
|
|
23
|
-
-
|
|
24
|
-
-
|
|
43
|
+
|
|
44
|
+
- **LVPWd_1**: Left Ventricular Posterior Wall point 1
|
|
45
|
+
- **LVPWd_2**: Left Ventricular Posterior Wall point 2
|
|
46
|
+
- **IVSd_1**: Interventricular Septum point 1
|
|
47
|
+
- **IVSd_2**: Interventricular Septum point 2
|
|
25
48
|
|
|
26
49
|
The model outputs 4-channel logits corresponding to heatmaps for each landmark.
|
|
27
50
|
|
|
28
|
-
For more information, see the original project page
|
|
51
|
+
For more information, see the original project page:
|
|
52
|
+
https://echonet.github.io/lvh/
|
|
29
53
|
"""
|
|
30
54
|
|
|
31
55
|
def __init__(self, **kwargs):
|
|
@@ -37,17 +61,6 @@ class EchoNetLVH(BaseModel):
|
|
|
37
61
|
"""
|
|
38
62
|
super().__init__(**kwargs)
|
|
39
63
|
|
|
40
|
-
# Scan conversion constants for echonet processing
|
|
41
|
-
self.rho_range = (0, 224) # Radial distance range in pixels
|
|
42
|
-
self.theta_range = (np.deg2rad(-45), np.deg2rad(45)) # Angular range in radians
|
|
43
|
-
self.fill_value = -1.0 # Fill value for scan conversion
|
|
44
|
-
self.resolution = 1.0 # mm per pixel resolution
|
|
45
|
-
|
|
46
|
-
# Network input/output dimensions
|
|
47
|
-
self.n_rho = 224
|
|
48
|
-
self.n_theta = 224
|
|
49
|
-
self.output_shape = (224, 224, 4)
|
|
50
|
-
|
|
51
64
|
# Pre-computed coordinate grid for efficient processing
|
|
52
65
|
self.coordinate_grid = ops.stack(
|
|
53
66
|
ops.cast(ops.convert_to_tensor(np.indices((224, 224))), "float32"), axis=-1
|
zea/models/lpips.py
CHANGED
|
@@ -1,7 +1,24 @@
|
|
|
1
1
|
"""LPIPS model for perceptual similarity.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
To try this model, simply load one of the available presets:
|
|
4
|
+
|
|
5
|
+
.. doctest::
|
|
6
|
+
|
|
7
|
+
>>> from zea.models.lpips import LPIPS
|
|
8
|
+
|
|
9
|
+
>>> model = LPIPS.from_preset("lpips")
|
|
10
|
+
|
|
11
|
+
.. important::
|
|
12
|
+
This is a ``zea`` implementation of the model.
|
|
13
|
+
For the original paper and code, see `here <https://github.com/richzhang/PerceptualSimilarity>`_.
|
|
14
|
+
|
|
15
|
+
Zhang, Richard, et al.
|
|
16
|
+
"The Unreasonable Effectiveness of Deep Features as a Perceptual Metric."
|
|
17
|
+
*https://arxiv.org/abs/1801.03924*
|
|
18
|
+
|
|
19
|
+
.. seealso::
|
|
20
|
+
A tutorial notebook where this model is used:
|
|
21
|
+
:doc:`../notebooks/metrics/lpips_example`.
|
|
5
22
|
|
|
6
23
|
"""
|
|
7
24
|
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""
|
|
2
|
+
nnU-Net segmentation model trained on the augmented CAMUS dataset.
|
|
3
|
+
|
|
4
|
+
To try this model, simply load one of the available presets:
|
|
5
|
+
|
|
6
|
+
.. doctest::
|
|
7
|
+
|
|
8
|
+
>>> from zea.models.lv_segmentation import AugmentedCamusSeg
|
|
9
|
+
|
|
10
|
+
>>> model = AugmentedCamusSeg.from_preset("augmented_camus_seg")
|
|
11
|
+
|
|
12
|
+
The model segments both the left ventricle and myocardium.
|
|
13
|
+
|
|
14
|
+
At the time of writing (17 September 2025) and to the best of our knowledge,
|
|
15
|
+
it is the state-of-the-art model for left ventricle segmentation on the CAMUS dataset.
|
|
16
|
+
|
|
17
|
+
.. important::
|
|
18
|
+
This is a ``zea`` implementation of the model.
|
|
19
|
+
For the original paper and code, see `here <https://github.com/GillesVanDeVyver/EchoGAINS>`_.
|
|
20
|
+
|
|
21
|
+
Van De Vyver, Gilles, et al.
|
|
22
|
+
"Generative augmentations for improved cardiac ultrasound segmentation using diffusion models."
|
|
23
|
+
*https://arxiv.org/abs/2502.20100*
|
|
24
|
+
|
|
25
|
+
.. seealso::
|
|
26
|
+
A tutorial notebook where this model is used:
|
|
27
|
+
:doc:`../notebooks/models/left_ventricle_segmentation_example`.
|
|
28
|
+
|
|
29
|
+
.. note::
|
|
30
|
+
The model is originally a PyTorch model converted to ONNX. To use this model, you must have `onnxruntime` installed. This is required for ONNX model inference.
|
|
31
|
+
|
|
32
|
+
You can install it using pip:
|
|
33
|
+
|
|
34
|
+
.. code-block:: bash
|
|
35
|
+
|
|
36
|
+
pip install onnxruntime
|
|
37
|
+
|
|
38
|
+
""" # noqa: E501
|
|
39
|
+
|
|
40
|
+
from keras import ops
|
|
41
|
+
|
|
42
|
+
from zea.internal.registry import model_registry
|
|
43
|
+
from zea.models.base import BaseModel
|
|
44
|
+
from zea.models.preset_utils import get_preset_loader, register_presets
|
|
45
|
+
from zea.models.presets import augmented_camus_seg_presets
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@model_registry(name="augmented_camus_seg")
|
|
49
|
+
class AugmentedCamusSeg(BaseModel):
|
|
50
|
+
"""
|
|
51
|
+
nnU-Net based left ventricle and myocardium segmentation model.
|
|
52
|
+
|
|
53
|
+
- Trained on the augmented CAMUS dataset.
|
|
54
|
+
- This class loads an ONNX model and provides inference for cardiac ultrasound segmentation tasks.
|
|
55
|
+
|
|
56
|
+
""" # noqa: E501
|
|
57
|
+
|
|
58
|
+
def call(self, inputs):
|
|
59
|
+
"""
|
|
60
|
+
Run inference on the input data using the loaded ONNX model.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
inputs (np.ndarray): Input image or batch of images for segmentation.
|
|
64
|
+
Shape: [batch, 1, 256, 256]
|
|
65
|
+
Range: Any numeric range; normalized internally.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
np.ndarray: Segmentation mask(s) for left ventricle and myocardium.
|
|
69
|
+
Shape: [batch, 3, 256, 256] (logits for background, LV, myocardium)
|
|
70
|
+
|
|
71
|
+
Raises:
|
|
72
|
+
ValueError: If model weights are not loaded.
|
|
73
|
+
"""
|
|
74
|
+
if not hasattr(self, "onnx_sess"):
|
|
75
|
+
raise ValueError("Model weights not loaded. Please call custom_load_weights() first.")
|
|
76
|
+
input_name = self.onnx_sess.get_inputs()[0].name
|
|
77
|
+
output_name = self.onnx_sess.get_outputs()[0].name
|
|
78
|
+
inputs = ops.convert_to_numpy(inputs).astype("float32")
|
|
79
|
+
output = self.onnx_sess.run([output_name], {input_name: inputs})[0]
|
|
80
|
+
return output
|
|
81
|
+
|
|
82
|
+
def custom_load_weights(self, preset, **kwargs):
|
|
83
|
+
"""Load the ONNX weights for the segmentation model."""
|
|
84
|
+
try:
|
|
85
|
+
import onnxruntime
|
|
86
|
+
except ImportError:
|
|
87
|
+
raise ImportError(
|
|
88
|
+
"onnxruntime is not installed. Please run "
|
|
89
|
+
"`pip install onnxruntime` to use this model."
|
|
90
|
+
)
|
|
91
|
+
loader = get_preset_loader(preset)
|
|
92
|
+
filename = loader.get_file("model.onnx")
|
|
93
|
+
self.onnx_sess = onnxruntime.InferenceSession(filename)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
register_presets(augmented_camus_seg_presets, AugmentedCamusSeg)
|
zea/models/preset_utils.py
CHANGED
|
@@ -136,7 +136,7 @@ def load_json(preset, config_file=CONFIG_FILE):
|
|
|
136
136
|
return config
|
|
137
137
|
|
|
138
138
|
|
|
139
|
-
def load_serialized_object(config, **kwargs):
|
|
139
|
+
def load_serialized_object(config, cls, **kwargs):
|
|
140
140
|
"""Load a serialized Keras object from a config."""
|
|
141
141
|
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
|
|
142
142
|
# Ensure that `dtype` is properly configured.
|
|
@@ -145,7 +145,7 @@ def load_serialized_object(config, **kwargs):
|
|
|
145
145
|
|
|
146
146
|
config["config"] = {**config["config"], **kwargs}
|
|
147
147
|
# return keras.saving.deserialize_keras_object(config)
|
|
148
|
-
return zea.models.base.deserialize_zea_object(config)
|
|
148
|
+
return zea.models.base.deserialize_zea_object(config, cls)
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
def check_config_class(config):
|
|
@@ -275,7 +275,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
275
275
|
|
|
276
276
|
def load_model(self, cls, load_weights, **kwargs):
|
|
277
277
|
"""Load a model from a serialized Keras config."""
|
|
278
|
-
model = load_serialized_object(self.config, **kwargs)
|
|
278
|
+
model = load_serialized_object(self.config, cls=cls, **kwargs)
|
|
279
279
|
|
|
280
280
|
if not load_weights:
|
|
281
281
|
return model
|
|
@@ -306,7 +306,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
306
306
|
def load_image_converter(self, cls, **kwargs):
|
|
307
307
|
"""Load an image converter from the preset."""
|
|
308
308
|
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
|
|
309
|
-
return load_serialized_object(converter_config, **kwargs)
|
|
309
|
+
return load_serialized_object(converter_config, cls, **kwargs)
|
|
310
310
|
|
|
311
311
|
def get_file(self, path):
|
|
312
312
|
"""Get a file from the preset."""
|
|
@@ -322,7 +322,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
322
322
|
if not issubclass(check_config_class(preprocessor_json), cls):
|
|
323
323
|
return super().load_preprocessor(cls, **kwargs)
|
|
324
324
|
# We found a `preprocessing.json` with a complete config for our class.
|
|
325
|
-
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
|
|
325
|
+
preprocessor = load_serialized_object(preprocessor_json, cls, **kwargs)
|
|
326
326
|
if hasattr(preprocessor, "load_preset_assets"):
|
|
327
327
|
preprocessor.load_preset_assets(self.preset)
|
|
328
328
|
return preprocessor
|
zea/models/presets.py
CHANGED
|
@@ -47,6 +47,34 @@ echonet_dynamic_presets = {
|
|
|
47
47
|
},
|
|
48
48
|
}
|
|
49
49
|
|
|
50
|
+
augmented_camus_seg_presets = {
|
|
51
|
+
"augmented_camus_seg": {
|
|
52
|
+
"metadata": {
|
|
53
|
+
"description": (
|
|
54
|
+
"Augmented CAMUS segmentation model for cardiac ultrasound segmentation. "
|
|
55
|
+
"Original paper and code: https://arxiv.org/abs/2502.20100"
|
|
56
|
+
),
|
|
57
|
+
"params": 33468899,
|
|
58
|
+
"path": "lv_segmentation",
|
|
59
|
+
},
|
|
60
|
+
"hf_handle": "hf://zeahub/augmented-camus-segmentation",
|
|
61
|
+
},
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
regional_quality_presets = {
|
|
65
|
+
"mobilenetv2_regional_quality": {
|
|
66
|
+
"metadata": {
|
|
67
|
+
"description": (
|
|
68
|
+
"MobileNetV2-based regional myocardial image quality scoring model. "
|
|
69
|
+
"Original GitHub repository and code: https://github.com/GillesVanDeVyver/arqee"
|
|
70
|
+
),
|
|
71
|
+
"params": 2217064,
|
|
72
|
+
"path": "regional_quality",
|
|
73
|
+
},
|
|
74
|
+
"hf_handle": "hf://zeahub/mobilenetv2-regional-quality",
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
|
|
50
78
|
echonet_lvh_presets = {
|
|
51
79
|
"echonetlvh": {
|
|
52
80
|
"metadata": {
|
|
@@ -97,6 +125,14 @@ diffusion_model_presets = {
|
|
|
97
125
|
},
|
|
98
126
|
"hf_handle": "hf://zeahub/diffusion-echonet-dynamic",
|
|
99
127
|
},
|
|
128
|
+
"diffusion-echonetlvh-3-frame": {
|
|
129
|
+
"metadata": {
|
|
130
|
+
"description": ("3-frame diffusion model trained on EchoNetLVH dataset."),
|
|
131
|
+
"params": 0,
|
|
132
|
+
"path": "diffusion",
|
|
133
|
+
},
|
|
134
|
+
"hf_handle": "hf://zeahub/diffusion-echonetlvh",
|
|
135
|
+
},
|
|
100
136
|
}
|
|
101
137
|
|
|
102
138
|
carotid_segmenter_presets = {
|