zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -5
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/config.py +34 -25
  10. zea/data/__init__.py +22 -25
  11. zea/data/augmentations.py +221 -28
  12. zea/data/convert/__init__.py +1 -6
  13. zea/data/convert/__main__.py +123 -0
  14. zea/data/convert/camus.py +101 -40
  15. zea/data/convert/echonet.py +187 -86
  16. zea/data/convert/echonetlvh/README.md +2 -3
  17. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  18. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  19. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  20. zea/data/convert/picmus.py +37 -40
  21. zea/data/convert/utils.py +86 -0
  22. zea/data/convert/{matlab.py → verasonics.py} +44 -65
  23. zea/data/data_format.py +155 -34
  24. zea/data/dataloader.py +12 -7
  25. zea/data/datasets.py +112 -71
  26. zea/data/file.py +184 -73
  27. zea/data/file_operations.py +496 -0
  28. zea/data/layers.py +3 -3
  29. zea/data/preset_utils.py +1 -1
  30. zea/datapaths.py +16 -4
  31. zea/display.py +14 -13
  32. zea/interface.py +14 -16
  33. zea/internal/_generate_keras_ops.py +6 -7
  34. zea/internal/cache.py +2 -49
  35. zea/internal/checks.py +6 -12
  36. zea/internal/config/validation.py +1 -2
  37. zea/internal/core.py +69 -6
  38. zea/internal/device.py +6 -2
  39. zea/internal/dummy_scan.py +330 -0
  40. zea/internal/operators.py +118 -2
  41. zea/internal/parameters.py +101 -70
  42. zea/internal/setup_zea.py +5 -6
  43. zea/internal/utils.py +282 -0
  44. zea/io_lib.py +322 -146
  45. zea/keras_ops.py +74 -4
  46. zea/log.py +9 -7
  47. zea/metrics.py +15 -7
  48. zea/models/__init__.py +31 -21
  49. zea/models/base.py +30 -14
  50. zea/models/carotid_segmenter.py +19 -4
  51. zea/models/diffusion.py +235 -23
  52. zea/models/echonet.py +22 -8
  53. zea/models/echonetlvh.py +31 -7
  54. zea/models/lpips.py +19 -2
  55. zea/models/lv_segmentation.py +30 -11
  56. zea/models/preset_utils.py +5 -5
  57. zea/models/regional_quality.py +30 -10
  58. zea/models/taesd.py +21 -5
  59. zea/models/unet.py +15 -1
  60. zea/ops.py +770 -336
  61. zea/probes.py +6 -6
  62. zea/scan.py +121 -51
  63. zea/simulator.py +24 -21
  64. zea/tensor_ops.py +477 -353
  65. zea/tools/fit_scan_cone.py +90 -160
  66. zea/tools/hf.py +1 -1
  67. zea/tools/selection_tool.py +47 -86
  68. zea/tracking/__init__.py +16 -0
  69. zea/tracking/base.py +94 -0
  70. zea/tracking/lucas_kanade.py +474 -0
  71. zea/tracking/segmentation.py +110 -0
  72. zea/utils.py +101 -480
  73. zea/visualize.py +177 -39
  74. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
  75. zea-0.0.8.dist-info/RECORD +122 -0
  76. zea-0.0.6.dist-info/RECORD +0 -112
  77. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  78. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  79. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/interface.py CHANGED
@@ -3,15 +3,15 @@
3
3
  Example usage
4
4
  ^^^^^^^^^^^^^^
5
5
 
6
- .. code-block:: python
6
+ .. doctest::
7
7
 
8
- import zea
9
- from zea.internal.setup_zea import setup_config
8
+ >>> import zea
9
+ >>> from zea.internal.setup_zea import setup_config
10
10
 
11
- config = setup_config("hf://zeahub/configs/config_camus.yaml")
11
+ >>> config = setup_config("hf://zeahub/configs/config_camus.yaml")
12
12
 
13
- interface = zea.Interface(config)
14
- interface.run(plot=True)
13
+ >>> interface = zea.Interface(config)
14
+ >>> interface.run(plot=True) # doctest: +SKIP
15
15
 
