FastSIMUS 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fast_simus/__init__.py ADDED
@@ -0,0 +1,33 @@
1
+ """FastSIMUS - Fast Simulator for Medical Ultrasound based on SIMUS/MUST."""
2
+
3
+ from fast_simus.medium_params import MediumParams
4
+ from fast_simus.pfield import PfieldPlan, PfieldStrategy, pfield, pfield_compute, pfield_precompute
5
+ from fast_simus.simus import SimusPlan, SimusResult, SimusStrategy, simus, simus_compute, simus_precompute
6
+ from fast_simus.transducer_params import BaffleType, TransducerParams
7
+ from fast_simus.tx_delay import (
8
+ diverging_wave,
9
+ focused,
10
+ plane_wave,
11
+ )
12
+ from fast_simus.utils.geometry import element_positions
13
+
14
+ __all__ = [
15
+ "BaffleType",
16
+ "MediumParams",
17
+ "PfieldPlan",
18
+ "PfieldStrategy",
19
+ "SimusPlan",
20
+ "SimusResult",
21
+ "SimusStrategy",
22
+ "TransducerParams",
23
+ "diverging_wave",
24
+ "element_positions",
25
+ "focused",
26
+ "pfield",
27
+ "pfield_compute",
28
+ "pfield_precompute",
29
+ "plane_wave",
30
+ "simus",
31
+ "simus_compute",
32
+ "simus_precompute",
33
+ ]
@@ -0,0 +1,261 @@
1
+ """Physics helpers for pfield computation.
2
+
3
+ Pure Array API functions for geometry, phase initialization, frequency
4
+ selection, and obliquity. No loop structure or backend-specific code.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from math import ceil, log, pi
10
+ from typing import NamedTuple
11
+
12
+ from beartype import beartype as typechecker
13
+ from jaxtyping import Complex, Float, jaxtyped
14
+
15
+ from fast_simus.spectrum import probe_spectrum, pulse_spectrum
16
+ from fast_simus.transducer_params import BaffleType
17
+ from fast_simus.utils._array_api import Array, _ArrayNamespace
18
+
19
+ # Conversion factor: Nepers to dB -- 20/log(10) ≈ 8.6859
20
+ # Shared by Array API path and Metal kernel wrapper.
21
+ NEPER_TO_DB = 20.0 / log(10.0)
22
+
23
+
24
+ class _FrequencyPlan(NamedTuple):
25
+ """Frequency sampling plan for pfield computation.
26
+
27
+ Attributes:
28
+ selected_freqs: Selected frequencies, shape
29
+ pulse_spectrum: Pulse spectrum at selected frequencies, shape
30
+ probe_spectrum: Probe spectrum at selected frequencies, shape
31
+ freq_step: Frequency step in Hz
32
+ """
33
+
34
+ selected_freqs: Float[Array, " n_frequencies"]
35
+ pulse_spectrum: Complex[Array, " n_frequencies"]
36
+ probe_spectrum: Float[Array, " n_frequencies"]
37
+ freq_step: float
38
+
39
+
40
+ @jaxtyped(typechecker=typechecker)
41
+ def _subelement_centroids(
42
+ element_width: float,
43
+ n_sub: int,
44
+ theta_e: Float[Array, " n_elements"],
45
+ xp: _ArrayNamespace,
46
+ ) -> Float[Array, "n_elements n_sub 2"]:
47
+ """Compute sub-element centroid positions relative to element centers.
48
+
49
+ Args:
50
+ element_width: Element width in meters.
51
+ n_sub: Number of sub-elements per element.
52
+ theta_e: Element angular positions in radians.
53
+ xp: Array namespace.
54
+
55
+ Returns:
56
+ Sub-element offsets with shape (n_elements, n_sub, 2) where [..., 0]
57
+ is lateral (x) and [..., 1] is axial (z).
58
+ """
59
+ seg_length = element_width / n_sub
60
+ seg_offsets = xp.asarray([-element_width / 2.0 + seg_length / 2.0 + i * seg_length for i in range(n_sub)])
61
+ # Broadcasting: (n_sub,) -> (1, n_sub), (n_elements,) -> (n_elements, 1)
62
+ seg_offsets_2d = xp.reshape(seg_offsets, (1, n_sub))
63
+ cos_theta = xp.cos(theta_e)[:, None]
64
+ sin_neg_theta = xp.sin(-theta_e)[:, None]
65
+ subelement_dx = seg_offsets_2d * cos_theta
66
+ subelement_dz = seg_offsets_2d * sin_neg_theta
67
+ return xp.stack([subelement_dx, subelement_dz], axis=-1)
68
+
69
+
70
+ @jaxtyped(typechecker=typechecker)
71
+ def _distances_and_angles(
72
+ points: Float[Array, "*batch 2"],
73
+ subelement_offsets: Float[Array, "n_elements n_sub 2"],
74
+ element_pos: Float[Array, "n_elements 2"],
75
+ theta_e: Float[Array, " n_elements"],
76
+ speed_of_sound: float,
77
+ freq_center: float,
78
+ xp: _ArrayNamespace,
79
+ ) -> tuple[
80
+ Float[Array, "*batch n_elements n_sub"],
81
+ Float[Array, "*batch n_elements n_sub"],
82
+ Float[Array, "*batch n_elements n_sub"],
83
+ ]:
84
+ """Compute distances and angles from grid points to sub-elements.
85
+
86
+ Args:
87
+ points: Grid point positions.
88
+ subelement_offsets: Sub-element offsets.
89
+ element_pos: Element positions.
90
+ theta_e: Element angular positions.
91
+ speed_of_sound: Speed of sound in m/s.
92
+ freq_center: Center frequency in Hz.
93
+ xp: Array namespace.
94
+
95
+ Returns:
96
+ Tuple of (distances, sin_theta, theta_arr):
97
+ - distances: Distances from grid points to sub-elements.
98
+ - sin_theta: Sine of angles relative to element normal.
99
+ - theta_arr: Angles relative to element normal.
100
+ """
101
+ delta: Float[Array, "*batch n_elements n_sub xz=2"] = (
102
+ points[..., None, None, :] - subelement_offsets - element_pos[:, None, :]
103
+ )
104
+ delta_x = delta[..., 0]
105
+ delta_z = delta[..., 1]
106
+ dist_squared = delta_x**2 + delta_z**2
107
+ distances = xp.sqrt(dist_squared)
108
+
109
+ # Distances with clipping (use unclipped sqrt for angle computation)
110
+ min_distance = xp.asarray(speed_of_sound / freq_center / 2.0)
111
+ distances_clipped = xp.where(distances < min_distance, min_distance, distances)
112
+
113
+ # Angle relative to element normal
114
+ _div_eps = xp.asarray(1e-16) # Numerical stability for division
115
+ theta_arr = xp.asin((delta_x + _div_eps) / (distances + _div_eps)) - theta_e[:, None]
116
+ sin_theta = xp.sin(theta_arr)
117
+
118
+ return distances_clipped, sin_theta, theta_arr
119
+
120
+
121
+ def _select_frequencies(
122
+ fc: float,
123
+ bandwidth: float,
124
+ tx_n_wavelengths: float,
125
+ db_thresh: float,
126
+ max_freq_step: float,
127
+ xp: _ArrayNamespace,
128
+ ) -> _FrequencyPlan:
129
+ """Select frequency samples for pfield computation.
130
+
131
+ Args:
132
+ fc: Center frequency in Hz.
133
+ bandwidth: Fractional bandwidth.
134
+ tx_n_wavelengths: Number of wavelengths in TX pulse.
135
+ db_thresh: Threshold in dB for frequency component selection.
136
+ max_freq_step: Upper bound for frequency step.
137
+ xp: Array namespace.
138
+
139
+ Returns:
140
+ FrequencyPlan with selected frequencies and spectra.
141
+ """
142
+ # Frequency samples
143
+ n_freq = int(2 * ceil(fc / max_freq_step) + 1)
144
+ frequencies = xp.linspace(0, 2 * fc, n_freq)
145
+ freq_step = float(frequencies[1])
146
+
147
+ # Keep only significant components (dB threshold)
148
+ angular_freqs_all = xp.asarray(2.0 * pi) * frequencies
149
+ spectrum_magnitude = xp.abs(
150
+ pulse_spectrum(angular_freqs_all, fc, tx_n_wavelengths) * probe_spectrum(angular_freqs_all, fc, bandwidth)
151
+ )
152
+ gain_db = 20.0 * xp.log10(xp.asarray(1e-200) + spectrum_magnitude / xp.max(spectrum_magnitude))
153
+ above_threshold = gain_db > db_thresh
154
+ idx_first, idx_last = _first_last_true(xp, above_threshold)
155
+
156
+ selected_freqs = frequencies[idx_first : idx_last + 1]
157
+
158
+ angular_freqs_sel = xp.asarray(2.0 * pi) * selected_freqs
159
+ pulse_spect = pulse_spectrum(angular_freqs_sel, fc, tx_n_wavelengths)
160
+ probe_spect = probe_spectrum(angular_freqs_sel, fc, bandwidth)
161
+
162
+ return _FrequencyPlan(selected_freqs, pulse_spect, probe_spect, freq_step)
163
+
164
+
165
+ @jaxtyped(typechecker=typechecker)
166
+ def _obliquity_factor(
167
+ theta_arr: Float[Array, "*batch n_elements n_sub"],
168
+ baffle: BaffleType | float,
169
+ xp: _ArrayNamespace,
170
+ ) -> Float[Array, "*batch n_elements n_sub"]:
171
+ """Compute obliquity factor based on baffle type.
172
+
173
+ Args:
174
+ theta_arr: Angles relative to element normal.
175
+ baffle: Baffle type or impedance ratio.
176
+ xp: Array namespace.
177
+
178
+ Returns:
179
+ Obliquity factor.
180
+ """
181
+ non_rigid_baffle = baffle != BaffleType.RIGID
182
+ _horizon_floor = xp.asarray(1e-16) # Near-zero for beyond-hemisphere angles
183
+
184
+ if non_rigid_baffle:
185
+ if baffle == BaffleType.SOFT:
186
+ obliquity_factor = xp.cos(theta_arr)
187
+ else:
188
+ cos_th = xp.cos(theta_arr)
189
+ obliquity_factor = cos_th / (cos_th + float(baffle))
190
+ else:
191
+ obliquity_factor = xp.ones(theta_arr.shape)
192
+
193
+ obliquity_factor = xp.where(
194
+ xp.abs(theta_arr) >= xp.asarray(pi / 2),
195
+ _horizon_floor,
196
+ obliquity_factor,
197
+ )
198
+
199
+ return obliquity_factor
200
+
201
+
202
+ @jaxtyped(typechecker=typechecker)
203
+ def _init_exponentials(
204
+ freq_start: float | Float[Array, ""],
205
+ speed_of_sound: float,
206
+ attenuation: float,
207
+ distances: Float[Array, "*batch n_elements n_sub"],
208
+ obliquity_factor: Float[Array, "*batch n_elements n_sub"],
209
+ freq_step: float | Float[Array, ""],
210
+ xp: _ArrayNamespace,
211
+ ) -> tuple[
212
+ Complex[Array, "*batch n_elements n_sub"],
213
+ Complex[Array, "*batch n_elements n_sub"],
214
+ ]:
215
+ """Initialize exponential arrays for frequency loop.
216
+
217
+ Args:
218
+ freq_start: Initial frequency in Hz (scalar or 0-d array).
219
+ speed_of_sound: Speed of sound in m/s.
220
+ attenuation: Attenuation coefficient in dB/cm/MHz.
221
+ distances: Distances.
222
+ obliquity_factor: Obliquity factor.
223
+ freq_step: Frequency step in Hz (scalar or 0-d array).
224
+ xp: Array namespace.
225
+
226
+ Returns:
227
+ Tuple of (phase_decay, phase_decay_step):
228
+ - phase_decay: Initial complex exponential array.
229
+ - phase_decay_step: Complex exponential increment per frequency step.
230
+ """
231
+ wavenumber_init = 2.0 * pi * freq_start / speed_of_sound
232
+ attenuation_wavenum = attenuation / NEPER_TO_DB * freq_start / 1e6 * 1e2
233
+
234
+ # exp(-kwa*distances + 1j*mod(kw*distances, 2pi))
235
+ kw0_r = xp.asarray(wavenumber_init) * distances
236
+ two_pi = xp.asarray(2.0 * pi)
237
+ phase_mod = kw0_r - two_pi * xp.floor(kw0_r / two_pi)
238
+ phase_decay = xp.exp(xp.asarray(-attenuation_wavenum) * distances + xp.asarray(1j) * phase_mod)
239
+
240
+ wavenumber_step = 2.0 * pi * freq_step / speed_of_sound
241
+ attenuation_step = attenuation / NEPER_TO_DB * freq_step / 1e6 * 1e2
242
+ phase_decay_step = xp.exp(xp.asarray(-attenuation_step + 1j * wavenumber_step) * distances)
243
+
244
+ # Incorporate obliquity / sqrt(distances) (2D, no elevation)
245
+ phase_decay = phase_decay * obliquity_factor / xp.sqrt(distances)
246
+
247
+ return phase_decay, phase_decay_step
248
+
249
+
250
+ def _first_last_true(xp: _ArrayNamespace, mask: Array) -> tuple[int, int]:
251
+ """Find first and last True index in 1D boolean array. JAX-compatible (no nonzero)."""
252
+ n = mask.shape[0]
253
+ if n == 0:
254
+ return 0, 0
255
+ # Cast to int: argmax on bool not allowed by array_api_strict
256
+ mask_int = xp.asarray(mask, dtype=xp.int32)
257
+ first = int(xp.argmax(mask_int))
258
+ if int(xp.max(mask_int)) == 0:
259
+ return 0, 0
260
+ last = n - 1 - int(xp.argmax(mask_int[::-1]))
261
+ return first, last
@@ -0,0 +1,203 @@
1
+ """Loop drivers for pfield frequency sweep (Layer 3).
2
+
3
+ Each driver iterates the same _freq_step_body() using a different mechanism:
4
+
5
+ - _freq_outer_python: Python for-loop (NumPy/CuPy, constant memory)
6
+ - _pfield_freq_vectorized: tensor broadcast (small grids only)
7
+ - _freq_outer_scan: JAX lax.scan for O(1) compilation cost
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from math import pi
13
+
14
+ import array_api_extra as xpx
15
+ from beartype import beartype as typechecker
16
+ from jaxtyping import Bool, Complex, Float, jaxtyped
17
+
18
+ from fast_simus.utils._array_api import Array, _ArrayNamespace
19
+
20
+
21
+ def _freq_step_body(
22
+ phase: Complex[Array, " *grid n_sources"],
23
+ phase_step: Complex[Array, " *grid n_sources"],
24
+ spectrum_k: complex | Array,
25
+ xp: _ArrayNamespace,
26
+ *,
27
+ directivity_k: Float[Array, " *grid n_sources"] | None = None,
28
+ ) -> tuple[Complex[Array, " *grid n_sources"], Float[Array, " *grid"]]:
29
+ """One frequency step: geometric update, source contraction, spectrum weight.
30
+
31
+ Single source of truth for per-frequency math. Source points are
32
+ already flattened (n_elements * n_sub) with 1/n_sub absorbed.
33
+
34
+ Args:
35
+ phase: Current phase state (geometric progression).
36
+ phase_step: Per-step multiplier for geometric progression.
37
+ spectrum_k: Combined pulse*probe spectrum weight for this frequency.
38
+ xp: Array namespace.
39
+ directivity_k: Per-source directivity for this frequency (optional).
40
+
41
+ Returns:
42
+ Tuple of (updated_phase, rp_k) where rp_k = |P_k|^2 at this frequency.
43
+ """
44
+ if directivity_k is not None:
45
+ phase_weighted = phase * directivity_k
46
+ else:
47
+ phase_weighted = phase
48
+ pressure_k = spectrum_k * xp.sum(phase_weighted, axis=-1)
49
+ phase = phase * phase_step
50
+ return phase, xp.real(pressure_k * xp.conj(pressure_k))
51
+
52
+
53
+ def _freq_outer_python(
54
+ phase_decay_init: Complex[Array, " *grid n_sources"],
55
+ phase_decay_step: Complex[Array, " *grid n_sources"],
56
+ is_out: Bool[Array, " *grid"],
57
+ wavenumbers: Float[Array, " n_freq"],
58
+ pulse_spect: Complex[Array, " n_freq"],
59
+ probe_spect: Float[Array, " n_freq"],
60
+ seg_length: float,
61
+ sin_theta: Float[Array, " *grid n_sources"],
62
+ full_frequency_directivity: bool,
63
+ xp: _ArrayNamespace,
64
+ ) -> Float[Array, " *grid"]:
65
+ """Python for-loop driver for NumPy/CuPy: constant O(grid * sources) memory.
66
+
67
+ Iterates one frequency at a time using _freq_step_body, accumulating
68
+ |P_k|^2 into the result. Peak memory is independent of n_freq.
69
+ """
70
+ spectra = pulse_spect * probe_spect
71
+ n_freq = int(wavenumbers.shape[0])
72
+ zero = xp.asarray(0.0)
73
+
74
+ phase = phase_decay_init
75
+ rp = xp.zeros(phase.shape[:-1])
76
+
77
+ if full_frequency_directivity:
78
+ for k in range(n_freq):
79
+ sinc_arg = wavenumbers[k] * seg_length / 2.0 * sin_theta / pi
80
+ directivity_k = xpx.sinc(sinc_arg, xp=xp)
81
+ phase, rp_k = _freq_step_body(phase, phase_decay_step, spectra[k], xp, directivity_k=directivity_k)
82
+ rp = rp + xp.where(is_out, zero, rp_k)
83
+ else:
84
+ for k in range(n_freq):
85
+ phase, rp_k = _freq_step_body(phase, phase_decay_step, spectra[k], xp)
86
+ rp = rp + xp.where(is_out, zero, rp_k)
87
+
88
+ return rp
89
+
90
+
91
+ @jaxtyped(typechecker=typechecker)
92
+ def _pfield_freq_vectorized(
93
+ phase_decay_init: Complex[Array, " *grid n_sources"],
94
+ phase_decay_step: Complex[Array, " *grid n_sources"],
95
+ is_out: Bool[Array, " *grid"],
96
+ wavenumbers: Float[Array, " n_freq"],
97
+ pulse_spect: Complex[Array, " n_freq"],
98
+ probe_spect: Float[Array, " n_freq"],
99
+ seg_length: float,
100
+ sin_theta: Float[Array, " *grid n_sources"],
101
+ full_frequency_directivity: bool,
102
+ xp: _ArrayNamespace,
103
+ ) -> Float[Array, " *grid"]:
104
+ """Vectorized frequency sweep: broadcast all frequencies at once.
105
+
106
+ Reference implementation kept for testing and small-grid use cases.
107
+ Production code uses _freq_outer_python (constant memory) instead.
108
+
109
+ Uses the geometric progression: phase_decay[k] = init * step^k.
110
+ Source points are pre-flattened with 1/n_sub absorbed.
111
+
112
+ Best for small grids where the (*grid, n_sources, n_freq) tensor fits
113
+ in memory. For large grids, use an iterative driver instead.
114
+ """
115
+ n_freq = wavenumbers.shape[0]
116
+ exponents = xp.arange(n_freq, dtype=wavenumbers.dtype)
117
+
118
+ # (*grid, n_sources, n_freq) via geometric progression
119
+ phase_k = phase_decay_init[..., None] * phase_decay_step[..., None] ** exponents
120
+
121
+ if full_frequency_directivity:
122
+ sinc_arg = wavenumbers * seg_length / 2.0 * sin_theta[..., None] / pi
123
+ phase_k = xpx.sinc(sinc_arg, xp=xp) * phase_k
124
+
125
+ # Contract over sources: (*grid, n_sources, n_freq) -> (*grid, n_freq)
126
+ pressure_all = xp.sum(phase_k, axis=-2)
127
+
128
+ pressure_all = pulse_spect * probe_spect * pressure_all
129
+
130
+ pressure_all = xp.where(is_out[..., None], xp.asarray(0.0 + 0j), pressure_all)
131
+
132
+ return xp.sum(xp.abs(pressure_all) ** 2, axis=-1)
133
+
134
+
135
+ def _freq_outer_scan(
136
+ phase_decay_init: Complex[Array, " *grid n_sources"],
137
+ phase_decay_step: Complex[Array, " *grid n_sources"],
138
+ is_out: Bool[Array, " *grid"],
139
+ wavenumbers: Float[Array, " n_freq"],
140
+ pulse_spect: Complex[Array, " n_freq"],
141
+ probe_spect: Float[Array, " n_freq"],
142
+ seg_length: float,
143
+ sin_theta: Float[Array, " *grid n_sources"],
144
+ full_frequency_directivity: bool,
145
+ xp: _ArrayNamespace,
146
+ ) -> Float[Array, " *grid"]:
147
+ """JAX vmap+scan driver: per-grid-point kernel vmapped over the grid.
148
+
149
+ The scan carry is only (n_sources,) per grid point, matching the Metal
150
+ kernel's per-thread model. vmap handles parallelism across grid points.
151
+ """
152
+ import jax
153
+ import jax.numpy as jnp
154
+
155
+ grid_shape = phase_decay_init.shape[:-1]
156
+ n_sources = phase_decay_init.shape[-1]
157
+
158
+ # Flatten grid to (n_grid, n_sources) and (n_grid,)
159
+ init_flat = xp.reshape(phase_decay_init, (-1, n_sources))
160
+ step_flat = xp.reshape(phase_decay_step, (-1, n_sources))
161
+ is_out_flat = xp.reshape(is_out, (-1,))
162
+ sin_theta_flat = xp.reshape(sin_theta, (-1, n_sources))
163
+
164
+ spectra = pulse_spect * probe_spect
165
+
166
+ if full_frequency_directivity:
167
+
168
+ def _single_point_scan(
169
+ phase_init_g: jax.Array, phase_step_g: jax.Array, is_out_g: jax.Array, sin_theta_g: jax.Array
170
+ ) -> jax.Array:
171
+ def scan_fn(carry, k):
172
+ phase, rp = carry
173
+ sinc_arg = wavenumbers[k] * seg_length / 2.0 * sin_theta_g / pi
174
+ directivity_k = xpx.sinc(sinc_arg, xp=xp)
175
+ p_k = spectra[k] * xp.sum(phase * directivity_k)
176
+ rp_k = xp.real(p_k * xp.conj(p_k))
177
+ rp = rp + xp.where(is_out_g, xp.asarray(0.0), rp_k)
178
+ phase = phase * phase_step_g
179
+ return (phase, rp), None
180
+
181
+ (_, rp), _ = jax.lax.scan(scan_fn, (phase_init_g, jnp.float32(0.0)), jnp.arange(spectra.shape[0]))
182
+ return rp
183
+
184
+ rp_flat = jax.vmap(_single_point_scan)(init_flat, step_flat, is_out_flat, sin_theta_flat)
185
+ else:
186
+
187
+ def _single_point_scan_no_dir(
188
+ phase_init_g: jax.Array, phase_step_g: jax.Array, is_out_g: jax.Array
189
+ ) -> jax.Array:
190
+ def scan_fn(carry, spectrum_k):
191
+ phase, rp = carry
192
+ p_k = spectrum_k * xp.sum(phase)
193
+ rp_k = xp.real(p_k * xp.conj(p_k))
194
+ rp = rp + xp.where(is_out_g, xp.asarray(0.0), rp_k)
195
+ phase = phase * phase_step_g
196
+ return (phase, rp), None
197
+
198
+ (_, rp), _ = jax.lax.scan(scan_fn, (phase_init_g, jnp.float32(0.0)), spectra)
199
+ return rp
200
+
201
+ rp_flat = jax.vmap(_single_point_scan_no_dir)(init_flat, step_flat, is_out_flat)
202
+
203
+ return xp.reshape(rp_flat, grid_shape)