zea 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/agent/selection.py +166 -0
- zea/backend/__init__.py +89 -0
- zea/backend/jax/__init__.py +14 -51
- zea/backend/tensorflow/__init__.py +0 -49
- zea/backend/tensorflow/dataloader.py +2 -1
- zea/backend/torch/__init__.py +27 -62
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/config.py +34 -25
- zea/data/__init__.py +22 -16
- zea/data/convert/camus.py +2 -1
- zea/data/convert/echonet.py +4 -4
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +1 -1
- zea/data/convert/matlab.py +11 -4
- zea/data/data_format.py +31 -30
- zea/data/datasets.py +7 -5
- zea/data/file.py +104 -2
- zea/data/layers.py +5 -6
- zea/datapaths.py +16 -4
- zea/display.py +7 -5
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +114 -2
- zea/internal/parameters.py +101 -70
- zea/internal/registry.py +1 -1
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +247 -19
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +365 -65
- zea/models/__init__.py +30 -20
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +187 -26
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -18
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +96 -0
- zea/models/preset_utils.py +5 -5
- zea/models/presets.py +36 -0
- zea/models/regional_quality.py +142 -0
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +414 -207
- zea/probes.py +6 -6
- zea/scan.py +109 -49
- zea/simulator.py +24 -21
- zea/tensor_ops.py +411 -206
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/utils.py +92 -480
- zea/visualize.py +177 -39
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/METADATA +9 -3
- zea-0.0.7.dist-info/RECORD +114 -0
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/WHEEL +1 -1
- zea-0.0.5.dist-info/RECORD +0 -110
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info}/entry_points.txt +0 -0
- {zea-0.0.5.dist-info → zea-0.0.7.dist-info/licenses}/LICENSE +0 -0
zea/data/data_format.py
CHANGED
|
@@ -7,11 +7,12 @@ from dataclasses import dataclass
|
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
|
|
9
9
|
import numpy as np
|
|
10
|
+
from keras.utils import pad_sequences
|
|
10
11
|
|
|
11
12
|
from zea import log
|
|
12
13
|
from zea.data.file import File, validate_file
|
|
13
14
|
from zea.internal.checks import _DATA_TYPES
|
|
14
|
-
from zea.utils import first_not_none_item
|
|
15
|
+
from zea.internal.utils import first_not_none_item
|
|
15
16
|
|
|
16
17
|
|
|
17
18
|
@dataclass
|
|
@@ -470,32 +471,30 @@ def _write_datasets(
|
|
|
470
471
|
)
|
|
471
472
|
|
|
472
473
|
if waveforms_one_way is not None:
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
)
|
|
474
|
+
_add_dataset(
|
|
475
|
+
group_name=scan_group_name,
|
|
476
|
+
name="waveforms_one_way",
|
|
477
|
+
data=pad_sequences(waveforms_one_way, dtype=np.float32, padding="post"),
|
|
478
|
+
description=(
|
|
479
|
+
"One-way waveform as simulated by the Verasonics system, "
|
|
480
|
+
"sampled at 250MHz. This is the waveform after being filtered "
|
|
481
|
+
"by the transducer bandwidth once."
|
|
482
|
+
),
|
|
483
|
+
unit="V",
|
|
484
|
+
)
|
|
485
485
|
|
|
486
486
|
if waveforms_two_way is not None:
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
)
|
|
487
|
+
_add_dataset(
|
|
488
|
+
group_name=scan_group_name,
|
|
489
|
+
name="waveforms_two_way",
|
|
490
|
+
data=pad_sequences(waveforms_two_way, dtype=np.float32, padding="post"),
|
|
491
|
+
description=(
|
|
492
|
+
"Two-way waveform as simulated by the Verasonics system, "
|
|
493
|
+
"sampled at 250MHz. This is the waveform after being filtered "
|
|
494
|
+
"by the transducer bandwidth twice."
|
|
495
|
+
),
|
|
496
|
+
unit="V",
|
|
497
|
+
)
|
|
499
498
|
|
|
500
499
|
# Add additional elements
|
|
501
500
|
if additional_elements is not None:
|
|
@@ -539,6 +538,7 @@ def generate_zea_dataset(
|
|
|
539
538
|
additional_elements=None,
|
|
540
539
|
event_structure=False,
|
|
541
540
|
cast_to_float=True,
|
|
541
|
+
overwrite=False,
|
|
542
542
|
):
|
|
543
543
|
"""Generates a dataset in the zea format.
|
|
544
544
|
|
|
@@ -585,10 +585,10 @@ def generate_zea_dataset(
|
|
|
585
585
|
waveform was used for each transmit event.
|
|
586
586
|
waveforms_one_way (list): List of one-way waveforms as simulated by the Verasonics
|
|
587
587
|
system, sampled at 250MHz. This is the waveform after being filtered by the
|
|
588
|
-
|
|
588
|
+
transducer bandwidth once. Every element in the list is a 1D numpy array.
|
|
589
589
|
waveforms_two_way (list): List of two-way waveforms as simulated by the Verasonics
|
|
590
590
|
system, sampled at 250MHz. This is the waveform after being filtered by the
|
|
591
|
-
|
|
591
|
+
transducer bandwidth twice. Every element in the list is a 1D numpy array.
|
|
592
592
|
additional_elements (List[DatasetElement]): A list of additional dataset
|
|
593
593
|
elements to be added to the dataset. Each element should be a DatasetElement
|
|
594
594
|
object. The additional elements are added under the scan group.
|
|
@@ -598,6 +598,7 @@ def generate_zea_dataset(
|
|
|
598
598
|
Instead of just a single data and scan group.
|
|
599
599
|
cast_to_float (bool): Whether to store data as float32. You may want to set this
|
|
600
600
|
to False if storing images.
|
|
601
|
+
overwrite (bool): Whether to overwrite the file if it already exists. Defaults to False.
|
|
601
602
|
|
|
602
603
|
"""
|
|
603
604
|
# check if all args are lists
|
|
@@ -637,10 +638,10 @@ def generate_zea_dataset(
|
|
|
637
638
|
# make sure input arguments of func is same length as data_and_parameters
|
|
638
639
|
# except `path` and `event_structure` arguments and ofcourse `data_and_parameters` itself
|
|
639
640
|
assert (
|
|
640
|
-
len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) -
|
|
641
|
+
len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) - 4
|
|
641
642
|
), (
|
|
642
643
|
"All arguments should be put in data_and_parameters except "
|
|
643
|
-
"`path`, `event_structure`, and `
|
|
644
|
+
"`path`, `event_structure`, `cast_to_float`, and `overwrite` arguments."
|
|
644
645
|
)
|
|
645
646
|
|
|
646
647
|
if event_structure:
|
|
@@ -682,7 +683,7 @@ def generate_zea_dataset(
|
|
|
682
683
|
# Convert path to Path object
|
|
683
684
|
path = Path(path)
|
|
684
685
|
|
|
685
|
-
if path.exists():
|
|
686
|
+
if path.exists() and not overwrite:
|
|
686
687
|
raise FileExistsError(f"The file {path} already exists.")
|
|
687
688
|
|
|
688
689
|
# Create the directory if it does not exist
|
zea/data/datasets.py
CHANGED
|
@@ -48,13 +48,15 @@ from zea.data.preset_utils import (
|
|
|
48
48
|
_hf_resolve_path,
|
|
49
49
|
)
|
|
50
50
|
from zea.datapaths import format_data_path
|
|
51
|
+
from zea.internal.utils import (
|
|
52
|
+
calculate_file_hash,
|
|
53
|
+
reduce_to_signature,
|
|
54
|
+
)
|
|
51
55
|
from zea.io_lib import search_file_tree
|
|
52
56
|
from zea.tools.hf import HFPath
|
|
53
57
|
from zea.utils import (
|
|
54
|
-
calculate_file_hash,
|
|
55
58
|
date_string_to_readable,
|
|
56
59
|
get_date_string,
|
|
57
|
-
reduce_to_signature,
|
|
58
60
|
)
|
|
59
61
|
|
|
60
62
|
_CHECK_MAX_DATASET_SIZE = 10000
|
|
@@ -241,7 +243,7 @@ class Folder:
|
|
|
241
243
|
return
|
|
242
244
|
|
|
243
245
|
num_frames_per_file = []
|
|
244
|
-
|
|
246
|
+
validated_successfully = True
|
|
245
247
|
for file_path in tqdm.tqdm(
|
|
246
248
|
self.file_paths,
|
|
247
249
|
total=self.n_files,
|
|
@@ -253,9 +255,9 @@ class Folder:
|
|
|
253
255
|
validation_error_log.append(f"File {file_path} is not a valid zea dataset.\n{e}\n")
|
|
254
256
|
# convert into warning
|
|
255
257
|
log.warning(f"Error in file {file_path}.\n{e}")
|
|
256
|
-
|
|
258
|
+
validated_successfully = False
|
|
257
259
|
|
|
258
|
-
if not
|
|
260
|
+
if not validated_successfully:
|
|
259
261
|
log.warning(
|
|
260
262
|
"Check warnings above for details. No validation file was created. "
|
|
261
263
|
f"See {validation_error_file_path} for details."
|
zea/data/file.py
CHANGED
|
@@ -6,6 +6,7 @@ from typing import List
|
|
|
6
6
|
|
|
7
7
|
import h5py
|
|
8
8
|
import numpy as np
|
|
9
|
+
from keras.utils import pad_sequences
|
|
9
10
|
|
|
10
11
|
from zea import log
|
|
11
12
|
from zea.data.preset_utils import HF_PREFIX, _hf_resolve_path
|
|
@@ -15,9 +16,9 @@ from zea.internal.checks import (
|
|
|
15
16
|
_REQUIRED_SCAN_KEYS,
|
|
16
17
|
get_check,
|
|
17
18
|
)
|
|
19
|
+
from zea.internal.utils import reduce_to_signature
|
|
18
20
|
from zea.probes import Probe
|
|
19
21
|
from zea.scan import Scan
|
|
20
|
-
from zea.utils import reduce_to_signature
|
|
21
22
|
|
|
22
23
|
|
|
23
24
|
def assert_key(file: h5py.File, key: str):
|
|
@@ -219,6 +220,41 @@ class File(h5py.File):
|
|
|
219
220
|
def load_data(self, data_type, indices: str | int | List[int] = "all"):
|
|
220
221
|
"""Load data from the file.
|
|
221
222
|
|
|
223
|
+
The indices parameter can be used to load a subset of the data. This can be
|
|
224
|
+
|
|
225
|
+
- 'all' to load all data
|
|
226
|
+
|
|
227
|
+
- an int to load a single frame
|
|
228
|
+
|
|
229
|
+
- a list of ints to load specific frames
|
|
230
|
+
|
|
231
|
+
- a tuple of lists, ranges or slices to index frames and transmits. Note that
|
|
232
|
+
indexing with lists of indices for both axes is not supported. In that case,
|
|
233
|
+
try to define one of the axes with a slice.
|
|
234
|
+
|
|
235
|
+
.. doctest::
|
|
236
|
+
|
|
237
|
+
>>> from zea import File
|
|
238
|
+
|
|
239
|
+
>>> path_to_file = (
|
|
240
|
+
... "hf://zeahub/picmus/database/experiments/contrast_speckle/"
|
|
241
|
+
... "contrast_speckle_expe_dataset_iq/contrast_speckle_expe_dataset_iq.hdf5"
|
|
242
|
+
... )
|
|
243
|
+
|
|
244
|
+
>>> with File(path_to_file, mode="r") as file:
|
|
245
|
+
... # data has shape (n_frames, n_tx, n_el, n_ax, n_ch)
|
|
246
|
+
... data = file.load_data("raw_data")
|
|
247
|
+
... data.shape
|
|
248
|
+
... # load first frame only
|
|
249
|
+
... data = file.load_data("raw_data", indices=0)
|
|
250
|
+
... data.shape
|
|
251
|
+
... # load frame 0 and transmits 0, 2 and 4
|
|
252
|
+
... data = file.load_data("raw_data", indices=(0, [0, 2, 4]))
|
|
253
|
+
... data.shape
|
|
254
|
+
(1, 75, 832, 128, 2)
|
|
255
|
+
(75, 832, 128, 2)
|
|
256
|
+
(3, 832, 128, 2)
|
|
257
|
+
|
|
222
258
|
Args:
|
|
223
259
|
data_type (str): The type of data to load. Options are 'raw_data', 'aligned_data',
|
|
224
260
|
'beamformed_data', 'envelope_data', 'image' and 'image_sc'.
|
|
@@ -337,7 +373,7 @@ class File(h5py.File):
|
|
|
337
373
|
Returns:
|
|
338
374
|
Scan: The scan object.
|
|
339
375
|
"""
|
|
340
|
-
return Scan.merge(self.get_scan_parameters(event), kwargs, safe=safe)
|
|
376
|
+
return Scan.merge(_reformat_waveforms(self.get_scan_parameters(event)), kwargs, safe=safe)
|
|
341
377
|
|
|
342
378
|
def get_probe_parameters(self, event=None) -> dict:
|
|
343
379
|
"""Returns a dictionary of probe parameters to initialize a probe
|
|
@@ -772,3 +808,69 @@ def _assert_unit_and_description_present(hdf5_file, _prefix=""):
|
|
|
772
808
|
assert "description" in hdf5_file[key].attrs.keys(), (
|
|
773
809
|
f"The file {_prefix}/{key} does not have a description attribute."
|
|
774
810
|
)
|
|
811
|
+
|
|
812
|
+
|
|
813
|
+
def _reformat_waveforms(scan_kwargs: dict) -> dict:
|
|
814
|
+
"""Reformat waveforms from dict to array if needed. This is for backwards compatibility and will
|
|
815
|
+
be removed in a future version of zea.
|
|
816
|
+
|
|
817
|
+
Args:
|
|
818
|
+
scan_kwargs (dict): The scan parameters.
|
|
819
|
+
|
|
820
|
+
Returns:
|
|
821
|
+
scan_kwargs (dict): The scan parameters with the keys waveforms_one_way and
|
|
822
|
+
waveforms_two_way reformatted to arrays if they were stored as dicts.
|
|
823
|
+
"""
|
|
824
|
+
|
|
825
|
+
# TODO: remove this in a future version of zea
|
|
826
|
+
if "waveforms_one_way" in scan_kwargs and isinstance(scan_kwargs["waveforms_one_way"], dict):
|
|
827
|
+
log.warning(
|
|
828
|
+
"The waveforms_one_way parameter is stored as a dictionary in the file. "
|
|
829
|
+
"Converting to array. This will be deprecated in future versions of zea. "
|
|
830
|
+
"Please update your files to store waveforms as arrays of shape `(n_tx, n_samples)`."
|
|
831
|
+
)
|
|
832
|
+
scan_kwargs["waveforms_one_way"] = _waveforms_dict_to_array(
|
|
833
|
+
scan_kwargs["waveforms_one_way"]
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
if "waveforms_two_way" in scan_kwargs and isinstance(scan_kwargs["waveforms_two_way"], dict):
|
|
837
|
+
log.warning(
|
|
838
|
+
"The waveforms_two_way parameter is stored as a dictionary in the file. "
|
|
839
|
+
"Converting to array. This will be deprecated in future versions of zea. "
|
|
840
|
+
"Please update your files to store waveforms as arrays of shape `(n_tx, n_samples)`."
|
|
841
|
+
)
|
|
842
|
+
scan_kwargs["waveforms_two_way"] = _waveforms_dict_to_array(
|
|
843
|
+
scan_kwargs["waveforms_two_way"]
|
|
844
|
+
)
|
|
845
|
+
return scan_kwargs
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
def _waveforms_dict_to_array(waveforms_dict: dict):
|
|
849
|
+
"""Convert waveforms stored as a dictionary to a padded numpy array."""
|
|
850
|
+
waveforms = dict_to_sorted_list(waveforms_dict)
|
|
851
|
+
return pad_sequences(waveforms, dtype=np.float32, padding="post")
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def dict_to_sorted_list(dictionary: dict):
|
|
855
|
+
"""Convert a dictionary with sortable keys to a sorted list of values.
|
|
856
|
+
|
|
857
|
+
.. note::
|
|
858
|
+
|
|
859
|
+
This function operates on the top level of the dictionary only.
|
|
860
|
+
If the dictionary contains nested dictionaries, those will not be sorted.
|
|
861
|
+
|
|
862
|
+
Example:
|
|
863
|
+
.. doctest::
|
|
864
|
+
|
|
865
|
+
>>> from zea.data.file import dict_to_sorted_list
|
|
866
|
+
>>> input_dict = {"number_000": 5, "number_001": 1, "number_002": 23}
|
|
867
|
+
>>> dict_to_sorted_list(input_dict)
|
|
868
|
+
[5, 1, 23]
|
|
869
|
+
|
|
870
|
+
Args:
|
|
871
|
+
dictionary (dict): The dictionary to convert. The keys must be sortable.
|
|
872
|
+
|
|
873
|
+
Returns:
|
|
874
|
+
list: The sorted list of values.
|
|
875
|
+
"""
|
|
876
|
+
return [value for _, value in sorted(dictionary.items())]
|
zea/data/layers.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import keras
|
|
4
4
|
import numpy as np
|
|
5
|
-
from keras.src.layers.preprocessing.
|
|
5
|
+
from keras.src.layers.preprocessing.data_layer import DataLayer
|
|
6
6
|
|
|
7
7
|
from zea.ops import Pad as PadOp
|
|
8
8
|
from zea.utils import map_negative_indices
|
|
@@ -11,7 +11,7 @@ from zea.utils import map_negative_indices
|
|
|
11
11
|
class Pad(PadOp):
|
|
12
12
|
"""Pad layer for padding tensors to a specified shape which can be used in tf.data pipelines."""
|
|
13
13
|
|
|
14
|
-
__call__ =
|
|
14
|
+
__call__ = DataLayer.__call__
|
|
15
15
|
|
|
16
16
|
def call(self, inputs):
|
|
17
17
|
"""
|
|
@@ -20,12 +20,12 @@ class Pad(PadOp):
|
|
|
20
20
|
return super().call(data=inputs)["data"]
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
class Resizer(
|
|
23
|
+
class Resizer(DataLayer):
|
|
24
24
|
"""
|
|
25
25
|
Resize layer for resizing images. Can deal with N-dimensional images.
|
|
26
26
|
Can do resize, center_crop, random_crop and crop_or_pad.
|
|
27
27
|
|
|
28
|
-
Can be used in tf.data pipelines.
|
|
28
|
+
Can be used in `tf.data` pipelines.
|
|
29
29
|
"""
|
|
30
30
|
|
|
31
31
|
def __init__(
|
|
@@ -36,7 +36,6 @@ class Resizer(TFDataLayer):
|
|
|
36
36
|
seed: int | None = None,
|
|
37
37
|
**resize_kwargs,
|
|
38
38
|
):
|
|
39
|
-
# noqa: E501
|
|
40
39
|
"""
|
|
41
40
|
Initializes the data loader with the specified parameters.
|
|
42
41
|
|
|
@@ -47,7 +46,7 @@ class Resizer(TFDataLayer):
|
|
|
47
46
|
['random_crop'](https://keras.io/api/layers/preprocessing_layers/image_augmentation/random_crop/),
|
|
48
47
|
['resize'](https://keras.io/api/layers/preprocessing_layers/image_preprocessing/resizing/),
|
|
49
48
|
'crop_or_pad': resizes an image to a target width and height by either centrally
|
|
50
|
-
|
|
49
|
+
cropping the image, padding it evenly with zeros or a combination of both.
|
|
51
50
|
resize_axes (tuple | None, optional): The axes along which to resize.
|
|
52
51
|
Must be of length 2. Defaults to None. In that case, can only process
|
|
53
52
|
default tensors of shape (batch, height, width, channels), where the
|
zea/datapaths.py
CHANGED
|
@@ -11,12 +11,24 @@ to set up your local data paths.
|
|
|
11
11
|
Example usage
|
|
12
12
|
^^^^^^^^^^^^^
|
|
13
13
|
|
|
14
|
-
..
|
|
14
|
+
.. doctest::
|
|
15
15
|
|
|
16
|
-
|
|
16
|
+
>>> import yaml
|
|
17
|
+
>>> from zea.datapaths import set_data_paths
|
|
17
18
|
|
|
18
|
-
|
|
19
|
-
|
|
19
|
+
>>> user_config = {"data_root": "/path/to/data", "output": "/path/to/output"}
|
|
20
|
+
>>> with open("users.yaml", "w", encoding="utf-8") as file:
|
|
21
|
+
... yaml.dump(user_config, file)
|
|
22
|
+
|
|
23
|
+
>>> user = set_data_paths("users.yaml")
|
|
24
|
+
>>> print(user.data_root)
|
|
25
|
+
/path/to/data
|
|
26
|
+
|
|
27
|
+
.. testcleanup::
|
|
28
|
+
|
|
29
|
+
import os
|
|
30
|
+
|
|
31
|
+
os.remove("users.yaml")
|
|
20
32
|
|
|
21
33
|
"""
|
|
22
34
|
|
zea/display.py
CHANGED
|
@@ -9,8 +9,8 @@ from keras import ops
|
|
|
9
9
|
from PIL import Image
|
|
10
10
|
|
|
11
11
|
from zea import log
|
|
12
|
+
from zea.tensor_ops import translate
|
|
12
13
|
from zea.tools.fit_scan_cone import fit_and_crop_around_scan_cone
|
|
13
|
-
from zea.utils import translate
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
def to_8bit(image, dynamic_range: Union[None, tuple] = None, pillow: bool = True):
|
|
@@ -340,12 +340,14 @@ def scan_convert(
|
|
|
340
340
|
def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value=0):
|
|
341
341
|
"""map_coordinates using keras.ops or scipy.ndimage when order > 1."""
|
|
342
342
|
if order > 1:
|
|
343
|
-
|
|
344
|
-
|
|
343
|
+
# Preserve original dtype before conversion
|
|
344
|
+
original_dtype = ops.dtype(inputs)
|
|
345
|
+
inputs_np = ops.convert_to_numpy(inputs).astype(np.float32)
|
|
346
|
+
coordinates_np = ops.convert_to_numpy(coordinates).astype(np.float32)
|
|
345
347
|
out = scipy.ndimage.map_coordinates(
|
|
346
|
-
|
|
348
|
+
inputs_np, coordinates_np, order=order, mode=fill_mode, cval=fill_value
|
|
347
349
|
)
|
|
348
|
-
return ops.convert_to_tensor(out)
|
|
350
|
+
return ops.convert_to_tensor(out.astype(original_dtype))
|
|
349
351
|
else:
|
|
350
352
|
return ops.image.map_coordinates(
|
|
351
353
|
inputs,
|
zea/interface.py
CHANGED
|
@@ -3,15 +3,15 @@
|
|
|
3
3
|
Example usage
|
|
4
4
|
^^^^^^^^^^^^^^
|
|
5
5
|
|
|
6
|
-
..
|
|
6
|
+
.. doctest::
|
|
7
7
|
|
|
8
|
-
import zea
|
|
9
|
-
from zea.internal.setup_zea import setup_config
|
|
8
|
+
>>> import zea
|
|
9
|
+
>>> from zea.internal.setup_zea import setup_config
|
|
10
10
|
|
|
11
|
-
config = setup_config("hf://zeahub/configs/config_camus.yaml")
|
|
11
|
+
>>> config = setup_config("hf://zeahub/configs/config_camus.yaml")
|
|
12
12
|
|
|
13
|
-
interface = zea.Interface(config)
|
|
14
|
-
interface.run(plot=True)
|
|
13
|
+
>>> interface = zea.Interface(config)
|
|
14
|
+
>>> interface.run(plot=True) # doctest: +SKIP
|
|
15
15
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
@@ -31,15 +31,15 @@ from zea.data.file import File
|
|
|
31
31
|
from zea.datapaths import format_data_path
|
|
32
32
|
from zea.display import to_8bit
|
|
33
33
|
from zea.internal.core import DataTypes
|
|
34
|
+
from zea.internal.utils import keep_trying
|
|
34
35
|
from zea.internal.viewer import (
|
|
35
36
|
ImageViewerMatplotlib,
|
|
36
37
|
ImageViewerOpenCV,
|
|
37
38
|
filename_from_window_dialog,
|
|
38
39
|
running_in_notebook,
|
|
39
40
|
)
|
|
40
|
-
from zea.io_lib import matplotlib_figure_to_numpy
|
|
41
|
+
from zea.io_lib import matplotlib_figure_to_numpy, save_video
|
|
41
42
|
from zea.ops import Pipeline
|
|
42
|
-
from zea.utils import keep_trying, save_to_gif, save_to_mp4
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
class Interface:
|
|
@@ -266,10 +266,11 @@ class Interface:
|
|
|
266
266
|
save = self.config.plot.save
|
|
267
267
|
|
|
268
268
|
if self.frame_no == "all":
|
|
269
|
-
|
|
270
|
-
asyncio.
|
|
271
|
-
|
|
272
|
-
|
|
269
|
+
try:
|
|
270
|
+
loop = asyncio.get_running_loop()
|
|
271
|
+
loop.create_task(self.run_movie(save)) # already running loop
|
|
272
|
+
except RuntimeError:
|
|
273
|
+
asyncio.run(self.run_movie(save)) # no loop yet
|
|
273
274
|
|
|
274
275
|
else:
|
|
275
276
|
if plot:
|
|
@@ -520,10 +521,7 @@ class Interface:
|
|
|
520
521
|
|
|
521
522
|
fps = self.config.plot.fps
|
|
522
523
|
|
|
523
|
-
|
|
524
|
-
save_to_gif(images, path, fps=fps)
|
|
525
|
-
elif self.config.plot.video_extension == "mp4":
|
|
526
|
-
save_to_mp4(images, path, fps=fps)
|
|
524
|
+
save_video(images, path, fps=fps)
|
|
527
525
|
|
|
528
526
|
if self.verbose:
|
|
529
527
|
log.info(f"Video saved to {log.yellow(path)}")
|
|
@@ -3,11 +3,10 @@ and :mod:`keras.ops.image` functions.
|
|
|
3
3
|
|
|
4
4
|
They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
|
|
5
5
|
|
|
6
|
-
..
|
|
6
|
+
.. doctest::
|
|
7
7
|
|
|
8
|
-
from zea.keras_ops import Squeeze
|
|
9
|
-
|
|
10
|
-
op = Squeeze(axis=1)
|
|
8
|
+
>>> from zea.keras_ops import Squeeze
|
|
9
|
+
>>> op = Squeeze(axis=1)
|
|
11
10
|
"""
|
|
12
11
|
|
|
13
12
|
import inspect
|
|
@@ -77,11 +76,11 @@ and :mod:`keras.ops.image` functions.
|
|
|
77
76
|
|
|
78
77
|
They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
|
|
79
78
|
|
|
80
|
-
..
|
|
79
|
+
.. doctest::
|
|
81
80
|
|
|
82
|
-
from zea.keras_ops import Squeeze
|
|
81
|
+
>>> from zea.keras_ops import Squeeze
|
|
83
82
|
|
|
84
|
-
op = Squeeze(axis=1)
|
|
83
|
+
>>> op = Squeeze(axis=1)
|
|
85
84
|
|
|
86
85
|
This file is generated automatically. Do not edit manually.
|
|
87
86
|
Generated with Keras {keras.__version__}
|
zea/internal/cache.py
CHANGED
|
@@ -21,10 +21,8 @@
|
|
|
21
21
|
|
|
22
22
|
import ast
|
|
23
23
|
import atexit
|
|
24
|
-
import hashlib
|
|
25
24
|
import inspect
|
|
26
25
|
import os
|
|
27
|
-
import pickle
|
|
28
26
|
import tempfile
|
|
29
27
|
import textwrap
|
|
30
28
|
from pathlib import Path
|
|
@@ -33,6 +31,7 @@ import joblib
|
|
|
33
31
|
import keras
|
|
34
32
|
|
|
35
33
|
from zea import log
|
|
34
|
+
from zea.internal.core import hash_elements
|
|
36
35
|
|
|
37
36
|
_DEFAULT_ZEA_CACHE_DIR = Path.home() / ".cache" / "zea"
|
|
38
37
|
|
|
@@ -80,52 +79,6 @@ _CACHE_DIR = ZEA_CACHE_DIR / "cached_funcs"
|
|
|
80
79
|
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
81
80
|
|
|
82
81
|
|
|
83
|
-
def serialize_elements(key_elements: list, shorten: bool = False) -> str:
|
|
84
|
-
"""Serialize elements of a list to generate a cache key.
|
|
85
|
-
|
|
86
|
-
In general uses the string representation of the elements unless
|
|
87
|
-
the element has a `serialized` attribute, in which case it uses that.
|
|
88
|
-
For instance this is useful for custom classes that inherit from `zea.core.Object`.
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
key_elements (list): List of elements to serialize. Can be nested lists
|
|
92
|
-
or tuples. In this case the elements are serialized recursively.
|
|
93
|
-
shorten (bool): If True, the serialized string is hashed to a shorter
|
|
94
|
-
representation using MD5. Defaults to False.
|
|
95
|
-
|
|
96
|
-
Returns:
|
|
97
|
-
str: A serialized string representation of the elements, joined by underscores.
|
|
98
|
-
|
|
99
|
-
"""
|
|
100
|
-
serialized_elements = []
|
|
101
|
-
for element in key_elements:
|
|
102
|
-
if isinstance(element, (list, tuple)):
|
|
103
|
-
# If element is a list or tuple, serialize its elements recursively
|
|
104
|
-
serialized_elements.append(serialize_elements(element))
|
|
105
|
-
elif hasattr(element, "serialized"):
|
|
106
|
-
# Use the serialized attribute if it exists (e.g. for zea.core.Object)
|
|
107
|
-
serialized_elements.append(str(element.serialized))
|
|
108
|
-
elif isinstance(element, str):
|
|
109
|
-
# If element is a string, use it as is
|
|
110
|
-
serialized_elements.append(element)
|
|
111
|
-
elif isinstance(element, keras.random.SeedGenerator):
|
|
112
|
-
# If element is a SeedGenerator, use the state
|
|
113
|
-
element = keras.ops.convert_to_numpy(element.state.value)
|
|
114
|
-
element = pickle.dumps(element)
|
|
115
|
-
element = hashlib.md5(element).hexdigest()
|
|
116
|
-
serialized_elements.append(element)
|
|
117
|
-
else:
|
|
118
|
-
# Otherwise, serialize the element using pickle and hash it
|
|
119
|
-
element = pickle.dumps(element)
|
|
120
|
-
element = hashlib.md5(element).hexdigest()
|
|
121
|
-
serialized_elements.append(element)
|
|
122
|
-
|
|
123
|
-
serialized = "_".join(serialized_elements)
|
|
124
|
-
if shorten:
|
|
125
|
-
return hashlib.md5(serialized.encode()).hexdigest()
|
|
126
|
-
return serialized
|
|
127
|
-
|
|
128
|
-
|
|
129
82
|
def get_function_source(func):
|
|
130
83
|
"""Recursively get the source code of a function and its nested functions."""
|
|
131
84
|
try:
|
|
@@ -188,7 +141,7 @@ def generate_cache_key(func, args, kwargs, arg_names):
|
|
|
188
141
|
# Add keras backend
|
|
189
142
|
key_elements.append(keras.backend.backend())
|
|
190
143
|
|
|
191
|
-
return f"{func.__qualname__}_" +
|
|
144
|
+
return f"{func.__qualname__}_" + hash_elements(key_elements)
|
|
192
145
|
|
|
193
146
|
|
|
194
147
|
def cache_output(*arg_names, verbose=False):
|
|
@@ -15,9 +15,8 @@ from pathlib import Path
|
|
|
15
15
|
|
|
16
16
|
from schema import And, Optional, Or, Regex, Schema
|
|
17
17
|
|
|
18
|
-
import zea.metrics # noqa: F401
|
|
19
18
|
from zea.internal.checks import _DATA_TYPES
|
|
20
|
-
from zea.
|
|
19
|
+
from zea.metrics import metrics_registry
|
|
21
20
|
|
|
22
21
|
# predefined checks, later used in schema to check validity of parameter
|
|
23
22
|
any_number = Or(
|