zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. zea/__init__.py +54 -19
  2. zea/agent/__init__.py +12 -12
  3. zea/agent/masks.py +2 -1
  4. zea/backend/tensorflow/dataloader.py +2 -5
  5. zea/beamform/beamformer.py +100 -50
  6. zea/beamform/lens_correction.py +9 -2
  7. zea/beamform/pfield.py +9 -2
  8. zea/beamform/pixelgrid.py +1 -1
  9. zea/config.py +34 -25
  10. zea/data/__init__.py +22 -25
  11. zea/data/augmentations.py +221 -28
  12. zea/data/convert/__init__.py +1 -6
  13. zea/data/convert/__main__.py +123 -0
  14. zea/data/convert/camus.py +101 -40
  15. zea/data/convert/echonet.py +187 -86
  16. zea/data/convert/echonetlvh/README.md +2 -3
  17. zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
  18. zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
  19. zea/data/convert/echonetlvh/precompute_crop.py +43 -64
  20. zea/data/convert/picmus.py +37 -40
  21. zea/data/convert/utils.py +86 -0
  22. zea/data/convert/{matlab.py → verasonics.py} +44 -65
  23. zea/data/data_format.py +155 -34
  24. zea/data/dataloader.py +12 -7
  25. zea/data/datasets.py +112 -71
  26. zea/data/file.py +184 -73
  27. zea/data/file_operations.py +496 -0
  28. zea/data/layers.py +3 -3
  29. zea/data/preset_utils.py +1 -1
  30. zea/datapaths.py +16 -4
  31. zea/display.py +14 -13
  32. zea/interface.py +14 -16
  33. zea/internal/_generate_keras_ops.py +6 -7
  34. zea/internal/cache.py +2 -49
  35. zea/internal/checks.py +6 -12
  36. zea/internal/config/validation.py +1 -2
  37. zea/internal/core.py +69 -6
  38. zea/internal/device.py +6 -2
  39. zea/internal/dummy_scan.py +330 -0
  40. zea/internal/operators.py +118 -2
  41. zea/internal/parameters.py +101 -70
  42. zea/internal/setup_zea.py +5 -6
  43. zea/internal/utils.py +282 -0
  44. zea/io_lib.py +322 -146
  45. zea/keras_ops.py +74 -4
  46. zea/log.py +9 -7
  47. zea/metrics.py +15 -7
  48. zea/models/__init__.py +31 -21
  49. zea/models/base.py +30 -14
  50. zea/models/carotid_segmenter.py +19 -4
  51. zea/models/diffusion.py +235 -23
  52. zea/models/echonet.py +22 -8
  53. zea/models/echonetlvh.py +31 -7
  54. zea/models/lpips.py +19 -2
  55. zea/models/lv_segmentation.py +30 -11
  56. zea/models/preset_utils.py +5 -5
  57. zea/models/regional_quality.py +30 -10
  58. zea/models/taesd.py +21 -5
  59. zea/models/unet.py +15 -1
  60. zea/ops.py +770 -336
  61. zea/probes.py +6 -6
  62. zea/scan.py +121 -51
  63. zea/simulator.py +24 -21
  64. zea/tensor_ops.py +477 -353
  65. zea/tools/fit_scan_cone.py +90 -160
  66. zea/tools/hf.py +1 -1
  67. zea/tools/selection_tool.py +47 -86
  68. zea/tracking/__init__.py +16 -0
  69. zea/tracking/base.py +94 -0
  70. zea/tracking/lucas_kanade.py +474 -0
  71. zea/tracking/segmentation.py +110 -0
  72. zea/utils.py +101 -480
  73. zea/visualize.py +177 -39
  74. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
  75. zea-0.0.8.dist-info/RECORD +122 -0
  76. zea-0.0.6.dist-info/RECORD +0 -112
  77. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
  78. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
  79. {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/data/data_format.py CHANGED
@@ -6,12 +6,14 @@ import inspect
6
6
  from dataclasses import dataclass
7
7
  from pathlib import Path
8
8
 
9
+ import h5py
9
10
  import numpy as np
11
+ from keras.utils import pad_sequences
10
12
 
11
13
  from zea import log
12
14
  from zea.data.file import File, validate_file
13
15
  from zea.internal.checks import _DATA_TYPES
14
- from zea.utils import first_not_none_item
16
+ from zea.internal.utils import first_not_none_item
15
17
 
16
18
 
17
19
  @dataclass
@@ -19,15 +21,15 @@ class DatasetElement:
19
21
  """Class to store a dataset element with a name, data, description and unit. Used to
20
22
  supply additional dataset elements to the generate_zea_dataset function."""
21
23
 
22
- # The group name to store the dataset under. This can be a nested group, e.g.
23
- # "scan/waveforms"
24
- group_name: str
25
24
  # The name of the dataset. This will be the key in the group.
26
25
  dataset_name: str
27
26
  # The data to store in the dataset.
28
27
  data: np.ndarray
29
28
  description: str
30
29
  unit: str
30
+ # The group name to store the dataset under. This can be a nested group, e.g.
31
+ # "lens/profiles"
32
+ group_name: str = ""
31
33
 
32
34
 
33
35
  def generate_example_dataset(
@@ -110,9 +112,43 @@ def generate_example_dataset(
110
112
  focus_distances=focus_distances,
111
113
  polar_angles=polar_angles,
112
114
  azimuth_angles=azimuth_angles,
115
+ additional_elements=_generate_example_dataset_elements(),
116
+ description="This is an example dataset generated by zea",
113
117
  )
114
118
 
115
119
 
120
+ def _generate_example_dataset_elements() -> list[DatasetElement]:
121
+ """Generates a list of example DatasetElement objects to be used as additional
122
+ elements in the generate_zea_dataset function.
123
+
124
+ Returns:
125
+ list: A list of DatasetElement objects.
126
+ """
127
+ example_elements = [
128
+ DatasetElement(
129
+ dataset_name="temperature",
130
+ data=np.array(42),
131
+ description="The temperature during the measurement",
132
+ unit="unitless",
133
+ ),
134
+ DatasetElement(
135
+ dataset_name="lens_profile",
136
+ data=np.random.rand(100),
137
+ description="An example lens profile",
138
+ unit="mm",
139
+ group_name="lens",
140
+ ),
141
+ DatasetElement(
142
+ dataset_name="lens_material",
143
+ data=np.array(["material1", "material2", "material3"], dtype=h5py.string_dtype()),
144
+ description="An example lens material list",
145
+ unit="unitless",
146
+ group_name="lens",
147
+ ),
148
+ ]
149
+ return example_elements
150
+
151
+
116
152
  def validate_input_data(raw_data, aligned_data, envelope_data, beamformed_data, image, image_sc):
117
153
  """
118
154
  Validates input data for generate_zea_dataset
@@ -470,38 +506,45 @@ def _write_datasets(
470
506
  )
471
507
 
472
508
  if waveforms_one_way is not None:
473
- for n in range(len(waveforms_one_way)):
474
- _add_dataset(
475
- group_name=scan_group_name + "/waveforms_one_way",
476
- name=f"waveform_{str(n).zfill(3)}",
477
- data=waveforms_one_way[n],
478
- description=(
479
- "One-way waveform as simulated by the Verasonics system, "
480
- "sampled at 250MHz. This is the waveform after being filtered "
481
- "by the tranducer bandwidth once."
482
- ),
483
- unit="V",
484
- )
509
+ _add_dataset(
510
+ group_name=scan_group_name,
511
+ name="waveforms_one_way",
512
+ data=pad_sequences(waveforms_one_way, dtype=np.float32, padding="post"),
513
+ description=(
514
+ "One-way waveform as simulated by the Verasonics system, "
515
+ "sampled at 250MHz. This is the waveform after being filtered "
516
+ "by the transducer bandwidth once."
517
+ ),
518
+ unit="V",
519
+ )
485
520
 
486
521
  if waveforms_two_way is not None:
487
- for n in range(len(waveforms_two_way)):
488
- _add_dataset(
489
- group_name=scan_group_name + "/waveforms_two_way",
490
- name=f"waveform_{str(n).zfill(3)}",
491
- data=waveforms_two_way[n],
492
- description=(
493
- "Two-way waveform as simulated by the Verasonics system, "
494
- "sampled at 250MHz. This is the waveform after being filtered "
495
- "by the tranducer bandwidth twice."
496
- ),
497
- unit="V",
498
- )
522
+ _add_dataset(
523
+ group_name=scan_group_name,
524
+ name="waveforms_two_way",
525
+ data=pad_sequences(waveforms_two_way, dtype=np.float32, padding="post"),
526
+ description=(
527
+ "Two-way waveform as simulated by the Verasonics system, "
528
+ "sampled at 250MHz. This is the waveform after being filtered "
529
+ "by the transducer bandwidth twice."
530
+ ),
531
+ unit="V",
532
+ )
499
533
 
500
534
  # Add additional elements
501
535
  if additional_elements is not None:
536
+ # Write scan group
537
+ non_standard_elements_group_name = "non_standard_elements"
538
+ non_standard_elements_group = dataset.create_group(non_standard_elements_group_name)
539
+ non_standard_elements_group.attrs["description"] = (
540
+ "This group contains non-standard elements that can be added by the user."
541
+ )
502
542
  for element in additional_elements:
543
+ group_name = non_standard_elements_group_name
544
+ if element.group_name != "":
545
+ group_name += f"/{element.group_name}"
503
546
  _add_dataset(
504
- group_name=element.group_name,
547
+ group_name=group_name,
505
548
  name=element.dataset_name,
506
549
  data=element.data,
507
550
  description=element.description,
@@ -539,6 +582,7 @@ def generate_zea_dataset(
539
582
  additional_elements=None,
540
583
  event_structure=False,
541
584
  cast_to_float=True,
585
+ overwrite=False,
542
586
  ):
543
587
  """Generates a dataset in the zea format.
544
588
 
@@ -585,10 +629,10 @@ def generate_zea_dataset(
585
629
  waveform was used for each transmit event.
586
630
  waveforms_one_way (list): List of one-way waveforms as simulated by the Verasonics
587
631
  system, sampled at 250MHz. This is the waveform after being filtered by the
588
- tranducer bandwidth once. Every element in the list is a 1D numpy array.
632
+ transducer bandwidth once. Every element in the list is a 1D numpy array.
589
633
  waveforms_two_way (list): List of two-way waveforms as simulated by the Verasonics
590
634
  system, sampled at 250MHz. This is the waveform after being filtered by the
591
- tranducer bandwidth twice. Every element in the list is a 1D numpy array.
635
+ transducer bandwidth twice. Every element in the list is a 1D numpy array.
592
636
  additional_elements (List[DatasetElement]): A list of additional dataset
593
637
  elements to be added to the dataset. Each element should be a DatasetElement
594
638
  object. The additional elements are added under the scan group.
@@ -598,6 +642,7 @@ def generate_zea_dataset(
598
642
  Instead of just a single data and scan group.
599
643
  cast_to_float (bool): Whether to store data as float32. You may want to set this
600
644
  to False if storing images.
645
+ overwrite (bool): Whether to overwrite the file if it already exists. Defaults to False.
601
646
 
602
647
  """
603
648
  # check if all args are lists
@@ -637,10 +682,10 @@ def generate_zea_dataset(
637
682
  # make sure input arguments of func is same length as data_and_parameters
638
683
  # except `path` and `event_structure` arguments and ofcourse `data_and_parameters` itself
639
684
  assert (
640
- len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) - 3
685
+ len(data_and_parameters) == len(inspect.signature(generate_zea_dataset).parameters) - 4
641
686
  ), (
642
687
  "All arguments should be put in data_and_parameters except "
643
- "`path`, `event_structure`, and `cast_to_float` arguments."
688
+ "`path`, `event_structure`, `cast_to_float`, and `overwrite` arguments."
644
689
  )
645
690
 
646
691
  if event_structure:
@@ -682,7 +727,7 @@ def generate_zea_dataset(
682
727
  # Convert path to Path object
683
728
  path = Path(path)
684
729
 
685
- if path.exists():
730
+ if path.exists() and not overwrite:
686
731
  raise FileExistsError(f"The file {path} already exists.")
687
732
 
688
733
  # Create the directory if it does not exist
@@ -720,3 +765,79 @@ def generate_zea_dataset(
720
765
 
721
766
  validate_file(path)
722
767
  log.info(f"zea dataset written to {log.yellow(path)}")
768
+
769
+
770
+ def load_description(path):
771
+ """Loads the description of a zea dataset.
772
+
773
+ Args:
774
+ path (str): The path to the zea dataset.
775
+
776
+ Returns:
777
+ str: The description of the dataset, or an empty string if not found.
778
+ """
779
+ path = Path(path)
780
+
781
+ with File(path, "r") as file:
782
+ description = file.attrs.get("description", "")
783
+
784
+ return description
785
+
786
+
787
+ def load_additional_elements(path):
788
+ """Loads additional dataset elements from a zea dataset.
789
+
790
+ Args:
791
+ path (str): The path to the zea dataset.
792
+
793
+ Returns:
794
+ list: A list of DatasetElement objects.
795
+ """
796
+ path = Path(path)
797
+
798
+ with File(path, "r") as file:
799
+ if "non_standard_elements" not in file:
800
+ return []
801
+
802
+ additional_elements = _load_additional_elements_from_group(file, "non_standard_elements")
803
+
804
+ return additional_elements
805
+
806
+
807
+ def _load_additional_elements_from_group(file, path):
808
+ """Recursively loads additional dataset elements from a group."""
809
+ elements = []
810
+ for name, item in file[path].items():
811
+ if isinstance(item, h5py.Dataset):
812
+ elements.append(_load_dataset_element_from_group(file, f"{path}/{name}"))
813
+ elif isinstance(item, h5py.Group):
814
+ elements.extend(_load_additional_elements_from_group(file, f"{path}/{name}"))
815
+ return elements
816
+
817
+
818
+ def _load_dataset_element_from_group(file, path):
819
+ """Loads a specific dataset element from a group.
820
+
821
+ Args:
822
+ file (h5py.File): The HDF5 file object.
823
+ path (str): The full path to the dataset element.
824
+ e.g., "non_standard_elements/lens/lens_profile"
825
+
826
+ Returns:
827
+ DatasetElement: The loaded dataset element.
828
+ """
829
+
830
+ dataset = file[path]
831
+ description = dataset.attrs.get("description", "")
832
+ unit = dataset.attrs.get("unit", "")
833
+ data = dataset[()]
834
+
835
+ path_parts = path.split("/")
836
+
837
+ return DatasetElement(
838
+ dataset_name=path_parts[-1],
839
+ data=data,
840
+ description=description,
841
+ unit=unit,
842
+ group_name="/".join(path_parts[1:-1]),
843
+ )
zea/data/dataloader.py CHANGED
@@ -5,7 +5,7 @@ H5 dataloader for loading images from zea datasets.
5
5
  import re
6
6
  from itertools import product
7
7
  from pathlib import Path
8
- from typing import List
8
+ from typing import List, Tuple, Union
9
9
 
10
10
  import numpy as np
11
11
 
@@ -65,12 +65,12 @@ def generate_h5_indices(
65
65
  (
66
66
  "/folder/path_to_file.hdf5",
67
67
  "data/image",
68
- [range(0, 1), slice(None, 256, None), slice(None, 256, None)],
68
+ (range(0, 1), slice(None, 256, None), slice(None, 256, None)),
69
69
  ),
70
70
  (
71
71
  "/folder/path_to_file.hdf5",
72
72
  "data/image",
73
- [range(1, 2), slice(None, 256, None), slice(None, 256, None)],
73
+ (range(1, 2), slice(None, 256, None), slice(None, 256, None)),
74
74
  ),
75
75
  ...,
76
76
  ]
@@ -117,7 +117,7 @@ def generate_h5_indices(
117
117
  # Optionally limit frames to load from each file
118
118
  n_frames_in_file = min(n_frames_in_file, limit_n_frames)
119
119
  indices = [
120
- range(i, i + block_size, frame_index_stride)
120
+ list(range(i, i + block_size, frame_index_stride))
121
121
  for i in range(0, n_frames_in_file - block_size + 1, block_step_size)
122
122
  ]
123
123
  yield [indices]
@@ -132,7 +132,7 @@ def generate_h5_indices(
132
132
  continue
133
133
 
134
134
  if additional_axes_iter:
135
- axis_indices += [range(shape[axis]) for axis in additional_axes_iter]
135
+ axis_indices += [list(range(shape[axis])) for axis in additional_axes_iter]
136
136
 
137
137
  axis_indices = product(*axis_indices)
138
138
 
@@ -140,7 +140,7 @@ def generate_h5_indices(
140
140
  full_indices = [slice(size) for size in shape]
141
141
  for i, axis in enumerate([initial_frame_axis] + list(additional_axes_iter)):
142
142
  full_indices[axis] = axis_index[i]
143
- indices.append((file, key, full_indices))
143
+ indices.append((file, key, tuple(full_indices)))
144
144
 
145
145
  if skipped_files > 0:
146
146
  log.warning(
@@ -321,7 +321,12 @@ class H5Generator(Dataset):
321
321
  initial_delay=INITIAL_RETRY_DELAY,
322
322
  retry_action=_h5_reopen_on_io_error,
323
323
  )
324
- def load(self, file: File, key: str, indices: tuple | str):
324
+ def load(
325
+ self,
326
+ file: File,
327
+ key: str,
328
+ indices: Tuple[Union[list, slice, int], ...] | List[int] | int | None = None,
329
+ ):
325
330
  """Extract data from hdf5 file.
326
331
  Args:
327
332
  file_name (str): name of the file to extract image from.
zea/data/datasets.py CHANGED
@@ -31,9 +31,12 @@ Features
31
31
 
32
32
  """
33
33
 
34
+ import functools
35
+ import multiprocessing
36
+ import os
34
37
  from collections import OrderedDict
35
38
  from pathlib import Path
36
- from typing import List
39
+ from typing import List, Tuple
37
40
 
38
41
  import numpy as np
39
42
  import tqdm
@@ -48,14 +51,12 @@ from zea.data.preset_utils import (
48
51
  _hf_resolve_path,
49
52
  )
50
53
  from zea.datapaths import format_data_path
54
+ from zea.internal.cache import cache_output
55
+ from zea.internal.core import hash_elements
56
+ from zea.internal.utils import calculate_file_hash, reduce_to_signature
51
57
  from zea.io_lib import search_file_tree
52
58
  from zea.tools.hf import HFPath
53
- from zea.utils import (
54
- calculate_file_hash,
55
- date_string_to_readable,
56
- get_date_string,
57
- reduce_to_signature,
58
- )
59
+ from zea.utils import date_string_to_readable, get_date_string
59
60
 
60
61
  _CHECK_MAX_DATASET_SIZE = 10000
61
62
  _VALIDATED_FLAG_FILE = "validated.flag"
@@ -104,16 +105,78 @@ class H5FileHandleCache:
104
105
 
105
106
  return self._file_handle_cache[file_path]
106
107
 
108
+ def close(self):
109
+ """Close all cached file handles."""
110
+ cache: OrderedDict = getattr(self, "_file_handle_cache", None)
111
+ if not cache:
112
+ return
113
+
114
+ # iterate over a static list to avoid mutation during iteration
115
+ for fh in list(cache.values()):
116
+ if fh is None:
117
+ continue
118
+ try:
119
+ # attempt to close unconditionally and swallow exceptions
120
+ fh.close()
121
+ except Exception:
122
+ # During interpreter shutdown or if the h5py internals are already
123
+ # torn down, close() can raise weird errors (e.g. TypeError).
124
+ # Swallow them here to avoid exceptions from __del__.
125
+ pass
126
+
127
+ cache.clear() # clear the cache dict
128
+
107
129
  def __del__(self):
108
- """Ensure cached files are closed."""
109
- if hasattr(self, "_file_handle_cache"):
110
- for _, file in self._file_handle_cache.items():
111
- if file is not None and self._check_if_open(file):
112
- file.close()
113
- self._file_handle_cache = OrderedDict()
130
+ self.close()
131
+
132
+
133
+ @cache_output("filepaths", "key", "_filepath_hash", verbose=True)
134
+ def _find_h5_file_shapes(filepaths, key, _filepath_hash, verbose=True):
135
+ # NOTE: we cache the output of this function such that file loading over the network is
136
+ # faster for repeated calls with the same filepaths, key and _filepath_hash
137
+
138
+ assert _filepath_hash is not None
114
139
 
140
+ get_shape = functools.partial(File.get_shape, key=key)
115
141
 
116
- def find_h5_files(paths: str | list, key: str = None, search_file_tree_kwargs: dict | None = None):
142
+ if os.environ.get("ZEA_FIND_H5_SHAPES_PARALLEL", "1") in ("1", "true", "yes"):
143
+ # using multiprocessing to speed up reading hdf5 files
144
+ # make sure to call find_h5_file_shapes from within a function
145
+ # or use if __name__ == "__main__" to avoid freezing the main process
146
+
147
+ with multiprocessing.Pool() as pool:
148
+ file_shapes = list(
149
+ tqdm.tqdm(
150
+ pool.imap(get_shape, filepaths),
151
+ total=len(filepaths),
152
+ desc="Getting file shapes in each h5 file",
153
+ disable=not verbose,
154
+ )
155
+ )
156
+ else:
157
+ file_shapes = []
158
+ for file_path in tqdm.tqdm(
159
+ filepaths,
160
+ desc="Getting file shapes in each h5 file",
161
+ disable=not verbose,
162
+ ):
163
+ file_shapes.append(get_shape(file_path))
164
+
165
+ return file_shapes
166
+
167
+
168
+ def _file_hash(filepaths):
169
+ # NOTE: this is really fast, even over network filesystemss
170
+ total_size = 0
171
+ modified_times = []
172
+ for fp in filepaths:
173
+ if os.path.isfile(fp):
174
+ total_size += os.path.getsize(fp)
175
+ modified_times.append(os.path.getmtime(fp))
176
+ return hash_elements([total_size, modified_times])
177
+
178
+
179
+ def find_h5_files(paths: str | list, key: str = None) -> Tuple[List[str], List[tuple]]:
117
180
  """
118
181
  Find HDF5 files from a directory or list of directories and optionally retrieve their shapes.
119
182
 
@@ -121,17 +184,11 @@ def find_h5_files(paths: str | list, key: str = None, search_file_tree_kwargs: d
121
184
  paths (str or list): A single directory path, a list of directory paths,
122
185
  or a single HDF5 file path.
123
186
  key (str, optional): The key to get the file shapes for.
124
- search_file_tree_kwargs (dict, optional): Additional keyword arguments for the
125
- search_file_tree function. Defaults to None.
126
187
 
127
188
  Returns:
128
- - file_paths (list): List of file paths to the HDF5 files.
129
- - file_shapes (list): List of shapes of the HDF5 datasets.
189
+ - file_paths (list): List of file paths (str) to the HDF5 files.
190
+ - file_shapes (list): List of shapes (tuple) of the HDF5 datasets.
130
191
  """
131
-
132
- if search_file_tree_kwargs is None:
133
- search_file_tree_kwargs = {}
134
-
135
192
  # Make sure paths is a list
136
193
  if not isinstance(paths, (tuple, list)):
137
194
  paths = [paths]
@@ -152,14 +209,12 @@ def find_h5_files(paths: str | list, key: str = None, search_file_tree_kwargs: d
152
209
  file_paths.append(str(path))
153
210
  continue
154
211
 
155
- dataset_info = search_file_tree(
156
- path,
157
- filetypes=FILE_TYPES,
158
- hdf5_key_for_length=key,
159
- **search_file_tree_kwargs,
160
- )
161
- file_shapes += dataset_info["file_shapes"]
162
- file_paths += [str(Path(path) / file_path) for file_path in dataset_info["file_paths"]]
212
+ _filepaths = list(search_file_tree(path, filetypes=FILE_TYPES))
213
+ file_shapes += _find_h5_file_shapes(_filepaths, key, _file_hash(_filepaths))
214
+ file_paths += _filepaths
215
+
216
+ # Convert file paths to strings
217
+ file_paths = [str(fp) for fp in file_paths]
163
218
 
164
219
  return file_paths, file_shapes
165
220
 
@@ -172,8 +227,7 @@ class Folder:
172
227
  def __init__(
173
228
  self,
174
229
  folder_path: list[str] | list[Path],
175
- key: str = None,
176
- search_file_tree_kwargs: dict | None = None,
230
+ key: str,
177
231
  validate: bool = True,
178
232
  hf_cache_dir: str = HF_DATASETS_DIR,
179
233
  **kwargs,
@@ -195,11 +249,8 @@ class Folder:
195
249
 
196
250
  self.folder_path = Path(folder_path)
197
251
  self.key = key
198
- self.search_file_tree_kwargs = search_file_tree_kwargs
199
252
  self.validate = validate
200
- self.file_paths, self.file_shapes = find_h5_files(
201
- folder_path, self.key, self.search_file_tree_kwargs
202
- )
253
+ self.file_paths, self.file_shapes = find_h5_files(folder_path, self.key)
203
254
  assert self.n_files > 0, f"No files in folder: {folder_path}"
204
255
  if self.validate:
205
256
  self.validate_folder()
@@ -241,7 +292,7 @@ class Folder:
241
292
  return
242
293
 
243
294
  num_frames_per_file = []
244
- validated_succesfully = True
295
+ validated_successfully = True
245
296
  for file_path in tqdm.tqdm(
246
297
  self.file_paths,
247
298
  total=self.n_files,
@@ -253,9 +304,9 @@ class Folder:
253
304
  validation_error_log.append(f"File {file_path} is not a valid zea dataset.\n{e}\n")
254
305
  # convert into warning
255
306
  log.warning(f"Error in file {file_path}.\n{e}")
256
- validated_succesfully = False
307
+ validated_successfully = False
257
308
 
258
- if not validated_succesfully:
309
+ if not validated_successfully:
259
310
  log.warning(
260
311
  "Check warnings above for details. No validation file was created. "
261
312
  f"See {validation_error_file_path} for details."
@@ -319,24 +370,27 @@ class Folder:
319
370
  data_types = self.get_data_types(self.file_paths[0])
320
371
 
321
372
  number_of_frames = sum(num_frames_per_file)
322
- with open(validation_file_path, "w", encoding="utf-8") as f:
323
- f.write(f"Dataset: {path}\n")
324
- f.write(f"Validated on: {get_date_string()}\n")
325
- f.write(f"Number of files: {self.n_files}\n")
326
- f.write(f"Number of frames: {number_of_frames}\n")
327
- f.write(f"Data types: {', '.join(data_types)}\n")
328
- f.write(f"{'-' * 80}\n")
329
- # write all file names (not entire path) with number of frames on a new line
330
- for file_path, num_frames in zip(self.file_paths, num_frames_per_file):
331
- f.write(f"{file_path.name}: {num_frames}\n")
332
- f.write(f"{'-' * 80}\n")
333
-
334
- # Write the hash of the validation file
335
- validation_file_hash = calculate_file_hash(validation_file_path)
336
- with open(validation_file_path, "a", encoding="utf-8") as f:
337
- # *** validation file hash *** (80 total line length)
338
- f.write("*** validation file hash ***\n")
339
- f.write(f"hash: {validation_file_hash}")
373
+ try:
374
+ with open(validation_file_path, "w", encoding="utf-8") as f:
375
+ f.write(f"Dataset: {path}\n")
376
+ f.write(f"Validated on: {get_date_string()}\n")
377
+ f.write(f"Number of files: {self.n_files}\n")
378
+ f.write(f"Number of frames: {number_of_frames}\n")
379
+ f.write(f"Data types: {', '.join(data_types)}\n")
380
+ f.write(f"{'-' * 80}\n")
381
+ # write all file names (not entire path) with number of frames on a new line
382
+ for file_path, num_frames in zip(self.file_paths, num_frames_per_file):
383
+ f.write(f"{file_path.name}: {num_frames}\n")
384
+ f.write(f"{'-' * 80}\n")
385
+
386
+ # Write the hash of the validation file
387
+ validation_file_hash = calculate_file_hash(validation_file_path)
388
+ with open(validation_file_path, "a", encoding="utf-8") as f:
389
+ # *** validation file hash *** (80 total line length)
390
+ f.write("*** validation file hash ***\n")
391
+ f.write(f"hash: {validation_file_hash}")
392
+ except Exception as e:
393
+ log.warning(f"Unable to write validation flag: {e}")
340
394
 
341
395
  def __repr__(self):
342
396
  return (
@@ -413,7 +467,6 @@ class Dataset(H5FileHandleCache):
413
467
  self,
414
468
  file_paths: List[str] | str,
415
469
  key: str,
416
- search_file_tree_kwargs: dict | None = None,
417
470
  validate: bool = True,
418
471
  directory_splits: list | None = None,
419
472
  **kwargs,
@@ -424,9 +477,6 @@ class Dataset(H5FileHandleCache):
424
477
  file_paths (str or list): (list of) path(s) to the folder(s) containing the HDF5 file(s)
425
478
  or list of HDF5 file paths. Can be a mixed list of folders and files.
426
479
  key (str): The key to access the HDF5 dataset.
427
- search_file_tree_kwargs (dict, optional): Additional keyword arguments for the
428
- search_file_tree function. These are only used when `file_paths` are directories.
429
- Defaults to None.
430
480
  validate (bool, optional): Whether to validate the dataset. Defaults to True.
431
481
  directory_splits (list, optional): List of directory split by. Is a list of floats
432
482
  between 0 and 1, with the same length as the number of file_paths given.
@@ -435,7 +485,6 @@ class Dataset(H5FileHandleCache):
435
485
  """
436
486
  super().__init__(**kwargs)
437
487
  self.key = key
438
- self.search_file_tree_kwargs = search_file_tree_kwargs
439
488
  self.validate = validate
440
489
 
441
490
  self.file_paths, self.file_shapes = self.find_files_and_shapes(file_paths)
@@ -475,7 +524,7 @@ class Dataset(H5FileHandleCache):
475
524
  file_path = Path(file_path)
476
525
 
477
526
  if file_path.is_dir():
478
- folder = Folder(file_path, self.key, self.search_file_tree_kwargs, self.validate)
527
+ folder = Folder(file_path, self.key, self.validate)
479
528
  file_paths += folder.file_paths
480
529
  file_shapes += folder.file_shapes
481
530
  del folder
@@ -539,14 +588,6 @@ class Dataset(H5FileHandleCache):
539
588
  def __str__(self):
540
589
  return f"Dataset with {self.n_files} files (key='{self.key}')"
541
590
 
542
- def close(self):
543
- """Close all cached file handles."""
544
- for file in self._file_handle_cache.values():
545
- if file is not None and file.id.valid:
546
- file.close()
547
- self._file_handle_cache.clear()
548
- log.info("Closed all cached file handles.")
549
-
550
591
  def __enter__(self):
551
592
  return self
552
593