zea 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/agent/selection.py +166 -0
  5. zea/backend/__init__.py +89 -0
  6. zea/backend/jax/__init__.py +14 -51
  7. zea/backend/tensorflow/__init__.py +0 -49
  8. zea/backend/tensorflow/dataloader.py +2 -1
  9. zea/backend/torch/__init__.py +27 -62
  10. zea/beamform/beamformer.py +100 -50
  11. zea/beamform/lens_correction.py +9 -2
  12. zea/beamform/pfield.py +9 -2
  13. zea/config.py +34 -25
  14. zea/data/__init__.py +22 -16
  15. zea/data/convert/camus.py +2 -1
  16. zea/data/convert/echonet.py +4 -4
  17. zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
  18. zea/data/convert/matlab.py +11 -4
  19. zea/data/data_format.py +31 -30
  20. zea/data/datasets.py +7 -5
  21. zea/data/file.py +104 -2
  22. zea/data/layers.py +5 -6
  23. zea/datapaths.py +16 -4
  24. zea/display.py +7 -5
  25. zea/interface.py +14 -16
  26. zea/internal/_generate_keras_ops.py +6 -7
  27. zea/internal/cache.py +2 -49
  28. zea/internal/config/validation.py +1 -2
  29. zea/internal/core.py +69 -6
  30. zea/internal/device.py +6 -2
  31. zea/internal/dummy_scan.py +330 -0
  32. zea/internal/operators.py +114 -2
  33. zea/internal/parameters.py +101 -70
  34. zea/internal/registry.py +1 -1
  35. zea/internal/setup_zea.py +5 -6
  36. zea/internal/utils.py +282 -0
  37. zea/io_lib.py +247 -19
  38. zea/keras_ops.py +74 -4
  39. zea/log.py +9 -7
  40. zea/metrics.py +365 -65
  41. zea/models/__init__.py +30 -20
  42. zea/models/base.py +30 -14
  43. zea/models/carotid_segmenter.py +19 -4
  44. zea/models/diffusion.py +187 -26
  45. zea/models/echonet.py +22 -8
  46. zea/models/echonetlvh.py +31 -18
  47. zea/models/lpips.py +19 -2
  48. zea/models/lv_segmentation.py +96 -0
  49. zea/models/preset_utils.py +5 -5
  50. zea/models/presets.py +36 -0
  51. zea/models/regional_quality.py +142 -0
  52. zea/models/taesd.py +21 -5
  53. zea/models/unet.py +15 -1
  54. zea/ops.py +414 -207
  55. zea/probes.py +6 -6
  56. zea/scan.py +109 -49
  57. zea/simulator.py +24 -21
  58. zea/tensor_ops.py +411 -206
  59. zea/tools/hf.py +1 -1
  60. zea/tools/selection_tool.py +47 -86
  61. zea/utils.py +92 -480
  62. zea/visualize.py +177 -39
  63. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/METADATA +9 -3
  64. zea-0.0.7.dist-info/RECORD +114 -0
  65. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/WHEEL +1 -1
  66. zea-0.0.5.dist-info/RECORD +0 -110
  67. {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
  68. {zea-0.0.5.dist-info → zea-0.0.7.dist-info/licenses}/LICENSE +0 -0
zea/data/data_format.py CHANGED
@@ -7,11 +7,12 @@ from dataclasses import dataclass
7
7
  from pathlib import Path
8
8
 
9
9
  import numpy as np
10
+ from keras.utils import pad_sequences
10
11
 
11
12
  from zea import log
12
13
  from zea.data.file import File, validate_file
13
14
  from zea.internal.checks import _DATA_TYPES
14
- from zea.utils import first_not_none_item
15
+ from zea.internal.utils import first_not_none_item
15
16
 
16
17
 
17
18
  @dataclass
@@ -470,32 +471,30 @@ def _write_datasets(
470
471
  )
471
472
 
472
473
  if waveforms_one_way is not None:
473
- for n in range(len(waveforms_one_way)):
474
- _add_dataset(
475
- group_name=scan_group_name + "/waveforms_one_way",
476
- name=f"waveform_{str(n).zfill(3)}",
477
- data=waveforms_one_way[n],
478
- description=(
479
- "One-way waveform as simulated by the Verasonics system, "
480
- "sampled at 250MHz. This is the waveform after being filtered "
481
- "by the tranducer bandwidth once."
482
- ),
483
- unit="V",
484
- )
474
+ _add_dataset(
475
+ group_name=scan_group_name,
476
+ name="waveforms_one_way",
477
+ data=pad_sequences(waveforms_one_way, dtype=np.float32, padding="post"),
478
+ description=(
479
+ "One-way waveform as simulated by the Verasonics system, "
480
+ "sampled at 250MHz. This is the waveform after being filtered "
481
+ "by the transducer bandwidth once."
482
+ ),
483
+ unit="V",
484
+ )
485
485
 
486
486
  if waveforms_two_way is not None:
487
- for n in range(len(waveforms_two_way)):
488
- _add_dataset(
489
- group_name=scan_group_name + "/waveforms_two_way",
490
- name=f"waveform_{str(n).zfill(3)}",
491
- data=waveforms_two_way[n],
492
- description=(
493
- "Two-way waveform as simulated by the Verasonics system, "
494
- "sampled at 250MHz. This is the waveform after being filtered "
495
- "by the tranducer bandwidth twice."
496
- ),
497
- unit="V",
498
- )
487
+ _add_dataset(
488
+ group_name=scan_group_name,
489
+ name="waveforms_two_way",
490
+ data=pad_sequences(waveforms_two_way, dtype=np.float32, padding="post"),
491
+ description=(
492
+ "Two-way waveform as simulated by the Verasonics system, "
493
+ "sampled at 250MHz. This is the waveform after being filtered "
494
+ "by the transducer bandwidth twice."
495
+ ),
496
+ unit="V",
497
+ )
499
498
 
500
499
  # Add additional elements
501
500
  if additional_elements is not None:
@@ -539,6 +538,7 @@ def generate_zea_dataset(
539
538
  additional_elements=None,
540
539
  event_structure=False,
541
540
  cast_to_float=True,
541
+ overwrite=False,
542
542
  ):
543
543
  """Generates a dataset in the zea format.
544
544
 
@@ -585,10 +585,10 @@ def generate_zea_dataset(
585
585
  waveform was used for each transmit event.
586
586
  waveforms_one_way (list): List of one-way waveforms as simulated by the Verasonics
587
587
  system, sampled at 250MHz. This is the waveform after being filtered by the
588
- tranducer bandwidth once. Every element in the list is a 1D numpy array.
588
+ transducer bandwidth once. Every element in the list is a 1D numpy array.
589
589
  waveforms_two_way (list): List of two-way waveforms as simulated by the Verasonics
590
590
  system, sampled at 250MHz. This is the waveform after being filtered by the
591
- tranducer bandwidth twice. Every element in the list is a 1D numpy array.
591
+ transducer bandwidth twice. Every element in the list is a 1D numpy array.
592
592
  additional_elements (List[DatasetElement]): A list of additional dataset
593
593
  elements to be added to the dataset. Each element should be a DatasetElement
594
594
  object. The additional elements are added under the scan group.
@@ -598,6 +598,7 @@ def generate_zea_dataset(
598
598
  Instead of just a single data and scan group.
599
599
  cast_to_float (bool): Whether to store data as float32. You may want to set this
600
600
  to False if storing images.
601
+ overwrite (bool): Whether to overwrite the file if it already exists. Defaults to False.
601
602
 
602
603
  """
603
604
  # check if all args are lists
@@ -637,10 +638,10 @@ def generate_zea_dataset(
637
638
  # make sure input arguments of func is same length as data_and_parameters
638
639
  # except `path` and `event_structure` arguments and ofcourse `data_and_parameters` itself
639
640
  assert (
640
- len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) - 3
641
+ len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) - 4
641
642
  ), (
642
643
  "All arguments should be put in data_and_parameters except "
643
- "`path`, `event_structure`, and `cast_to_float` arguments."
644
+ "`path`, `event_structure`, `cast_to_float`, and `overwrite` arguments."
644
645
  )
