zea 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/agent/selection.py +166 -0
  5. zea/backend/__init__.py +89 -0
  6. zea/backend/jax/__init__.py +14 -51
  7. zea/backend/tensorflow/__init__.py +0 -49
  8. zea/backend/tensorflow/dataloader.py +2 -1
  9. zea/backend/torch/__init__.py +27 -62
  10. zea/beamform/beamformer.py +100 -50
  11. zea/beamform/lens_correction.py +9 -2
  12. zea/beamform/pfield.py +9 -2
  13. zea/config.py +34 -25
  14. zea/data/__init__.py +22 -16
  15. zea/data/convert/camus.py +2 -1
  16. zea/data/convert/echonet.py +4 -4
  17. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  18. zea/data/convert/matlab.py +11 -4
  19. zea/data/data_format.py +31 -30
  20. zea/data/datasets.py +7 -5
  21. zea/data/file.py +104 -2
  22. zea/data/layers.py +5 -6
  23. zea/datapaths.py +16 -4
  24. zea/display.py +7 -5
  25. zea/interface.py +14 -16
  26. zea/internal/_generate_keras_ops.py +6 -7
  27. zea/internal/cache.py +2 -49
  28. zea/internal/config/validation.py +1 -2
  29. zea/internal/core.py +69 -6
  30. zea/internal/device.py +6 -2
  31. zea/internal/dummy_scan.py +330 -0
  32. zea/internal/operators.py +114 -2
  33. zea/internal/parameters.py +101 -70
  34. zea/internal/registry.py +1 -1
  35. zea/internal/setup_zea.py +5 -6
  36. zea/internal/utils.py +282 -0
  37. zea/io_lib.py +247 -19
  38. zea/keras_ops.py +74 -4
  39. zea/log.py +9 -7
  40. zea/metrics.py +365 -65
  41. zea/models/__init__.py +30 -20
  42. zea/models/base.py +30 -14
  43. zea/models/carotid_segmenter.py +19 -4
  44. zea/models/diffusion.py +187 -26
  45. zea/models/echonet.py +22 -8
  46. zea/models/echonetlvh.py +31 -18
  47. zea/models/lpips.py +19 -2
  48. zea/models/lv_segmentation.py +96 -0
  49. zea/models/preset_utils.py +5 -5
  50. zea/models/presets.py +36 -0
  51. zea/models/regional_quality.py +142 -0
  52. zea/models/taesd.py +21 -5
  53. zea/models/unet.py +15 -1
  54. zea/ops.py +414 -207
  55. zea/probes.py +6 -6
  56. zea/scan.py +109 -49
  57. zea/simulator.py +24 -21
  58. zea/tensor_ops.py +411 -206
  59. zea/tools/hf.py +1 -1
  60. zea/tools/selection_tool.py +47 -86
  61. zea/utils.py +92 -480
  62. zea/visualize.py +177 -39
  63. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/METADATA +9 -3
  64. zea-0.0.7.dist-info/RECORD +114 -0
  65. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/WHEEL +1 -1
  66. zea-0.0.5.dist-info/RECORD +0 -110
  67. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  68. {zea-0.0.5.dist-info → zea-0.0.7.dist-info/licenses}/LICENSE +0 -0
@@ -1,9 +1,24 @@
1
1
  """Carotid segmentation model.
2
2
 
3
- Original implementation of paper:
4
- - "Unsupervised domain adaptation method for segmenting cross-sectional CCA images"
5
- - https://doi.org/10.1016/j.cmpb.2022.107037
6
- - Author: Luuk van Knippenberg
3
+ To try this model, simply load one of the available presets:
4
+
5
+ .. doctest::
6
+
7
+ >>> from zea.models.carotid_segmenter import CarotidSegmenter
8
+
9
+ >>> model = CarotidSegmenter.from_preset("carotid-segmenter")
10
+
11
+ .. important::
12
+ This is a ``zea`` implementation of the model.
13
+ For the original paper see:
14
+
15
+ van Knippenberg, Luuk, et al.
16
+ "Unsupervised domain adaptation method for segmenting cross-sectional CCA images."
17
+ *https://doi.org/10.1016/j.cmpb.2022.107037*
18
+
19
+ .. seealso::
20
+ A tutorial notebook where this model is used:
21
+ :doc:`../notebooks/models/carotid_segmentation_example`.
7
22
  """
