zea 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/agent/selection.py +166 -0
  5. zea/backend/__init__.py +89 -0
  6. zea/backend/jax/__init__.py +14 -51
  7. zea/backend/tensorflow/__init__.py +0 -49
  8. zea/backend/tensorflow/dataloader.py +2 -1
  9. zea/backend/torch/__init__.py +27 -62
  10. zea/beamform/beamformer.py +100 -50
  11. zea/beamform/lens_correction.py +9 -2
  12. zea/beamform/pfield.py +9 -2
  13. zea/config.py +34 -25
  14. zea/data/__init__.py +22 -16
  15. zea/data/convert/camus.py +2 -1
  16. zea/data/convert/echonet.py +4 -4
  17. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  18. zea/data/convert/matlab.py +11 -4
  19. zea/data/data_format.py +31 -30
  20. zea/data/datasets.py +7 -5
  21. zea/data/file.py +104 -2
  22. zea/data/layers.py +5 -6
  23. zea/datapaths.py +16 -4
  24. zea/display.py +7 -5
  25. zea/interface.py +14 -16
  26. zea/internal/_generate_keras_ops.py +6 -7
  27. zea/internal/cache.py +2 -49
  28. zea/internal/config/validation.py +1 -2
  29. zea/internal/core.py +69 -6
  30. zea/internal/device.py +6 -2
  31. zea/internal/dummy_scan.py +330 -0
  32. zea/internal/operators.py +114 -2
  33. zea/internal/parameters.py +101 -70
  34. zea/internal/registry.py +1 -1
  35. zea/internal/setup_zea.py +5 -6
  36. zea/internal/utils.py +282 -0
  37. zea/io_lib.py +247 -19
  38. zea/keras_ops.py +74 -4
  39. zea/log.py +9 -7
  40. zea/metrics.py +365 -65
  41. zea/models/__init__.py +30 -20
  42. zea/models/base.py +30 -14
  43. zea/models/carotid_segmenter.py +19 -4
  44. zea/models/diffusion.py +187 -26
  45. zea/models/echonet.py +22 -8
  46. zea/models/echonetlvh.py +31 -18
  47. zea/models/lpips.py +19 -2
  48. zea/models/lv_segmentation.py +96 -0
  49. zea/models/preset_utils.py +5 -5
  50. zea/models/presets.py +36 -0
  51. zea/models/regional_quality.py +142 -0
  52. zea/models/taesd.py +21 -5
  53. zea/models/unet.py +15 -1
  54. zea/ops.py +414 -207
  55. zea/probes.py +6 -6
  56. zea/scan.py +109 -49
  57. zea/simulator.py +24 -21
  58. zea/tensor_ops.py +411 -206
  59. zea/tools/hf.py +1 -1
  60. zea/tools/selection_tool.py +47 -86
  61. zea/utils.py +92 -480
  62. zea/visualize.py +177 -39
  63. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/METADATA +9 -3
  64. zea-0.0.7.dist-info/RECORD +114 -0
  65. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/WHEEL +1 -1
  66. zea-0.0.5.dist-info/RECORD +0 -110
  67. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  68. {zea-0.0.5.dist-info → zea-0.0.7.dist-info/licenses}/LICENSE +0 -0
zea/io_lib.py CHANGED
@@ -19,7 +19,7 @@ from PIL import Image, ImageSequence
19
19
  from zea import log
20
20
  from zea.data.file import File
21
21
 
22
- _SUPPORTED_VID_TYPES = [".avi", ".mp4", ".gif"]
22
+ _SUPPORTED_VID_TYPES = [".mp4", ".gif"]
23
23
  _SUPPORTED_IMG_TYPES = [".jpg", ".png", ".JPEG", ".PNG", ".jpeg"]
24
24
  _SUPPORTED_ZEA_TYPES = [".hdf5", ".h5"]
25
25
 
@@ -27,7 +27,7 @@ _SUPPORTED_ZEA_TYPES = [".hdf5", ".h5"]
27
27
  def load_video(filename, mode="L"):