16
16
  """
17
17
 
@@ -31,15 +31,15 @@ from zea.data.file import File
31
31
  from zea.datapaths import format_data_path
32
32
  from zea.display import to_8bit
33
33
  from zea.internal.core import DataTypes
34
+ from zea.internal.utils import keep_trying
34
35
  from zea.internal.viewer import (
35
36
  ImageViewerMatplotlib,
36
37
  ImageViewerOpenCV,
37
38
  filename_from_window_dialog,
38
39
  running_in_notebook,
39
40
  )
40
- from zea.io_lib import matplotlib_figure_to_numpy
41
+ from zea.io_lib import matplotlib_figure_to_numpy, save_video
41
42
  from zea.ops import Pipeline
42
- from zea.utils import keep_trying, save_to_gif, save_to_mp4
43
43
 
44
44
 
45
45
  class Interface:
@@ -266,10 +266,11 @@ class Interface:
266
266
  save = self.config.plot.save
267
267
 
268
268
  if self.frame_no == "all":
269
- if not asyncio.get_event_loop().is_running():
270
- asyncio.run(self.run_movie(save))
271
- else:
272
- asyncio.create_task(self.run_movie(save))
269
+ try:
270
+ loop = asyncio.get_running_loop()
271
+ loop.create_task(self.run_movie(save)) # already running loop
272
+ except RuntimeError:
273
+ asyncio.run(self.run_movie(save)) # no loop yet
273
274
 
274
275
  else:
275
276
  if plot:
@@ -520,10 +521,7 @@ class Interface:
520
521
 
521
522
  fps = self.config.plot.fps
522
523
 
523
- if self.config.plot.video_extension == "gif":
524
- save_to_gif(images, path, fps=fps)
525
- elif self.config.plot.video_extension == "mp4":
526
- save_to_mp4(images, path, fps=fps)
524
+ save_video(images, path, fps=fps)
527
525
 
528
526
  if self.verbose:
529
527
  log.info(f"Video saved to {log.yellow(path)}")
@@ -3,11 +3,10 @@ and :mod:`keras.ops.image` functions.
3
3
 
4
4
  They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
5
5
 
6
- .. code-block:: python
6
+ .. doctest::
7
7
 
8
- from zea.keras_ops import Squeeze
9
-
10
- op = Squeeze(axis=1)
8
+ >>> from zea.keras_ops import Squeeze
9
+ >>> op = Squeeze(axis=1)
11
10
  """
12
11
 
13
12
  import inspect
@@ -77,11 +76,11 @@ and :mod:`keras.ops.image` functions.
77
76
 
78
77
  They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
79
78
 
80
- .. code-block:: python
79
+ .. doctest::
81
80
 
82
- from zea.keras_ops import Squeeze
81
+ >>> from zea.keras_ops import Squeeze
83
82
 
84
- op = Squeeze(axis=1)
83
+ >>> op = Squeeze(axis=1)
85
84
 
86
85
  This file is generated automatically. Do not edit manually.
87
86
  Generated with Keras {keras.__version__}
zea/internal/cache.py CHANGED
@@ -21,10 +21,8 @@
21
21
 
22
22
  import ast
23
23
  import atexit
24
- import hashlib
25
24
  import inspect
26
25
  import os
27
- import pickle
28
26
  import tempfile
29
27
  import textwrap
30
28
  from pathlib import Path
@@ -33,6 +31,7 @@ import joblib
33
31
  import keras
34
32
 
35
33
  from zea import log
34
+ from zea.internal.core import hash_elements
36
35
 
37
36
  _DEFAULT_ZEA_CACHE_DIR = Path.home() / ".cache" / "zea"
38
37
 
@@ -80,52 +79,6 @@ _CACHE_DIR = ZEA_CACHE_DIR / "cached_funcs"
80
79
  _CACHE_DIR.mkdir(parents=True, exist_ok=True)
81
80
 
82
81
 
