zea 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +8 -7
- zea/__main__.py +8 -26
- zea/agent/selection.py +166 -0
- zea/backend/__init__.py +89 -0
- zea/backend/jax/__init__.py +14 -51
- zea/backend/tensorflow/__init__.py +0 -49
- zea/backend/torch/__init__.py +27 -62
- zea/data/__main__.py +6 -3
- zea/data/file.py +19 -74
- zea/data/layers.py +2 -3
- zea/display.py +1 -5
- zea/doppler.py +75 -0
- zea/internal/_generate_keras_ops.py +125 -0
- zea/internal/core.py +10 -3
- zea/internal/device.py +33 -16
- zea/internal/notebooks.py +39 -0
- zea/internal/operators.py +10 -0
- zea/internal/parameters.py +75 -19
- zea/internal/registry.py +1 -1
- zea/internal/viewer.py +24 -24
- zea/io_lib.py +60 -62
- zea/keras_ops.py +1989 -0
- zea/metrics.py +357 -65
- zea/models/__init__.py +6 -3
- zea/models/deeplabv3.py +131 -0
- zea/models/diffusion.py +18 -18
- zea/models/echonetlvh.py +279 -0
- zea/models/lv_segmentation.py +79 -0
- zea/models/presets.py +50 -0
- zea/models/regional_quality.py +122 -0
- zea/ops.py +52 -56
- zea/scan.py +10 -3
- zea/tensor_ops.py +251 -0
- zea/tools/fit_scan_cone.py +2 -2
- zea/tools/selection_tool.py +28 -9
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/METADATA +10 -3
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/RECORD +40 -33
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/WHEEL +1 -1
- zea/internal/convert.py +0 -150
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info}/entry_points.txt +0 -0
- {zea-0.0.4.dist-info → zea-0.0.6.dist-info/licenses}/LICENSE +0 -0
zea/__init__.py
CHANGED
|
@@ -7,10 +7,10 @@ from . import log
|
|
|
7
7
|
|
|
8
8
|
# dynamically add __version__ attribute (see pyproject.toml)
|
|
9
9
|
# __version__ = __import__("importlib.metadata").metadata.version(__package__)
|
|
10
|
-
__version__ = "0.0.
|
|
10
|
+
__version__ = "0.0.6"
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def
|
|
13
|
+
def _bootstrap_backend():
|
|
14
14
|
"""Setup function to initialize the zea package."""
|
|
15
15
|
|
|
16
16
|
def _check_backend_installed():
|
|
@@ -40,14 +40,14 @@ def setup():
|
|
|
40
40
|
|
|
41
41
|
_check_backend_installed()
|
|
42
42
|
|
|
43
|
-
import
|
|
43
|
+
from keras.backend import backend as keras_backend
|
|
44
44
|
|
|
45
|
-
log.info(f"Using backend {
|
|
45
|
+
log.info(f"Using backend {keras_backend()!r}")
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
# call and clean up namespace
|
|
49
|
-
|
|
50
|
-
del
|
|
49
|
+
_bootstrap_backend()
|
|
50
|
+
del _bootstrap_backend
|
|
51
51
|
|
|
52
52
|
from . import (
|
|
53
53
|
agent,
|
|
@@ -55,6 +55,7 @@ from . import (
|
|
|
55
55
|
data,
|
|
56
56
|
display,
|
|
57
57
|
io_lib,
|
|
58
|
+
keras_ops,
|
|
58
59
|
metrics,
|
|
59
60
|
models,
|
|
60
61
|
simulator,
|
|
@@ -68,7 +69,7 @@ from .data.file import File, load_file
|
|
|
68
69
|
from .datapaths import set_data_paths
|
|
69
70
|
from .interface import Interface
|
|
70
71
|
from .internal.device import init_device
|
|
71
|
-
from .internal.setup_zea import
|
|
72
|
+
from .internal.setup_zea import setup, setup_config
|
|
72
73
|
from .ops import Pipeline
|
|
73
74
|
from .probes import Probe
|
|
74
75
|
from .scan import Scan
|
zea/__main__.py
CHANGED
|
@@ -9,30 +9,22 @@ import argparse
|
|
|
9
9
|
import sys
|
|
10
10
|
from pathlib import Path
|
|
11
11
|
|
|
12
|
-
from zea import log
|
|
13
12
|
from zea.visualize import set_mpl_style
|
|
14
13
|
|
|
15
14
|
|
|
16
|
-
def
|
|
15
|
+
def get_parser():
|
|
17
16
|
"""Command line argument parser"""
|
|
18
|
-
parser = argparse.ArgumentParser(
|
|
19
|
-
|
|
17
|
+
parser = argparse.ArgumentParser(
|
|
18
|
+
description="Load and process ultrasound data based on a configuration file."
|
|
19
|
+
)
|
|
20
|
+
parser.add_argument("-c", "--config", type=str, default=None, help="path to the config file.")
|
|
20
21
|
parser.add_argument(
|
|
21
22
|
"-t",
|
|
22
23
|
"--task",
|
|
23
24
|
default="view",
|
|
24
25
|
choices=["view"],
|
|
25
26
|
type=str,
|
|
26
|
-
help="
|
|
27
|
-
)
|
|
28
|
-
parser.add_argument(
|
|
29
|
-
"--backend",
|
|
30
|
-
default=None,
|
|
31
|
-
type=str,
|
|
32
|
-
help=(
|
|
33
|
-
"Keras backend to use. Default is the one set by the environment "
|
|
34
|
-
"variable KERAS_BACKEND."
|
|
35
|
-
),
|
|
27
|
+
help="Which task to run. Currently only 'view' is supported.",
|
|
36
28
|
)
|
|
37
29
|
parser.add_argument(
|
|
38
30
|
"--skip_validate_file",
|
|
@@ -40,27 +32,18 @@ def get_args():
|
|
|
40
32
|
action="store_true",
|
|
41
33
|
help="Skip zea file integrity checks. Use with caution.",
|
|
42
34
|
)
|
|
43
|
-
parser
|
|
44
|
-
args = parser.parse_args()
|
|
45
|
-
return args
|
|
35
|
+
return parser
|
|
46
36
|
|
|
47
37
|
|
|
48
38
|
def main():
|
|
49
39
|
"""main entrypoint for zea"""
|
|
50
|
-
args =
|
|
40
|
+
args = get_parser().parse_args()
|
|
51
41
|
|
|
52
42
|
set_mpl_style()
|
|
53
43
|
|
|
54
|
-
if args.backend:
|
|
55
|
-
from zea.internal.setup_zea import set_backend
|
|
56
|
-
|
|
57
|
-
set_backend(args.backend)
|
|
58
|
-
|
|
59
44
|
wd = Path(__file__).parent.resolve()
|
|
60
45
|
sys.path.append(str(wd))
|
|
61
46
|
|
|
62
|
-
import keras
|
|
63
|
-
|
|
64
47
|
from zea.interface import Interface
|
|
65
48
|
from zea.internal.setup_zea import setup
|
|
66
49
|
|
|
@@ -72,7 +55,6 @@ def main():
|
|
|
72
55
|
validate_file=not args.skip_validate_file,
|
|
73
56
|
)
|
|
74
57
|
|
|
75
|
-
log.info(f"Using {keras.backend.backend()} backend")
|
|
76
58
|
cli.run(plot=True)
|
|
77
59
|
else:
|
|
78
60
|
raise ValueError(f"Unknown task {args.task}, see `zea --help` for available tasks.")
|
zea/agent/selection.py
CHANGED
|
@@ -11,11 +11,14 @@ For a comprehensive example usage, see: :doc:`../notebooks/agent/agent_example`
|
|
|
11
11
|
All strategies are stateless, meaning that they do not maintain any internal state.
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
|
+
from typing import Callable
|
|
15
|
+
|
|
14
16
|
import keras
|
|
15
17
|
from keras import ops
|
|
16
18
|
|
|
17
19
|
from zea import tensor_ops
|
|
18
20
|
from zea.agent import masks
|
|
21
|
+
from zea.backend.autograd import AutoGrad
|
|
19
22
|
from zea.internal.registry import action_selection_registry
|
|
20
23
|
|
|
21
24
|
|
|
@@ -493,3 +496,166 @@ class CovarianceSamplingLines(LinesActionModel):
|
|
|
493
496
|
best_mask = ops.squeeze(best_mask, axis=0)
|
|
494
497
|
|
|
495
498
|
return best_mask, self.lines_to_im_size(best_mask)
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
class TaskBasedLines(GreedyEntropy):
|
|
502
|
+
"""Task-based line selection for maximizing information gain.
|
|
503
|
+
|
|
504
|
+
This action selection strategy chooses lines to maximize information gain with respect
|
|
505
|
+
to a downstream task outcome. It uses gradient-based saliency to identify which image
|
|
506
|
+
regions contribute most to task uncertainty, then selects lines accordingly.
|
|
507
|
+
"""
|
|
508
|
+
|
|
509
|
+
def __init__(
|
|
510
|
+
self,
|
|
511
|
+
n_actions: int,
|
|
512
|
+
n_possible_actions: int,
|
|
513
|
+
img_width: int,
|
|
514
|
+
img_height: int,
|
|
515
|
+
downstream_task_function: Callable,
|
|
516
|
+
mean: float = 0,
|
|
517
|
+
std_dev: float = 1,
|
|
518
|
+
num_lines_to_update: int = 5,
|
|
519
|
+
**kwargs,
|
|
520
|
+
):
|
|
521
|
+
"""Initialize the TaskBasedLines action selection model.
|
|
522
|
+
|
|
523
|
+
Args:
|
|
524
|
+
n_actions (int): The number of actions the agent can take.
|
|
525
|
+
n_possible_actions (int): The number of possible actions (line positions).
|
|
526
|
+
img_width (int): The width of the input image.
|
|
527
|
+
img_height (int): The height of the input image.
|
|
528
|
+
downstream_task_function (Callable): A differentiable function that takes a
|
|
529
|
+
batch of inputs and produces scalar outputs. This represents the downstream
|
|
530
|
+
task for which information gain should be maximized.
|
|
531
|
+
mean (float, optional): The mean of the RBF used for reweighting. Defaults to 0.
|
|
532
|
+
std_dev (float, optional): The standard deviation of the RBF used for reweighting.
|
|
533
|
+
Defaults to 1.
|
|
534
|
+
num_lines_to_update (int, optional): The number of lines around the selected line
|
|
535
|
+
to update during reweighting. Must be odd. Defaults to 5.
|
|
536
|
+
**kwargs: Additional keyword arguments passed to the parent class.
|
|
537
|
+
"""
|
|
538
|
+
super().__init__(
|
|
539
|
+
n_actions,
|
|
540
|
+
n_possible_actions,
|
|
541
|
+
img_width,
|
|
542
|
+
img_height,
|
|
543
|
+
mean,
|
|
544
|
+
std_dev,
|
|
545
|
+
num_lines_to_update,
|
|
546
|
+
)
|
|
547
|
+
self.downstream_task_function = downstream_task_function
|
|
548
|
+
|
|
549
|
+
def compute_output_and_saliency_propagation(self, particles):
|
|
550
|
+
"""Compute saliency-weighted posterior variance for task-based selection.
|
|
551
|
+
|
|
552
|
+
This method computes how much each pixel contributes to the variance of the
|
|
553
|
+
downstream task output. It uses automatic differentiation to compute gradients
|
|
554
|
+
of the task function with respect to each particle, then weights the posterior
|
|
555
|
+
variance by the squared mean gradient.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
particles (Tensor): Particles of shape (batch_size, n_particles, height, width)
|
|
559
|
+
representing the posterior distribution over images.
|
|
560
|
+
|
|
561
|
+
Returns:
|
|
562
|
+
Tensor: Pixelwise contribution to downstream task variance,
|
|
563
|
+
of shape (batch_size, height, width). Higher values indicate pixels
|
|
564
|
+
that contribute more to task uncertainty.
|
|
565
|
+
"""
|
|
566
|
+
autograd = AutoGrad()
|
|
567
|
+
|
|
568
|
+
autograd.set_function(self.downstream_task_function)
|
|
569
|
+
downstream_grad_and_value_fn = autograd.get_gradient_and_value_jit_fn()
|
|
570
|
+
jacobian, _ = ops.vectorized_map(
|
|
571
|
+
lambda p: ops.vectorized_map(
|
|
572
|
+
downstream_grad_and_value_fn,
|
|
573
|
+
p,
|
|
574
|
+
),
|
|
575
|
+
particles,
|
|
576
|
+
)
|
|
577
|
+
|
|
578
|
+
posterior_variance = ops.var(particles, axis=1)
|
|
579
|
+
mean_jacobian = ops.mean(jacobian, axis=1)
|
|
580
|
+
return posterior_variance * (mean_jacobian**2)
|
|
581
|
+
|
|
582
|
+
def sum_neighbouring_columns_into_n_possible_actions(self, full_linewise_salience):
|
|
583
|
+
"""Aggregate column-wise saliency into line-wise saliency scores.
|
|
584
|
+
|
|
585
|
+
This method groups neighboring columns together to create saliency scores
|
|
586
|
+
for each possible line action. Since each line action may correspond to
|
|
587
|
+
multiple image columns, this aggregation is necessary to match the action space.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
full_linewise_salience (Tensor): Saliency values for each column,
|
|
591
|
+
of shape (batch_size, full_image_width).
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
Tensor: Aggregated saliency scores for each possible action,
|
|
595
|
+
of shape (batch_size, n_possible_actions).
|
|
596
|
+
|
|
597
|
+
Raises:
|
|
598
|
+
AssertionError: If the image width is not evenly divisible by n_possible_actions.
|
|
599
|
+
"""
|
|
600
|
+
batch_size = ops.shape(full_linewise_salience)[0]
|
|
601
|
+
full_image_width = ops.shape(full_linewise_salience)[1]
|
|
602
|
+
assert full_image_width % self.n_possible_actions == 0, (
|
|
603
|
+
"n_possible_actions must divide evenly into image width"
|
|
604
|
+
)
|
|
605
|
+
cols_per_action = full_image_width // self.n_possible_actions
|
|
606
|
+
stacked_linewise_salience = ops.reshape(
|
|
607
|
+
full_linewise_salience,
|
|
608
|
+
(batch_size, self.n_possible_actions, cols_per_action),
|
|
609
|
+
)
|
|
610
|
+
return ops.sum(stacked_linewise_salience, axis=2)
|
|
611
|
+
|
|
612
|
+
def sample(self, particles):
|
|
613
|
+
"""Sample actions using task-based information gain maximization.
|
|
614
|
+
|
|
615
|
+
This method computes which lines would provide the most information about
|
|
616
|
+
the downstream task by:
|
|
617
|
+
1. Computing pixelwise contribution to task variance using gradients
|
|
618
|
+
2. Aggregating contributions into line-wise scores
|
|
619
|
+
3. Greedily selecting lines with highest contribution scores
|
|
620
|
+
4. Reweighting scores around selected lines (inherited from GreedyEntropy)
|
|
621
|
+
|
|
622
|
+
Args:
|
|
623
|
+
particles (Tensor): Particles representing the posterior distribution,
|
|
624
|
+
of shape (batch_size, n_particles, height, width).
|
|
625
|
+
|
|
626
|
+
Returns:
|
|
627
|
+
Tuple[Tensor, Tensor, Tensor]:
|
|
628
|
+
- selected_lines_k_hot: Selected lines as k-hot vectors,
|
|
629
|
+
shaped (batch_size, n_possible_actions)
|
|
630
|
+
- masks: Binary masks of shape (batch_size, img_height, img_width)
|
|
631
|
+
- pixelwise_contribution_to_var_dst: Pixelwise contribution to downstream
|
|
632
|
+
task variance, of shape (batch_size, height, width)
|
|
633
|
+
|
|
634
|
+
Note:
|
|
635
|
+
Unlike the parent GreedyEntropy class, this method returns an additional
|
|
636
|
+
tensor containing the pixelwise contribution scores for analysis.
|
|
637
|
+
"""
|
|
638
|
+
pixelwise_contribution_to_var_dst = self.compute_output_and_saliency_propagation(particles)
|
|
639
|
+
linewise_contribution_to_var_dst = ops.sum(pixelwise_contribution_to_var_dst, axis=1)
|
|
640
|
+
actionwise_contribution_to_var_dst = self.sum_neighbouring_columns_into_n_possible_actions(
|
|
641
|
+
linewise_contribution_to_var_dst
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
# Greedily select best line, reweight entropies, and repeat
|
|
645
|
+
all_selected_lines = []
|
|
646
|
+
for _ in range(self.n_actions):
|
|
647
|
+
max_contribution_line, actionwise_contribution_to_var_dst = ops.vectorized_map(
|
|
648
|
+
self.select_line_and_reweight_entropy,
|
|
649
|
+
actionwise_contribution_to_var_dst,
|
|
650
|
+
)
|
|
651
|
+
all_selected_lines.append(max_contribution_line)
|
|
652
|
+
|
|
653
|
+
selected_lines_k_hot = ops.any(
|
|
654
|
+
ops.one_hot(all_selected_lines, self.n_possible_actions, dtype=masks._DEFAULT_DTYPE),
|
|
655
|
+
axis=0,
|
|
656
|
+
)
|
|
657
|
+
return (
|
|
658
|
+
selected_lines_k_hot,
|
|
659
|
+
self.lines_to_im_size(selected_lines_k_hot),
|
|
660
|
+
pixelwise_contribution_to_var_dst,
|
|
661
|
+
)
|
zea/backend/__init__.py
CHANGED
|
@@ -25,6 +25,8 @@ Key Features
|
|
|
25
25
|
|
|
26
26
|
"""
|
|
27
27
|
|
|
28
|
+
from contextlib import nullcontext
|
|
29
|
+
|
|
28
30
|
import keras
|
|
29
31
|
|
|
30
32
|
from zea import log
|
|
@@ -114,3 +116,90 @@ def _jit_compile(func, jax=True, tensorflow=True, **kwargs):
|
|
|
114
116
|
log.warning("Initialize zea.Pipeline with jit_options=None to suppress this warning.")
|
|
115
117
|
log.warning("Falling back to non-compiled mode.")
|
|
116
118
|
return func
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class on_device:
|
|
122
|
+
"""Context manager to set the device regardless of backend.
|
|
123
|
+
|
|
124
|
+
For the `torch` backend, you need to manually move the model and data to the device before
|
|
125
|
+
using this context manager.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
device (str): Device string, e.g. ``'cuda'``, ``'gpu'``, or ``'cpu'``.
|
|
129
|
+
|
|
130
|
+
Example:
|
|
131
|
+
.. code-block:: python
|
|
132
|
+
|
|
133
|
+
with zea.backend.on_device("gpu:3"):
|
|
134
|
+
pipeline = zea.Pipeline([zea.keras_ops.Abs()])
|
|
135
|
+
output = pipeline(data=keras.random.normal((10, 10))) # output is on "cuda:3"
|
|
136
|
+
"""
|
|
137
|
+
|
|
138
|
+
def __init__(self, device: str):
|
|
139
|
+
self.device = self.get_device(device)
|
|
140
|
+
self._context = self.get_context(self.device)
|
|
141
|
+
|
|
142
|
+
def get_context(self, device):
|
|
143
|
+
if device is None:
|
|
144
|
+
return nullcontext()
|
|
145
|
+
|
|
146
|
+
if keras.backend.backend() == "tensorflow":
|
|
147
|
+
import tensorflow as tf
|
|
148
|
+
|
|
149
|
+
return tf.device(device)
|
|
150
|
+
|
|
151
|
+
if keras.backend.backend() == "jax":
|
|
152
|
+
import jax
|
|
153
|
+
|
|
154
|
+
return jax.default_device(device)
|
|
155
|
+
if keras.backend.backend() == "torch":
|
|
156
|
+
import torch
|
|
157
|
+
|
|
158
|
+
return torch.device(device)
|
|
159
|
+
|
|
160
|
+
return nullcontext()
|
|
161
|
+
|
|
162
|
+
def get_device(self, device: str):
|
|
163
|
+
if device is None:
|
|
164
|
+
return None
|
|
165
|
+
|
|
166
|
+
device = device.lower()
|
|
167
|
+
|
|
168
|
+
if keras.backend.backend() == "tensorflow":
|
|
169
|
+
return device.replace("cuda", "gpu")
|
|
170
|
+
|
|
171
|
+
if keras.backend.backend() == "jax":
|
|
172
|
+
from zea.backend.jax import str_to_jax_device
|
|
173
|
+
|
|
174
|
+
device = device.replace("cuda", "gpu")
|
|
175
|
+
return str_to_jax_device(device)
|
|
176
|
+
|
|
177
|
+
if keras.backend.backend() == "torch":
|
|
178
|
+
return device.replace("gpu", "cuda")
|
|
179
|
+
|
|
180
|
+
def __enter__(self):
|
|
181
|
+
self._context.__enter__()
|
|
182
|
+
|
|
183
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
184
|
+
self._context.__exit__(exc_type, exc_val, exc_tb)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
if keras.backend.backend() in ["tensorflow", "jax", "numpy"]:
|
|
188
|
+
|
|
189
|
+
def func_on_device(func, device, *args, **kwargs):
|
|
190
|
+
"""Moves all tensor arguments of a function to a specified device before calling it.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
func (callable): Function to be called.
|
|
194
|
+
device (str): Device to move tensors to.
|
|
195
|
+
*args: Positional arguments to be passed to the function.
|
|
196
|
+
**kwargs: Keyword arguments to be passed to the function.
|
|
197
|
+
Returns:
|
|
198
|
+
The output of the function.
|
|
199
|
+
"""
|
|
200
|
+
with on_device(device):
|
|
201
|
+
return func(*args, **kwargs)
|
|
202
|
+
elif keras.backend.backend() == "torch":
|
|
203
|
+
from zea.backend.torch import func_on_device
|
|
204
|
+
else:
|
|
205
|
+
raise ValueError(f"Unsupported backend: {keras.backend.backend()}")
|
zea/backend/jax/__init__.py
CHANGED
|
@@ -1,47 +1,21 @@
|
|
|
1
1
|
"""Jax utilities for zea."""
|
|
2
2
|
|
|
3
3
|
import jax
|
|
4
|
-
import numpy as np
|
|
5
4
|
|
|
6
5
|
|
|
7
|
-
def
|
|
8
|
-
"""
|
|
9
|
-
|
|
6
|
+
def str_to_jax_device(device):
|
|
7
|
+
"""Convert a device string to a JAX device.
|
|
10
8
|
Args:
|
|
11
|
-
|
|
12
|
-
inputs (ndarray): Input array.
|
|
13
|
-
device (str): Device string, e.g. ``'cuda'``, ``'gpu'``, or ``'cpu'``.
|
|
14
|
-
return_numpy (bool, optional): Whether to convert output
|
|
15
|
-
data back to numpy. Defaults to False.
|
|
16
|
-
**kwargs: Additional keyword arguments to be passed to the ``func``.
|
|
17
|
-
|
|
9
|
+
device (str): Device string, e.g. ``'gpu:0'``, or ``'cpu:0'``.
|
|
18
10
|
Returns:
|
|
19
|
-
jax.
|
|
20
|
-
|
|
21
|
-
Raises:
|
|
22
|
-
AssertionError: If ``func`` is not a function from the JAX library.
|
|
23
|
-
|
|
24
|
-
Note:
|
|
25
|
-
This function converts the ``inputs`` array to a JAX array and moves
|
|
26
|
-
it to the specified ``device``. It then applies the ``func`` function to the inputs
|
|
27
|
-
and returns the output data. If the output is a dictionary, it extracts the first value
|
|
28
|
-
from the dictionary. If ``return_numpy`` is True, it converts the output data back to a
|
|
29
|
-
numpy array before returning.
|
|
30
|
-
|
|
31
|
-
Example:
|
|
32
|
-
.. code-block:: python
|
|
33
|
-
|
|
34
|
-
import jax.numpy as jnp
|
|
35
|
-
|
|
11
|
+
jax.Device: The corresponding JAX device.
|
|
12
|
+
"""
|
|
36
13
|
|
|
37
|
-
|
|
38
|
-
|
|
14
|
+
if not isinstance(device, str):
|
|
15
|
+
raise ValueError(f"Device must be a string, got {type(device)}")
|
|
39
16
|
|
|
17
|
+
device = device.lower().replace("cuda", "gpu")
|
|
40
18
|
|
|
41
|
-
inputs = [1, 2, 3, 4, 5]
|
|
42
|
-
device = "gpu"
|
|
43
|
-
output = on_device_jax(square, inputs, device)
|
|
44
|
-
"""
|
|
45
19
|
device = device.split(":")
|
|
46
20
|
if len(device) == 2:
|
|
47
21
|
device_type, device_number = device
|
|
@@ -51,20 +25,9 @@ def on_device_jax(func, inputs, device, return_numpy=False, **kwargs):
|
|
|
51
25
|
device_type = device[0]
|
|
52
26
|
device_number = 0
|
|
53
27
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
with jax.default_device(jax_device):
|
|
62
|
-
outputs = func(inputs, **kwargs)
|
|
63
|
-
|
|
64
|
-
if isinstance(outputs, dict):
|
|
65
|
-
outputs = list(outputs.values())[0]
|
|
66
|
-
|
|
67
|
-
if return_numpy:
|
|
68
|
-
outputs = np.array(outputs)
|
|
69
|
-
|
|
70
|
-
return outputs
|
|
28
|
+
available = jax.devices(device_type)
|
|
29
|
+
if len(available) == 0:
|
|
30
|
+
raise ValueError(f"No JAX devices available for type '{device_type}'.")
|
|
31
|
+
if device_number < 0 or device_number >= len(available):
|
|
32
|
+
raise ValueError(f"Device '{device}' is not available; JAX devices found: {available}")
|
|
33
|
+
return available[device_number]
|
|
@@ -15,52 +15,3 @@ sys.path = [str(p) if isinstance(p, PosixPath) else p for p in sys.path]
|
|
|
15
15
|
import tensorflow as tf # noqa: E402
|
|
16
16
|
|
|
17
17
|
from .dataloader import make_dataloader # noqa: E402
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def on_device_tf(func, inputs, device, return_numpy=False, **kwargs):
|
|
21
|
-
"""Applies a Tensorflow function to inputs on a specified device.
|
|
22
|
-
|
|
23
|
-
Args:
|
|
24
|
-
func (function): Function to apply to the input data.
|
|
25
|
-
inputs (ndarray): Input array.
|
|
26
|
-
device (str): Device string, e.g. ``'cuda'``, ``'gpu'``, or ``'cpu'``.
|
|
27
|
-
return_numpy (bool, optional): Whether to convert output
|
|
28
|
-
data back to numpy. Defaults to False.
|
|
29
|
-
**kwargs: Additional keyword arguments to be passed to the ``func``.
|
|
30
|
-
|
|
31
|
-
Returns:
|
|
32
|
-
tf.Tensor or ndarray: The output data.
|
|
33
|
-
|
|
34
|
-
Raises:
|
|
35
|
-
AssertionError: If ``func`` is not a function from the tensorflow library.
|
|
36
|
-
|
|
37
|
-
Note:
|
|
38
|
-
This function converts the ``inputs`` array to a tf.Tensor and moves
|
|
39
|
-
it to the specified ``device``. It then applies the ``func`` function to the inputs
|
|
40
|
-
and returns the output data. If the output is a dictionary, it extracts the first value
|
|
41
|
-
from the dictionary. If ``return_numpy`` is True, it converts the output data back to a
|
|
42
|
-
numpy array before returning.
|
|
43
|
-
|
|
44
|
-
Example:
|
|
45
|
-
.. code-block:: python
|
|
46
|
-
|
|
47
|
-
import tensorflow as tf
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def square(x):
|
|
51
|
-
return x**2
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
inputs = [1, 2, 3, 4, 5]
|
|
55
|
-
device = "cuda"
|
|
56
|
-
output = on_device_tf(square, inputs, device)
|
|
57
|
-
"""
|
|
58
|
-
device = device.replace("cuda", "gpu")
|
|
59
|
-
|
|
60
|
-
with tf.device(device):
|
|
61
|
-
outputs = func(inputs, **kwargs)
|
|
62
|
-
|
|
63
|
-
if return_numpy:
|
|
64
|
-
if not isinstance(outputs, np.ndarray):
|
|
65
|
-
outputs = outputs.numpy()
|
|
66
|
-
return outputs
|
zea/backend/torch/__init__.py
CHANGED
|
@@ -3,72 +3,37 @@
|
|
|
3
3
|
Initialize modules for registries.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
import numpy as np
|
|
7
6
|
import torch
|
|
8
7
|
|
|
9
8
|
|
|
10
|
-
def
|
|
11
|
-
"""
|
|
9
|
+
def func_on_device(func, device, *args, **kwargs):
|
|
10
|
+
"""Moves all tensor arguments of a function to a specified device before calling it.
|
|
12
11
|
|
|
13
12
|
Args:
|
|
14
|
-
func (
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
data back to numpy. Defaults to False.
|
|
19
|
-
**kwargs: Additional keyword arguments to be passed to the ``func``.
|
|
20
|
-
|
|
13
|
+
func (callable): Function to be called.
|
|
14
|
+
device (str or torch.device): Device to move tensors to.
|
|
15
|
+
*args: Positional arguments to be passed to the function.
|
|
16
|
+
**kwargs: Keyword arguments to be passed to the function.
|
|
21
17
|
Returns:
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
Raises:
|
|
25
|
-
AssertionError: If ``func`` is not a function from the torch library.
|
|
26
|
-
|
|
27
|
-
Note:
|
|
28
|
-
This function converts the ``inputs`` array to a torch.Tensor and moves it to
|
|
29
|
-
the specified ``device``. It then applies the ``func`` function to the inputs and
|
|
30
|
-
returns the output data. If the output is a dictionary, it extracts the first value
|
|
31
|
-
from the dictionary. If ``return_numpy`` is True, it converts the output data back to a
|
|
32
|
-
numpy array before returning.
|
|
33
|
-
|
|
34
|
-
Example:
|
|
35
|
-
.. code-block:: python
|
|
36
|
-
|
|
37
|
-
import torch
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def square(x):
|
|
41
|
-
return x**2
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
inputs = [1, 2, 3, 4, 5]
|
|
45
|
-
device = "cuda"
|
|
46
|
-
output = on_device_torch(square, inputs, device)
|
|
47
|
-
print(output)
|
|
18
|
+
The output of the function.
|
|
48
19
|
"""
|
|
49
|
-
device
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
if return_numpy:
|
|
71
|
-
if not isinstance(outputs, np.ndarray):
|
|
72
|
-
outputs = outputs.cpu().numpy()
|
|
73
|
-
|
|
74
|
-
return outputs
|
|
20
|
+
if device is None:
|
|
21
|
+
return func(*args, **kwargs)
|
|
22
|
+
|
|
23
|
+
if isinstance(device, str):
|
|
24
|
+
device = torch.device(device)
|
|
25
|
+
|
|
26
|
+
def move_to_device(x):
|
|
27
|
+
if isinstance(x, torch.Tensor):
|
|
28
|
+
return x.to(device)
|
|
29
|
+
elif isinstance(x, (list, tuple)):
|
|
30
|
+
return type(x)(move_to_device(i) for i in x)
|
|
31
|
+
elif isinstance(x, dict):
|
|
32
|
+
return {k: move_to_device(v) for k, v in x.items()}
|
|
33
|
+
else:
|
|
34
|
+
return x
|
|
35
|
+
|
|
36
|
+
args = move_to_device(args)
|
|
37
|
+
kwargs = move_to_device(kwargs)
|
|
38
|
+
|
|
39
|
+
return func(*args, **kwargs)
|
zea/data/__main__.py
CHANGED
|
@@ -9,8 +9,8 @@ import argparse
|
|
|
9
9
|
from zea import Folder
|
|
10
10
|
|
|
11
11
|
|
|
12
|
-
def
|
|
13
|
-
parser = argparse.ArgumentParser(description="Copy a zea.Folder to a new location.")
|
|
12
|
+
def get_parser():
|
|
13
|
+
parser = argparse.ArgumentParser(description="Copy a :class:`zea.Folder` to a new location.")
|
|
14
14
|
parser.add_argument("src", help="Source folder path")
|
|
15
15
|
parser.add_argument("dst", help="Destination folder path")
|
|
16
16
|
parser.add_argument("key", help="Key to access in the hdf5 files")
|
|
@@ -20,8 +20,11 @@ def main():
|
|
|
20
20
|
choices=["a", "w", "r+", "x"],
|
|
21
21
|
help="Mode in which to open the destination files (default: 'a')",
|
|
22
22
|
)
|
|
23
|
+
return parser
|
|
24
|
+
|
|
23
25
|
|
|
24
|
-
|
|
26
|
+
def main():
|
|
27
|
+
args = get_parser().parse_args()
|
|
25
28
|
|
|
26
29
|
src_folder = Folder(args.src, args.key, validate=False)
|
|
27
30
|
src_folder.copy(args.dst, args.key, mode=args.mode)
|