zea 0.0.9__tar.gz → 0.0.10__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (133) hide show
  1. {zea-0.0.9 → zea-0.0.10}/PKG-INFO +3 -5
  2. {zea-0.0.9 → zea-0.0.10}/pyproject.toml +3 -5
  3. {zea-0.0.9 → zea-0.0.10}/zea/__init__.py +11 -5
  4. {zea-0.0.9 → zea-0.0.10}/zea/agent/masks.py +15 -3
  5. {zea-0.0.9 → zea-0.0.10}/zea/agent/selection.py +12 -3
  6. {zea-0.0.9 → zea-0.0.10}/zea/backend/__init__.py +17 -3
  7. {zea-0.0.9 → zea-0.0.10}/zea/beamform/beamformer.py +158 -89
  8. {zea-0.0.9 → zea-0.0.10}/zea/beamform/delays.py +12 -9
  9. {zea-0.0.9 → zea-0.0.10}/zea/beamform/lens_correction.py +0 -73
  10. {zea-0.0.9 → zea-0.0.10}/zea/beamform/pfield.py +4 -10
  11. zea-0.0.10/zea/beamform/phantoms.py +145 -0
  12. zea-0.0.10/zea/beamform/pixelgrid.py +189 -0
  13. {zea-0.0.9 → zea-0.0.10}/zea/config.py +2 -2
  14. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/__main__.py +17 -7
  15. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonetlvh/__init__.py +47 -50
  16. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonetlvh/precompute_crop.py +12 -0
  17. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/images.py +1 -1
  18. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/verasonics.py +375 -119
  19. {zea-0.0.9 → zea-0.0.10}/zea/data/data_format.py +53 -8
  20. {zea-0.0.9 → zea-0.0.10}/zea/data/datasets.py +0 -7
  21. {zea-0.0.9 → zea-0.0.10}/zea/data/file.py +8 -2
  22. {zea-0.0.9 → zea-0.0.10}/zea/data/file_operations.py +2 -0
  23. {zea-0.0.9 → zea-0.0.10}/zea/display.py +26 -9
  24. {zea-0.0.9 → zea-0.0.10}/zea/doppler.py +1 -1
  25. {zea-0.0.9 → zea-0.0.10}/zea/func/__init__.py +6 -0
  26. {zea-0.0.9 → zea-0.0.10}/zea/func/tensor.py +4 -2
  27. {zea-0.0.9 → zea-0.0.10}/zea/func/ultrasound.py +158 -62
  28. {zea-0.0.9 → zea-0.0.10}/zea/internal/config/parameters.py +1 -1
  29. {zea-0.0.9 → zea-0.0.10}/zea/internal/device.py +6 -1
  30. {zea-0.0.9 → zea-0.0.10}/zea/internal/dummy_scan.py +15 -8
  31. zea-0.0.10/zea/internal/notebooks.py +152 -0
  32. {zea-0.0.9 → zea-0.0.10}/zea/internal/parameters.py +20 -0
  33. {zea-0.0.9 → zea-0.0.10}/zea/internal/registry.py +1 -1
  34. {zea-0.0.9 → zea-0.0.10}/zea/metrics.py +84 -68
  35. {zea-0.0.9 → zea-0.0.10}/zea/models/__init__.py +1 -0
  36. {zea-0.0.9 → zea-0.0.10}/zea/models/diffusion.py +5 -0
  37. zea-0.0.10/zea/models/hvae/__init__.py +243 -0
  38. zea-0.0.10/zea/models/hvae/model.py +1139 -0
  39. zea-0.0.10/zea/models/hvae/utils.py +616 -0
  40. {zea-0.0.9 → zea-0.0.10}/zea/models/layers.py +1 -1
  41. {zea-0.0.9 → zea-0.0.10}/zea/models/lpips.py +12 -2
  42. {zea-0.0.9 → zea-0.0.10}/zea/models/presets.py +16 -0
  43. {zea-0.0.9 → zea-0.0.10}/zea/ops/__init__.py +6 -4
  44. {zea-0.0.9 → zea-0.0.10}/zea/ops/base.py +28 -29
  45. {zea-0.0.9 → zea-0.0.10}/zea/ops/pipeline.py +17 -7
  46. {zea-0.0.9 → zea-0.0.10}/zea/ops/tensor.py +50 -73
  47. {zea-0.0.9 → zea-0.0.10}/zea/ops/ultrasound.py +250 -103
  48. {zea-0.0.9 → zea-0.0.10}/zea/probes.py +2 -0
  49. {zea-0.0.9 → zea-0.0.10}/zea/scan.py +223 -83
  50. {zea-0.0.9 → zea-0.0.10}/zea/simulator.py +8 -8
  51. {zea-0.0.9 → zea-0.0.10}/zea/tools/fit_scan_cone.py +8 -6
  52. {zea-0.0.9 → zea-0.0.10}/zea/tools/selection_tool.py +13 -7
  53. {zea-0.0.9 → zea-0.0.10}/zea/visualize.py +3 -1
  54. zea-0.0.9/zea/beamform/phantoms.py +0 -43
  55. zea-0.0.9/zea/beamform/pixelgrid.py +0 -131
  56. zea-0.0.9/zea/internal/notebooks.py +0 -39
  57. {zea-0.0.9 → zea-0.0.10}/LICENSE +0 -0
  58. {zea-0.0.9 → zea-0.0.10}/README.md +0 -0
  59. {zea-0.0.9 → zea-0.0.10}/zea/__main__.py +0 -0
  60. {zea-0.0.9 → zea-0.0.10}/zea/agent/__init__.py +0 -0
  61. {zea-0.0.9 → zea-0.0.10}/zea/agent/gumbel.py +0 -0
  62. {zea-0.0.9 → zea-0.0.10}/zea/backend/autograd.py +0 -0
  63. {zea-0.0.9 → zea-0.0.10}/zea/backend/jax/__init__.py +0 -0
  64. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/__init__.py +0 -0
  65. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/dataloader.py +0 -0
  66. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/layers/__init__.py +0 -0
  67. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/layers/apodization.py +0 -0
  68. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/layers/utils.py +0 -0
  69. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/losses.py +0 -0
  70. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/models/__init__.py +0 -0
  71. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/models/lista.py +0 -0
  72. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/scripts/convert-echonet-dynamic.py +0 -0
  73. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/scripts/convert-taesd.py +0 -0
  74. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/utils/__init__.py +0 -0
  75. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/utils/callbacks.py +0 -0
  76. {zea-0.0.9 → zea-0.0.10}/zea/backend/tensorflow/utils/utils.py +0 -0
  77. {zea-0.0.9 → zea-0.0.10}/zea/backend/tf2jax.py +0 -0
  78. {zea-0.0.9 → zea-0.0.10}/zea/backend/torch/__init__.py +0 -0
  79. {zea-0.0.9 → zea-0.0.10}/zea/backend/torch/losses.py +0 -0
  80. {zea-0.0.9 → zea-0.0.10}/zea/beamform/__init__.py +0 -0
  81. {zea-0.0.9 → zea-0.0.10}/zea/data/__init__.py +0 -0
  82. {zea-0.0.9 → zea-0.0.10}/zea/data/__main__.py +0 -0
  83. {zea-0.0.9 → zea-0.0.10}/zea/data/augmentations.py +0 -0
  84. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/__init__.py +0 -0
  85. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/camus.py +0 -0
  86. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonet.py +0 -0
  87. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonetlvh/README.md +0 -0
  88. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/echonetlvh/manual_rejections.txt +0 -0
  89. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/picmus.py +0 -0
  90. {zea-0.0.9 → zea-0.0.10}/zea/data/convert/utils.py +0 -0
  91. {zea-0.0.9 → zea-0.0.10}/zea/data/dataloader.py +0 -0
  92. {zea-0.0.9 → zea-0.0.10}/zea/data/layers.py +0 -0
  93. {zea-0.0.9 → zea-0.0.10}/zea/data/preset_utils.py +0 -0
  94. {zea-0.0.9 → zea-0.0.10}/zea/data/utils.py +0 -0
  95. {zea-0.0.9 → zea-0.0.10}/zea/datapaths.py +0 -0
  96. {zea-0.0.9 → zea-0.0.10}/zea/interface.py +0 -0
  97. {zea-0.0.9 → zea-0.0.10}/zea/internal/_generate_keras_ops.py +0 -0
  98. {zea-0.0.9 → zea-0.0.10}/zea/internal/cache.py +0 -0
  99. {zea-0.0.9 → zea-0.0.10}/zea/internal/checks.py +0 -0
  100. {zea-0.0.9 → zea-0.0.10}/zea/internal/config/create.py +0 -0
  101. {zea-0.0.9 → zea-0.0.10}/zea/internal/config/validation.py +0 -0
  102. {zea-0.0.9 → zea-0.0.10}/zea/internal/core.py +0 -0
  103. {zea-0.0.9 → zea-0.0.10}/zea/internal/git_info.py +0 -0
  104. {zea-0.0.9 → zea-0.0.10}/zea/internal/operators.py +0 -0
  105. {zea-0.0.9 → zea-0.0.10}/zea/internal/setup_zea.py +0 -0
  106. {zea-0.0.9 → zea-0.0.10}/zea/internal/utils.py +0 -0
  107. {zea-0.0.9 → zea-0.0.10}/zea/internal/viewer.py +0 -0
  108. {zea-0.0.9 → zea-0.0.10}/zea/io_lib.py +0 -0
  109. {zea-0.0.9 → zea-0.0.10}/zea/log.py +0 -0
  110. {zea-0.0.9 → zea-0.0.10}/zea/models/base.py +0 -0
  111. {zea-0.0.9 → zea-0.0.10}/zea/models/carotid_segmenter.py +0 -0
  112. {zea-0.0.9 → zea-0.0.10}/zea/models/deeplabv3.py +0 -0
  113. {zea-0.0.9 → zea-0.0.10}/zea/models/dense.py +0 -0
  114. {zea-0.0.9 → zea-0.0.10}/zea/models/echonet.py +0 -0
  115. {zea-0.0.9 → zea-0.0.10}/zea/models/echonetlvh.py +0 -0
  116. {zea-0.0.9 → zea-0.0.10}/zea/models/generative.py +0 -0
  117. {zea-0.0.9 → zea-0.0.10}/zea/models/gmm.py +0 -0
  118. {zea-0.0.9 → zea-0.0.10}/zea/models/lv_segmentation.py +0 -0
  119. {zea-0.0.9 → zea-0.0.10}/zea/models/preset_utils.py +0 -0
  120. {zea-0.0.9 → zea-0.0.10}/zea/models/regional_quality.py +0 -0
  121. {zea-0.0.9 → zea-0.0.10}/zea/models/taesd.py +0 -0
  122. {zea-0.0.9 → zea-0.0.10}/zea/models/unet.py +0 -0
  123. {zea-0.0.9 → zea-0.0.10}/zea/models/utils.py +0 -0
  124. {zea-0.0.9 → zea-0.0.10}/zea/ops/keras_ops.py +0 -0
  125. {zea-0.0.9 → zea-0.0.10}/zea/tools/__init__.py +0 -0
  126. {zea-0.0.9 → zea-0.0.10}/zea/tools/hf.py +0 -0
  127. {zea-0.0.9 → zea-0.0.10}/zea/tools/wndb.py +0 -0
  128. {zea-0.0.9 → zea-0.0.10}/zea/tracking/__init__.py +0 -0
  129. {zea-0.0.9 → zea-0.0.10}/zea/tracking/base.py +0 -0
  130. {zea-0.0.9 → zea-0.0.10}/zea/tracking/lucas_kanade.py +0 -0
  131. {zea-0.0.9 → zea-0.0.10}/zea/tracking/segmentation.py +0 -0
  132. {zea-0.0.9 → zea-0.0.10}/zea/utils.py +0 -0
  133. {zea-0.0.9 → zea-0.0.10}/zea/zea_darkmode.mplstyle +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: zea