28
28
  """Load a video file and return a numpy array of frames.
29
29
 
30
- Supported file types: avi, mp4, gif.
30
+ Supported file types: mp4, gif.
31
31
 
32
32
  Args:
33
33
  filename (str): The path to the video file.
@@ -57,12 +57,24 @@ def load_video(filename, mode="L"):
57
57
  with Image.open(filename) as im:
58
58
  for frame in ImageSequence.Iterator(im):
59
59
  frames.append(_convert_image_mode(frame, mode=mode))
60
- else: # .mp4, .avi
61
- reader = imageio.get_reader(filename)
62
- for frame in reader:
63
- img = Image.fromarray(frame)
64
- frames.append(_convert_image_mode(img, mode=mode))
65
- reader.close()
60
+ elif ext == ".mp4":
61
+ # Use imageio with FFMPEG format for MP4 files
62
+ try:
63
+ reader = imageio.get_reader(filename, format="FFMPEG")
64
+ except (ImportError, ValueError) as exc:
65
+ raise ImportError(
66
+ "FFMPEG plugin is required to load MP4 files. "
67
+ "Please install it with 'pip install imageio-ffmpeg'."
68
+ ) from exc
69
+
70
+ try:
71
+ for frame in reader:
72
+ img = Image.fromarray(frame)
73
+ frames.append(_convert_image_mode(img, mode=mode))
74
+ finally:
75
+ reader.close()
76
+ else:
77
+ raise ValueError(f"Unsupported file extension: {ext}")
66
78
 
67
79
  return np.stack(frames, axis=0)
68
80
 
@@ -98,17 +110,151 @@ def load_image(filename, mode="L"):
98
110
  return _convert_image_mode(img, mode=mode)
99
111
 
100
112
 
101
- def _convert_image_mode(img, mode="L"):
102
- """Convert a PIL Image to the specified mode and return as numpy array."""
103
- if mode not in {"L", "RGB"}:
104
- raise ValueError(f"Unsupported mode: {mode}, must be one of: L, RGB")
105
- if mode == "L":
106
- img = img.convert("L")
107
- arr = np.array(img)
108
- elif mode == "RGB":
109
- img = img.convert("RGB")
110
- arr = np.array(img)
111
- return arr
113
+ def save_video(images, filename, fps=20, **kwargs):
114
+ """Saves a sequence of images to a video file.
115
+
116
+ Supported file types: mp4, gif.
117
+
118
+ Args:
119
+ images (list or np.ndarray): List or array of images. Must have shape
120
+ (n_frames, height, width, channels) or (n_frames, height, width).
121
+ If channel axis is not present, or is 1, grayscale image is assumed,
122
+ which is then converted to RGB. Images should be uint8.
123
+ filename (str or Path): Filename to which data should be written.
124
+ fps (int): Frames per second of rendered format.
125
+ **kwargs: Additional keyword arguments passed to the specific save function.
126
+ For GIF files, this includes `shared_color_palette` (bool).
127
+
128
+ Raises:
129
+ ValueError: If the file extension is not supported.
130
+
131
+ """
132
+ filename = Path(filename)
133
+ ext = filename.suffix.lower()
134
+
135
+ if ext == ".mp4":
136
+ return save_to_mp4(images, filename, fps=fps)
137
+ elif ext == ".gif":
138
+ return save_to_gif(images, filename, fps=fps, **kwargs)
139
+ else:
140
+ raise ValueError(f"Unsupported file extension: {ext}")
141
+
142
+
143
+ def save_to_gif(images, filename, fps=20, shared_color_palette=False):
144
+ """Saves a sequence of images to a GIF file.
145
+
146
+ .. note::
147
+ It's recommended to use :func:`save_video` for a more general interface
148
+ that supports multiple formats.
149
+
150
+ Args:
151
+ images (list or np.ndarray): List or array of images. Must have shape
152
+ (n_frames, height, width, channels) or (n_frames, height, width).
153
+ If channel axis is not present, or is 1, grayscale image is assumed,
154
+ which is then converted to RGB. Images should be uint8.
155
+ filename (str or Path): Filename to which data should be written.
156
+ fps (int): Frames per second of rendered format.
157
+ shared_color_palette (bool, optional): If True, creates a global
158
+ color palette across all frames, ensuring consistent colors
159
+ throughout the GIF. Defaults to False, which is default behavior
160
+ of PIL.Image.save. Note: True can cause slow saving for longer
161
+ sequences, and also lead to larger file sizes in some cases.
162
+
163
+ """
164
+ images = preprocess_for_saving(images)
165
+
166
+ if fps > 50:
167
+ log.warning(f"Cannot set fps ({fps}) > 50. Setting it automatically to 50.")
168
+ fps = 50
169
+
170
+ duration = int(round(1000 / fps)) # milliseconds per frame
171
+
172
+ pillow_imgs = [Image.fromarray(img) for img in images]
173
+
174
+ if shared_color_palette:
175
+ # Apply the same palette to all frames without dithering for consistent color mapping
176
+ # Convert all images to RGB and combine their colors for palette generation
177
+ all_colors = np.vstack([np.array(img.convert("RGB")).reshape(-1, 3) for img in pillow_imgs])
178
+ combined_image = Image.fromarray(all_colors.reshape(-1, 1, 3))
179
+
180
+ # Generate palette from all frames
181
+ global_palette = combined_image.quantize(
182
+ colors=256,
183
+ method=Image.MEDIANCUT,
184
+ kmeans=1,
185
+ )
186
+
187
+ # Apply the same palette to all frames without dithering
188
+ pillow_imgs = [
189
+ img.convert("RGB").quantize(
190
+ palette=global_palette,
191
+ dither=Image.NONE,
192
+ )
193
+ for img in pillow_imgs
194
+ ]
195
+
196
+ pillow_img, *pillow_imgs = pillow_imgs
197
+
198
+ pillow_img.save(
199
+ fp=filename,
200
+ format="GIF",
201
+ append_images=pillow_imgs,
202
+ save_all=True,
203
+ loop=0,
204
+ duration=duration,
205
+ interlace=False,
206
+ optimize=False,
207
+ )
208
+ log.success(f"Successfully saved GIF to -> {log.yellow(filename)}")
209
+
210
+
211
+ def save_to_mp4(images, filename, fps=20):
212
+ """Saves a sequence of images to an MP4 file.
213
+
214
+ .. note::
215
+ It's recommended to use :func:`save_video` for a more general interface
216
+ that supports multiple formats.
217
+
218
+ Args:
219
+ images (list or np.ndarray): List or array of images. Must have shape
220
+ (n_frames, height, width, channels) or (n_frames, height, width).
221
+ If channel axis is not present, or is 1, grayscale image is assumed,
222
+ which is then converted to RGB. Images should be uint8.
223
+ filename (str or Path): Filename to which data should be written.
224
+ fps (int): Frames per second of rendered format.
225
+
226
+ Raises:
227
+ ImportError: If imageio-ffmpeg is not installed.
228
+
229
+ Returns:
230
+ str: Success message.
231
+
232
+ """
233
+ images = preprocess_for_saving(images)
234
+
235
+ filename = str(filename)
236
+
237
+ parent_dir = Path(filename).parent
238
+ parent_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ # Use imageio with FFMPEG format for MP4 files
241
+ try:
242
+ writer = imageio.get_writer(
243
+ filename, fps=fps, format="FFMPEG", codec="libx264", pixelformat="yuv420p"
244
+ )
245
+ except (ImportError, ValueError) as exc:
246
+ raise ImportError(
247
+ "FFMPEG plugin is required to save MP4 files. "
248
+ "Please install it with 'pip install imageio-ffmpeg'."
249
+ ) from exc
250
+
251
+ try:
252
+ for image in images:
253
+ writer.append_data(image)
254
+ finally:
255
+ writer.close()
256
+
257
+ return log.success(f"Successfully saved MP4 to -> {filename}")
112
258
 
113
259
 
114
260
  def search_file_tree(
@@ -341,3 +487,85 @@ def retry_on_io_error(max_retries=3, initial_delay=0.5, retry_action=None):
341
487
  return wrapper
342
488
 
343
489
  return decorator
490
+
491
+
492
+ def _convert_image_mode(img, mode="L"):
493
+ """Convert a PIL Image to the specified mode and return as numpy array."""
494
+ if mode not in {"L", "RGB"}:
495
+ raise ValueError(f"Unsupported mode: {mode}, must be one of: L, RGB")
496
+ if mode == "L":
497
+ img = img.convert("L")
498
+ arr = np.array(img)
499
+ elif mode == "RGB":
500
+ img = img.convert("RGB")
501
+ arr = np.array(img)
502
+ return arr
503
+
504
+
505
+ def grayscale_to_rgb(image):
506
+ """Converts a grayscale image to an RGB image.
507
+
508
+ Args:
509
+ image (ndarray): Grayscale image. Must have shape (height, width).
510
+
511
+ Returns:
512
+ ndarray: RGB image.
513
+ """
514
+ assert image.ndim == 2, "Input image must be grayscale."
515
+ # Stack the grayscale image into 3 channels (RGB)
516
+ return np.stack([image] * 3, axis=-1)
517
+
518
+
519
+ def _assert_uint8_images(images: np.ndarray):
520
+ """
521
+ Asserts that the input images have the correct properties.
522
+
523
+ Args:
524
+ images (np.ndarray): The input images.
525
+
526
+ Raises:
527
+ AssertionError: If the dtype of images is not uint8.
528
+ AssertionError: If the shape of images is not (n_frames, height, width, channels)
529
+ or (n_frames, height, width) for grayscale images.
530
+ AssertionError: If images have anything other than 1 (grayscale),
531
+ 3 (rgb) or 4 (rgba) channels.
532
+ """
533
+ assert images.dtype == np.uint8, f"dtype of images should be uint8, got {images.dtype}"
534
+
535
+ assert images.ndim in (3, 4), (
536
+ "images must have shape (n_frames, height, width, channels),"
537
+ f" or (n_frames, height, width) for grayscale images. Got {images.shape}"
538
+ )
539
+
540
+ if images.ndim == 4:
541
+ assert images.shape[-1] in (1, 3, 4), (
542
+ "Grayscale images must have 1 channel, "
543
+ "RGB images must have 3 channels, and RGBA images must have 4 channels. "
544
+ f"Got shape: {images.shape}, channels: {images.shape[-1]}"
545
+ )
546
+
547
+
548
+ def preprocess_for_saving(images):
549
+ """Preprocesses images for saving to GIF or MP4.
550
+
551
+ Args:
552
+ images (ndarray, list[ndarray]): Images. Must have shape (n_frames, height, width, channels)
553
+ or (n_frames, height, width).
554
+ """
555
+ images = np.array(images)
556
+ _assert_uint8_images(images)
557
+
558
+ # Remove channel axis if it is 1 (grayscale image)
559
+ if images.ndim == 4 and images.shape[-1] == 1:
560
+ images = np.squeeze(images, axis=-1)
561
+
562
+ # convert grayscale images to RGB
563
+ if images.ndim == 3:
564
+ images = [grayscale_to_rgb(image) for image in images]
565
+ images = np.array(images)
566
+
567
+ # drop alpha channel if present (RGBA -> RGB)
568
+ if images.ndim == 4 and images.shape[-1] == 4:
569
+ images = images[..., :3]
570
+
571
+ return images
zea/keras_ops.py CHANGED
@@ -3,14 +3,14 @@ and :mod:`keras.ops.image` functions.
3
3
 
4
4
  They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
5
5
 
6
- .. code-block:: python
6
+ .. doctest::
7
7
 
8
- from zea.keras_ops import Squeeze
8
+ >>> from zea.keras_ops import Squeeze
9
9
 
10
- op = Squeeze(axis=1)
10
+ >>> op = Squeeze(axis=1)
11
11
 
12
12
  This file is generated automatically. Do not edit manually.
13
- Generated with Keras 3.11.3
13
+ Generated with Keras 3.12.0
14
14
  """
