zea 0.0.0__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zea/__init__.py +74 -0
- zea/__main__.py +82 -0
- zea/agent/__init__.py +27 -0
- zea/agent/gumbel.py +107 -0
- zea/agent/masks.py +192 -0
- zea/agent/selection.py +478 -0
- zea/backend/__init__.py +114 -0
- zea/backend/autograd.py +181 -0
- zea/backend/jax/__init__.py +70 -0
- zea/backend/tensorflow/__init__.py +66 -0
- zea/backend/tensorflow/dataloader.py +372 -0
- zea/backend/tensorflow/layers/__init__.py +0 -0
- zea/backend/tensorflow/layers/apodization.py +37 -0
- zea/backend/tensorflow/layers/utils.py +105 -0
- zea/backend/tensorflow/losses.py +64 -0
- zea/backend/tensorflow/models/__init__.py +0 -0
- zea/backend/tensorflow/models/lista.py +112 -0
- zea/backend/tensorflow/scripts/convert-echonet-dynamic.py +139 -0
- zea/backend/tensorflow/scripts/convert-taesd.py +88 -0
- zea/backend/tensorflow/utils/__init__.py +0 -0
- zea/backend/tensorflow/utils/callbacks.py +4 -0
- zea/backend/tensorflow/utils/utils.py +35 -0
- zea/backend/tf2jax.py +10 -0
- zea/backend/torch/__init__.py +74 -0
- zea/backend/torch/losses.py +64 -0
- zea/beamform/__init__.py +20 -0
- zea/beamform/beamformer.py +498 -0
- zea/beamform/delays.py +153 -0
- zea/beamform/lens_correction.py +198 -0
- zea/beamform/pfield.py +406 -0
- zea/beamform/phantoms.py +43 -0
- zea/beamform/pixelgrid.py +130 -0
- zea/config.py +520 -0
- zea/data/__init__.py +55 -0
- zea/data/__main__.py +31 -0
- zea/data/augmentations.py +329 -0
- zea/data/convert/__init__.py +6 -0
- zea/data/convert/camus.py +259 -0
- zea/data/convert/echonet.py +438 -0
- zea/data/convert/echonetlvh/README.md +9 -0
- zea/data/convert/echonetlvh/convert_raw_to_usbmd.py +499 -0
- zea/data/convert/echonetlvh/precompute_crop.py +251 -0
- zea/data/convert/images.py +135 -0
- zea/data/convert/matlab.py +1230 -0
- zea/data/convert/picmus.py +182 -0
- zea/data/data_format.py +718 -0
- zea/data/dataloader.py +409 -0
- zea/data/datasets.py +650 -0
- zea/data/file.py +844 -0
- zea/data/layers.py +167 -0
- zea/data/preset_utils.py +109 -0
- zea/data/utils.py +90 -0
- zea/datapaths.py +566 -0
- zea/display.py +666 -0
- zea/interface.py +546 -0
- zea/internal/cache.py +284 -0
- zea/internal/checks.py +301 -0
- zea/internal/config/create.py +161 -0
- zea/internal/config/parameters.py +123 -0
- zea/internal/config/validation.py +165 -0
- zea/internal/convert.py +150 -0
- zea/internal/core.py +314 -0
- zea/internal/device.py +414 -0
- zea/internal/git_info.py +44 -0
- zea/internal/operators.py +72 -0
- zea/internal/parameters.py +425 -0
- zea/internal/registry.py +203 -0
- zea/internal/setup_zea.py +223 -0
- zea/internal/viewer.py +450 -0
- zea/io_lib.py +350 -0
- zea/log.py +356 -0
- zea/metrics.py +158 -0
- zea/models/__init__.py +88 -0
- zea/models/base.py +195 -0
- zea/models/carotid_segmenter.py +168 -0
- zea/models/dense.py +132 -0
- zea/models/diffusion.py +842 -0
- zea/models/echonet.py +181 -0
- zea/models/generative.py +75 -0
- zea/models/gmm.py +208 -0
- zea/models/layers.py +69 -0
- zea/models/lpips.py +181 -0
- zea/models/preset_utils.py +414 -0
- zea/models/presets.py +100 -0
- zea/models/taesd.py +245 -0
- zea/models/unet.py +207 -0
- zea/models/utils.py +60 -0
- zea/ops.py +3026 -0
- zea/probes.py +224 -0
- zea/scan.py +619 -0
- zea/simulator.py +343 -0
- zea/tensor_ops.py +1327 -0
- zea/tools/__init__.py +8 -0
- zea/tools/fit_scan_cone.py +709 -0
- zea/tools/hf.py +174 -0
- zea/tools/selection_tool.py +847 -0
- zea/tools/wndb.py +22 -0
- zea/utils.py +664 -0
- zea/visualize.py +634 -0
- zea/zea_darkmode.mplstyle +798 -0
- zea-0.0.2.dist-info/LICENSE +202 -0
- zea-0.0.2.dist-info/METADATA +115 -0
- zea-0.0.2.dist-info/RECORD +105 -0
- {zea-0.0.0.dist-info → zea-0.0.2.dist-info}/WHEEL +1 -2
- zea-0.0.2.dist-info/entry_points.txt +3 -0
- zea-0.0.0.dist-info/METADATA +0 -17
- zea-0.0.0.dist-info/RECORD +0 -5
- zea-0.0.0.dist-info/top_level.txt +0 -1
zea/__init__.py
CHANGED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
"""``zea``: *A Toolbox for Cognitive Ultrasound Imaging.*"""
|
|
2
|
+
|
|
3
|
+
import importlib.util
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
from . import log
|
|
7
|
+
|
|
8
|
+
# dynamically add __version__ attribute (see pyproject.toml)
|
|
9
|
+
# __version__ = __import__("importlib.metadata").metadata.version(__package__)
|
|
10
|
+
__version__ = "0.0.2"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def setup():
|
|
14
|
+
"""Setup function to initialize the zea package."""
|
|
15
|
+
|
|
16
|
+
def _check_backend_installed():
|
|
17
|
+
"""Assert that at least one ML backend (torch, tensorflow, jax) is installed.
|
|
18
|
+
If not, raise an AssertionError with a helpful install message.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
ml_backends = ["torch", "tensorflow", "jax"]
|
|
22
|
+
for backend in ml_backends:
|
|
23
|
+
if importlib.util.find_spec(backend) is not None:
|
|
24
|
+
return
|
|
25
|
+
|
|
26
|
+
backend_env = os.environ.get("KERAS_BACKEND", "numpy")
|
|
27
|
+
install_guide_urls = {
|
|
28
|
+
"torch": "https://pytorch.org/get-started/locally/",
|
|
29
|
+
"tensorflow": "https://www.tensorflow.org/install",
|
|
30
|
+
"jax": "https://docs.jax.dev/en/latest/installation.html",
|
|
31
|
+
}
|
|
32
|
+
guide_url = install_guide_urls.get(backend_env, "https://keras.io/getting_started/")
|
|
33
|
+
raise ImportError(
|
|
34
|
+
"No ML backend (torch, tensorflow, jax) installed in current environment. "
|
|
35
|
+
f"Please install at least one ML backend before importing {__package__} or "
|
|
36
|
+
f"any other library. Current KERAS_BACKEND is set to '{backend_env}', "
|
|
37
|
+
f"please install it first, see: {guide_url}. One simple alternative is to "
|
|
38
|
+
f"install with default backend: `pip install {__package__}[jax]`."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
_check_backend_installed()
|
|
42
|
+
|
|
43
|
+
import keras
|
|
44
|
+
|
|
45
|
+
log.info(f"Using backend {keras.backend.backend()!r}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
# call and clean up namespace
|
|
49
|
+
setup()
|
|
50
|
+
del setup
|
|
51
|
+
|
|
52
|
+
from . import (
|
|
53
|
+
agent,
|
|
54
|
+
beamform,
|
|
55
|
+
data,
|
|
56
|
+
display,
|
|
57
|
+
io_lib,
|
|
58
|
+
metrics,
|
|
59
|
+
models,
|
|
60
|
+
simulator,
|
|
61
|
+
tensor_ops,
|
|
62
|
+
utils,
|
|
63
|
+
visualize,
|
|
64
|
+
)
|
|
65
|
+
from .config import Config
|
|
66
|
+
from .data.datasets import Dataset, Folder
|
|
67
|
+
from .data.file import File, load_file
|
|
68
|
+
from .datapaths import set_data_paths
|
|
69
|
+
from .interface import Interface
|
|
70
|
+
from .internal.device import init_device
|
|
71
|
+
from .internal.setup_zea import set_backend, setup, setup_config
|
|
72
|
+
from .ops import Pipeline
|
|
73
|
+
from .probes import Probe
|
|
74
|
+
from .scan import Scan
|
zea/__main__.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
"""Main entry point for zea
|
|
2
|
+
|
|
3
|
+
Run as `zea --config path/to/config.yaml` to start the zea interface.
|
|
4
|
+
Or do not pass a config file to open a file dialog to choose a config file.
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import argparse
|
|
9
|
+
import sys
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
from zea import log
|
|
13
|
+
from zea.visualize import set_mpl_style
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def get_args():
|
|
17
|
+
"""Command line argument parser"""
|
|
18
|
+
parser = argparse.ArgumentParser(description="Process ultrasound data.")
|
|
19
|
+
parser.add_argument("-c", "--config", type=str, default=None, help="path to config file.")
|
|
20
|
+
parser.add_argument(
|
|
21
|
+
"-t",
|
|
22
|
+
"--task",
|
|
23
|
+
default="view",
|
|
24
|
+
choices=["view"],
|
|
25
|
+
type=str,
|
|
26
|
+
help="which task to run",
|
|
27
|
+
)
|
|
28
|
+
parser.add_argument(
|
|
29
|
+
"--backend",
|
|
30
|
+
default=None,
|
|
31
|
+
type=str,
|
|
32
|
+
help=(
|
|
33
|
+
"Keras backend to use. Default is the one set by the environment "
|
|
34
|
+
"variable KERAS_BACKEND."
|
|
35
|
+
),
|
|
36
|
+
)
|
|
37
|
+
parser.add_argument(
|
|
38
|
+
"--skip_validate_file",
|
|
39
|
+
default=False,
|
|
40
|
+
action="store_true",
|
|
41
|
+
help="Skip zea file integrity checks. Use with caution.",
|
|
42
|
+
)
|
|
43
|
+
parser.add_argument("--gui", default=False, action=argparse.BooleanOptionalAction)
|
|
44
|
+
args = parser.parse_args()
|
|
45
|
+
return args
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def main():
|
|
49
|
+
"""main entrypoint for zea"""
|
|
50
|
+
args = get_args()
|
|
51
|
+
|
|
52
|
+
set_mpl_style()
|
|
53
|
+
|
|
54
|
+
if args.backend:
|
|
55
|
+
from zea.internal.setup_zea import set_backend
|
|
56
|
+
|
|
57
|
+
set_backend(args.backend)
|
|
58
|
+
|
|
59
|
+
wd = Path(__file__).parent.resolve()
|
|
60
|
+
sys.path.append(str(wd))
|
|
61
|
+
|
|
62
|
+
import keras
|
|
63
|
+
|
|
64
|
+
from zea.interface import Interface
|
|
65
|
+
from zea.internal.setup_zea import setup
|
|
66
|
+
|
|
67
|
+
config = setup(args.config)
|
|
68
|
+
|
|
69
|
+
if args.task == "view":
|
|
70
|
+
cli = Interface(
|
|
71
|
+
config,
|
|
72
|
+
validate_file=not args.skip_validate_file,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
log.info(f"Using {keras.backend.backend()} backend")
|
|
76
|
+
cli.run(plot=True)
|
|
77
|
+
else:
|
|
78
|
+
raise ValueError(f"Unknown task {args.task}, see `zea --help` for available tasks.")
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
if __name__ == "__main__":
|
|
82
|
+
main()
|
zea/agent/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Agent subpackage for closing action-perception loop in ultrasound imaging.
|
|
2
|
+
|
|
3
|
+
The `agent` subpackage provides tools and utilities for agent-based algorithms within the ``zea`` framework, including mask generation and action selection strategies. See :mod:`zea.agent.masks` and :mod:`zea.agent.selection` for key functions implementing intelligent focused transmit selection, such as the :class:`zea.agent.selection.GreedyEntropy` algorithm.
|
|
4
|
+
|
|
5
|
+
For a practical example, see :doc:`../notebooks/agent/agent_example`.
|
|
6
|
+
|
|
7
|
+
Example usage
|
|
8
|
+
^^^^^^^^^^^^^
|
|
9
|
+
|
|
10
|
+
.. code-block:: python
|
|
11
|
+
|
|
12
|
+
import zea
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
agent = zea.agent.selection.GreedyEntropy(
|
|
16
|
+
n_actions=7,
|
|
17
|
+
n_possible_actions=112,
|
|
18
|
+
img_width=112,
|
|
19
|
+
img_height=112,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
# (batch, samples, height, width)
|
|
23
|
+
particles = np.random.rand(1, 10, 112, 112)
|
|
24
|
+
lines, mask = agent.sample(particles)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from . import masks, selection
|
zea/agent/gumbel.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Gumbel-Softmax trick implemented with the multi-backend ``keras.ops``."""
|
|
2
|
+
|
|
3
|
+
import keras
|
|
4
|
+
import numpy as np
|
|
5
|
+
from keras import ops
|
|
6
|
+
|
|
7
|
+
if keras.backend.backend() != "jax":
|
|
8
|
+
# This allows tensorflow tracing
|
|
9
|
+
prod = ops.prod
|
|
10
|
+
else:
|
|
11
|
+
# Jax does not allow shapes to be tensors
|
|
12
|
+
prod = np.prod
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class SubsetOperator:
|
|
16
|
+
"""SubsetOperator applies the Gumbel-Softmax trick for continuous top-k selection.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
k (int): The number of elements to select.
|
|
20
|
+
tau (float, optional): The temperature parameter for Gumbel-Softmax. Defaults to 1.0.
|
|
21
|
+
hard (bool, optional): Whether to use straight-through Gumbel-Softmax. Defaults to False.
|
|
22
|
+
|
|
23
|
+
Sources:
|
|
24
|
+
- `Reparameterizable Subset Sampling via Continuous Relaxations <https://github.com/ermongroup/subsets>`_
|
|
25
|
+
- `Sampling Subsets with Gumbel-Top Relaxations <https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/DL2/sampling/subsets.html>`_
|
|
26
|
+
""" # noqa: E501
|
|
27
|
+
|
|
28
|
+
def __init__(self, k, tau=1.0, hard=False, n_value_dims=1):
|
|
29
|
+
self.k = k
|
|
30
|
+
self.tau = tau
|
|
31
|
+
self.hard = hard
|
|
32
|
+
self.EPSILON = np.finfo(np.float32).tiny
|
|
33
|
+
self.n_value_dims = n_value_dims # for a image mask: n_value_dims=2
|
|
34
|
+
|
|
35
|
+
def gumbel_sample(self, shape):
|
|
36
|
+
"""Samples from Gumbel(0,1) distribution"""
|
|
37
|
+
uniform = keras.random.uniform(shape, minval=0, maxval=1)
|
|
38
|
+
return -ops.log(-ops.log(uniform + self.EPSILON) + self.EPSILON)
|
|
39
|
+
|
|
40
|
+
def __call__(self, scores):
|
|
41
|
+
# Gumbel-Softmax trick to make the sampling differentiable
|
|
42
|
+
gumbel_noise = self.gumbel_sample(ops.shape(scores))
|
|
43
|
+
scores = scores + gumbel_noise
|
|
44
|
+
|
|
45
|
+
# Continuous top-k selection
|
|
46
|
+
khot = ops.zeros_like(scores)
|
|
47
|
+
onehot_approx = ops.zeros_like(scores)
|
|
48
|
+
|
|
49
|
+
for _ in range(self.k):
|
|
50
|
+
khot_mask = ops.max(1.0 - onehot_approx, self.EPSILON)
|
|
51
|
+
scores = scores + ops.log(khot_mask)
|
|
52
|
+
onehot_approx = ops.softmax(scores / self.tau, axis=1)
|
|
53
|
+
khot = khot + onehot_approx
|
|
54
|
+
|
|
55
|
+
# Optionally convert soft selection to hard selection using straight-through estimator
|
|
56
|
+
if self.hard:
|
|
57
|
+
res = hard_straight_through(khot, self.k, self.n_value_dims)
|
|
58
|
+
else:
|
|
59
|
+
res = khot
|
|
60
|
+
|
|
61
|
+
return res
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def hard_straight_through(khot_orig, k, n_value_dims=1):
|
|
65
|
+
"""Applies the hard straight-through estimator to the given k-hot encoded tensor.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
khot_orig (Tensor): The original k-hot encoded tensor.
|
|
69
|
+
k (int): The number of top elements to select.
|
|
70
|
+
n_value_dims (int, optional): The number of value dimensions in the input tensor.
|
|
71
|
+
Defaults to 1. E.g. for a 2D image mask, `n_value_dims=2`.
|
|
72
|
+
Returns:
|
|
73
|
+
Tensor: The tensor after applying the hard straight-through estimator,
|
|
74
|
+
with the same shape as `khot_orig`.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
# Extract the batch dimensions and the value dimensions
|
|
78
|
+
original_shape = ops.shape(khot_orig)
|
|
79
|
+
value_dims = original_shape[-n_value_dims:]
|
|
80
|
+
|
|
81
|
+
# Flatten the input tensor along the value dimensions
|
|
82
|
+
khot = ops.reshape(khot_orig, (-1, prod(value_dims)))
|
|
83
|
+
|
|
84
|
+
# Get the top-k indices
|
|
85
|
+
indices = ops.top_k(khot, k)[1]
|
|
86
|
+
|
|
87
|
+
# Reshape the indices for use with ops.scatter
|
|
88
|
+
scatter_indices = ops.stack(
|
|
89
|
+
[
|
|
90
|
+
ops.repeat(ops.arange(ops.shape(khot)[0]), k),
|
|
91
|
+
ops.reshape(indices, (-1,)),
|
|
92
|
+
],
|
|
93
|
+
axis=-1,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
# Create the hard k-hot tensor
|
|
97
|
+
khot_hard = ops.scatter(
|
|
98
|
+
scatter_indices,
|
|
99
|
+
ops.ones(prod(ops.shape(indices)), "float32"),
|
|
100
|
+
ops.shape(khot),
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Straight-through estimator
|
|
104
|
+
out = khot_hard - ops.stop_gradient(khot) + khot
|
|
105
|
+
|
|
106
|
+
# Reshape to the original shape
|
|
107
|
+
return ops.reshape(out, original_shape)
|
zea/agent/masks.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mask generation utilities.
|
|
3
|
+
|
|
4
|
+
These masks are used as a measurement operator for focused scan-line subsampling.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import keras
|
|
10
|
+
from keras import ops
|
|
11
|
+
|
|
12
|
+
from zea import tensor_ops
|
|
13
|
+
from zea.agent.gumbel import hard_straight_through
|
|
14
|
+
|
|
15
|
+
_DEFAULT_DTYPE = "bool"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def indices_to_k_hot(
|
|
19
|
+
indices: List[int],
|
|
20
|
+
n_possible_actions: int,
|
|
21
|
+
dtype=_DEFAULT_DTYPE,
|
|
22
|
+
):
|
|
23
|
+
"""Convert a list of indices to a k-hot encoded vector.
|
|
24
|
+
|
|
25
|
+
A k-hot encoded vector is suitable during tracing when the number of actions can change.
|
|
26
|
+
This is the default represenation for actions in zea.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
indices (List[int]): List of indices to set to 1.
|
|
30
|
+
n_possible_actions (int): Total number of possible actions.
|
|
31
|
+
dtype (str, optional): Data type of the mask. Defaults to _DEFAULT_DTYPE.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
Tensor: k-hot-encoded vector of shape (n_possible_actions).
|
|
35
|
+
"""
|
|
36
|
+
mask = ops.zeros(n_possible_actions, dtype=dtype)
|
|
37
|
+
return ops.scatter_update(
|
|
38
|
+
mask, ops.expand_dims(indices, axis=1), ops.ones(len(indices), dtype=dtype)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def k_hot_to_indices(selected_lines, n_actions: int, fill_value=-1):
|
|
43
|
+
"""Convert k-hot encoded lines to indices of selected actions.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
selected_lines (Tensor): k-hot encoded lines of shape (batch_size, n_possible_actions).
|
|
47
|
+
n_actions (int): Number of lines selected.
|
|
48
|
+
fill_value (int, optional): Value to fill in case there are not enough selected actions.
|
|
49
|
+
Defaults to -1.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Tensor: Indices of selected actions of shape (batch_size, n_actions).
|
|
53
|
+
If there are fewer than `n_actions` selected, the remaining indices will be
|
|
54
|
+
filled with `fill_value`.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# Find nonzero indices for each frame
|
|
58
|
+
def get_nonzero(row):
|
|
59
|
+
return tensor_ops.nonzero(row > 0, size=n_actions, fill_value=fill_value)[0]
|
|
60
|
+
|
|
61
|
+
indices = ops.vectorized_map(get_nonzero, selected_lines)
|
|
62
|
+
return indices
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def random_uniform_lines(
|
|
66
|
+
n_actions: int,
|
|
67
|
+
n_possible_actions: int,
|
|
68
|
+
n_masks: int,
|
|
69
|
+
seed: int | keras.random.SeedGenerator = None,
|
|
70
|
+
dtype=_DEFAULT_DTYPE,
|
|
71
|
+
):
|
|
72
|
+
"""Will generate a mask with random lines.
|
|
73
|
+
|
|
74
|
+
Guarantees precisely n_actions.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
n_actions (int): Number of actions to be selected.
|
|
78
|
+
n_possible_actions (int): Number of possible actions.
|
|
79
|
+
n_masks (int): Number of masks to generate.
|
|
80
|
+
seed (int | SeedGenerator | jax.random.key, optional): Seed for random number generation.
|
|
81
|
+
Defaults to None.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Tensor: k-hot-encoded line vectors of shape (n_masks, n_possible_actions).
|
|
85
|
+
Needs to be converted to image size.
|
|
86
|
+
"""
|
|
87
|
+
masks = keras.random.uniform([n_masks, n_possible_actions], seed=seed, dtype="float32")
|
|
88
|
+
masks = hard_straight_through(masks, n_actions)
|
|
89
|
+
return ops.cast(masks, dtype=dtype)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _assert_equal_spacing(n_actions, n_possible_actions):
|
|
93
|
+
assert n_possible_actions % n_actions == 0, (
|
|
94
|
+
"Number of actions must divide evenly into possible actions to use equispaced sampling."
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def initial_equispaced_lines(
|
|
99
|
+
n_actions, n_possible_actions, dtype=_DEFAULT_DTYPE, assert_equal_spacing=True
|
|
100
|
+
):
|
|
101
|
+
"""Generate an initial equispaced k-hot line mask.
|
|
102
|
+
|
|
103
|
+
For example, if ``n_actions=2`` and ``n_possible_actions=6``,
|
|
104
|
+
then ``initial_mask=[1, 0, 0, 1, 0, 0]``.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
n_actions (int): Number of actions to be selected.
|
|
108
|
+
n_possible_actions (int): Number of possible actions.
|
|
109
|
+
dtype (str, optional): Data type of the mask. Defaults to _DEFAULT_DTYPE.
|
|
110
|
+
assert_equal_spacing (bool, optional): If True, asserts that
|
|
111
|
+
`n_possible_actions` is divisible by `n_actions`, this means that every
|
|
112
|
+
line will have the exact same spacing. Otherwise, there might be
|
|
113
|
+
some spacing differences. Defaults to True.
|
|
114
|
+
|
|
115
|
+
Returns:
|
|
116
|
+
Tensor: k-hot-encoded line vector of shape (n_possible_actions).
|
|
117
|
+
Needs to be converted to image size.
|
|
118
|
+
"""
|
|
119
|
+
if assert_equal_spacing:
|
|
120
|
+
_assert_equal_spacing(n_actions, n_possible_actions)
|
|
121
|
+
selected_indices = ops.arange(0, n_possible_actions, n_possible_actions // n_actions)
|
|
122
|
+
else:
|
|
123
|
+
selected_indices = ops.linspace(0, n_possible_actions - 1, n_actions, dtype="int32")
|
|
124
|
+
|
|
125
|
+
return indices_to_k_hot(selected_indices, n_possible_actions, dtype=dtype)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def next_equispaced_lines(previous_lines, shift=1):
|
|
129
|
+
"""
|
|
130
|
+
Rolls the previous equispaced mask of shape (..., n_possible_actions) to the right by
|
|
131
|
+
`shift` which is 1 by default.
|
|
132
|
+
"""
|
|
133
|
+
return ops.roll(previous_lines, shift=shift, axis=-1)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def lines_to_im_size(lines, img_size: tuple):
|
|
137
|
+
"""
|
|
138
|
+
Convert k-hot-encoded line vectors to image size.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
lines (Tensor): shape is (n_masks, n_possible_actions)
|
|
142
|
+
img_size (tuple): (height, width)
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Tensor: Masks of shape (n_masks, img_size, img_size)
|
|
146
|
+
"""
|
|
147
|
+
height, width = img_size
|
|
148
|
+
|
|
149
|
+
remainder = width % ops.shape(lines)[1]
|
|
150
|
+
assert remainder == 0, (
|
|
151
|
+
f"Width must be divisible by number of actions. Got remainder: {remainder}."
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Repeat till width of image
|
|
155
|
+
masks = ops.repeat(lines, width // ops.shape(lines)[1], axis=1)
|
|
156
|
+
|
|
157
|
+
# Repeat till height of image
|
|
158
|
+
masks = ops.repeat(masks[:, None], height, axis=1)
|
|
159
|
+
|
|
160
|
+
return masks
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def make_line_mask(
|
|
164
|
+
line_indices: List[int],
|
|
165
|
+
image_shape: List[int],
|
|
166
|
+
line_width: int = 1,
|
|
167
|
+
dtype=_DEFAULT_DTYPE,
|
|
168
|
+
):
|
|
169
|
+
"""
|
|
170
|
+
Creates a mask with vertical (i.e. second axis) lines at specified indices.
|
|
171
|
+
|
|
172
|
+
Args:
|
|
173
|
+
line_indices (List[int]): A list of indices where the lines should be drawn.
|
|
174
|
+
image_shape (List[int]): The shape of the image as [height, width, channels].
|
|
175
|
+
line_width (int, optional): The width of each line. Defaults to 1.
|
|
176
|
+
dtype (str, optional): The data type of the mask. Defaults to "float32".
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
mask (Tensor): A tensor of the same shape as `image_shape` with lines drawn
|
|
180
|
+
at the specified indices.
|
|
181
|
+
"""
|
|
182
|
+
height, width, channels = image_shape
|
|
183
|
+
|
|
184
|
+
# Create k-hot vector for the line indices
|
|
185
|
+
k_hot = indices_to_k_hot(line_indices, width // line_width, dtype=dtype)
|
|
186
|
+
# Expand to (1, n_possible_actions) for lines_to_im_size
|
|
187
|
+
k_hot = ops.expand_dims(k_hot, axis=0)
|
|
188
|
+
# Use lines_to_im_size to create the mask of shape (1, height, width)
|
|
189
|
+
mask_2d = lines_to_im_size(k_hot, (height, width))[0]
|
|
190
|
+
|
|
191
|
+
# Expand to (height, width, channels)
|
|
192
|
+
return ops.broadcast_to(mask_2d[..., None], (height, width, channels))
|