zea 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +1 -1
- zea/beamform/beamformer.py +69 -42
- zea/beamform/phantoms.py +43 -0
- zea/beamform/pixelgrid.py +43 -55
- zea/data/__main__.py +31 -0
- zea/data/data_format.py +3 -0
- zea/data/datasets.py +72 -6
- zea/data/file.py +31 -0
- zea/display.py +6 -0
- zea/internal/cache.py +13 -8
- zea/internal/core.py +32 -9
- zea/internal/operators.py +0 -29
- zea/internal/parameters.py +115 -85
- zea/io_lib.py +1 -1
- zea/models/taesd.py +3 -2
- zea/ops.py +31 -12
- zea/scan.py +136 -78
- zea/simulator.py +3 -5
- zea/visualize.py +57 -0
- {zea-0.0.1.dist-info → zea-0.0.2.dist-info}/METADATA +2 -2
- {zea-0.0.1.dist-info → zea-0.0.2.dist-info}/RECORD +24 -22
- {zea-0.0.1.dist-info → zea-0.0.2.dist-info}/LICENSE +0 -0
- {zea-0.0.1.dist-info → zea-0.0.2.dist-info}/WHEEL +0 -0
- {zea-0.0.1.dist-info → zea-0.0.2.dist-info}/entry_points.txt +0 -0
zea/__init__.py
CHANGED
zea/beamform/beamformer.py
CHANGED
|
@@ -7,6 +7,45 @@ from zea.beamform.lens_correction import calculate_lens_corrected_delays
|
|
|
7
7
|
from zea.tensor_ops import safe_vectorize
|
|
8
8
|
|
|
9
9
|
|
|
10
|
+
def fnum_window_fn_rect(normalized_angle):
|
|
11
|
+
"""Rectangular window function for f-number masking."""
|
|
12
|
+
return ops.where(normalized_angle <= 1.0, 1.0, 0.0)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def fnum_window_fn_hann(normalized_angle):
|
|
16
|
+
"""Hann window function for f-number masking."""
|
|
17
|
+
# Use a Hann window function to smoothly transition the mask
|
|
18
|
+
return ops.where(
|
|
19
|
+
normalized_angle <= 1.0,
|
|
20
|
+
0.5 * (1 + ops.cos(np.pi * normalized_angle)),
|
|
21
|
+
0.0,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def fnum_window_fn_tukey(normalized_angle, alpha=0.5):
|
|
26
|
+
"""Tukey window function for f-number masking.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
normalized_angle (ops.Tensor): Normalized angle values in the range [0, 1].
|
|
30
|
+
alpha (float, optional): The alpha parameter for the Tukey window. 0.0 corresponds to a
|
|
31
|
+
rectangular window, 1.0 corresponds to a Hann window. Defaults to 0.5.
|
|
32
|
+
"""
|
|
33
|
+
# Use a Tukey window function to smoothly transition the mask
|
|
34
|
+
normalized_angle = ops.clip(ops.abs(normalized_angle), 0.0, 1.0)
|
|
35
|
+
|
|
36
|
+
beta = 1.0 - alpha
|
|
37
|
+
|
|
38
|
+
return ops.where(
|
|
39
|
+
normalized_angle < beta,
|
|
40
|
+
1.0,
|
|
41
|
+
ops.where(
|
|
42
|
+
normalized_angle < 1.0,
|
|
43
|
+
0.5 * (1 + ops.cos(np.pi * (normalized_angle - beta) / (ops.abs(alpha) + 1e-6))),
|
|
44
|
+
0.0,
|
|
45
|
+
),
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
10
49
|
def tof_correction(
|
|
11
50
|
data,
|
|
12
51
|
flatgrid,
|
|
@@ -24,6 +63,7 @@ def tof_correction(
|
|
|
24
63
|
apply_lens_correction=False,
|
|
25
64
|
lens_thickness=1e-3,
|
|
26
65
|
lens_sound_speed=1000,
|
|
66
|
+
fnum_window_fn=fnum_window_fn_rect,
|
|
27
67
|
):
|
|
28
68
|
"""Time-of-flight correction for a flat grid.
|
|
29
69
|
|
|
@@ -54,6 +94,9 @@ def tof_correction(
|
|
|
54
94
|
lens correction. Defaults to 1e-3.
|
|
55
95
|
lens_sound_speed (float, optional): Speed of sound in the lens in m/s. Used
|
|
56
96
|
for lens correction Defaults to 1000.
|
|
97
|
+
fnum_window_fn (callable, optional): F-number function to define the transition from
|
|
98
|
+
straight in front of the element (fn(0.0)) to the largest angle within the f-number cone
|
|
99
|
+
(fn(1.0)). The function should be zero for fn(x>1.0).
|
|
57
100
|
|
|
58
101
|
Returns:
|
|
59
102
|
(ops.Tensor): time-of-flight corrected data
|
|
@@ -100,7 +143,7 @@ def tof_correction(
|
|
|
100
143
|
mask = ops.cond(
|
|
101
144
|
fnum == 0,
|
|
102
145
|
lambda: ops.ones((n_pix, n_el, 1)),
|
|
103
|
-
lambda:
|
|
146
|
+
lambda: fnumber_mask(flatgrid, probe_geometry, fnum, fnum_window_fn=fnum_window_fn),
|
|
104
147
|
)
|
|
105
148
|
|
|
106
149
|
def _apply_delays(data_tx, txdel):
|
|
@@ -408,64 +451,48 @@ def distance_Tx_generic(
|
|
|
408
451
|
return dist
|
|
409
452
|
|
|
410
453
|
|
|
411
|
-
def
|
|
454
|
+
def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
|
|
412
455
|
"""Apodization mask for the receive beamformer.
|
|
413
456
|
|
|
414
|
-
Computes a
|
|
457
|
+
Computes a mask to disregard pixels outside of the vision cone of a
|
|
415
458
|
transducer element. Transducer elements can only accurately measure
|
|
416
459
|
signals within some range of incidence angles. Waves coming in from the
|
|
417
460
|
side do not register correctly leading to a worse image.
|
|
418
461
|
|
|
419
462
|
Args:
|
|
420
|
-
|
|
463
|
+
flatgrid (ops.Tensor): The flattened image grid `(n_pix, 3)`.
|
|
421
464
|
probe_geometry (ops.Tensor): The transducer element positions of shape
|
|
422
465
|
`(n_el, 3)`.
|
|
423
466
|
f_number (int): The receive f-number. Set to zero to not use masking and
|
|
424
467
|
return 1. (The f-number is the ratio between distance from the transducer
|
|
425
468
|
and the size of the aperture below which transducer elements contribute to
|
|
426
469
|
the signal for a pixel.).
|
|
470
|
+
fnum_window_fn (callable): F-number function to define the transition from
|
|
471
|
+
straight in front of the element (fn(0.0)) to the largest angle within the f-number cone
|
|
472
|
+
(fn(1.0)). The function should be zero for fn(x>1.0).
|
|
473
|
+
|
|
427
474
|
|
|
428
475
|
Returns:
|
|
429
476
|
Tensor: Mask of shape `(n_pix, n_el, 1)`
|
|
430
477
|
"""
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
#
|
|
442
|
-
#
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
) # getting error here? pip install -U keras ;)
|
|
450
|
-
ones_x_pixel = ops.ones((1, n_el), dtype=ops.dtype(x_pixel))
|
|
451
|
-
ones_x_element = ops.ones((n_pix, 1), dtype=ops.dtype(x_element))
|
|
452
|
-
|
|
453
|
-
aperture = ops.matmul(aperture[..., None], ones_aperture)
|
|
454
|
-
expanded_x_pixel = ops.matmul(x_pixel[..., None], ones_x_pixel)
|
|
455
|
-
expanded_x_element = ops.matmul(ones_x_element, x_element[None])
|
|
456
|
-
|
|
457
|
-
# Compute the lateral distance between elements and pixels
|
|
458
|
-
# Of shape (n_pix, n_el)
|
|
459
|
-
distance = ops.abs(expanded_x_pixel - expanded_x_element)
|
|
460
|
-
|
|
461
|
-
# Compute binary mask for which the lateral pixel distance is less than
|
|
462
|
-
# half
|
|
463
|
-
# the aperture i.e. where the pixel lies within the vision cone of the
|
|
464
|
-
# element
|
|
465
|
-
mask = distance <= aperture / 2
|
|
466
|
-
mask = ops.cast(mask, "float32")
|
|
467
|
-
|
|
468
|
-
# Add dummy dimension for RF/IQ channel channel
|
|
478
|
+
|
|
479
|
+
grid_relative_to_probe = flatgrid[:, None] - probe_geometry[None]
|
|
480
|
+
|
|
481
|
+
grid_relative_to_probe_norm = ops.linalg.norm(grid_relative_to_probe, axis=-1)
|
|
482
|
+
|
|
483
|
+
grid_relative_to_probe_z = grid_relative_to_probe[..., 2] / (grid_relative_to_probe_norm + 1e-6)
|
|
484
|
+
|
|
485
|
+
alpha = ops.arccos(grid_relative_to_probe_z)
|
|
486
|
+
|
|
487
|
+
# The f-number is fnum = z/aperture = 1/(2 * tan(alpha))
|
|
488
|
+
# Rearranging gives us alpha = arctan(1/(2 * fnum))
|
|
489
|
+
# We can use this to compute the maximum angle alpha that is allowed
|
|
490
|
+
max_alpha = ops.arctan(1 / (2 * f_number))
|
|
491
|
+
|
|
492
|
+
normalized_angle = alpha / max_alpha
|
|
493
|
+
mask = fnum_window_fn(normalized_angle)
|
|
494
|
+
|
|
495
|
+
# Add dummy channel dimension
|
|
469
496
|
mask = mask[..., None]
|
|
470
497
|
|
|
471
498
|
return mask
|
zea/beamform/phantoms.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def fish():
|
|
5
|
+
"""Returns a scatterer phantom for ultrasound simulation tests.
|
|
6
|
+
|
|
7
|
+
Returns:
|
|
8
|
+
ndarray: The scatterer positions of shape (n_scat, 3).
|
|
9
|
+
"""
|
|
10
|
+
# The size is the height of the fish
|
|
11
|
+
size = 11e-3
|
|
12
|
+
z_offset = 2.0 * size
|
|
13
|
+
|
|
14
|
+
# See https://en.wikipedia.org/wiki/Fish_curve
|
|
15
|
+
def fish_curve(t, size=1):
|
|
16
|
+
x = size * (np.cos(t) - np.sin(t) ** 2 / np.sqrt(2))
|
|
17
|
+
y = size * np.cos(t) * np.sin(t)
|
|
18
|
+
return x, y
|
|
19
|
+
|
|
20
|
+
scat_x, scat_z = fish_curve(np.linspace(0, 2 * np.pi, 100), size=size)
|
|
21
|
+
|
|
22
|
+
scat_x = np.concatenate(
|
|
23
|
+
[
|
|
24
|
+
scat_x,
|
|
25
|
+
np.array([size * 0.7]),
|
|
26
|
+
np.array([size * 1.1]),
|
|
27
|
+
np.array([size * 1.4]),
|
|
28
|
+
np.array([size * 1.2]),
|
|
29
|
+
]
|
|
30
|
+
)
|
|
31
|
+
scat_y = np.zeros_like(scat_x)
|
|
32
|
+
scat_z = np.concatenate(
|
|
33
|
+
[
|
|
34
|
+
scat_z,
|
|
35
|
+
np.array([-size * 0.1]),
|
|
36
|
+
np.array([-size * 0.25]),
|
|
37
|
+
np.array([-size * 0.6]),
|
|
38
|
+
np.array([-size * 1.0]),
|
|
39
|
+
]
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
scat_z += z_offset
|
|
43
|
+
return np.stack([scat_x, scat_y, scat_z], axis=1)
|
zea/beamform/pixelgrid.py
CHANGED
|
@@ -7,67 +7,27 @@ from zea import log
|
|
|
7
7
|
eps = 1e-10
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
def get_grid(
|
|
11
|
-
xlims,
|
|
12
|
-
zlims,
|
|
13
|
-
Nx: int,
|
|
14
|
-
Nz: int,
|
|
15
|
-
sound_speed,
|
|
16
|
-
center_frequency,
|
|
17
|
-
pixels_per_wavelength,
|
|
18
|
-
verbose=False,
|
|
19
|
-
):
|
|
20
|
-
"""Creates a pixelgrid based on scan class parameters."""
|
|
21
|
-
|
|
22
|
-
if Nx and Nz:
|
|
23
|
-
grid = cartesian_pixel_grid(xlims, zlims, Nx=int(Nx), Nz=int(Nz))
|
|
24
|
-
else:
|
|
25
|
-
wvln = sound_speed / center_frequency
|
|
26
|
-
dx = wvln / pixels_per_wavelength
|
|
27
|
-
dz = dx
|
|
28
|
-
grid = cartesian_pixel_grid(xlims, zlims, dx=dx, dz=dz)
|
|
29
|
-
if verbose:
|
|
30
|
-
print(
|
|
31
|
-
f"Pixelgrid was set automatically to Nx: {grid.shape[1]}, Nz: {grid.shape[0]}, "
|
|
32
|
-
f"using {pixels_per_wavelength} pixels per wavelength."
|
|
33
|
-
)
|
|
34
|
-
return grid
|
|
35
|
-
|
|
36
|
-
|
|
37
10
|
def check_for_aliasing(scan):
|
|
38
11
|
"""Checks if the scan class parameters will cause spatial aliasing due to a too low pixel
|
|
39
12
|
density. If so, a warning is printed with a suggestion to increase the pixel density by either
|
|
40
13
|
increasing the number of pixels, or decreasing the pixel spacing, depending on which parameter
|
|
41
14
|
was set by the user."""
|
|
42
|
-
wvln = scan.sound_speed / scan.center_frequency
|
|
43
|
-
dx = wvln / scan.pixels_per_wavelength
|
|
44
|
-
dz = dx
|
|
45
|
-
|
|
46
15
|
width = scan.xlims[1] - scan.xlims[0]
|
|
47
16
|
depth = scan.zlims[1] - scan.zlims[0]
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
log.warning(
|
|
63
|
-
f"dx = {dx:.7f} > wvln/2 = {wvln / 2:.7f}. "
|
|
64
|
-
f"Consider increasing scan.pixels_per_wavelength to 2 or more"
|
|
65
|
-
)
|
|
66
|
-
if dz > wvln / 2:
|
|
67
|
-
log.warning(
|
|
68
|
-
f"dz = {dz:.7f} > wvln/2 = {wvln / 2:.7f}. "
|
|
69
|
-
f"Consider increasing scan.pixels_per_wavelength to 2 or more"
|
|
70
|
-
)
|
|
17
|
+
wvln = scan.wavelength
|
|
18
|
+
|
|
19
|
+
if width / scan.Nx > wvln / 2:
|
|
20
|
+
log.warning(
|
|
21
|
+
f"width/Nx = {width / scan.Nx:.7f} < wavelength/2 = {wvln / 2}. "
|
|
22
|
+
f"Consider either increasing scan.Nx to {int(np.ceil(width / (wvln / 2)))} or more, or "
|
|
23
|
+
"increasing scan.pixels_per_wavelength to 2 or more."
|
|
24
|
+
)
|
|
25
|
+
if depth / scan.Nz > wvln / 2:
|
|
26
|
+
log.warning(
|
|
27
|
+
f"depth/Nz = {depth / scan.Nz:.7f} < wavelength/2 = {wvln / 2:.7f}. "
|
|
28
|
+
f"Consider either increasing scan.Nz to {int(np.ceil(depth / (wvln / 2)))} or more, or "
|
|
29
|
+
"increasing scan.pixels_per_wavelength to 2 or more."
|
|
30
|
+
)
|
|
71
31
|
|
|
72
32
|
|
|
73
33
|
def cartesian_pixel_grid(xlims, zlims, Nx=None, Nz=None, dx=None, dz=None):
|
|
@@ -130,7 +90,7 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
|
|
|
130
90
|
Cartesian coordinates (x, y, z), with nr being the number of radial pixels.
|
|
131
91
|
"""
|
|
132
92
|
# Get focusing positions in rho-theta coordinates
|
|
133
|
-
r = np.arange(rlims[0], rlims[1]
|
|
93
|
+
r = np.arange(rlims[0], rlims[1], dr) # Depth rho
|
|
134
94
|
t = dirs[:, 0] # Use azimuthal angle theta (ignore elevation angle)
|
|
135
95
|
tt, rr = np.meshgrid(t, r, indexing="ij")
|
|
136
96
|
|
|
@@ -140,3 +100,31 @@ def radial_pixel_grid(rlims, dr, oris, dirs):
|
|
|
140
100
|
yy = 0 * xx
|
|
141
101
|
grid = np.stack((xx, yy, zz), axis=-1)
|
|
142
102
|
return grid
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def polar_pixel_grid(polar_limits, zlims, Nz: int, Nr: int):
|
|
106
|
+
"""Generate a polar grid.
|
|
107
|
+
|
|
108
|
+
Uses radial_pixel_grid but based on parameters that are present in the scan class.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
polar_limits (tuple): Azimuthal limits of pixel grid ([azimuth_min, azimuth_max])
|
|
112
|
+
zlims (tuple): Depth limits of pixel grid ([zmin, zmax])
|
|
113
|
+
Nz (int, optional): Number of depth pixels.
|
|
114
|
+
Nr (int, optional): Number of azimuthal pixels.
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
grid (np.ndarray): Pixel grid of size (Nz, Nr, 3) in Cartesian coordinates (x, y, z)
|
|
118
|
+
"""
|
|
119
|
+
assert len(polar_limits) == 2, "polar_limits must be a tuple of length 2."
|
|
120
|
+
assert len(zlims) == 2, "zlims must be a tuple of length 2."
|
|
121
|
+
|
|
122
|
+
dr = (zlims[1] - zlims[0]) / Nz
|
|
123
|
+
|
|
124
|
+
oris = np.array([0, 0, 0])
|
|
125
|
+
oris = np.tile(oris, (Nr, 1))
|
|
126
|
+
dirs_az = np.linspace(*polar_limits, Nr)
|
|
127
|
+
|
|
128
|
+
dirs_el = np.zeros(Nr)
|
|
129
|
+
dirs = np.vstack((dirs_az, dirs_el)).T
|
|
130
|
+
return radial_pixel_grid(zlims, dr, oris, dirs).transpose(1, 0, 2)
|
zea/data/__main__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
"""Command-line interface for copying a zea.Folder to a new location.
|
|
2
|
+
|
|
3
|
+
Usage:
|
|
4
|
+
python -m zea.data <source_folder> <destination_folder> <key>
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
|
|
9
|
+
from zea import Folder
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def main():
|
|
13
|
+
parser = argparse.ArgumentParser(description="Copy a zea.Folder to a new location.")
|
|
14
|
+
parser.add_argument("src", help="Source folder path")
|
|
15
|
+
parser.add_argument("dst", help="Destination folder path")
|
|
16
|
+
parser.add_argument("key", help="Key to access in the hdf5 files")
|
|
17
|
+
parser.add_argument(
|
|
18
|
+
"--mode",
|
|
19
|
+
default="a",
|
|
20
|
+
choices=["a", "w", "r+", "x"],
|
|
21
|
+
help="Mode in which to open the destination files (default: 'a')",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
args = parser.parse_args()
|
|
25
|
+
|
|
26
|
+
src_folder = Folder(args.src, args.key, validate=False)
|
|
27
|
+
src_folder.copy(args.dst, args.key, mode=args.mode)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if __name__ == "__main__":
|
|
31
|
+
main()
|
zea/data/data_format.py
CHANGED
|
@@ -84,10 +84,12 @@ def generate_example_dataset(
|
|
|
84
84
|
aligned_data = np.ones((n_frames, n_tx, n_ax, n_el, n_ch))
|
|
85
85
|
envelope_data = np.ones((n_frames, n_z, n_x))
|
|
86
86
|
beamformed_data = np.ones((n_frames, n_z, n_x, n_ch))
|
|
87
|
+
image_sc = np.ones_like(image)
|
|
87
88
|
else:
|
|
88
89
|
aligned_data = None
|
|
89
90
|
envelope_data = None
|
|
90
91
|
beamformed_data = None
|
|
92
|
+
image_sc = None
|
|
91
93
|
|
|
92
94
|
generate_zea_dataset(
|
|
93
95
|
path,
|
|
@@ -96,6 +98,7 @@ def generate_example_dataset(
|
|
|
96
98
|
envelope_data=envelope_data,
|
|
97
99
|
beamformed_data=beamformed_data,
|
|
98
100
|
image=image,
|
|
101
|
+
image_sc=image_sc,
|
|
99
102
|
probe_geometry=probe_geometry,
|
|
100
103
|
sampling_frequency=sampling_frequency,
|
|
101
104
|
center_frequency=center_frequency,
|
zea/data/datasets.py
CHANGED
|
@@ -113,14 +113,14 @@ class H5FileHandleCache:
|
|
|
113
113
|
self._file_handle_cache = OrderedDict()
|
|
114
114
|
|
|
115
115
|
|
|
116
|
-
def find_h5_files(paths: str | list, key: str, search_file_tree_kwargs: dict | None = None):
|
|
116
|
+
def find_h5_files(paths: str | list, key: str = None, search_file_tree_kwargs: dict | None = None):
|
|
117
117
|
"""
|
|
118
|
-
Find HDF5 files from a directory or list of directories and retrieve their shapes.
|
|
118
|
+
Find HDF5 files from a directory or list of directories and optionally retrieve their shapes.
|
|
119
119
|
|
|
120
120
|
Args:
|
|
121
121
|
paths (str or list): A single directory path, a list of directory paths,
|
|
122
122
|
or a single HDF5 file path.
|
|
123
|
-
key (str): The key to
|
|
123
|
+
key (str, optional): The key to get the file shapes for.
|
|
124
124
|
search_file_tree_kwargs (dict, optional): Additional keyword arguments for the
|
|
125
125
|
search_file_tree function. Defaults to None.
|
|
126
126
|
|
|
@@ -147,7 +147,8 @@ def find_h5_files(paths: str | list, key: str, search_file_tree_kwargs: dict | N
|
|
|
147
147
|
if Path(path).is_file():
|
|
148
148
|
path = Path(path)
|
|
149
149
|
# If the path is a file, get its shape directly
|
|
150
|
-
|
|
150
|
+
if key is not None:
|
|
151
|
+
file_shapes.append(File.get_shape(path, key))
|
|
151
152
|
file_paths.append(str(path))
|
|
152
153
|
continue
|
|
153
154
|
|
|
@@ -171,7 +172,7 @@ class Folder:
|
|
|
171
172
|
def __init__(
|
|
172
173
|
self,
|
|
173
174
|
folder_path: list[str] | list[Path],
|
|
174
|
-
key: str,
|
|
175
|
+
key: str = None,
|
|
175
176
|
search_file_tree_kwargs: dict | None = None,
|
|
176
177
|
validate: bool = True,
|
|
177
178
|
hf_cache_dir: str = HF_DATASETS_DIR,
|
|
@@ -192,7 +193,7 @@ class Folder:
|
|
|
192
193
|
|
|
193
194
|
super().__init__(**kwargs)
|
|
194
195
|
|
|
195
|
-
self.folder_path = folder_path
|
|
196
|
+
self.folder_path = Path(folder_path)
|
|
196
197
|
self.key = key
|
|
197
198
|
self.search_file_tree_kwargs = search_file_tree_kwargs
|
|
198
199
|
self.validate = validate
|
|
@@ -346,6 +347,64 @@ class Folder:
|
|
|
346
347
|
def __str__(self):
|
|
347
348
|
return f"Folder with {self.n_files} files in '{self.folder_path}' (key='{self.key}')"
|
|
348
349
|
|
|
350
|
+
def copy(self, to_path: str | Path, key: str = None, mode: str | None = None):
|
|
351
|
+
"""Copy the data for all or a specific key to a new location.
|
|
352
|
+
|
|
353
|
+
Has the option to copy all keys or only a specific key. By default, it only copies if the
|
|
354
|
+
destination file does not already contain the key. You can change the mode to 'w' to
|
|
355
|
+
overwrite the destination file. Will always copy metadata such as dataset attributes and
|
|
356
|
+
scan object.
|
|
357
|
+
|
|
358
|
+
Args:
|
|
359
|
+
to_path (str or Path): The destination path where files will be copied.
|
|
360
|
+
key (str, optional): The key to copy from the source files.
|
|
361
|
+
If 'all' or '*', all keys will be copied. Defaults to None, which
|
|
362
|
+
uses the key set in the Folder instance.
|
|
363
|
+
mode (str): The mode in which to open the destination files.
|
|
364
|
+
Defaults to 'a' (append mode), and 'w' (write mode) if key is 'all' or '*'.
|
|
365
|
+
See: https://docs.h5py.org/en/stable/high/file.html#opening-creating-files
|
|
366
|
+
"""
|
|
367
|
+
if key is None and self.key is None:
|
|
368
|
+
raise ValueError(
|
|
369
|
+
"No key specified. Please provide a key to copy the data for, or set the "
|
|
370
|
+
"key attribute of the Folder instance."
|
|
371
|
+
)
|
|
372
|
+
elif key is None:
|
|
373
|
+
key = self.key
|
|
374
|
+
|
|
375
|
+
all_keys = key == "all" or key == "*"
|
|
376
|
+
|
|
377
|
+
if mode is None:
|
|
378
|
+
mode = "a" if not all_keys else "w"
|
|
379
|
+
|
|
380
|
+
if all_keys:
|
|
381
|
+
key_msg = "Including all keys."
|
|
382
|
+
assert mode in ["w", "x"], (
|
|
383
|
+
"If you want to copy all keys, the destination file must be opened "
|
|
384
|
+
"in 'w' or 'x' mode, which means it will be overwritten or created."
|
|
385
|
+
)
|
|
386
|
+
else:
|
|
387
|
+
key_msg = f"Only copying key '{key}'."
|
|
388
|
+
assert mode in ["a", "w", "r+", "x"], (
|
|
389
|
+
f"Invalid mode '{mode}'. Must be one of 'a', 'w', 'r+', or 'x'."
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
to_path = Path(to_path)
|
|
393
|
+
to_path.mkdir(parents=True, exist_ok=True)
|
|
394
|
+
|
|
395
|
+
for file_path in tqdm.tqdm(
|
|
396
|
+
self.file_paths,
|
|
397
|
+
total=self.n_files,
|
|
398
|
+
desc=f"Copying dataset from {self.folder_path} to {to_path}. {key_msg}",
|
|
399
|
+
):
|
|
400
|
+
dst_path = Path(file_path).relative_to(self.folder_path)
|
|
401
|
+
with File(file_path) as src, File(to_path / dst_path, mode) as dst:
|
|
402
|
+
if all_keys:
|
|
403
|
+
for obj in src.keys():
|
|
404
|
+
src.copy(obj, dst)
|
|
405
|
+
else:
|
|
406
|
+
src.copy_key(key, dst)
|
|
407
|
+
|
|
349
408
|
|
|
350
409
|
class Dataset(H5FileHandleCache):
|
|
351
410
|
"""Iterate over File(s) and Folder(s)."""
|
|
@@ -582,3 +641,10 @@ def count_samples_per_directory(file_names, directories):
|
|
|
582
641
|
)
|
|
583
642
|
|
|
584
643
|
return counts
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
if __name__ == "__main__":
|
|
647
|
+
src_folder = Folder(
|
|
648
|
+
"/mnt/z/Ultrasound-BMd/data/USBMD_datasets/CAMUS/val/patient0450", "image", validate=False
|
|
649
|
+
)
|
|
650
|
+
src_folder.copy("./CAMIUS", key="all")
|
zea/data/file.py
CHANGED
|
@@ -175,6 +175,8 @@ class File(h5py.File):
|
|
|
175
175
|
if isinstance(key, enum.Enum):
|
|
176
176
|
key = key.value
|
|
177
177
|
|
|
178
|
+
assert isinstance(key, str), f"Key must be a string, got {type(key)}. "
|
|
179
|
+
|
|
178
180
|
# Return the key if it is in the file
|
|
179
181
|
if key in self.keys():
|
|
180
182
|
return key
|
|
@@ -448,6 +450,35 @@ class File(h5py.File):
|
|
|
448
450
|
def __str__(self):
|
|
449
451
|
return f"zea HDF5 File: '{self.path.name}' (mode={self.mode})"
|
|
450
452
|
|
|
453
|
+
def copy_key(self, key: str, dst: "File"):
|
|
454
|
+
"""Copy a specific key to another file.
|
|
455
|
+
|
|
456
|
+
Will always copy the attributes and the scan data if it exists. Will warn if the key is
|
|
457
|
+
not in this file or if the key already exists in the destination file.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
key (str): The key to copy.
|
|
461
|
+
dst (File): The destination file to copy the key to.
|
|
462
|
+
"""
|
|
463
|
+
key = self.format_key(key)
|
|
464
|
+
|
|
465
|
+
# Copy the key if it does not already exist in the destination file
|
|
466
|
+
if key in dst:
|
|
467
|
+
log.warning(f"Skipping key '{key}' because it already exists in dst file {dst.path}.")
|
|
468
|
+
elif key in self:
|
|
469
|
+
self.copy(key, dst, name=key)
|
|
470
|
+
else:
|
|
471
|
+
log.warning(f"Key '{key}' not found in src file {self.path}. Skipping copy.")
|
|
472
|
+
|
|
473
|
+
# Copy attributes from src to dst
|
|
474
|
+
for attr_key, attr_value in self.attrs.items():
|
|
475
|
+
dst[key].attrs[attr_key] = attr_value
|
|
476
|
+
|
|
477
|
+
# Copy scan data if requested
|
|
478
|
+
if "scan" in self and "scan" not in dst:
|
|
479
|
+
# Copy the scan data if it exists
|
|
480
|
+
self.copy("scan", dst)
|
|
481
|
+
|
|
451
482
|
def summary(self):
|
|
452
483
|
"""Print the contents of the file."""
|
|
453
484
|
_print_hdf5_attrs(self)
|
zea/display.py
CHANGED
|
@@ -3,6 +3,7 @@
|
|
|
3
3
|
from functools import partial
|
|
4
4
|
from typing import Tuple, Union
|
|
5
5
|
|
|
6
|
+
import keras
|
|
6
7
|
import numpy as np
|
|
7
8
|
import scipy
|
|
8
9
|
from keras import ops
|
|
@@ -357,6 +358,11 @@ def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value
|
|
|
357
358
|
|
|
358
359
|
def _interpolate_batch(images, coordinates, fill_value=0.0, order=1, vectorize=True):
|
|
359
360
|
"""Interpolate a batch of images."""
|
|
361
|
+
|
|
362
|
+
# TODO: figure out why tensorflow map_coordinates is broken
|
|
363
|
+
if keras.backend.backend() == "tensorflow":
|
|
364
|
+
assert order > 1, "Some bug in tensorflow in map_coordinates, set order > 1 to use scipy."
|
|
365
|
+
|
|
360
366
|
image_shape = images.shape
|
|
361
367
|
num_image_dims = coordinates.shape[0]
|
|
362
368
|
|
zea/internal/cache.py
CHANGED
|
@@ -80,7 +80,7 @@ _CACHE_DIR = ZEA_CACHE_DIR / "cached_funcs"
|
|
|
80
80
|
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
81
81
|
|
|
82
82
|
|
|
83
|
-
def serialize_elements(key_elements: list):
|
|
83
|
+
def serialize_elements(key_elements: list, shorten: bool = False) -> str:
|
|
84
84
|
"""Serialize elements of a list to generate a cache key.
|
|
85
85
|
|
|
86
86
|
In general uses the string representation of the elements unless
|
|
@@ -90,18 +90,18 @@ def serialize_elements(key_elements: list):
|
|
|
90
90
|
Args:
|
|
91
91
|
key_elements (list): List of elements to serialize. Can be nested lists
|
|
92
92
|
or tuples. In this case the elements are serialized recursively.
|
|
93
|
+
shorten (bool): If True, the serialized string is hashed to a shorter
|
|
94
|
+
representation using MD5. Defaults to False.
|
|
93
95
|
|
|
94
96
|
Returns:
|
|
95
|
-
|
|
96
|
-
lists of tuples those are combined into a single string.
|
|
97
|
+
str: A serialized string representation of the elements, joined by underscores.
|
|
97
98
|
|
|
98
99
|
"""
|
|
99
100
|
serialized_elements = []
|
|
100
101
|
for element in key_elements:
|
|
101
102
|
if isinstance(element, (list, tuple)):
|
|
102
103
|
# If element is a list or tuple, serialize its elements recursively
|
|
103
|
-
|
|
104
|
-
serialized_elements.append("_".join(element))
|
|
104
|
+
serialized_elements.append(serialize_elements(element))
|
|
105
105
|
elif hasattr(element, "serialized"):
|
|
106
106
|
# Use the serialized attribute if it exists (e.g. for zea.core.Object)
|
|
107
107
|
serialized_elements.append(str(element.serialized))
|
|
@@ -120,7 +120,10 @@ def serialize_elements(key_elements: list):
|
|
|
120
120
|
element = hashlib.md5(element).hexdigest()
|
|
121
121
|
serialized_elements.append(element)
|
|
122
122
|
|
|
123
|
-
|
|
123
|
+
serialized = "_".join(serialized_elements)
|
|
124
|
+
if shorten:
|
|
125
|
+
return hashlib.md5(serialized.encode()).hexdigest()
|
|
126
|
+
return serialized
|
|
124
127
|
|
|
125
128
|
|
|
126
129
|
def get_function_source(func):
|
|
@@ -182,8 +185,10 @@ def generate_cache_key(func, args, kwargs, arg_names):
|
|
|
182
185
|
if name in bound_args.arguments:
|
|
183
186
|
key_elements.append(bound_args.arguments[name])
|
|
184
187
|
|
|
185
|
-
|
|
186
|
-
|
|
188
|
+
# Add keras backend
|
|
189
|
+
key_elements.append(keras.backend.backend())
|
|
190
|
+
|
|
191
|
+
return f"{func.__qualname__}_" + serialize_elements(key_elements, shorten=True)
|
|
187
192
|
|
|
188
193
|
|
|
189
194
|
def cache_output(*arg_names, verbose=False):
|