zea 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
zea/__init__.py CHANGED
@@ -7,10 +7,10 @@ from . import log
7
7
 
8
8
  # dynamically add __version__ attribute (see pyproject.toml)
9
9
  # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.4"
10
+ __version__ = "0.0.5"
11
11
 
12
12
 
13
- def setup():
13
+ def _bootstrap_backend():
14
14
  """Setup function to initialize the zea package."""
15
15
 
16
16
  def _check_backend_installed():
@@ -40,14 +40,14 @@ def setup():
40
40
 
41
41
  _check_backend_installed()
42
42
 
43
- import keras
43
+ from keras.backend import backend as keras_backend
44
44
 
45
- log.info(f"Using backend {keras.backend.backend()!r}")
45
+ log.info(f"Using backend {keras_backend()!r}")
46
46
 
47
47
 
48
48
  # call and clean up namespace
49
- setup()
50
- del setup
49
+ _bootstrap_backend()
50
+ del _bootstrap_backend
51
51
 
52
52
  from . import (
53
53
  agent,
@@ -55,6 +55,7 @@ from . import (
55
55
  data,
56
56
  display,
57
57
  io_lib,
58
+ keras_ops,
58
59
  metrics,
59
60
  models,
60
61
  simulator,
@@ -68,7 +69,7 @@ from .data.file import File, load_file
68
69
  from .datapaths import set_data_paths
69
70
  from .interface import Interface
70
71
  from .internal.device import init_device
71
- from .internal.setup_zea import set_backend, setup, setup_config
72
+ from .internal.setup_zea import setup, setup_config
72
73
  from .ops import Pipeline
73
74
  from .probes import Probe
74
75
  from .scan import Scan
zea/__main__.py CHANGED
@@ -9,30 +9,22 @@ import argparse
9
9
  import sys
10
10
  from pathlib import Path
11
11
 
12
- from zea import log
13
12
  from zea.visualize import set_mpl_style
14
13
 
15
14
 
16
- def get_args():
15
+ def get_parser():
17
16
  """Command line argument parser"""
18
- parser = argparse.ArgumentParser(description="Process ultrasound data.")
19
- parser.add_argument("-c", "--config", type=str, default=None, help="path to config file.")
17
+ parser = argparse.ArgumentParser(
18
+ description="Load and process ultrasound data based on a configuration file."
19
+ )
20
+ parser.add_argument("-c", "--config", type=str, default=None, help="path to the config file.")
20
21
  parser.add_argument(
21
22
  "-t",
22
23
  "--task",
23
24
  default="view",
24
25
  choices=["view"],
25
26
  type=str,
26
- help="which task to run",
27
- )
28
- parser.add_argument(
29
- "--backend",
30
- default=None,
31
- type=str,
32
- help=(
33
- "Keras backend to use. Default is the one set by the environment "
34
- "variable KERAS_BACKEND."
35
- ),
27
+ help="Which task to run. Currently only 'view' is supported.",
36
28
  )
37
29
  parser.add_argument(
38
30
  "--skip_validate_file",
@@ -40,27 +32,18 @@ def get_args():
40
32
  action="store_true",
41
33
  help="Skip zea file integrity checks. Use with caution.",
42
34
  )
43
- parser.add_argument("--gui", default=False, action=argparse.BooleanOptionalAction)
44
- args = parser.parse_args()
45
- return args
35
+ return parser
46
36
 
47
37
 
48
38
  def main():
49
39
  """main entrypoint for zea"""
50
- args = get_args()
40
+ args = get_parser().parse_args()
51
41
 
52
42
  set_mpl_style()
53
43
 
54
- if args.backend:
55
- from zea.internal.setup_zea import set_backend
56
-
57
- set_backend(args.backend)
58
-
59
44
  wd = Path(__file__).parent.resolve()
60
45
  sys.path.append(str(wd))
61
46
 
62
- import keras
63
-
64
47
  from zea.interface import Interface
65
48
  from zea.internal.setup_zea import setup
66
49
 
@@ -72,7 +55,6 @@ def main():
72
55
  validate_file=not args.skip_validate_file,
73
56
  )
74
57
 
75
- log.info(f"Using {keras.backend.backend()} backend")
76
58
  cli.run(plot=True)
77
59
  else:
78
60
  raise ValueError(f"Unknown task {args.task}, see `zea --help` for available tasks.")
zea/data/__main__.py CHANGED
@@ -9,8 +9,8 @@ import argparse
9
9
  from zea import Folder
10
10
 
11
11
 
12
- def main():
13
- parser = argparse.ArgumentParser(description="Copy a zea.Folder to a new location.")
12
+ def get_parser():
13
+ parser = argparse.ArgumentParser(description="Copy a :class:`zea.Folder` to a new location.")
14
14
  parser.add_argument("src", help="Source folder path")
15
15
  parser.add_argument("dst", help="Destination folder path")
16
16
  parser.add_argument("key", help="Key to access in the hdf5 files")
@@ -20,8 +20,11 @@ def main():
20
20
  choices=["a", "w", "r+", "x"],
21
21
  help="Mode in which to open the destination files (default: 'a')",
22
22
  )