8
23
 
9
24
  import keras
zea/models/diffusion.py CHANGED
@@ -1,4 +1,15 @@
1
- """Diffusion models"""
1
+ """
2
+ Diffusion models for ultrasound image generation and posterior sampling.
3
+
4
+ To try this model, simply load one of the available presets:
5
+
6
+ .. doctest::
7
+
8
+ >>> from zea.models.diffusion import DiffusionModel
9
+
10
+ >>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
11
+
12
+ """
2
13
 
3
14
  import abc
4
15
  from typing import Literal
@@ -6,11 +17,12 @@ from typing import Literal
6
17
  import keras
7
18
  from keras import ops
8
19
 
9
- from zea.backend import _import_tf
20
+ from zea.backend import _import_tf, jit
10
21
  from zea.backend.autograd import AutoGrad
11
22
  from zea.internal.core import Object
12
23
  from zea.internal.operators import Operator
13
24
  from zea.internal.registry import diffusion_guidance_registry, model_registry, operator_registry
25
+ from zea.internal.utils import fn_requires_argument
14
26
  from zea.models.dense import get_time_conditional_dense_network
15
27
  from zea.models.generative import DeepGenerativeModel
16
28
  from zea.models.preset_utils import register_presets
@@ -18,7 +30,6 @@ from zea.models.presets import diffusion_model_presets
18
30
  from zea.models.unet import get_time_conditional_unetwork
19
31
  from zea.models.utils import LossTrackerWrapper
20
32
  from zea.tensor_ops import L2, fori_loop, split_seed
21
- from zea.utils import fn_requires_argument
22
33
 
23
34
  tf = _import_tf()
24
35
 
@@ -182,7 +193,7 @@ class DiffusionModel(DeepGenerativeModel):
182
193
  **kwargs: Additional arguments.
183
194
 
184
195
  Returns:
185
- Generated samples.
196
+ Generated samples of shape `(n_samples, *input_shape)`.
186
197
  """
187
198
  seed, seed1 = split_seed(seed, 2)
188
199
 
@@ -233,7 +244,8 @@ class DiffusionModel(DeepGenerativeModel):
233
244
  `(batch_size, n_samples, *input_shape)`.
234
245
 
235
246
  """
236
- shape = ops.shape(measurements)
247
+ batch_size = ops.shape(measurements)[0]
248
+ shape = (batch_size, n_samples, *self.input_shape)
237
249
 
238
250
  def _tile_with_sample_dim(tensor):
239
251
  """Tile the tensor with an additional sample dimension."""
@@ -250,7 +262,7 @@ class DiffusionModel(DeepGenerativeModel):
250
262
  seed1, seed2 = split_seed(seed, 2)
251
263
 
252
264
  initial_noise = keras.random.normal(
253
- shape=ops.shape(measurements),
265
+ shape=(batch_size * n_samples, *self.input_shape),
254
266
  seed=seed1,
255
267
  )
256
268
 
@@ -262,9 +274,9 @@ class DiffusionModel(DeepGenerativeModel):
262
274
  initial_step=initial_step,
263
275
  seed=seed2,
264
276
  **kwargs,
265
- )
266
- # returns: (batch_size, n_samples, *input_shape)
267
- return ops.reshape(out, (shape[0], n_samples, *shape[1:]))
277
+ ) # ( batch_size * n_samples, *self.input_shape)
278
+
279
+ return ops.reshape(out, shape) # (batch_size, n_samples, *input_shape)
268
280
 
