FastSIMUS 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_simus/__init__.py +33 -0
- fast_simus/_pfield_math.py +261 -0
- fast_simus/_pfield_strategies.py +203 -0
- fast_simus/_simus_strategies.py +210 -0
- fast_simus/backends/__init__.py +1 -0
- fast_simus/backends/mlx.py +101 -0
- fast_simus/kernels/__init__.py +9 -0
- fast_simus/kernels/cuda_simus.py +321 -0
- fast_simus/kernels/metal_pfield.py +219 -0
- fast_simus/kernels/metal_simus.py +377 -0
- fast_simus/kernels/pfield.metal +97 -0
- fast_simus/kernels/simus_fused.cu +332 -0
- fast_simus/kernels/simus_rx_simd.metal +128 -0
- fast_simus/kernels/simus_tx_tiled.metal +175 -0
- fast_simus/medium_params.py +22 -0
- fast_simus/pfield.py +475 -0
- fast_simus/py.typed +0 -0
- fast_simus/simus.py +567 -0
- fast_simus/spectrum.py +107 -0
- fast_simus/transducer_params.py +160 -0
- fast_simus/transducer_presets.py +102 -0
- fast_simus/tx_delay.py +276 -0
- fast_simus/utils/__init__.py +5 -0
- fast_simus/utils/_array_api.py +294 -0
- fast_simus/utils/geometry.py +88 -0
- fastsimus-0.0.1.dist-info/METADATA +594 -0
- fastsimus-0.0.1.dist-info/RECORD +28 -0
- fastsimus-0.0.1.dist-info/WHEEL +4 -0
fast_simus/simus.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
1
|
+
"""Ultrasound RF signal simulation for linear and convex arrays.
|
|
2
|
+
|
|
3
|
+
Implements the SIMUS algorithm: for each frequency in the transmit bandwidth,
|
|
4
|
+
compute forward TX pressure at scatterers, scatter by reflection coefficients,
|
|
5
|
+
back-propagate to receive elements (acoustic reciprocity), accumulate complex
|
|
6
|
+
RF spectrum, then IFFT to time-domain RF signals.
|
|
7
|
+
|
|
8
|
+
All functions are Array API compliant and work with NumPy, JAX, CuPy backends.
|
|
9
|
+
|
|
10
|
+
References:
|
|
11
|
+
Garcia D. SIMUS: an open-source simulator for medical ultrasound imaging.
|
|
12
|
+
Part I: theory & examples. CMPB, 2022;218:106726.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from enum import StrEnum
|
|
18
|
+
from math import ceil, inf, log2, pi
|
|
19
|
+
from types import ModuleType
|
|
20
|
+
from typing import NamedTuple, cast
|
|
21
|
+
|
|
22
|
+
import array_api_extra as xpx
|
|
23
|
+
from array_api_compat import is_jax_namespace
|
|
24
|
+
from jaxtyping import Complex, Float
|
|
25
|
+
|
|
26
|
+
from fast_simus._pfield_math import (
|
|
27
|
+
_distances_and_angles,
|
|
28
|
+
_init_exponentials,
|
|
29
|
+
_obliquity_factor,
|
|
30
|
+
_select_frequencies,
|
|
31
|
+
_subelement_centroids,
|
|
32
|
+
)
|
|
33
|
+
from fast_simus.medium_params import MediumParams
|
|
34
|
+
from fast_simus.spectrum import probe_spectrum as _probe_spectrum_fn
|
|
35
|
+
from fast_simus.spectrum import pulse_spectrum as _pulse_spectrum_fn
|
|
36
|
+
from fast_simus.transducer_params import TransducerParams
|
|
37
|
+
from fast_simus.utils._array_api import (
|
|
38
|
+
Array,
|
|
39
|
+
_ArrayNamespace,
|
|
40
|
+
_ArrayNamespaceWithFFT,
|
|
41
|
+
array_namespace,
|
|
42
|
+
is_cupy_namespace,
|
|
43
|
+
)
|
|
44
|
+
from fast_simus.utils.geometry import element_positions
|
|
45
|
+
|
|
46
|
+
_DEFAULT_MEDIUM = MediumParams()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _two_way_pulse_duration(
|
|
50
|
+
freq_center: float,
|
|
51
|
+
bandwidth: float,
|
|
52
|
+
tx_n_wavelengths: float,
|
|
53
|
+
xp: _ArrayNamespace,
|
|
54
|
+
) -> float:
|
|
55
|
+
"""Compute the temporal extent of the two-way (pulse-echo) pulse.
|
|
56
|
+
|
|
57
|
+
Replicates the pulse duration computation from PyMUST's getpulse(param, 2).
|
|
58
|
+
Uses pulse_spectrum * probe_spectrum^2, IFFTs, and thresholds at 1/1023.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
freq_center: Center frequency in Hz.
|
|
62
|
+
bandwidth: Fractional bandwidth (0.75 = 75%).
|
|
63
|
+
tx_n_wavelengths: Number of wavelengths of the TX pulse.
|
|
64
|
+
xp: Array namespace (must have FFT extension).
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
Pulse duration in seconds.
|
|
68
|
+
"""
|
|
69
|
+
# hasattr instead of isinstance(_ArrayNamespaceWithFFT) because Python 3.12+
|
|
70
|
+
# Protocol isinstance uses getattr_static, which misses lazy sub-module attrs
|
|
71
|
+
# like numpy.fft. See https://docs.python.org/3/whatsnew/3.12.html#typing
|
|
72
|
+
if not hasattr(xp, "fft"):
|
|
73
|
+
msg = "simus requires an array backend with FFT support (e.g. numpy, jax, cupy)"
|
|
74
|
+
raise RuntimeError(msg)
|
|
75
|
+
xp_fft = cast(_ArrayNamespaceWithFFT, xp)
|
|
76
|
+
|
|
77
|
+
dt = 1e-9
|
|
78
|
+
df = freq_center / tx_n_wavelengths / 32
|
|
79
|
+
p = ceil(log2(1.0 / dt / 2.0 / df))
|
|
80
|
+
n_fft = 2**p
|
|
81
|
+
omega = 2.0 * pi * xp.linspace(0, 1.0 / dt / 2.0, n_fft)
|
|
82
|
+
|
|
83
|
+
# Two-way spectrum: pulse * probe^2
|
|
84
|
+
ps = _pulse_spectrum_fn(omega, freq_center, tx_n_wavelengths)
|
|
85
|
+
pr = _probe_spectrum_fn(omega, freq_center, bandwidth)
|
|
86
|
+
two_way = ps * pr**2
|
|
87
|
+
|
|
88
|
+
pulse = xp_fft.fft.fftshift(xp_fft.fft.irfft(two_way))
|
|
89
|
+
pulse = pulse / xp.max(xp.abs(pulse))
|
|
90
|
+
|
|
91
|
+
above = pulse > (1.0 / 1023)
|
|
92
|
+
n = above.shape[0]
|
|
93
|
+
indices = xp.arange(n)
|
|
94
|
+
masked_min = xp.where(above, indices, xp.asarray(n))
|
|
95
|
+
masked_max = xp.where(above, indices, xp.asarray(-1))
|
|
96
|
+
idx1 = int(xp.min(masked_min))
|
|
97
|
+
idx2 = int(xp.max(masked_max))
|
|
98
|
+
|
|
99
|
+
if idx1 >= n:
|
|
100
|
+
return tx_n_wavelengths / freq_center
|
|
101
|
+
|
|
102
|
+
trim_idx = min(idx1 + 1, 2 * n_fft - 1 - idx2 - 1)
|
|
103
|
+
pulse_trimmed = pulse[-trim_idx : trim_idx - 2 : -1]
|
|
104
|
+
return float(pulse_trimmed.shape[0] * dt)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class SimusStrategy(StrEnum):
|
|
108
|
+
"""Backend strategy for the simus frequency sweep.
|
|
109
|
+
|
|
110
|
+
Attributes:
|
|
111
|
+
PYTHON: Python for-loop (NumPy/CuPy, constant memory).
|
|
112
|
+
SCAN: JAX lax.scan for O(1) compilation cost.
|
|
113
|
+
METAL: Custom Metal kernel on Apple Silicon (MLX).
|
|
114
|
+
CUDA: Custom CUDA kernel on NVIDIA GPUs (CuPy + NVRTC).
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
PYTHON = "python"
|
|
118
|
+
SCAN = "scan"
|
|
119
|
+
METAL = "metal"
|
|
120
|
+
CUDA = "cuda"
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class SimusResult(NamedTuple):
|
|
124
|
+
"""Result of simus RF signal simulation.
|
|
125
|
+
|
|
126
|
+
Attributes:
|
|
127
|
+
rf: Time-domain RF signals, shape (n_samples, n_elements).
|
|
128
|
+
spectrum: Complex RF spectrum, shape (n_freq_full, n_elements).
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
rf: Float[Array, "n_samples n_elements"]
|
|
132
|
+
spectrum: Complex[Array, "n_freq_full n_elements"]
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class SimusPlan(NamedTuple):
|
|
136
|
+
"""Precomputed plan for simus computation.
|
|
137
|
+
|
|
138
|
+
Contains all data-dependent quantities so that ``simus_compute`` has
|
|
139
|
+
static array shapes.
|
|
140
|
+
|
|
141
|
+
Attributes:
|
|
142
|
+
selected_freqs: Significant frequency samples in Hz.
|
|
143
|
+
pulse_spectrum: Pulse spectrum at selected frequencies.
|
|
144
|
+
probe_spectrum: Probe response at selected frequencies.
|
|
145
|
+
n_sub: Number of sub-elements per transducer element.
|
|
146
|
+
seg_length: Sub-element length in meters.
|
|
147
|
+
correction_factor: Scaling factor for the integration
|
|
148
|
+
(df * element_width, or element_width when tx_n_wavelengths=inf).
|
|
149
|
+
n_freq_full: Total number of frequency bins (0 to 2*fc).
|
|
150
|
+
freq_idx_start: Index of first selected frequency in full spectrum.
|
|
151
|
+
n_fft: Number of points for the IFFT (from fs, fc, Nf).
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
selected_freqs: Float[Array, " n_frequencies"]
|
|
155
|
+
pulse_spectrum: Complex[Array, " n_frequencies"]
|
|
156
|
+
probe_spectrum: Float[Array, " n_frequencies"]
|
|
157
|
+
n_sub: int
|
|
158
|
+
seg_length: float
|
|
159
|
+
correction_factor: float
|
|
160
|
+
n_freq_full: int
|
|
161
|
+
freq_idx_start: int
|
|
162
|
+
n_fft: int
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def simus_precompute(
|
|
166
|
+
scatterers: Float[Array, "*batch 2"],
|
|
167
|
+
rc: Float[Array, " *batch"],
|
|
168
|
+
delays: Float[Array, " n_elements"],
|
|
169
|
+
params: TransducerParams,
|
|
170
|
+
medium: MediumParams = _DEFAULT_MEDIUM,
|
|
171
|
+
*,
|
|
172
|
+
fs: float | None = None,
|
|
173
|
+
tx_n_wavelengths: float | int = 1.0,
|
|
174
|
+
db_thresh: float | int = -60.0,
|
|
175
|
+
element_splitting: int | None = None,
|
|
176
|
+
frequency_step: float | int = 1.0,
|
|
177
|
+
) -> SimusPlan:
|
|
178
|
+
"""Precompute static quantities for simus computation.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
scatterers: Scatterer positions in meters. Shape ``(*batch, 2)``.
|
|
182
|
+
rc: Reflection coefficients. Shape ``(*batch,)``.
|
|
183
|
+
delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
|
|
184
|
+
params: Transducer parameters.
|
|
185
|
+
medium: Medium parameters.
|
|
186
|
+
fs: Sampling frequency in Hz. Defaults to 4 * fc.
|
|
187
|
+
tx_n_wavelengths: Number of wavelengths in the TX pulse.
|
|
188
|
+
db_thresh: Threshold in dB for frequency component selection.
|
|
189
|
+
element_splitting: Number of sub-elements per element (None = auto).
|
|
190
|
+
frequency_step: Scaling factor for the frequency step.
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
SimusPlan with static-shaped arrays and precomputed scalars.
|
|
194
|
+
"""
|
|
195
|
+
xp = array_namespace(scatterers, delays)
|
|
196
|
+
speed_of_sound = medium.speed_of_sound
|
|
197
|
+
fc = params.freq_center
|
|
198
|
+
|
|
199
|
+
if fs is None:
|
|
200
|
+
fs = 4.0 * fc
|
|
201
|
+
|
|
202
|
+
# NaN-clean delays
|
|
203
|
+
delays_clean = xp.where(xp.isnan(delays), xp.asarray(0.0), delays)
|
|
204
|
+
|
|
205
|
+
# Element splitting
|
|
206
|
+
if element_splitting is not None:
|
|
207
|
+
n_sub = element_splitting
|
|
208
|
+
else:
|
|
209
|
+
lambda_min = speed_of_sound / (fc * (1.0 + params.bandwidth / 2.0))
|
|
210
|
+
n_sub = ceil(params.element_width / lambda_min)
|
|
211
|
+
|
|
212
|
+
seg_length = params.element_width / n_sub
|
|
213
|
+
|
|
214
|
+
# Max distance for frequency step (use element centers, matching PyMUST simus)
|
|
215
|
+
element_pos, theta_elements, _ = element_positions(params.n_elements, params.pitch, params.radius, xp)
|
|
216
|
+
if theta_elements is None:
|
|
217
|
+
theta_elements = xp.zeros(params.n_elements)
|
|
218
|
+
|
|
219
|
+
x = scatterers[..., 0]
|
|
220
|
+
z = scatterers[..., 1]
|
|
221
|
+
d2 = (xp.reshape(x, (-1, 1)) - element_pos[:, 0]) ** 2 + (xp.reshape(z, (-1, 1)) - element_pos[:, 1]) ** 2
|
|
222
|
+
max_d = float(xp.max(xp.sqrt(d2)))
|
|
223
|
+
|
|
224
|
+
# Two-way pulse length correction (matches MATLAB: getpulse(param,2))
|
|
225
|
+
if tx_n_wavelengths != float("inf"):
|
|
226
|
+
tp = _two_way_pulse_duration(fc, params.bandwidth, tx_n_wavelengths, xp)
|
|
227
|
+
max_d = max_d + tp * speed_of_sound
|
|
228
|
+
|
|
229
|
+
# Round-trip frequency step (matches PyMUST simus df formula)
|
|
230
|
+
df = 1.0 / 2.0 / (2.0 * max_d / speed_of_sound + float(xp.max(delays_clean)))
|
|
231
|
+
df = float(frequency_step) * df
|
|
232
|
+
|
|
233
|
+
# Full frequency grid
|
|
234
|
+
n_freq_full = int(2 * ceil(fc / df) + 1)
|
|
235
|
+
|
|
236
|
+
# Frequency selection using shared helper
|
|
237
|
+
freq_plan = _select_frequencies(fc, params.bandwidth, tx_n_wavelengths, db_thresh, df, xp)
|
|
238
|
+
df_actual = freq_plan.freq_step
|
|
239
|
+
|
|
240
|
+
# Find start index of selected frequencies in full spectrum
|
|
241
|
+
freq_idx_start = round(float(freq_plan.selected_freqs[0]) / df_actual) if df_actual > 0 else 0
|
|
242
|
+
|
|
243
|
+
# Correction factor
|
|
244
|
+
correction_factor = 1.0 if tx_n_wavelengths == float("inf") else df_actual
|
|
245
|
+
correction_factor = correction_factor * params.element_width
|
|
246
|
+
|
|
247
|
+
# IFFT length
|
|
248
|
+
n_fft = ceil(fs / 2.0 / fc * (n_freq_full - 1))
|
|
249
|
+
|
|
250
|
+
return SimusPlan(
|
|
251
|
+
selected_freqs=freq_plan.selected_freqs,
|
|
252
|
+
pulse_spectrum=freq_plan.pulse_spectrum,
|
|
253
|
+
probe_spectrum=freq_plan.probe_spectrum,
|
|
254
|
+
n_sub=n_sub,
|
|
255
|
+
seg_length=seg_length,
|
|
256
|
+
correction_factor=correction_factor,
|
|
257
|
+
n_freq_full=n_freq_full,
|
|
258
|
+
freq_idx_start=freq_idx_start,
|
|
259
|
+
n_fft=n_fft,
|
|
260
|
+
)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _prepare_simus_sweep(
|
|
264
|
+
scatterers: Float[Array, "*batch 2"],
|
|
265
|
+
delays_clean: Float[Array, " n_elements"],
|
|
266
|
+
tx_apodization: Float[Array, " n_elements"],
|
|
267
|
+
plan: SimusPlan,
|
|
268
|
+
params: TransducerParams,
|
|
269
|
+
medium: MediumParams,
|
|
270
|
+
*,
|
|
271
|
+
full_frequency_directivity: bool,
|
|
272
|
+
xp: _ArrayNamespace,
|
|
273
|
+
) -> dict:
|
|
274
|
+
"""Compute geometry and phase arrays for simus frequency sweep.
|
|
275
|
+
|
|
276
|
+
Unlike pfield's _prepare_frequency_sweep, this keeps per-element structure
|
|
277
|
+
(n_scat, n_elem, n_sub) instead of flattening to (n_scat, n_sources).
|
|
278
|
+
Delay+apodization are NOT absorbed into the geometric progression --
|
|
279
|
+
they are kept separate for the TX/RX chain.
|
|
280
|
+
"""
|
|
281
|
+
element_pos, theta_elements, apex_offset = element_positions(params.n_elements, params.pitch, params.radius, xp)
|
|
282
|
+
if theta_elements is None:
|
|
283
|
+
theta_elements = xp.zeros(params.n_elements)
|
|
284
|
+
|
|
285
|
+
speed_of_sound = medium.speed_of_sound
|
|
286
|
+
attenuation = medium.attenuation
|
|
287
|
+
|
|
288
|
+
subelement_offsets = _subelement_centroids(params.element_width, plan.n_sub, theta_elements, xp)
|
|
289
|
+
|
|
290
|
+
x = scatterers[..., 0]
|
|
291
|
+
z = scatterers[..., 1]
|
|
292
|
+
is_out = z < 0
|
|
293
|
+
if params.radius != inf:
|
|
294
|
+
is_out = is_out | ((x**2 + (z + apex_offset) ** 2) <= params.radius**2)
|
|
295
|
+
|
|
296
|
+
distances, sin_theta, theta_arr = _distances_and_angles(
|
|
297
|
+
scatterers, subelement_offsets, element_pos, theta_elements, speed_of_sound, params.freq_center, xp
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
obliquity_factor = _obliquity_factor(theta_arr, params.baffle, xp)
|
|
301
|
+
|
|
302
|
+
freq_start = plan.selected_freqs[0]
|
|
303
|
+
n_freqs = plan.selected_freqs.shape[0]
|
|
304
|
+
freq_step = (plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freqs > 1 else xp.asarray(0.0)
|
|
305
|
+
|
|
306
|
+
phase_init, phase_step = _init_exponentials(
|
|
307
|
+
freq_start, speed_of_sound, attenuation, distances, obliquity_factor, freq_step, xp
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if not full_frequency_directivity:
|
|
311
|
+
center_wavenumber = 2.0 * pi * params.freq_center / speed_of_sound
|
|
312
|
+
sinc_arg = xp.asarray(center_wavenumber * plan.seg_length / 2.0) * sin_theta / pi
|
|
313
|
+
phase_init = phase_init * xpx.sinc(sinc_arg, xp=xp)
|
|
314
|
+
|
|
315
|
+
# Delay+apodization as separate geometric progressions (not absorbed)
|
|
316
|
+
delay_apod_init = xp.exp(xp.asarray(1j * 2.0 * pi) * freq_start * delays_clean) * tx_apodization
|
|
317
|
+
delay_apod_step = xp.exp(xp.asarray(1j * 2.0 * pi) * freq_step * delays_clean)
|
|
318
|
+
|
|
319
|
+
wavenumbers = xp.asarray(2.0 * pi) * plan.selected_freqs / speed_of_sound
|
|
320
|
+
|
|
321
|
+
return {
|
|
322
|
+
"phase_init": phase_init,
|
|
323
|
+
"phase_step": phase_step,
|
|
324
|
+
"delay_apod_init": delay_apod_init,
|
|
325
|
+
"delay_apod_step": delay_apod_step,
|
|
326
|
+
"is_out": is_out,
|
|
327
|
+
"wavenumbers": wavenumbers,
|
|
328
|
+
"pulse_spect": plan.pulse_spectrum,
|
|
329
|
+
"probe_spect": plan.probe_spectrum,
|
|
330
|
+
"seg_length": plan.seg_length,
|
|
331
|
+
"sin_theta": sin_theta,
|
|
332
|
+
"full_frequency_directivity": full_frequency_directivity,
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _irfft_and_threshold(
|
|
337
|
+
spect_selected: Complex[Array, "n_freq_sel n_elem"],
|
|
338
|
+
plan: SimusPlan,
|
|
339
|
+
n_elements: int,
|
|
340
|
+
xp: _ArrayNamespace,
|
|
341
|
+
) -> tuple[Float[Array, "n_samples n_elem"], Complex[Array, "n_freq_full n_elem"]]:
|
|
342
|
+
"""Place selected spectrum, IFFT to time domain, apply smooth thresholding."""
|
|
343
|
+
if not hasattr(xp, "fft"):
|
|
344
|
+
msg = "simus requires an array backend with FFT support (e.g. numpy, jax, cupy)"
|
|
345
|
+
raise RuntimeError(msg)
|
|
346
|
+
xp_fft = cast(_ArrayNamespaceWithFFT, xp)
|
|
347
|
+
|
|
348
|
+
n_freq_sel = spect_selected.shape[0]
|
|
349
|
+
full_spectrum = xp.zeros((plan.n_freq_full, n_elements), dtype=spect_selected.dtype)
|
|
350
|
+
full_spectrum = xpx.at(full_spectrum)[plan.freq_idx_start : plan.freq_idx_start + n_freq_sel, :].set( # type: ignore[attr-defined]
|
|
351
|
+
spect_selected
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
rf = xp_fft.fft.irfft(xp.conj(full_spectrum), n=plan.n_fft, axis=0)
|
|
355
|
+
|
|
356
|
+
n_keep = (plan.n_fft + 1) // 2
|
|
357
|
+
rf = rf[:n_keep, ...]
|
|
358
|
+
|
|
359
|
+
# Smooth thresholding of small values (-100 dB)
|
|
360
|
+
rel_thresh = 1e-5
|
|
361
|
+
rf_peak = xp.max(xp.abs(rf))
|
|
362
|
+
rel_rf = xp.abs(rf) / (rf_peak + xp.asarray(1e-30))
|
|
363
|
+
smooth_gate = 0.5 * (1.0 + xp.tanh((rel_rf - rel_thresh) / (rel_thresh / 10.0))) # type: ignore[attr-defined]
|
|
364
|
+
|
|
365
|
+
rf = rf * smooth_gate
|
|
366
|
+
|
|
367
|
+
return rf, full_spectrum
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def _select_simus_strategy(xp: _ArrayNamespace, strategy: SimusStrategy | None) -> SimusStrategy:
|
|
371
|
+
"""Auto-select simus strategy based on array backend."""
|
|
372
|
+
if strategy is not None:
|
|
373
|
+
return strategy
|
|
374
|
+
|
|
375
|
+
if is_jax_namespace(cast(ModuleType, xp)):
|
|
376
|
+
return SimusStrategy.SCAN
|
|
377
|
+
|
|
378
|
+
try:
|
|
379
|
+
import mlx.core
|
|
380
|
+
|
|
381
|
+
if xp is mlx.core:
|
|
382
|
+
return SimusStrategy.METAL
|
|
383
|
+
except ImportError:
|
|
384
|
+
pass
|
|
385
|
+
|
|
386
|
+
if is_cupy_namespace(xp):
|
|
387
|
+
return SimusStrategy.CUDA
|
|
388
|
+
|
|
389
|
+
return SimusStrategy.PYTHON
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
def simus_compute(
|
|
393
|
+
scatterers: Float[Array, "*batch 2"],
|
|
394
|
+
rc: Float[Array, " *batch"],
|
|
395
|
+
delays: Float[Array, " n_elements"],
|
|
396
|
+
plan: SimusPlan,
|
|
397
|
+
params: TransducerParams,
|
|
398
|
+
medium: MediumParams = _DEFAULT_MEDIUM,
|
|
399
|
+
*,
|
|
400
|
+
tx_apodization: Float[Array, " n_elements"] | None = None,
|
|
401
|
+
full_frequency_directivity: bool = False,
|
|
402
|
+
strategy: SimusStrategy | None = None,
|
|
403
|
+
) -> SimusResult:
|
|
404
|
+
"""Compute RF signals given a precomputed plan.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
scatterers: Scatterer positions in meters. Shape ``(*batch, 2)``.
|
|
408
|
+
rc: Reflection coefficients. Shape ``(*batch,)``.
|
|
409
|
+
delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
|
|
410
|
+
plan: Precomputed plan from ``simus_precompute``.
|
|
411
|
+
params: Transducer parameters.
|
|
412
|
+
medium: Medium parameters.
|
|
413
|
+
tx_apodization: Transmit apodization weights. Shape ``(n_elements,)``.
|
|
414
|
+
full_frequency_directivity: If True, compute element directivity at
|
|
415
|
+
every frequency.
|
|
416
|
+
strategy: Backend strategy for the frequency sweep. If None,
|
|
417
|
+
auto-selects based on the detected array backend.
|
|
418
|
+
|
|
419
|
+
Returns:
|
|
420
|
+
SimusResult with RF signals and complex spectrum.
|
|
421
|
+
"""
|
|
422
|
+
xp = array_namespace(scatterers, rc, delays)
|
|
423
|
+
|
|
424
|
+
if tx_apodization is None:
|
|
425
|
+
tx_apodization = xp.ones(params.n_elements)
|
|
426
|
+
|
|
427
|
+
nan_mask = xp.isnan(delays)
|
|
428
|
+
tx_apodization = xp.where(nan_mask, xp.asarray(0.0), tx_apodization)
|
|
429
|
+
delays_clean = xp.where(nan_mask, xp.asarray(0.0), delays)
|
|
430
|
+
|
|
431
|
+
# Flatten scatterers for the frequency sweep
|
|
432
|
+
n_scat = scatterers.shape[0] if scatterers.ndim >= 2 else 1
|
|
433
|
+
scatterers_flat = xp.reshape(scatterers, (n_scat, 2)) if scatterers.ndim > 2 else scatterers
|
|
434
|
+
rc_flat = xp.reshape(rc, (n_scat,)) if rc.ndim > 1 else rc
|
|
435
|
+
|
|
436
|
+
selected = _select_simus_strategy(xp, strategy)
|
|
437
|
+
|
|
438
|
+
if selected == SimusStrategy.METAL:
|
|
439
|
+
import mlx.core as mx
|
|
440
|
+
|
|
441
|
+
from fast_simus.kernels.metal_simus import simus_metal
|
|
442
|
+
|
|
443
|
+
spect_selected = cast(
|
|
444
|
+
Array,
|
|
445
|
+
simus_metal(
|
|
446
|
+
scatterers=cast(mx.array, scatterers_flat),
|
|
447
|
+
rc=cast(mx.array, rc_flat),
|
|
448
|
+
params=params,
|
|
449
|
+
plan=plan,
|
|
450
|
+
medium=medium,
|
|
451
|
+
delays_clean=cast(mx.array, delays_clean),
|
|
452
|
+
tx_apodization=cast(mx.array, tx_apodization),
|
|
453
|
+
),
|
|
454
|
+
)
|
|
455
|
+
elif selected == SimusStrategy.CUDA:
|
|
456
|
+
from fast_simus.kernels.cuda_simus import simus_cuda
|
|
457
|
+
|
|
458
|
+
spect_selected = cast(
|
|
459
|
+
Array,
|
|
460
|
+
simus_cuda(
|
|
461
|
+
scatterers=scatterers_flat,
|
|
462
|
+
rc=rc_flat,
|
|
463
|
+
params=params,
|
|
464
|
+
plan=plan,
|
|
465
|
+
medium=medium,
|
|
466
|
+
delays_clean=delays_clean,
|
|
467
|
+
tx_apodization=tx_apodization,
|
|
468
|
+
),
|
|
469
|
+
)
|
|
470
|
+
else:
|
|
471
|
+
sweep = _prepare_simus_sweep(
|
|
472
|
+
scatterers_flat,
|
|
473
|
+
delays_clean,
|
|
474
|
+
tx_apodization,
|
|
475
|
+
plan,
|
|
476
|
+
params,
|
|
477
|
+
medium,
|
|
478
|
+
full_frequency_directivity=full_frequency_directivity,
|
|
479
|
+
xp=xp,
|
|
480
|
+
)
|
|
481
|
+
if selected == SimusStrategy.SCAN:
|
|
482
|
+
from fast_simus._simus_strategies import _simus_freq_outer_scan
|
|
483
|
+
|
|
484
|
+
spect_selected = _simus_freq_outer_scan(rc=rc_flat, xp=xp, **sweep)
|
|
485
|
+
else:
|
|
486
|
+
from fast_simus._simus_strategies import _simus_freq_outer_python
|
|
487
|
+
|
|
488
|
+
spect_selected = _simus_freq_outer_python(rc=rc_flat, xp=xp, **sweep)
|
|
489
|
+
|
|
490
|
+
# Apply correction factor
|
|
491
|
+
spect_selected = spect_selected * xp.asarray(plan.correction_factor)
|
|
492
|
+
|
|
493
|
+
rf, full_spectrum = _irfft_and_threshold(spect_selected, plan, params.n_elements, xp)
|
|
494
|
+
|
|
495
|
+
return SimusResult(rf=rf, spectrum=full_spectrum)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def simus(
|
|
499
|
+
scatterers: Float[Array, "*batch 2"],
|
|
500
|
+
rc: Float[Array, " *batch"],
|
|
501
|
+
delays: Float[Array, " n_elements"],
|
|
502
|
+
params: TransducerParams,
|
|
503
|
+
medium: MediumParams = _DEFAULT_MEDIUM,
|
|
504
|
+
*,
|
|
505
|
+
fs: float | None = None,
|
|
506
|
+
tx_apodization: Float[Array, " n_elements"] | None = None,
|
|
507
|
+
tx_n_wavelengths: float | int = 1.0,
|
|
508
|
+
db_thresh: float | int = -60.0,
|
|
509
|
+
full_frequency_directivity: bool = False,
|
|
510
|
+
element_splitting: int | None = None,
|
|
511
|
+
frequency_step: float | int = 1.0,
|
|
512
|
+
strategy: SimusStrategy | None = None,
|
|
513
|
+
) -> SimusResult:
|
|
514
|
+
"""Simulate ultrasound RF signals for a linear or convex array.
|
|
515
|
+
|
|
516
|
+
Computes RF radio-frequency signals generated by an ultrasound uniform
|
|
517
|
+
linear or convex array insonifying a medium of scatterers. Uses the SIMUS
|
|
518
|
+
algorithm: TX forward propagation, scattering, RX back-propagation
|
|
519
|
+
(acoustic reciprocity), and IFFT to time domain.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
scatterers: Scatterer positions in meters. Shape ``(*batch, 2)`` where
|
|
523
|
+
``[..., 0]`` is lateral (x) and ``[..., 1]`` is axial (z).
|
|
524
|
+
rc: Reflection coefficients. Shape ``(*batch,)``. Same size as scatterers
|
|
525
|
+
(excluding last dimension).
|
|
526
|
+
delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
|
|
527
|
+
params: Transducer parameters (geometry, frequency, bandwidth).
|
|
528
|
+
medium: Medium parameters (speed of sound, attenuation).
|
|
529
|
+
fs: Sampling frequency in Hz. Defaults to ``4 * params.freq_center``.
|
|
530
|
+
tx_apodization: Transmit apodization weights. Shape ``(n_elements,)``.
|
|
531
|
+
tx_n_wavelengths: Number of wavelengths in the TX pulse.
|
|
532
|
+
db_thresh: Threshold in dB for frequency component selection.
|
|
533
|
+
full_frequency_directivity: If True, compute element directivity at
|
|
534
|
+
every frequency. If False, use center-frequency-only directivity.
|
|
535
|
+
element_splitting: Number of sub-elements per element (None = auto).
|
|
536
|
+
frequency_step: Scaling factor for the frequency step.
|
|
537
|
+
strategy: Backend strategy for the frequency sweep. If None,
|
|
538
|
+
auto-selects based on the detected array backend.
|
|
539
|
+
|
|
540
|
+
Returns:
|
|
541
|
+
SimusResult with:
|
|
542
|
+
- rf: Time-domain RF signals, shape (n_samples, n_elements)
|
|
543
|
+
- spectrum: Complex RF spectrum, shape (n_freq_full, n_elements)
|
|
544
|
+
"""
|
|
545
|
+
plan = simus_precompute(
|
|
546
|
+
scatterers,
|
|
547
|
+
rc,
|
|
548
|
+
delays,
|
|
549
|
+
params,
|
|
550
|
+
medium,
|
|
551
|
+
fs=fs,
|
|
552
|
+
tx_n_wavelengths=tx_n_wavelengths,
|
|
553
|
+
db_thresh=db_thresh,
|
|
554
|
+
element_splitting=element_splitting,
|
|
555
|
+
frequency_step=frequency_step,
|
|
556
|
+
)
|
|
557
|
+
return simus_compute(
|
|
558
|
+
scatterers,
|
|
559
|
+
rc,
|
|
560
|
+
delays,
|
|
561
|
+
plan,
|
|
562
|
+
params,
|
|
563
|
+
medium,
|
|
564
|
+
tx_apodization=tx_apodization,
|
|
565
|
+
full_frequency_directivity=full_frequency_directivity,
|
|
566
|
+
strategy=strategy,
|
|
567
|
+
)
|
fast_simus/spectrum.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
"""Spectrum computation for ultrasound pulse and probe response.
|
|
2
|
+
|
|
3
|
+
Provides functions to compute frequency-domain representations of the
|
|
4
|
+
transmitted pulse and probe frequency response.
|
|
5
|
+
|
|
6
|
+
All functions are Array API compliant and work with NumPy, JAX, CuPy backends.
|
|
7
|
+
|
|
8
|
+
References:
|
|
9
|
+
Garcia D. SIMUS: an open-source simulator for medical ultrasound imaging.
|
|
10
|
+
Part I: theory & examples. CMPB, 2022;218:106726.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from math import log, pi
|
|
16
|
+
|
|
17
|
+
import array_api_extra as xpx
|
|
18
|
+
from beartype import beartype as typechecker
|
|
19
|
+
from jaxtyping import Complex, Float, jaxtyped
|
|
20
|
+
|
|
21
|
+
from fast_simus.utils._array_api import Array, array_namespace
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@jaxtyped(typechecker=typechecker)
|
|
25
|
+
def pulse_spectrum(
|
|
26
|
+
angular_freq: Float[Array, " n_freqs"],
|
|
27
|
+
freq_center: float | int,
|
|
28
|
+
tx_n_wavelengths: float | int = 1.0,
|
|
29
|
+
) -> Complex[Array, " n_freqs"]:
|
|
30
|
+
"""Compute the pulse spectrum for a windowed sine pulse.
|
|
31
|
+
|
|
32
|
+
Computes the frequency-domain representation of a windowed sine pulse
|
|
33
|
+
with the given center frequency and number of wavelengths.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
angular_freq: Angular frequency in rad/s. Shape (n_freqs,).
|
|
37
|
+
freq_center: Center frequency in Hz. Must be positive.
|
|
38
|
+
tx_n_wavelengths: Number of wavelengths of the TX pulse.
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Complex spectrum at the given angular frequencies. Shape (n_freqs,).
|
|
42
|
+
"""
|
|
43
|
+
pulse_duration_s = tx_n_wavelengths / freq_center
|
|
44
|
+
angular_freq_center = 2.0 * pi * freq_center
|
|
45
|
+
|
|
46
|
+
xp = array_namespace(angular_freq)
|
|
47
|
+
sinc_arg_lower = pulse_duration_s * (angular_freq - angular_freq_center) / 2.0 / pi
|
|
48
|
+
sinc_arg_upper = pulse_duration_s * (angular_freq + angular_freq_center) / 2.0 / pi
|
|
49
|
+
# array-api-extra does not have type interoperability
|
|
50
|
+
return 1j * (xpx.sinc(sinc_arg_lower, xp=xp) - xpx.sinc(sinc_arg_upper, xp=xp))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@jaxtyped(typechecker=typechecker)
|
|
54
|
+
def probe_spectrum(
|
|
55
|
+
angular_freq: Float[Array, " n_freqs"],
|
|
56
|
+
freq_center: float | int,
|
|
57
|
+
bandwidth: float | int = 0.75,
|
|
58
|
+
) -> Float[Array, " n_freqs"]:
|
|
59
|
+
"""Compute the probe frequency response.
|
|
60
|
+
|
|
61
|
+
Computes the one-way probe frequency response using a generalized normal
|
|
62
|
+
window. The bandwidth parameter defines the pulse-echo 6 dB fractional
|
|
63
|
+
bandwidth.
|
|
64
|
+
|
|
65
|
+
The returned spectrum is the square root of the pulse-echo response,
|
|
66
|
+
appropriate for one-way (transmit-only or receive-only) use.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
angular_freq: Angular frequency in rad/s. Shape (n_freqs,).
|
|
70
|
+
freq_center: Center frequency in Hz. Must be positive.
|
|
71
|
+
bandwidth: Fractional bandwidth (0.75 = 75%). Must be in (0, 2.0).
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
Real-valued probe response (one-way) at the given angular frequencies.
|
|
75
|
+
Shape (n_freqs,).
|
|
76
|
+
|
|
77
|
+
References:
|
|
78
|
+
Generalized normal window:
|
|
79
|
+
https://en.wikipedia.org/wiki/Window_function#Generalized_normal_window
|
|
80
|
+
|
|
81
|
+
Reference implementation (log(126) constant verified against):
|
|
82
|
+
https://github.com/creatis-ULTIM/PyMUST/blob/df02b42/src/pymust/utils.py#L141
|
|
83
|
+
"""
|
|
84
|
+
# Validate bandwidth
|
|
85
|
+
if not (0.0 < bandwidth < 2.0):
|
|
86
|
+
msg = f"bandwidth must be in (0, 2.0), got {bandwidth!r}"
|
|
87
|
+
raise ValueError(msg)
|
|
88
|
+
|
|
89
|
+
angular_freq_center = 2.0 * pi * freq_center
|
|
90
|
+
# Convert fractional bandwidth to angular bandwidth
|
|
91
|
+
angular_bandwidth = bandwidth * angular_freq_center
|
|
92
|
+
# Shape parameter for the generalized normal window
|
|
93
|
+
# The constant 126 comes from the two-way 6 dB bandwidth criterion:
|
|
94
|
+
# For pulse-echo, the total response is the product of TX and RX responses,
|
|
95
|
+
# so the one-way response at -3 dB corresponds to -6 dB two-way.
|
|
96
|
+
# In linear scale: 10^(6/10) ≈ 3.98, but the generalized normal window
|
|
97
|
+
# parameterization uses 126 = 2 * (2^6) to define the bandwidth edges
|
|
98
|
+
# where the two-way response falls to -6 dB.
|
|
99
|
+
shape_param = log(126) / log(2.0 * angular_freq_center / angular_bandwidth)
|
|
100
|
+
# Denominator of the exponent
|
|
101
|
+
sigma = angular_bandwidth / 2.0 / (log(2) ** (1.0 / shape_param))
|
|
102
|
+
|
|
103
|
+
xp = array_namespace(angular_freq)
|
|
104
|
+
# Pulse-echo (squared) response
|
|
105
|
+
spectrum_squared = xp.exp(-((xp.abs(angular_freq - angular_freq_center) / sigma) ** shape_param))
|
|
106
|
+
# One-way response (square root)
|
|
107
|
+
return xp.sqrt(spectrum_squared)
|