23
+ return parser
24
+
23
25
 
24
- args = parser.parse_args()
26
+ def main():
27
+ args = get_parser().parse_args()
25
28
 
26
29
  src_folder = Folder(args.src, args.key, validate=False)
27
30
  src_folder.copy(args.dst, args.key, mode=args.mode)
zea/data/file.py CHANGED
@@ -290,7 +290,7 @@ class File(h5py.File):
290
290
  """
291
291
  scan_parameters = {}
292
292
  if "scan" in self:
293
- scan_parameters = recursively_load_dict_contents_from_group(self, "scan")
293
+ scan_parameters = self.recursively_load_dict_contents_from_group("scan")
294
294
  elif "event" in list(self.keys())[0]:
295
295
  if event is None:
296
296
  raise ValueError(
@@ -305,38 +305,17 @@ class File(h5py.File):
305
305
  f"Found number of events: {len(self.keys())}."
306
306
  )
307
307
 
308
- scan_parameters = recursively_load_dict_contents_from_group(self, f"event_{event}/scan")
308
+ scan_parameters = self.recursively_load_dict_contents_from_group(f"event_{event}/scan")
309
309
  else:
310
310
  log.warning("Could not find scan parameters in file.")
311
311
 
312
312
  return scan_parameters
313
313
 
314
314
  def get_scan_parameters(self, event=None) -> dict:
315
- """Returns a dictionary of default parameters to initialize a scan
316
- object that works with the file.
315
+ """Returns a dictionary of scan parameters stored in the file."""
316
+ return self.get_parameters(event)
317
317
 
318
- Returns:
319
- dict: The default parameters (the keys are identical to the
320
- __init__ parameters of the Scan class).
321
- """
322
- file_scan_parameters = self.get_parameters(event)
323
-
324
- scan_parameters = {}
325
- for parameter, value in file_scan_parameters.items():
326
- if parameter in Scan.VALID_PARAMS:
327
- param_type = Scan.VALID_PARAMS[parameter]["type"]
328
- if param_type in (bool, int, float):
329
- scan_parameters[parameter] = param_type(value)
330
- elif isinstance(param_type, tuple) and float in param_type:
331
- scan_parameters[parameter] = float(value)
332
- else:
333
- scan_parameters[parameter] = value
334
-
335
- if len(scan_parameters) == 0:
336
- log.info(f"Could not find proper scan parameters in {self}.")
337
- return scan_parameters
338
-
339
- def scan(self, event=None, **kwargs) -> Scan:
318
+ def scan(self, event=None, safe=True, **kwargs) -> Scan:
340
319
  """Returns a Scan object initialized with the parameters from the file.
341
320
 
342
321
  Args:
@@ -348,6 +327,9 @@ class File(h5py.File):
348
327
  ...
349
328
 
350
329
  Defaults to None. In that case no event structure is expected.
330
+ safe (bool, optional): If True, will only use parameters that are
331
+ defined in the Scan class. If False, will use all parameters
332
+ from the file. Defaults to True.
351
333
  **kwargs: Additional keyword arguments to pass to the Scan object.
352
334
  These will override the parameters from the file if they are