269
281
  def log_likelihood(self, data, **kwargs):
270
282
  """Approximate log-likelihood of the data under the model.
@@ -776,7 +788,12 @@ register_presets(diffusion_model_presets, DiffusionModel)
776
788
  class DiffusionGuidance(abc.ABC, Object):
777
789
  """Base class for diffusion guidance methods."""
778
790
 
779
- def __init__(self, diffusion_model, operator, disable_jit=False):
791
+ def __init__(
792
+ self,
793
+ diffusion_model: DiffusionModel,
794
+ operator: Operator,
795
+ disable_jit: bool = False,
796
+ ):
780
797
  """Initialize the diffusion guidance.
781
798
 
782
799
  Args:
@@ -823,12 +840,12 @@ class DPS(DiffusionGuidance):
823
840
  omega,
824
841
  **kwargs,
825
842
  ):
826
- """Compute measurement error for diffusion posterior sampling.
843
+ """
844
+ Compute measurement error for diffusion posterior sampling.
827
845
 
828
846
  Args:
829
847
  noisy_images: Noisy images.
830
- measurement: Target measurement.
831
- operator: Forward operator.
848
+ measurements: Target measurement.
832
849
  noise_rates: Current noise rates.
833
850
  signal_rates: Current signal rates.
834
851
  omega: Weight for the measurement error.
@@ -849,20 +866,164 @@ class DPS(DiffusionGuidance):
849
866
  return measurement_error, (pred_noises, pred_images)
850
867
 
851
868
  def __call__(self, noisy_images, **kwargs):
852
- """Call the gradient function.
853
-
854
- Returns a function with the following signature:
855
- (
856
- noisy_images,
857
- measurement,
858
- operator,
859
- noise_rates,
860
- signal_rates,
861
- omega,
862
- **operator_kwargs,
863
- ) -> gradients, (error, (pred_noises, pred_images))
869
+ """
870
+ Call the gradient function.
864
871
 
865
- where operator_kwargs are the kwargs for the operator.
872
+ Args:
873
+ noisy_images: Noisy images.
874
+ measurement: Target measurement.
875
+ operator: Forward operator.
876
+ noise_rates: Current noise rates.
877
+ signal_rates: Current signal rates.
878
+ omega: Weight for the measurement error.
879
+ **kwargs: Additional arguments for the operator.
866
880
 