3
- Version: 0.0.9
3
+ Version: 0.0.10
4
4
  Summary: A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework.
5
5
  License-File: LICENSE
6
6
  Keywords: ultrasound,machine learning,beamforming
@@ -44,8 +44,6 @@ Requires-Dist: jax ; extra == "backends"
44
44
  Requires-Dist: jax[cuda12-pip] (>=0.4.26) ; extra == "jax"
45
45
  Requires-Dist: keras (>=3.12)
46
46
  Requires-Dist: matplotlib (>=3.8)
47
- Requires-Dist: mock ; extra == "dev"
48
- Requires-Dist: mock ; extra == "docs"
49
47
  Requires-Dist: myst-parser ; extra == "dev"
50
48
  Requires-Dist: myst-parser ; extra == "docs"
51
49
  Requires-Dist: nbsphinx ; extra == "dev"
@@ -76,8 +74,6 @@ Requires-Dist: simpleitk (>=2.2.1) ; extra == "dev"
76
74
  Requires-Dist: simpleitk (>=2.2.1) ; extra == "tests"
77
75
  Requires-Dist: sphinx ; extra == "dev"
78
76
  Requires-Dist: sphinx ; extra == "docs"
79
- Requires-Dist: sphinx-argparse ; extra == "dev"
80
- Requires-Dist: sphinx-argparse ; extra == "docs"
81
77
  Requires-Dist: sphinx-autobuild ; extra == "dev"
