zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -5
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/config.py +34 -25
- zea/data/__init__.py +22 -25
- zea/data/augmentations.py +221 -28
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +123 -0
- zea/data/convert/camus.py +101 -40
- zea/data/convert/echonet.py +187 -86
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/{matlab.py → verasonics.py} +44 -65
- zea/data/data_format.py +155 -34
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +112 -71
- zea/data/file.py +184 -73
- zea/data/file_operations.py +496 -0
- zea/data/layers.py +3 -3
- zea/data/preset_utils.py +1 -1
- zea/datapaths.py +16 -4
- zea/display.py +14 -13
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/checks.py +6 -12
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +118 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +322 -146
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +31 -21
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +235 -23
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +30 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +770 -336
- zea/probes.py +6 -6
- zea/scan.py +121 -51
- zea/simulator.py +24 -21
- zea/tensor_ops.py +477 -353
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +101 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
- zea-0.0.8.dist-info/RECORD +122 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/data/data_format.py
CHANGED
|
@@ -6,12 +6,14 @@ import inspect
|
|
|
6
6
|
from dataclasses import dataclass
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
|
|
9
|
+
import h5py
|
|
9
10
|
import numpy as np
|
|
11
|
+
from keras.utils import pad_sequences
|
|
10
12
|
|
|
11
13
|
from zea import log
|
|
12
14
|
from zea.data.file import File, validate_file
|
|
13
15
|
from zea.internal.checks import _DATA_TYPES
|
|
14
|
-
from zea.utils import first_not_none_item
|
|
16
|
+
from zea.internal.utils import first_not_none_item
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
@dataclass
|
|
@@ -19,15 +21,15 @@ class DatasetElement:
|
|
|
19
21
|
"""Class to store a dataset element with a name, data, description and unit. Used to
|
|
20
22
|
supply additional dataset elements to the generate_zea_dataset function."""
|
|
21
23
|
|
|
22
|
-
# The group name to store the dataset under. This can be a nested group, e.g.
|
|
23
|
-
# "scan/waveforms"
|
|
24
|
-
group_name: str
|
|
25
24
|
# The name of the dataset. This will be the key in the group.
|
|
26
25
|
dataset_name: str
|
|
27
26
|
# The data to store in the dataset.
|
|
28
27
|
data: np.ndarray
|
|
29
28
|
description: str
|
|
30
29
|
unit: str
|
|
30
|
+
# The group name to store the dataset under. This can be a nested group, e.g.
|
|
31
|
+
# "lens/profiles"
|
|
32
|
+
group_name: str = ""
|
|
31
33
|
|
|
32
34
|
|
|
33
35
|
def generate_example_dataset(
|
|
@@ -110,9 +112,43 @@ def generate_example_dataset(
|
|
|
110
112
|
focus_distances=focus_distances,
|
|
111
113
|
polar_angles=polar_angles,
|
|
112
114
|
azimuth_angles=azimuth_angles,
|
|
115
|
+
additional_elements=_generate_example_dataset_elements(),
|
|
116
|
+
description="This is an example dataset generated by zea",
|
|
113
117
|
)
|
|
114
118
|
|
|
115
119
|
|
|
120
|
+
def _generate_example_dataset_elements() -> list[DatasetElement]:
|
|
121
|
+
"""Generates a list of example DatasetElement objects to be used as additional
|
|
122
|
+
elements in the generate_zea_dataset function.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
list: A list of DatasetElement objects.
|
|
126
|
+
"""
|
|
127
|
+
example_elements = [
|
|
128
|
+
DatasetElement(
|
|
129
|
+
dataset_name="temperature",
|
|
130
|
+
data=np.array(42),
|
|
131
|
+
description="The temperature during the measurement",
|
|
132
|
+
unit="unitless",
|
|
133
|
+
),
|
|
134
|
+
DatasetElement(
|
|
135
|
+
dataset_name="lens_profile",
|
|
136
|
+
data=np.random.rand(100),
|
|
137
|
+
description="An example lens profile",
|
|
138
|
+
unit="mm",
|
|
139
|
+
group_name="lens",
|
|
140
|
+
),
|
|
141
|
+
DatasetElement(
|
|
142
|
+
dataset_name="lens_material",
|
|
143
|
+
data=np.array(["material1", "material2", "material3"], dtype=h5py.string_dtype()),
|
|
144
|
+
description="An example lens material list",
|
|
145
|
+
unit="unitless",
|
|
146
|
+
group_name="lens",
|
|
147
|
+
),
|
|
148
|
+
]
|
|
149
|
+
return example_elements
|
|
150
|
+
|
|
151
|
+
|
|
116
152
|
def validate_input_data(raw_data, aligned_data, envelope_data, beamformed_data, image, image_sc):
|
|
117
153
|
"""
|
|
118
154
|
Validates input data for generate_zea_dataset
|
|
@@ -470,38 +506,45 @@ def _write_datasets(
|
|
|
470
506
|
)
|
|
471
507
|
|
|
472
508
|
if waveforms_one_way is not None:
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
)
|
|
509
|
+
_add_dataset(
|
|
510
|
+
group_name=scan_group_name,
|
|
511
|
+
name="waveforms_one_way",
|
|
512
|
+
data=pad_sequences(waveforms_one_way, dtype=np.float32, padding="post"),
|
|
513
|
+
description=(
|
|
514
|
+
"One-way waveform as simulated by the Verasonics system, "
|
|
515
|
+
"sampled at 250MHz. This is the waveform after being filtered "
|
|
516
|
+
"by the transducer bandwidth once."
|
|
517
|
+
),
|
|
518
|
+
unit="V",
|
|
519
|
+
)
|
|
485
520
|
|
|
486
521
|
if waveforms_two_way is not None:
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
)
|
|
522
|
+
_add_dataset(
|
|
523
|
+
group_name=scan_group_name,
|
|
524
|
+
name="waveforms_two_way",
|
|
525
|
+
data=pad_sequences(waveforms_two_way, dtype=np.float32, padding="post"),
|
|
526
|
+
description=(
|
|
527
|
+
"Two-way waveform as simulated by the Verasonics system, "
|
|
528
|
+
"sampled at 250MHz. This is the waveform after being filtered "
|
|
529
|
+
"by the transducer bandwidth twice."
|
|
530
|
+
),
|
|
531
|
+
unit="V",
|
|
532
|
+
)
|
|
499
533
|
|
|
500
534
|
# Add additional elements
|
|
501
535
|
if additional_elements is not None:
|
|
536
|
+
# Write scan group
|
|
537
|
+
non_standard_elements_group_name = "non_standard_elements"
|
|
538
|
+
non_standard_elements_group = dataset.create_group(non_standard_elements_group_name)
|
|
539
|
+
non_standard_elements_group.attrs["description"] = (
|
|
540
|
+
"This group contains non-standard elements that can be added by the user."
|
|
541
|
+
)
|
|
502
542
|
for element in additional_elements:
|
|
543
|
+
group_name = non_standard_elements_group_name
|
|
544
|
+
if element.group_name != "":
|
|
545
|
+
group_name += f"/{element.group_name}"
|
|
503
546
|
_add_dataset(
|
|
504
|
-
group_name=
|
|
547
|
+
group_name=group_name,
|
|
505
548
|
name=element.dataset_name,
|
|
506
549
|
data=element.data,
|
|
507
550
|
description=element.description,
|
|
@@ -539,6 +582,7 @@ def generate_zea_dataset(
|
|
|
539
582
|
additional_elements=None,
|
|
540
583
|
event_structure=False,
|
|
541
584
|
cast_to_float=True,
|
|
585
|
+
overwrite=False,
|
|
542
586
|
):
|
|
543
587
|
"""Generates a dataset in the zea format.
|
|
544
588
|
|
|
@@ -585,10 +629,10 @@ def generate_zea_dataset(
|
|
|
585
629
|
waveform was used for each transmit event.
|
|
586
630
|
waveforms_one_way (list): List of one-way waveforms as simulated by the Verasonics
|
|
587
631
|
system, sampled at 250MHz. This is the waveform after being filtered by the
|
|
588
|
-
|
|
632
|
+
transducer bandwidth once. Every element in the list is a 1D numpy array.
|
|
589
633
|
waveforms_two_way (list): List of two-way waveforms as simulated by the Verasonics
|
|
590
634
|
system, sampled at 250MHz. This is the waveform after being filtered by the
|
|
591
|
-
|
|
635
|
+
transducer bandwidth twice. Every element in the list is a 1D numpy array.
|
|
592
636
|
additional_elements (List[DatasetElement]): A list of additional dataset
|
|
593
637
|
elements to be added to the dataset. Each element should be a DatasetElement
|
|
594
638
|
object. The additional elements are added under the scan group.
|
|
@@ -598,6 +642,7 @@ def generate_zea_dataset(
|
|
|
598
642
|
Instead of just a single data and scan group.
|
|
599
643
|
cast_to_float (bool): Whether to store data as float32. You may want to set this
|
|
600
644
|
to False if storing images.
|
|
645
|
+
overwrite (bool): Whether to overwrite the file if it already exists. Defaults to False.
|
|
601
646
|
|
|
602
647
|
"""
|
|
603
648
|
# check if all args are lists
|
|
@@ -637,10 +682,10 @@ def generate_zea_dataset(
|
|
|
637
682
|
# make sure input arguments of func is same length as data_and_parameters
|
|
638
683
|
# except `path` and `event_structure` arguments and ofcourse `data_and_parameters` itself
|
|
639
684
|
assert (
|
|
640
|
-
len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) -
|
|
685
|
+
len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) - 4
|
|
641
686
|
), (
|
|
642
687
|
"All arguments should be put in data_and_parameters except "
|
|
643
|
-
"`path`, `event_structure`, and `
|
|
688
|
+
"`path`, `event_structure`, `cast_to_float`, and `overwrite` arguments."
|
|
644
689
|
)
|
|
645
690
|
|
|
646
691
|
if event_structure:
|
|
@@ -682,7 +727,7 @@ def generate_zea_dataset(
|
|
|
682
727
|
# Convert path to Path object
|
|
683
728
|
path = Path(path)
|
|
684
729
|
|
|
685
|
-
if path.exists():
|
|
730
|
+
if path.exists() and not overwrite:
|
|
686
731
|
raise FileExistsError(f"The file {path} already exists.")
|
|
687
732
|
|
|
688
733
|
# Create the directory if it does not exist
|
|
@@ -720,3 +765,79 @@ def generate_zea_dataset(
|
|
|
720
765
|
|
|
721
766
|
validate_file(path)
|
|
722
767
|
log.info(f"zea dataset written to {log.yellow(path)}")
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def load_description(path):
|
|
771
|
+
"""Loads the description of a zea dataset.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
path (str): The path to the zea dataset.
|
|
775
|
+
|
|
776
|
+
Returns:
|
|
777
|
+
str: The description of the dataset, or an empty string if not found.
|
|
778
|
+
"""
|
|
779
|
+
path = Path(path)
|
|
780
|
+
|
|
781
|
+
with File(path, "r") as file:
|
|
782
|
+
description = file.attrs.get("description", "")
|
|
783
|
+
|
|
784
|
+
return description
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def load_additional_elements(path):
|
|
788
|
+
"""Loads additional dataset elements from a zea dataset.
|
|
789
|
+
|
|
790
|
+
Args:
|
|
791
|
+
path (str): The path to the zea dataset.
|
|
792
|
+
|
|
793
|
+
Returns:
|
|
794
|
+
list: A list of DatasetElement objects.
|
|
795
|
+
"""
|
|
796
|
+
path = Path(path)
|
|
797
|
+
|
|
798
|
+
with File(path, "r") as file:
|
|
799
|
+
if "non_standard_elements" not in file:
|
|
800
|
+
return []
|
|
801
|
+
|
|
802
|
+
additional_elements = _load_additional_elements_from_group(file, "non_standard_elements")
|
|
803
|
+
|
|
804
|
+
return additional_elements
|
|
805
|
+
|
|
806
|
+
|
|
807
|
+
def _load_additional_elements_from_group(file, path):
|
|
808
|
+
"""Recursively loads additional dataset elements from a group."""
|
|
809
|
+
elements = []
|
|
810
|
+
for name, item in file[path].items():
|
|
811
|
+
if isinstance(item, h5py.Dataset):
|
|
812
|
+
elements.append(_load_dataset_element_from_group(file, f"{path}/{name}"))
|
|
813
|
+
elif isinstance(item, h5py.Group):
|
|
814
|
+
elements.extend(_load_additional_elements_from_group(file, f"{path}/{name}"))
|
|
815
|
+
return elements
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
def _load_dataset_element_from_group(file, path):
|
|
819
|
+
"""Loads a specific dataset element from a group.
|
|
820
|
+
|
|
821
|
+
Args:
|
|
822
|
+
file (h5py.File): The HDF5 file object.
|
|
823
|
+
path (str): The full path to the dataset element.
|
|
824
|
+
e.g., "non_standard_elements/lens/lens_profile"
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
DatasetElement: The loaded dataset element.
|
|
828
|
+
"""
|
|
829
|
+
|
|
830
|
+
dataset = file[path]
|
|
831
|
+
description = dataset.attrs.get("description", "")
|
|
832
|
+
unit = dataset.attrs.get("unit", "")
|
|
833
|
+
data = dataset[()]
|
|
834
|
+
|
|
835
|
+
path_parts = path.split("/")
|
|
836
|
+
|
|
837
|
+
return DatasetElement(
|
|
838
|
+
dataset_name=path_parts[-1],
|
|
839
|
+
data=data,
|
|
840
|
+
description=description,
|
|
841
|
+
unit=unit,
|
|
842
|
+
group_name="/".join(path_parts[1:-1]),
|
|
843
|
+
)
|
zea/data/dataloader.py
CHANGED
|
@@ -5,7 +5,7 @@ H5 dataloader for loading images from zea datasets.
|
|
|
5
5
|
import re
|
|
6
6
|
from itertools import product
|
|
7
7
|
from pathlib import Path
|
|
8
|
-
from typing import List
|
|
8
|
+
from typing import List, Tuple, Union
|
|
9
9
|
|
|
10
10
|
import numpy as np
|
|
11
11
|
|
|
@@ -65,12 +65,12 @@ def generate_h5_indices(
|
|
|
65
65
|
(
|
|
66
66
|
"/folder/path_to_file.hdf5",
|
|
67
67
|
"data/image",
|
|
68
|
-
|
|
68
|
+
(range(0, 1), slice(None, 256, None), slice(None, 256, None)),
|
|
69
69
|
),
|
|
70
70
|
(
|
|
71
71
|
"/folder/path_to_file.hdf5",
|
|
72
72
|
"data/image",
|
|
73
|
-
|
|
73
|
+
(range(1, 2), slice(None, 256, None), slice(None, 256, None)),
|
|
74
74
|
),
|
|
75
75
|
...,
|
|
76
76
|
]
|
|
@@ -117,7 +117,7 @@ def generate_h5_indices(
|
|
|
117
117
|
# Optionally limit frames to load from each file
|
|
118
118
|
n_frames_in_file = min(n_frames_in_file, limit_n_frames)
|
|
119
119
|
indices = [
|
|
120
|
-
range(i, i + block_size, frame_index_stride)
|
|
120
|
+
list(range(i, i + block_size, frame_index_stride))
|
|
121
121
|
for i in range(0, n_frames_in_file - block_size + 1, block_step_size)
|
|
122
122
|
]
|
|
123
123
|
yield [indices]
|
|
@@ -132,7 +132,7 @@ def generate_h5_indices(
|
|
|
132
132
|
continue
|
|
133
133
|
|
|
134
134
|
if additional_axes_iter:
|
|
135
|
-
axis_indices += [range(shape[axis]) for axis in additional_axes_iter]
|
|
135
|
+
axis_indices += [list(range(shape[axis])) for axis in additional_axes_iter]
|
|
136
136
|
|
|
137
137
|
axis_indices = product(*axis_indices)
|
|
138
138
|
|
|
@@ -140,7 +140,7 @@ def generate_h5_indices(
|
|
|
140
140
|
full_indices = [slice(size) for size in shape]
|
|
141
141
|
for i, axis in enumerate([initial_frame_axis] + list(additional_axes_iter)):
|
|
142
142
|
full_indices[axis] = axis_index[i]
|
|
143
|
-
indices.append((file, key, full_indices))
|
|
143
|
+
indices.append((file, key, tuple(full_indices)))
|
|
144
144
|
|
|
145
145
|
if skipped_files > 0:
|
|
146
146
|
log.warning(
|
|
@@ -321,7 +321,12 @@ class H5Generator(Dataset):
|
|
|
321
321
|
initial_delay=INITIAL_RETRY_DELAY,
|
|
322
322
|
retry_action=_h5_reopen_on_io_error,
|
|
323
323
|
)
|
|
324
|
-
def load(
|
|
324
|
+
def load(
|
|
325
|
+
self,
|
|
326
|
+
file: File,
|
|
327
|
+
key: str,
|
|
328
|
+
indices: Tuple[Union[list, slice, int], ...] | List[int] | int | None = None,
|
|
329
|
+
):
|
|
325
330
|
"""Extract data from hdf5 file.
|
|
326
331
|
Args:
|
|
327
332
|
file_name (str): name of the file to extract image from.
|
zea/data/datasets.py
CHANGED
|
@@ -31,9 +31,12 @@ Features
|
|
|
31
31
|
|
|
32
32
|
"""
|
|
33
33
|
|
|
34
|
+
import functools
|
|
35
|
+
import multiprocessing
|
|
36
|
+
import os
|
|
34
37
|
from collections import OrderedDict
|
|
35
38
|
from pathlib import Path
|
|
36
|
-
from typing import List
|
|
39
|
+
from typing import List, Tuple
|
|
37
40
|
|
|
38
41
|
import numpy as np
|
|
39
42
|
import tqdm
|
|
@@ -48,14 +51,12 @@ from zea.data.preset_utils import (
|
|
|
48
51
|
_hf_resolve_path,
|
|
49
52
|
)
|
|
50
53
|
from zea.datapaths import format_data_path
|
|
54
|
+
from zea.internal.cache import cache_output
|
|
55
|
+
from zea.internal.core import hash_elements
|
|
56
|
+
from zea.internal.utils import calculate_file_hash, reduce_to_signature
|
|
51
57
|
from zea.io_lib import search_file_tree
|
|
52
58
|
from zea.tools.hf import HFPath
|
|
53
|
-
from zea.utils import
|
|
54
|
-
calculate_file_hash,
|
|
55
|
-
date_string_to_readable,
|
|
56
|
-
get_date_string,
|
|
57
|
-
reduce_to_signature,
|
|
58
|
-
)
|
|
59
|
+
from zea.utils import date_string_to_readable, get_date_string
|
|
59
60
|
|
|
60
61
|
_CHECK_MAX_DATASET_SIZE = 10000
|
|
61
62
|
_VALIDATED_FLAG_FILE = "validated.flag"
|
|
@@ -104,16 +105,78 @@ class H5FileHandleCache:
|
|
|
104
105
|
|
|
105
106
|
return self._file_handle_cache[file_path]
|
|
106
107
|
|
|
108
|
+
def close(self):
|
|
109
|
+
"""Close all cached file handles."""
|
|
110
|
+
cache: OrderedDict = getattr(self, "_file_handle_cache", None)
|
|
111
|
+
if not cache:
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
# iterate over a static list to avoid mutation during iteration
|
|
115
|
+
for fh in list(cache.values()):
|
|
116
|
+
if fh is None:
|
|
117
|
+
continue
|
|
118
|
+
try:
|
|
119
|
+
# attempt to close unconditionally and swallow exceptions
|
|
120
|
+
fh.close()
|
|
121
|
+
except Exception:
|
|
122
|
+
# During interpreter shutdown or if the h5py internals are already
|
|
123
|
+
# torn down, close() can raise weird errors (e.g. TypeError).
|
|
124
|
+
# Swallow them here to avoid exceptions from __del__.
|
|
125
|
+
pass
|
|
126
|
+
|
|
127
|
+
cache.clear() # clear the cache dict
|
|
128
|
+
|
|
107
129
|
def __del__(self):
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
130
|
+
self.close()
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
@cache_output("filepaths", "key", "_filepath_hash", verbose=True)
|
|
134
|
+
def _find_h5_file_shapes(filepaths, key, _filepath_hash, verbose=True):
|
|
135
|
+
# NOTE: we cache the output of this function such that file loading over the network is
|
|
136
|
+
# faster for repeated calls with the same filepaths, key and _filepath_hash
|
|
137
|
+
|
|
138
|
+
assert _filepath_hash is not None
|
|
114
139
|
|
|
140
|
+
get_shape = functools.partial(File.get_shape, key=key)
|
|
115
141
|
|
|
116
|
-
|
|
142
|
+
if os.environ.get("ZEA_FIND_H5_SHAPES_PARALLEL", "1") in ("1", "true", "yes"):
|
|
143
|
+
# using multiprocessing to speed up reading hdf5 files
|
|
144
|
+
# make sure to call find_h5_file_shapes from within a function
|
|
145
|
+
# or use if __name__ == "__main__" to avoid freezing the main process
|
|
146
|
+
|
|
147
|
+
with multiprocessing.Pool() as pool:
|
|
148
|
+
file_shapes = list(
|
|
149
|
+
tqdm.tqdm(
|
|
150
|
+
pool.imap(get_shape, filepaths),
|
|
151
|
+
total=len(filepaths),
|
|
152
|
+
desc="Getting file shapes in each h5 file",
|
|
153
|
+
disable=not verbose,
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
file_shapes = []
|
|
158
|
+
for file_path in tqdm.tqdm(
|
|
159
|
+
filepaths,
|
|
160
|
+
desc="Getting file shapes in each h5 file",
|
|
161
|
+
disable=not verbose,
|
|
162
|
+
):
|
|
163
|
+
file_shapes.append(get_shape(file_path))
|
|
164
|
+
|
|
165
|
+
return file_shapes
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _file_hash(filepaths):
|
|
169
|
+
# NOTE: this is really fast, even over network filesystemss
|
|
170
|
+
total_size = 0
|
|
171
|
+
modified_times = []
|
|
172
|
+
for fp in filepaths:
|
|
173
|
+
if os.path.isfile(fp):
|
|
174
|
+
total_size += os.path.getsize(fp)
|
|
175
|
+
modified_times.append(os.path.getmtime(fp))
|
|
176
|
+
return hash_elements([total_size, modified_times])
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def find_h5_files(paths: str | list, key: str = None) -> Tuple[List[str], List[tuple]]:
|
|
117
180
|
"""
|
|
118
181
|
Find HDF5 files from a directory or list of directories and optionally retrieve their shapes.
|
|
119
182
|
|
|
@@ -121,17 +184,11 @@ def find_h5_files(paths: str | list, key: str = None, search_file_tree_kwargs: d
|
|
|
121
184
|
paths (str or list): A single directory path, a list of directory paths,
|
|
122
185
|
or a single HDF5 file path.
|
|
123
186
|
key (str, optional): The key to get the file shapes for.
|
|
124
|
-
search_file_tree_kwargs (dict, optional): Additional keyword arguments for the
|
|
125
|
-
search_file_tree function. Defaults to None.
|
|
126
187
|
|
|
127
188
|
Returns:
|
|
128
|
-
- file_paths (list): List of file paths to the HDF5 files.
|
|
129
|
-
- file_shapes (list): List of shapes of the HDF5 datasets.
|
|
189
|
+
- file_paths (list): List of file paths (str) to the HDF5 files.
|
|
190
|
+
- file_shapes (list): List of shapes (tuple) of the HDF5 datasets.
|
|
130
191
|
"""
|
|
131
|
-
|
|
132
|
-
if search_file_tree_kwargs is None:
|
|
133
|
-
search_file_tree_kwargs = {}
|
|
134
|
-
|
|
135
192
|
# Make sure paths is a list
|
|
136
193
|
if not isinstance(paths, (tuple, list)):
|
|
137
194
|
paths = [paths]
|
|
@@ -152,14 +209,12 @@ def find_h5_files(paths: str | list, key: str = None, search_file_tree_kwargs: d
|
|
|
152
209
|
file_paths.append(str(path))
|
|
153
210
|
continue
|
|
154
211
|
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
file_shapes += dataset_info["file_shapes"]
|
|
162
|
-
file_paths += [str(Path(path) / file_path) for file_path in dataset_info["file_paths"]]
|
|
212
|
+
_filepaths = list(search_file_tree(path, filetypes=FILE_TYPES))
|
|
213
|
+
file_shapes += _find_h5_file_shapes(_filepaths, key, _file_hash(_filepaths))
|
|
214
|
+
file_paths += _filepaths
|
|
215
|
+
|
|
216
|
+
# Convert file paths to strings
|
|
217
|
+
file_paths = [str(fp) for fp in file_paths]
|
|
163
218
|
|
|
164
219
|
return file_paths, file_shapes
|
|
165
220
|
|
|
@@ -172,8 +227,7 @@ class Folder:
|
|
|
172
227
|
def __init__(
|
|
173
228
|
self,
|
|
174
229
|
folder_path: list[str] | list[Path],
|
|
175
|
-
key: str
|
|
176
|
-
search_file_tree_kwargs: dict | None = None,
|
|
230
|
+
key: str,
|
|
177
231
|
validate: bool = True,
|
|
178
232
|
hf_cache_dir: str = HF_DATASETS_DIR,
|
|
179
233
|
**kwargs,
|
|
@@ -195,11 +249,8 @@ class Folder:
|
|
|
195
249
|
|
|
196
250
|
self.folder_path = Path(folder_path)
|
|
197
251
|
self.key = key
|
|
198
|
-
self.search_file_tree_kwargs = search_file_tree_kwargs
|
|
199
252
|
self.validate = validate
|
|
200
|
-
self.file_paths, self.file_shapes = find_h5_files(
|
|
201
|
-
folder_path, self.key, self.search_file_tree_kwargs
|
|
202
|
-
)
|
|
253
|
+
self.file_paths, self.file_shapes = find_h5_files(folder_path, self.key)
|
|
203
254
|
assert self.n_files > 0, f"No files in folder: {folder_path}"
|
|
204
255
|
if self.validate:
|
|
205
256
|
self.validate_folder()
|
|
@@ -241,7 +292,7 @@ class Folder:
|
|
|
241
292
|
return
|
|
242
293
|
|
|
243
294
|
num_frames_per_file = []
|
|
244
|
-
|
|
295
|
+
validated_successfully = True
|
|
245
296
|
for file_path in tqdm.tqdm(
|
|
246
297
|
self.file_paths,
|
|
247
298
|
total=self.n_files,
|
|
@@ -253,9 +304,9 @@ class Folder:
|
|
|
253
304
|
validation_error_log.append(f"File {file_path} is not a valid zea dataset.\n{e}\n")
|
|
254
305
|
# convert into warning
|
|
255
306
|
log.warning(f"Error in file {file_path}.\n{e}")
|
|
256
|
-
|
|
307
|
+
validated_successfully = False
|
|
257
308
|
|
|
258
|
-
if not
|
|
309
|
+
if not validated_successfully:
|
|
259
310
|
log.warning(
|
|
260
311
|
"Check warnings above for details. No validation file was created. "
|
|
261
312
|
f"See {validation_error_file_path} for details."
|
|
@@ -319,24 +370,27 @@ class Folder:
|
|
|
319
370
|
data_types = self.get_data_types(self.file_paths[0])
|
|
320
371
|
|
|
321
372
|
number_of_frames = sum(num_frames_per_file)
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
373
|
+
try:
|
|
374
|
+
with open(validation_file_path, "w", encoding="utf-8") as f:
|
|
375
|
+
f.write(f"Dataset: {path}\n")
|
|
376
|
+
f.write(f"Validated on: {get_date_string()}\n")
|
|
377
|
+
f.write(f"Number of files: {self.n_files}\n")
|
|
378
|
+
f.write(f"Number of frames: {number_of_frames}\n")
|
|
379
|
+
f.write(f"Data types: {', '.join(data_types)}\n")
|
|
380
|
+
f.write(f"{'-' * 80}\n")
|
|
381
|
+
# write all file names (not entire path) with number of frames on a new line
|
|
382
|
+
for file_path, num_frames in zip(self.file_paths, num_frames_per_file):
|
|
383
|
+
f.write(f"{file_path.name}: {num_frames}\n")
|
|
384
|
+
f.write(f"{'-' * 80}\n")
|
|
385
|
+
|
|
386
|
+
# Write the hash of the validation file
|
|
387
|
+
validation_file_hash = calculate_file_hash(validation_file_path)
|
|
388
|
+
with open(validation_file_path, "a", encoding="utf-8") as f:
|
|
389
|
+
# *** validation file hash *** (80 total line length)
|
|
390
|
+
f.write("*** validation file hash ***\n")
|
|
391
|
+
f.write(f"hash: {validation_file_hash}")
|
|
392
|
+
except Exception as e:
|
|
393
|
+
log.warning(f"Unable to write validation flag: {e}")
|
|
340
394
|
|
|
341
395
|
def __repr__(self):
|
|
342
396
|
return (
|
|
@@ -413,7 +467,6 @@ class Dataset(H5FileHandleCache):
|
|
|
413
467
|
self,
|
|
414
468
|
file_paths: List[str] | str,
|
|
415
469
|
key: str,
|
|
416
|
-
search_file_tree_kwargs: dict | None = None,
|
|
417
470
|
validate: bool = True,
|
|
418
471
|
directory_splits: list | None = None,
|
|
419
472
|
**kwargs,
|
|
@@ -424,9 +477,6 @@ class Dataset(H5FileHandleCache):
|
|
|
424
477
|
file_paths (str or list): (list of) path(s) to the folder(s) containing the HDF5 file(s)
|
|
425
478
|
or list of HDF5 file paths. Can be a mixed list of folders and files.
|
|
426
479
|
key (str): The key to access the HDF5 dataset.
|
|
427
|
-
search_file_tree_kwargs (dict, optional): Additional keyword arguments for the
|
|
428
|
-
search_file_tree function. These are only used when `file_paths` are directories.
|
|
429
|
-
Defaults to None.
|
|
430
480
|
validate (bool, optional): Whether to validate the dataset. Defaults to True.
|
|
431
481
|
directory_splits (list, optional): List of directory split by. Is a list of floats
|
|
432
482
|
between 0 and 1, with the same length as the number of file_paths given.
|
|
@@ -435,7 +485,6 @@ class Dataset(H5FileHandleCache):
|
|
|
435
485
|
"""
|
|
436
486
|
super().__init__(**kwargs)
|
|
437
487
|
self.key = key
|
|
438
|
-
self.search_file_tree_kwargs = search_file_tree_kwargs
|
|
439
488
|
self.validate = validate
|
|
440
489
|
|
|
441
490
|
self.file_paths, self.file_shapes = self.find_files_and_shapes(file_paths)
|
|
@@ -475,7 +524,7 @@ class Dataset(H5FileHandleCache):
|
|
|
475
524
|
file_path = Path(file_path)
|
|
476
525
|
|
|
477
526
|
if file_path.is_dir():
|
|
478
|
-
folder = Folder(file_path, self.key, self.
|
|
527
|
+
folder = Folder(file_path, self.key, self.validate)
|
|
479
528
|
file_paths += folder.file_paths
|
|
480
529
|
file_shapes += folder.file_shapes
|
|
481
530
|
del folder
|
|
@@ -539,14 +588,6 @@ class Dataset(H5FileHandleCache):
|
|
|
539
588
|
def __str__(self):
|
|
540
589
|
return f"Dataset with {self.n_files} files (key='{self.key}')"
|
|
541
590
|
|
|
542
|
-
def close(self):
|
|
543
|
-
"""Close all cached file handles."""
|
|
544
|
-
for file in self._file_handle_cache.values():
|
|
545
|
-
if file is not None and file.id.valid:
|
|
546
|
-
file.close()
|
|
547
|
-
self._file_handle_cache.clear()
|
|
548
|
-
log.info("Closed all cached file handles.")
|
|
549
|
-
|
|
550
591
|
def __enter__(self):
|
|
551
592
|
return self
|
|
552
593
|
|