zea 0.0.6__py3-none-any.whl → 0.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +54 -19
- zea/agent/__init__.py +12 -12
- zea/agent/masks.py +2 -1
- zea/backend/tensorflow/dataloader.py +2 -5
- zea/beamform/beamformer.py +100 -50
- zea/beamform/lens_correction.py +9 -2
- zea/beamform/pfield.py +9 -2
- zea/beamform/pixelgrid.py +1 -1
- zea/config.py +34 -25
- zea/data/__init__.py +22 -25
- zea/data/augmentations.py +221 -28
- zea/data/convert/__init__.py +1 -6
- zea/data/convert/__main__.py +123 -0
- zea/data/convert/camus.py +101 -40
- zea/data/convert/echonet.py +187 -86
- zea/data/convert/echonetlvh/README.md +2 -3
- zea/data/convert/echonetlvh/{convert_raw_to_usbmd.py → __init__.py} +174 -103
- zea/data/convert/echonetlvh/manual_rejections.txt +73 -0
- zea/data/convert/echonetlvh/precompute_crop.py +43 -64
- zea/data/convert/picmus.py +37 -40
- zea/data/convert/utils.py +86 -0
- zea/data/convert/{matlab.py → verasonics.py} +44 -65
- zea/data/data_format.py +155 -34
- zea/data/dataloader.py +12 -7
- zea/data/datasets.py +112 -71
- zea/data/file.py +184 -73
- zea/data/file_operations.py +496 -0
- zea/data/layers.py +3 -3
- zea/data/preset_utils.py +1 -1
- zea/datapaths.py +16 -4
- zea/display.py +14 -13
- zea/interface.py +14 -16
- zea/internal/_generate_keras_ops.py +6 -7
- zea/internal/cache.py +2 -49
- zea/internal/checks.py +6 -12
- zea/internal/config/validation.py +1 -2
- zea/internal/core.py +69 -6
- zea/internal/device.py +6 -2
- zea/internal/dummy_scan.py +330 -0
- zea/internal/operators.py +118 -2
- zea/internal/parameters.py +101 -70
- zea/internal/setup_zea.py +5 -6
- zea/internal/utils.py +282 -0
- zea/io_lib.py +322 -146
- zea/keras_ops.py +74 -4
- zea/log.py +9 -7
- zea/metrics.py +15 -7
- zea/models/__init__.py +31 -21
- zea/models/base.py +30 -14
- zea/models/carotid_segmenter.py +19 -4
- zea/models/diffusion.py +235 -23
- zea/models/echonet.py +22 -8
- zea/models/echonetlvh.py +31 -7
- zea/models/lpips.py +19 -2
- zea/models/lv_segmentation.py +30 -11
- zea/models/preset_utils.py +5 -5
- zea/models/regional_quality.py +30 -10
- zea/models/taesd.py +21 -5
- zea/models/unet.py +15 -1
- zea/ops.py +770 -336
- zea/probes.py +6 -6
- zea/scan.py +121 -51
- zea/simulator.py +24 -21
- zea/tensor_ops.py +477 -353
- zea/tools/fit_scan_cone.py +90 -160
- zea/tools/hf.py +1 -1
- zea/tools/selection_tool.py +47 -86
- zea/tracking/__init__.py +16 -0
- zea/tracking/base.py +94 -0
- zea/tracking/lucas_kanade.py +474 -0
- zea/tracking/segmentation.py +110 -0
- zea/utils.py +101 -480
- zea/visualize.py +177 -39
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/METADATA +6 -2
- zea-0.0.8.dist-info/RECORD +122 -0
- zea-0.0.6.dist-info/RECORD +0 -112
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/WHEEL +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/entry_points.txt +0 -0
- {zea-0.0.6.dist-info → zea-0.0.8.dist-info}/licenses/LICENSE +0 -0
zea/interface.py
CHANGED
|
@@ -3,15 +3,15 @@
|
|
|
3
3
|
Example usage
|
|
4
4
|
^^^^^^^^^^^^^^
|
|
5
5
|
|
|
6
|
-
..
|
|
6
|
+
.. doctest::
|
|
7
7
|
|
|
8
|
-
import zea
|
|
9
|
-
from zea.internal.setup_zea import setup_config
|
|
8
|
+
>>> import zea
|
|
9
|
+
>>> from zea.internal.setup_zea import setup_config
|
|
10
10
|
|
|
11
|
-
config = setup_config("hf://zeahub/configs/config_camus.yaml")
|
|
11
|
+
>>> config = setup_config("hf://zeahub/configs/config_camus.yaml")
|
|
12
12
|
|
|
13
|
-
interface = zea.Interface(config)
|
|
14
|
-
interface.run(plot=True)
|
|
13
|
+
>>> interface = zea.Interface(config)
|
|
14
|
+
>>> interface.run(plot=True) # doctest: +SKIP
|
|
15
15
|
|
|
16
16
|
"""
|
|
17
17
|
|
|
@@ -31,15 +31,15 @@ from zea.data.file import File
|
|
|
31
31
|
from zea.datapaths import format_data_path
|
|
32
32
|
from zea.display import to_8bit
|
|
33
33
|
from zea.internal.core import DataTypes
|
|
34
|
+
from zea.internal.utils import keep_trying
|
|
34
35
|
from zea.internal.viewer import (
|
|
35
36
|
ImageViewerMatplotlib,
|
|
36
37
|
ImageViewerOpenCV,
|
|
37
38
|
filename_from_window_dialog,
|
|
38
39
|
running_in_notebook,
|
|
39
40
|
)
|
|
40
|
-
from zea.io_lib import matplotlib_figure_to_numpy
|
|
41
|
+
from zea.io_lib import matplotlib_figure_to_numpy, save_video
|
|
41
42
|
from zea.ops import Pipeline
|
|
42
|
-
from zea.utils import keep_trying, save_to_gif, save_to_mp4
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
class Interface:
|
|
@@ -266,10 +266,11 @@ class Interface:
|
|
|
266
266
|
save = self.config.plot.save
|
|
267
267
|
|
|
268
268
|
if self.frame_no == "all":
|
|
269
|
-
|
|
270
|
-
asyncio.
|
|
271
|
-
|
|
272
|
-
|
|
269
|
+
try:
|
|
270
|
+
loop = asyncio.get_running_loop()
|
|
271
|
+
loop.create_task(self.run_movie(save)) # already running loop
|
|
272
|
+
except RuntimeError:
|
|
273
|
+
asyncio.run(self.run_movie(save)) # no loop yet
|
|
273
274
|
|
|
274
275
|
else:
|
|
275
276
|
if plot:
|
|
@@ -520,10 +521,7 @@ class Interface:
|
|
|
520
521
|
|
|
521
522
|
fps = self.config.plot.fps
|
|
522
523
|
|
|
523
|
-
|
|
524
|
-
save_to_gif(images, path, fps=fps)
|
|
525
|
-
elif self.config.plot.video_extension == "mp4":
|
|
526
|
-
save_to_mp4(images, path, fps=fps)
|
|
524
|
+
save_video(images, path, fps=fps)
|
|
527
525
|
|
|
528
526
|
if self.verbose:
|
|
529
527
|
log.info(f"Video saved to {log.yellow(path)}")
|
|
@@ -3,11 +3,10 @@ and :mod:`keras.ops.image` functions.
|
|
|
3
3
|
|
|
4
4
|
They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
|
|
5
5
|
|
|
6
|
-
..
|
|
6
|
+
.. doctest::
|
|
7
7
|
|
|
8
|
-
from zea.keras_ops import Squeeze
|
|
9
|
-
|
|
10
|
-
op = Squeeze(axis=1)
|
|
8
|
+
>>> from zea.keras_ops import Squeeze
|
|
9
|
+
>>> op = Squeeze(axis=1)
|
|
11
10
|
"""
|
|
12
11
|
|
|
13
12
|
import inspect
|
|
@@ -77,11 +76,11 @@ and :mod:`keras.ops.image` functions.
|
|
|
77
76
|
|
|
78
77
|
They can be used in zea pipelines like any other :class:`zea.Operation`, for example:
|
|
79
78
|
|
|
80
|
-
..
|
|
79
|
+
.. doctest::
|
|
81
80
|
|
|
82
|
-
from zea.keras_ops import Squeeze
|
|
81
|
+
>>> from zea.keras_ops import Squeeze
|
|
83
82
|
|
|
84
|
-
op = Squeeze(axis=1)
|
|
83
|
+
>>> op = Squeeze(axis=1)
|
|
85
84
|
|
|
86
85
|
This file is generated automatically. Do not edit manually.
|
|
87
86
|
Generated with Keras {keras.__version__}
|
zea/internal/cache.py
CHANGED
|
@@ -21,10 +21,8 @@
|
|
|
21
21
|
|
|
22
22
|
import ast
|
|
23
23
|
import atexit
|
|
24
|
-
import hashlib
|
|
25
24
|
import inspect
|
|
26
25
|
import os
|
|
27
|
-
import pickle
|
|
28
26
|
import tempfile
|
|
29
27
|
import textwrap
|
|
30
28
|
from pathlib import Path
|
|
@@ -33,6 +31,7 @@ import joblib
|
|
|
33
31
|
import keras
|
|
34
32
|
|
|
35
33
|
from zea import log
|
|
34
|
+
from zea.internal.core import hash_elements
|
|
36
35
|
|
|
37
36
|
_DEFAULT_ZEA_CACHE_DIR = Path.home() / ".cache" / "zea"
|
|
38
37
|
|
|
@@ -80,52 +79,6 @@ _CACHE_DIR = ZEA_CACHE_DIR / "cached_funcs"
|
|
|
80
79
|
_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
81
80
|
|
|
82
81
|
|
|
83
|
-
def serialize_elements(key_elements: list, shorten: bool = False) -> str:
|
|
84
|
-
"""Serialize elements of a list to generate a cache key.
|
|
85
|
-
|
|
86
|
-
In general uses the string representation of the elements unless
|
|
87
|
-
the element has a `serialized` attribute, in which case it uses that.
|
|
88
|
-
For instance this is useful for custom classes that inherit from `zea.core.Object`.
|
|
89
|
-
|
|
90
|
-
Args:
|
|
91
|
-
key_elements (list): List of elements to serialize. Can be nested lists
|
|
92
|
-
or tuples. In this case the elements are serialized recursively.
|
|
93
|
-
shorten (bool): If True, the serialized string is hashed to a shorter
|
|
94
|
-
representation using MD5. Defaults to False.
|
|
95
|
-
|
|
96
|
-
Returns:
|
|
97
|
-
str: A serialized string representation of the elements, joined by underscores.
|
|
98
|
-
|
|
99
|
-
"""
|
|
100
|
-
serialized_elements = []
|
|
101
|
-
for element in key_elements:
|
|
102
|
-
if isinstance(element, (list, tuple)):
|
|
103
|
-
# If element is a list or tuple, serialize its elements recursively
|
|
104
|
-
serialized_elements.append(serialize_elements(element))
|
|
105
|
-
elif hasattr(element, "serialized"):
|
|
106
|
-
# Use the serialized attribute if it exists (e.g. for zea.core.Object)
|
|
107
|
-
serialized_elements.append(str(element.serialized))
|
|
108
|
-
elif isinstance(element, str):
|
|
109
|
-
# If element is a string, use it as is
|
|
110
|
-
serialized_elements.append(element)
|
|
111
|
-
elif isinstance(element, keras.random.SeedGenerator):
|
|
112
|
-
# If element is a SeedGenerator, use the state
|
|
113
|
-
element = keras.ops.convert_to_numpy(element.state.value)
|
|
114
|
-
element = pickle.dumps(element)
|
|
115
|
-
element = hashlib.md5(element).hexdigest()
|
|
116
|
-
serialized_elements.append(element)
|
|
117
|
-
else:
|
|
118
|
-
# Otherwise, serialize the element using pickle and hash it
|
|
119
|
-
element = pickle.dumps(element)
|
|
120
|
-
element = hashlib.md5(element).hexdigest()
|
|
121
|
-
serialized_elements.append(element)
|
|
122
|
-
|
|
123
|
-
serialized = "_".join(serialized_elements)
|
|
124
|
-
if shorten:
|
|
125
|
-
return hashlib.md5(serialized.encode()).hexdigest()
|
|
126
|
-
return serialized
|
|
127
|
-
|
|
128
|
-
|
|
129
82
|
def get_function_source(func):
|
|
130
83
|
"""Recursively get the source code of a function and its nested functions."""
|
|
131
84
|
try:
|
|
@@ -188,7 +141,7 @@ def generate_cache_key(func, args, kwargs, arg_names):
|
|
|
188
141
|
# Add keras backend
|
|
189
142
|
key_elements.append(keras.backend.backend())
|
|
190
143
|
|
|
191
|
-
return f"{func.__qualname__}_" +
|
|
144
|
+
return f"{func.__qualname__}_" + hash_elements(key_elements)
|
|
192
145
|
|
|
193
146
|
|
|
194
147
|
def cache_output(*arg_names, verbose=False):
|
zea/internal/checks.py
CHANGED
|
@@ -64,8 +64,7 @@ def _check_raw_data(data=None, shape=None, with_batch_dim=None):
|
|
|
64
64
|
shape (tuple, optional): shape of the data. Defaults to None.
|
|
65
65
|
either data or shape must be provided.
|
|
66
66
|
with_batch_dim (bool, optional): whether data has frame dimension at the start.
|
|
67
|
-
Setting this to True requires the data to have 5 dimensions. Defaults to
|
|
68
|
-
False.
|
|
67
|
+
Setting this to True requires the data to have 5 dimensions. Defaults to None.
|
|
69
68
|
|
|
70
69
|
Raises:
|
|
71
70
|
AssertionError: if data does not have expected shape
|
|
@@ -105,8 +104,7 @@ def _check_aligned_data(data=None, shape=None, with_batch_dim=None):
|
|
|
105
104
|
shape (tuple, optional): shape of the data. Defaults to None.
|
|
106
105
|
either data or shape must be provided.
|
|
107
106
|
with_batch_dim (bool, optional): whether data has frame dimension at the start.
|
|
108
|
-
Setting this to True requires the data to have 5 dimensions. Defaults to
|
|
109
|
-
False.
|
|
107
|
+
Setting this to True requires the data to have 5 dimensions. Defaults to None.
|
|
110
108
|
|
|
111
109
|
Raises:
|
|
112
110
|
AssertionError: if data does not have expected shape
|
|
@@ -147,8 +145,7 @@ def _check_beamformed_data(data=None, shape=None, with_batch_dim=None):
|
|
|
147
145
|
shape (tuple, optional): shape of the data. Defaults to None.
|
|
148
146
|
either data or shape must be provided.
|
|
149
147
|
with_batch_dim (bool, optional): whether data has frame dimension at the start.
|
|
150
|
-
Setting this to True requires the data to have 4 dimensions. Defaults to
|
|
151
|
-
False.
|
|
148
|
+
Setting this to True requires the data to have 4 dimensions. Defaults to None.
|
|
152
149
|
|
|
153
150
|
Raises:
|
|
154
151
|
AssertionError: if data does not have expected shape
|
|
@@ -190,8 +187,7 @@ def _check_envelope_data(data=None, shape=None, with_batch_dim=None):
|
|
|
190
187
|
shape (tuple, optional): shape of the data. Defaults to None.
|
|
191
188
|
either data or shape must be provided.
|
|
192
189
|
with_batch_dim (bool, optional): whether data has frame dimension at the start.
|
|
193
|
-
Setting this to True requires the data to have
|
|
194
|
-
False.
|
|
190
|
+
Setting this to True requires the data to have 3 dimensions. Defaults to None.
|
|
195
191
|
|
|
196
192
|
Raises:
|
|
197
193
|
AssertionError: if data does not have expected shape
|
|
@@ -227,8 +223,7 @@ def _check_image(data=None, shape=None, with_batch_dim=None):
|
|
|
227
223
|
shape (tuple, optional): shape of the data. Defaults to None.
|
|
228
224
|
either data or shape must be provided.
|
|
229
225
|
with_batch_dim (bool, optional): whether data has frame dimension at the start.
|
|
230
|
-
Setting this to True requires the data to have
|
|
231
|
-
False.
|
|
226
|
+
Setting this to True requires the data to have 3 dimensions. Defaults to None.
|
|
232
227
|
|
|
233
228
|
Raises:
|
|
234
229
|
AssertionError: if data does not have expected shape.
|
|
@@ -264,8 +259,7 @@ def _check_image_sc(data=None, shape=None, with_batch_dim=None):
|
|
|
264
259
|
shape (tuple, optional): shape of the data. Defaults to None.
|
|
265
260
|
either data or shape must be provided.
|
|
266
261
|
with_batch_dim (bool, optional): whether data has frame dimension at the start.
|
|
267
|
-
Setting this to True requires the data to have
|
|
268
|
-
False.
|
|
262
|
+
Setting this to True requires the data to have 3 dimensions. Defaults to None.
|
|
269
263
|
|
|
270
264
|
Raises:
|
|
271
265
|
AssertionError: if data does not have expected shape.
|
|
@@ -15,9 +15,8 @@ from pathlib import Path
|
|
|
15
15
|
|
|
16
16
|
from schema import And, Optional, Or, Regex, Schema
|
|
17
17
|
|
|
18
|
-
import zea.metrics # noqa: F401
|
|
19
18
|
from zea.internal.checks import _DATA_TYPES
|
|
20
|
-
from zea.
|
|
19
|
+
from zea.metrics import metrics_registry
|
|
21
20
|
|
|
22
21
|
# predefined checks, later used in schema to check validity of parameter
|
|
23
22
|
any_number = Or(
|
zea/internal/core.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
"""Base classes for the toolbox"""
|
|
2
2
|
|
|
3
3
|
import enum
|
|
4
|
+
import hashlib
|
|
4
5
|
import json
|
|
5
6
|
import pickle
|
|
6
7
|
from copy import deepcopy
|
|
@@ -8,7 +9,8 @@ from copy import deepcopy
|
|
|
8
9
|
import keras
|
|
9
10
|
import numpy as np
|
|
10
11
|
|
|
11
|
-
from zea.utils import reduce_to_signature
|
|
12
|
+
from zea.internal.utils import reduce_to_signature
|
|
13
|
+
from zea.utils import update_dictionary
|
|
12
14
|
|
|
13
15
|
CONVERT_TO_KERAS_TYPES = (np.ndarray, int, float, list, tuple, bool)
|
|
14
16
|
BASE_FLOAT_PRECISION = "float32"
|
|
@@ -76,7 +78,7 @@ class Object:
|
|
|
76
78
|
attributes.pop(
|
|
77
79
|
"_serialized", None
|
|
78
80
|
) # Remove the cached serialized attribute to avoid recursion
|
|
79
|
-
self._serialized =
|
|
81
|
+
self._serialized = serialize_elements([attributes])
|
|
80
82
|
return self._serialized
|
|
81
83
|
|
|
82
84
|
def __setattr__(self, name: str, value):
|
|
@@ -167,9 +169,7 @@ def _skip_to_tensor(value):
|
|
|
167
169
|
# Skip str (because JIT does not support it)
|
|
168
170
|
# Skip methods and functions
|
|
169
171
|
# Skip byte strings
|
|
170
|
-
|
|
171
|
-
return True
|
|
172
|
-
return False
|
|
172
|
+
return isinstance(value, str) or callable(value) or isinstance(value, bytes)
|
|
173
173
|
|
|
174
174
|
|
|
175
175
|
def dict_to_tensor(dictionary, keep_as_is=None):
|
|
@@ -184,8 +184,9 @@ def dict_to_tensor(dictionary, keep_as_is=None):
|
|
|
184
184
|
# Get the value from the dictionary
|
|
185
185
|
value = dictionary[key]
|
|
186
186
|
|
|
187
|
-
if isinstance(value, Object):
|
|
187
|
+
if isinstance(value, Object) and hasattr(value, "to_tensor"):
|
|
188
188
|
snapshot[key] = value.to_tensor(keep_as_is=keep_as_is)
|
|
189
|
+
continue
|
|
189
190
|
|
|
190
191
|
# Skip certain types
|
|
191
192
|
if _skip_to_tensor(value):
|
|
@@ -288,3 +289,65 @@ class ZEADecoderJSON(json.JSONDecoder):
|
|
|
288
289
|
obj[key] = self._MOD_TYPES_MAP[value] if value is not None else None
|
|
289
290
|
|
|
290
291
|
return obj
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def serialize_elements(key_elements: list) -> str:
|
|
295
|
+
"""Serialize elements of a list to a string.
|
|
296
|
+
|
|
297
|
+
Generally, uses the pickle representation of the elements.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
key_elements (list): List of elements to serialize. Can be nested lists
|
|
301
|
+
or tuples. In this case the elements are serialized recursively.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
str: A serialized string representation of the elements, joined by underscores.
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
def _serialize(element) -> str:
|
|
308
|
+
return pickle.dumps(element).hex()
|
|
309
|
+
|
|
310
|
+
def _serialize_element(element) -> str:
|
|
311
|
+
if isinstance(element, (list, tuple)):
|
|
312
|
+
# If element is a list or tuple, serialize its elements recursively
|
|
313
|
+
element = serialize_elements(element)
|
|
314
|
+
elif isinstance(element, Object) and hasattr(element, "serialized"):
|
|
315
|
+
# Use the serialized attribute if it exists
|
|
316
|
+
element = str(element.serialized)
|
|
317
|
+
elif isinstance(element, keras.random.SeedGenerator):
|
|
318
|
+
# If element is a SeedGenerator, use the state
|
|
319
|
+
element = keras.ops.convert_to_numpy(element.state.value)
|
|
320
|
+
element = _serialize(element)
|
|
321
|
+
elif isinstance(element, dict):
|
|
322
|
+
# If element is a dictionary, sort its keys and serialize its values recursively.
|
|
323
|
+
# This is needed to ensure the internal state and ordering of the dictionary does
|
|
324
|
+
# not affect the serialization.
|
|
325
|
+
keys = list(sorted(element.keys()))
|
|
326
|
+
values = [element[k] for k in keys]
|
|
327
|
+
keys = serialize_elements(keys)
|
|
328
|
+
values = serialize_elements(values)
|
|
329
|
+
element = f"k_{keys}_v_{values}"
|
|
330
|
+
else:
|
|
331
|
+
# Otherwise, serialize the element directly
|
|
332
|
+
element = _serialize(element)
|
|
333
|
+
|
|
334
|
+
return element
|
|
335
|
+
|
|
336
|
+
serialized_elements = []
|
|
337
|
+
for element in key_elements:
|
|
338
|
+
serialized_elements.append(_serialize_element(element))
|
|
339
|
+
|
|
340
|
+
return "_".join(serialized_elements)
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def hash_elements(key_elements: list) -> str:
|
|
344
|
+
"""Generate an MD5 hash of the elements.
|
|
345
|
+
|
|
346
|
+
Args:
|
|
347
|
+
key_elements (list): List of elements to serialize and hash.
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
str: An MD5 hash of the serialized elements.
|
|
351
|
+
"""
|
|
352
|
+
serialized = serialize_elements(key_elements)
|
|
353
|
+
return hashlib.md5(serialized.encode()).hexdigest()
|
zea/internal/device.py
CHANGED
|
@@ -377,7 +377,11 @@ def init_device(
|
|
|
377
377
|
allow_preallocate: bool = True,
|
|
378
378
|
verbose: bool = True,
|
|
379
379
|
):
|
|
380
|
-
"""
|
|
380
|
+
"""Automatically selects a GPU or CPU device.
|
|
381
|
+
|
|
382
|
+
Useful to call at the start of a script to set the device for
|
|
383
|
+
tensorflow, jax or pytorch. The function will select a GPU based
|
|
384
|
+
on available memory, or fall back to CPU if no GPU is available.
|
|
381
385
|
|
|
382
386
|
Args:
|
|
383
387
|
backend (str): String indicating which backend to use. Can be
|
|
@@ -412,7 +416,7 @@ def init_device(
|
|
|
412
416
|
elif backend in ["numpy", "cpu"]:
|
|
413
417
|
device = "cpu"
|
|
414
418
|
else:
|
|
415
|
-
raise ValueError(f"Unknown backend ({backend})
|
|
419
|
+
raise ValueError(f"Unknown backend ({backend}).")
|
|
416
420
|
|
|
417
421
|
# Early exit if device is CPU
|
|
418
422
|
if device == "cpu":
|