881
+ Returns:
882
+ Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
867
883
  """
868
884
  return self.gradient_fn(noisy_images, **kwargs)
885
+
886
+
887
+ @diffusion_guidance_registry(name="dds")
888
+ class DDS(DiffusionGuidance):
889
+ """
890
+ Decomposed Diffusion Sampling guidance.
891
+
892
+ Reference paper: https://arxiv.org/pdf/2303.05754
893
+ """
894
+
895
+ def setup(self):
896
+ """Setup DDS guidance function."""
897
+ if not self.disable_jit:
898
+ self.call = jit(self.call)
899
+
900
+ def Acg(self, x, **op_kwargs):
901
+ # we transform the operator from A(x) to A.T(A(x)) to get the normal equations,
902
+ # so that it is suitable for conjugate gradient. (symmetric, positive definite)
903
+ # Normal equations: A^T y = A^T A x
904
+ return self.operator.transpose(self.operator.forward(x, **op_kwargs), **op_kwargs)
905
+
906
+ def conjugate_gradient_inner_loop(self, i, loop_state, eps=1e-5):
907
+ """
908
+ A single iteration of the conjugate gradient method.
909
+ This involves minimizing the error of x along the current search
910
+ vector p, and then choosing the next search vector.
911
+
912
+ Reference code from: https://github.com/svi-diffusion/
913
+ """
914
+ p, rs_old, r, x, eps, op_kwargs = loop_state
915
+
916
+ # compute alpha
917
+ Ap = self.Acg(p, **op_kwargs) # transform search vector p by A
918
+ a = rs_old / ops.sum(p * Ap) # minimize f along the line p
919
+
920
+ x_new = x + a * p # set new x at the minimum of f along line p
921
+ r_new = r - a * Ap # shortcut to compute next residual
922
+
923
+ # compute Gram-Schmidt coefficient beta to choose next search vector
924
+ # so that p_new is A-orthogonal to p_current.
925
+ rs_new = ops.sum(r_new * r_new)
926
+ p_new = r_new + (rs_new / rs_old) * p
927
+
928
+ # this is like a jittable 'break' -- if the residual
929
+ # is less than eps, then we just return the old
930
+ # loop state rather than the updated one.
931
+ next_loop_state = ops.cond(
932
+ ops.abs(ops.sqrt(rs_old)) < eps,
933
+ lambda: (p, rs_old, r, x, eps, op_kwargs),
934
+ lambda: (p_new, rs_new, r_new, x_new, eps, op_kwargs),
935
+ )
936
+
937
+ return next_loop_state
938
+
939
+ def call(
940
+ self,
941
+ noisy_images,
942
+ measurements,
943
+ noise_rates,
944
+ signal_rates,
945
+ n_inner,
946
+ eps,
947
+ verbose,
948
+ **op_kwargs,
949
+ ):
950
+ """
951
+ Call the DDS guidance function
952
+
953
+ Args:
954
+ noisy_images: Noisy images.
955
+ measurement: Target measurement.
956
+ noise_rates: Current noise rates.
957
+ signal_rates: Current signal rates.
958
+ n_inner: Number of conjugate gradient steps.
959
+ verbose: Whether to calculate error.
960
+
961
+ Returns:
962
+ Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
963
+ """
964
+ pred_noises, pred_images = self.diffusion_model.denoise(
965
+ noisy_images,
966
+ noise_rates,
967
+ signal_rates,
968
+ training=False,
969
+ )
970
+ measurements_cg = self.operator.transpose(measurements, **op_kwargs)
971
+ r = measurements_cg - self.Acg(pred_images, **op_kwargs) # residual
972
+ p = ops.copy(r) # initial search vector = residual
973
+ rs_old = ops.sum(r * r) # residual dot product
974
+ _, _, _, pred_images_updated_cg, _, _ = fori_loop(
975
+ 0,
976
+ n_inner,
977
+ self.conjugate_gradient_inner_loop,
978
+ (p, rs_old, r, pred_images, eps, op_kwargs),
979
+ )
980
+
981
+ # Not strictly necessary, just for debugging
982
+ error = ops.cond(
983
+ verbose,
984
+ lambda: L2(measurements - self.operator.forward(pred_images_updated_cg, **op_kwargs)),
985
+ lambda: 0.0,
986
+ )
987
+
988
+ pred_images = pred_images_updated_cg
989
+ # we have already performed the guidance steps in self.conjugate_gradient_method, so
990
+ # we can set these gradients to zero.
991
+ gradients = ops.zeros_like(pred_images)
992
+ return gradients, (error, (pred_noises, pred_images))
993
+
994
+ def __call__(
995
+ self,
996
+ noisy_images,
997
+ measurements,
998
+ noise_rates,
999
+ signal_rates,
1000
+ n_inner=5,
1001
+ eps=1e-5,
1002
+ verbose=False,
1003
+ **op_kwargs,
1004
+ ):
1005
+ """
1006
+ Call the DDS guidance function
1007
+
1008
+ Args:
1009
+ noisy_images: Noisy images.
1010
+ measurement: Target measurement.
1011
+ noise_rates: Current noise rates.
1012
+ signal_rates: Current signal rates.
1013
+ n_inner: Number of conjugate gradient steps.
1014
+ verbose: Whether to calculate error.
1015
+ **kwargs: Additional arguments for the operator.
1016
+
1017
+ Returns:
1018
+ Tuple of (gradients, (measurement_error, (pred_noises, pred_images)))
1019
+ """
1020
+ return self.call(
1021
+ noisy_images,
1022
+ measurements,
1023
+ noise_rates,
1024
+ signal_rates,
1025
+ n_inner,
1026
+ eps,
1027
+ verbose,
1028
+ **op_kwargs,
1029
+ )
zea/models/echonet.py CHANGED
@@ -1,6 +1,25 @@
1
- """Echonet-Dynamic segmentation model for cardiac ultrasound segmentation.
2
- Link below does not work it seems, this is slightly different but does have some info:
3
- https://github.com/bryanhe/dynamic
1
+ """
2
+ Echonet-Dynamic segmentation model for cardiac ultrasound segmentation.
3
+
4
+ To try this model, simply load one of the available presets:
5
+
6
+ .. doctest::
7
+
8
+ >>> from zea.models.echonet import EchoNetDynamic
9
+
10
+ >>> model = EchoNetDynamic.from_preset("echonet-dynamic") # doctest: +SKIP
11
+
12
+ .. important::
13
+ This is a ``zea`` implementation of the model.
14
+ For the original paper and code, see `here <https://echonet.github.io/dynamic/>`_.
15
+
16
+ Ouyang, David, et al. "Video-based AI for beat-to-beat assessment of cardiac function."
17
+ *Nature 580.7802 (2020): 252-256*
18
+
19
+ .. seealso::
20
+ A tutorial notebook where this model is used:
21
+ :doc:`../notebooks/models/left_ventricle_segmentation_example`.
22
+
4
23
  """