15
15
 
16
16
  import keras
@@ -388,6 +388,16 @@ class Cholesky(Lambda):
388
388
  except AttributeError as e:
389
389
  raise MissingKerasOps("Cholesky", "keras.ops.cholesky") from e
390
390
 
391
+ @ops_registry("keras.ops.cholesky_inverse")
392
+ class CholeskyInverse(Lambda):
393
+ """Operation wrapping keras.ops.cholesky_inverse."""
394
+
395
+ def __init__(self, **kwargs):
396
+ try:
397
+ super().__init__(func=keras.ops.cholesky_inverse, **kwargs)
398
+ except AttributeError as e:
399
+ raise MissingKerasOps("CholeskyInverse", "keras.ops.cholesky_inverse") from e
400
+
391
401
  @ops_registry("keras.ops.clip")
392
402
  class Clip(Lambda):
393
403
  """Operation wrapping keras.ops.clip."""
@@ -918,6 +928,36 @@ class Isnan(Lambda):
918
928
  except AttributeError as e:
919
929
  raise MissingKerasOps("Isnan", "keras.ops.isnan") from e
920
930
 
931
+ @ops_registry("keras.ops.isneginf")
932
+ class Isneginf(Lambda):
933
+ """Operation wrapping keras.ops.isneginf."""
934
+
935
+ def __init__(self, **kwargs):
936
+ try:
937
+ super().__init__(func=keras.ops.isneginf, **kwargs)
938
+ except AttributeError as e:
939
+ raise MissingKerasOps("Isneginf", "keras.ops.isneginf") from e
940
+
941
+ @ops_registry("keras.ops.isposinf")
942
+ class Isposinf(Lambda):
943
+ """Operation wrapping keras.ops.isposinf."""
944
+
945
+ def __init__(self, **kwargs):
946
+ try:
947
+ super().__init__(func=keras.ops.isposinf, **kwargs)
948
+ except AttributeError as e:
949
+ raise MissingKerasOps("Isposinf", "keras.ops.isposinf") from e
950
+
951
+ @ops_registry("keras.ops.isreal")
952
+ class Isreal(Lambda):
953
+ """Operation wrapping keras.ops.isreal."""
954
+
955
+ def __init__(self, **kwargs):
956
+ try:
957
+ super().__init__(func=keras.ops.isreal, **kwargs)
958
+ except AttributeError as e:
959
+ raise MissingKerasOps("Isreal", "keras.ops.isreal") from e
960
+
921
961
  @ops_registry("keras.ops.istft")
