zea 0.0.9__tar.gz → 0.0.10__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {zea-0.0.9 → zea-0.0.10}/PKG-INFO +3 -5
- {zea-0.0.9 → zea-0.0.10}/pyproject.toml +3 -5
- {zea-0.0.9 → zea-0.0.10}/zea/__init__.py +11 -5
- {zea-0.0.9 → zea-0.0.10}/zea/agent/masks.py +15 -3
- {zea-0.0.9 → zea-0.0.10}/zea/agent/selection.py +12 -3
- {zea-0.0.9 → zea-0.0.10}/zea/backend/__init__.py +17 -3
- {zea-0.0.9 → zea-0.0.10}/zea/beamform/beamformer.py +158 -89
- {zea-0.0.9 → zea-0.0.10}/zea/beamform/delays.py +12 -9
- {zea-0.0.9 → zea-0.0.10}/zea/beamform/lens_correction.py +0 -73
- {zea-0.0.9 → zea-0.0.10}/zea/beamform/pfield.py +4 -10
- zea-0.0.10/zea/beamform/phantoms.py +145 -0
- zea-0.0.10/zea/beamform/pixelgrid.py +189 -0
- {zea-0.0.9 → zea-0.0.10}/zea/config.py +2 -2
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/__main__.py +17 -7
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonetlvh/__init__.py +47 -50
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonetlvh/precompute_crop.py +12 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/images.py +1 -1
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/verasonics.py +375 -119
- {zea-0.0.9 → zea-0.0.10}/zea/data/data_format.py +53 -8
- {zea-0.0.9 → zea-0.0.10}/zea/data/datasets.py +0 -7
- {zea-0.0.9 → zea-0.0.10}/zea/data/file.py +8 -2
- {zea-0.0.9 → zea-0.0.10}/zea/data/file_operations.py +2 -0
- {zea-0.0.9 → zea-0.0.10}/zea/display.py +26 -9
- {zea-0.0.9 → zea-0.0.10}/zea/doppler.py +1 -1
- {zea-0.0.9 → zea-0.0.10}/zea/func/__init__.py +6 -0
- {zea-0.0.9 → zea-0.0.10}/zea/func/tensor.py +4 -2
- {zea-0.0.9 → zea-0.0.10}/zea/func/ultrasound.py +158 -62
- {zea-0.0.9 → zea-0.0.10}/zea/internal/config/parameters.py +1 -1
- {zea-0.0.9 → zea-0.0.10}/zea/internal/device.py +6 -1
- {zea-0.0.9 → zea-0.0.10}/zea/internal/dummy_scan.py +15 -8
- zea-0.0.10/zea/internal/notebooks.py +152 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/parameters.py +20 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/registry.py +1 -1
- {zea-0.0.9 → zea-0.0.10}/zea/metrics.py +84 -68
- {zea-0.0.9 → zea-0.0.10}/zea/models/__init__.py +1 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/diffusion.py +5 -0
- zea-0.0.10/zea/models/hvae/__init__.py +243 -0
- zea-0.0.10/zea/models/hvae/model.py +1139 -0
- zea-0.0.10/zea/models/hvae/utils.py +616 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/layers.py +1 -1
- {zea-0.0.9 → zea-0.0.10}/zea/models/lpips.py +12 -2
- {zea-0.0.9 → zea-0.0.10}/zea/models/presets.py +16 -0
- {zea-0.0.9 → zea-0.0.10}/zea/ops/__init__.py +6 -4
- {zea-0.0.9 → zea-0.0.10}/zea/ops/base.py +28 -29
- {zea-0.0.9 → zea-0.0.10}/zea/ops/pipeline.py +17 -7
- {zea-0.0.9 → zea-0.0.10}/zea/ops/tensor.py +50 -73
- {zea-0.0.9 → zea-0.0.10}/zea/ops/ultrasound.py +250 -103
- {zea-0.0.9 → zea-0.0.10}/zea/probes.py +2 -0
- {zea-0.0.9 → zea-0.0.10}/zea/scan.py +223 -83
- {zea-0.0.9 → zea-0.0.10}/zea/simulator.py +8 -8
- {zea-0.0.9 → zea-0.0.10}/zea/tools/fit_scan_cone.py +8 -6
- {zea-0.0.9 → zea-0.0.10}/zea/tools/selection_tool.py +13 -7
- {zea-0.0.9 → zea-0.0.10}/zea/visualize.py +3 -1
- zea-0.0.9/zea/beamform/phantoms.py +0 -43
- zea-0.0.9/zea/beamform/pixelgrid.py +0 -131
- zea-0.0.9/zea/internal/notebooks.py +0 -39
- {zea-0.0.9 → zea-0.0.10}/LICENSE +0 -0
- {zea-0.0.9 → zea-0.0.10}/README.md +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/__main__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/agent/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/agent/gumbel.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/autograd.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/jax/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/dataloader.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/layers/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/layers/apodization.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/layers/utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/losses.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/models/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/models/lista.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/scripts/convert-echonet-dynamic.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/scripts/convert-taesd.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/utils/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/utils/callbacks.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/utils/utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/tf2jax.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/torch/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/backend/torch/losses.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/beamform/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/__main__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/augmentations.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/camus.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonet.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonetlvh/README.md +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonetlvh/manual_rejections.txt +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/picmus.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/convert/utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/dataloader.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/layers.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/preset_utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/data/utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/datapaths.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/interface.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/_generate_keras_ops.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/cache.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/checks.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/config/create.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/config/validation.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/core.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/git_info.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/operators.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/setup_zea.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/internal/viewer.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/io_lib.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/log.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/base.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/carotid_segmenter.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/deeplabv3.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/dense.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/echonet.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/echonetlvh.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/generative.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/gmm.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/lv_segmentation.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/preset_utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/regional_quality.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/taesd.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/unet.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/models/utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/ops/keras_ops.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/tools/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/tools/hf.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/tools/wndb.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/tracking/__init__.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/tracking/base.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/tracking/lucas_kanade.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/tracking/segmentation.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/utils.py +0 -0
- {zea-0.0.9 → zea-0.0.10}/zea/zea_darkmode.mplstyle +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: zea
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.10
|
|
4
4
|
Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
|
|
5
5
|
License-File: LICENSE
|
|
6
6
|
Keywords: ultrasound,machine learning,beamforming
|
|
@@ -44,8 +44,6 @@ Requires-Dist: jax ; extra == "backends"
|
|
|
44
44
|
Requires-Dist: jax[cuda12-pip] (>=0.4.26) ; extra == "jax"
|
|
45
45
|
Requires-Dist: keras (>=3.12)
|
|
46
46
|
Requires-Dist: matplotlib (>=3.8)
|
|
47
|
-
Requires-Dist: mock ; extra == "dev"
|
|
48
|
-
Requires-Dist: mock ; extra == "docs"
|
|
49
47
|
Requires-Dist: myst-parser ; extra == "dev"
|
|
50
48
|
Requires-Dist: myst-parser ; extra == "docs"
|
|
51
49
|
Requires-Dist: nbsphinx ; extra == "dev"
|
|
@@ -76,8 +74,6 @@ Requires-Dist: simpleitk (>=2.2.1) ; extra == "dev"
|
|
|
76
74
|
Requires-Dist: simpleitk (>=2.2.1) ; extra == "tests"
|
|
77
75
|
Requires-Dist: sphinx ; extra == "dev"
|
|
78
76
|
Requires-Dist: sphinx ; extra == "docs"
|
|
79
|
-
Requires-Dist: sphinx-argparse ; extra == "dev"
|
|
80
|
-
Requires-Dist: sphinx-argparse ; extra == "docs"
|
|
81
77
|
Requires-Dist: sphinx-autobuild ; extra == "dev"
|
|
82
78
|
Requires-Dist: sphinx-autobuild ; extra == "docs"
|
|
83
79
|
Requires-Dist: sphinx-autodoc-typehints ; extra == "dev"
|
|
@@ -88,6 +84,8 @@ Requires-Dist: sphinx-reredirects ; extra == "dev"
|
|
|
88
84
|
Requires-Dist: sphinx-reredirects ; extra == "docs"
|
|
89
85
|
Requires-Dist: sphinx_design ; extra == "dev"
|
|
90
86
|
Requires-Dist: sphinx_design ; extra == "docs"
|
|
87
|
+
Requires-Dist: sphinxcontrib-autoprogram ; extra == "dev"
|
|
88
|
+
Requires-Dist: sphinxcontrib-autoprogram ; extra == "docs"
|
|
91
89
|
Requires-Dist: sphinxcontrib-bibtex ; extra == "dev"
|
|
92
90
|
Requires-Dist: sphinxcontrib-bibtex ; extra == "docs"
|
|
93
91
|
Requires-Dist: tensorflow ; extra == "backends"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "zea"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.10"
|
|
4
4
|
description = "A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework."
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "Tristan Stevens", email = "t.s.w.stevens@tue.nl" },
|
|
@@ -65,10 +65,9 @@ dev = [
|
|
|
65
65
|
"sphinx-autodoc-typehints",
|
|
66
66
|
"sphinx-copybutton",
|
|
67
67
|
"sphinx_design",
|
|
68
|
-
"
|
|
68
|
+
"sphinxcontrib-autoprogram",
|
|
69
69
|
"sphinx-reredirects",
|
|
70
70
|
"sphinxcontrib-bibtex",
|
|
71
|
-
"mock",
|
|
72
71
|
"myst-parser",
|
|
73
72
|
"nbsphinx",
|
|
74
73
|
"furo",
|
|
@@ -97,10 +96,9 @@ docs = [
|
|
|
97
96
|
"sphinx-autodoc-typehints",
|
|
98
97
|
"sphinx-copybutton",
|
|
99
98
|
"sphinx_design",
|
|
100
|
-
"
|
|
99
|
+
"sphinxcontrib-autoprogram",
|
|
101
100
|
"sphinx-reredirects",
|
|
102
101
|
"sphinxcontrib-bibtex",
|
|
103
|
-
"mock",
|
|
104
102
|
"myst-parser",
|
|
105
103
|
"nbsphinx",
|
|
106
104
|
"furo",
|
|
@@ -2,12 +2,16 @@
|
|
|
2
2
|
|
|
3
3
|
import importlib.util
|
|
4
4
|
import os
|
|
5
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
5
6
|
|
|
6
7
|
from . import log
|
|
7
8
|
|
|
8
|
-
|
|
9
|
-
# __version__
|
|
10
|
-
__version__ = "
|
|
9
|
+
try:
|
|
10
|
+
# dynamically add __version__ attribute (see pyproject.toml)
|
|
11
|
+
__version__ = version("zea")
|
|
12
|
+
except PackageNotFoundError:
|
|
13
|
+
# Package is not installed (e.g., running from source)
|
|
14
|
+
__version__ = "dev"
|
|
11
15
|
|
|
12
16
|
|
|
13
17
|
def _bootstrap_backend():
|
|
@@ -80,8 +84,10 @@ def _bootstrap_backend():
|
|
|
80
84
|
log.info(f"Using backend {keras_backend()!r}")
|
|
81
85
|
|
|
82
86
|
|
|
83
|
-
#
|
|
84
|
-
|
|
87
|
+
# Skip backend bootstrap when building on ReadTheDocs
|
|
88
|
+
if os.environ.get("READTHEDOCS") != "True":
|
|
89
|
+
_bootstrap_backend()
|
|
90
|
+
|
|
85
91
|
del _bootstrap_backend
|
|
86
92
|
|
|
87
93
|
from . import (
|
|
@@ -4,6 +4,8 @@ Mask generation utilities.
|
|
|
4
4
|
These masks are used as a measurement operator for focused scan-line subsampling.
|
|
5
5
|
"""
|
|
6
6
|
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
7
9
|
from typing import List
|
|
8
10
|
|
|
9
11
|
import keras
|
|
@@ -117,11 +119,21 @@ def initial_equispaced_lines(
|
|
|
117
119
|
Tensor: k-hot-encoded line vector of shape (n_possible_actions).
|
|
118
120
|
Needs to be converted to image size.
|
|
119
121
|
"""
|
|
122
|
+
assert n_actions > 0, "Number of actions must be > 0."
|
|
123
|
+
assert n_possible_actions > 0, "Number of possible actions must be > 0."
|
|
124
|
+
assert n_actions <= n_possible_actions, (
|
|
125
|
+
"Number of actions must be less than or equal to number of possible actions."
|
|
126
|
+
)
|
|
127
|
+
|
|
120
128
|
if assert_equal_spacing:
|
|
121
129
|
_assert_equal_spacing(n_actions, n_possible_actions)
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
130
|
+
|
|
131
|
+
# Distribute indices as evenly as possible
|
|
132
|
+
# This approach ensures spacing differs by at most 1 when not divisible
|
|
133
|
+
step = n_possible_actions / n_actions
|
|
134
|
+
selected_indices = ops.cast(
|
|
135
|
+
ops.round(ops.arange(0, n_actions, dtype="float32") * step), "int32"
|
|
136
|
+
)
|
|
125
137
|
|
|
126
138
|
return indices_to_k_hot(selected_indices, n_possible_actions, dtype=dtype)
|
|
127
139
|
|
|
@@ -96,6 +96,7 @@ class GreedyEntropy(LinesActionModel):
|
|
|
96
96
|
std_dev: float = 1,
|
|
97
97
|
num_lines_to_update: int = 5,
|
|
98
98
|
entropy_sigma: float = 1.0,
|
|
99
|
+
average_entropy_across_batch: bool = False,
|
|
99
100
|
):
|
|
100
101
|
"""Initialize the GreedyEntropy action selection model.
|
|
101
102
|
|
|
@@ -110,6 +111,10 @@ class GreedyEntropy(LinesActionModel):
|
|
|
110
111
|
to update. Must be odd.
|
|
111
112
|
entropy_sigma (float, optional): The standard deviation of the Gaussian
|
|
112
113
|
Mixture components used to approximate the posterior.
|
|
114
|
+
average_entropy_across_batch (bool, optional): Whether to average entropy
|
|
115
|
+
across the batch when selecting lines. This can be useful when
|
|
116
|
+
selecting planes in 3D imaging, where the batch dimension represents
|
|
117
|
+
a third spatial dimension. Defaults to False.
|
|
113
118
|
"""
|
|
114
119
|
super().__init__(n_actions, n_possible_actions, img_width, img_height)
|
|
115
120
|
|
|
@@ -117,6 +122,7 @@ class GreedyEntropy(LinesActionModel):
|
|
|
117
122
|
# of the selected line is set to 0 once it's been selected.
|
|
118
123
|
assert num_lines_to_update % 2 == 1, "num_samples must be odd."
|
|
119
124
|
self.num_lines_to_update = num_lines_to_update
|
|
125
|
+
self.average_entropy_across_batch = average_entropy_across_batch
|
|
120
126
|
|
|
121
127
|
# see here what I mean by upside_down_gaussian:
|
|
122
128
|
# https://colab.research.google.com/drive/1CQp_Z6nADzOFsybdiH5Cag0vtVZjjioU?usp=sharing
|
|
@@ -153,7 +159,7 @@ class GreedyEntropy(LinesActionModel):
|
|
|
153
159
|
assert particles.shape[1] > 1, "The entropy cannot be approximated using a single particle."
|
|
154
160
|
|
|
155
161
|
if n_possible_actions is None:
|
|
156
|
-
n_possible_actions =
|
|
162
|
+
n_possible_actions = ops.shape(particles)[-1]
|
|
157
163
|
|
|
158
164
|
# TODO: I think we only need to compute the lower triangular
|
|
159
165
|
# of this matrix, since it's symmetric
|
|
@@ -164,7 +170,8 @@ class GreedyEntropy(LinesActionModel):
|
|
|
164
170
|
# Vertically stack all columns corresponding with the same line
|
|
165
171
|
# This way we can just sum across the height axis and get the entropy
|
|
166
172
|
# for each pixel in a given line
|
|
167
|
-
batch_size, n_particles, _, height, _ =
|
|
173
|
+
batch_size, n_particles, _, height, _ = ops.shape(gaussian_error_per_pixel_i_j)
|
|
174
|
+
|
|
168
175
|
gaussian_error_per_pixel_stacked = ops.transpose(
|
|
169
176
|
ops.reshape(
|
|
170
177
|
ops.transpose(gaussian_error_per_pixel_i_j, (0, 1, 2, 4, 3)),
|
|
@@ -274,6 +281,8 @@ class GreedyEntropy(LinesActionModel):
|
|
|
274
281
|
|
|
275
282
|
pixelwise_entropy = self.compute_pixelwise_entropy(particles)
|
|
276
283
|
linewise_entropy = ops.sum(pixelwise_entropy, axis=1)
|
|
284
|
+
if self.average_entropy_across_batch:
|
|
285
|
+
linewise_entropy = ops.expand_dims(ops.mean(linewise_entropy, axis=0), axis=0)
|
|
277
286
|
|
|
278
287
|
# Greedily select best line, reweight entropies, and repeat
|
|
279
288
|
all_selected_lines = []
|
|
@@ -334,7 +343,7 @@ class EquispacedLines(LinesActionModel):
|
|
|
334
343
|
n_possible_actions: int,
|
|
335
344
|
img_width: int,
|
|
336
345
|
img_height: int,
|
|
337
|
-
assert_equal_spacing=True,
|
|
346
|
+
assert_equal_spacing: bool = True,
|
|
338
347
|
):
|
|
339
348
|
super().__init__(n_actions, n_possible_actions, img_width, img_height)
|
|
340
349
|
|
|
@@ -59,8 +59,21 @@ def _import_torch():
|
|
|
59
59
|
return None
|
|
60
60
|
|
|
61
61
|
|
|
62
|
+
def _get_backend():
|
|
63
|
+
try:
|
|
64
|
+
backend_result = keras.backend.backend()
|
|
65
|
+
if isinstance(backend_result, str):
|
|
66
|
+
return backend_result
|
|
67
|
+
else:
|
|
68
|
+
# to handle mocked backends during testing
|
|
69
|
+
return None
|
|
70
|
+
except Exception:
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
|
|
62
74
|
tf_mod = _import_tf()
|
|
63
75
|
jax_mod = _import_jax()
|
|
76
|
+
backend = _get_backend()
|
|
64
77
|
|
|
65
78
|
|
|
66
79
|
def tf_function(func=None, jit_compile=False, **kwargs):
|
|
@@ -184,7 +197,7 @@ class on_device:
|
|
|
184
197
|
self._context.__exit__(exc_type, exc_val, exc_tb)
|
|
185
198
|
|
|
186
199
|
|
|
187
|
-
if
|
|
200
|
+
if backend in [None, "tensorflow", "jax", "numpy"]:
|
|
188
201
|
|
|
189
202
|
def func_on_device(func, device, *args, **kwargs):
|
|
190
203
|
"""Moves all tensor arguments of a function to a specified device before calling it.
|
|
@@ -199,7 +212,8 @@ if keras.backend.backend() in ["tensorflow", "jax", "numpy"]:
|
|
|
199
212
|
"""
|
|
200
213
|
with on_device(device):
|
|
201
214
|
return func(*args, **kwargs)
|
|
202
|
-
|
|
215
|
+
|
|
216
|
+
elif backend == "torch":
|
|
203
217
|
from zea.backend.torch import func_on_device
|
|
204
218
|
else:
|
|
205
|
-
raise ValueError(f"Unsupported backend: {
|
|
219
|
+
raise ValueError(f"Unsupported backend: {backend}")
|
|
@@ -4,7 +4,7 @@ import keras
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from keras import ops
|
|
6
6
|
|
|
7
|
-
from zea.beamform.lens_correction import
|
|
7
|
+
from zea.beamform.lens_correction import compute_lens_corrected_travel_times
|
|
8
8
|
from zea.func.tensor import vmap
|
|
9
9
|
|
|
10
10
|
|
|
@@ -62,6 +62,7 @@ def tof_correction(
|
|
|
62
62
|
focus_distances,
|
|
63
63
|
t_peak,
|
|
64
64
|
tx_waveform_indices,
|
|
65
|
+
transmit_origins,
|
|
65
66
|
apply_lens_correction=False,
|
|
66
67
|
lens_thickness=1e-3,
|
|
67
68
|
lens_sound_speed=1000,
|
|
@@ -87,6 +88,7 @@ def tof_correction(
|
|
|
87
88
|
Shape `(n_waveforms,)`.
|
|
88
89
|
tx_waveform_indices (ops.Tensor): The indices of the waveform used for each
|
|
89
90
|
transmit of shape `(n_tx,)`.
|
|
91
|
+
transmit_origins (ops.Tensor): Transmit origins of shape (n_tx, 3).
|
|
90
92
|
apply_lens_correction (bool, optional): Whether to apply lens correction to
|
|
91
93
|
time-of-flights. This makes it slower, but more accurate in the near-field.
|
|
92
94
|
Defaults to False.
|
|
@@ -120,8 +122,7 @@ def tof_correction(
|
|
|
120
122
|
# rxdel has shape (n_el, n_pix)
|
|
121
123
|
# --------------------------------------------------------------------
|
|
122
124
|
|
|
123
|
-
|
|
124
|
-
txdel, rxdel = delay_fn(
|
|
125
|
+
txdel, rxdel = calculate_delays(
|
|
125
126
|
flatgrid,
|
|
126
127
|
t0_delays,
|
|
127
128
|
tx_apodizations,
|
|
@@ -133,10 +134,12 @@ def tof_correction(
|
|
|
133
134
|
n_el,
|
|
134
135
|
focus_distances,
|
|
135
136
|
polar_angles,
|
|
136
|
-
t_peak
|
|
137
|
-
tx_waveform_indices
|
|
138
|
-
|
|
139
|
-
|
|
137
|
+
t_peak,
|
|
138
|
+
tx_waveform_indices,
|
|
139
|
+
transmit_origins,
|
|
140
|
+
apply_lens_correction,
|
|
141
|
+
lens_thickness,
|
|
142
|
+
lens_sound_speed,
|
|
140
143
|
)
|
|
141
144
|
|
|
142
145
|
n_pix = ops.shape(flatgrid)[0]
|
|
@@ -207,7 +210,11 @@ def calculate_delays(
|
|
|
207
210
|
polar_angles,
|
|
208
211
|
t_peak,
|
|
209
212
|
tx_waveform_indices,
|
|
210
|
-
|
|
213
|
+
transmit_origins,
|
|
214
|
+
apply_lens_correction=False,
|
|
215
|
+
lens_thickness=None,
|
|
216
|
+
lens_sound_speed=None,
|
|
217
|
+
n_iter=2,
|
|
211
218
|
):
|
|
212
219
|
"""Calculates the delays in samples to every pixel in the grid.
|
|
213
220
|
|
|
@@ -242,6 +249,16 @@ def calculate_delays(
|
|
|
242
249
|
`(n_waveforms,)`.
|
|
243
250
|
tx_waveform_indices (Tensor): The indices of the waveform used for each
|
|
244
251
|
transmit of shape `(n_tx,)`.
|
|
252
|
+
transmit_origins (Tensor): Transmit origins of shape (n_tx, 3).
|
|
253
|
+
apply_lens_correction (bool, optional): Whether to apply lens correction to
|
|
254
|
+
time-of-flights. This makes it slower, but more accurate in the near-field.
|
|
255
|
+
Defaults to False.
|
|
256
|
+
lens_thickness (float, optional): Thickness of the lens in meters. Used for
|
|
257
|
+
lens correction.
|
|
258
|
+
lens_sound_speed (float, optional): Speed of sound in the lens in m/s. Used
|
|
259
|
+
for lens correction.
|
|
260
|
+
n_iter (int, optional): Number of iterations for the Newton-Raphson method
|
|
261
|
+
used in lens correction. Defaults to 2.
|
|
245
262
|
|
|
246
263
|
|
|
247
264
|
Returns:
|
|
@@ -252,38 +269,56 @@ def calculate_delays(
|
|
|
252
269
|
`(n_pix, n_el)`.
|
|
253
270
|
"""
|
|
254
271
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
272
|
+
# Validate input shapes
|
|
273
|
+
for arr in [t0_delays, grid, tx_apodizations, probe_geometry]:
|
|
274
|
+
assert arr.ndim == 2
|
|
275
|
+
assert probe_geometry.shape[0] == n_el
|
|
276
|
+
assert t0_delays.shape[0] == n_tx
|
|
277
|
+
|
|
278
|
+
if not apply_lens_correction:
|
|
279
|
+
# Compute receive distances in meters of shape (n_pix, n_el)
|
|
280
|
+
rx_distances = distance_Rx(grid, probe_geometry)
|
|
281
|
+
|
|
282
|
+
# Convert distances to delays in seconds
|
|
283
|
+
rx_delays = rx_distances / sound_speed
|
|
284
|
+
else:
|
|
285
|
+
# Compute lens-corrected travel times from each element to each pixel
|
|
286
|
+
assert lens_thickness is not None, "lens_thickness must be provided for lens correction."
|
|
287
|
+
assert lens_sound_speed is not None, (
|
|
288
|
+
"lens_sound_speed must be provided for lens correction."
|
|
289
|
+
)
|
|
290
|
+
rx_delays = compute_lens_corrected_travel_times(
|
|
260
291
|
probe_geometry,
|
|
261
|
-
|
|
262
|
-
|
|
292
|
+
grid,
|
|
293
|
+
lens_thickness,
|
|
294
|
+
lens_sound_speed,
|
|
263
295
|
sound_speed,
|
|
296
|
+
n_iter=n_iter,
|
|
264
297
|
)
|
|
265
298
|
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
299
|
+
# Compute transmit delays
|
|
300
|
+
tx_delays = vmap(transmit_delays, in_axes=(None, 0, 0, None, 0, 0, 0, None, 0), out_axes=1)(
|
|
301
|
+
grid,
|
|
302
|
+
t0_delays,
|
|
303
|
+
tx_apodizations,
|
|
304
|
+
rx_delays,
|
|
305
|
+
focus_distances,
|
|
306
|
+
polar_angles,
|
|
307
|
+
initial_times,
|
|
308
|
+
None,
|
|
309
|
+
transmit_origins,
|
|
310
|
+
)
|
|
269
311
|
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
return distance_Rx(grid, probe_geometry)
|
|
312
|
+
# Add the offset to the transmit peak time
|
|
313
|
+
tx_delays += ops.take(t_peak, tx_waveform_indices)[None]
|
|
273
314
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
#
|
|
315
|
+
# TODO: nan to num needed?
|
|
316
|
+
# tx_delays = ops.nan_to_num(tx_delays, nan=0.0, posinf=0.0, neginf=0.0)
|
|
317
|
+
# rx_delays = ops.nan_to_num(rx_delays, nan=0.0, posinf=0.0, neginf=0.0)
|
|
277
318
|
|
|
278
|
-
#
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
tx_delays = (
|
|
282
|
-
tx_distances / sound_speed
|
|
283
|
-
- initial_times[None]
|
|
284
|
-
+ ops.take(t_peak, tx_waveform_indices)[None]
|
|
285
|
-
) * sampling_frequency
|
|
286
|
-
rx_delays = (rx_distances / sound_speed) * sampling_frequency
|
|
319
|
+
# Convert from seconds to samples
|
|
320
|
+
tx_delays *= sampling_frequency
|
|
321
|
+
rx_delays *= sampling_frequency
|
|
287
322
|
|
|
288
323
|
return tx_delays, rx_delays
|
|
289
324
|
|
|
@@ -414,7 +449,7 @@ def complex_rotate(iq, theta):
|
|
|
414
449
|
def distance_Rx(grid, probe_geometry):
|
|
415
450
|
"""Computes distance to user-defined pixels from elements.
|
|
416
451
|
|
|
417
|
-
Expects all inputs to be
|
|
452
|
+
Expects all inputs to be arrays specified in SI units.
|
|
418
453
|
|
|
419
454
|
Args:
|
|
420
455
|
grid (ops.Tensor): Pixel positions in x,y,z of shape `(n_pix, 3)`.
|
|
@@ -425,83 +460,117 @@ def distance_Rx(grid, probe_geometry):
|
|
|
425
460
|
`(n_pix, n_el)`.
|
|
426
461
|
"""
|
|
427
462
|
# Get norm of distance vector between elements and pixels via broadcasting
|
|
428
|
-
dist = ops.linalg.norm(grid - probe_geometry[None,
|
|
463
|
+
dist = ops.linalg.norm(grid[:, None, :] - probe_geometry[None, :, :], axis=-1)
|
|
429
464
|
return dist
|
|
430
465
|
|
|
431
466
|
|
|
432
|
-
def
|
|
467
|
+
def transmit_delays(
|
|
433
468
|
grid,
|
|
434
469
|
t0_delays,
|
|
435
470
|
tx_apodization,
|
|
436
|
-
|
|
471
|
+
rx_delays,
|
|
437
472
|
focus_distance,
|
|
438
473
|
polar_angle,
|
|
439
|
-
|
|
474
|
+
initial_time,
|
|
475
|
+
azimuth_angle=None,
|
|
476
|
+
transmit_origin=None,
|
|
440
477
|
):
|
|
441
|
-
"""
|
|
478
|
+
"""
|
|
479
|
+
Computes the transmit delay from transmission to each pixel in the grid.
|
|
480
|
+
|
|
481
|
+
Uses the first-arrival time for pixels before the focus (or virtual source)
|
|
482
|
+
and the last-arrival time for pixels beyond the focus.
|
|
442
483
|
|
|
443
|
-
|
|
444
|
-
|
|
484
|
+
The receive delays can be precomputed since they do not depend on the
|
|
485
|
+
transmit parameters.
|
|
445
486
|
|
|
446
487
|
Args:
|
|
447
|
-
grid (ops.Tensor): Flattened tensor of pixel positions in x,y,z of shape
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
`(n_el,)`.
|
|
453
|
-
probe_geometry (ops.Tensor): The positions of the transducer elements of shape
|
|
454
|
-
`(n_el, 3)`.
|
|
488
|
+
grid (ops.Tensor): Flattened tensor of pixel positions in x,y,z of shape `(n_pix, 3)`
|
|
489
|
+
t0_delays (Tensor): The transmit delays in seconds of shape (n_el,).
|
|
490
|
+
tx_apodization (Tensor): The transmit apodization of shape (n_el,).
|
|
491
|
+
rx_delays (Tensor): The travel times in seconds from elements to pixels
|
|
492
|
+
of shape (n_pix, n_el).
|
|
455
493
|
focus_distance (float): The focus distance in meters.
|
|
456
494
|
polar_angle (float): The polar angle in radians.
|
|
457
|
-
|
|
495
|
+
initial_time (float): The initial time for this transmit in seconds.
|
|
496
|
+
azimuth_angle (float, optional): The azimuth angle in radians. Defaults to 0.0.
|
|
497
|
+
transmit_origin (ops.Tensor, optional): The origin of the transmit beam of shape (3,).
|
|
498
|
+
If None, defaults to (0, 0, 0). Defaults to None.
|
|
458
499
|
|
|
459
500
|
Returns:
|
|
460
|
-
Tensor:
|
|
461
|
-
of shape `(n_pix,)`
|
|
501
|
+
Tensor: The transmit delays of shape `(n_pix,)`.
|
|
462
502
|
"""
|
|
463
|
-
#
|
|
464
|
-
|
|
465
|
-
y = grid[:, 1]
|
|
466
|
-
z = grid[:, 2]
|
|
467
|
-
|
|
468
|
-
# Reshape x, y, and z to shape (n_pix, 1)
|
|
469
|
-
x = x[..., None]
|
|
470
|
-
y = y[..., None]
|
|
471
|
-
z = z[..., None]
|
|
472
|
-
|
|
473
|
-
# Get the individual x, y, and z coordinates of the elements and add a
|
|
474
|
-
# dummy dimension at the beginning to shape (1, n_el).
|
|
475
|
-
ele_x = probe_geometry[None, :, 0]
|
|
476
|
-
ele_y = probe_geometry[None, :, 1]
|
|
477
|
-
ele_z = probe_geometry[None, :, 2]
|
|
478
|
-
|
|
479
|
-
# Compute the differences dx, dy, and dz of shape (n_pix, n_el)
|
|
480
|
-
dx = x - ele_x
|
|
481
|
-
dy = y - ele_y
|
|
482
|
-
dz = z - ele_z
|
|
483
|
-
|
|
484
|
-
# Define an infinite offset for elements that do not fire to not consider them in
|
|
485
|
-
# the transmit distance calculation.
|
|
503
|
+
# Add a large offset for elements that are not used in the transmit to
|
|
504
|
+
# disqualify them from being the closest element
|
|
486
505
|
offset = ops.where(tx_apodization == 0, np.inf, 0.0)
|
|
487
506
|
|
|
488
|
-
# Compute
|
|
489
|
-
# (n_pix, n_el)
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
507
|
+
# Compute total travel time from t=0 to each pixel via each element
|
|
508
|
+
# rx_delays has shape (n_pix, n_el)
|
|
509
|
+
# t0_delays has shape (n_el,)
|
|
510
|
+
total_times = rx_delays + t0_delays[None, :]
|
|
511
|
+
|
|
512
|
+
if azimuth_angle is None:
|
|
513
|
+
azimuth_angle = ops.zeros_like(polar_angle)
|
|
514
|
+
|
|
515
|
+
# Set origin to (0, 0, 0) if not provided
|
|
516
|
+
if transmit_origin is None:
|
|
517
|
+
transmit_origin = ops.zeros(3, dtype=grid.dtype)
|
|
518
|
+
|
|
519
|
+
# Compute the 3D position of the focal point
|
|
520
|
+
# The beam direction vector
|
|
521
|
+
beam_direction = ops.stack(
|
|
522
|
+
[
|
|
523
|
+
ops.sin(polar_angle) * ops.cos(azimuth_angle),
|
|
524
|
+
ops.sin(polar_angle) * ops.sin(azimuth_angle),
|
|
525
|
+
ops.cos(polar_angle),
|
|
526
|
+
]
|
|
527
|
+
)
|
|
494
528
|
|
|
495
|
-
#
|
|
496
|
-
#
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
529
|
+
# Handle plane wave case where focus_distance is set to zero
|
|
530
|
+
# We use np.inf to consider the first wavefront arrival for all pixels
|
|
531
|
+
focus_distance = ops.where(focus_distance == 0.0, np.inf, focus_distance)
|
|
532
|
+
|
|
533
|
+
# Compute focal point position: origin + focus_distance * beam_direction
|
|
534
|
+
# For negative focus_distance (diverging/virtual source), this is behind the origin
|
|
535
|
+
focal_point = transmit_origin + focus_distance * beam_direction # shape (3,)
|
|
536
|
+
|
|
537
|
+
# Deal with plane wave case where focus_distance is infinite and beam_direction is zero
|
|
538
|
+
# (np.inf * 0.0 -> nan) so we convert nan to zero
|
|
539
|
+
focal_point = ops.where(ops.isnan(focal_point), 0.0, focal_point)
|
|
540
|
+
|
|
541
|
+
# Compute the position of each pixel relative to the focal point
|
|
542
|
+
pixel_relative_to_focus = grid - focal_point[None, :] # shape (n_pix, 3)
|
|
543
|
+
|
|
544
|
+
# Project onto the beam direction to determine if pixel is before or after focus
|
|
545
|
+
# Positive projection means pixel is in the direction of beam propagation (beyond focus)
|
|
546
|
+
# Negative projection means pixel is behind the focus (before focus)
|
|
547
|
+
projection_along_beam = ops.sum(
|
|
548
|
+
pixel_relative_to_focus * beam_direction[None, :], axis=-1
|
|
549
|
+
) # shape (n_pix,)
|
|
550
|
+
|
|
551
|
+
# For focused waves (positive focus_distance):
|
|
552
|
+
# - Use min time for pixels before focus (projection < 0)
|
|
553
|
+
# - Use max time for pixels beyond focus (projection > 0)
|
|
554
|
+
# For diverging waves (negative focus_distance, virtual source):
|
|
555
|
+
# - The sign of focus_distance flips the logic
|
|
556
|
+
# - Use min time for pixels between transducer and virtual source
|
|
557
|
+
# - Use max time for pixels beyond transducer
|
|
558
|
+
is_before_focus = ops.cast(ops.sign(focus_distance), "float32") * projection_along_beam < 0.0
|
|
559
|
+
|
|
560
|
+
# Compute the effective time of the pixels to the wavefront by computing the
|
|
561
|
+
# smallest time over all elements (first wavefront arrival) for pixels before
|
|
562
|
+
# the focus, and the largest time (last wavefront contribution) for pixels
|
|
563
|
+
# beyond the focus.
|
|
564
|
+
tx_delay = ops.where(
|
|
565
|
+
is_before_focus,
|
|
566
|
+
ops.min(total_times + offset[None, :], axis=-1),
|
|
567
|
+
ops.max(total_times - offset[None, :], axis=-1),
|
|
502
568
|
)
|
|
503
569
|
|
|
504
|
-
|
|
570
|
+
# Subtract the initial time offset for this transmit
|
|
571
|
+
tx_delay = tx_delay - initial_time
|
|
572
|
+
|
|
573
|
+
return tx_delay
|
|
505
574
|
|
|
506
575
|
|
|
507
576
|
def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
|
|
@@ -52,7 +52,7 @@ def compute_t0_delays_planewave(probe_geometry, polar_angles, azimuth_angles=0,
|
|
|
52
52
|
|
|
53
53
|
|
|
54
54
|
def compute_t0_delays_focused(
|
|
55
|
-
|
|
55
|
+
transmit_origins,
|
|
56
56
|
focus_distances,
|
|
57
57
|
probe_geometry,
|
|
58
58
|
polar_angles,
|
|
@@ -63,12 +63,12 @@ def compute_t0_delays_focused(
|
|
|
63
63
|
the first element fires at t=0.
|
|
64
64
|
|
|
65
65
|
Args:
|
|
66
|
-
|
|
67
|
-
|
|
66
|
+
transmit_origins (np.ndarray): The origin of the focused transmit of shape (n_tx, 3,).
|
|
67
|
+
focus_distances (np.ndarray): The distance to the focus for each transmit of shape (n_tx,).
|
|
68
68
|
probe_geometry (np.ndarray): The positions of the elements in the array of
|
|
69
69
|
shape (element, 3).
|
|
70
|
-
polar_angles (np.ndarray): The polar angles
|
|
71
|
-
azimuth_angles (np.ndarray, optional): The azimuth angles
|
|
70
|
+
polar_angles (np.ndarray): The polar angles in radians of shape (n_tx,).
|
|
71
|
+
azimuth_angles (np.ndarray, optional): The azimuth angles in
|
|
72
72
|
radians of shape (n_tx,).
|
|
73
73
|
sound_speed (float, optional): The speed of sound. Defaults to 1540.
|
|
74
74
|
|
|
@@ -79,12 +79,15 @@ def compute_t0_delays_focused(
|
|
|
79
79
|
assert polar_angles.shape == (n_tx,), (
|
|
80
80
|
f"polar_angles must have length n_tx = {n_tx}. Got length {len(polar_angles)}."
|
|
81
81
|
)
|
|
82
|
-
assert
|
|
83
|
-
f"
|
|
82
|
+
assert transmit_origins.shape == (n_tx, 3), (
|
|
83
|
+
f"transmit_origins must have shape (n_tx, 3). Got shape {transmit_origins.shape}."
|
|
84
84
|
)
|
|
85
85
|
assert probe_geometry.shape[1] == 3 and probe_geometry.ndim == 2, (
|
|
86
86
|
f"probe_geometry must have shape (element, 3). Got shape {probe_geometry.shape}."
|
|
87
87
|
)
|
|
88
|
+
assert focus_distances.shape == (n_tx,), (
|
|
89
|
+
f"focus_distances must have length n_tx = {n_tx}. Got length {len(focus_distances)}."
|
|
90
|
+
)
|
|
88
91
|
|
|
89
92
|
# Convert single angles to arrays for broadcasting
|
|
90
93
|
polar_angles = np.atleast_1d(polar_angles)
|
|
@@ -107,12 +110,12 @@ def compute_t0_delays_focused(
|
|
|
107
110
|
)
|
|
108
111
|
|
|
109
112
|
# Add a new dimension for broadcasting
|
|
110
|
-
# The shape is now (n_tx,
|
|
113
|
+
# The shape is now (n_tx, 1, 3)
|
|
111
114
|
v = np.expand_dims(v, axis=1)
|
|
112
115
|
|
|
113
116
|
# Compute the location of the virtual source by adding the focus distance
|
|
114
117
|
# to the origin along the wave vectors.
|
|
115
|
-
virtual_sources =
|
|
118
|
+
virtual_sources = transmit_origins[:, None] + focus_distances[:, None, None] * v
|
|
116
119
|
|
|
117
120
|
# Compute the distances between the virtual sources and each element
|
|
118
121
|
dist = np.linalg.norm(virtual_sources - probe_geometry, axis=-1)
|