5
24
 
6
25
  from pathlib import Path
@@ -33,11 +52,6 @@ tf = _import_tf()
33
52
  class EchoNetDynamic(BaseModel):
34
53
  """EchoNet-Dynamic segmentation model for cardiac ultrasound segmentation.
35
54
 
36
- Original paper and code: https://echonet.github.io/dynamic/
37
-
38
- This class extracts useful parts of the original code and wraps it in a
39
- easy to use class.
40
-
41
55
  Preprocessing should normalize the input images with mean and standard deviation.
42
56
 
43
57
  """
zea/models/echonetlvh.py CHANGED
@@ -1,4 +1,26 @@
1
- """EchoNetLVH model for segmentation of PLAX view cardiac ultrasound. For more details see https://echonet.github.io/lvh/index.html."""
1
+ """EchoNetLVH model for segmentation of PLAX view cardiac ultrasound.
2
+
3
+ To try this model, simply load one of the available presets:
4
+
5
+ .. doctest::
6
+
7
+ >>> from zea.models.echonetlvh import EchoNetLVH
8
+
9
+ >>> model = EchoNetLVH.from_preset("echonetlvh")
10
+
11
+ .. important::
12
+ This is a ``zea`` implementation of the model.
13
+ For the original paper and code, see `here <https://echonet.github.io/lvh/>`_.
14
+
15
+ Duffy, Grant, et al.
16
+ "High-throughput precision phenotyping of left ventricular hypertrophy with cardiovascular deep learning."
17
+ *JAMA cardiology 7.4 (2022): 386-395*
18
+
19
+ .. seealso::
20
+ A tutorial notebook where this model is used:
21
+ :doc:`../notebooks/agent/task_based_perception_action_loop`.
22
+
23
+ """ # noqa: E501
2
24
 
3
25
  import numpy as np
4
26
  from keras import ops
@@ -8,7 +30,7 @@ from zea.models.base import BaseModel
8
30
  from zea.models.deeplabv3 import DeeplabV3Plus
9
31
  from zea.models.preset_utils import register_presets
10
32
  from zea.models.presets import echonet_lvh_presets
11
- from zea.utils import translate
33
+ from zea.tensor_ops import translate
12
34
 
13
35
 
14
36
  @model_registry(name="echonetlvh")
