zea 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zea/__init__.py CHANGED
@@ -7,7 +7,7 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.1"
10
+ __version__ = "0.0.2"
11
11
 
12
12
 
13
13
  def setup():
@@ -7,6 +7,45 @@ from zea.beamform.lens_correction import calculate_lens_corrected_delays
7
7
  from zea.tensor_ops import safe_vectorize
8
8
 
9
9
 
10
+ def fnum_window_fn_rect(normalized_angle):
11
+ """Rectangular window function for f-number masking."""
12
+ return ops.where(normalized_angle <= 1.0, 1.0, 0.0)
13
+
14
+
15
+ def fnum_window_fn_hann(normalized_angle):
16
+ """Hann window function for f-number masking."""
17
+ # Use a Hann window function to smoothly transition the mask
18
+ return ops.where(
19
+ normalized_angle <= 1.0,
20
+ 0.5 * (1 + ops.cos(np.pi * normalized_angle)),
21
+ 0.0,
22
+ )
23
+
24
+
25
+ def fnum_window_fn_tukey(normalized_angle, alpha=0.5):
26
+ """Tukey window function for f-number masking.
27
+
28
+ Args:
29
+ normalized_angle (ops.Tensor): Normalized angle values in the range [0, 1].
30
+ alpha (float, optional): The alpha parameter for the Tukey window. 0.0 corresponds to a
31
+ rectangular window, 1.0 corresponds to a Hann window. Defaults to 0.5.
32
+ """
33
+ # Use a Tukey window function to smoothly transition the mask
34
+ normalized_angle = ops.clip(ops.abs(normalized_angle), 0.0, 1.0)
35
+
36
+ beta = 1.0 - alpha
37
+
38
+ return ops.where(
39
+ normalized_angle < beta,
40
+ 1.0,
41
+ ops.where(
42
+ normalized_angle < 1.0,
43
+ 0.5 * (1 + ops.cos(np.pi * (normalized_angle - beta) / (ops.abs(alpha) + 1e-6))),
44
+ 0.0,
45
+ ),
46
+ )
47
+
48
+
10
49
  def tof_correction(
11
50
  data,
12
51
  flatgrid,
@@ -24,6 +63,7 @@ def tof_correction(
24
63
  apply_lens_correction=False,
25
64
  lens_thickness=1e-3,
26
65
  lens_sound_speed=1000,
66
+ fnum_window_fn=fnum_window_fn_rect,
27
67
  ):
28
68
  """Time-of-flight correction for a flat grid.
29
69
 
@@ -54,6 +94,9 @@ def tof_correction(
54
94
  lens correction. Defaults to 1e-3.
55
95
  lens_sound_speed (float, optional): Speed of sound in the lens in m/s. Used
56
96
  for lens correction Defaults to 1000.
97
+ fnum_window_fn (callable, optional): F-number function to define the transition from
98
+ straight in front of the element (fn(0.0)) to the largest angle within the f-number cone
99
+ (fn(1.0)). The function should be zero for fn(x>1.0).
57
100
 
58
101
  Returns:
59
102
  (ops.Tensor): time-of-flight corrected data
@@ -100,7 +143,7 @@ def tof_correction(
100
143
  mask = ops.cond(
101
144
  fnum == 0,
102
145
  lambda: ops.ones((n_pix, n_el, 1)),
103
- lambda: apod_mask(flatgrid, probe_geometry, fnum),
146
+ lambda: fnumber_mask(flatgrid, probe_geometry, fnum, fnum_window_fn=fnum_window_fn),
104
147
  )
105
148
 
106
149
  def _apply_delays(data_tx, txdel):
@@ -408,64 +451,48 @@ def distance_Tx_generic(
408
451
  return dist
409
452
 
410
453
 
411
- def apod_mask(grid, probe_geometry, f_number):
454
+ def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
412
455
  """Apodization mask for the receive beamformer.
413
456
 
414
- Computes a binary mask to disregard pixels outside of the vision cone of a
457
+ Computes a mask to disregard pixels outside of the vision cone of a
415
458
  transducer element. Transducer elements can only accurately measure
416
459
  signals within some range of incidence angles. Waves coming in from the
417
460
  side do not register correctly leading to a worse image.
418
461
 
419
462
  Args:
420
- grid (ops.Tensor): The flattened image grid `(n_pix, 3)`.
463
+ flatgrid (ops.Tensor): The flattened image grid `(n_pix, 3)`.
421
464
  probe_geometry (ops.Tensor): The transducer element positions of shape
422
465
  `(n_el, 3)`.
423
466
  f_number (int): The receive f-number. Set to zero to not use masking and
424
467
  return 1. (The f-number is the ratio between distance from the transducer
425
468
  and the size of the aperture below which transducer elements contribute to
426
469
  the signal for a pixel.).
470
+ fnum_window_fn (callable): F-number function to define the transition from
471
+ straight in front of the element (fn(0.0)) to the largest angle within the f-number cone
472
+ (fn(1.0)). The function should be zero for fn(x>1.0).
473
+
427
474
 
428
475
  Returns:
429
476
  Tensor: Mask of shape `(n_pix, n_el, 1)`
430
477
  """
431
- n_pix = ops.shape(grid)[0]
432
- n_el = ops.shape(probe_geometry)[0]
433
-
434
- # Get the depth of every pixel
435
- z_pixel = grid[:, 2]
436
- # Get the lateral location of each pixel
437
- x_pixel = grid[:, 0]
438
- # Get the lateral location of each element
439
- x_element = ops.cast(probe_geometry[:, 0], dtype="float32")
440
-
441
- # Compute the aperture size for every pixel
442
- # The f-number is by definition f=z/aperture
443
- aperture = z_pixel / f_number
444
-
445
- # Use matrix multiplication to expand aperture tensor, x_pixel tensor, and
446
- # x_element tensor to shape (n_pix, n_el)
447
- ones_aperture = ops.ones(
448
- (1, n_el), dtype=ops.dtype(aperture)
449
- ) # getting error here? pip install -U keras ;)
450
- ones_x_pixel = ops.ones((1, n_el), dtype=ops.dtype(x_pixel))
451
- ones_x_element = ops.ones((n_pix, 1), dtype=ops.dtype(x_element))
452
-
453
- aperture = ops.matmul(aperture[..., None], ones_aperture)
454
- expanded_x_pixel = ops.matmul(x_pixel[..., None], ones_x_pixel)
455
- expanded_x_element = ops.matmul(ones_x_element, x_element[None])
456
-
457
- # Compute the lateral distance between elements and pixels
458
- # Of shape (n_pix, n_el)
459
- distance = ops.abs(expanded_x_pixel - expanded_x_element)
460
-
461
- # Compute binary mask for which the lateral pixel distance is less than
462
- # half
463
- # the aperture i.e. where the pixel lies within the vision cone of the
464
- # element
465
- mask = distance <= aperture / 2
466
- mask = ops.cast(mask, "float32")
467
-
468
- # Add dummy dimension for RF/IQ channel channel
478
+
479
+ grid_relative_to_probe = flatgrid[:, None] - probe_geometry[None]
480
+
481
+ grid_relative_to_probe_norm = ops.linalg.norm(grid_relative_to_probe, axis=-1)
482
+
483
+ grid_relative_to_probe_z = grid_relative_to_probe[..., 2] / (grid_relative_to_probe_norm + 1e-6)
484
+
485
+ alpha = ops.arccos(grid_relative_to_probe_z)
486
+
487
+ # The f-number is fnum = z/aperture = 1/(2 * tan(alpha))
488
+ # Rearranging gives us alpha = arctan(1/(2 * fnum))
489
+ # We can use this to compute the maximum angle alpha that is allowed
490
+ max_alpha = ops.arctan(1 / (2 * f_number))
491
+
492
+ normalized_angle = alpha / max_alpha
493
+ mask = fnum_window_fn(normalized_angle)
494
+
495
+ # Add dummy channel dimension
469
496
  mask = mask[..., None]
470
497
 
471
498
  return mask
@@ -0,0 +1,43 @@
1
+ import numpy as np
2
+
3
+
4
+ def fish():
5
+ """Returns a scatterer phantom for ultrasound simulation tests.
6
+
7
+ Returns:
8
+ ndarray: The scatterer positions of shape (n_scat, 3).
9
+ """
10
+ # The size is the height of the fish
11
+ size = 11e-3
12
+ z_offset = 2.0 * size
13
+
14
+ # See https://en.wikipedia.org/wiki/Fish_curve
15
+ def fish_curve(t, size=1):
16
+ x = size * (np.cos(t) - np.sin(t) ** 2 / np.sqrt(2))
17
+ y = size * np.cos(t) * np.sin(t)
18
+ return x, y
19
+
20
+ scat_x, scat_z = fish_curve(np.linspace(0, 2 * np.pi, 100), size=size)
21
+
22
+ scat_x = np.concatenate(
23
+ [
24
+ scat_x,
25
+ np.array([size * 0.7]),
26
+ np.array([size * 1.1]),
27
+ np.array([size * 1.4]),
28
+ np.array([size * 1.2]),
29
+ ]
30
+ )
31
+ scat_y = np.zeros_like(scat_x)
32
+ scat_z = np.concatenate(
33
+ [
34
+ scat_z,
35
+ np.array([-size * 0.1]),
36
+ np.array([-size * 0.25]),
37
+ np.array([-size * 0.6]),
38
+ np.array([-size * 1.0]),
39
+ ]
40
+ )
41
+
42
+ scat_z += z_offset
43
+ return np.stack([scat_x, scat_y, scat_z], axis=1)
zea/beamform/pixelgrid.py CHANGED
@@ -7,67 +7,27 @@ from zea import log
7
7
  eps = 1e-10
8
8
 
9
9
 
10
- def get_grid(
11
- xlims,
12
- zlims,
13
- Nx: int,
14
- Nz: int,
15
- sound_speed,
16
- center_frequency,
17
- pixels_per_wavelength,
18
- verbose=False,
19
- ):
20
- """Creates a pixelgrid based on scan class parameters."""
21
-
22
- if Nx and Nz:
23
- grid = cartesian_pixel_grid(xlims, zlims, Nx=int(Nx), Nz=int(Nz))
24
- else:
25
- wvln = sound_speed / center_frequency
26
- dx = wvln / pixels_per_wavelength
27
- dz = dx
28
- grid = cartesian_pixel_grid(xlims, zlims, dx=dx, dz=dz)
29
- if verbose:
30
- print(
31
- f"Pixelgrid was set automatically to Nx: {grid.shape[1]}, Nz: {grid.shape[0]}, "
32
- f"using {pixels_per_wavelength} pixels per wavelength."
33
- )
34
- return grid
35
-
36
-
37
10
  def check_for_aliasing(scan):
38
11
  """Checks if the scan class parameters will cause spatial aliasing due to a too low pixel
39
12
  density. If so, a warning is printed with a suggestion to increase the pixel density by either
40
13
  increasing the number of pixels, or decreasing the pixel spacing, depending on which parameter
41
14
  was set by the user."""
42
- wvln = scan.sound_speed / scan.center_frequency
43
- dx = wvln / scan.pixels_per_wavelength
44
- dz = dx
45
-
46
15
  width = scan.xlims[1] - scan.xlims[0]
47
16
  depth = scan.zlims[1] - scan.zlims[0]
48
-
49
- if scan.Nx and scan.Nz:
50
- if width / scan.Nx > wvln / 2:
51
- log.warning(
52
- f"width/Nx = {width / scan.Nx:.7f} < wvln/2 = {wvln / 2}. "
53
- f"Consider increasing scan.Nx to {int(np.ceil(width / (wvln / 2)))} or more."
54
- )
55
- if depth / scan.Nz > wvln / 2:
56
- log.warning(
57
- f"depth/Nz = {depth / scan.Nz:.7f} < wvln/2 = {wvln / 2:.7f}. "
58
- f"Consider increasing scan.Nz to {int(np.ceil(depth / (wvln / 2)))} or more."
59
- )
60
- else:
61
- if dx > wvln / 2:
62
- log.warning(
63
- f"dx = {dx:.7f} > wvln/2 = {wvln / 2:.7f}. "
64
- f"Consider increasing scan.pixels_per_wavelength to 2 or more"
65
- )
66
- if dz > wvln / 2:
67
- log.warning(
68
- f"dz = {dz:.7f} > wvln/2 = {wvln / 2:.7f}. "
69
- f"Consider increasing scan.pixels_per_wavelength to 2 or more"
70
- )
17
+ wvln = scan.wavelength
18
+
19
+ if width / scan.Nx > wvln / 2:
20
+ log.warning(
21
+ f"width/Nx = {width / scan.Nx:.7f} < wavelength/2 = {wvln / 2}. "
22
+ f"Consider either increasing scan.Nx to {int(np.ceil(width / (wvln / 2)))} or more, or "
23
+ "increasing scan.pixels_per_wavelength to 2 or more."
24
+ )
25
+ if depth / scan.Nz > wvln / 2:
26
+ log.warning(
27
+ f"depth/Nz = {depth / scan.Nz:.7f} < wavelength/2 = {wvln / 2:.7f}. "
28
+ f"Consider either increasing scan.Nz to {int(np.ceil(depth / (wvln / 2)))} or more, or "
29
+ "increasing scan.pixels_per_wavelength to 2 or more."
30
+ )
71
31
 
72
32
 
73
33
  def cartesian_pixel_grid(xlims, zlims, Nx=None, Nz=None, dx=None, dz=None):
@@ -130,7 +90,7 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
130
90
  Cartesian coordinates (x, y, z), with nr being the number of radial pixels.
131
91
  """
132
92
  # Get focusing positions in rho-theta coordinates
133
- r = np.arange(rlims[0], rlims[1] + eps, dr) # Depth rho
93
+ r = np.arange(rlims[0], rlims[1], dr) # Depth rho
134
94
  t = dirs[:, 0] # Use azimuthal angle theta (ignore elevation angle)
135
95
  tt, rr = np.meshgrid(t, r, indexing="ij")
136
96
 
@@ -140,3 +100,31 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
140
100
  yy = 0 * xx
141
101
  grid = np.stack((xx, yy, zz), axis=-1)
142
102
  return grid
103
+
104
+
105
+ def polar_pixel_grid(polar_limits, zlims, Nz: int, Nr: int):
106
+ """Generate a polar grid.
107
+
108
+ Uses radial_pixel_grid but based on parameters that are present in the scan class.
109
+
110
+ Args:
111
+ polar_limits (tuple): Azimuthal limits of pixel grid ([azimuth_min, azimuth_max])
112
+ zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
113
+ Nz (int, optional): Number of depth pixels.
114
+ Nr (int, optional): Number of azimuthal pixels.
115
+
116
+ Returns:
117
+ grid (np.ndarray): Pixel grid of size (Nz, Nr, 3) in Cartesian coordinates (x, y, z)
118
+ """
119
+ assert len(polar_limits) == 2, "polar_limits must be a tuple of length 2."
120
+ assert len(zlims) == 2, "zlims must be a tuple of length 2."
121
+
122
+ dr = (zlims[1] - zlims[0]) / Nz
123
+
124
+ oris = np.array([0, 0, 0])
125
+ oris = np.tile(oris, (Nr, 1))
126
+ dirs_az = np.linspace(*polar_limits, Nr)
127
+
128
+ dirs_el = np.zeros(Nr)
129
+ dirs = np.vstack((dirs_az, dirs_el)).T
130
+ return radial_pixel_grid(zlims, dr, oris, dirs).transpose(1, 0, 2)
zea/data/__main__.py ADDED
@@ -0,0 +1,31 @@
1
+ """Command-line interface for copying a zea.Folder to a new location.
2
+
3
+ Usage:
4
+ python -m zea.data <source_folder> <destination_folder> <key>
5
+ """
6
+
7
+ import argparse
8
+
9
+ from zea import Folder
10
+
11
+
12
+ def main():
13
+ parser = argparse.ArgumentParser(description="Copy a zea.Folder to a new location.")
14
+ parser.add_argument("src", help="Source folder path")
15
+ parser.add_argument("dst", help="Destination folder path")
16
+ parser.add_argument("key", help="Key to access in the hdf5 files")
17
+ parser.add_argument(
18
+ "--mode",
19
+ default="a",
20
+ choices=["a", "w", "r+", "x"],
21
+ help="Mode in which to open the destination files (default: 'a')",
22
+ )
23
+
24
+ args = parser.parse_args()
25
+
26
+ src_folder = Folder(args.src, args.key, validate=False)
27
+ src_folder.copy(args.dst, args.key, mode=args.mode)
28
+
29
+
30
+ if __name__ == "__main__":
31
+ main()
zea/data/data_format.py CHANGED
@@ -84,10 +84,12 @@ def generate_example_dataset(
84
84
  aligned_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
85
85
  envelope_data = np.ones((n_frames, n_z, n_x))
86
86
  beamformed_data = np.ones((n_frames, n_z, n_x, n_ch))
87
+ image_sc = np.ones_like(image)
87
88
  else:
88
89
  aligned_data = None
89
90
  envelope_data = None
90
91
  beamformed_data = None
92
+ image_sc = None
91
93
 
92
94
  generate_zea_dataset(
93
95
  path,
@@ -96,6 +98,7 @@ def generate_example_dataset(
96
98
  envelope_data=envelope_data,
97
99
  beamformed_data=beamformed_data,
98
100
  image=image,
101
+ image_sc=image_sc,
99
102
  probe_geometry=probe_geometry,
100
103
  sampling_frequency=sampling_frequency,
101
104
  center_frequency=center_frequency,
zea/data/datasets.py CHANGED
@@ -113,14 +113,14 @@ class H5FileHandleCache:
113
113
  self._file_handle_cache = OrderedDict()
114
114
 
115
115
 
116
- def find_h5_files(paths: str | list, key: str, search_file_tree_kwargs: dict | None = None):
116
+ def find_h5_files(paths: str | list, key: str = None, search_file_tree_kwargs: dict | None = None):
117
117
  """
118
- Find HDF5 files from a directory or list of directories and retrieve their shapes.
118
+ Find HDF5 files from a directory or list of directories and optionally retrieve their shapes.
119
119
 
120
120
  Args:
121
121
  paths (str or list): A single directory path, a list of directory paths,
122
122
  or a single HDF5 file path.
123
- key (str): The key to access the HDF5 dataset.
123
+ key (str, optional): The key to get the file shapes for.
124
124
  search_file_tree_kwargs (dict, optional): Additional keyword arguments for the
125
125
  search_file_tree function. Defaults to None.
126
126
 
@@ -147,7 +147,8 @@ def find_h5_files(paths: str | list, key: str, search_file_tree_kwargs: dict | N
147
147
  if Path(path).is_file():
148
148
  path = Path(path)
149
149
  # If the path is a file, get its shape directly
150
- file_shapes.append(File.get_shape(path, key))
150
+ if key is not None:
151
+ file_shapes.append(File.get_shape(path, key))
151
152
  file_paths.append(str(path))
152
153
  continue
153
154
 
@@ -171,7 +172,7 @@ class Folder:
171
172
  def __init__(
172
173
  self,
173
174
  folder_path: list[str] | list[Path],
174
- key: str,
175
+ key: str = None,
175
176
  search_file_tree_kwargs: dict | None = None,
176
177
  validate: bool = True,
177
178
  hf_cache_dir: str = HF_DATASETS_DIR,
@@ -192,7 +193,7 @@ class Folder:
192
193
 
193
194
  super().__init__(**kwargs)
194
195
 
195
- self.folder_path = folder_path
196
+ self.folder_path = Path(folder_path)
196
197
  self.key = key
197
198
  self.search_file_tree_kwargs = search_file_tree_kwargs
198
199
  self.validate = validate
@@ -346,6 +347,64 @@ class Folder:
346
347
  def __str__(self):
347
348
  return f"Folder with {self.n_files} files in '{self.folder_path}' (key='{self.key}')"
348
349
 
350
+ def copy(self, to_path: str | Path, key: str = None, mode: str | None = None):
351
+ """Copy the data for all or a specific key to a new location.
352
+
353
+ Has the option to copy all keys or only a specific key. By default, it only copies if the
354
+ destination file does not already contain the key. You can change the mode to 'w' to
355
+ overwrite the destination file. Will always copy metadata such as dataset attributes and
356
+ scan object.
357
+
358
+ Args:
359
+ to_path (str or Path): The destination path where files will be copied.
360
+ key (str, optional): The key to copy from the source files.
361
+ If 'all' or '*', all keys will be copied. Defaults to None, which
362
+ uses the key set in the Folder instance.
363
+ mode (str): The mode in which to open the destination files.
364
+ Defaults to 'a' (append mode), and 'w' (write mode) if key is 'all' or '*'.
365
+ See: https://docs.h5py.org/en/stable/high/file.html#opening-creating-files
366
+ """
367
+ if key is None and self.key is None:
368
+ raise ValueError(
369
+ "No key specified. Please provide a key to copy the data for, or set the "
370
+ "key attribute of the Folder instance."
371
+ )
372
+ elif key is None:
373
+ key = self.key
374
+
375
+ all_keys = key == "all" or key == "*"
376
+
377
+ if mode is None:
378
+ mode = "a" if not all_keys else "w"
379
+
380
+ if all_keys:
381
+ key_msg = "Including all keys."
382
+ assert mode in ["w", "x"], (
383
+ "If you want to copy all keys, the destination file must be opened "
384
+ "in 'w' or 'x' mode, which means it will be overwritten or created."
385
+ )
386
+ else:
387
+ key_msg = f"Only copying key '{key}'."
388
+ assert mode in ["a", "w", "r+", "x"], (
389
+ f"Invalid mode '{mode}'. Must be one of 'a', 'w', 'r+', or 'x'."
390
+ )
391
+
392
+ to_path = Path(to_path)
393
+ to_path.mkdir(parents=True, exist_ok=True)
394
+
395
+ for file_path in tqdm.tqdm(
396
+ self.file_paths,
397
+ total=self.n_files,
398
+ desc=f"Copying dataset from {self.folder_path} to {to_path}. {key_msg}",
399
+ ):
400
+ dst_path = Path(file_path).relative_to(self.folder_path)
401
+ with File(file_path) as src, File(to_path / dst_path, mode) as dst:
402
+ if all_keys:
403
+ for obj in src.keys():
404
+ src.copy(obj, dst)
405
+ else:
406
+ src.copy_key(key, dst)
407
+
349
408
 
350
409
  class Dataset(H5FileHandleCache):
351
410
  """Iterate over File(s) and Folder(s)."""
@@ -582,3 +641,10 @@ def count_samples_per_directory(file_names, directories):
582
641
  )
583
642
 
584
643
  return counts
644
+
645
+
646
+ if __name__ == "__main__":
647
+ src_folder = Folder(
648
+ "/mnt/z/Ultrasound-BMd/data/USBMD_datasets/CAMUS/val/patient0450", "image", validate=False
649
+ )
650
+ src_folder.copy("./CAMIUS", key="all")
zea/data/file.py CHANGED
@@ -175,6 +175,8 @@ class File(h5py.File):
175
175
  if isinstance(key, enum.Enum):
176
176
  key = key.value
177
177
 
178
+ assert isinstance(key, str), f"Key must be a string, got {type(key)}. "
179
+
178
180
  # Return the key if it is in the file
179
181
  if key in self.keys():
180
182
  return key
@@ -448,6 +450,35 @@ class File(h5py.File):
448
450
  def __str__(self):
449
451
  return f"zea HDF5 File: '{self.path.name}' (mode={self.mode})"
450
452
 
453
+ def copy_key(self, key: str, dst: "File"):
454
+ """Copy a specific key to another file.
455
+
456
+ Will always copy the attributes and the scan data if it exists. Will warn if the key is
457
+ not in this file or if the key already exists in the destination file.
458
+
459
+ Args:
460
+ key (str): The key to copy.
461
+ dst (File): The destination file to copy the key to.
462
+ """
463
+ key = self.format_key(key)
464
+
465
+ # Copy the key if it does not already exist in the destination file
466
+ if key in dst:
467
+ log.warning(f"Skipping key '{key}' because it already exists in dst file {dst.path}.")
468
+ elif key in self:
469
+ self.copy(key, dst, name=key)
470
+ else:
471
+ log.warning(f"Key '{key}' not found in src file {self.path}. Skipping copy.")
472
+
473
+ # Copy attributes from src to dst
474
+ for attr_key, attr_value in self.attrs.items():
475
+ dst[key].attrs[attr_key] = attr_value
476
+
477
+ # Copy scan data if requested
478
+ if "scan" in self and "scan" not in dst:
479
+ # Copy the scan data if it exists
480
+ self.copy("scan", dst)
481
+
451
482
  def summary(self):
452
483
  """Print the contents of the file."""
453
484
  _print_hdf5_attrs(self)
zea/display.py CHANGED
@@ -3,6 +3,7 @@
3
3
  from functools import partial
4
4
  from typing import Tuple, Union
5
5
 
6
+ import keras
6
7
  import numpy as np
7
8
  import scipy
8
9
  from keras import ops
@@ -357,6 +358,11 @@ def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value
357
358
 
358
359
  def _interpolate_batch(images, coordinates, fill_value=0.0, order=1, vectorize=True):
359
360
  """Interpolate a batch of images."""
361
+
362
+ # TODO: figure out why tensorflow map_coordinates is broken
363
+ if keras.backend.backend() == "tensorflow":
364
+ assert order > 1, "Some bug in tensorflow in map_coordinates, set order > 1 to use scipy."
365
+
360
366
  image_shape = images.shape
361
367
  num_image_dims = coordinates.shape[0]
362
368
 
zea/internal/cache.py CHANGED
@@ -80,7 +80,7 @@ _CACHE_DIR = ZEA_CACHE_DIR / "cached_funcs"
80
80
  _CACHE_DIR.mkdir(parents=True, exist_ok=True)
81
81
 
82
82
 
83
- def serialize_elements(key_elements: list):
83
+ def serialize_elements(key_elements: list, shorten: bool = False) -> str:
84
84
  """Serialize elements of a list to generate a cache key.
85
85
 
86
86
  In general uses the string representation of the elements unless
@@ -90,18 +90,18 @@ def serialize_elements(key_elements: list):
90
90
  Args:
91
91
  key_elements (list): List of elements to serialize. Can be nested lists
92
92
  or tuples. In this case the elements are serialized recursively.
93
+ shorten (bool): If True, the serialized string is hashed to a shorter
94
+ representation using MD5. Defaults to False.
93
95
 
94
96
  Returns:
95
- list[str]: List of serialized elements. In cases where the elements were
96
- lists of tuples those are combined into a single string.
97
+ str: A serialized string representation of the elements, joined by underscores.
97
98
 
98
99
  """
99
100
  serialized_elements = []
100
101
  for element in key_elements:
101
102
  if isinstance(element, (list, tuple)):
102
103
  # If element is a list or tuple, serialize its elements recursively
103
- element = serialize_elements(element)
104
- serialized_elements.append("_".join(element))
104
+ serialized_elements.append(serialize_elements(element))
105
105
  elif hasattr(element, "serialized"):
106
106
  # Use the serialized attribute if it exists (e.g. for zea.core.Object)
107
107
  serialized_elements.append(str(element.serialized))
@@ -120,7 +120,10 @@ def serialize_elements(key_elements: list):
120
120
  element = hashlib.md5(element).hexdigest()
121
121
  serialized_elements.append(element)
122
122
 
123
- return serialized_elements
123
+ serialized = "_".join(serialized_elements)
124
+ if shorten:
125
+ return hashlib.md5(serialized.encode()).hexdigest()
126
+ return serialized
124
127
 
125
128
 
126
129
  def get_function_source(func):
@@ -182,8 +185,10 @@ def generate_cache_key(func, args, kwargs, arg_names):
182
185
  if name in bound_args.arguments:
183
186
  key_elements.append(bound_args.arguments[name])
184
187
 
185
- key = "_".join(serialize_elements(key_elements))
186
- return f"{func.__qualname__}_" + hashlib.md5(key.encode()).hexdigest()
188
+ # Add keras backend
189
+ key_elements.append(keras.backend.backend())
190
+
191
+ return f"{func.__qualname__}_" + serialize_elements(key_elements, shorten=True)
187
192
 
188
193
 
189
194
  def cache_output(*arg_names, verbose=False):