922
962
  class Istft(Lambda):
923
963
  """Operation wrapping keras.ops.istft."""
@@ -1828,6 +1868,16 @@ class Trunc(Lambda):
1828
1868
  except AttributeError as e:
1829
1869
  raise MissingKerasOps("Trunc", "keras.ops.trunc") from e
1830
1870
 
1871
+ @ops_registry("keras.ops.unfold")
1872
+ class Unfold(Lambda):
1873
+ """Operation wrapping keras.ops.unfold."""
1874
+
1875
+ def __init__(self, **kwargs):
1876
+ try:
1877
+ super().__init__(func=keras.ops.unfold, **kwargs)
1878
+ except AttributeError as e:
1879
+ raise MissingKerasOps("Unfold", "keras.ops.unfold") from e
1880
+
1831
1881
  @ops_registry("keras.ops.unstack")
1832
1882
  class Unstack(Lambda):
1833
1883
  """Operation wrapping keras.ops.unstack."""
@@ -1848,6 +1898,16 @@ class Var(Lambda):
1848
1898
  except AttributeError as e:
1849
1899
  raise MissingKerasOps("Var", "keras.ops.var") from e
1850
1900
 
1901
+ @ops_registry("keras.ops.view")
1902
+ class View(Lambda):
1903
+ """Operation wrapping keras.ops.view."""
1904
+
1905
+ def __init__(self, **kwargs):
1906
+ try:
1907
+ super().__init__(func=keras.ops.view, **kwargs)
1908
+ except AttributeError as e:
1909
+ raise MissingKerasOps("View", "keras.ops.view") from e
1910
+
1851
1911
  @ops_registry("keras.ops.view_as_complex")