@@ -18,14 +40,16 @@ class EchoNetLVH(BaseModel):
18
40
 
19
41
  This model performs semantic segmentation on echocardiogram images to identify
20
42
  key anatomical landmarks for measuring left ventricular wall thickness:
21
- - LVPWd_1: Left Ventricular Posterior Wall point 1
22
- - LVPWd_2: Left Ventricular Posterior Wall point 2
23
- - IVSd_1: Interventricular Septum point 1
24
- - IVSd_2: Interventricular Septum point 2
43
+
44
+ - **LVPWd_1**: Left Ventricular Posterior Wall point 1
45
+ - **LVPWd_2**: Left Ventricular Posterior Wall point 2
46
+ - **IVSd_1**: Interventricular Septum point 1
47
+ - **IVSd_2**: Interventricular Septum point 2
25
48
 
26
49
  The model outputs 4-channel logits corresponding to heatmaps for each landmark.
27
50
 
28
- For more information, see the original project page at https://echonet.github.io/lvh/index.html
51
+ For more information, see the original project page:
52
+ https://echonet.github.io/lvh/
29
53
  """
30
54
 
31
55
  def __init__(self, **kwargs):
@@ -37,17 +61,6 @@ class EchoNetLVH(BaseModel):
37
61
  """
38
62
  super().__init__(**kwargs)
39
63
 
40
- # Scan conversion constants for echonet processing
41
- self.rho_range = (0, 224) # Radial distance range in pixels
42
- self.theta_range = (np.deg2rad(-45), np.deg2rad(45)) # Angular range in radians
43
- self.fill_value = -1.0 # Fill value for scan conversion
44
- self.resolution = 1.0 # mm per pixel resolution
45
-
46
- # Network input/output dimensions
47
- self.n_rho = 224
48
- self.n_theta = 224
49
- self.output_shape = (224, 224, 4)
50
-
51
64
  # Pre-computed coordinate grid for efficient processing
