zea 0.0.6__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +3 -3
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +173 -12
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +28 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +390 -196
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +406 -302
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/METADATA +4 -2
- zea-0.0.7.dist-info/RECORD +114 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.7.dist-info}/licenses/LICENSE +0 -0
zea/io_lib.py
CHANGED
|
@@ -19,7 +19,7 @@ from PIL import Image, ImageSequence
|
|
|
19
19
|
from zea import log
|
|
20
20
|
from zea.data.file import File
|
|
21
21
|
|
|
22
|
-
_SUPPORTED_VID_TYPES = [".
|
|
22
|
+
_SUPPORTED_VID_TYPES = [".mp4", ".gif"]
|
|
23
23
|
_SUPPORTED_IMG_TYPES = [".jpg", ".png", ".JPEG", ".PNG", ".jpeg"]
|
|
24
24
|
_SUPPORTED_ZEA_TYPES = [".hdf5", ".h5"]
|
|
25
25
|
|
|
@@ -27,7 +27,7 @@ _SUPPORTED_ZEA_TYPES = [".hdf5", ".h5"]
|
|
|
27
27
|
def load_video(filename, mode="L"):
|
|
28
28
|
"""Load a video file and return a numpy array of frames.
|
|
29
29
|
|
|
30
|
-
Supported file types:
|
|
30
|
+
Supported file types: mp4, gif.
|
|
31
31
|
|
|
32
32
|
Args:
|
|
33
33
|
filename (str): The path to the video file.
|
|
@@ -57,12 +57,24 @@ def load_video(filename, mode="L"):
|
|
|
57
57
|
with Image.open(filename) as im:
|
|
58
58
|
for frame in ImageSequence.Iterator(im):
|
|
59
59
|
frames.append(_convert_image_mode(frame, mode=mode))
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
60
|
+
elif ext == ".mp4":
|
|
61
|
+
# Use imageio with FFMPEG format for MP4 files
|
|
62
|
+
try:
|
|
63
|
+
reader = imageio.get_reader(filename, format="FFMPEG")
|
|
64
|
+
except (ImportError, ValueError) as exc:
|
|
65
|
+
raise ImportError(
|
|
66
|
+
"FFMPEG plugin is required to load MP4 files. "
|
|
67
|
+
"Please install it with 'pip install imageio-ffmpeg'."
|
|
68
|
+
) from exc
|
|
69
|
+
|
|
70
|
+
try:
|
|
71
|
+
for frame in reader:
|
|
72
|
+
img = Image.fromarray(frame)
|
|
73
|
+
frames.append(_convert_image_mode(img, mode=mode))
|
|
74
|
+
finally:
|
|
75
|
+
reader.close()
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Unsupported file extension: {ext}")
|
|
66
78
|
|
|
67
79
|
return np.stack(frames, axis=0)
|
|
68
80
|
|
|
@@ -98,17 +110,151 @@ def load_image(filename, mode="L"):
|
|
|
98
110
|
return _convert_image_mode(img, mode=mode)
|
|
99
111
|
|
|
100
112
|
|
|
101
|
-
def
|
|
102
|
-
"""
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
113
|
+
def save_video(images, filename, fps=20, **kwargs):
|
|
114
|
+
"""Saves a sequence of images to a video file.
|
|
115
|
+
|
|
116
|
+
Supported file types: mp4, gif.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
images (list or np.ndarray): List or array of images. Must have shape
|
|
120
|
+
(n_frames, height, width, channels) or (n_frames, height, width).
|
|
121
|
+
If channel axis is not present, or is 1, grayscale image is assumed,
|
|
122
|
+
which is then converted to RGB. Images should be uint8.
|
|
123
|
+
filename (str or Path): Filename to which data should be written.
|
|
124
|
+
fps (int): Frames per second of rendered format.
|
|
125
|
+
**kwargs: Additional keyword arguments passed to the specific save function.
|
|
126
|
+
For GIF files, this includes `shared_color_palette` (bool).
|
|
127
|
+
|
|
128
|
+
Raises:
|
|
129
|
+
ValueError: If the file extension is not supported.
|
|
130
|
+
|
|
131
|
+
"""
|
|
132
|
+
filename = Path(filename)
|
|
133
|
+
ext = filename.suffix.lower()
|
|
134
|
+
|
|
135
|
+
if ext == ".mp4":
|
|
136
|
+
return save_to_mp4(images, filename, fps=fps)
|
|
137
|
+
elif ext == ".gif":
|
|
138
|
+
return save_to_gif(images, filename, fps=fps, **kwargs)
|
|
139
|
+
else:
|
|
140
|
+
raise ValueError(f"Unsupported file extension: {ext}")
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def save_to_gif(images, filename, fps=20, shared_color_palette=False):
|
|
144
|
+
"""Saves a sequence of images to a GIF file.
|
|
145
|
+
|
|
146
|
+
.. note::
|
|
147
|
+
It's recommended to use :func:`save_video` for a more general interface
|
|
148
|
+
that supports multiple formats.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
images (list or np.ndarray): List or array of images. Must have shape
|
|
152
|
+
(n_frames, height, width, channels) or (n_frames, height, width).
|
|
153
|
+
If channel axis is not present, or is 1, grayscale image is assumed,
|
|
154
|
+
which is then converted to RGB. Images should be uint8.
|
|
155
|
+
filename (str or Path): Filename to which data should be written.
|
|
156
|
+
fps (int): Frames per second of rendered format.
|
|
157
|
+
shared_color_palette (bool, optional): If True, creates a global
|
|
158
|
+
color palette across all frames, ensuring consistent colors
|
|
159
|
+
throughout the GIF. Defaults to False, which is default behavior
|
|
160
|
+
of PIL.Image.save. Note: True can cause slow saving for longer
|
|
161
|
+
sequences, and also lead to larger file sizes in some cases.
|
|
162
|
+
|
|
163
|
+
"""
|
|
164
|
+
images = preprocess_for_saving(images)
|
|
165
|
+
|
|
166
|
+
if fps > 50:
|
|
167
|
+
log.warning(f"Cannot set fps ({fps}) > 50. Setting it automatically to 50.")
|
|
168
|
+
fps = 50
|
|
169
|
+
|
|
170
|
+
duration = int(round(1000 / fps)) # milliseconds per frame
|
|
171
|
+
|
|
172
|
+
pillow_imgs = [Image.fromarray(img) for img in images]
|
|
173
|
+
|
|
174
|
+
if shared_color_palette:
|
|
175
|
+
# Apply the same palette to all frames without dithering for consistent color mapping
|
|
176
|
+
# Convert all images to RGB and combine their colors for palette generation
|
|
177
|
+
all_colors = np.vstack([np.array(img.convert("RGB")).reshape(-1, 3) for img in pillow_imgs])
|
|
178
|
+
combined_image = Image.fromarray(all_colors.reshape(-1, 1, 3))
|
|
179
|
+
|
|
180
|
+
# Generate palette from all frames
|
|
181
|
+
global_palette = combined_image.quantize(
|
|
182
|
+
colors=256,
|
|
183
|
+
method=Image.MEDIANCUT,
|
|
184
|
+
kmeans=1,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
# Apply the same palette to all frames without dithering
|
|
188
|
+
pillow_imgs = [
|
|
189
|
+
img.convert("RGB").quantize(
|
|
190
|
+
palette=global_palette,
|
|
191
|
+
dither=Image.NONE,
|
|
192
|
+
)
|
|
193
|
+
for img in pillow_imgs
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
pillow_img, *pillow_imgs = pillow_imgs
|
|
197
|
+
|
|
198
|
+
pillow_img.save(
|
|
199
|
+
fp=filename,
|
|
200
|
+
format="GIF",
|
|
201
|
+
append_images=pillow_imgs,
|
|
202
|
+
save_all=True,
|
|
203
|
+
loop=0,
|
|
204
|
+
duration=duration,
|
|
205
|
+
interlace=False,
|
|
206
|
+
optimize=False,
|
|
207
|
+
)
|
|
208
|
+
log.success(f"Successfully saved GIF to -> {log.yellow(filename)}")
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
def save_to_mp4(images, filename, fps=20):
|
|
212
|
+
"""Saves a sequence of images to an MP4 file.
|
|
213
|
+
|
|
214
|
+
.. note::
|
|
215
|
+
It's recommended to use :func:`save_video` for a more general interface
|
|
216
|
+
that supports multiple formats.
|
|
217
|
+
|
|
218
|
+
Args:
|
|
219
|
+
images (list or np.ndarray): List or array of images. Must have shape
|
|
220
|
+
(n_frames, height, width, channels) or (n_frames, height, width).
|
|
221
|
+
If channel axis is not present, or is 1, grayscale image is assumed,
|
|
222
|
+
which is then converted to RGB. Images should be uint8.
|
|
223
|
+
filename (str or Path): Filename to which data should be written.
|
|
224
|
+
fps (int): Frames per second of rendered format.
|
|
225
|
+
|
|
226
|
+
Raises:
|
|
227
|
+
ImportError: If imageio-ffmpeg is not installed.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
str: Success message.
|
|
231
|
+
|
|
232
|
+
"""
|
|
233
|
+
images = preprocess_for_saving(images)
|
|
234
|
+
|
|
235
|
+
filename = str(filename)
|
|
236
|
+
|
|
237
|
+
parent_dir = Path(filename).parent
|
|
238
|
+
parent_dir.mkdir(parents=True, exist_ok=True)
|
|
239
|
+
|
|
240
|
+
# Use imageio with FFMPEG format for MP4 files
|
|
241
|
+
try:
|
|
242
|
+
writer = imageio.get_writer(
|
|
243
|
+
filename, fps=fps, format="FFMPEG", codec="libx264", pixelformat="yuv420p"
|
|
244
|
+
)
|
|
245
|
+
except (ImportError, ValueError) as exc:
|
|
246
|
+
raise ImportError(
|
|
247
|
+
"FFMPEG plugin is required to save MP4 files. "
|
|
248
|
+
"Please install it with 'pip install imageio-ffmpeg'."
|
|
249
|
+
) from exc
|
|
250
|
+
|
|
251
|
+
try:
|
|
252
|
+
for image in images:
|
|
253
|
+
writer.append_data(image)
|
|
254
|
+
finally:
|
|
255
|
+
writer.close()
|
|
256
|
+
|
|
257
|
+
return log.success(f"Successfully saved MP4 to -> {filename}")
|
|
112
258
|
|
|
113
259
|
|
|
114
260
|
def search_file_tree(
|
|
@@ -341,3 +487,85 @@ def retry_on_io_error(max_retries=3, initial_delay=0.5, retry_action=None):
|
|
|
341
487
|
return wrapper
|
|
342
488
|
|
|
343
489
|
return decorator
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def _convert_image_mode(img, mode="L"):
|
|
493
|
+
"""Convert a PIL Image to the specified mode and return as numpy array."""
|
|
494
|
+
if mode not in {"L", "RGB"}:
|
|
495
|
+
raise ValueError(f"Unsupported mode: {mode}, must be one of: L, RGB")
|
|
496
|
+
if mode == "L":
|
|
497
|
+
img = img.convert("L")
|
|
498
|
+
arr = np.array(img)
|
|
499
|
+
elif mode == "RGB":
|
|
500
|
+
img = img.convert("RGB")
|
|
501
|
+
arr = np.array(img)
|
|
502
|
+
return arr
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
def grayscale_to_rgb(image):
|
|
506
|
+
"""Converts a grayscale image to an RGB image.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
image (ndarray): Grayscale image. Must have shape (height, width).
|
|
510
|
+
|
|
511
|
+
Returns:
|
|
512
|
+
ndarray: RGB image.
|
|
513
|
+
"""
|
|
514
|
+
assert image.ndim == 2, "Input image must be grayscale."
|
|
515
|
+
# Stack the grayscale image into 3 channels (RGB)
|
|
516
|
+
return np.stack([image] * 3, axis=-1)
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
def _assert_uint8_images(images: np.ndarray):
|
|
520
|
+
"""
|
|
521
|
+
Asserts that the input images have the correct properties.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
images (np.ndarray): The input images.
|
|
525
|
+
|
|
526
|
+
Raises:
|
|
527
|
+
AssertionError: If the dtype of images is not uint8.
|
|
528
|
+
AssertionError: If the shape of images is not (n_frames, height, width, channels)
|
|
529
|
+
or (n_frames, height, width) for grayscale images.
|
|
530
|
+
AssertionError: If images have anything other than 1 (grayscale),
|
|
531
|
+
3 (rgb) or 4 (rgba) channels.
|
|
532
|
+
"""
|
|
533
|
+
assert images.dtype == np.uint8, f"dtype of images should be uint8, got {images.dtype}"
|
|
534
|
+
|
|
535
|
+
assert images.ndim in (3, 4), (
|
|
536
|
+
"images must have shape (n_frames, height, width, channels),"
|
|
537
|
+
f" or (n_frames, height, width) for grayscale images. Got {images.shape}"
|
|
538
|
+
)
|
|
539
|
+
|
|
540
|
+
if images.ndim == 4:
|
|
541
|
+
assert images.shape[-1] in (1, 3, 4), (
|
|
542
|
+
"Grayscale images must have 1 channel, "
|
|
543
|
+
"RGB images must have 3 channels, and RGBA images must have 4 channels. "
|
|
544
|
+
f"Got shape: {images.shape}, channels: {images.shape[-1]}"
|
|
545
|
+
)
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
def preprocess_for_saving(images):
|
|
549
|
+
"""Preprocesses images for saving to GIF or MP4.
|
|
550
|
+
|
|
551
|
+
Args:
|
|
552
|
+
images (ndarray, list[ndarray]): Images. Must have shape (n_frames, height, width, channels)
|
|
553
|
+
or (n_frames, height, width).
|
|
554
|
+
"""
|
|
555
|
+
images = np.array(images)
|
|
556
|
+
_assert_uint8_images(images)
|
|
557
|
+
|
|
558
|
+
# Remove channel axis if it is 1 (grayscale image)
|
|
559
|
+
if images.ndim == 4 and images.shape[-1] == 1:
|
|
560
|
+
images = np.squeeze(images, axis=-1)
|
|
561
|
+
|
|
562
|
+
# convert grayscale images to RGB
|
|
563
|
+
if images.ndim == 3:
|
|
564
|
+
images = [grayscale_to_rgb(image) for image in images]
|
|
565
|
+
images = np.array(images)
|
|
566
|
+
|
|
567
|
+
# drop alpha channel if present (RGBA -> RGB)
|
|
568
|
+
if images.ndim == 4 and images.shape[-1] == 4:
|
|
569
|
+
images = images[..., :3]
|
|
570
|
+
|
|
571
|
+
return images
|
zea/keras_ops.py
CHANGED
|
@@ -3,14 +3,14 @@ and :mod:`keras.ops.image` functions.
|
|
|
3
3
|
|
|
4
4
|
They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
|
|
5
5
|
|
|
6
|
-
..
|
|
6
|
+
.. doctest::
|
|
7
7
|
|
|
8
|
-
from zea.keras_ops import Squeeze
|
|
8
|
+
>>> from zea.keras_ops import Squeeze
|
|
9
9
|
|
|
10
|
-
op = Squeeze(axis=1)
|
|
10
|
+
>>> op = Squeeze(axis=1)
|
|
11
11
|
|
|
12
12
|
This file is generated automatically. Do not edit manually.
|
|
13
|
-
Generated with Keras 3.
|
|
13
|
+
Generated with Keras 3.12.0
|
|
14
14
|
"""
|
|
15
15
|
|
|
16
16
|
import keras
|
|
@@ -388,6 +388,16 @@ class Cholesky(Lambda):
|
|
|
388
388
|
except AttributeError as e:
|
|
389
389
|
raise MissingKerasOps("Cholesky", "keras.ops.cholesky") from e
|
|
390
390
|
|
|
391
|
+
@ops_registry("keras.ops.cholesky_inverse")
|
|
392
|
+
class CholeskyInverse(Lambda):
|
|
393
|
+
"""Operation wrapping keras.ops.cholesky_inverse."""
|
|
394
|
+
|
|
395
|
+
def __init__(self, **kwargs):
|
|
396
|
+
try:
|
|
397
|
+
super().__init__(func=keras.ops.cholesky_inverse, **kwargs)
|
|
398
|
+
except AttributeError as e:
|
|
399
|
+
raise MissingKerasOps("CholeskyInverse", "keras.ops.cholesky_inverse") from e
|
|
400
|
+
|
|
391
401
|
@ops_registry("keras.ops.clip")
|
|
392
402
|
class Clip(Lambda):
|
|
393
403
|
"""Operation wrapping keras.ops.clip."""
|
|
@@ -918,6 +928,36 @@ class Isnan(Lambda):
|
|
|
918
928
|
except AttributeError as e:
|
|
919
929
|
raise MissingKerasOps("Isnan", "keras.ops.isnan") from e
|
|
920
930
|
|
|
931
|
+
@ops_registry("keras.ops.isneginf")
|
|
932
|
+
class Isneginf(Lambda):
|
|
933
|
+
"""Operation wrapping keras.ops.isneginf."""
|
|
934
|
+
|
|
935
|
+
def __init__(self, **kwargs):
|
|
936
|
+
try:
|
|
937
|
+
super().__init__(func=keras.ops.isneginf, **kwargs)
|
|
938
|
+
except AttributeError as e:
|
|
939
|
+
raise MissingKerasOps("Isneginf", "keras.ops.isneginf") from e
|
|
940
|
+
|
|
941
|
+
@ops_registry("keras.ops.isposinf")
|
|
942
|
+
class Isposinf(Lambda):
|
|
943
|
+
"""Operation wrapping keras.ops.isposinf."""
|
|
944
|
+
|
|
945
|
+
def __init__(self, **kwargs):
|
|
946
|
+
try:
|
|
947
|
+
super().__init__(func=keras.ops.isposinf, **kwargs)
|
|
948
|
+
except AttributeError as e:
|
|
949
|
+
raise MissingKerasOps("Isposinf", "keras.ops.isposinf") from e
|
|
950
|
+
|
|
951
|
+
@ops_registry("keras.ops.isreal")
|
|
952
|
+
class Isreal(Lambda):
|
|
953
|
+
"""Operation wrapping keras.ops.isreal."""
|
|
954
|
+
|
|
955
|
+
def __init__(self, **kwargs):
|
|
956
|
+
try:
|
|
957
|
+
super().__init__(func=keras.ops.isreal, **kwargs)
|
|
958
|
+
except AttributeError as e:
|
|
959
|
+
raise MissingKerasOps("Isreal", "keras.ops.isreal") from e
|
|
960
|
+
|
|
921
961
|
@ops_registry("keras.ops.istft")
|
|
922
962
|
class Istft(Lambda):
|
|
923
963
|
"""Operation wrapping keras.ops.istft."""
|
|
@@ -1828,6 +1868,16 @@ class Trunc(Lambda):
|
|
|
1828
1868
|
except AttributeError as e:
|
|
1829
1869
|
raise MissingKerasOps("Trunc", "keras.ops.trunc") from e
|
|
1830
1870
|
|
|
1871
|
+
@ops_registry("keras.ops.unfold")
|
|
1872
|
+
class Unfold(Lambda):
|
|
1873
|
+
"""Operation wrapping keras.ops.unfold."""
|
|
1874
|
+
|
|
1875
|
+
def __init__(self, **kwargs):
|
|
1876
|
+
try:
|
|
1877
|
+
super().__init__(func=keras.ops.unfold, **kwargs)
|
|
1878
|
+
except AttributeError as e:
|
|
1879
|
+
raise MissingKerasOps("Unfold", "keras.ops.unfold") from e
|
|
1880
|
+
|
|
1831
1881
|
@ops_registry("keras.ops.unstack")
|
|
1832
1882
|
class Unstack(Lambda):
|
|
1833
1883
|
"""Operation wrapping keras.ops.unstack."""
|
|
@@ -1848,6 +1898,16 @@ class Var(Lambda):
|
|
|
1848
1898
|
except AttributeError as e:
|
|
1849
1899
|
raise MissingKerasOps("Var", "keras.ops.var") from e
|
|
1850
1900
|
|
|
1901
|
+
@ops_registry("keras.ops.view")
|
|
1902
|
+
class View(Lambda):
|
|
1903
|
+
"""Operation wrapping keras.ops.view."""
|
|
1904
|
+
|
|
1905
|
+
def __init__(self, **kwargs):
|
|
1906
|
+
try:
|
|
1907
|
+
super().__init__(func=keras.ops.view, **kwargs)
|
|
1908
|
+
except AttributeError as e:
|
|
1909
|
+
raise MissingKerasOps("View", "keras.ops.view") from e
|
|
1910
|
+
|
|
1851
1911
|
@ops_registry("keras.ops.view_as_complex")
|
|
1852
1912
|
class ViewAsComplex(Lambda):
|
|
1853
1913
|
"""Operation wrapping keras.ops.view_as_complex."""
|
|
@@ -1987,3 +2047,13 @@ class RgbToHsv(Lambda):
|
|
|
1987
2047
|
super().__init__(func=keras.ops.image.rgb_to_hsv, **kwargs)
|
|
1988
2048
|
except AttributeError as e:
|
|
1989
2049
|
raise MissingKerasOps("RgbToHsv", "keras.ops.image.rgb_to_hsv") from e
|
|
2050
|
+
|
|
2051
|
+
@ops_registry("keras.ops.image.scale_and_translate")
|
|
2052
|
+
class ScaleAndTranslate(Lambda):
|
|
2053
|
+
"""Operation wrapping keras.ops.image.scale_and_translate."""
|
|
2054
|
+
|
|
2055
|
+
def __init__(self, **kwargs):
|
|
2056
|
+
try:
|
|
2057
|
+
super().__init__(func=keras.ops.image.scale_and_translate, **kwargs)
|
|
2058
|
+
except AttributeError as e:
|
|
2059
|
+
raise MissingKerasOps("ScaleAndTranslate", "keras.ops.image.scale_and_translate") from e
|
zea/log.py
CHANGED
|
@@ -6,7 +6,7 @@ to the console and to a file with color support.
|
|
|
6
6
|
Example usage
|
|
7
7
|
^^^^^^^^^^^^^^
|
|
8
8
|
|
|
9
|
-
..
|
|
9
|
+
.. testsetup::
|
|
10
10
|
|
|
11
11
|
from zea import log
|
|
12
12
|
|
|
@@ -323,18 +323,20 @@ def set_level(level):
|
|
|
323
323
|
|
|
324
324
|
Also sets the log level for the file logger if it exists.
|
|
325
325
|
|
|
326
|
-
Example:
|
|
327
|
-
>>> from zea import log
|
|
328
|
-
>>> with log.set_level("WARNING"):
|
|
329
|
-
... log.info("This will not be shown")
|
|
330
|
-
... log.warning("This will be shown")
|
|
331
|
-
|
|
332
326
|
Args:
|
|
333
327
|
level (str or int): The log level to set temporarily
|
|
334
328
|
(e.g., "DEBUG", "INFO", logging.WARNING).
|
|
335
329
|
|
|
336
330
|
Yields:
|
|
337
331
|
None
|
|
332
|
+
|
|
333
|
+
Example:
|
|
334
|
+
.. doctest::
|
|
335
|
+
|
|
336
|
+
>>> from zea import log
|
|
337
|
+
>>> with log.set_level("ERROR"):
|
|
338
|
+
... _ = log.info("Info messages will not be shown")
|
|
339
|
+
... _ = log.error("Error messages will be shown")
|
|
338
340
|
"""
|
|
339
341
|
prev_level = logger.level
|
|
340
342
|
prev_file_level = file_logger.level if file_logger else None
|
zea/metrics.py
CHANGED
|
@@ -10,8 +10,9 @@ from keras import ops
|
|
|
10
10
|
from zea import log, tensor_ops
|
|
11
11
|
from zea.backend import func_on_device
|
|
12
12
|
from zea.internal.registry import metrics_registry
|
|
13
|
+
from zea.internal.utils import reduce_to_signature
|
|
13
14
|
from zea.models.lpips import LPIPS
|
|
14
|
-
from zea.
|
|
15
|
+
from zea.tensor_ops import translate
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
def get_metric(name, **kwargs):
|
|
@@ -313,11 +314,18 @@ class Metrics:
|
|
|
313
314
|
if specified.
|
|
314
315
|
|
|
315
316
|
Example:
|
|
316
|
-
..
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
317
|
+
.. doctest::
|
|
318
|
+
|
|
319
|
+
>>> from zea import metrics
|
|
320
|
+
>>> import numpy as np
|
|
321
|
+
|
|
322
|
+
>>> metrics = metrics.Metrics(["psnr", "lpips"], image_range=[0, 255])
|
|
323
|
+
>>> y_true = np.random.rand(4, 128, 128, 1)
|
|
324
|
+
>>> y_pred = np.random.rand(4, 128, 128, 1)
|
|
325
|
+
>>> result = metrics(y_true, y_pred)
|
|
326
|
+
>>> result = {k: float(v) for k, v in result.items()}
|
|
327
|
+
>>> print(result) # doctest: +ELLIPSIS
|
|
328
|
+
{'psnr': ..., 'lpips': ...}
|
|
321
329
|
"""
|
|
322
330
|
|
|
323
331
|
def __init__(
|
|
@@ -364,7 +372,7 @@ class Metrics:
|
|
|
364
372
|
# Because most metric functions do not support batching, we vmap over the batch axes.
|
|
365
373
|
metric_fn = fun
|
|
366
374
|
for ax in reversed(batch_axes):
|
|
367
|
-
metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax)
|
|
375
|
+
metric_fn = tensor_ops.vmap(metric_fn, in_axes=ax, _use_torch_vmap=True)
|
|
368
376
|
|
|
369
377
|
out = func_on_device(metric_fn, device, y_true, y_pred)
|
|
370
378
|
|
zea/models/__init__.py
CHANGED
|
@@ -2,31 +2,37 @@
|
|
|
2
2
|
|
|
3
3
|
``zea`` contains a collection of models for various tasks, all located in the :mod:`zea.models` package.
|
|
4
4
|
|
|
5
|
-
|
|
5
|
+
See the following dropdown for a list of available models:
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
- :class:`zea.models.
|
|
10
|
-
- :class:`zea.models.
|
|
11
|
-
- :class:`zea.models.
|
|
12
|
-
- :class:`zea.models.
|
|
7
|
+
.. dropdown:: **Available models**
|
|
8
|
+
|
|
9
|
+
- :class:`zea.models.echonet.EchoNetDynamic`: A model for left ventricle segmentation.
|
|
10
|
+
- :class:`zea.models.carotid_segmenter.CarotidSegmenter`: A model for carotid artery segmentation.
|
|
11
|
+
- :class:`zea.models.echonetlvh.EchoNetLVH`: A model for left ventricle hypertrophy segmentation.
|
|
12
|
+
- :class:`zea.models.unet.UNet`: A simple U-Net implementation.
|
|
13
|
+
- :class:`zea.models.lpips.LPIPS`: A model implementing the perceptual similarity metric.
|
|
14
|
+
- :class:`zea.models.taesd.TinyAutoencoder`: A tiny autoencoder model for image compression.
|
|
15
|
+
- :class:`zea.models.regional_quality.MobileNetv2RegionalQuality`: A scoring model for myocardial regions in apical views.
|
|
16
|
+
- :class:`zea.models.lv_segmentation.AugmentedCamusSeg`: A nnU-Net based left ventricle and myocardium segmentation model.
|
|
13
17
|
|
|
14
18
|
Presets for these models can be found in :mod:`zea.models.presets`.
|
|
15
19
|
|
|
16
20
|
To use these models, you can import them directly from the :mod:`zea.models` module and load the pretrained weights using the :meth:`from_preset` method. For example:
|
|
17
21
|
|
|
18
|
-
..
|
|
22
|
+
.. doctest::
|
|
19
23
|
|
|
20
|
-
from zea.models.unet import UNet
|
|
24
|
+
>>> from zea.models.unet import UNet
|
|
21
25
|
|
|
22
|
-
model = UNet.from_preset("unet-echonet-inpainter")
|
|
26
|
+
>>> model = UNet.from_preset("unet-echonet-inpainter")
|
|
23
27
|
|
|
24
28
|
You can list all available presets using the :attr:`presets` attribute:
|
|
25
29
|
|
|
26
|
-
..
|
|
30
|
+
.. doctest::
|
|
27
31
|
|
|
28
|
-
|
|
29
|
-
|
|
32
|
+
>>> from zea.models.unet import UNet
|
|
33
|
+
>>> presets = list(UNet.presets.keys())
|
|
34
|
+
>>> print(f"Available built-in zea presets for UNet: {presets}")
|
|
35
|
+
Available built-in zea presets for UNet: ['unet-echonet-inpainter']
|
|
30
36
|
|
|
31
37
|
|
|
32
38
|
zea generative models
|
|
@@ -40,19 +46,21 @@ Typically, these models have some additional methods, such as:
|
|
|
40
46
|
- :meth:`posterior_sample` for drawing samples from the posterior given measurements
|
|
41
47
|
- :meth:`log_density` for computing the log-probability of data under the model
|
|
42
48
|
|
|
43
|
-
|
|
49
|
+
See the following dropdown for a list of available *generative* models:
|
|
50
|
+
|
|
51
|
+
.. dropdown:: **Available models**
|
|
44
52
|
|
|
45
|
-
- :class:`zea.models.diffusion.DiffusionModel`: A deep generative diffusion model for ultrasound image generation.
|
|
46
|
-
- :class:`zea.models.gmm.GaussianMixtureModel`: A Gaussian Mixture Model.
|
|
53
|
+
- :class:`zea.models.diffusion.DiffusionModel`: A deep generative diffusion model for ultrasound image generation.
|
|
54
|
+
- :class:`zea.models.gmm.GaussianMixtureModel`: A Gaussian Mixture Model.
|
|
47
55
|
|
|
48
56
|
An example of how to use the :class:`zea.models.diffusion.DiffusionModel` is shown below:
|
|
49
57
|
|
|
50
|
-
..
|
|
58
|
+
.. doctest::
|
|
51
59
|
|
|
52
|
-
from zea.models.diffusion import DiffusionModel
|
|
60
|
+
>>> from zea.models.diffusion import DiffusionModel
|
|
53
61
|
|
|
54
|
-
model = DiffusionModel.from_preset("diffusion-echonet-dynamic")
|
|
55
|
-
samples = model.sample(n_samples=4)
|
|
62
|
+
>>> model = DiffusionModel.from_preset("diffusion-echonet-dynamic") # doctest: +SKIP
|
|
63
|
+
>>> samples = model.sample(n_samples=4) # doctest: +SKIP
|
|
56
64
|
|
|
57
65
|
|
|
58
66
|
Contributing and adding new models
|
|
@@ -84,7 +92,9 @@ from . import (
|
|
|
84
92
|
gmm,
|
|
85
93
|
layers,
|
|
86
94
|
lpips,
|
|
95
|
+
lv_segmentation,
|
|
87
96
|
presets,
|
|
97
|
+
regional_quality,
|
|
88
98
|
taesd,
|
|
89
99
|
unet,
|
|
90
100
|
utils,
|
zea/models/base.py
CHANGED
|
@@ -8,6 +8,7 @@ import importlib
|
|
|
8
8
|
import keras
|
|
9
9
|
from keras.src.saving.serialization_lib import record_object_after_deserialization
|
|
10
10
|
|
|
11
|
+
from zea import log
|
|
11
12
|
from zea.internal.core import classproperty
|
|
12
13
|
from zea.models.preset_utils import builtin_presets, get_preset_loader, get_preset_saver
|
|
13
14
|
|
|
@@ -77,7 +78,7 @@ class BaseModel(keras.models.Model):
|
|
|
77
78
|
initialized.
|
|
78
79
|
**kwargs: Additional keyword arguments.
|
|
79
80
|
|
|
80
|
-
|
|
81
|
+
Example:
|
|
81
82
|
.. code-block:: python
|
|
82
83
|
|
|
83
84
|
# Load a Gemma backbone with pre-trained weights.
|
|
@@ -96,14 +97,29 @@ class BaseModel(keras.models.Model):
|
|
|
96
97
|
|
|
97
98
|
"""
|
|
98
99
|
loader = get_preset_loader(preset)
|
|
99
|
-
|
|
100
|
-
if
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
100
|
+
loader_cls = loader.check_model_class()
|
|
101
|
+
if cls != loader_cls:
|
|
102
|
+
full_cls_name = f"{cls.__module__}.{cls.__name__}"
|
|
103
|
+
full_loader_cls_name = f"{loader_cls.__module__}.{loader_cls.__name__}"
|
|
104
|
+
if issubclass(cls, loader_cls):
|
|
105
|
+
log.warning(
|
|
106
|
+
f"The preset '{preset}' is for model class '{full_loader_cls_name}', but you "
|
|
107
|
+
f"are calling from a subclass '{full_cls_name}', so the returned object will "
|
|
108
|
+
f"be of type '{full_cls_name}'."
|
|
109
|
+
)
|
|
110
|
+
elif issubclass(loader_cls, cls):
|
|
111
|
+
log.warning(
|
|
112
|
+
f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
|
|
113
|
+
f"which is a subclass of the calling class '{full_cls_name}', "
|
|
114
|
+
f"so the returned object will be of type '{full_cls_name}'."
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
f"The preset '{preset}' is for model class '{full_loader_cls_name}', "
|
|
119
|
+
f"which is not compatible with the calling class '{full_cls_name}'. "
|
|
120
|
+
f"Please call '{full_loader_cls_name}.from_preset()' instead."
|
|
121
|
+
)
|
|
122
|
+
return loader.load_model(cls, load_weights, **kwargs)
|
|
107
123
|
|
|
108
124
|
def save_to_preset(self, preset_dir):
|
|
109
125
|
"""Save backbone to a preset directory.
|
|
@@ -115,7 +131,7 @@ class BaseModel(keras.models.Model):
|
|
|
115
131
|
saver.save_model(self)
|
|
116
132
|
|
|
117
133
|
|
|
118
|
-
def deserialize_zea_object(config):
|
|
134
|
+
def deserialize_zea_object(config, cls=None):
|
|
119
135
|
"""Retrieve the object by deserializing the config dict.
|
|
120
136
|
|
|
121
137
|
Need to borrow this function from keras and customize a bit to allow
|
|
@@ -132,10 +148,10 @@ def deserialize_zea_object(config):
|
|
|
132
148
|
class_name = config["class_name"]
|
|
133
149
|
inner_config = config["config"] or {}
|
|
134
150
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
151
|
+
if cls is None:
|
|
152
|
+
module = config.get("module", None)
|
|
153
|
+
registered_name = config.get("registered_name", class_name)
|
|
154
|
+
cls = _retrieve_class(module, registered_name, config)
|
|
139
155
|
|
|
140
156
|
if not hasattr(cls, "from_config"):
|
|
141
157
|
raise TypeError(
|