83
- def serialize_elements(key_elements: list, shorten: bool = False) -> str:
84
- """Serialize elements of a list to generate a cache key.
85
-
86
- In general uses the string representation of the elements unless
87
- the element has a `serialized` attribute, in which case it uses that.
88
- For instance this is useful for custom classes that inherit from `zea.core.Object`.
89
-
90
- Args:
91
- key_elements (list): List of elements to serialize. Can be nested lists
92
- or tuples. In this case the elements are serialized recursively.
93
- shorten (bool): If True, the serialized string is hashed to a shorter
94
- representation using MD5. Defaults to False.
95
-
96
- Returns:
97
- str: A serialized string representation of the elements, joined by underscores.
98
-
99
- """
100
- serialized_elements = []
101
- for element in key_elements:
102
- if isinstance(element, (list, tuple)):
103
- # If element is a list or tuple, serialize its elements recursively
104
- serialized_elements.append(serialize_elements(element))
105
- elif hasattr(element, "serialized"):
106
- # Use the serialized attribute if it exists (e.g. for zea.core.Object)
107
- serialized_elements.append(str(element.serialized))
108
- elif isinstance(element, str):
109
- # If element is a string, use it as is
110
- serialized_elements.append(element)
111
- elif isinstance(element, keras.random.SeedGenerator):
112
- # If element is a SeedGenerator, use the state
113
- element = keras.ops.convert_to_numpy(element.state.value)
114
- element = pickle.dumps(element)
115
- element = hashlib.md5(element).hexdigest()
116
- serialized_elements.append(element)
117
- else:
118
- # Otherwise, serialize the element using pickle and hash it
119
- element = pickle.dumps(element)
120
- element = hashlib.md5(element).hexdigest()
121
- serialized_elements.append(element)
122
-
123
- serialized = "_".join(serialized_elements)
124
- if shorten:
125
- return hashlib.md5(serialized.encode()).hexdigest()
126
- return serialized
127
-
128
-
129
82
  def get_function_source(func):
130
83
  """Recursively get the source code of a function and its nested functions."""
131
84
  try:
@@ -188,7 +141,7 @@ def generate_cache_key(func, args, kwargs, arg_names):
188
141
  # Add keras backend
189
142
  key_elements.append(keras.backend.backend())
190
143
 
191
- return f"{func.__qualname__}_" + serialize_elements(key_elements, shorten=True)
144
+ return f"{func.__qualname__}_" + hash_elements(key_elements)
192
145
 
193
146
 
194
147
  def cache_output(*arg_names, verbose=False):
zea/internal/checks.py CHANGED
@@ -64,8 +64,7 @@ def _check_raw_data(data=None, shape=None, with_batch_dim=None):
64
64
  shape (tuple, optional): shape of the data. Defaults to None.
65
65
  either data or shape must be provided.
66
66
  with_batch_dim (bool, optional): whether data has frame dimension at the start.
67
- Setting this to True requires the data to have 5 dimensions. Defaults to
68
- False.
67
+ Setting this to True requires the data to have 5 dimensions. Defaults to None.
69
68
 
70
69
  Raises:
71
70
  AssertionError: if data does not have expected shape
@@ -105,8 +104,7 @@ def _check_aligned_data(data=None, shape=None, with_batch_dim=None):
105
104
  shape (tuple, optional): shape of the data. Defaults to None.
106
105
  either data or shape must be provided.
107
106
  with_batch_dim (bool, optional): whether data has frame dimension at the start.
108
- Setting this to True requires the data to have 5 dimensions. Defaults to
109
- False.
107
+ Setting this to True requires the data to have 5 dimensions. Defaults to None.
110
108
 
111
109
  Raises:
112
110
  AssertionError: if data does not have expected shape
@@ -147,8 +145,7 @@ def _check_beamformed_data(data=None, shape=None, with_batch_dim=None):
147
145
  shape (tuple, optional): shape of the data. Defaults to None.
148
146
  either data or shape must be provided.
149
147
  with_batch_dim (bool, optional): whether data has frame dimension at the start.
150
- Setting this to True requires the data to have 4 dimensions. Defaults to
151
- False.
148
+ Setting this to True requires the data to have 4 dimensions. Defaults to None.
152
149
 