52
65
  self.coordinate_grid = ops.stack(
53
66
  ops.cast(ops.convert_to_tensor(np.indices((224, 224))), "float32"), axis=-1
zea/models/lpips.py CHANGED
@@ -1,7 +1,24 @@
1
1
  """LPIPS model for perceptual similarity.
2
2
 
3
- See original code: https://github.com/richzhang/PerceptualSimilarity
4
- As well as the paper: https://arxiv.org/abs/1801.03924
3
+ To try this model, simply load one of the available presets:
4
+
5
+ .. doctest::
6
+
7
+ >>> from zea.models.lpips import LPIPS
8
+
9
+ >>> model = LPIPS.from_preset("lpips")
10
+
11
+ .. important::
12
+ This is a ``zea`` implementation of the model.
13
+ For the original paper and code, see `here <https://github.com/richzhang/PerceptualSimilarity>`_.
14
+
15
+ Zhang, Richard, et al.
16
+ "The Unreasonable Effectiveness of Deep Features as a Perceptual Metric."
17
+ *https://arxiv.org/abs/1801.03924*
18
+
19
+ .. seealso::
20
+ A tutorial notebook where this model is used:
21
+ :doc:`../notebooks/metrics/lpips_example`.
5
22
 
6
23
  """
7
24
 
@@ -0,0 +1,96 @@
1
+ """
2
+ nnU-Net segmentation model trained on the augmented CAMUS dataset.
3
+
4
+ To try this model, simply load one of the available presets:
5
+
6
+ .. doctest::
7
+
8
+ >>> from zea.models.lv_segmentation import AugmentedCamusSeg
9
+
10
+ >>> model = AugmentedCamusSeg.from_preset("augmented_camus_seg")
11
+
12
+ The model segments both the left ventricle and myocardium.
13
+
14
+ At the time of writing (17 September 2025) and to the best of our knowledge,
15
+ it is the state-of-the-art model for left ventricle segmentation on the CAMUS dataset.
16
+
17
+ .. important::
18
+ This is a ``zea`` implementation of the model.
19
+ For the original paper and code, see `here <https://github.com/GillesVanDeVyver/EchoGAINS>`_.
20
+
21
+ Van De Vyver, Gilles, et al.
22
+ "Generative augmentations for improved cardiac ultrasound segmentation using diffusion models."
23
+ *https://arxiv.org/abs/2502.20100*
24
+
25
+ .. seealso::
26
+ A tutorial notebook where this model is used:
27
+ :doc:`../notebooks/models/left_ventricle_segmentation_example`.
28
+
29
+ .. note::
30
+ The model is originally a PyTorch model converted to ONNX. To use this model, you must have `onnxruntime` installed. This is required for ONNX model inference.
31
+
32
+ You can install it using pip:
33
+
34
+ .. code-block:: bash
35
+
36
+ pip install onnxruntime
37
+
38
+ """ # noqa: E501
39
+
40
+ from keras import ops
41
+
42
+ from zea.internal.registry import model_registry
43
+ from zea.models.base import BaseModel
44
+ from zea.models.preset_utils import get_preset_loader, register_presets
45
+ from zea.models.presets import augmented_camus_seg_presets
46
+
47
+
48
+ @model_registry(name="augmented_camus_seg")
49
+ class AugmentedCamusSeg(BaseModel):
50
+ """
51
+ nnU-Net based left ventricle and myocardium segmentation model.
52
+
53
+ - Trained on the augmented CAMUS dataset.
54
+ - This class loads an ONNX model and provides inference for cardiac ultrasound segmentation tasks.
55
+
56
+ """ # noqa: E501
57
+
58
+ def call(self, inputs):
59
+ """
60
+ Run inference on the input data using the loaded ONNX model.
61
+
62
+ Args:
63
+ inputs (np.ndarray): Input image or batch of images for segmentation.
64
+ Shape: [batch, 1, 256, 256]
65
+ Range: Any numeric range; normalized internally.
66
+
67
+ Returns:
68
+ np.ndarray: Segmentation mask(s) for left ventricle and myocardium.
69
+ Shape: [batch, 3, 256, 256] (logits for background, LV, myocardium)
70
+
71
+ Raises:
72
+ ValueError: If model weights are not loaded.
73
+ """
74
+ if not hasattr(self, "onnx_sess"):
75
+ raise ValueError("Model weights not loaded. Please call custom_load_weights() first.")
76
+ input_name = self.onnx_sess.get_inputs()[0].name
77
+ output_name = self.onnx_sess.get_outputs()[0].name
78
+ inputs = ops.convert_to_numpy(inputs).astype("float32")
79
+ output = self.onnx_sess.run([output_name], {input_name: inputs})[0]
80
+ return output
81
+
82
+ def custom_load_weights(self, preset, **kwargs):
83
+ """Load the ONNX weights for the segmentation model."""
84
+ try:
85
+ import onnxruntime
86
+ except ImportError:
87
+ raise ImportError(
88
+ "onnxruntime is not installed. Please run "
89
+ "`pip install onnxruntime` to use this model."
90
+ )
91
+ loader = get_preset_loader(preset)
92
+ filename = loader.get_file("model.onnx")
93
+ self.onnx_sess = onnxruntime.InferenceSession(filename)
94
+
95
+
96
+ register_presets(augmented_camus_seg_presets, AugmentedCamusSeg)
@@ -136,7 +136,7 @@ def load_json(preset, config_file=CONFIG_FILE):
136
136
  return config
137
137
 
138
138
 
139
- def load_serialized_object(config, **kwargs):
139
+ def load_serialized_object(config, cls, **kwargs):
140
140
  """Load a serialized Keras object from a config."""
141
141
  # `dtype` in config might be a serialized `DTypePolicy` or `DTypePolicyMap`.
142
142
  # Ensure that `dtype` is properly configured.
@@ -145,7 +145,7 @@ def load_serialized_object(config, **kwargs):
145
145
 
146
146
  config["config"] = {**config["config"], **kwargs}
147
147
  # return keras.saving.deserialize_keras_object(config)
148
- return zea.models.base.deserialize_zea_object(config)
148
+ return zea.models.base.deserialize_zea_object(config, cls)
149
149
 
150
150
 
151
151
  def check_config_class(config):
@@ -275,7 +275,7 @@ class KerasPresetLoader(PresetLoader):
275
275
 
276
276
  def load_model(self, cls, load_weights, **kwargs):
277
277
  """Load a model from a serialized Keras config."""
278
- model = load_serialized_object(self.config, **kwargs)
278
+ model = load_serialized_object(self.config, cls=cls, **kwargs)
279
279
 
280
280
  if not load_weights:
281
281
  return model
@@ -306,7 +306,7 @@ class KerasPresetLoader(PresetLoader):
306
306
  def load_image_converter(self, cls, **kwargs):
307
307
  """Load an image converter from the preset."""
308
308
  converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
309
- return load_serialized_object(converter_config, **kwargs)
309
+ return load_serialized_object(converter_config, cls, **kwargs)
310
310
 
311
311
  def get_file(self, path):
312
312
  """Get a file from the preset."""
@@ -322,7 +322,7 @@ class KerasPresetLoader(PresetLoader):
322
322
  if not issubclass(check_config_class(preprocessor_json), cls):
323
323
  return super().load_preprocessor(cls, **kwargs)
324
324
  # We found a `preprocessing.json` with a complete config for our class.
325
- preprocessor = load_serialized_object(preprocessor_json, **kwargs)
325
+ preprocessor = load_serialized_object(preprocessor_json, cls, **kwargs)
326
326
  if hasattr(preprocessor, "load_preset_assets"):
327
327
  preprocessor.load_preset_assets(self.preset)
328
328
  return preprocessor
zea/models/presets.py CHANGED
@@ -47,6 +47,34 @@ echonet_dynamic_presets = {
47
47
  },
48
48
  }