82
78
  Requires-Dist: sphinx-autobuild ; extra == "docs"
83
79
  Requires-Dist: sphinx-autodoc-typehints ; extra == "dev"
@@ -88,6 +84,8 @@ Requires-Dist: sphinx-reredirects ; extra == "dev"
88
84
  Requires-Dist: sphinx-reredirects ; extra == "docs"
89
85
  Requires-Dist: sphinx_design ; extra == "dev"
90
86
  Requires-Dist: sphinx_design ; extra == "docs"
87
+ Requires-Dist: sphinxcontrib-autoprogram ; extra == "dev"
88
+ Requires-Dist: sphinxcontrib-autoprogram ; extra == "docs"
91
89
  Requires-Dist: sphinxcontrib-bibtex ; extra == "dev"
92
90
  Requires-Dist: sphinxcontrib-bibtex ; extra == "docs"
93
91
  Requires-Dist: tensorflow ; extra == "backends"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "zea"
3
- version = "0.0.9"
3
+ version = "0.0.10"
4
4
  description = "A Toolbox for Cognitive Ultrasound Imaging. Provides a set of tools for processing of ultrasound data, all built in your favorite machine learning framework."
5
5
  authors = [
6
6
  { name = "Tristan Stevens", email = "t.s.w.stevens@tue.nl" },
@@ -65,10 +65,9 @@ dev = [
65
65
  "sphinx-autodoc-typehints",
66
66
  "sphinx-copybutton",
67
67
  "sphinx_design",
68
- "sphinx-argparse",
68
+ "sphinxcontrib-autoprogram",
69
69
  "sphinx-reredirects",
70
70
  "sphinxcontrib-bibtex",
71
- "mock",
72
71
  "myst-parser",
73
72
  "nbsphinx",
74
73
  "furo",
@@ -97,10 +96,9 @@ docs = [
97
96
  "sphinx-autodoc-typehints",
98
97
  "sphinx-copybutton",
99
98
  "sphinx_design",
100
- "sphinx-argparse",
99
+ "sphinxcontrib-autoprogram",
101
100
  "sphinx-reredirects",
102
101
  "sphinxcontrib-bibtex",
103
- "mock",
104
102
  "myst-parser",
105
103
  "nbsphinx",
106
104
  "furo",
@@ -2,12 +2,16 @@
2
2
 
3
3
  import importlib.util
4
4
  import os
5
+ from importlib.metadata import PackageNotFoundError, version
5
6
 
6
7
  from . import log
7
8
 
8
- # dynamically add __version__ attribute (see pyproject.toml)
9
- # __version__ = __import__("importlib.metadata").metadata.version(__package__)
10
- __version__ = "0.0.9"
9
+ try:
10
+ # dynamically add __version__ attribute (see pyproject.toml)
11
+ __version__ = version("zea")
12
+ except PackageNotFoundError:
13
+ # Package is not installed (e.g., running from source)
14
+ __version__ = "dev"
11
15
 
12
16
 
13
17
  def _bootstrap_backend():
@@ -80,8 +84,10 @@ def _bootstrap_backend():
80
84
  log.info(f"Using backend {keras_backend()!r}")
81
85
 
82
86
 
83
- # call and clean up namespace
84
- _bootstrap_backend()
87
+ # Skip backend bootstrap when building on ReadTheDocs
88
+ if os.environ.get("READTHEDOCS") != "True":
89
+ _bootstrap_backend()
90
+
85
91
  del _bootstrap_backend
86
92
 
87
93
  from . import (
@@ -4,6 +4,8 @@ Mask generation utilities.
4
4
  These masks are used as a measurement operator for focused scan-line subsampling.
5
5
  """
6
6
 
7
+ from __future__ import annotations
8
+
7
9
  from typing import List
8
10
 
9
11
  import keras
@@ -117,11 +119,21 @@ def initial_equispaced_lines(
117
119
  Tensor: k-hot-encoded line vector of shape (n_possible_actions).
118
120
  Needs to be converted to image size.
119
121
  """
122
+ assert n_actions > 0, "Number of actions must be > 0."
123
+ assert n_possible_actions > 0, "Number of possible actions must be > 0."
124
+ assert n_actions <= n_possible_actions, (
125
+ "Number of actions must be less than or equal to number of possible actions."
126
+ )
127
+
120
128
  if assert_equal_spacing:
121
129
  _assert_equal_spacing(n_actions, n_possible_actions)
122
- selected_indices = ops.arange(0, n_possible_actions, n_possible_actions // n_actions)
123
- else:
124
- selected_indices = ops.linspace(0, n_possible_actions - 1, n_actions, dtype="int32")
130
+
131
+ # Distribute indices as evenly as possible
132
+ # This approach ensures spacing differs by at most 1 when not divisible
133
+ step = n_possible_actions / n_actions
134
+ selected_indices = ops.cast(
135
+ ops.round(ops.arange(0, n_actions, dtype="float32") * step), "int32"
136
+ )
125
137
 
126
138
  return indices_to_k_hot(selected_indices, n_possible_actions, dtype=dtype)
127
139
 
@@ -96,6 +96,7 @@ class GreedyEntropy(LinesActionModel):
96
96
  std_dev: float = 1,
97
97
  num_lines_to_update: int = 5,
98
98
  entropy_sigma: float = 1.0,
99
+ average_entropy_across_batch: bool = False,
99
100
  ):
100
101
  """Initialize the GreedyEntropy action selection model.
101
102
 
@@ -110,6 +111,10 @@ class GreedyEntropy(LinesActionModel):
110
111
  to update. Must be odd.
111
112
  entropy_sigma (float, optional): The standard deviation of the Gaussian
112
113
  Mixture components used to approximate the posterior.
114
+ average_entropy_across_batch (bool, optional): Whether to average entropy
115
+ across the batch when selecting lines. This can be useful when
116
+ selecting planes in 3D imaging, where the batch dimension represents
117
+ a third spatial dimension. Defaults to False.
113
118
  """
114
119
  super().__init__(n_actions, n_possible_actions, img_width, img_height)
115
120
 
@@ -117,6 +122,7 @@ class GreedyEntropy(LinesActionModel):
117
122
  # of the selected line is set to 0 once it's been selected.
118
123
  assert num_lines_to_update % 2 == 1, "num_samples must be odd."
119
124
  self.num_lines_to_update = num_lines_to_update
125
+ self.average_entropy_across_batch = average_entropy_across_batch
120
126
 
121
127
  # see here what I mean by upside_down_gaussian:
122
128
  # https://colab.research.google.com/drive/1CQp_Z6nADzOFsybdiH5Cag0vtVZjjioU?usp=sharing
@@ -153,7 +159,7 @@ class GreedyEntropy(LinesActionModel):
153
159
  assert particles.shape[1] > 1, "The entropy cannot be approximated using a single particle."
154
160
 
155
161
  if n_possible_actions is None:
156
- n_possible_actions = particles.shape[-1]
162
+ n_possible_actions = ops.shape(particles)[-1]
157
163
 
158
164
  # TODO: I think we only need to compute the lower triangular
159
165
  # of this matrix, since it's symmetric
@@ -164,7 +170,8 @@ class GreedyEntropy(LinesActionModel):
164
170
  # Vertically stack all columns corresponding with the same line
165
171
  # This way we can just sum across the height axis and get the entropy
166
172
  # for each pixel in a given line
167
- batch_size, n_particles, _, height, _ = gaussian_error_per_pixel_i_j.shape
173
+ batch_size, n_particles, _, height, _ = ops.shape(gaussian_error_per_pixel_i_j)
174
+
168
175
  gaussian_error_per_pixel_stacked = ops.transpose(
169
176
  ops.reshape(
170
177
  ops.transpose(gaussian_error_per_pixel_i_j, (0, 1, 2, 4, 3)),
@@ -274,6 +281,8 @@ class GreedyEntropy(LinesActionModel):
274
281
 
275
282
  pixelwise_entropy = self.compute_pixelwise_entropy(particles)
276
283
  linewise_entropy = ops.sum(pixelwise_entropy, axis=1)
284
+ if self.average_entropy_across_batch:
285
+ linewise_entropy = ops.expand_dims(ops.mean(linewise_entropy, axis=0), axis=0)
277
286
 
278
287
  # Greedily select best line, reweight entropies, and repeat
279
288
  all_selected_lines = []
@@ -334,7 +343,7 @@ class EquispacedLines(LinesActionModel):
334
343
  n_possible_actions: int,
335
344
  img_width: int,
336
345
  img_height: int,
337
- assert_equal_spacing=True,
346
+ assert_equal_spacing: bool = True,
338
347
  ):
339
348
  super().__init__(n_actions, n_possible_actions, img_width, img_height)
340
349
 
@@ -59,8 +59,21 @@ def _import_torch():
59
59
  return None
60
60
 
61
61
 
62
+ def _get_backend():
63
+ try:
64
+ backend_result = keras.backend.backend()
65
+ if isinstance(backend_result, str):
66
+ return backend_result
67
+ else:
68
+ # to handle mocked backends during testing
69
+ return None
70
+ except Exception:
71
+ return None
72
+
73
+
62
74
  tf_mod = _import_tf()
63
75
  jax_mod = _import_jax()
76
+ backend = _get_backend()
64
77
 
65
78
 
66
79
  def tf_function(func=None, jit_compile=False, **kwargs):
@@ -184,7 +197,7 @@ class on_device:
184
197
  self._context.__exit__(exc_type, exc_val, exc_tb)
185
198
 
186
199
 
187
- if keras.backend.backend() in ["tensorflow", "jax", "numpy"]:
200
+ if backend in [None, "tensorflow", "jax", "numpy"]:
188
201
 
189
202
  def func_on_device(func, device, *args, **kwargs):
190
203
  """Moves all tensor arguments of a function to a specified device before calling it.
@@ -199,7 +212,8 @@ if keras.backend.backend() in ["tensorflow", "jax", "numpy"]:
199
212
  """
200
213
  with on_device(device):
201
214
  return func(*args, **kwargs)
202
- elif keras.backend.backend() == "torch":
215
+
216
+ elif backend == "torch":
203
217
  from zea.backend.torch import func_on_device
204
218
  else:
205
- raise ValueError(f"Unsupported backend: {keras.backend.backend()}")
219
+ raise ValueError(f"Unsupported backend: {backend}")
@@ -4,7 +4,7 @@ import keras
4
4
  import numpy as np
5
5
  from keras import ops
6
6
 
7
- from zea.beamform.lens_correction import calculate_lens_corrected_delays
7
+ from zea.beamform.lens_correction import compute_lens_corrected_travel_times
8
8
  from zea.func.tensor import vmap
9
9
 
10
10
 
@@ -62,6 +62,7 @@ def tof_correction(
62
62
  focus_distances,
63
63
  t_peak,
64
64
  tx_waveform_indices,
65
+ transmit_origins,
65
66
  apply_lens_correction=False,
66
67
  lens_thickness=1e-3,
67
68
  lens_sound_speed=1000,
@@ -87,6 +88,7 @@ def tof_correction(
87
88
  Shape `(n_waveforms,)`.
88
89
  tx_waveform_indices (ops.Tensor): The indices of the waveform used for each
89
90
  transmit of shape `(n_tx,)`.
91
+ transmit_origins (ops.Tensor): Transmit origins of shape (n_tx, 3).
90
92
  apply_lens_correction (bool, optional): Whether to apply lens correction to
91
93
  time-of-flights. This makes it slower, but more accurate in the near-field.
92
94
  Defaults to False.
@@ -120,8 +122,7 @@ def tof_correction(
120
122
  # rxdel has shape (n_el, n_pix)
121
123
  # --------------------------------------------------------------------
122
124
 
123
- delay_fn = calculate_lens_corrected_delays if apply_lens_correction else calculate_delays
124
- txdel, rxdel = delay_fn(
125
+ txdel, rxdel = calculate_delays(
125
126
  flatgrid,
126
127
  t0_delays,
127
128
  tx_apodizations,
@@ -133,10 +134,12 @@ def tof_correction(
133
134
  n_el,
134
135
  focus_distances,
135
136
  polar_angles,
136
- t_peak=t_peak,
137
- tx_waveform_indices=tx_waveform_indices,
138
- lens_thickness=lens_thickness,
139
- lens_sound_speed=lens_sound_speed,
137
+ t_peak,
138
+ tx_waveform_indices,
139
+ transmit_origins,
140
+ apply_lens_correction,
141
+ lens_thickness,
142
+ lens_sound_speed,
140
143
  )
141
144
 
142
145
  n_pix = ops.shape(flatgrid)[0]
@@ -207,7 +210,11 @@ def calculate_delays(
207
210
  polar_angles,
208
211
  t_peak,
209
212
  tx_waveform_indices,
210
- **kwargs,
213
+ transmit_origins,
214
+ apply_lens_correction=False,
215
+ lens_thickness=None,
216
+ lens_sound_speed=None,
217
+ n_iter=2,
211
218
  ):
212
219
  """Calculates the delays in samples to every pixel in the grid.
213
220
 
@@ -242,6 +249,16 @@ def calculate_delays(
242
249
  `(n_waveforms,)`.
243
250
  tx_waveform_indices (Tensor): The indices of the waveform used for each
244
251
  transmit of shape `(n_tx,)`.
252
+ transmit_origins (Tensor): Transmit origins of shape (n_tx, 3).
253
+ apply_lens_correction (bool, optional): Whether to apply lens correction to
254
+ time-of-flights. This makes it slower, but more accurate in the near-field.
255
+ Defaults to False.
256
+ lens_thickness (float, optional): Thickness of the lens in meters. Used for
257
+ lens correction.
258
+ lens_sound_speed (float, optional): Speed of sound in the lens in m/s. Used
259
+ for lens correction.
260
+ n_iter (int, optional): Number of iterations for the Newton-Raphson method
261
+ used in lens correction. Defaults to 2.
245
262
 
246
263
 
247
264
  Returns:
@@ -252,38 +269,56 @@ def calculate_delays(
252
269
  `(n_pix, n_el)`.
253
270
  """
254
271
 
255
- def _tx_distances(polar_angles, t0_delays, tx_apodizations, focus_distances):
256
- return distance_Tx_generic(
257
- grid,
258
- t0_delays,
259
- tx_apodizations,
272
+ # Validate input shapes
273
+ for arr in [t0_delays, grid, tx_apodizations, probe_geometry]:
274
+ assert arr.ndim == 2
275
+ assert probe_geometry.shape[0] == n_el
276
+ assert t0_delays.shape[0] == n_tx
277
+
278
+ if not apply_lens_correction:
279
+ # Compute receive distances in meters of shape (n_pix, n_el)
280
+ rx_distances = distance_Rx(grid, probe_geometry)
281
+
282
+ # Convert distances to delays in seconds
283
+ rx_delays = rx_distances / sound_speed
284
+ else:
285
+ # Compute lens-corrected travel times from each element to each pixel
286
+ assert lens_thickness is not None, "lens_thickness must be provided for lens correction."
287
+ assert lens_sound_speed is not None, (
288
+ "lens_sound_speed must be provided for lens correction."
289
+ )
290
+ rx_delays = compute_lens_corrected_travel_times(
260
291
  probe_geometry,
261
- focus_distances,
262
- polar_angles,
292
+ grid,
293
+ lens_thickness,
294
+ lens_sound_speed,
263
295
  sound_speed,
296
+ n_iter=n_iter,
264
297
  )
265
298
 
266
- tx_distances = vmap(_tx_distances)(polar_angles, t0_delays, tx_apodizations, focus_distances)
267
- tx_distances = ops.transpose(tx_distances, (1, 0))
268
- # tx_distances shape is now (n_pix, n_tx)
299
+ # Compute transmit delays
300
+ tx_delays = vmap(transmit_delays, in_axes=(None, 0, 0, None, 0, 0, 0, None, 0), out_axes=1)(
301
+ grid,
302
+ t0_delays,
303
+ tx_apodizations,
304
+ rx_delays,
305
+ focus_distances,
306
+ polar_angles,
307
+ initial_times,
308
+ None,
309
+ transmit_origins,
310
+ )
269
311
 
270
- # Compute receive distances
271
- def _rx_distances(probe_geometry):
272
- return distance_Rx(grid, probe_geometry)
312
+ # Add the offset to the transmit peak time
313
+ tx_delays += ops.take(t_peak, tx_waveform_indices)[None]
273
314
 
274
- rx_distances = vmap(_rx_distances)(probe_geometry)
275
- rx_distances = ops.transpose(rx_distances, (1, 0))
276
- # rx_distances shape is now (n_pix, n_el)
315
+ # TODO: nan to num needed?
316
+ # tx_delays = ops.nan_to_num(tx_delays, nan=0.0, posinf=0.0, neginf=0.0)
317
+ # rx_delays = ops.nan_to_num(rx_delays, nan=0.0, posinf=0.0, neginf=0.0)
277
318
 
278
- # Compute the delays [in samples] from the distances
279
- # The units here are ([m]/[m/s]-[s])*[1/s] resulting in a unitless quantity
280
- # TODO: Add pulse width to transmit delays
281
- tx_delays = (
282
- tx_distances / sound_speed
283
- - initial_times[None]
284
- + ops.take(t_peak, tx_waveform_indices)[None]
285
- ) * sampling_frequency
286
- rx_delays = (rx_distances / sound_speed) * sampling_frequency
319
+ # Convert from seconds to samples
320
+ tx_delays *= sampling_frequency
321
+ rx_delays *= sampling_frequency
287
322
 
288
323
  return tx_delays, rx_delays
289
324
 
@@ -414,7 +449,7 @@ def complex_rotate(iq, theta):
414
449
  def distance_Rx(grid, probe_geometry):
415
450
  """Computes distance to user-defined pixels from elements.
416
451
 
417
- Expects all inputs to be numpy arrays specified in SI units.
452
+ Expects all inputs to be arrays specified in SI units.
418
453
 
419
454
  Args:
420
455
  grid (ops.Tensor): Pixel positions in x,y,z of shape `(n_pix, 3)`.
@@ -425,83 +460,117 @@ def distance_Rx(grid, probe_geometry):
425
460
  `(n_pix, n_el)`.
426
461
  """
427
462
  # Get norm of distance vector between elements and pixels via broadcasting
428
- dist = ops.linalg.norm(grid - probe_geometry[None, ...], axis=-1)
463
+ dist = ops.linalg.norm(grid[:, None, :] - probe_geometry[None, :, :], axis=-1)
429
464
  return dist
430
465
 
431
466
 
432
- def distance_Tx_generic(
467
+ def transmit_delays(
433
468
  grid,
434
469
  t0_delays,
435
470
  tx_apodization,
436
- probe_geometry,
471
+ rx_delays,
437
472
  focus_distance,
438
473
  polar_angle,
439
- sound_speed=1540,
474
+ initial_time,
475
+ azimuth_angle=None,
476
+ transmit_origin=None,
440
477
  ):
441
- """Generic transmit distance calculation.
478
+ """
479
+ Computes the transmit delay from transmission to each pixel in the grid.
480
+
481
+ Uses the first-arrival time for pixels before the focus (or virtual source)
482
+ and the last-arrival time for pixels beyond the focus.
442
483
 
443
- Computes distance to user-defined pixels for generic transmits based on
444
- the t0_delays.
484
+ The receive delays can be precomputed since they do not depend on the
485
+ transmit parameters.
445
486
 
446
487
  Args:
447
- grid (ops.Tensor): Flattened tensor of pixel positions in x,y,z of shape
448
- `(n_pix, 3)`
449
- t0_delays (ops.Tensor): The transmit delays in seconds of shape `(n_el,)`,
450
- shifted such that the smallest delay is 0. Defaults to None.
451
- tx_apodization (ops.Tensor): The transmit apodizations of shape
452
- `(n_el,)`.
453
- probe_geometry (ops.Tensor): The positions of the transducer elements of shape
454
- `(n_el, 3)`.
488
+ grid (ops.Tensor): Flattened tensor of pixel positions in x,y,z of shape `(n_pix, 3)`
489
+ t0_delays (Tensor): The transmit delays in seconds of shape (n_el,).
490
+ tx_apodization (Tensor): The transmit apodization of shape (n_el,).
491
+ rx_delays (Tensor): The travel times in seconds from elements to pixels
492
+ of shape (n_pix, n_el).
455
493
  focus_distance (float): The focus distance in meters.
456
494
  polar_angle (float): The polar angle in radians.
457
- sound_speed (float): The speed of sound in m/s. Defaults to 1540.
495
+ initial_time (float): The initial time for this transmit in seconds.
496
+ azimuth_angle (float, optional): The azimuth angle in radians. Defaults to 0.0.
497
+ transmit_origin (ops.Tensor, optional): The origin of the transmit beam of shape (3,).
498
+ If None, defaults to (0, 0, 0). Defaults to None.
458
499
 
459
500
  Returns:
460
- Tensor: Distance from each pixel to each element in meters
461
- of shape `(n_pix,)`
501
+ Tensor: The transmit delays of shape `(n_pix,)`.
462
502
  """
463
- # Get the individual x, y, and z components of the pixel coordinates
464
- x = grid[:, 0]
465
- y = grid[:, 1]
466
- z = grid[:, 2]
467
-
468
- # Reshape x, y, and z to shape (n_pix, 1)
469
- x = x[..., None]
470
- y = y[..., None]
471
- z = z[..., None]
472
-
473
- # Get the individual x, y, and z coordinates of the elements and add a
474
- # dummy dimension at the beginning to shape (1, n_el).
475
- ele_x = probe_geometry[None, :, 0]
476
- ele_y = probe_geometry[None, :, 1]
477
- ele_z = probe_geometry[None, :, 2]
478
-
479
- # Compute the differences dx, dy, and dz of shape (n_pix, n_el)
480
- dx = x - ele_x
481
- dy = y - ele_y
482
- dz = z - ele_z
483
-
484
- # Define an infinite offset for elements that do not fire to not consider them in
485
- # the transmit distance calculation.
503
+ # Add a large offset for elements that are not used in the transmit to
504
+ # disqualify them from being the closest element
486
505
  offset = ops.where(tx_apodization == 0, np.inf, 0.0)
487
506
 
488
- # Compute the distance between the elements and the pixels of shape
489
- # (n_pix, n_el)
490
- dist = t0_delays[None] * sound_speed + ops.sqrt(dx**2 + dy**2 + dz**2)
491
-
492
- # Compute the z-coordinate of the focal point
493
- focal_z = ops.cos(polar_angle) * focus_distance
507
+ # Compute total travel time from t=0 to each pixel via each element
508
+ # rx_delays has shape (n_pix, n_el)
509
+ # t0_delays has shape (n_el,)
510
+ total_times = rx_delays + t0_delays[None, :]
511
+
512
+ if azimuth_angle is None:
513
+ azimuth_angle = ops.zeros_like(polar_angle)
514
+
515
+ # Set origin to (0, 0, 0) if not provided
516
+ if transmit_origin is None:
517
+ transmit_origin = ops.zeros(3, dtype=grid.dtype)
518
+
519
+ # Compute the 3D position of the focal point
520
+ # The beam direction vector
521
+ beam_direction = ops.stack(
522
+ [
523
+ ops.sin(polar_angle) * ops.cos(azimuth_angle),
524
+ ops.sin(polar_angle) * ops.sin(azimuth_angle),
525
+ ops.cos(polar_angle),
526
+ ]
527
+ )
494
528
 
495
- # Compute the effective distance of the pixels to the wavefront by computing the
496
- # largest distance over all the elements when the pixel is behind the virtual
497
- # source and the smallest distance otherwise.
498
- dist = ops.where(
499
- ops.cast(ops.sign(focus_distance), "float32") * (grid[:, 2] - focal_z) <= 0.0,
500
- ops.min(dist + offset[None], 1),
501
- ops.max(dist - offset[None], 1),
529
+ # Handle plane wave case where focus_distance is set to zero
530
+ # We use np.inf to consider the first wavefront arrival for all pixels
531
+ focus_distance = ops.where(focus_distance == 0.0, np.inf, focus_distance)
532
+
533
+ # Compute focal point position: origin + focus_distance * beam_direction
534
+ # For negative focus_distance (diverging/virtual source), this is behind the origin
535
+ focal_point = transmit_origin + focus_distance * beam_direction # shape (3,)
536
+
537
+ # Deal with plane wave case where focus_distance is infinite and beam_direction is zero
538
+ # (np.inf * 0.0 -> nan) so we convert nan to zero
539
+ focal_point = ops.where(ops.isnan(focal_point), 0.0, focal_point)
540
+
541
+ # Compute the position of each pixel relative to the focal point
542
+ pixel_relative_to_focus = grid - focal_point[None, :] # shape (n_pix, 3)
543
+
544
+ # Project onto the beam direction to determine if pixel is before or after focus
545
+ # Positive projection means pixel is in the direction of beam propagation (beyond focus)
546
+ # Negative projection means pixel is behind the focus (before focus)
547
+ projection_along_beam = ops.sum(
548
+ pixel_relative_to_focus * beam_direction[None, :], axis=-1
549
+ ) # shape (n_pix,)
550
+
551
+ # For focused waves (positive focus_distance):
552
+ # - Use min time for pixels before focus (projection < 0)
553
+ # - Use max time for pixels beyond focus (projection > 0)
554
+ # For diverging waves (negative focus_distance, virtual source):
555
+ # - The sign of focus_distance flips the logic
556
+ # - Use min time for pixels between transducer and virtual source
557
+ # - Use max time for pixels beyond transducer
558
+ is_before_focus = ops.cast(ops.sign(focus_distance), "float32") * projection_along_beam < 0.0
559
+
560
+ # Compute the effective time of the pixels to the wavefront by computing the
561
+ # smallest time over all elements (first wavefront arrival) for pixels before
562
+ # the focus, and the largest time (last wavefront contribution) for pixels
563
+ # beyond the focus.
564
+ tx_delay = ops.where(
565
+ is_before_focus,
566
+ ops.min(total_times + offset[None, :], axis=-1),
567
+ ops.max(total_times - offset[None, :], axis=-1),
502
568
  )
503
569
 
504
- return dist
570
+ # Subtract the initial time offset for this transmit
571
+ tx_delay = tx_delay - initial_time
572
+
573
+ return tx_delay
505
574
 
506
575
 
507
576
  def fnumber_mask(flatgrid, probe_geometry, f_number, fnum_window_fn):
@@ -52,7 +52,7 @@ def compute_t0_delays_planewave(probe_geometry, polar_angles, azimuth_angles=0,
52
52
 
53
53
 
54
54
  def compute_t0_delays_focused(
55
- origins,
55
+ transmit_origins,
56
56
  focus_distances,
57
57
  probe_geometry,
58
58
  polar_angles,
@@ -63,12 +63,12 @@ def compute_t0_delays_focused(
63
63
  the first element fires at t=0.
64
64
 
65
65
  Args:
66
- origins (np.ndarray): The origin of the focused transmit of shape (n_tx, 3,).
67
- focus_distance (float): The distance to the focus.
66
+ transmit_origins (np.ndarray): The origin of the focused transmit of shape (n_tx, 3,).
67
+ focus_distances (np.ndarray): The distance to the focus for each transmit of shape (n_tx,).
68
68
  probe_geometry (np.ndarray): The positions of the elements in the array of
69
69
  shape (element, 3).
70
- polar_angles (np.ndarray): The polar angles of the planewave in radians of shape (n_tx,).
71
- azimuth_angles (np.ndarray, optional): The azimuth angles of the planewave in
70
+ polar_angles (np.ndarray): The polar angles in radians of shape (n_tx,).
71
+ azimuth_angles (np.ndarray, optional): The azimuth angles in
72
72
  radians of shape (n_tx,).
73
73
  sound_speed (float, optional): The speed of sound. Defaults to 1540.
74
74
 
@@ -79,12 +79,15 @@ def compute_t0_delays_focused(
79
79
  assert polar_angles.shape == (n_tx,), (
80
80
  f"polar_angles must have length n_tx = {n_tx}. Got length {len(polar_angles)}."
81
81
  )
82
- assert origins.shape == (n_tx, 3), (
83
- f"origins must have shape (n_tx, 3). Got shape {origins.shape}."
82
+ assert transmit_origins.shape == (n_tx, 3), (
83
+ f"transmit_origins must have shape (n_tx, 3). Got shape {transmit_origins.shape}."
84
84
  )
85
85
  assert probe_geometry.shape[1] == 3 and probe_geometry.ndim == 2, (
86
86
  f"probe_geometry must have shape (element, 3). Got shape {probe_geometry.shape}."
87
87
  )
88
+ assert focus_distances.shape == (n_tx,), (
89
+ f"focus_distances must have length n_tx = {n_tx}. Got length {len(focus_distances)}."
90
+ )
88
91
 
89
92
  # Convert single angles to arrays for broadcasting
90
93
  polar_angles = np.atleast_1d(polar_angles)
@@ -107,12 +110,12 @@ def compute_t0_delays_focused(
107
110
  )
108
111
 
109
112
  # Add a new dimension for broadcasting
110
- # The shape is now (n_tx, n_el, 3)
113
+ # The shape is now (n_tx, 1, 3)
111
114
  v = np.expand_dims(v, axis=1)
112
115
 
113
116
  # Compute the location of the virtual source by adding the focus distance
114
117
  # to the origin along the wave vectors.
115
- virtual_sources = origins[:, None] + focus_distances[:, None, None] * v
118
+ virtual_sources = transmit_origins[:, None] + focus_distances[:, None, None] * v
116
119
 
117
120
  # Compute the distances between the virtual sources and each element
118
121
  dist = np.linalg.norm(virtual_sources - probe_geometry, axis=-1)