153
150
  Raises:
154
151
  AssertionError: if data does not have expected shape
@@ -190,8 +187,7 @@ def _check_envelope_data(data=None, shape=None, with_batch_dim=None):
190
187
  shape (tuple, optional): shape of the data. Defaults to None.
191
188
  either data or shape must be provided.
192
189
  with_batch_dim (bool, optional): whether data has frame dimension at the start.
193
- Setting this to True requires the data to have 4 dimensions. Defaults to
194
- False.
190
+ Setting this to True requires the data to have 3 dimensions. Defaults to None.
195
191
 
196
192
  Raises:
197
193
  AssertionError: if data does not have expected shape
@@ -227,8 +223,7 @@ def _check_image(data=None, shape=None, with_batch_dim=None):
227
223
  shape (tuple, optional): shape of the data. Defaults to None.
228
224
  either data or shape must be provided.
229
225
  with_batch_dim (bool, optional): whether data has frame dimension at the start.
230
- Setting this to True requires the data to have 4 dimensions. Defaults to
231
- False.
226
+ Setting this to True requires the data to have 3 dimensions. Defaults to None.
232
227
 
233
228
  Raises:
234
229
  AssertionError: if data does not have expected shape.
@@ -264,8 +259,7 @@ def _check_image_sc(data=None, shape=None, with_batch_dim=None):
264
259
  shape (tuple, optional): shape of the data. Defaults to None.
265
260
  either data or shape must be provided.
266
261
  with_batch_dim (bool, optional): whether data has frame dimension at the start.
267
- Setting this to True requires the data to have 4 dimensions. Defaults to
268
- False.
262
+ Setting this to True requires the data to have 3 dimensions. Defaults to None.
269
263
 
270
264
  Raises:
271
265
  AssertionError: if data does not have expected shape.
@@ -15,9 +15,8 @@ from pathlib import Path
15
15
 
16
16
  from schema import And, Optional, Or, Regex, Schema
17
17
 
18
- import zea.metrics # noqa: F401
19
18
  from zea.internal.checks import _DATA_TYPES
20
- from zea.internal.registry import metrics_registry
19
+ from zea.metrics import metrics_registry
21
20
 
22
21
  # predefined checks, later used in schema to check validity of parameter