353
335
  present in the file.
@@ -355,7 +337,7 @@ class File(h5py.File):
355
337
  Returns:
356
338
  Scan: The scan object.
357
339
  """
358
- return Scan.merge(self.get_scan_parameters(event), kwargs)
340
+ return Scan.merge(self.get_scan_parameters(event), kwargs, safe=safe)
359
341
 
360
342
  def get_probe_parameters(self, event=None) -> dict:
361
343
  """Returns a dictionary of probe parameters to initialize a probe
@@ -388,21 +370,24 @@ class File(h5py.File):
388
370
  probe_parameters_file = self.get_probe_parameters(event)
389
371
  return Probe.from_parameters(self.probe_name, probe_parameters_file)
390
372
 
391
- def recursively_load_dict_contents_from_group(self, path: str, squeeze: bool = False) -> dict:
373
+ def recursively_load_dict_contents_from_group(self, path: str) -> dict:
392
374
  """Load dict from contents of group
393
375
 
394
376
  Values inside the group are converted to numpy arrays
395
- or primitive types (int, float, str). Single element
396
- arrays are converted to the corresponding primitive type (if squeeze=True)
377
+ or primitive types (int, float, str).
397
378
 
398
379
  Args:
399
380
  path (str): path to group
400
- squeeze (bool, optional): squeeze arrays with single element.
401
- Defaults to False.
402
381
  Returns:
403
382
  dict: dictionary with contents of group
404
383
  """
405
- return recursively_load_dict_contents_from_group(self, path, squeeze)
384
+ ans = {}
385
+ for key, item in self[path].items():
386
+ if isinstance(item, h5py.Dataset):
387
+ ans[key] = item[()]
388
+ elif isinstance(item, h5py.Group):
389
+ ans[key] = self.recursively_load_dict_contents_from_group(path + "/" + key + "/")
390
+ return ans
406
391
 
407
392
  @classmethod
408
393
  def get_shape(cls, path: str, key: str) -> tuple:
