zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -1
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/config.py +34 -25
  9. zea/data/__init__.py +22 -16
  10. zea/data/convert/camus.py +2 -1
  11. zea/data/convert/echonet.py +4 -4
  12. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  13. zea/data/convert/matlab.py +11 -4
  14. zea/data/data_format.py +31 -30
  15. zea/data/datasets.py +7 -5
  16. zea/data/file.py +104 -2
  17. zea/data/layers.py +3 -3
  18. zea/datapaths.py +16 -4
  19. zea/display.py +7 -5
  20. zea/interface.py +14 -16
  21. zea/internal/_generate_keras_ops.py +6 -7
  22. zea/internal/cache.py +2 -49
  23. zea/internal/config/validation.py +1 -2
  24. zea/internal/core.py +69 -6
  25. zea/internal/device.py +6 -2
  26. zea/internal/dummy_scan.py +330 -0
  27. zea/internal/operators.py +114 -2
  28. zea/internal/parameters.py +101 -70
  29. zea/internal/setup_zea.py +5 -6
  30. zea/internal/utils.py +282 -0
  31. zea/io_lib.py +247 -19
  32. zea/keras_ops.py +74 -4
  33. zea/log.py +9 -7
  34. zea/metrics.py +15 -7
  35. zea/models/__init__.py +30 -20
  36. zea/models/base.py +30 -14
  37. zea/models/carotid_segmenter.py +19 -4
  38. zea/models/diffusion.py +173 -12
  39. zea/models/echonet.py +22 -8
  40. zea/models/echonetlvh.py +31 -7
  41. zea/models/lpips.py +19 -2
  42. zea/models/lv_segmentation.py +28 -11
  43. zea/models/preset_utils.py +5 -5
  44. zea/models/regional_quality.py +30 -10
  45. zea/models/taesd.py +21 -5
  46. zea/models/unet.py +15 -1
  47. zea/ops.py +390 -196
  48. zea/probes.py +6 -6
  49. zea/scan.py +109 -49
  50. zea/simulator.py +24 -21
  51. zea/tensor_ops.py +406 -302
  52. zea/tools/hf.py +1 -1
  53. zea/tools/selection_tool.py +47 -86
  54. zea/utils.py +92 -480
  55. zea/visualize.py +177 -39
  56. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
  57. zea-0.0.7.dist-info/RECORD +114 -0
  58. zea-0.0.6.dist-info/RECORD +0 -112
  59. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
  60. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  61. {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/io_lib.py CHANGED
@@ -19,7 +19,7 @@ from PIL import Image, ImageSequence
19
19
  from zea import log
20
20
  from zea.data.file import File
21
21
 
22
- _SUPPORTED_VID_TYPES = [".avi", ".mp4", ".gif"]
22
+ _SUPPORTED_VID_TYPES = [".mp4", ".gif"]
23
23
  _SUPPORTED_IMG_TYPES = [".jpg", ".png", ".JPEG", ".PNG", ".jpeg"]
24
24
  _SUPPORTED_ZEA_TYPES = [".hdf5", ".h5"]
25
25
 
@@ -27,7 +27,7 @@ _SUPPORTED_ZEA_TYPES = [".hdf5", ".h5"]
27
27
  def load_video(filename, mode="L"):
28
28
  """Load a video file and return a numpy array of frames.
29
29
 
30
- Supported file types: avi, mp4, gif.
30
+ Supported file types: mp4, gif.
31
31
 
32
32
  Args:
33
33
  filename (str): The path to the video file.
@@ -57,12 +57,24 @@ def load_video(filename, mode="L"):
57
57
  with Image.open(filename) as im:
58
58
  for frame in ImageSequence.Iterator(im):
59
59
  frames.append(_convert_image_mode(frame, mode=mode))
60
- else: # .mp4, .avi
61
- reader = imageio.get_reader(filename)
62
- for frame in reader:
63
- img = Image.fromarray(frame)
64
- frames.append(_convert_image_mode(img, mode=mode))
65
- reader.close()
60
+ elif ext == ".mp4":
61
+ # Use imageio with FFMPEG format for MP4 files
62
+ try:
63
+ reader = imageio.get_reader(filename, format="FFMPEG")
64
+ except (ImportError, ValueError) as exc:
65
+ raise ImportError(
66
+ "FFMPEG plugin is required to load MP4 files. "
67
+ "Please install it with 'pip install imageio-ffmpeg'."
68
+ ) from exc
69
+
70
+ try:
71
+ for frame in reader:
72
+ img = Image.fromarray(frame)
73
+ frames.append(_convert_image_mode(img, mode=mode))
74
+ finally:
75
+ reader.close()
76
+ else:
77
+ raise ValueError(f"Unsupported file extension: {ext}")
66
78
 
67
79
  return np.stack(frames, axis=0)
68
80
 
@@ -98,17 +110,151 @@ def load_image(filename, mode="L"):
98
110
  return _convert_image_mode(img, mode=mode)
99
111
 
100
112
 
101
- def _convert_image_mode(img, mode="L"):
102
- """Convert a PIL Image to the specified mode and return as numpy array."""
103
- if mode not in {"L", "RGB"}:
104
- raise ValueError(f"Unsupported mode: {mode}, must be one of: L, RGB")
105
- if mode == "L":
106
- img = img.convert("L")
107
- arr = np.array(img)
108
- elif mode == "RGB":
109
- img = img.convert("RGB")
110
- arr = np.array(img)
111
- return arr
113
+ def save_video(images, filename, fps=20, **kwargs):
114
+ """Saves a sequence of images to a video file.
115
+
116
+ Supported file types: mp4, gif.
117
+
118
+ Args:
119
+ images (list or np.ndarray): List or array of images. Must have shape
120
+ (n_frames, height, width, channels) or (n_frames, height, width).
121
+ If channel axis is not present, or is 1, grayscale image is assumed,
122
+ which is then converted to RGB. Images should be uint8.
123
+ filename (str or Path): Filename to which data should be written.
124
+ fps (int): Frames per second of rendered format.
125
+ **kwargs: Additional keyword arguments passed to the specific save function.
126
+ For GIF files, this includes `shared_color_palette` (bool).
127
+
128
+ Raises:
129
+ ValueError: If the file extension is not supported.
130
+
131
+ """
132
+ filename = Path(filename)
133
+ ext = filename.suffix.lower()
134
+
135
+ if ext == ".mp4":
136
+ return save_to_mp4(images, filename, fps=fps)
137
+ elif ext == ".gif":
138
+ return save_to_gif(images, filename, fps=fps, **kwargs)
139
+ else:
140
+ raise ValueError(f"Unsupported file extension: {ext}")
141
+
142
+
143
+ def save_to_gif(images, filename, fps=20, shared_color_palette=False):
144
+ """Saves a sequence of images to a GIF file.
145
+
146
+ .. note::
147
+ It's recommended to use :func:`save_video` for a more general interface
148
+ that supports multiple formats.
149
+
150
+ Args:
151
+ images (list or np.ndarray): List or array of images. Must have shape
152
+ (n_frames, height, width, channels) or (n_frames, height, width).
153
+ If channel axis is not present, or is 1, grayscale image is assumed,
154
+ which is then converted to RGB. Images should be uint8.
155
+ filename (str or Path): Filename to which data should be written.
156
+ fps (int): Frames per second of rendered format.
157
+ shared_color_palette (bool, optional): If True, creates a global
158
+ color palette across all frames, ensuring consistent colors
159
+ throughout the GIF. Defaults to False, which is default behavior
160
+ of PIL.Image.save. Note: True can cause slow saving for longer
161
+ sequences, and also lead to larger file sizes in some cases.
162
+
163
+ """
164
+ images = preprocess_for_saving(images)
165
+
166
+ if fps > 50:
167
+ log.warning(f"Cannot set fps ({fps}) > 50. Setting it automatically to 50.")
168
+ fps = 50
169
+
170
+ duration = int(round(1000 / fps)) # milliseconds per frame
171
+
172
+ pillow_imgs = [Image.fromarray(img) for img in images]
173
+
174
+ if shared_color_palette:
175
+ # Apply the same palette to all frames without dithering for consistent color mapping
176
+ # Convert all images to RGB and combine their colors for palette generation
177
+ all_colors = np.vstack([np.array(img.convert("RGB")).reshape(-1, 3) for img in pillow_imgs])
178
+ combined_image = Image.fromarray(all_colors.reshape(-1, 1, 3))
179
+
180
+ # Generate palette from all frames
181
+ global_palette = combined_image.quantize(
182
+ colors=256,
183
+ method=Image.MEDIANCUT,
184
+ kmeans=1,
185
+ )
186
+
187
+ # Apply the same palette to all frames without dithering
188
+ pillow_imgs = [
189
+ img.convert("RGB").quantize(
190
+ palette=global_palette,
191
+ dither=Image.NONE,
192
+ )
193
+ for img in pillow_imgs
194
+ ]
195
+
196
+ pillow_img, *pillow_imgs = pillow_imgs
197
+
198
+ pillow_img.save(
199
+ fp=filename,
200
+ format="GIF",
201
+ append_images=pillow_imgs,
202
+ save_all=True,
203
+ loop=0,
204
+ duration=duration,
205
+ interlace=False,
206
+ optimize=False,
207
+ )
208
+ log.success(f"Successfully saved GIF to -> {log.yellow(filename)}")
209
+
210
+
211
+ def save_to_mp4(images, filename, fps=20):
212
+ """Saves a sequence of images to an MP4 file.
213
+
214
+ .. note::
215
+ It's recommended to use :func:`save_video` for a more general interface
216
+ that supports multiple formats.
217
+
218
+ Args:
219
+ images (list or np.ndarray): List or array of images. Must have shape
220
+ (n_frames, height, width, channels) or (n_frames, height, width).
221
+ If channel axis is not present, or is 1, grayscale image is assumed,
222
+ which is then converted to RGB. Images should be uint8.
223
+ filename (str or Path): Filename to which data should be written.
224
+ fps (int): Frames per second of rendered format.
225
+
226
+ Raises:
227
+ ImportError: If imageio-ffmpeg is not installed.
228
+
229
+ Returns:
230
+ str: Success message.
231
+
232
+ """
233
+ images = preprocess_for_saving(images)
234
+
235
+ filename = str(filename)
236
+
237
+ parent_dir = Path(filename).parent
238
+ parent_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ # Use imageio with FFMPEG format for MP4 files
241
+ try:
242
+ writer = imageio.get_writer(
243
+ filename, fps=fps, format="FFMPEG", codec="libx264", pixelformat="yuv420p"
244
+ )
245
+ except (ImportError, ValueError) as exc:
246
+ raise ImportError(
247
+ "FFMPEG plugin is required to save MP4 files. "
248
+ "Please install it with 'pip install imageio-ffmpeg'."
249
+ ) from exc
250
+
251
+ try:
252
+ for image in images:
253
+ writer.append_data(image)
254
+ finally:
255
+ writer.close()
256
+
257
+ return log.success(f"Successfully saved MP4 to -> {filename}")
112
258
 
113
259
 
114
260
  def search_file_tree(
@@ -341,3 +487,85 @@ def retry_on_io_error(max_retries=3, initial_delay=0.5, retry_action=None):
341
487
  return wrapper
342
488
 
343
489
  return decorator
490
+
491
+
492
+ def _convert_image_mode(img, mode="L"):
493
+ """Convert a PIL Image to the specified mode and return as numpy array."""
494
+ if mode not in {"L", "RGB"}:
495
+ raise ValueError(f"Unsupported mode: {mode}, must be one of: L, RGB")
496
+ if mode == "L":
497
+ img = img.convert("L")
498
+ arr = np.array(img)
499
+ elif mode == "RGB":
500
+ img = img.convert("RGB")
501
+ arr = np.array(img)
502
+ return arr
503
+
504
+
505
+ def grayscale_to_rgb(image):
506
+ """Converts a grayscale image to an RGB image.
507
+
508
+ Args:
509
+ image (ndarray): Grayscale image. Must have shape (height, width).
510
+
511
+ Returns:
512
+ ndarray: RGB image.
513
+ """
514
+ assert image.ndim == 2, "Input image must be grayscale."
515
+ # Stack the grayscale image into 3 channels (RGB)
516
+ return np.stack([image] * 3, axis=-1)
517
+
518
+
519
+ def _assert_uint8_images(images: np.ndarray):
520
+ """
521
+ Asserts that the input images have the correct properties.
522
+
523
+ Args:
524
+ images (np.ndarray): The input images.
525
+
526
+ Raises:
527
+ AssertionError: If the dtype of images is not uint8.
528
+ AssertionError: If the shape of images is not (n_frames, height, width, channels)
529
+ or (n_frames, height, width) for grayscale images.
530
+ AssertionError: If images have anything other than 1 (grayscale),
531
+ 3 (rgb) or 4 (rgba) channels.
532
+ """
533
+ assert images.dtype == np.uint8, f"dtype of images should be uint8, got {images.dtype}"
534
+
535
+ assert images.ndim in (3, 4), (
536
+ "images must have shape (n_frames, height, width, channels),"
537
+ f" or (n_frames, height, width) for grayscale images. Got {images.shape}"
538
+ )
539
+
540
+ if images.ndim == 4:
541
+ assert images.shape[-1] in (1, 3, 4), (
542
+ "Grayscale images must have 1 channel, "
543
+ "RGB images must have 3 channels, and RGBA images must have 4 channels. "
544
+ f"Got shape: {images.shape}, channels: {images.shape[-1]}"
545
+ )
546
+
547
+
548
+ def preprocess_for_saving(images):
549
+ """Preprocesses images for saving to GIF or MP4.
550
+
551
+ Args:
552
+ images (ndarray, list[ndarray]): Images. Must have shape (n_frames, height, width, channels)
553
+ or (n_frames, height, width).
554
+ """
555
+ images = np.array(images)
556
+ _assert_uint8_images(images)
557
+
558
+ # Remove channel axis if it is 1 (grayscale image)
559
+ if images.ndim == 4 and images.shape[-1] == 1:
560
+ images = np.squeeze(images, axis=-1)
561
+
562
+ # convert grayscale images to RGB
563
+ if images.ndim == 3:
564
+ images = [grayscale_to_rgb(image) for image in images]
565
+ images = np.array(images)
566
+
567
+ # drop alpha channel if present (RGBA -> RGB)
568
+ if images.ndim == 4 and images.shape[-1] == 4:
569
+ images = images[..., :3]
570
+
571
+ return images
zea/keras_ops.py CHANGED
@@ -3,14 +3,14 @@ and :mod:`keras.ops.image` functions.
3
3
 
4
4
  They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
5
5
 
6
- .. code-block:: python
6
+ .. doctest::
7
7
 
8
- from zea.keras_ops import Squeeze
8
+ >>> from zea.keras_ops import Squeeze
9
9
 
10
- op = Squeeze(axis=1)
10
+ >>> op = Squeeze(axis=1)
11
11
 
12
12
  This file is generated automatically. Do not edit manually.
13
- Generated with Keras 3.11.3
13
+ Generated with Keras 3.12.0
14
14
  """
15
15
 
16
16
  import keras
@@ -388,6 +388,16 @@ class Cholesky(Lambda):
388
388
  except AttributeError as e:
389
389
  raise MissingKerasOps("Cholesky", "keras.ops.cholesky") from e
390
390
 
391
+ @ops_registry("keras.ops.cholesky_inverse")
392
+ class CholeskyInverse(Lambda):
393
+ """Operation wrapping keras.ops.cholesky_inverse."""
394
+
395
+ def __init__(self, **kwargs):
396
+ try:
397
+ super().__init__(func=keras.ops.cholesky_inverse, **kwargs)
398
+ except AttributeError as e:
399
+ raise MissingKerasOps("CholeskyInverse", "keras.ops.cholesky_inverse") from e
400
+
391
401
  @ops_registry("keras.ops.clip")
392
402
  class Clip(Lambda):
393
403
  """Operation wrapping keras.ops.clip."""
@@ -918,6 +928,36 @@ class Isnan(Lambda):
918
928
  except AttributeError as e:
919
929
  raise MissingKerasOps("Isnan", "keras.ops.isnan") from e
920
930
 
931
+ @ops_registry("keras.ops.isneginf")
932
+ class Isneginf(Lambda):
933
+ """Operation wrapping keras.ops.isneginf."""
934
+
935
+ def __init__(self, **kwargs):
936
+ try:
937
+ super().__init__(func=keras.ops.isneginf, **kwargs)
938
+ except AttributeError as e:
939
+ raise MissingKerasOps("Isneginf", "keras.ops.isneginf") from e
940
+
941
+ @ops_registry("keras.ops.isposinf")
942
+ class Isposinf(Lambda):
943
+ """Operation wrapping keras.ops.isposinf."""
944
+
945
+ def __init__(self, **kwargs):
946
+ try:
947
+ super().__init__(func=keras.ops.isposinf, **kwargs)
948
+ except AttributeError as e:
949
+ raise MissingKerasOps("Isposinf", "keras.ops.isposinf") from e
950
+
951
+ @ops_registry("keras.ops.isreal")
952
+ class Isreal(Lambda):
953
+ """Operation wrapping keras.ops.isreal."""
954
+
955
+ def __init__(self, **kwargs):
956
+ try:
957
+ super().__init__(func=keras.ops.isreal, **kwargs)
958
+ except AttributeError as e:
959
+ raise MissingKerasOps("Isreal", "keras.ops.isreal") from e
960
+
921
961
  @ops_registry("keras.ops.istft")
922
962
  class Istft(Lambda):
923
963
  """Operation wrapping keras.ops.istft."""
@@ -1828,6 +1868,16 @@ class Trunc(Lambda):
1828
1868
  except AttributeError as e:
1829
1869
  raise MissingKerasOps("Trunc", "keras.ops.trunc") from e
1830
1870
 
1871
+ @ops_registry("keras.ops.unfold")
1872
+ class Unfold(Lambda):
1873
+ """Operation wrapping keras.ops.unfold."""
1874
+
1875
+ def __init__(self, **kwargs):
1876
+ try:
1877
+ super().__init__(func=keras.ops.unfold, **kwargs)
1878
+ except AttributeError as e:
1879
+ raise MissingKerasOps("Unfold", "keras.ops.unfold") from e
1880
+
1831
1881
  @ops_registry("keras.ops.unstack")
1832
1882
  class Unstack(Lambda):
1833
1883
  """Operation wrapping keras.ops.unstack."""
@@ -1848,6 +1898,16 @@ class Var(Lambda):
1848
1898
  except AttributeError as e:
1849
1899
  raise MissingKerasOps("Var", "keras.ops.var") from e
1850
1900
 
1901
+ @ops_registry("keras.ops.view")
1902
+ class View(Lambda):
1903
+ """Operation wrapping keras.ops.view."""
1904
+
1905
+ def __init__(self, **kwargs):
1906
+ try:
1907
+ super().__init__(func=keras.ops.view, **kwargs)
1908
+ except AttributeError as e:
1909
+ raise MissingKerasOps("View", "keras.ops.view") from e
1910
+
1851
1911
  @ops_registry("keras.ops.view_as_complex")
1852
1912
  class ViewAsComplex(Lambda):
1853
1913
  """Operation wrapping keras.ops.view_as_complex."""
@@ -1987,3 +2047,13 @@ class RgbToHsv(Lambda):
1987
2047
  super().__init__(func=keras.ops.image.rgb_to_hsv, **kwargs)
1988
2048
  except AttributeError as e:
1989
2049
  raise MissingKerasOps("RgbToHsv", "keras.ops.image.rgb_to_hsv") from e
2050
+
2051
+ @ops_registry("keras.ops.image.scale_and_translate")
2052
+ class ScaleAndTranslate(Lambda):
2053
+ """Operation wrapping keras.ops.image.scale_and_translate."""
2054
+
2055
+ def __init__(self, **kwargs):
2056
+ try:
2057
+ super().__init__(func=keras.ops.image.scale_and_translate, **kwargs)
2058
+ except AttributeError as e:
2059
+ raise MissingKerasOps("ScaleAndTranslate", "keras.ops.image.scale_and_translate") from e
zea/log.py CHANGED
@@ -6,7 +6,7 @@ to the console and to a file with color support.
6
6
  Example usage
7
7
  ^^^^^^^^^^^^^^
8
8
 
9
- .. code-block:: python
9
+ .. testsetup::
10
10
 
11
11
  from zea import log
12
12
 
@@ -323,18 +323,20 @@ def set_level(level):
323
323
 
324
324
  Also sets the log level for the file logger if it exists.
325
325
 
326
- Example:
327
- >>> from zea import log
328
- >>> with log.set_level("WARNING"):
329
- ... log.info("This will not be shown")
330
- ... log.warning("This will be shown")
331
-
332
326
  Args:
333
327
  level (str or int): The log level to set temporarily
334
328
  (e.g., "DEBUG", "INFO", logging.WARNING).
335
329
 
336
330
  Yields:
337
331
  None
332
+
333
+ Example:
334
+ .. doctest::
335
+
336
+ >>> from zea import log
337
+ >>> with log.set_level("ERROR"):
338
+ ... _ = log.info("Info messages will not be shown")
339
+ ... _ = log.error("Error messages will be shown")
338
340
  """
339
341
  prev_level = logger.level
340
342
  prev_file_level = file_logger.level if file_logger else None
zea/metrics.py CHANGED
@@ -10,8 +10,9 @@ from keras import ops
10
10
  from zea import log, tensor_ops
11
11
  from zea.backend import func_on_device
12
12
  from zea.internal.registry import metrics_registry
13
+ from zea.internal.utils import reduce_to_signature
13
14
  from zea.models.lpips import LPIPS
14
- from zea.utils import reduce_to_signature, translate
15
+ from zea.tensor_ops import translate
15
16
 
16
17
 
17
18
  def get_metric(name, **kwargs):
@@ -313,11 +314,18 @@ class Metrics:
313
314
  if specified.
314
315
 
315
316
  Example:
316
- .. code-block:: python
317
-
318
- metrics = zea.metrics.Metrics(["psnr", "lpips"], image_range=[0, 255])
319
- result = metrics(y_true, y_pred)
320
- print(result) # {"psnr": 30.5, "lpips": 0.15}
317
+ .. doctest::
318
+
319
+ >>> from zea import metrics
320
+ >>> import numpy as np
321
+
322
+ >>> metrics = metrics.Metrics(["psnr", "lpips"], image_range=[0, 255])
323
+ >>> y_true = np.random.rand(4, 128, 128, 1)
324
+ >>> y_pred = np.random.rand(4, 128, 128, 1)
325
+ >>> result = metrics(y_true, y_pred)
326
+ >>> result = {k: float(v) for k, v in result.items()}
327
+ >>> print(result) # doctest: +ELLIPSIS
328
+ {'psnr': ..., 'lpips': ...}
321
329
  """
322
330
 
323
331
  def __init__(
@@ -364,7 +372,7 @@ class Metrics:
364
372
  # Because most metric functions do not support batching, we vmap over the batch axes.
365
373
  metric_fn = fun
366
374
  for ax in reversed(batch_axes):
367
- metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax)
375
+ metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
368
376
 
369
377
  out = func_on_device(metric_fn, device, y_true, y_pred)
370
378
 
zea/models/__init__.py CHANGED
@@ -2,31 +2,37 @@
2
2
 
3
3
  ``zea`` contains a collection of models for various tasks, all located in the :mod:`zea.models` package.
4
4
 
5
- Currently, the following models are available (all inherited from :class:`zea.models.BaseModel`):
5
+ See the following dropdown for a list of available models:
6
6
 
7
- - :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
8
- - :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
9
- - :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
10
- - :class:`zea.models.unet.UNet`: A simple U-Net implementation.
11
- - :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
12
- - :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
7
+ .. dropdown:: **Available models**
8
+
9
+ - :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
10
+ - :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
11
+ - :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
12
+ - :class:`zea.models.unet.UNet`: A simple U-Net implementation.
13
+ - :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
14
+ - :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
15
+ - :class:`zea.models.regional_quality.MobileNetv2RegionalQuality`: A scoring model for myocardial regions in apical views.
16
+ - :class:`zea.models.lv_segmentation.AugmentedCamusSeg`: A nnU-Net based left ventricle and myocardium segmentation model.
13
17
 
14
18
  Presets for these models can be found in :mod:`zea.models.presets`.
15
19
 
16
20
  To use these models, you can import them directly from the :mod:`zea.models` module and load the pretrained weights using the :meth:`from_preset` method. For example:
17
21
 
18
- .. code-block:: python
22
+ .. doctest::
19
23
 
20
- from zea.models.unet import UNet
24
+ >>> from zea.models.unet import UNet
21
25
 
22
- model = UNet.from_preset("unet-echonet-inpainter")
26
+ >>> model = UNet.from_preset("unet-echonet-inpainter")
23
27
 
24
28
  You can list all available presets using the :attr:`presets` attribute:
25
29
 
26
- .. code-block:: python
30
+ .. doctest::
27
31
 
28
- presets = list(UNet.presets.keys())
29
- print(f"Available built-in zea presets for UNet: {presets}")
32
+ >>> from zea.models.unet import UNet
33
+ >>> presets = list(UNet.presets.keys())
34
+ >>> print(f"Available built-in zea presets for UNet: {presets}")
35
+ Available built-in zea presets for UNet: ['unet-echonet-inpainter']
30
36
 
31
37
 
32
38
  zea generative models
@@ -40,19 +46,21 @@ Typically, these models have some additional methods, such as:
40
46
  - :meth:`posterior_sample` for drawing samples from the posterior given measurements
41
47
  - :meth:`log_density` for computing the log-probability of data under the model
42
48
 
43
- The following generative models are currently available:
49
+ See the following dropdown for a list of available *generative* models:
50
+
51
+ .. dropdown:: **Available models**
44
52
 
45
- - :class:`zea.models.diffusion.DiffusionModel`: A deep generative diffusion model for ultrasound image generation.
46
- - :class:`zea.models.gmm.GaussianMixtureModel`: A Gaussian Mixture Model.
53
+ - :class:`zea.models.diffusion.DiffusionModel`: A deep generative diffusion model for ultrasound image generation.
54
+ - :class:`zea.models.gmm.GaussianMixtureModel`: A Gaussian Mixture Model.
47
55
 
48
56
  An example of how to use the :class:`zea.models.diffusion.DiffusionModel` is shown below:
49
57
 
50
- .. code-block:: python
58
+ .. doctest::
51
59
 
52
- from zea.models.diffusion import DiffusionModel
60
+ >>> from zea.models.diffusion import DiffusionModel
53
61
 
54
- model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
55
- samples = model.sample(n_samples=4)
62
+ >>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
63
+ >>> samples = model.sample(n_samples=4) # doctest: +SKIP
56
64
 
57
65
 
58
66
  Contributing and adding new models
@@ -84,7 +92,9 @@ from . import (
84
92
  gmm,
85
93
  layers,
86
94
  lpips,
95
+ lv_segmentation,
87
96
  presets,
97
+ regional_quality,
88
98
  taesd,
89
99
  unet,
90
100
  utils,
zea/models/base.py CHANGED
@@ -8,6 +8,7 @@ import importlib
8
8
  import keras
9
9
  from keras.src.saving.serialization_lib import record_object_after_deserialization
10
10
 
11
+ from zea import log
11
12
  from zea.internal.core import classproperty
12
13
  from zea.models.preset_utils import builtin_presets, get_preset_loader, get_preset_saver
13
14
 
@@ -77,7 +78,7 @@ class BaseModel(keras.models.Model):
77
78
  initialized.
78
79
  **kwargs: Additional keyword arguments.
79
80
 
80
- Examples:
81
+ Example:
81
82
  .. code-block:: python
82
83
 
83
84
  # Load a Gemma backbone with pre-trained weights.
@@ -96,14 +97,29 @@ class BaseModel(keras.models.Model):
96
97
 
97
98
  """
98
99
  loader = get_preset_loader(preset)
99
- model_cls = loader.check_model_class()
100
- if not issubclass(model_cls, cls):
101
- raise ValueError(
102
- f"Saved preset has type `{model_cls.__name__}` which is not "
103
- f"a subclass of calling class `{cls.__name__}`. Call "
104
- f"`from_preset` directly on `{model_cls.__name__}` instead."
105
- )
106
- return loader.load_model(model_cls, load_weights, **kwargs)
100
+ loader_cls = loader.check_model_class()
101
+ if cls != loader_cls:
102
+ full_cls_name = f"{cls.__module__}.{cls.__name__}"
103
+ full_loader_cls_name = f"{loader_cls.__module__}.{loader_cls.__name__}"
104
+ if issubclass(cls, loader_cls):
105
+ log.warning(
106
+ f"The preset '{preset}' is for model class '{full_loader_cls_name}', but you "
107
+ f"are calling from a subclass '{full_cls_name}', so the returned object will "
108
+ f"be of type '{full_cls_name}'."
109
+ )
110
+ elif issubclass(loader_cls, cls):
111
+ log.warning(
112
+ f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
113
+ f"which is a subclass of the calling class '{full_cls_name}', "
114
+ f"so the returned object will be of type '{full_cls_name}'."
115
+ )
116
+ else:
117
+ raise ValueError(
118
+ f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
119
+ f"which is not compatible with the calling class '{full_cls_name}'. "
120
+ f"Please call '{full_loader_cls_name}.from_preset()' instead."
121
+ )
122
+ return loader.load_model(cls, load_weights, **kwargs)
107
123
 
108
124
  def save_to_preset(self, preset_dir):
109
125
  """Save backbone to a preset directory.
@@ -115,7 +131,7 @@ class BaseModel(keras.models.Model):
115
131
  saver.save_model(self)
116
132
 
117
133
 
118
- def deserialize_zea_object(config):
134
+ def deserialize_zea_object(config, cls=None):
119
135
  """Retrieve the object by deserializing the config dict.
120
136
 
121
137
  Need to borrow this function from keras and customize a bit to allow
@@ -132,10 +148,10 @@ def deserialize_zea_object(config):
132
148
  class_name = config["class_name"]
133
149
  inner_config = config["config"] or {}
134
150
 
135
- module = config.get("module", None)
136
- registered_name = config.get("registered_name", class_name)
137
-
138
- cls = _retrieve_class(module, registered_name, config)
151
+ if cls is None:
152
+ module = config.get("module", None)
153
+ registered_name = config.get("registered_name", class_name)
154
+ cls = _retrieve_class(module, registered_name, config)
139
155
 
140
156
  if not hasattr(cls, "from_config"):
141
157
  raise TypeError(