23
22
  any_number = Or(
zea/internal/core.py CHANGED
@@ -1,6 +1,7 @@
1
1
  """Base classes for the toolbox"""
2
2
 
3
3
  import enum
4
+ import hashlib
4
5
  import json
5
6
  import pickle
6
7
  from copy import deepcopy
@@ -8,7 +9,8 @@ from copy import deepcopy
8
9
  import keras
9
10
  import numpy as np
10
11
 
11
- from zea.utils import reduce_to_signature, update_dictionary
12
+ from zea.internal.utils import reduce_to_signature
13
+ from zea.utils import update_dictionary
12
14
 
13
15
  CONVERT_TO_KERAS_TYPES = (np.ndarray, int, float, list, tuple, bool)
14
16
  BASE_FLOAT_PRECISION = "float32"
@@ -76,7 +78,7 @@ class Object:
76
78
  attributes.pop(
77
79
  "_serialized", None
78
80
  ) # Remove the cached serialized attribute to avoid recursion
79
- self._serialized = pickle.dumps(attributes)
81
+ self._serialized = serialize_elements([attributes])
80
82
  return self._serialized
81
83
 
82
84
  def __setattr__(self, name: str, value):
@@ -167,9 +169,7 @@ def _skip_to_tensor(value):
167
169
  # Skip str (because JIT does not support it)
168
170
  # Skip methods and functions
169
171
  # Skip byte strings
170
- if isinstance(value, str) or callable(value) or isinstance(value, bytes):
171
- return True
172
- return False
172
+ return isinstance(value, str) or callable(value) or isinstance(value, bytes)
173
173
 
174
174
 
175
175
  def dict_to_tensor(dictionary, keep_as_is=None):
@@ -184,8 +184,9 @@ def dict_to_tensor(dictionary, keep_as_is=None):
184
184
  # Get the value from the dictionary
185
185
  value = dictionary[key]
186
186
 
187
- if isinstance(value, Object):
187
+ if isinstance(value, Object) and hasattr(value, "to_tensor"):
188
188
  snapshot[key] = value.to_tensor(keep_as_is=keep_as_is)
189
+ continue
189
190
 
190
191
  # Skip certain types
191
192
  if _skip_to_tensor(value):
@@ -288,3 +289,65 @@ class ZEADecoderJSON(json.JSONDecoder):
288
289
  obj[key] = self._MOD_TYPES_MAP[value] if value is not None else None
289
290
 
290
291
  return obj
292
+
293
+
294
+ def serialize_elements(key_elements: list) -> str:
295
+ """Serialize elements of a list to a string.
296
+
297
+ Generally, uses the pickle representation of the elements.
298
+
299
+ Args:
300
+ key_elements (list): List of elements to serialize. Can be nested lists
301
+ or tuples. In this case the elements are serialized recursively.
302
+
303
+ Returns:
304
+ str: A serialized string representation of the elements, joined by underscores.
305
+ """
306
+
307
+ def _serialize(element) -> str:
308
+ return pickle.dumps(element).hex()
309
+
310
+ def _serialize_element(element) -> str:
311
+ if isinstance(element, (list, tuple)):
312
+ # If element is a list or tuple, serialize its elements recursively
313
+ element = serialize_elements(element)
314
+ elif isinstance(element, Object) and hasattr(element, "serialized"):
315
+ # Use the serialized attribute if it exists
316
+ element = str(element.serialized)
317
+ elif isinstance(element, keras.random.SeedGenerator):
318
+ # If element is a SeedGenerator, use the state
319
+ element = keras.ops.convert_to_numpy(element.state.value)
320
+ element = _serialize(element)
321
+ elif isinstance(element, dict):
322
+ # If element is a dictionary, sort its keys and serialize its values recursively.
323
+ # This is needed to ensure the internal state and ordering of the dictionary does
324
+ # not affect the serialization.
325
+ keys = list(sorted(element.keys()))
326
+ values = [element[k] for k in keys]
327
+ keys = serialize_elements(keys)
328
+ values = serialize_elements(values)
329
+ element = f"k_{keys}_v_{values}"
330
+ else:
331
+ # Otherwise, serialize the element directly
332
+ element = _serialize(element)
333
+
334
+ return element
335
+
336
+ serialized_elements = []
337
+ for element in key_elements:
338
+ serialized_elements.append(_serialize_element(element))
339
+
340
+ return "_".join(serialized_elements)
341
+
342
+
343
+ def hash_elements(key_elements: list) -> str:
344
+ """Generate an MD5 hash of the elements.
345
+
346
+ Args:
347
+ key_elements (list): List of elements to serialize and hash.
348
+
349
+ Returns:
350
+ str: An MD5 hash of the serialized elements.
351
+ """
352
+ serialized = serialize_elements(key_elements)
353
+ return hashlib.md5(serialized.encode()).hexdigest()
zea/internal/device.py CHANGED
@@ -377,7 +377,11 @@ def init_device(
377
377
  allow_preallocate: bool = True,
378
378
  verbose: bool = True,
379
379
  ):
380
- """Selects a GPU or CPU device based on the config.
380
+ """Automatically selects a GPU or CPU device.
381
+
382
+ Useful to call at the start of a script to set the device for
383
+ tensorflow, jax or pytorch. The function will select a GPU based
384
+ on available memory, or fall back to CPU if no GPU is available.
381
385
 
382
386
  Args:
383
387
  backend (str): String indicating which backend to use. Can be
@@ -412,7 +416,7 @@ def init_device(
412
416
  elif backend in ["numpy", "cpu"]:
413
417
  device = "cpu"
414
418
  else:
415
- raise ValueError(f"Unknown backend ({backend}) in config.")
419
+ raise ValueError(f"Unknown backend ({backend}).")
416
420
 
417
421
  # Early exit if device is CPU
418
422
  if device == "cpu":