645
646
 
646
647
  if event_structure:
@@ -682,7 +683,7 @@ def generate_zea_dataset(
682
683
  # Convert path to Path object
683
684
  path = Path(path)
684
685
 
685
- if path.exists():
686
+ if path.exists() and not overwrite:
686
687
  raise FileExistsError(f"The file {path} already exists.")
687
688
 
688
689
  # Create the directory if it does not exist
zea/data/datasets.py CHANGED
@@ -48,13 +48,15 @@ from zea.data.preset_utils import (
48
48
  _hf_resolve_path,
49
49
  )
50
50
  from zea.datapaths import format_data_path
51
+ from zea.internal.utils import (
52
+ calculate_file_hash,
53
+ reduce_to_signature,
54
+ )
51
55
  from zea.io_lib import search_file_tree
52
56
  from zea.tools.hf import HFPath
53
57
  from zea.utils import (
54
- calculate_file_hash,
55
58
  date_string_to_readable,
56
59
  get_date_string,
57
- reduce_to_signature,
58
60
  )
59
61
 
60
62
  _CHECK_MAX_DATASET_SIZE = 10000
@@ -241,7 +243,7 @@ class Folder:
241
243
  return
242
244
 
243
245
  num_frames_per_file = []
244
- validated_succesfully = True
246
+ validated_successfully = True
245
247
  for file_path in tqdm.tqdm(
246
248
  self.file_paths,
247
249
  total=self.n_files,
@@ -253,9 +255,9 @@ class Folder:
253
255
  validation_error_log.append(f"File {file_path} is not a valid zea dataset.\n{e}\n")
254
256
  # convert into warning
255
257
  log.warning(f"Error in file {file_path}.\n{e}")
256
- validated_succesfully = False
258
+ validated_successfully = False
257
259
 
258
- if not validated_succesfully:
260
+ if not validated_successfully:
259
261
  log.warning(
260
262
  "Check warnings above for details. No validation file was created. "
261
263
  f"See {validation_error_file_path} for details."
zea/data/file.py CHANGED
@@ -6,6 +6,7 @@ from typing import List
6
6
 
7
7
  import h5py
8
8
  import numpy as np
9
+ from keras.utils import pad_sequences
9
10
 
10
11
  from zea import log
11
12
  from zea.data.preset_utils import HF_PREFIX, _hf_resolve_path
@@ -15,9 +16,9 @@ from zea.internal.checks import (
15
16
  _REQUIRED_SCAN_KEYS,
16
17
  get_check,
17
18
  )
19
+ from zea.internal.utils import reduce_to_signature
18
20
  from zea.probes import Probe
19
21
  from zea.scan import Scan
20
- from zea.utils import reduce_to_signature
21
22
 
22
23
 
23
24
  def assert_key(file: h5py.File, key: str):
@@ -219,6 +220,41 @@ class File(h5py.File):
219
220
  def load_data(self, data_type, indices: str | int | List[int] = "all"):
220
221
  """Load data from the file.
221
222
 
223
+ The indices parameter can be used to load a subset of the data. This can be
224
+
225
+ - 'all' to load all data
226
+
227
+ - an int to load a single frame
228
+
229
+ - a list of ints to load specific frames
230
+
231
+ - a tuple of lists, ranges or slices to index frames and transmits. Note that
232
+ indexing with lists of indices for both axes is not supported. In that case,
233
+ try to define one of the axes with a slice.
234
+
235
+ .. doctest::
236
+
237
+ >>> from zea import File
238
+
239
+ >>> path_to_file = (
240
+ ... "hf://zeahub/picmus/database/experiments/contrast_speckle/"
241
+ ... "contrast_speckle_expe_dataset_iq/contrast_speckle_expe_dataset_iq.hdf5"
242
+ ... )
243
+
244
+ >>> with File(path_to_file, mode="r") as file:
245
+ ... # data has shape (n_frames, n_tx, n_el, n_ax, n_ch)
246
+ ... data = file.load_data("raw_data")
247
+ ... data.shape
248
+ ... # load first frame only
249
+ ... data = file.load_data("raw_data", indices=0)
250
+ ... data.shape
251
+ ... # load frame 0 and transmits 0, 2 and 4
252
+ ... data = file.load_data("raw_data", indices=(0, [0, 2, 4]))
253
+ ... data.shape
254
+ (1, 75, 832, 128, 2)
255
+ (75, 832, 128, 2)
256
+ (3, 832, 128, 2)
257
+
222
258
  Args:
223
259
  data_type (str): The type of data to load. Options are 'raw_data', 'aligned_data',
224
260
  'beamformed_data', 'envelope_data', 'image' and 'image_sc'.
@@ -337,7 +373,7 @@ class File(h5py.File):
337
373
  Returns:
338
374
  Scan: The scan object.
339
375
  """
340
- return Scan.merge(self.get_scan_parameters(event), kwargs, safe=safe)
376
+ return Scan.merge(_reformat_waveforms(self.get_scan_parameters(event)), kwargs, safe=safe)
341
377
 
342
378
  def get_probe_parameters(self, event=None) -> dict:
343
379
  """Returns a dictionary of probe parameters to initialize a probe
@@ -772,3 +808,69 @@ def _assert_unit_and_description_present(hdf5_file, _prefix=""):
772
808
  assert "description" in hdf5_file[key].attrs.keys(), (
773
809
  f"The file {_prefix}/{key} does not have a description attribute."
774
810
  )
811
+
812
+
813
+ def _reformat_waveforms(scan_kwargs: dict) -> dict:
814
+ """Reformat waveforms from dict to array if needed. This is for backwards compatibility and will
815
+ be removed in a future version of zea.
816
+
817
+ Args:
818
+ scan_kwargs (dict): The scan parameters.
819
+
820
+ Returns:
821
+ scan_kwargs (dict): The scan parameters with the keys waveforms_one_way and
822
+ waveforms_two_way reformatted to arrays if they were stored as dicts.
823
+ """
824
+
825
+ # TODO: remove this in a future version of zea
826
+ if "waveforms_one_way" in scan_kwargs and isinstance(scan_kwargs["waveforms_one_way"], dict):
827
+ log.warning(
828
+ "The waveforms_one_way parameter is stored as a dictionary in the file. "
829
+ "Converting to array. This will be deprecated in future versions of zea. "
830
+ "Please update your files to store waveforms as arrays of shape `(n_tx, n_samples)`."
831
+ )
832
+ scan_kwargs["waveforms_one_way"] = _waveforms_dict_to_array(
833
+ scan_kwargs["waveforms_one_way"]
834
+ )
835
+
836
+ if "waveforms_two_way" in scan_kwargs and isinstance(scan_kwargs["waveforms_two_way"], dict):
837
+ log.warning(
838
+ "The waveforms_two_way parameter is stored as a dictionary in the file. "
839
+ "Converting to array. This will be deprecated in future versions of zea. "
840
+ "Please update your files to store waveforms as arrays of shape `(n_tx, n_samples)`."
841
+ )
842
+ scan_kwargs["waveforms_two_way"] = _waveforms_dict_to_array(
843
+ scan_kwargs["waveforms_two_way"]
844
+ )
845
+ return scan_kwargs
846
+
847
+
848
+ def _waveforms_dict_to_array(waveforms_dict: dict):
849
+ """Convert waveforms stored as a dictionary to a padded numpy array."""
850
+ waveforms = dict_to_sorted_list(waveforms_dict)
851
+ return pad_sequences(waveforms, dtype=np.float32, padding="post")
852
+
853
+
854
+ def dict_to_sorted_list(dictionary: dict):
855
+ """Convert a dictionary with sortable keys to a sorted list of values.
856
+
857
+ .. note::
858
+
859
+ This function operates on the top level of the dictionary only.
860
+ If the dictionary contains nested dictionaries, those will not be sorted.
861
+
862
+ Example:
863
+ .. doctest::
864
+
865
+ >>> from zea.data.file import dict_to_sorted_list
866
+ >>> input_dict = {"number_000": 5, "number_001": 1, "number_002": 23}
867
+ >>> dict_to_sorted_list(input_dict)
868
+ [5, 1, 23]
869
+
870
+ Args:
871
+ dictionary (dict): The dictionary to convert. The keys must be sortable.
872
+
873
+ Returns:
874
+ list: The sorted list of values.
875
+ """
876
+ return [value for _, value in sorted(dictionary.items())]
zea/data/layers.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  import keras
4
4
  import numpy as np
5
- from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
5
+ from keras.src.layers.preprocessing.data_layer import DataLayer
6
6
 
7
7
  from zea.ops import Pad as PadOp
8
8
  from zea.utils import map_negative_indices
@@ -11,7 +11,7 @@ from zea.utils import map_negative_indices
11
11
  class Pad(PadOp):
12
12
  """Pad layer for padding tensors to a specified shape which can be used in tf.data pipelines."""
13
13
 
14
- __call__ = TFDataLayer.__call__
14
+ __call__ = DataLayer.__call__
15
15
 
16
16
  def call(self, inputs):
17
17
  """
@@ -20,12 +20,12 @@ class Pad(PadOp):
20
20
  return super().call(data=inputs)["data"]
21
21
 
22
22
 
23
- class Resizer(TFDataLayer):
23
+ class Resizer(DataLayer):
24
24
  """
25
25
  Resize layer for resizing images. Can deal with N-dimensional images.
26
26
  Can do resize, center_crop, random_crop and crop_or_pad.
27
27
 
28
- Can be used in tf.data pipelines.
28
+ Can be used in `tf.data` pipelines.
29
29
  """
30
30
 
31
31
  def __init__(
@@ -36,7 +36,6 @@ class Resizer(TFDataLayer):
36
36
  seed: int | None = None,
37
37
  **resize_kwargs,
38
38
  ):
39
- # noqa: E501
40
39
  """
41
40
  Initializes the data loader with the specified parameters.
42
41
 
@@ -47,7 +46,7 @@ class Resizer(TFDataLayer):
47
46
  ['random_crop'](https://keras.io/api/layers/preprocessing_layers/image_augmentation/random_crop/),
48
47
  ['resize'](https://keras.io/api/layers/preprocessing_layers/image_preprocessing/resizing/),
49
48
  'crop_or_pad': resizes an image to a target width and height by either centrally
50
- cropping the image, padding it evenly with zeros or a combination of both.
49
+ cropping the image, padding it evenly with zeros or a combination of both.
51
50
  resize_axes (tuple | None, optional): The axes along which to resize.
52
51
  Must be of length 2. Defaults to None. In that case, can only process
53
52
  default tensors of shape (batch, height, width, channels), where the
zea/datapaths.py CHANGED
@@ -11,12 +11,24 @@ to set up your local data paths.
11
11
  Example usage
12
12
  ^^^^^^^^^^^^^
13
13
 
14
- .. code-block:: python
14
+ .. doctest::
15
15
 
16
- from zea.datapaths import set_data_paths
16
+ >>> import yaml
17
+ >>> from zea.datapaths import set_data_paths
17
18
 
18
- user = set_data_paths("users.yaml")
19
- print(user.data_root)
19
+ >>> user_config = {"data_root": "/path/to/data", "output": "/path/to/output"}
20
+ >>> with open("users.yaml", "w", encoding="utf-8") as file:
21
+ ... yaml.dump(user_config, file)
22
+
23
+ >>> user = set_data_paths("users.yaml")
24
+ >>> print(user.data_root)
25
+ /path/to/data
26
+
27
+ .. testcleanup::
28
+
29
+ import os
30
+
31
+ os.remove("users.yaml")
20
32
 
21
33
  """
22
34
 
zea/display.py CHANGED
@@ -9,8 +9,8 @@ from keras import ops
9
9
  from PIL import Image
10
10
 
11
11
  from zea import log
12
+ from zea.tensor_ops import translate
12
13
  from zea.tools.fit_scan_cone import fit_and_crop_around_scan_cone
13
- from zea.utils import translate
14
14
 
15
15
 
16
16
  def to_8bit(image, dynamic_range: Union[None, tuple] = None, pillow: bool = True):
@@ -340,12 +340,14 @@ def scan_convert(
340
340
  def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value=0):
341
341
  """map_coordinates using keras.ops or scipy.ndimage when order > 1."""
342
342
  if order > 1:
343
- inputs = ops.convert_to_numpy(inputs)
344
- coordinates = ops.convert_to_numpy(coordinates)
343
+ # Preserve original dtype before conversion
344
+ original_dtype = ops.dtype(inputs)
345
+ inputs_np = ops.convert_to_numpy(inputs).astype(np.float32)
346
+ coordinates_np = ops.convert_to_numpy(coordinates).astype(np.float32)
345
347
  out = scipy.ndimage.map_coordinates(
346
- inputs, coordinates, order=order, mode=fill_mode, cval=fill_value
348
+ inputs_np, coordinates_np, order=order, mode=fill_mode, cval=fill_value
347
349
  )
348
- return ops.convert_to_tensor(out)
350
+ return ops.convert_to_tensor(out.astype(original_dtype))
349
351
  else:
350
352
  return ops.image.map_coordinates(
351
353
  inputs,
zea/interface.py CHANGED
@@ -3,15 +3,15 @@
3
3
  Example usage
4
4
  ^^^^^^^^^^^^^^
5
5
 
6
- .. code-block:: python
6
+ .. doctest::
7
7
 
8
- import zea
9
- from zea.internal.setup_zea import setup_config
8
+ >>> import zea
9
+ >>> from zea.internal.setup_zea import setup_config
10
10
 
11
- config = setup_config("hf://zeahub/configs/config_camus.yaml")
11
+ >>> config = setup_config("hf://zeahub/configs/config_camus.yaml")
12
12
 
13
- interface = zea.Interface(config)
14
- interface.run(plot=True)
13
+ >>> interface = zea.Interface(config)
14
+ >>> interface.run(plot=True) # doctest: +SKIP
15
15
 
16
16
  """
17
17
 
@@ -31,15 +31,15 @@ from zea.data.file import File
31
31
  from zea.datapaths import format_data_path
32
32
  from zea.display import to_8bit
33
33
  from zea.internal.core import DataTypes
34
+ from zea.internal.utils import keep_trying
34
35
  from zea.internal.viewer import (
35
36
  ImageViewerMatplotlib,
36
37
  ImageViewerOpenCV,
37
38
  filename_from_window_dialog,
38
39
  running_in_notebook,
39
40
  )
40
- from zea.io_lib import matplotlib_figure_to_numpy
41
+ from zea.io_lib import matplotlib_figure_to_numpy, save_video
41
42
  from zea.ops import Pipeline
42
- from zea.utils import keep_trying, save_to_gif, save_to_mp4
43
43
 
44
44
 
45
45
  class Interface:
@@ -266,10 +266,11 @@ class Interface:
266
266
  save = self.config.plot.save
267
267
 
268
268
  if self.frame_no == "all":
269
- if not asyncio.get_event_loop().is_running():
270
- asyncio.run(self.run_movie(save))
271
- else:
272
- asyncio.create_task(self.run_movie(save))
269
+ try:
270
+ loop = asyncio.get_running_loop()
271
+ loop.create_task(self.run_movie(save)) # already running loop
272
+ except RuntimeError:
273
+ asyncio.run(self.run_movie(save)) # no loop yet
273
274
 
274
275
  else:
275
276
  if plot:
@@ -520,10 +521,7 @@ class Interface:
520
521
 
521
522
  fps = self.config.plot.fps
522
523
 
523
- if self.config.plot.video_extension == "gif":
524
- save_to_gif(images, path, fps=fps)
525
- elif self.config.plot.video_extension == "mp4":
526
- save_to_mp4(images, path, fps=fps)
524
+ save_video(images, path, fps=fps)
527
525
 
528
526
  if self.verbose:
529
527
  log.info(f"Video saved to {log.yellow(path)}")
@@ -3,11 +3,10 @@ and :mod:`keras.ops.image` functions.
3
3
 
4
4
  They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
5
5
 
6
- .. code-block:: python
6
+ .. doctest::
7
7
 
8
- from zea.keras_ops import Squeeze
9
-
10
- op = Squeeze(axis=1)
8
+ >>> from zea.keras_ops import Squeeze
9
+ >>> op = Squeeze(axis=1)
11
10
  """
12
11
 
13
12
  import inspect
@@ -77,11 +76,11 @@ and :mod:`keras.ops.image` functions.
77
76
 
78
77
  They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
79
78
 
80
- .. code-block:: python
79
+ .. doctest::
81
80
 
82
- from zea.keras_ops import Squeeze
81
+ >>> from zea.keras_ops import Squeeze
83
82
 
84
- op = Squeeze(axis=1)
83
+ >>> op = Squeeze(axis=1)
85
84
 
86
85
  This file is generated automatically. Do not edit manually.
87
86
  Generated with Keras {keras.__version__}
zea/internal/cache.py CHANGED
@@ -21,10 +21,8 @@
21
21
 
22
22
  import ast
23
23
  import atexit
24
- import hashlib
25
24
  import inspect
26
25
  import os
27
- import pickle
28
26
  import tempfile
29
27
  import textwrap
30
28
  from pathlib import Path
@@ -33,6 +31,7 @@ import joblib
33
31
  import keras
34
32
 
35
33
  from zea import log
34
+ from zea.internal.core import hash_elements
36
35
 
37
36
  _DEFAULT_ZEA_CACHE_DIR = Path.home() / ".cache" / "zea"
38
37
 
@@ -80,52 +79,6 @@ _CACHE_DIR = ZEA_CACHE_DIR / "cached_funcs"
80
79
  _CACHE_DIR.mkdir(parents=True, exist_ok=True)
81
80
 
82
81
 
83
- def serialize_elements(key_elements: list, shorten: bool = False) -> str:
84
- """Serialize elements of a list to generate a cache key.
85
-
86
- In general uses the string representation of the elements unless
87
- the element has a `serialized` attribute, in which case it uses that.
88
- For instance this is useful for custom classes that inherit from `zea.core.Object`.
89
-
90
- Args:
91
- key_elements (list): List of elements to serialize. Can be nested lists
92
- or tuples. In this case the elements are serialized recursively.
93
- shorten (bool): If True, the serialized string is hashed to a shorter
94
- representation using MD5. Defaults to False.
95
-
96
- Returns:
97
- str: A serialized string representation of the elements, joined by underscores.
98
-
99
- """
100
- serialized_elements = []
101
- for element in key_elements:
102
- if isinstance(element, (list, tuple)):
103
- # If element is a list or tuple, serialize its elements recursively
104
- serialized_elements.append(serialize_elements(element))
105
- elif hasattr(element, "serialized"):
106
- # Use the serialized attribute if it exists (e.g. for zea.core.Object)
107
- serialized_elements.append(str(element.serialized))
108
- elif isinstance(element, str):
109
- # If element is a string, use it as is
110
- serialized_elements.append(element)
111
- elif isinstance(element, keras.random.SeedGenerator):
112
- # If element is a SeedGenerator, use the state
113
- element = keras.ops.convert_to_numpy(element.state.value)
114
- element = pickle.dumps(element)
115
- element = hashlib.md5(element).hexdigest()
116
- serialized_elements.append(element)
117
- else:
118
- # Otherwise, serialize the element using pickle and hash it
119
- element = pickle.dumps(element)
120
- element = hashlib.md5(element).hexdigest()
121
- serialized_elements.append(element)
122
-
123
- serialized = "_".join(serialized_elements)
124
- if shorten:
125
- return hashlib.md5(serialized.encode()).hexdigest()
126
- return serialized
127
-
128
-
129
82
  def get_function_source(func):
130
83
  """Recursively get the source code of a function and its nested functions."""
131
84
  try:
@@ -188,7 +141,7 @@ def generate_cache_key(func, args, kwargs, arg_names):
188
141
  # Add keras backend
189
142
  key_elements.append(keras.backend.backend())
190
143
 
191
- return f"{func.__qualname__}_" + serialize_elements(key_elements, shorten=True)
144
+ return f"{func.__qualname__}_" + hash_elements(key_elements)
192
145
 
193
146
 
194
147
  def cache_output(*arg_names, verbose=False):
@@ -15,9 +15,8 @@ from pathlib import Path
15
15
 
16
16
  from schema import And, Optional, Or, Regex, Schema
17
17
 
18
- import zea.metrics # noqa: F401
19
18
  from zea.internal.checks import _DATA_TYPES
20
- from zea.internal.registry import metrics_registry
19
+ from zea.metrics import metrics_registry
21
20
 
22
21
  # predefined checks, later used in schema to check validity of parameter
23
22
  any_number = Or(