49
49
 
50
+ augmented_camus_seg_presets = {
51
+ "augmented_camus_seg": {
52
+ "metadata": {
53
+ "description": (
54
+ "Augmented CAMUS segmentation model for cardiac ultrasound segmentation. "
55
+ "Original paper and code: https://arxiv.org/abs/2502.20100"
56
+ ),
57
+ "params": 33468899,
58
+ "path": "lv_segmentation",
59
+ },
60
+ "hf_handle": "hf://zeahub/augmented-camus-segmentation",
61
+ },
62
+ }
63
+
64
+ regional_quality_presets = {
65
+ "mobilenetv2_regional_quality": {
66
+ "metadata": {
67
+ "description": (
68
+ "MobileNetV2-based regional myocardial image quality scoring model. "
69
+ "Original GitHub repository and code: https://github.com/GillesVanDeVyver/arqee"
70
+ ),
71
+ "params": 2217064,
72
+ "path": "regional_quality",
73
+ },
74
+ "hf_handle": "hf://zeahub/mobilenetv2-regional-quality",
75
+ }
76
+ }
77
+
50
78
  echonet_lvh_presets = {
51
79
  "echonetlvh": {
52
80
  "metadata": {
@@ -97,6 +125,14 @@ diffusion_model_presets = {
97
125
  },
98
126
  "hf_handle": "hf://zeahub/diffusion-echonet-dynamic",
99
127
  },
128
+ "diffusion-echonetlvh-3-frame": {
129
+ "metadata": {
130
+ "description": ("3-frame diffusion model trained on EchoNetLVH dataset."),
131
+ "params": 0,
132
+ "path": "diffusion",
133
+ },
134
+ "hf_handle": "hf://zeahub/diffusion-echonetlvh",
135
+ },
100
136
  }
101
137
 
102
138
  carotid_segmenter_presets = {