1852
1912
  class ViewAsComplex(Lambda):
1853
1913
  """Operation wrapping keras.ops.view_as_complex."""
@@ -1987,3 +2047,13 @@ class RgbToHsv(Lambda):
1987
2047
  super().__init__(func=keras.ops.image.rgb_to_hsv, **kwargs)
1988
2048
  except AttributeError as e:
1989
2049
  raise MissingKerasOps("RgbToHsv", "keras.ops.image.rgb_to_hsv") from e
2050
+
2051
+ @ops_registry("keras.ops.image.scale_and_translate")
2052
+ class ScaleAndTranslate(Lambda):
2053
+ """Operation wrapping keras.ops.image.scale_and_translate."""
2054
+
2055
+ def __init__(self, **kwargs):
2056
+ try:
2057
+ super().__init__(func=keras.ops.image.scale_and_translate, **kwargs)
2058
+ except AttributeError as e:
2059
+ raise MissingKerasOps("ScaleAndTranslate", "keras.ops.image.scale_and_translate") from e
zea/log.py CHANGED
@@ -6,7 +6,7 @@ to the console and to a file with color support.
6
6
  Example usage
7
7
  ^^^^^^^^^^^^^^
8
8
 
9
- .. code-block:: python
9
+ .. testsetup::
10
10
 
11
11
  from zea import log
12
12
 
@@ -323,18 +323,20 @@ def set_level(level):
323
323
 
324
324
  Also sets the log level for the file logger if it exists.
325
325
 
326
- Example:
327
- >>> from zea import log
328
- >>> with log.set_level("WARNING"):
329
- ... log.info("This will not be shown")
330
- ... log.warning("This will be shown")
331
-
332
326
  Args:
333
327
  level (str or int): The log level to set temporarily
334
328
  (e.g., "DEBUG", "INFO", logging.WARNING).
335
329
 
336
330
  Yields:
337
331
  None
332
+
333
+ Example:
334
+ .. doctest::
335
+
336
+ >>> from zea import log
337
+ >>> with log.set_level("ERROR"):
338
+ ... _ = log.info("Info messages will not be shown")
339
+ ... _ = log.error("Error messages will be shown")
338
340
  """
339
341
  prev_level = logger.level
340
342
  prev_file_level = file_logger.level if file_logger else None