zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +3 -3
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +173 -12
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +28 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +390 -196
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +406 -302
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
- zea-0.0.7.dist-info/RECORD +114 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/models/carotid_segmenter.py
CHANGED
|
@@ -1,9 +1,24 @@
|
|
|
1
1
|
"""Carotid segmentation model.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
3
|
+
To try this model, simply load one of the available presets:
|
|
4
|
+
|
|
5
|
+
.. doctest::
|
|
6
|
+
|
|
7
|
+
>>> from zea.models.carotid_segmenter import CarotidSegmenter
|
|
8
|
+
|
|
9
|
+
>>> model = CarotidSegmenter.from_preset("carotid-segmenter")
|
|
10
|
+
|
|
11
|
+
.. important::
|
|
12
|
+
This is a ``zea`` implementation of the model.
|
|
13
|
+
For the original paper see:
|
|
14
|
+
|
|
15
|
+
van Knippenberg, Luuk, et al.
|
|
16
|
+
"Unsupervised domain adaptation method for segmenting cross-sectional CCA images."
|
|
17
|
+
*https://doi.org/10.1016/j.cmpb.2022.107037*
|
|
18
|
+
|
|
19
|
+
.. seealso::
|
|
20
|
+
A tutorial notebook where this model is used:
|
|
21
|
+
:doc:`../notebooks/models/carotid_segmentation_example`.
|
|
7
22
|
"""
|
|
8
23
|
|
|
9
24
|
import keras
|
zea/models/diffusion.py
CHANGED
|
@@ -1,4 +1,15 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Diffusion models for ultrasound image generation and posterior sampling.
|
|
3
|
+
|
|
4
|
+
To try this model, simply load one of the available presets:
|
|
5
|
+
|
|
6
|
+
.. doctest::
|
|
7
|
+
|
|
8
|
+
>>> from zea.models.diffusion import DiffusionModel
|
|
9
|
+
|
|
10
|
+
>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
|
|
11
|
+
|
|
12
|
+
"""
|
|
2
13
|
|
|
3
14
|
import abc
|
|
4
15
|
from typing import Literal
|
|
@@ -6,11 +17,12 @@ from typing import Literal
|
|
|
6
17
|
import keras
|
|
7
18
|
from keras import ops
|
|
8
19
|
|
|
9
|
-
from zea.backend import _import_tf
|
|
20
|
+
from zea.backend import _import_tf, jit
|
|
10
21
|
from zea.backend.autograd import AutoGrad
|
|
11
22
|
from zea.internal.core import Object
|
|
12
23
|
from zea.internal.operators import Operator
|
|
13
24
|
from zea.internal.registry import diffusion_guidance_registry, model_registry, operator_registry
|
|
25
|
+
from zea.internal.utils import fn_requires_argument
|
|
14
26
|
from zea.models.dense import get_time_conditional_dense_network
|
|
15
27
|
from zea.models.generative import DeepGenerativeModel
|
|
16
28
|
from zea.models.preset_utils import register_presets
|
|
@@ -18,7 +30,6 @@ from zea.models.presets import diffusion_model_presets
|
|
|
18
30
|
from zea.models.unet import get_time_conditional_unetwork
|
|
19
31
|
from zea.models.utils import LossTrackerWrapper
|
|
20
32
|
from zea.tensor_ops import L2, fori_loop, split_seed
|
|
21
|
-
from zea.utils import fn_requires_argument
|
|
22
33
|
|
|
23
34
|
tf = _import_tf()
|
|
24
35
|
|
|
@@ -182,7 +193,7 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
182
193
|
**kwargs: Additional arguments.
|
|
183
194
|
|
|
184
195
|
Returns:
|
|
185
|
-
Generated samples
|
|
196
|
+
Generated samples of shape `(n_samples, *input_shape)`.
|
|
186
197
|
"""
|
|
187
198
|
seed, seed1 = split_seed(seed, 2)
|
|
188
199
|
|
|
@@ -233,7 +244,8 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
233
244
|
`(batch_size, n_samples, *input_shape)`.
|
|
234
245
|
|
|
235
246
|
"""
|
|
236
|
-
|
|
247
|
+
batch_size = ops.shape(measurements)[0]
|
|
248
|
+
shape = (batch_size, n_samples, *self.input_shape)
|
|
237
249
|
|
|
238
250
|
def _tile_with_sample_dim(tensor):
|
|
239
251
|
"""Tile the tensor with an additional sample dimension."""
|
|
@@ -250,7 +262,7 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
250
262
|
seed1, seed2 = split_seed(seed, 2)
|
|
251
263
|
|
|
252
264
|
initial_noise = keras.random.normal(
|
|
253
|
-
shape=
|
|
265
|
+
shape=(batch_size * n_samples, *self.input_shape),
|
|
254
266
|
seed=seed1,
|
|
255
267
|
)
|
|
256
268
|
|
|
@@ -262,9 +274,9 @@ class DiffusionModel(DeepGenerativeModel):
|
|
|
262
274
|
initial_step=initial_step,
|
|
263
275
|
seed=seed2,
|
|
264
276
|
**kwargs,
|
|
265
|
-
)
|
|
266
|
-
|
|
267
|
-
return ops.reshape(out, (
|
|
277
|
+
) # ( batch_size * n_samples, *self.input_shape)
|
|
278
|
+
|
|
279
|
+
return ops.reshape(out, shape) # (batch_size, n_samples, *input_shape)
|
|
268
280
|
|
|
269
281
|
def log_likelihood(self, data, **kwargs):
|
|
270
282
|
"""Approximate log-likelihood of the data under the model.
|
|
@@ -776,7 +788,12 @@ register_presets(diffusion_model_presets, DiffusionModel)
|
|
|
776
788
|
class DiffusionGuidance(abc.ABC, Object):
|
|
777
789
|
"""Base class for diffusion guidance methods."""
|
|
778
790
|
|
|
779
|
-
def __init__(
|
|
791
|
+
def __init__(
|
|
792
|
+
self,
|
|
793
|
+
diffusion_model: DiffusionModel,
|
|
794
|
+
operator: Operator,
|
|
795
|
+
disable_jit: bool = False,
|
|
796
|
+
):
|
|
780
797
|
"""Initialize the diffusion guidance.
|
|
781
798
|
|
|
782
799
|
Args:
|
|
@@ -828,8 +845,7 @@ class DPS(DiffusionGuidance):
|
|
|
828
845
|
|
|
829
846
|
Args:
|
|
830
847
|
noisy_images: Noisy images.
|
|
831
|
-
|
|
832
|
-
operator: Forward operator.
|
|
848
|
+
measurements: Target measurement.
|
|
833
849
|
noise_rates: Current noise rates.
|
|
834
850
|
signal_rates: Current signal rates.
|
|
835
851
|
omega: Weight for the measurement error.
|
|
@@ -866,3 +882,148 @@ class DPS(DiffusionGuidance):
|
|
|
866
882
|
Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
|
|
867
883
|
"""
|
|
868
884
|
return self.gradient_fn(noisy_images, **kwargs)
|
|
885
|
+
|
|
886
|
+
|
|
887
|
+
@diffusion_guidance_registry(name="dds")
|
|
888
|
+
class DDS(DiffusionGuidance):
|
|
889
|
+
"""
|
|
890
|
+
Decomposed Diffusion Sampling guidance.
|
|
891
|
+
|
|
892
|
+
Reference paper: https://arxiv.org/pdf/2303.05754
|
|
893
|
+
"""
|
|
894
|
+
|
|
895
|
+
def setup(self):
|
|
896
|
+
"""Setup DDS guidance function."""
|
|
897
|
+
if not self.disable_jit:
|
|
898
|
+
self.call = jit(self.call)
|
|
899
|
+
|
|
900
|
+
def Acg(self, x, **op_kwargs):
|
|
901
|
+
# we transform the operator from A(x) to A.T(A(x)) to get the normal equations,
|
|
902
|
+
# so that it is suitable for conjugate gradient. (symmetric, positive definite)
|
|
903
|
+
# Normal equations: A^T y = A^T A x
|
|
904
|
+
return self.operator.transpose(self.operator.forward(x, **op_kwargs), **op_kwargs)
|
|
905
|
+
|
|
906
|
+
def conjugate_gradient_inner_loop(self, i, loop_state, eps=1e-5):
|
|
907
|
+
"""
|
|
908
|
+
A single iteration of the conjugate gradient method.
|
|
909
|
+
This involves minimizing the error of x along the current search
|
|
910
|
+
vector p, and then choosing the next search vector.
|
|
911
|
+
|
|
912
|
+
Reference code from: https://github.com/svi-diffusion/
|
|
913
|
+
"""
|
|
914
|
+
p, rs_old, r, x, eps, op_kwargs = loop_state
|
|
915
|
+
|
|
916
|
+
# compute alpha
|
|
917
|
+
Ap = self.Acg(p, **op_kwargs) # transform search vector p by A
|
|
918
|
+
a = rs_old / ops.sum(p * Ap) # minimize f along the line p
|
|
919
|
+
|
|
920
|
+
x_new = x + a * p # set new x at the minimum of f along line p
|
|
921
|
+
r_new = r - a * Ap # shortcut to compute next residual
|
|
922
|
+
|
|
923
|
+
# compute Gram-Schmidt coefficient beta to choose next search vector
|
|
924
|
+
# so that p_new is A-orthogonal to p_current.
|
|
925
|
+
rs_new = ops.sum(r_new * r_new)
|
|
926
|
+
p_new = r_new + (rs_new / rs_old) * p
|
|
927
|
+
|
|
928
|
+
# this is like a jittable 'break' -- if the residual
|
|
929
|
+
# is less than eps, then we just return the old
|
|
930
|
+
# loop state rather than the updated one.
|
|
931
|
+
next_loop_state = ops.cond(
|
|
932
|
+
ops.abs(ops.sqrt(rs_old)) < eps,
|
|
933
|
+
lambda: (p, rs_old, r, x, eps, op_kwargs),
|
|
934
|
+
lambda: (p_new, rs_new, r_new, x_new, eps, op_kwargs),
|
|
935
|
+
)
|
|
936
|
+
|
|
937
|
+
return next_loop_state
|
|
938
|
+
|
|
939
|
+
def call(
|
|
940
|
+
self,
|
|
941
|
+
noisy_images,
|
|
942
|
+
measurements,
|
|
943
|
+
noise_rates,
|
|
944
|
+
signal_rates,
|
|
945
|
+
n_inner,
|
|
946
|
+
eps,
|
|
947
|
+
verbose,
|
|
948
|
+
**op_kwargs,
|
|
949
|
+
):
|
|
950
|
+
"""
|
|
951
|
+
Call the DDS guidance function
|
|
952
|
+
|
|
953
|
+
Args:
|
|
954
|
+
noisy_images: Noisy images.
|
|
955
|
+
measurement: Target measurement.
|
|
956
|
+
noise_rates: Current noise rates.
|
|
957
|
+
signal_rates: Current signal rates.
|
|
958
|
+
n_inner: Number of conjugate gradient steps.
|
|
959
|
+
verbose: Whether to calculate error.
|
|
960
|
+
|
|
961
|
+
Returns:
|
|
962
|
+
Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
|
|
963
|
+
"""
|
|
964
|
+
pred_noises, pred_images = self.diffusion_model.denoise(
|
|
965
|
+
noisy_images,
|
|
966
|
+
noise_rates,
|
|
967
|
+
signal_rates,
|
|
968
|
+
training=False,
|
|
969
|
+
)
|
|
970
|
+
measurements_cg = self.operator.transpose(measurements, **op_kwargs)
|
|
971
|
+
r = measurements_cg - self.Acg(pred_images, **op_kwargs) # residual
|
|
972
|
+
p = ops.copy(r) # initial search vector = residual
|
|
973
|
+
rs_old = ops.sum(r * r) # residual dot product
|
|
974
|
+
_, _, _, pred_images_updated_cg, _, _ = fori_loop(
|
|
975
|
+
0,
|
|
976
|
+
n_inner,
|
|
977
|
+
self.conjugate_gradient_inner_loop,
|
|
978
|
+
(p, rs_old, r, pred_images, eps, op_kwargs),
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
# Not strictly necessary, just for debugging
|
|
982
|
+
error = ops.cond(
|
|
983
|
+
verbose,
|
|
984
|
+
lambda: L2(measurements - self.operator.forward(pred_images_updated_cg, **op_kwargs)),
|
|
985
|
+
lambda: 0.0,
|
|
986
|
+
)
|
|
987
|
+
|
|
988
|
+
pred_images = pred_images_updated_cg
|
|
989
|
+
# we have already performed the guidance steps in self.conjugate_gradient_method, so
|
|
990
|
+
# we can set these gradients to zero.
|
|
991
|
+
gradients = ops.zeros_like(pred_images)
|
|
992
|
+
return gradients, (error, (pred_noises, pred_images))
|
|
993
|
+
|
|
994
|
+
def __call__(
|
|
995
|
+
self,
|
|
996
|
+
noisy_images,
|
|
997
|
+
measurements,
|
|
998
|
+
noise_rates,
|
|
999
|
+
signal_rates,
|
|
1000
|
+
n_inner=5,
|
|
1001
|
+
eps=1e-5,
|
|
1002
|
+
verbose=False,
|
|
1003
|
+
**op_kwargs,
|
|
1004
|
+
):
|
|
1005
|
+
"""
|
|
1006
|
+
Call the DDS guidance function
|
|
1007
|
+
|
|
1008
|
+
Args:
|
|
1009
|
+
noisy_images: Noisy images.
|
|
1010
|
+
measurement: Target measurement.
|
|
1011
|
+
noise_rates: Current noise rates.
|
|
1012
|
+
signal_rates: Current signal rates.
|
|
1013
|
+
n_inner: Number of conjugate gradient steps.
|
|
1014
|
+
verbose: Whether to calculate error.
|
|
1015
|
+
**kwargs: Additional arguments for the operator.
|
|
1016
|
+
|
|
1017
|
+
Returns:
|
|
1018
|
+
Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
|
|
1019
|
+
"""
|
|
1020
|
+
return self.call(
|
|
1021
|
+
noisy_images,
|
|
1022
|
+
measurements,
|
|
1023
|
+
noise_rates,
|
|
1024
|
+
signal_rates,
|
|
1025
|
+
n_inner,
|
|
1026
|
+
eps,
|
|
1027
|
+
verbose,
|
|
1028
|
+
**op_kwargs,
|
|
1029
|
+
)
|
zea/models/echonet.py
CHANGED
|
@@ -1,6 +1,25 @@
|
|
|
1
|
-
"""
|
|
2
|
-
|
|
3
|
-
|
|
1
|
+
"""
|
|
2
|
+
Echonet-Dynamic segmentation model for cardiac ultrasound segmentation.
|
|
3
|
+
|
|
4
|
+
To try this model, simply load one of the available presets:
|
|
5
|
+
|
|
6
|
+
.. doctest::
|
|
7
|
+
|
|
8
|
+
>>> from zea.models.echonet import EchoNetDynamic
|
|
9
|
+
|
|
10
|
+
>>> model = EchoNetDynamic.from_preset("echonet-dynamic") # doctest: +SKIP
|
|
11
|
+
|
|
12
|
+
.. important::
|
|
13
|
+
This is a ``zea`` implementation of the model.
|
|
14
|
+
For the original paper and code, see `here <https://echonet.github.io/dynamic/>`_.
|
|
15
|
+
|
|
16
|
+
Ouyang, David, et al. "Video-based AI for beat-to-beat assessment of cardiac function."
|
|
17
|
+
*Nature 580.7802 (2020): 252-256*
|
|
18
|
+
|
|
19
|
+
.. seealso::
|
|
20
|
+
A tutorial notebook where this model is used:
|
|
21
|
+
:doc:`../notebooks/models/left_ventricle_segmentation_example`.
|
|
22
|
+
|
|
4
23
|
"""
|
|
5
24
|
|
|
6
25
|
from pathlib import Path
|
|
@@ -33,11 +52,6 @@ tf = _import_tf()
|
|
|
33
52
|
class EchoNetDynamic(BaseModel):
|
|
34
53
|
"""EchoNet-Dynamic segmentation model for cardiac ultrasound segmentation.
|
|
35
54
|
|
|
36
|
-
Original paper and code: https://echonet.github.io/dynamic/
|
|
37
|
-
|
|
38
|
-
This class extracts useful parts of the original code and wraps it in a
|
|
39
|
-
easy to use class.
|
|
40
|
-
|
|
41
55
|
Preprocessing should normalize the input images with mean and standard deviation.
|
|
42
56
|
|
|
43
57
|
"""
|
zea/models/echonetlvh.py
CHANGED
|
@@ -1,4 +1,26 @@
|
|
|
1
|
-
"""EchoNetLVH model for segmentation of PLAX view cardiac ultrasound.
|
|
1
|
+
"""EchoNetLVH model for segmentation of PLAX view cardiac ultrasound.
|
|
2
|
+
|
|
3
|
+
To try this model, simply load one of the available presets:
|
|
4
|
+
|
|
5
|
+
.. doctest::
|
|
6
|
+
|
|
7
|
+
>>> from zea.models.echonetlvh import EchoNetLVH
|
|
8
|
+
|
|
9
|
+
>>> model = EchoNetLVH.from_preset("echonetlvh")
|
|
10
|
+
|
|
11
|
+
.. important::
|
|
12
|
+
This is a ``zea`` implementation of the model.
|
|
13
|
+
For the original paper and code, see `here <https://echonet.github.io/lvh/>`_.
|
|
14
|
+
|
|
15
|
+
Duffy, Grant, et al.
|
|
16
|
+
"High-throughput precision phenotyping of left ventricular hypertrophy with cardiovascular deep learning."
|
|
17
|
+
*JAMA cardiology 7.4 (2022): 386-395*
|
|
18
|
+
|
|
19
|
+
.. seealso::
|
|
20
|
+
A tutorial notebook where this model is used:
|
|
21
|
+
:doc:`../notebooks/agent/task_based_perception_action_loop`.
|
|
22
|
+
|
|
23
|
+
""" # noqa: E501
|
|
2
24
|
|
|
3
25
|
import numpy as np
|
|
4
26
|
from keras import ops
|
|
@@ -8,7 +30,7 @@ from zea.models.base import BaseModel
|
|
|
8
30
|
from zea.models.deeplabv3 import DeeplabV3Plus
|
|
9
31
|
from zea.models.preset_utils import register_presets
|
|
10
32
|
from zea.models.presets import echonet_lvh_presets
|
|
11
|
-
from zea.
|
|
33
|
+
from zea.tensor_ops import translate
|
|
12
34
|
|
|
13
35
|
|
|
14
36
|
@model_registry(name="echonetlvh")
|
|
@@ -18,14 +40,16 @@ class EchoNetLVH(BaseModel):
|
|
|
18
40
|
|
|
19
41
|
This model performs semantic segmentation on echocardiogram images to identify
|
|
20
42
|
key anatomical landmarks for measuring left ventricular wall thickness:
|
|
21
|
-
|
|
22
|
-
-
|
|
23
|
-
-
|
|
24
|
-
-
|
|
43
|
+
|
|
44
|
+
- **LVPWd_1**: Left Ventricular Posterior Wall point 1
|
|
45
|
+
- **LVPWd_2**: Left Ventricular Posterior Wall point 2
|
|
46
|
+
- **IVSd_1**: Interventricular Septum point 1
|
|
47
|
+
- **IVSd_2**: Interventricular Septum point 2
|
|
25
48
|
|
|
26
49
|
The model outputs 4-channel logits corresponding to heatmaps for each landmark.
|
|
27
50
|
|
|
28
|
-
For more information, see the original project page
|
|
51
|
+
For more information, see the original project page:
|
|
52
|
+
https://echonet.github.io/lvh/
|
|
29
53
|
"""
|
|
30
54
|
|
|
31
55
|
def __init__(self, **kwargs):
|
zea/models/lpips.py
CHANGED
|
@@ -1,7 +1,24 @@
|
|
|
1
1
|
"""LPIPS model for perceptual similarity.
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
To try this model, simply load one of the available presets:
|
|
4
|
+
|
|
5
|
+
.. doctest::
|
|
6
|
+
|
|
7
|
+
>>> from zea.models.lpips import LPIPS
|
|
8
|
+
|
|
9
|
+
>>> model = LPIPS.from_preset("lpips")
|
|
10
|
+
|
|
11
|
+
.. important::
|
|
12
|
+
This is a ``zea`` implementation of the model.
|
|
13
|
+
For the original paper and code, see `here <https://github.com/richzhang/PerceptualSimilarity>`_.
|
|
14
|
+
|
|
15
|
+
Zhang, Richard, et al.
|
|
16
|
+
"The Unreasonable Effectiveness of Deep Features as a Perceptual Metric."
|
|
17
|
+
*https://arxiv.org/abs/1801.03924*
|
|
18
|
+
|
|
19
|
+
.. seealso::
|
|
20
|
+
A tutorial notebook where this model is used:
|
|
21
|
+
:doc:`../notebooks/metrics/lpips_example`.
|
|
5
22
|
|
|
6
23
|
"""
|
|
7
24
|
|
zea/models/lv_segmentation.py
CHANGED
|
@@ -1,23 +1,40 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
3
|
-
Van De Vyver, Gilles, et al.
|
|
4
|
-
"Generative augmentations for improved cardiac ultrasound segmentation using diffusion models."
|
|
5
|
-
arXiv preprint arXiv:2502.20100 (2025).
|
|
2
|
+
nnU-Net segmentation model trained on the augmented CAMUS dataset.
|
|
6
3
|
|
|
7
|
-
|
|
4
|
+
To try this model, simply load one of the available presets:
|
|
5
|
+
|
|
6
|
+
.. doctest::
|
|
7
|
+
|
|
8
|
+
>>> from zea.models.lv_segmentation import AugmentedCamusSeg
|
|
9
|
+
|
|
10
|
+
>>> model = AugmentedCamusSeg.from_preset("augmented_camus_seg")
|
|
11
|
+
|
|
12
|
+
The model segments both the left ventricle and myocardium.
|
|
8
13
|
|
|
9
14
|
At the time of writing (17 September 2025) and to the best of our knowledge,
|
|
10
15
|
it is the state-of-the-art model for left ventricle segmentation on the CAMUS dataset.
|
|
11
16
|
|
|
12
|
-
|
|
17
|
+
.. important::
|
|
18
|
+
This is a ``zea`` implementation of the model.
|
|
19
|
+
For the original paper and code, see `here <https://github.com/GillesVanDeVyver/EchoGAINS>`_.
|
|
20
|
+
|
|
21
|
+
Van De Vyver, Gilles, et al.
|
|
22
|
+
"Generative augmentations for improved cardiac ultrasound segmentation using diffusion models."
|
|
23
|
+
*https://arxiv.org/abs/2502.20100*
|
|
24
|
+
|
|
25
|
+
.. seealso::
|
|
26
|
+
A tutorial notebook where this model is used:
|
|
27
|
+
:doc:`../notebooks/models/left_ventricle_segmentation_example`.
|
|
28
|
+
|
|
29
|
+
.. note::
|
|
30
|
+
The model is originally a PyTorch model converted to ONNX. To use this model, you must have `onnxruntime` installed. This is required for ONNX model inference.
|
|
31
|
+
|
|
32
|
+
You can install it using pip:
|
|
13
33
|
|
|
14
|
-
|
|
15
|
-
-----
|
|
16
|
-
To use this model, you must install the `onnxruntime` Python package:
|
|
34
|
+
.. code-block:: bash
|
|
17
35
|
|
|
18
|
-
|
|
36
|
+
pip install onnxruntime
|
|
19
37
|
|
|
20
|
-
This is required for ONNX model inference.
|
|
21
38
|
""" # noqa: E501
|
|
22
39
|
|
|
23
40
|
from keras import ops
|
zea/models/preset_utils.py
CHANGED
|
@@ -136,7 +136,7 @@ def load_json(preset, config_file=CONFIG_FILE):
|
|
|
136
136
|
return config
|
|
137
137
|
|
|
138
138
|
|
|
139
|
-
def load_serialized_object(config, **kwargs):
|
|
139
|
+
def load_serialized_object(config, cls, **kwargs):
|
|
140
140
|
"""Load a serialized Keras object from a config."""
|
|
141
141
|
# `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
|
|
142
142
|
# Ensure that `dtype` is properly configured.
|
|
@@ -145,7 +145,7 @@ def load_serialized_object(config, **kwargs):
|
|
|
145
145
|
|
|
146
146
|
config["config"] = {**config["config"], **kwargs}
|
|
147
147
|
# return keras.saving.deserialize_keras_object(config)
|
|
148
|
-
return zea.models.base.deserialize_zea_object(config)
|
|
148
|
+
return zea.models.base.deserialize_zea_object(config, cls)
|
|
149
149
|
|
|
150
150
|
|
|
151
151
|
def check_config_class(config):
|
|
@@ -275,7 +275,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
275
275
|
|
|
276
276
|
def load_model(self, cls, load_weights, **kwargs):
|
|
277
277
|
"""Load a model from a serialized Keras config."""
|
|
278
|
-
model = load_serialized_object(self.config, **kwargs)
|
|
278
|
+
model = load_serialized_object(self.config, cls=cls, **kwargs)
|
|
279
279
|
|
|
280
280
|
if not load_weights:
|
|
281
281
|
return model
|
|
@@ -306,7 +306,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
306
306
|
def load_image_converter(self, cls, **kwargs):
|
|
307
307
|
"""Load an image converter from the preset."""
|
|
308
308
|
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
|
|
309
|
-
return load_serialized_object(converter_config, **kwargs)
|
|
309
|
+
return load_serialized_object(converter_config, cls, **kwargs)
|
|
310
310
|
|
|
311
311
|
def get_file(self, path):
|
|
312
312
|
"""Get a file from the preset."""
|
|
@@ -322,7 +322,7 @@ class KerasPresetLoader(PresetLoader):
|
|
|
322
322
|
if not issubclass(check_config_class(preprocessor_json), cls):
|
|
323
323
|
return super().load_preprocessor(cls, **kwargs)
|
|
324
324
|
# We found a `preprocessing.json` with a complete config for our class.
|
|
325
|
-
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
|
|
325
|
+
preprocessor = load_serialized_object(preprocessor_json, cls, **kwargs)
|
|
326
326
|
if hasattr(preprocessor, "load_preset_assets"):
|
|
327
327
|
preprocessor.load_preset_assets(self.preset)
|
|
328
328
|
return preprocessor
|
zea/models/regional_quality.py
CHANGED
|
@@ -1,21 +1,41 @@
|
|
|
1
1
|
"""
|
|
2
|
-
|
|
3
|
-
Van De Vyver, et al. "Regional Image Quality Scoring for 2-D Echocardiography Using Deep Learning."
|
|
4
|
-
Ultrasound in Medicine & Biology 51.4 (2025): 638-649.
|
|
2
|
+
MobileNetv2 based image quality model for myocardial regions in apical views.
|
|
5
3
|
|
|
6
|
-
|
|
4
|
+
To try this model, simply load one of the available presets:
|
|
7
5
|
|
|
8
|
-
|
|
6
|
+
.. doctest::
|
|
7
|
+
|
|
8
|
+
>>> from zea.models.regional_quality import MobileNetv2RegionalQuality
|
|
9
|
+
|
|
10
|
+
>>> model = MobileNetv2RegionalQuality.from_preset("mobilenetv2_regional_quality")
|
|
11
|
+
|
|
12
|
+
The model predicts the regional image quality of
|
|
9
13
|
the myocardial regions in apical views. It can also be used to get the overall image quality by averaging the
|
|
10
14
|
regional scores.
|
|
11
15
|
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
16
|
+
At the time of writing (17 September 2025) and to the best of our knowledge,
|
|
17
|
+
it is the state-of-the-art model for left ventricle segmentation on the CAMUS dataset.
|
|
18
|
+
|
|
19
|
+
.. important::
|
|
20
|
+
This is a ``zea`` implementation of the model.
|
|
21
|
+
For the original paper and code, see `here <https://github.com/GillesVanDeVyver/arqee>`_.
|
|
22
|
+
|
|
23
|
+
Van De Vyver, et al. "Regional Image Quality Scoring for 2-D Echocardiography Using Deep Learning."
|
|
24
|
+
*Ultrasound in Medicine & Biology 51.4 (2025): 638-649*
|
|
25
|
+
|
|
26
|
+
.. seealso::
|
|
27
|
+
A tutorial notebook where this model is used:
|
|
28
|
+
:doc:`../notebooks/metrics/myocardial_quality_example`.
|
|
29
|
+
|
|
30
|
+
.. note::
|
|
31
|
+
The model is originally a PyTorch model converted to ONNX. To use this model, you must have `onnxruntime` installed. This is required for ONNX model inference.
|
|
32
|
+
|
|
33
|
+
You can install it using pip:
|
|
34
|
+
|
|
35
|
+
.. code-block:: bash
|
|
15
36
|
|
|
16
|
-
|
|
37
|
+
pip install onnxruntime
|
|
17
38
|
|
|
18
|
-
This is required for ONNX model inference.
|
|
19
39
|
""" # noqa: E501
|
|
20
40
|
|
|
21
41
|
import numpy as np
|
zea/models/taesd.py
CHANGED
|
@@ -1,9 +1,19 @@
|
|
|
1
|
-
"""
|
|
1
|
+
"""
|
|
2
|
+
Tiny Autoencoder (TAESD) model.
|
|
3
|
+
|
|
4
|
+
.. doctest::
|
|
5
|
+
|
|
6
|
+
>>> from zea.models.taesd import TinyAutoencoder
|
|
7
|
+
|
|
8
|
+
>>> model = TinyAutoencoder.from_preset("taesdxl") # doctest: +SKIP
|
|
2
9
|
|
|
3
|
-
|
|
10
|
+
.. important::
|
|
11
|
+
This is a ``zea`` implementation of the model.
|
|
12
|
+
For the original code, see `here <https://github.com/madebyollin/taesd>`_.
|
|
4
13
|
|
|
5
|
-
|
|
6
|
-
:
|
|
14
|
+
.. seealso::
|
|
15
|
+
A tutorial notebook where this model is used:
|
|
16
|
+
:doc:`../notebooks/models/taesd_autoencoder_example`.
|
|
7
17
|
|
|
8
18
|
"""
|
|
9
19
|
|
|
@@ -23,7 +33,13 @@ tf = _import_tf()
|
|
|
23
33
|
|
|
24
34
|
@model_registry(name="taesdxl")
|
|
25
35
|
class TinyAutoencoder(BaseModel):
|
|
26
|
-
"""
|
|
36
|
+
"""Tiny Autoencoder model.
|
|
37
|
+
|
|
38
|
+
.. note::
|
|
39
|
+
|
|
40
|
+
This model currently only supports TensorFlow and Jax backends.
|
|
41
|
+
|
|
42
|
+
"""
|
|
27
43
|
|
|
28
44
|
def __init__(self, **kwargs):
|
|
29
45
|
"""
|
zea/models/unet.py
CHANGED
|
@@ -1,4 +1,18 @@
|
|
|
1
|
-
"""UNet models and architectures
|
|
1
|
+
"""UNet models and architectures.
|
|
2
|
+
|
|
3
|
+
To try this model, simply load one of the available presets:
|
|
4
|
+
|
|
5
|
+
.. doctest::
|
|
6
|
+
|
|
7
|
+
>>> from zea.models.unet import UNet
|
|
8
|
+
|
|
9
|
+
>>> model = UNet.from_preset("unet-echonet-inpainter")
|
|
10
|
+
|
|
11
|
+
.. seealso::
|
|
12
|
+
A tutorial notebook where this model is used:
|
|
13
|
+
:doc:`../notebooks/models/unet_example`.
|
|
14
|
+
|
|
15
|
+
"""
|
|
2
16
|
|
|
3
17
|
import keras
|
|
4
18
|
from keras import layers
|