zea 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +8 -7
- zea/__main__.py +8 -26
- zea/data/__main__.py +6 -3
- zea/data/file.py +19 -74
- zea/display.py +1 -5
- zea/doppler.py +75 -0
- zea/internal/_generate_keras_ops.py +125 -0
- zea/internal/core.py +10 -3
- zea/internal/device.py +33 -16
- zea/internal/notebooks.py +39 -0
- zea/internal/operators.py +10 -0
- zea/internal/parameters.py +75 -19
- zea/internal/viewer.py +24 -24
- zea/io_lib.py +60 -62
- zea/keras_ops.py +1989 -0
- zea/models/__init__.py +6 -3
- zea/models/deeplabv3.py +131 -0
- zea/models/diffusion.py +4 -4
- zea/models/echonetlvh.py +290 -0
- zea/models/presets.py +14 -0
- zea/ops.py +28 -45
- zea/scan.py +10 -3
- zea/tensor_ops.py +150 -0
- zea/tools/fit_scan_cone.py +2 -2
- zea/tools/selection_tool.py +28 -9
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/METADATA +5 -2
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/RECORD +30 -25
- zea/internal/convert.py +0 -150
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/LICENSE +0 -0
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/WHEEL +0 -0
- {zea-0.0.4.dist-info → zea-0.0.5.dist-info}/entry_points.txt +0 -0
zea/__init__.py
CHANGED
|
@@ -7,10 +7,10 @@ from . import log
|
|
|
7
7
|
|
|
8
8
|
# dynamically add __version__ attribute (see pyproject.toml)
|
|
9
9
|
# __version__ = __import__("importlib.metadata").metadata.version(__package__)
|
|
10
|
-
__version__ = "0.0.
|
|
10
|
+
__version__ = "0.0.5"
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def _bootstrap_backend():
|
|
14
14
|
"""Setup function to initialize the zea package."""
|
|
15
15
|
|
|
16
16
|
def _check_backend_installed():
|
|
@@ -40,14 +40,14 @@ def setup():
|
|
|
40
40
|
|
|
41
41
|
_check_backend_installed()
|
|
42
42
|
|
|
43
|
-
import
|
|
43
|
+
from keras.backend import backend as keras_backend
|
|
44
44
|
|
|
45
|
-
log.info(f"Using backend {
|
|
45
|
+
log.info(f"Using backend {keras_backend()!r}")
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
# call and clean up namespace
|
|
49
|
-
|
|
50
|
-
del
|
|
49
|
+
_bootstrap_backend()
|
|
50
|
+
del _bootstrap_backend
|
|
51
51
|
|
|
52
52
|
from . import (
|
|
53
53
|
agent,
|
|
@@ -55,6 +55,7 @@ from . import (
|
|
|
55
55
|
data,
|
|
56
56
|
display,
|
|
57
57
|
io_lib,
|
|
58
|
+
keras_ops,
|
|
58
59
|
metrics,
|
|
59
60
|
models,
|
|
60
61
|
simulator,
|
|
@@ -68,7 +69,7 @@ from .data.file import File, load_file
|
|
|
68
69
|
from .datapaths import set_data_paths
|
|
69
70
|
from .interface import Interface
|
|
70
71
|
from .internal.device import init_device
|
|
71
|
-
from .internal.setup_zea import
|
|
72
|
+
from .internal.setup_zea import setup, setup_config
|
|
72
73
|
from .ops import Pipeline
|
|
73
74
|
from .probes import Probe
|
|
74
75
|
from .scan import Scan
|
zea/__main__.py
CHANGED
|
@@ -9,30 +9,22 @@ import argparse
|
|
|
9
9
|
import sys
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
|
|
12
|
-
from zea import log
|
|
13
12
|
from zea.visualize import set_mpl_style
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
def
|
|
15
|
+
def get_parser():
|
|
17
16
|
"""Command line argument parser"""
|
|
18
|
-
parser = argparse.ArgumentParser(
|
|
19
|
-
|
|
17
|
+
parser = argparse.ArgumentParser(
|
|
18
|
+
description="Load and process ultrasound data based on a configuration file."
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument("-c", "--config", type=str, default=None, help="path to the config file.")
|
|
20
21
|
parser.add_argument(
|
|
21
22
|
"-t",
|
|
22
23
|
"--task",
|
|
23
24
|
default="view",
|
|
24
25
|
choices=["view"],
|
|
25
26
|
type=str,
|
|
26
|
-
help="
|
|
27
|
-
)
|
|
28
|
-
parser.add_argument(
|
|
29
|
-
"--backend",
|
|
30
|
-
default=None,
|
|
31
|
-
type=str,
|
|
32
|
-
help=(
|
|
33
|
-
"Keras backend to use. Default is the one set by the environment "
|
|
34
|
-
"variable KERAS_BACKEND."
|
|
35
|
-
),
|
|
27
|
+
help="Which task to run. Currently only 'view' is supported.",
|
|
36
28
|
)
|
|
37
29
|
parser.add_argument(
|
|
38
30
|
"--skip_validate_file",
|
|
@@ -40,27 +32,18 @@ def get_args():
|
|
|
40
32
|
action="store_true",
|
|
41
33
|
help="Skip zea file integrity checks. Use with caution.",
|
|
42
34
|
)
|
|
43
|
-
parser
|
|
44
|
-
args = parser.parse_args()
|
|
45
|
-
return args
|
|
35
|
+
return parser
|
|
46
36
|
|
|
47
37
|
|
|
48
38
|
def main():
|
|
49
39
|
"""main entrypoint for zea"""
|
|
50
|
-
args =
|
|
40
|
+
args = get_parser().parse_args()
|
|
51
41
|
|
|
52
42
|
set_mpl_style()
|
|
53
43
|
|
|
54
|
-
if args.backend:
|
|
55
|
-
from zea.internal.setup_zea import set_backend
|
|
56
|
-
|
|
57
|
-
set_backend(args.backend)
|
|
58
|
-
|
|
59
44
|
wd = Path(__file__).parent.resolve()
|
|
60
45
|
sys.path.append(str(wd))
|
|
61
46
|
|
|
62
|
-
import keras
|
|
63
|
-
|
|
64
47
|
from zea.interface import Interface
|
|
65
48
|
from zea.internal.setup_zea import setup
|
|
66
49
|
|
|
@@ -72,7 +55,6 @@ def main():
|
|
|
72
55
|
validate_file=not args.skip_validate_file,
|
|
73
56
|
)
|
|
74
57
|
|
|
75
|
-
log.info(f"Using {keras.backend.backend()} backend")
|
|
76
58
|
cli.run(plot=True)
|
|
77
59
|
else:
|
|
78
60
|
raise ValueError(f"Unknown task {args.task}, see `zea --help` for available tasks.")
|
zea/data/__main__.py
CHANGED
|
@@ -9,8 +9,8 @@ import argparse
|
|
|
9
9
|
from zea import Folder
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def
|
|
13
|
-
parser = argparse.ArgumentParser(description="Copy a zea.Folder to a new location.")
|
|
12
|
+
def get_parser():
|
|
13
|
+
parser = argparse.ArgumentParser(description="Copy a :class:`zea.Folder` to a new location.")
|
|
14
14
|
parser.add_argument("src", help="Source folder path")
|
|
15
15
|
parser.add_argument("dst", help="Destination folder path")
|
|
16
16
|
parser.add_argument("key", help="Key to access in the hdf5 files")
|
|
@@ -20,8 +20,11 @@ def main():
|
|
|
20
20
|
choices=["a", "w", "r+", "x"],
|
|
21
21
|
help="Mode in which to open the destination files (default: 'a')",
|
|
22
22
|
)
|
|
23
|
+
return parser
|
|
24
|
+
|
|
23
25
|
|
|
24
|
-
|
|
26
|
+
def main():
|
|
27
|
+
args = get_parser().parse_args()
|
|
25
28
|
|
|
26
29
|
src_folder = Folder(args.src, args.key, validate=False)
|
|
27
30
|
src_folder.copy(args.dst, args.key, mode=args.mode)
|
zea/data/file.py
CHANGED
|
@@ -290,7 +290,7 @@ class File(h5py.File):
|
|
|
290
290
|
"""
|
|
291
291
|
scan_parameters = {}
|
|
292
292
|
if "scan" in self:
|
|
293
|
-
scan_parameters = recursively_load_dict_contents_from_group(
|
|
293
|
+
scan_parameters = self.recursively_load_dict_contents_from_group("scan")
|
|
294
294
|
elif "event" in list(self.keys())[0]:
|
|
295
295
|
if event is None:
|
|
296
296
|
raise ValueError(
|
|
@@ -305,38 +305,17 @@ class File(h5py.File):
|
|
|
305
305
|
f"Found number of events: {len(self.keys())}."
|
|
306
306
|
)
|
|
307
307
|
|
|
308
|
-
scan_parameters = recursively_load_dict_contents_from_group(
|
|
308
|
+
scan_parameters = self.recursively_load_dict_contents_from_group(f"event_{event}/scan")
|
|
309
309
|
else:
|
|
310
310
|
log.warning("Could not find scan parameters in file.")
|
|
311
311
|
|
|
312
312
|
return scan_parameters
|
|
313
313
|
|
|
314
314
|
def get_scan_parameters(self, event=None) -> dict:
|
|
315
|
-
"""Returns a dictionary of
|
|
316
|
-
|
|
315
|
+
"""Returns a dictionary of scan parameters stored in the file."""
|
|
316
|
+
return self.get_parameters(event)
|
|
317
317
|
|
|
318
|
-
|
|
319
|
-
dict: The default parameters (the keys are identical to the
|
|
320
|
-
__init__ parameters of the Scan class).
|
|
321
|
-
"""
|
|
322
|
-
file_scan_parameters = self.get_parameters(event)
|
|
323
|
-
|
|
324
|
-
scan_parameters = {}
|
|
325
|
-
for parameter, value in file_scan_parameters.items():
|
|
326
|
-
if parameter in Scan.VALID_PARAMS:
|
|
327
|
-
param_type = Scan.VALID_PARAMS[parameter]["type"]
|
|
328
|
-
if param_type in (bool, int, float):
|
|
329
|
-
scan_parameters[parameter] = param_type(value)
|
|
330
|
-
elif isinstance(param_type, tuple) and float in param_type:
|
|
331
|
-
scan_parameters[parameter] = float(value)
|
|
332
|
-
else:
|
|
333
|
-
scan_parameters[parameter] = value
|
|
334
|
-
|
|
335
|
-
if len(scan_parameters) == 0:
|
|
336
|
-
log.info(f"Could not find proper scan parameters in {self}.")
|
|
337
|
-
return scan_parameters
|
|
338
|
-
|
|
339
|
-
def scan(self, event=None, **kwargs) -> Scan:
|
|
318
|
+
def scan(self, event=None, safe=True, **kwargs) -> Scan:
|
|
340
319
|
"""Returns a Scan object initialized with the parameters from the file.
|
|
341
320
|
|
|
342
321
|
Args:
|
|
@@ -348,6 +327,9 @@ class File(h5py.File):
|
|
|
348
327
|
...
|
|
349
328
|
|
|
350
329
|
Defaults to None. In that case no event structure is expected.
|
|
330
|
+
safe (bool, optional): If True, will only use parameters that are
|
|
331
|
+
defined in the Scan class. If False, will use all parameters
|
|
332
|
+
from the file. Defaults to True.
|
|
351
333
|
**kwargs: Additional keyword arguments to pass to the Scan object.
|
|
352
334
|
These will override the parameters from the file if they are
|
|
353
335
|
present in the file.
|
|
@@ -355,7 +337,7 @@ class File(h5py.File):
|
|
|
355
337
|
Returns:
|
|
356
338
|
Scan: The scan object.
|
|
357
339
|
"""
|
|
358
|
-
return Scan.merge(self.get_scan_parameters(event), kwargs)
|
|
340
|
+
return Scan.merge(self.get_scan_parameters(event), kwargs, safe=safe)
|
|
359
341
|
|
|
360
342
|
def get_probe_parameters(self, event=None) -> dict:
|
|
361
343
|
"""Returns a dictionary of probe parameters to initialize a probe
|
|
@@ -388,21 +370,24 @@ class File(h5py.File):
|
|
|
388
370
|
probe_parameters_file = self.get_probe_parameters(event)
|
|
389
371
|
return Probe.from_parameters(self.probe_name, probe_parameters_file)
|
|
390
372
|
|
|
391
|
-
def recursively_load_dict_contents_from_group(self, path: str
|
|
373
|
+
def recursively_load_dict_contents_from_group(self, path: str) -> dict:
|
|
392
374
|
"""Load dict from contents of group
|
|
393
375
|
|
|
394
376
|
Values inside the group are converted to numpy arrays
|
|
395
|
-
or primitive types (int, float, str).
|
|
396
|
-
arrays are converted to the corresponding primitive type (if squeeze=True)
|
|
377
|
+
or primitive types (int, float, str).
|
|
397
378
|
|
|
398
379
|
Args:
|
|
399
380
|
path (str): path to group
|
|
400
|
-
squeeze (bool, optional): squeeze arrays with single element.
|
|
401
|
-
Defaults to False.
|
|
402
381
|
Returns:
|
|
403
382
|
dict: dictionary with contents of group
|
|
404
383
|
"""
|
|
405
|
-
|
|
384
|
+
ans = {}
|
|
385
|
+
for key, item in self[path].items():
|
|
386
|
+
if isinstance(item, h5py.Dataset):
|
|
387
|
+
ans[key] = item[()]
|
|
388
|
+
elif isinstance(item, h5py.Group):
|
|
389
|
+
ans[key] = self.recursively_load_dict_contents_from_group(path + "/" + key + "/")
|
|
390
|
+
return ans
|
|
406
391
|
|
|
407
392
|
@classmethod
|
|
408
393
|
def get_shape(cls, path: str, key: str) -> tuple:
|
|
@@ -519,54 +504,14 @@ def load_file(
|
|
|
519
504
|
# the number of selected transmits
|
|
520
505
|
if data_type in ["raw_data", "aligned_data"]:
|
|
521
506
|
indices = File._prepare_indices(indices)
|
|
522
|
-
n_tx = data.shape[1]
|
|
523
507
|
if isinstance(indices, tuple) and len(indices) > 1:
|
|
524
|
-
|
|
525
|
-
transmits = np.arange(n_tx)[tx_idx]
|
|
526
|
-
scan_kwargs["selected_transmits"] = transmits
|
|
508
|
+
scan_kwargs["selected_transmits"] = indices[1]
|
|
527
509
|
|
|
528
510
|
scan = file.scan(**scan_kwargs)
|
|
529
511
|
|
|
530
512
|
return data, scan, probe
|
|
531
513
|
|
|
532
514
|
|
|
533
|
-
def recursively_load_dict_contents_from_group(
|
|
534
|
-
h5file: h5py._hl.files.File, path: str, squeeze: bool = False
|
|
535
|
-
) -> dict:
|
|
536
|
-
"""Load dict from contents of group
|
|
537
|
-
|
|
538
|
-
Values inside the group are converted to numpy arrays
|
|
539
|
-
or primitive types (int, float, str). Single element
|
|
540
|
-
arrays are converted to the corresponding primitive type (if squeeze=True)
|
|
541
|
-
|
|
542
|
-
Args:
|
|
543
|
-
h5file (h5py._hl.files.File): h5py file object
|
|
544
|
-
path (str): path to group
|
|
545
|
-
squeeze (bool, optional): squeeze arrays with single element.
|
|
546
|
-
Defaults to False.
|
|
547
|
-
Returns:
|
|
548
|
-
dict: dictionary with contents of group
|
|
549
|
-
"""
|
|
550
|
-
ans = {}
|
|
551
|
-
for key, item in h5file[path].items():
|
|
552
|
-
if isinstance(item, h5py._hl.dataset.Dataset):
|
|
553
|
-
ans[key] = item[()]
|
|
554
|
-
# all ones in shape
|
|
555
|
-
if squeeze:
|
|
556
|
-
if ans[key].shape == () or all(i == 1 for i in ans[key].shape):
|
|
557
|
-
# check for strings
|
|
558
|
-
if isinstance(ans[key], str):
|
|
559
|
-
ans[key] = str(ans[key])
|
|
560
|
-
# check for integers
|
|
561
|
-
elif int(ans[key]) == float(ans[key]):
|
|
562
|
-
ans[key] = int(ans[key])
|
|
563
|
-
else:
|
|
564
|
-
ans[key] = float(ans[key])
|
|
565
|
-
elif isinstance(item, h5py._hl.group.Group):
|
|
566
|
-
ans[key] = recursively_load_dict_contents_from_group(h5file, path + "/" + key + "/")
|
|
567
|
-
return ans
|
|
568
|
-
|
|
569
|
-
|
|
570
515
|
def _print_hdf5_attrs(hdf5_obj, prefix=""):
|
|
571
516
|
"""Recursively prints all keys, attributes, and shapes in an HDF5 file.
|
|
572
517
|
|
zea/display.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
|
3
3
|
from functools import partial
|
|
4
4
|
from typing import Tuple, Union
|
|
5
5
|
|
|
6
|
-
import keras
|
|
7
6
|
import numpy as np
|
|
8
7
|
import scipy
|
|
9
8
|
from keras import ops
|
|
@@ -342,6 +341,7 @@ def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value
|
|
|
342
341
|
"""map_coordinates using keras.ops or scipy.ndimage when order > 1."""
|
|
343
342
|
if order > 1:
|
|
344
343
|
inputs = ops.convert_to_numpy(inputs)
|
|
344
|
+
coordinates = ops.convert_to_numpy(coordinates)
|
|
345
345
|
out = scipy.ndimage.map_coordinates(
|
|
346
346
|
inputs, coordinates, order=order, mode=fill_mode, cval=fill_value
|
|
347
347
|
)
|
|
@@ -359,10 +359,6 @@ def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value
|
|
|
359
359
|
def _interpolate_batch(images, coordinates, fill_value=0.0, order=1, vectorize=True):
|
|
360
360
|
"""Interpolate a batch of images."""
|
|
361
361
|
|
|
362
|
-
# TODO: figure out why tensorflow map_coordinates is broken
|
|
363
|
-
if keras.backend.backend() == "tensorflow":
|
|
364
|
-
assert order > 1, "Some bug in tensorflow in map_coordinates, set order > 1 to use scipy."
|
|
365
|
-
|
|
366
362
|
image_shape = images.shape
|
|
367
363
|
num_image_dims = coordinates.shape[0]
|
|
368
364
|
|
zea/doppler.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
1
|
+
"""Doppler functions for processing I/Q ultrasound data."""
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from keras import ops
|
|
5
|
+
|
|
6
|
+
from zea import tensor_ops
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def color_doppler(
|
|
10
|
+
data,
|
|
11
|
+
center_frequency,
|
|
12
|
+
pulse_repetition_frequency,
|
|
13
|
+
sound_speed,
|
|
14
|
+
hamming_size=None,
|
|
15
|
+
lag=1,
|
|
16
|
+
):
|
|
17
|
+
"""Compute Color Doppler from packet of I/Q Data.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
data (ndarray): I/Q complex data of shape (n_frames, grid_size_z, grid_size_x).
|
|
21
|
+
n_frames corresponds to the ensemble length used to compute
|
|
22
|
+
the Doppler signal.
|
|
23
|
+
center_frequency (float): Center frequency of the ultrasound probe in Hz.
|
|
24
|
+
pulse_repetition_frequency (float): Pulse repetition frequency in Hz.
|
|
25
|
+
sound_speed (float): Speed of sound in the medium in m/s.
|
|
26
|
+
hamming_size (int or tuple, optional): Size of the Hamming window to apply
|
|
27
|
+
for spatial averaging. If None, no window is applied.
|
|
28
|
+
If an integer, it is applied to both dimensions. If a tuple, it should
|
|
29
|
+
contain two integers for the row and column dimensions.
|
|
30
|
+
lag (int, optional): Lag for the auto-correlation computation.
|
|
31
|
+
Defaults to 1, meaning Doppler is computed from the current frame
|
|
32
|
+
and the next frame.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
doppler_velocities (ndarray): Doppler velocity map of shape (grid_size_z, grid_size_x) in
|
|
36
|
+
meters/second.
|
|
37
|
+
|
|
38
|
+
"""
|
|
39
|
+
assert data.ndim == 3, "Data must be a 3-D array"
|
|
40
|
+
if not (isinstance(lag, int) and lag >= 1):
|
|
41
|
+
raise ValueError("lag must be an integer >= 1")
|
|
42
|
+
n_frames = data.shape[0]
|
|
43
|
+
assert n_frames > lag, "Data must have more frames than the lag"
|
|
44
|
+
|
|
45
|
+
if hamming_size is None:
|
|
46
|
+
hamming_size = np.array([1, 1], dtype=int)
|
|
47
|
+
elif np.isscalar(hamming_size):
|
|
48
|
+
hamming_size = np.array([int(hamming_size), int(hamming_size)], dtype=int)
|
|
49
|
+
else:
|
|
50
|
+
assert len(hamming_size) == 2, "hamming_size must be an integer or a tuple of two integers"
|
|
51
|
+
hamming_size = np.array(hamming_size, dtype=int)
|
|
52
|
+
if not np.all(hamming_size > 0):
|
|
53
|
+
raise ValueError("hamming_size must contain integers > 0")
|
|
54
|
+
|
|
55
|
+
# Auto-correlation method
|
|
56
|
+
iq1 = data[: n_frames - lag]
|
|
57
|
+
iq2 = data[lag:]
|
|
58
|
+
autocorr = ops.sum(iq1 * ops.conj(iq2), axis=0) # Ensemble auto-correlation
|
|
59
|
+
|
|
60
|
+
# Spatial weighted average
|
|
61
|
+
if hamming_size[0] != 1 and hamming_size[1] != 1:
|
|
62
|
+
h_row = np.hamming(hamming_size[0])
|
|
63
|
+
h_col = np.hamming(hamming_size[1])
|
|
64
|
+
autocorr = tensor_ops.apply_along_axis(
|
|
65
|
+
lambda x: tensor_ops.correlate(x, h_row, mode="same"), 0, autocorr
|
|
66
|
+
)
|
|
67
|
+
autocorr = tensor_ops.apply_along_axis(
|
|
68
|
+
lambda x: tensor_ops.correlate(x, h_col, mode="same"), 1, autocorr
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# Doppler velocity
|
|
72
|
+
nyquist_velocity = sound_speed * pulse_repetition_frequency / (4 * center_frequency * lag)
|
|
73
|
+
phase = ops.arctan2(ops.imag(autocorr), ops.real(autocorr))
|
|
74
|
+
doppler_velocities = -nyquist_velocity * phase / np.pi
|
|
75
|
+
return doppler_velocities
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
"""This file creates a :class:`zea.Operation` for all unary :mod:`keras.ops`
|
|
2
|
+
and :mod:`keras.ops.image` functions.
|
|
3
|
+
|
|
4
|
+
They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
|
|
5
|
+
|
|
6
|
+
.. code-block:: python
|
|
7
|
+
|
|
8
|
+
from zea.keras_ops import Squeeze
|
|
9
|
+
|
|
10
|
+
op = Squeeze(axis=1)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
import inspect
|
|
14
|
+
import shutil
|
|
15
|
+
import tempfile
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
import keras
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _filter_funcs_by_first_arg(funcs, arg_name):
|
|
22
|
+
"""Filter a list of (name, func) tuples to those whose first argument matches arg_name."""
|
|
23
|
+
filtered = []
|
|
24
|
+
for name, func in funcs:
|
|
25
|
+
try:
|
|
26
|
+
sig = inspect.signature(func)
|
|
27
|
+
params = list(sig.parameters.keys())
|
|
28
|
+
if params and params[0] == arg_name:
|
|
29
|
+
filtered.append((name, func))
|
|
30
|
+
except (ValueError, TypeError):
|
|
31
|
+
# Skip functions that can't be inspected
|
|
32
|
+
continue
|
|
33
|
+
return filtered
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _functions_from_namespace(namespace):
|
|
37
|
+
"""Get all functions from a given namespace."""
|
|
38
|
+
return [(name, obj) for name, obj in inspect.getmembers(namespace) if inspect.isfunction(obj)]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _unary_functions_from_namespace(namespace, arg_name="x"):
|
|
42
|
+
"""Get all unary functions from a given namespace."""
|
|
43
|
+
funcs = _functions_from_namespace(namespace)
|
|
44
|
+
return _filter_funcs_by_first_arg(funcs, arg_name)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _snake_to_pascal(name):
|
|
48
|
+
"""Convert a snake_case name to PascalCase."""
|
|
49
|
+
return "".join(word.capitalize() for word in name.split("_"))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _generate_operation_class_code(name, namespace):
|
|
53
|
+
"""Generate Python code for a zea.Operation class for a given keras.ops function."""
|
|
54
|
+
class_name = _snake_to_pascal(name)
|
|
55
|
+
module_path = f"{namespace.__name__}.{name}"
|
|
56
|
+
doc = f"Operation wrapping {module_path}."
|
|
57
|
+
|
|
58
|
+
return f'''
|
|
59
|
+
@ops_registry("{module_path}")
|
|
60
|
+
class {class_name}(Lambda):
|
|
61
|
+
"""{doc}"""
|
|
62
|
+
|
|
63
|
+
def __init__(self, **kwargs):
|
|
64
|
+
try:
|
|
65
|
+
super().__init__(func={module_path}, **kwargs)
|
|
66
|
+
except AttributeError as e:
|
|
67
|
+
raise MissingKerasOps("{class_name}", "{module_path}") from e
|
|
68
|
+
'''
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _generate_ops_file():
|
|
72
|
+
"""Generate a .py file with all operation class definitions."""
|
|
73
|
+
|
|
74
|
+
# File header with version info
|
|
75
|
+
content = f'''"""Auto-generated :class:`zea.Operation` for all unary :mod:`keras.ops`
|
|
76
|
+
and :mod:`keras.ops.image` functions.
|
|
77
|
+
|
|
78
|
+
They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
|
|
79
|
+
|
|
80
|
+
.. code-block:: python
|
|
81
|
+
|
|
82
|
+
from zea.keras_ops import Squeeze
|
|
83
|
+
|
|
84
|
+
op = Squeeze(axis=1)
|
|
85
|
+
|
|
86
|
+
This file is generated automatically. Do not edit manually.
|
|
87
|
+
Generated with Keras {keras.__version__}
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
import keras
|
|
91
|
+
|
|
92
|
+
from zea.internal.registry import ops_registry
|
|
93
|
+
from zea.ops import Lambda
|
|
94
|
+
|
|
95
|
+
class MissingKerasOps(ValueError):
|
|
96
|
+
def __init__(self, class_name: str, func: str):
|
|
97
|
+
super().__init__(
|
|
98
|
+
f"Failed to create {{class_name}} with {{func}}. " +
|
|
99
|
+
"This may be due to an incompatible version of `keras`. " +
|
|
100
|
+
"Please try to upgrade `keras` to the latest version by running " +
|
|
101
|
+
"`pip install --upgrade keras`."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
'''
|
|
105
|
+
|
|
106
|
+
for name, _ in _unary_functions_from_namespace(keras.ops, "x"):
|
|
107
|
+
content += _generate_operation_class_code(name, keras.ops)
|
|
108
|
+
|
|
109
|
+
for name, _ in _unary_functions_from_namespace(keras.ops.image, "images"):
|
|
110
|
+
content += _generate_operation_class_code(name, keras.ops.image)
|
|
111
|
+
|
|
112
|
+
# Write to a temporary file first, then move to final location
|
|
113
|
+
target_path = Path(__file__).parent.parent / "keras_ops.py"
|
|
114
|
+
with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as tmp_file:
|
|
115
|
+
tmp_file.write(content)
|
|
116
|
+
temp_path = Path(tmp_file.name)
|
|
117
|
+
|
|
118
|
+
# Atomic move to avoid partial writes
|
|
119
|
+
shutil.move(temp_path, target_path)
|
|
120
|
+
|
|
121
|
+
print("Done generating `keras_ops.py`.")
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
if __name__ == "__main__":
|
|
125
|
+
_generate_ops_file()
|
zea/internal/core.py
CHANGED
|
@@ -119,11 +119,18 @@ class Object:
|
|
|
119
119
|
return cls(**reduced_params)
|
|
120
120
|
|
|
121
121
|
@classmethod
|
|
122
|
-
def merge(cls, obj1: dict, obj2: dict):
|
|
123
|
-
"""Merge multiple objects and safely initialize a new object.
|
|
122
|
+
def merge(cls, obj1: dict, obj2: dict, safe: bool = False):
|
|
123
|
+
"""Merge multiple objects and safely initialize a new object.
|
|
124
|
+
|
|
125
|
+
Optionally can safely initialize the object, which removes any invalid
|
|
126
|
+
arguments.
|
|
127
|
+
"""
|
|
124
128
|
# TODO: support actual zea.core.Objects, now we only support dictionaries
|
|
125
129
|
params = update_dictionary(obj1, obj2)
|
|
126
|
-
|
|
130
|
+
if not safe:
|
|
131
|
+
return cls(**params)
|
|
132
|
+
else:
|
|
133
|
+
return cls.safe_initialize(**params)
|
|
127
134
|
|
|
128
135
|
@classmethod
|
|
129
136
|
def _tree_unflatten(cls, aux, children):
|
zea/internal/device.py
CHANGED
|
@@ -64,24 +64,43 @@ def get_gpu_memory(verbose=True):
|
|
|
64
64
|
|
|
65
65
|
Returns:
|
|
66
66
|
memory_free_values: list of available memory for each gpu in MiB.
|
|
67
|
+
Returns empty list if nvidia-smi is not available.
|
|
67
68
|
"""
|
|
68
69
|
if not check_nvidia_smi():
|
|
69
70
|
log.warning(
|
|
70
71
|
"nvidia-smi is not available. Please install nvidia-utils. "
|
|
71
|
-
"Cannot retrieve GPU memory. Falling back to CPU
|
|
72
|
+
"Cannot retrieve GPU memory. Falling back to CPU."
|
|
72
73
|
)
|
|
73
|
-
return
|
|
74
|
+
return []
|
|
74
75
|
|
|
75
76
|
def _output_to_list(x):
|
|
76
77
|
return x.decode("ascii").split("\n")[:-1]
|
|
77
78
|
|
|
78
|
-
COMMAND =
|
|
79
|
+
COMMAND = [
|
|
80
|
+
"nvidia-smi",
|
|
81
|
+
"--query-gpu=memory.free",
|
|
82
|
+
"--format=csv,noheader,nounits",
|
|
83
|
+
]
|
|
84
|
+
# Fail-safe timeout (seconds). Override with ZEA_NVIDIA_SMI_TIMEOUT; set <=0 to disable.
|
|
85
|
+
smi_timeout = float(os.getenv("ZEA_NVIDIA_SMI_TIMEOUT", "30"))
|
|
79
86
|
try:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
87
|
+
if smi_timeout > 0:
|
|
88
|
+
raw = sp.check_output(COMMAND, timeout=smi_timeout)
|
|
89
|
+
else:
|
|
90
|
+
raw = sp.check_output(COMMAND)
|
|
91
|
+
memory_free_info = _output_to_list(raw)
|
|
92
|
+
except sp.TimeoutExpired:
|
|
93
|
+
log.warning(f"nvidia-smi timed out after {smi_timeout}s. Falling back to CPU.")
|
|
94
|
+
return []
|
|
95
|
+
except sp.SubprocessError as e:
|
|
96
|
+
log.warning(f"Failed to retrieve GPU memory: {e}")
|
|
97
|
+
return []
|
|
98
|
+
|
|
99
|
+
memory_free_values = [int(x) for x in memory_free_info]
|
|
83
100
|
|
|
84
|
-
|
|
101
|
+
if verbose:
|
|
102
|
+
header = "GPU settings"
|
|
103
|
+
print("-" * 2 + header.center(50 - 4, "-") + "-" * 2)
|
|
85
104
|
|
|
86
105
|
# only show enabled devices
|
|
87
106
|
if os.environ.get("CUDA_VISIBLE_DEVICES", "") != "":
|
|
@@ -89,10 +108,12 @@ def get_gpu_memory(verbose=True):
|
|
|
89
108
|
gpus = [int(gpu) for gpu in gpus.split(",")][: len(memory_free_values)]
|
|
90
109
|
if verbose:
|
|
91
110
|
# Report the number of disabled GPUs out of the total
|
|
92
|
-
num_disabled_gpus = len(memory_free_values) - len(gpus)
|
|
93
111
|
num_gpus = len(memory_free_values)
|
|
94
|
-
|
|
95
|
-
|
|
112
|
+
num_disabled_gpus = num_gpus - len(gpus)
|
|
113
|
+
if num_gpus > 0:
|
|
114
|
+
print(f"{num_disabled_gpus}/{num_gpus} GPUs were disabled")
|
|
115
|
+
else:
|
|
116
|
+
print("No GPUs detected by nvidia-smi.")
|
|
96
117
|
|
|
97
118
|
memory_free_values = [memory_free_values[gpu] for gpu in gpus]
|
|
98
119
|
|
|
@@ -253,15 +274,11 @@ def get_device(device="auto:1", verbose=True, hide_others=True):
|
|
|
253
274
|
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
|
254
275
|
# returns None to indicate CPU
|
|
255
276
|
|
|
256
|
-
if device.lower() == "cpu":
|
|
277
|
+
if isinstance(device, str) and device.lower() == "cpu":
|
|
257
278
|
return _cpu_case()
|
|
258
279
|
|
|
259
|
-
if verbose:
|
|
260
|
-
header = "GPU settings"
|
|
261
|
-
print("-" * 2 + header.center(50 - 4, "-") + "-" * 2)
|
|
262
|
-
|
|
263
280
|
memory = get_gpu_memory(verbose=verbose)
|
|
264
|
-
if memory
|
|
281
|
+
if len(memory) == 0: # nvidia-smi not working, fallback to CPU
|
|
265
282
|
return _cpu_case()
|
|
266
283
|
|
|
267
284
|
gpu_ids = list(range(len(memory)))
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import numpy as np
|
|
3
|
+
from matplotlib import animation
|
|
4
|
+
|
|
5
|
+
from zea import Scan
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def animate_images(
|
|
9
|
+
images, path, scan: Scan = None, interval=100, cmap="gray", figsize=(5, 4), dpi=80
|
|
10
|
+
):
|
|
11
|
+
"""Helper function to animate a list of images."""
|
|
12
|
+
if interval <= 0:
|
|
13
|
+
raise ValueError("interval must be a positive integer (milliseconds).")
|
|
14
|
+
if len(images) == 0:
|
|
15
|
+
raise ValueError("images must be a non-empty sequence.")
|
|
16
|
+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
|
|
17
|
+
if scan is not None:
|
|
18
|
+
extent = scan.extent * 1e3 if getattr(scan, "extent") is not None else None
|
|
19
|
+
else:
|
|
20
|
+
extent = None
|
|
21
|
+
im = ax.imshow(np.array(images[0]), animated=True, cmap=cmap, extent=extent)
|
|
22
|
+
ax.set_xlabel("X (mm)")
|
|
23
|
+
ax.set_ylabel("Z (mm)")
|
|
24
|
+
|
|
25
|
+
def update(frame):
|
|
26
|
+
im.set_array(np.array(images[frame]))
|
|
27
|
+
return [im]
|
|
28
|
+
|
|
29
|
+
ani = animation.FuncAnimation(
|
|
30
|
+
fig,
|
|
31
|
+
update,
|
|
32
|
+
frames=len(images),
|
|
33
|
+
blit=True,
|
|
34
|
+
interval=interval,
|
|
35
|
+
)
|
|
36
|
+
plt.close(fig)
|
|
37
|
+
fps = max(1, 1000 // interval)
|
|
38
|
+
|
|
39
|
+
ani.save(path, writer="imagemagick", fps=fps)
|