@@ -519,54 +504,14 @@ def load_file(
519
504
  # the number of selected transmits
520
505
  if data_type in ["raw_data", "aligned_data"]:
521
506
  indices = File._prepare_indices(indices)
522
- n_tx = data.shape[1]
523
507
  if isinstance(indices, tuple) and len(indices) > 1:
524
- tx_idx = indices[1]
525
- transmits = np.arange(n_tx)[tx_idx]
526
- scan_kwargs["selected_transmits"] = transmits
508
+ scan_kwargs["selected_transmits"] = indices[1]
527
509
 
528
510
  scan = file.scan(**scan_kwargs)
529
511
 
530
512
  return data, scan, probe
531
513
 
532
514
 
533
- def recursively_load_dict_contents_from_group(
534
- h5file: h5py._hl.files.File, path: str, squeeze: bool = False
535
- ) -> dict:
536
- """Load dict from contents of group
537
-
538
- Values inside the group are converted to numpy arrays
539
- or primitive types (int, float, str). Single element
540
- arrays are converted to the corresponding primitive type (if squeeze=True)
541
-
542
- Args:
543
- h5file (h5py._hl.files.File): h5py file object
544
- path (str): path to group
545
- squeeze (bool, optional): squeeze arrays with single element.
546
- Defaults to False.
547
- Returns:
548
- dict: dictionary with contents of group
549
- """
550
- ans = {}
551
- for key, item in h5file[path].items():
552
- if isinstance(item, h5py._hl.dataset.Dataset):
553
- ans[key] = item[()]
554
- # all ones in shape
555
- if squeeze:
556
- if ans[key].shape == () or all(i == 1 for i in ans[key].shape):
557
- # check for strings
558
- if isinstance(ans[key], str):
559
- ans[key] = str(ans[key])
560
- # check for integers
561
- elif int(ans[key]) == float(ans[key]):
562
- ans[key] = int(ans[key])
563
- else:
564
- ans[key] = float(ans[key])
565
- elif isinstance(item, h5py._hl.group.Group):
566
- ans[key] = recursively_load_dict_contents_from_group(h5file, path + "/" + key + "/")
567
- return ans
568
-
569
-
570
515
  def _print_hdf5_attrs(hdf5_obj, prefix=""):
571
516
  """Recursively prints all keys, attributes, and shapes in an HDF5 file.
572
517
 
zea/display.py CHANGED
@@ -3,7 +3,6 @@
3
3
  from functools import partial
4
4
  from typing import Tuple, Union
5
5
 
6
- import keras
7
6
  import numpy as np
8
7
  import scipy
9
8
  from keras import ops
@@ -342,6 +341,7 @@ def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value
342
341
  """map_coordinates using keras.ops or scipy.ndimage when order > 1."""
343
342
  if order > 1:
344
343
  inputs = ops.convert_to_numpy(inputs)
344
+ coordinates = ops.convert_to_numpy(coordinates)
345
345
  out = scipy.ndimage.map_coordinates(
346
346
  inputs, coordinates, order=order, mode=fill_mode, cval=fill_value
347
347
  )
@@ -359,10 +359,6 @@ def map_coordinates(inputs, coordinates, order, fill_mode="constant", fill_value
359
359
  def _interpolate_batch(images, coordinates, fill_value=0.0, order=1, vectorize=True):
360
360
  """Interpolate a batch of images."""
361
361
 
362
- # TODO: figure out why tensorflow map_coordinates is broken
363
- if keras.backend.backend() == "tensorflow":
364
- assert order > 1, "Some bug in tensorflow in map_coordinates, set order > 1 to use scipy."
365
-
366
362
  image_shape = images.shape
367
363
  num_image_dims = coordinates.shape[0]
368
364
 
zea/doppler.py ADDED
@@ -0,0 +1,75 @@
1
+ """Doppler functions for processing I/Q ultrasound data."""
2
+
3
+ import numpy as np
4
+ from keras import ops
5
+
6
+ from zea import tensor_ops
7
+
8
+
9
+ def color_doppler(
10
+ data,
11
+ center_frequency,
12
+ pulse_repetition_frequency,
13
+ sound_speed,
14
+ hamming_size=None,
15
+ lag=1,
16
+ ):
17
+ """Compute Color Doppler from packet of I/Q Data.
18
+
19
+ Args:
20
+ data (ndarray): I/Q complex data of shape (n_frames, grid_size_z, grid_size_x).
21
+ n_frames corresponds to the ensemble length used to compute
22
+ the Doppler signal.
23
+ center_frequency (float): Center frequency of the ultrasound probe in Hz.
24
+ pulse_repetition_frequency (float): Pulse repetition frequency in Hz.
25
+ sound_speed (float): Speed of sound in the medium in m/s.
26
+ hamming_size (int or tuple, optional): Size of the Hamming window to apply
27
+ for spatial averaging. If None, no window is applied.
28
+ If an integer, it is applied to both dimensions. If a tuple, it should
29
+ contain two integers for the row and column dimensions.
30
+ lag (int, optional): Lag for the auto-correlation computation.
31
+ Defaults to 1, meaning Doppler is computed from the current frame
32
+ and the next frame.
33
+
34
+ Returns:
35
+ doppler_velocities (ndarray): Doppler velocity map of shape (grid_size_z, grid_size_x) in
36
+ meters/second.
37
+
38
+ """
39
+ assert data.ndim == 3, "Data must be a 3-D array"
40
+ if not (isinstance(lag, int) and lag >= 1):
41
+ raise ValueError("lag must be an integer >= 1")
42
+ n_frames = data.shape[0]
43
+ assert n_frames > lag, "Data must have more frames than the lag"
44
+
45
+ if hamming_size is None:
46
+ hamming_size = np.array([1, 1], dtype=int)
47
+ elif np.isscalar(hamming_size):
48
+ hamming_size = np.array([int(hamming_size), int(hamming_size)], dtype=int)
49
+ else:
50
+ assert len(hamming_size) == 2, "hamming_size must be an integer or a tuple of two integers"
51
+ hamming_size = np.array(hamming_size, dtype=int)
52
+ if not np.all(hamming_size > 0):
53
+ raise ValueError("hamming_size must contain integers > 0")
54
+
55
+ # Auto-correlation method
56
+ iq1 = data[: n_frames - lag]
57
+ iq2 = data[lag:]
58
+ autocorr = ops.sum(iq1 * ops.conj(iq2), axis=0) # Ensemble auto-correlation
59
+
60
+ # Spatial weighted average
61
+ if hamming_size[0] != 1 and hamming_size[1] != 1:
62
+ h_row = np.hamming(hamming_size[0])
63
+ h_col = np.hamming(hamming_size[1])
64
+ autocorr = tensor_ops.apply_along_axis(
65
+ lambda x: tensor_ops.correlate(x, h_row, mode="same"), 0, autocorr
66
+ )
67
+ autocorr = tensor_ops.apply_along_axis(
68
+ lambda x: tensor_ops.correlate(x, h_col, mode="same"), 1, autocorr
69
+ )
70
+
71
+ # Doppler velocity
72
+ nyquist_velocity = sound_speed * pulse_repetition_frequency / (4 * center_frequency * lag)
73
+ phase = ops.arctan2(ops.imag(autocorr), ops.real(autocorr))
74
+ doppler_velocities = -nyquist_velocity * phase / np.pi
75
+ return doppler_velocities
@@ -0,0 +1,125 @@
1
+ """This file creates a :class:`zea.Operation` for all unary :mod:`keras.ops`
2
+ and :mod:`keras.ops.image` functions.
3
+
4
+ They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
5
+
6
+ .. code-block:: python
7
+
8
+ from zea.keras_ops import Squeeze
9
+
10
+ op = Squeeze(axis=1)
11
+ """
12
+
13
+ import inspect
14
+ import shutil
15
+ import tempfile
16
+ from pathlib import Path
17
+
18
+ import keras
19
+
20
+
21
+ def _filter_funcs_by_first_arg(funcs, arg_name):
22
+ """Filter a list of (name, func) tuples to those whose first argument matches arg_name."""
23
+ filtered = []
24
+ for name, func in funcs:
25
+ try:
26
+ sig = inspect.signature(func)
27
+ params = list(sig.parameters.keys())
28
+ if params and params[0] == arg_name:
29
+ filtered.append((name, func))
30
+ except (ValueError, TypeError):
31
+ # Skip functions that can't be inspected
32
+ continue
33
+ return filtered
34
+
35
+
36
+ def _functions_from_namespace(namespace):
37
+ """Get all functions from a given namespace."""
38
+ return [(name, obj) for name, obj in inspect.getmembers(namespace) if inspect.isfunction(obj)]
39
+
40
+
41
+ def _unary_functions_from_namespace(namespace, arg_name="x"):
42
+ """Get all unary functions from a given namespace."""
43
+ funcs = _functions_from_namespace(namespace)
44
+ return _filter_funcs_by_first_arg(funcs, arg_name)
45
+
46
+
47
+ def _snake_to_pascal(name):
48
+ """Convert a snake_case name to PascalCase."""
49
+ return "".join(word.capitalize() for word in name.split("_"))
50
+
51
+
52
+ def _generate_operation_class_code(name, namespace):
53
+ """Generate Python code for a zea.Operation class for a given keras.ops function."""
54
+ class_name = _snake_to_pascal(name)
55
+ module_path = f"{namespace.__name__}.{name}"
56
+ doc = f"Operation wrapping {module_path}."
57
+
58
+ return f'''
59
+ @ops_registry("{module_path}")
60
+ class {class_name}(Lambda):
61
+ """{doc}"""
62
+
63
+ def __init__(self, **kwargs):
64
+ try:
65
+ super().__init__(func={module_path}, **kwargs)
66
+ except AttributeError as e:
67
+ raise MissingKerasOps("{class_name}", "{module_path}") from e
68
+ '''
69
+
70
+
71
+ def _generate_ops_file():
72
+ """Generate a .py file with all operation class definitions."""
73
+
74
+ # File header with version info
75
+ content = f'''"""Auto-generated :class:`zea.Operation` for all unary :mod:`keras.ops`
76
+ and :mod:`keras.ops.image` functions.
77
+
78
+ They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
79
+
80
+ .. code-block:: python
81
+
82
+ from zea.keras_ops import Squeeze
83
+
84
+ op = Squeeze(axis=1)
85
+
86
+ This file is generated automatically. Do not edit manually.
87
+ Generated with Keras {keras.__version__}
88
+ """
89
+
90
+ import keras
91
+
92
+ from zea.internal.registry import ops_registry
93
+ from zea.ops import Lambda
94
+
95
+ class MissingKerasOps(ValueError):
96
+ def __init__(self, class_name: str, func: str):
97
+ super().__init__(
98
+ f"Failed to create {{class_name}} with {{func}}. " +
99
+ "This may be due to an incompatible version of `keras`. " +
100
+ "Please try to upgrade `keras` to the latest version by running " +
101
+ "`pip install --upgrade keras`."
102
+ )
103
+
104
+ '''
105
+
106
+ for name, _ in _unary_functions_from_namespace(keras.ops, "x"):
107
+ content += _generate_operation_class_code(name, keras.ops)
108
+
109
+ for name, _ in _unary_functions_from_namespace(keras.ops.image, "images"):
110
+ content += _generate_operation_class_code(name, keras.ops.image)
111
+
112
+ # Write to a temporary file first, then move to final location
113
+ target_path = Path(__file__).parent.parent / "keras_ops.py"
114
+ with tempfile.NamedTemporaryFile("w", delete=False, encoding="utf-8") as tmp_file:
115
+ tmp_file.write(content)
116
+ temp_path = Path(tmp_file.name)
117
+
118
+ # Atomic move to avoid partial writes
119
+ shutil.move(temp_path, target_path)
120
+
121
+ print("Done generating `keras_ops.py`.")
122
+
123
+
124
+ if __name__ == "__main__":
125
+ _generate_ops_file()
zea/internal/core.py CHANGED
@@ -119,11 +119,18 @@ class Object:
119
119
  return cls(**reduced_params)
120
120
 
121
121
  @classmethod
122
- def merge(cls, obj1: dict, obj2: dict):
123
- """Merge multiple objects and safely initialize a new object."""
122
+ def merge(cls, obj1: dict, obj2: dict, safe: bool = False):
123
+ """Merge multiple objects and safely initialize a new object.
124
+
125
+ Optionally can safely initialize the object, which removes any invalid
126
+ arguments.
127
+ """
124
128
  # TODO: support actual zea.core.Objects, now we only support dictionaries
125
129
  params = update_dictionary(obj1, obj2)
126
- return cls.safe_initialize(**params)
130
+ if not safe:
131
+ return cls(**params)
132
+ else:
133
+ return cls.safe_initialize(**params)
127
134
 
128
135
  @classmethod
129
136
  def _tree_unflatten(cls, aux, children):
zea/internal/device.py CHANGED
@@ -64,24 +64,43 @@ def get_gpu_memory(verbose=True):
64
64
 
65
65
  Returns:
66
66
  memory_free_values: list of available memory for each gpu in MiB.
67
+ Returns empty list if nvidia-smi is not available.
67
68
  """
68
69
  if not check_nvidia_smi():
69
70
  log.warning(
70
71
  "nvidia-smi is not available. Please install nvidia-utils. "
71
- "Cannot retrieve GPU memory. Falling back to CPU.."
72
+ "Cannot retrieve GPU memory. Falling back to CPU."
72
73
  )
73
- return None
74
+ return []
74
75
 
75
76
  def _output_to_list(x):
76
77
  return x.decode("ascii").split("\n")[:-1]
77
78
 
78
- COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv"
79
+ COMMAND = [
80
+ "nvidia-smi",
81
+ "--query-gpu=memory.free",
82
+ "--format=csv,noheader,nounits",
83
+ ]
84
+ # Fail-safe timeout (seconds). Override with ZEA_NVIDIA_SMI_TIMEOUT; set <=0 to disable.
85
+ smi_timeout = float(os.getenv("ZEA_NVIDIA_SMI_TIMEOUT", "30"))
79
86
  try:
80
- memory_free_info = _output_to_list(sp.check_output(COMMAND.split()))[1:]
81
- except Exception as e:
82
- print(f"An error occurred: {e}")
87
+ if smi_timeout > 0:
88
+ raw = sp.check_output(COMMAND, timeout=smi_timeout)
89
+ else:
90
+ raw = sp.check_output(COMMAND)
91
+ memory_free_info = _output_to_list(raw)
92
+ except sp.TimeoutExpired:
93
+ log.warning(f"nvidia-smi timed out after {smi_timeout}s. Falling back to CPU.")
94
+ return []
95
+ except sp.SubprocessError as e:
96
+ log.warning(f"Failed to retrieve GPU memory: {e}")
97
+ return []
98
+
99
+ memory_free_values = [int(x) for x in memory_free_info]
83
100
 
84
- memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
101
+ if verbose:
102
+ header = "GPU settings"
103
+ print("-" * 2 + header.center(50 - 4, "-") + "-" * 2)
85
104
 
86
105
  # only show enabled devices
87
106
  if os.environ.get("CUDA_VISIBLE_DEVICES", "") != "":
@@ -89,10 +108,12 @@ def get_gpu_memory(verbose=True):
89
108
  gpus = [int(gpu) for gpu in gpus.split(",")][: len(memory_free_values)]
90
109
  if verbose:
91
110
  # Report the number of disabled GPUs out of the total
92
- num_disabled_gpus = len(memory_free_values) - len(gpus)
93
111
  num_gpus = len(memory_free_values)
94
-
95
- print(f"{num_disabled_gpus / num_gpus} GPUs were disabled")
112
+ num_disabled_gpus = num_gpus - len(gpus)
113
+ if num_gpus > 0:
114
+ print(f"{num_disabled_gpus}/{num_gpus} GPUs were disabled")
115
+ else:
116
+ print("No GPUs detected by nvidia-smi.")
96
117
 
97
118
  memory_free_values = [memory_free_values[gpu] for gpu in gpus]
98
119
 
@@ -253,15 +274,11 @@ def get_device(device="auto:1", verbose=True, hide_others=True):
253
274
  os.environ["CUDA_VISIBLE_DEVICES"] = ""
254
275
  # returns None to indicate CPU
255
276
 
256
- if device.lower() == "cpu":
277
+ if isinstance(device, str) and device.lower() == "cpu":
257
278
  return _cpu_case()
258
279
 
259
- if verbose:
260
- header = "GPU settings"
261
- print("-" * 2 + header.center(50 - 4, "-") + "-" * 2)
262
-
263
280
  memory = get_gpu_memory(verbose=verbose)
264
- if memory is None: # nvidia-smi not working, fallback to CPU
281
+ if len(memory) == 0: # nvidia-smi not working, fallback to CPU
265
282
  return _cpu_case()
266
283
 
267
284
  gpu_ids = list(range(len(memory)))
@@ -0,0 +1,39 @@
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from matplotlib import animation
4
+
5
+ from zea import Scan
6
+
7
+
8
+ def animate_images(
9
+ images, path, scan: Scan = None, interval=100, cmap="gray", figsize=(5, 4), dpi=80
10
+ ):
11
+ """Helper function to animate a list of images."""
12
+ if interval <= 0:
13
+ raise ValueError("interval must be a positive integer (milliseconds).")
14
+ if len(images) == 0:
15
+ raise ValueError("images must be a non-empty sequence.")
16
+ fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
17
+ if scan is not None:
18
+ extent = scan.extent * 1e3 if getattr(scan, "extent") is not None else None
19
+ else:
20
+ extent = None
21
+ im = ax.imshow(np.array(images[0]), animated=True, cmap=cmap, extent=extent)
22
+ ax.set_xlabel("X (mm)")
23
+ ax.set_ylabel("Z (mm)")
24
+
25
+ def update(frame):
26
+ im.set_array(np.array(images[frame]))
27
+ return [im]
28
+
29
+ ani = animation.FuncAnimation(
30
+ fig,
31
+ update,
32
+ frames=len(images),
33
+ blit=True,
34
+ interval=interval,
35
+ )
36
+ plt.close(fig)
37
+ fps = max(1, 1000 // interval)
38
+
39
+ ani.save(path, writer="imagemagick", fps=fps)