FastSIMUS 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fast_simus/simus.py ADDED
@@ -0,0 +1,567 @@
1
+ """Ultrasound RF signal simulation for linear and convex arrays.
2
+
3
+ Implements the SIMUS algorithm: for each frequency in the transmit bandwidth,
4
+ compute forward TX pressure at scatterers, scatter by reflection coefficients,
5
+ back-propagate to receive elements (acoustic reciprocity), accumulate complex
6
+ RF spectrum, then IFFT to time-domain RF signals.
7
+
8
+ All functions are Array API compliant and work with NumPy, JAX, CuPy backends.
9
+
10
+ References:
11
+ Garcia D. SIMUS: an open-source simulator for medical ultrasound imaging.
12
+ Part I: theory & examples. CMPB, 2022;218:106726.
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from enum import StrEnum
18
+ from math import ceil, inf, log2, pi
19
+ from types import ModuleType
20
+ from typing import NamedTuple, cast
21
+
22
+ import array_api_extra as xpx
23
+ from array_api_compat import is_jax_namespace
24
+ from jaxtyping import Complex, Float
25
+
26
+ from fast_simus._pfield_math import (
27
+ _distances_and_angles,
28
+ _init_exponentials,
29
+ _obliquity_factor,
30
+ _select_frequencies,
31
+ _subelement_centroids,
32
+ )
33
+ from fast_simus.medium_params import MediumParams
34
+ from fast_simus.spectrum import probe_spectrum as _probe_spectrum_fn
35
+ from fast_simus.spectrum import pulse_spectrum as _pulse_spectrum_fn
36
+ from fast_simus.transducer_params import TransducerParams
37
+ from fast_simus.utils._array_api import (
38
+ Array,
39
+ _ArrayNamespace,
40
+ _ArrayNamespaceWithFFT,
41
+ array_namespace,
42
+ is_cupy_namespace,
43
+ )
44
+ from fast_simus.utils.geometry import element_positions
45
+
46
+ _DEFAULT_MEDIUM = MediumParams()
47
+
48
+
49
+ def _two_way_pulse_duration(
50
+ freq_center: float,
51
+ bandwidth: float,
52
+ tx_n_wavelengths: float,
53
+ xp: _ArrayNamespace,
54
+ ) -> float:
55
+ """Compute the temporal extent of the two-way (pulse-echo) pulse.
56
+
57
+ Replicates the pulse duration computation from PyMUST's getpulse(param, 2).
58
+ Uses pulse_spectrum * probe_spectrum^2, IFFTs, and thresholds at 1/1023.
59
+
60
+ Args:
61
+ freq_center: Center frequency in Hz.
62
+ bandwidth: Fractional bandwidth (0.75 = 75%).
63
+ tx_n_wavelengths: Number of wavelengths of the TX pulse.
64
+ xp: Array namespace (must have FFT extension).
65
+
66
+ Returns:
67
+ Pulse duration in seconds.
68
+ """
69
+ # hasattr instead of isinstance(_ArrayNamespaceWithFFT) because Python 3.12+
70
+ # Protocol isinstance uses getattr_static, which misses lazy sub-module attrs
71
+ # like numpy.fft. See https://docs.python.org/3/whatsnew/3.12.html#typing
72
+ if not hasattr(xp, "fft"):
73
+ msg = "simus requires an array backend with FFT support (e.g. numpy, jax, cupy)"
74
+ raise RuntimeError(msg)
75
+ xp_fft = cast(_ArrayNamespaceWithFFT, xp)
76
+
77
+ dt = 1e-9
78
+ df = freq_center / tx_n_wavelengths / 32
79
+ p = ceil(log2(1.0 / dt / 2.0 / df))
80
+ n_fft = 2**p
81
+ omega = 2.0 * pi * xp.linspace(0, 1.0 / dt / 2.0, n_fft)
82
+
83
+ # Two-way spectrum: pulse * probe^2
84
+ ps = _pulse_spectrum_fn(omega, freq_center, tx_n_wavelengths)
85
+ pr = _probe_spectrum_fn(omega, freq_center, bandwidth)
86
+ two_way = ps * pr**2
87
+
88
+ pulse = xp_fft.fft.fftshift(xp_fft.fft.irfft(two_way))
89
+ pulse = pulse / xp.max(xp.abs(pulse))
90
+
91
+ above = pulse > (1.0 / 1023)
92
+ n = above.shape[0]
93
+ indices = xp.arange(n)
94
+ masked_min = xp.where(above, indices, xp.asarray(n))
95
+ masked_max = xp.where(above, indices, xp.asarray(-1))
96
+ idx1 = int(xp.min(masked_min))
97
+ idx2 = int(xp.max(masked_max))
98
+
99
+ if idx1 >= n:
100
+ return tx_n_wavelengths / freq_center
101
+
102
+ trim_idx = min(idx1 + 1, 2 * n_fft - 1 - idx2 - 1)
103
+ pulse_trimmed = pulse[-trim_idx : trim_idx - 2 : -1]
104
+ return float(pulse_trimmed.shape[0] * dt)
105
+
106
+
107
+ class SimusStrategy(StrEnum):
108
+ """Backend strategy for the simus frequency sweep.
109
+
110
+ Attributes:
111
+ PYTHON: Python for-loop (NumPy/CuPy, constant memory).
112
+ SCAN: JAX lax.scan for O(1) compilation cost.
113
+ METAL: Custom Metal kernel on Apple Silicon (MLX).
114
+ CUDA: Custom CUDA kernel on NVIDIA GPUs (CuPy + NVRTC).
115
+ """
116
+
117
+ PYTHON = "python"
118
+ SCAN = "scan"
119
+ METAL = "metal"
120
+ CUDA = "cuda"
121
+
122
+
123
+ class SimusResult(NamedTuple):
124
+ """Result of simus RF signal simulation.
125
+
126
+ Attributes:
127
+ rf: Time-domain RF signals, shape (n_samples, n_elements).
128
+ spectrum: Complex RF spectrum, shape (n_freq_full, n_elements).
129
+ """
130
+
131
+ rf: Float[Array, "n_samples n_elements"]
132
+ spectrum: Complex[Array, "n_freq_full n_elements"]
133
+
134
+
135
+ class SimusPlan(NamedTuple):
136
+ """Precomputed plan for simus computation.
137
+
138
+ Contains all data-dependent quantities so that ``simus_compute`` has
139
+ static array shapes.
140
+
141
+ Attributes:
142
+ selected_freqs: Significant frequency samples in Hz.
143
+ pulse_spectrum: Pulse spectrum at selected frequencies.
144
+ probe_spectrum: Probe response at selected frequencies.
145
+ n_sub: Number of sub-elements per transducer element.
146
+ seg_length: Sub-element length in meters.
147
+ correction_factor: Scaling factor for the integration
148
+ (df * element_width, or element_width when tx_n_wavelengths=inf).
149
+ n_freq_full: Total number of frequency bins (0 to 2*fc).
150
+ freq_idx_start: Index of first selected frequency in full spectrum.
151
+ n_fft: Number of points for the IFFT (from fs, fc, Nf).
152
+ """
153
+
154
+ selected_freqs: Float[Array, " n_frequencies"]
155
+ pulse_spectrum: Complex[Array, " n_frequencies"]
156
+ probe_spectrum: Float[Array, " n_frequencies"]
157
+ n_sub: int
158
+ seg_length: float
159
+ correction_factor: float
160
+ n_freq_full: int
161
+ freq_idx_start: int
162
+ n_fft: int
163
+
164
+
165
+ def simus_precompute(
166
+ scatterers: Float[Array, "*batch 2"],
167
+ rc: Float[Array, " *batch"],
168
+ delays: Float[Array, " n_elements"],
169
+ params: TransducerParams,
170
+ medium: MediumParams = _DEFAULT_MEDIUM,
171
+ *,
172
+ fs: float | None = None,
173
+ tx_n_wavelengths: float | int = 1.0,
174
+ db_thresh: float | int = -60.0,
175
+ element_splitting: int | None = None,
176
+ frequency_step: float | int = 1.0,
177
+ ) -> SimusPlan:
178
+ """Precompute static quantities for simus computation.
179
+
180
+ Args:
181
+ scatterers: Scatterer positions in meters. Shape ``(*batch, 2)``.
182
+ rc: Reflection coefficients. Shape ``(*batch,)``.
183
+ delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
184
+ params: Transducer parameters.
185
+ medium: Medium parameters.
186
+ fs: Sampling frequency in Hz. Defaults to 4 * fc.
187
+ tx_n_wavelengths: Number of wavelengths in the TX pulse.
188
+ db_thresh: Threshold in dB for frequency component selection.
189
+ element_splitting: Number of sub-elements per element (None = auto).
190
+ frequency_step: Scaling factor for the frequency step.
191
+
192
+ Returns:
193
+ SimusPlan with static-shaped arrays and precomputed scalars.
194
+ """
195
+ xp = array_namespace(scatterers, delays)
196
+ speed_of_sound = medium.speed_of_sound
197
+ fc = params.freq_center
198
+
199
+ if fs is None:
200
+ fs = 4.0 * fc
201
+
202
+ # NaN-clean delays
203
+ delays_clean = xp.where(xp.isnan(delays), xp.asarray(0.0), delays)
204
+
205
+ # Element splitting
206
+ if element_splitting is not None:
207
+ n_sub = element_splitting
208
+ else:
209
+ lambda_min = speed_of_sound / (fc * (1.0 + params.bandwidth / 2.0))
210
+ n_sub = ceil(params.element_width / lambda_min)
211
+
212
+ seg_length = params.element_width / n_sub
213
+
214
+ # Max distance for frequency step (use element centers, matching PyMUST simus)
215
+ element_pos, theta_elements, _ = element_positions(params.n_elements, params.pitch, params.radius, xp)
216
+ if theta_elements is None:
217
+ theta_elements = xp.zeros(params.n_elements)
218
+
219
+ x = scatterers[..., 0]
220
+ z = scatterers[..., 1]
221
+ d2 = (xp.reshape(x, (-1, 1)) - element_pos[:, 0]) ** 2 + (xp.reshape(z, (-1, 1)) - element_pos[:, 1]) ** 2
222
+ max_d = float(xp.max(xp.sqrt(d2)))
223
+
224
+ # Two-way pulse length correction (matches MATLAB: getpulse(param,2))
225
+ if tx_n_wavelengths != float("inf"):
226
+ tp = _two_way_pulse_duration(fc, params.bandwidth, tx_n_wavelengths, xp)
227
+ max_d = max_d + tp * speed_of_sound
228
+
229
+ # Round-trip frequency step (matches PyMUST simus df formula)
230
+ df = 1.0 / 2.0 / (2.0 * max_d / speed_of_sound + float(xp.max(delays_clean)))
231
+ df = float(frequency_step) * df
232
+
233
+ # Full frequency grid
234
+ n_freq_full = int(2 * ceil(fc / df) + 1)
235
+
236
+ # Frequency selection using shared helper
237
+ freq_plan = _select_frequencies(fc, params.bandwidth, tx_n_wavelengths, db_thresh, df, xp)
238
+ df_actual = freq_plan.freq_step
239
+
240
+ # Find start index of selected frequencies in full spectrum
241
+ freq_idx_start = round(float(freq_plan.selected_freqs[0]) / df_actual) if df_actual > 0 else 0
242
+
243
+ # Correction factor
244
+ correction_factor = 1.0 if tx_n_wavelengths == float("inf") else df_actual
245
+ correction_factor = correction_factor * params.element_width
246
+
247
+ # IFFT length
248
+ n_fft = ceil(fs / 2.0 / fc * (n_freq_full - 1))
249
+
250
+ return SimusPlan(
251
+ selected_freqs=freq_plan.selected_freqs,
252
+ pulse_spectrum=freq_plan.pulse_spectrum,
253
+ probe_spectrum=freq_plan.probe_spectrum,
254
+ n_sub=n_sub,
255
+ seg_length=seg_length,
256
+ correction_factor=correction_factor,
257
+ n_freq_full=n_freq_full,
258
+ freq_idx_start=freq_idx_start,
259
+ n_fft=n_fft,
260
+ )
261
+
262
+
263
+ def _prepare_simus_sweep(
264
+ scatterers: Float[Array, "*batch 2"],
265
+ delays_clean: Float[Array, " n_elements"],
266
+ tx_apodization: Float[Array, " n_elements"],
267
+ plan: SimusPlan,
268
+ params: TransducerParams,
269
+ medium: MediumParams,
270
+ *,
271
+ full_frequency_directivity: bool,
272
+ xp: _ArrayNamespace,
273
+ ) -> dict:
274
+ """Compute geometry and phase arrays for simus frequency sweep.
275
+
276
+ Unlike pfield's _prepare_frequency_sweep, this keeps per-element structure
277
+ (n_scat, n_elem, n_sub) instead of flattening to (n_scat, n_sources).
278
+ Delay+apodization are NOT absorbed into the geometric progression --
279
+ they are kept separate for the TX/RX chain.
280
+ """
281
+ element_pos, theta_elements, apex_offset = element_positions(params.n_elements, params.pitch, params.radius, xp)
282
+ if theta_elements is None:
283
+ theta_elements = xp.zeros(params.n_elements)
284
+
285
+ speed_of_sound = medium.speed_of_sound
286
+ attenuation = medium.attenuation
287
+
288
+ subelement_offsets = _subelement_centroids(params.element_width, plan.n_sub, theta_elements, xp)
289
+
290
+ x = scatterers[..., 0]
291
+ z = scatterers[..., 1]
292
+ is_out = z < 0
293
+ if params.radius != inf:
294
+ is_out = is_out | ((x**2 + (z + apex_offset) ** 2) <= params.radius**2)
295
+
296
+ distances, sin_theta, theta_arr = _distances_and_angles(
297
+ scatterers, subelement_offsets, element_pos, theta_elements, speed_of_sound, params.freq_center, xp
298
+ )
299
+
300
+ obliquity_factor = _obliquity_factor(theta_arr, params.baffle, xp)
301
+
302
+ freq_start = plan.selected_freqs[0]
303
+ n_freqs = plan.selected_freqs.shape[0]
304
+ freq_step = (plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freqs > 1 else xp.asarray(0.0)
305
+
306
+ phase_init, phase_step = _init_exponentials(
307
+ freq_start, speed_of_sound, attenuation, distances, obliquity_factor, freq_step, xp
308
+ )
309
+
310
+ if not full_frequency_directivity:
311
+ center_wavenumber = 2.0 * pi * params.freq_center / speed_of_sound
312
+ sinc_arg = xp.asarray(center_wavenumber * plan.seg_length / 2.0) * sin_theta / pi
313
+ phase_init = phase_init * xpx.sinc(sinc_arg, xp=xp)
314
+
315
+ # Delay+apodization as separate geometric progressions (not absorbed)
316
+ delay_apod_init = xp.exp(xp.asarray(1j * 2.0 * pi) * freq_start * delays_clean) * tx_apodization
317
+ delay_apod_step = xp.exp(xp.asarray(1j * 2.0 * pi) * freq_step * delays_clean)
318
+
319
+ wavenumbers = xp.asarray(2.0 * pi) * plan.selected_freqs / speed_of_sound
320
+
321
+ return {
322
+ "phase_init": phase_init,
323
+ "phase_step": phase_step,
324
+ "delay_apod_init": delay_apod_init,
325
+ "delay_apod_step": delay_apod_step,
326
+ "is_out": is_out,
327
+ "wavenumbers": wavenumbers,
328
+ "pulse_spect": plan.pulse_spectrum,
329
+ "probe_spect": plan.probe_spectrum,
330
+ "seg_length": plan.seg_length,
331
+ "sin_theta": sin_theta,
332
+ "full_frequency_directivity": full_frequency_directivity,
333
+ }
334
+
335
+
336
+ def _irfft_and_threshold(
337
+ spect_selected: Complex[Array, "n_freq_sel n_elem"],
338
+ plan: SimusPlan,
339
+ n_elements: int,
340
+ xp: _ArrayNamespace,
341
+ ) -> tuple[Float[Array, "n_samples n_elem"], Complex[Array, "n_freq_full n_elem"]]:
342
+ """Place selected spectrum, IFFT to time domain, apply smooth thresholding."""
343
+ if not hasattr(xp, "fft"):
344
+ msg = "simus requires an array backend with FFT support (e.g. numpy, jax, cupy)"
345
+ raise RuntimeError(msg)
346
+ xp_fft = cast(_ArrayNamespaceWithFFT, xp)
347
+
348
+ n_freq_sel = spect_selected.shape[0]
349
+ full_spectrum = xp.zeros((plan.n_freq_full, n_elements), dtype=spect_selected.dtype)
350
+ full_spectrum = xpx.at(full_spectrum)[plan.freq_idx_start : plan.freq_idx_start + n_freq_sel, :].set( # type: ignore[attr-defined]
351
+ spect_selected
352
+ )
353
+
354
+ rf = xp_fft.fft.irfft(xp.conj(full_spectrum), n=plan.n_fft, axis=0)
355
+
356
+ n_keep = (plan.n_fft + 1) // 2
357
+ rf = rf[:n_keep, ...]
358
+
359
+ # Smooth thresholding of small values (-100 dB)
360
+ rel_thresh = 1e-5
361
+ rf_peak = xp.max(xp.abs(rf))
362
+ rel_rf = xp.abs(rf) / (rf_peak + xp.asarray(1e-30))
363
+ smooth_gate = 0.5 * (1.0 + xp.tanh((rel_rf - rel_thresh) / (rel_thresh / 10.0))) # type: ignore[attr-defined]
364
+
365
+ rf = rf * smooth_gate
366
+
367
+ return rf, full_spectrum
368
+
369
+
370
+ def _select_simus_strategy(xp: _ArrayNamespace, strategy: SimusStrategy | None) -> SimusStrategy:
371
+ """Auto-select simus strategy based on array backend."""
372
+ if strategy is not None:
373
+ return strategy
374
+
375
+ if is_jax_namespace(cast(ModuleType, xp)):
376
+ return SimusStrategy.SCAN
377
+
378
+ try:
379
+ import mlx.core
380
+
381
+ if xp is mlx.core:
382
+ return SimusStrategy.METAL
383
+ except ImportError:
384
+ pass
385
+
386
+ if is_cupy_namespace(xp):
387
+ return SimusStrategy.CUDA
388
+
389
+ return SimusStrategy.PYTHON
390
+
391
+
392
+ def simus_compute(
393
+ scatterers: Float[Array, "*batch 2"],
394
+ rc: Float[Array, " *batch"],
395
+ delays: Float[Array, " n_elements"],
396
+ plan: SimusPlan,
397
+ params: TransducerParams,
398
+ medium: MediumParams = _DEFAULT_MEDIUM,
399
+ *,
400
+ tx_apodization: Float[Array, " n_elements"] | None = None,
401
+ full_frequency_directivity: bool = False,
402
+ strategy: SimusStrategy | None = None,
403
+ ) -> SimusResult:
404
+ """Compute RF signals given a precomputed plan.
405
+
406
+ Args:
407
+ scatterers: Scatterer positions in meters. Shape ``(*batch, 2)``.
408
+ rc: Reflection coefficients. Shape ``(*batch,)``.
409
+ delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
410
+ plan: Precomputed plan from ``simus_precompute``.
411
+ params: Transducer parameters.
412
+ medium: Medium parameters.
413
+ tx_apodization: Transmit apodization weights. Shape ``(n_elements,)``.
414
+ full_frequency_directivity: If True, compute element directivity at
415
+ every frequency.
416
+ strategy: Backend strategy for the frequency sweep. If None,
417
+ auto-selects based on the detected array backend.
418
+
419
+ Returns:
420
+ SimusResult with RF signals and complex spectrum.
421
+ """
422
+ xp = array_namespace(scatterers, rc, delays)
423
+
424
+ if tx_apodization is None:
425
+ tx_apodization = xp.ones(params.n_elements)
426
+
427
+ nan_mask = xp.isnan(delays)
428
+ tx_apodization = xp.where(nan_mask, xp.asarray(0.0), tx_apodization)
429
+ delays_clean = xp.where(nan_mask, xp.asarray(0.0), delays)
430
+
431
+ # Flatten scatterers for the frequency sweep
432
+ n_scat = scatterers.shape[0] if scatterers.ndim >= 2 else 1
433
+ scatterers_flat = xp.reshape(scatterers, (n_scat, 2)) if scatterers.ndim > 2 else scatterers
434
+ rc_flat = xp.reshape(rc, (n_scat,)) if rc.ndim > 1 else rc
435
+
436
+ selected = _select_simus_strategy(xp, strategy)
437
+
438
+ if selected == SimusStrategy.METAL:
439
+ import mlx.core as mx
440
+
441
+ from fast_simus.kernels.metal_simus import simus_metal
442
+
443
+ spect_selected = cast(
444
+ Array,
445
+ simus_metal(
446
+ scatterers=cast(mx.array, scatterers_flat),
447
+ rc=cast(mx.array, rc_flat),
448
+ params=params,
449
+ plan=plan,
450
+ medium=medium,
451
+ delays_clean=cast(mx.array, delays_clean),
452
+ tx_apodization=cast(mx.array, tx_apodization),
453
+ ),
454
+ )
455
+ elif selected == SimusStrategy.CUDA:
456
+ from fast_simus.kernels.cuda_simus import simus_cuda
457
+
458
+ spect_selected = cast(
459
+ Array,
460
+ simus_cuda(
461
+ scatterers=scatterers_flat,
462
+ rc=rc_flat,
463
+ params=params,
464
+ plan=plan,
465
+ medium=medium,
466
+ delays_clean=delays_clean,
467
+ tx_apodization=tx_apodization,
468
+ ),
469
+ )
470
+ else:
471
+ sweep = _prepare_simus_sweep(
472
+ scatterers_flat,
473
+ delays_clean,
474
+ tx_apodization,
475
+ plan,
476
+ params,
477
+ medium,
478
+ full_frequency_directivity=full_frequency_directivity,
479
+ xp=xp,
480
+ )
481
+ if selected == SimusStrategy.SCAN:
482
+ from fast_simus._simus_strategies import _simus_freq_outer_scan
483
+
484
+ spect_selected = _simus_freq_outer_scan(rc=rc_flat, xp=xp, **sweep)
485
+ else:
486
+ from fast_simus._simus_strategies import _simus_freq_outer_python
487
+
488
+ spect_selected = _simus_freq_outer_python(rc=rc_flat, xp=xp, **sweep)
489
+
490
+ # Apply correction factor
491
+ spect_selected = spect_selected * xp.asarray(plan.correction_factor)
492
+
493
+ rf, full_spectrum = _irfft_and_threshold(spect_selected, plan, params.n_elements, xp)
494
+
495
+ return SimusResult(rf=rf, spectrum=full_spectrum)
496
+
497
+
498
+ def simus(
499
+ scatterers: Float[Array, "*batch 2"],
500
+ rc: Float[Array, " *batch"],
501
+ delays: Float[Array, " n_elements"],
502
+ params: TransducerParams,
503
+ medium: MediumParams = _DEFAULT_MEDIUM,
504
+ *,
505
+ fs: float | None = None,
506
+ tx_apodization: Float[Array, " n_elements"] | None = None,
507
+ tx_n_wavelengths: float | int = 1.0,
508
+ db_thresh: float | int = -60.0,
509
+ full_frequency_directivity: bool = False,
510
+ element_splitting: int | None = None,
511
+ frequency_step: float | int = 1.0,
512
+ strategy: SimusStrategy | None = None,
513
+ ) -> SimusResult:
514
+ """Simulate ultrasound RF signals for a linear or convex array.
515
+
516
+ Computes RF radio-frequency signals generated by an ultrasound uniform
517
+ linear or convex array insonifying a medium of scatterers. Uses the SIMUS
518
+ algorithm: TX forward propagation, scattering, RX back-propagation
519
+ (acoustic reciprocity), and IFFT to time domain.
520
+
521
+ Args:
522
+ scatterers: Scatterer positions in meters. Shape ``(*batch, 2)`` where
523
+ ``[..., 0]`` is lateral (x) and ``[..., 1]`` is axial (z).
524
+ rc: Reflection coefficients. Shape ``(*batch,)``. Same size as scatterers
525
+ (excluding last dimension).
526
+ delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
527
+ params: Transducer parameters (geometry, frequency, bandwidth).
528
+ medium: Medium parameters (speed of sound, attenuation).
529
+ fs: Sampling frequency in Hz. Defaults to ``4 * params.freq_center``.
530
+ tx_apodization: Transmit apodization weights. Shape ``(n_elements,)``.
531
+ tx_n_wavelengths: Number of wavelengths in the TX pulse.
532
+ db_thresh: Threshold in dB for frequency component selection.
533
+ full_frequency_directivity: If True, compute element directivity at
534
+ every frequency. If False, use center-frequency-only directivity.
535
+ element_splitting: Number of sub-elements per element (None = auto).
536
+ frequency_step: Scaling factor for the frequency step.
537
+ strategy: Backend strategy for the frequency sweep. If None,
538
+ auto-selects based on the detected array backend.
539
+
540
+ Returns:
541
+ SimusResult with:
542
+ - rf: Time-domain RF signals, shape (n_samples, n_elements)
543
+ - spectrum: Complex RF spectrum, shape (n_freq_full, n_elements)
544
+ """
545
+ plan = simus_precompute(
546
+ scatterers,
547
+ rc,
548
+ delays,
549
+ params,
550
+ medium,
551
+ fs=fs,
552
+ tx_n_wavelengths=tx_n_wavelengths,
553
+ db_thresh=db_thresh,
554
+ element_splitting=element_splitting,
555
+ frequency_step=frequency_step,
556
+ )
557
+ return simus_compute(
558
+ scatterers,
559
+ rc,
560
+ delays,
561
+ plan,
562
+ params,
563
+ medium,
564
+ tx_apodization=tx_apodization,
565
+ full_frequency_directivity=full_frequency_directivity,
566
+ strategy=strategy,
567
+ )
fast_simus/spectrum.py ADDED
@@ -0,0 +1,107 @@
1
+ """Spectrum computation for ultrasound pulse and probe response.
2
+
3
+ Provides functions to compute frequency-domain representations of the
4
+ transmitted pulse and probe frequency response.
5
+
6
+ All functions are Array API compliant and work with NumPy, JAX, CuPy backends.
7
+
8
+ References:
9
+ Garcia D. SIMUS: an open-source simulator for medical ultrasound imaging.
10
+ Part I: theory & examples. CMPB, 2022;218:106726.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from math import log, pi
16
+
17
+ import array_api_extra as xpx
18
+ from beartype import beartype as typechecker
19
+ from jaxtyping import Complex, Float, jaxtyped
20
+
21
+ from fast_simus.utils._array_api import Array, array_namespace
22
+
23
+
24
+ @jaxtyped(typechecker=typechecker)
25
+ def pulse_spectrum(
26
+ angular_freq: Float[Array, " n_freqs"],
27
+ freq_center: float | int,
28
+ tx_n_wavelengths: float | int = 1.0,
29
+ ) -> Complex[Array, " n_freqs"]:
30
+ """Compute the pulse spectrum for a windowed sine pulse.
31
+
32
+ Computes the frequency-domain representation of a windowed sine pulse
33
+ with the given center frequency and number of wavelengths.
34
+
35
+ Args:
36
+ angular_freq: Angular frequency in rad/s. Shape (n_freqs,).
37
+ freq_center: Center frequency in Hz. Must be positive.
38
+ tx_n_wavelengths: Number of wavelengths of the TX pulse.
39
+
40
+ Returns:
41
+ Complex spectrum at the given angular frequencies. Shape (n_freqs,).
42
+ """
43
+ pulse_duration_s = tx_n_wavelengths / freq_center
44
+ angular_freq_center = 2.0 * pi * freq_center
45
+
46
+ xp = array_namespace(angular_freq)
47
+ sinc_arg_lower = pulse_duration_s * (angular_freq - angular_freq_center) / 2.0 / pi
48
+ sinc_arg_upper = pulse_duration_s * (angular_freq + angular_freq_center) / 2.0 / pi
49
+ # array-api-extra does not have type interoperability
50
+ return 1j * (xpx.sinc(sinc_arg_lower, xp=xp) - xpx.sinc(sinc_arg_upper, xp=xp))
51
+
52
+
53
+ @jaxtyped(typechecker=typechecker)
54
+ def probe_spectrum(
55
+ angular_freq: Float[Array, " n_freqs"],
56
+ freq_center: float | int,
57
+ bandwidth: float | int = 0.75,
58
+ ) -> Float[Array, " n_freqs"]:
59
+ """Compute the probe frequency response.
60
+
61
+ Computes the one-way probe frequency response using a generalized normal
62
+ window. The bandwidth parameter defines the pulse-echo 6 dB fractional
63
+ bandwidth.
64
+
65
+ The returned spectrum is the square root of the pulse-echo response,
66
+ appropriate for one-way (transmit-only or receive-only) use.
67
+
68
+ Args:
69
+ angular_freq: Angular frequency in rad/s. Shape (n_freqs,).
70
+ freq_center: Center frequency in Hz. Must be positive.
71
+ bandwidth: Fractional bandwidth (0.75 = 75%). Must be in (0, 2.0).
72
+
73
+ Returns:
74
+ Real-valued probe response (one-way) at the given angular frequencies.
75
+ Shape (n_freqs,).
76
+
77
+ References:
78
+ Generalized normal window:
79
+ https://en.wikipedia.org/wiki/Window_function#Generalized_normal_window
80
+
81
+ Reference implementation (log(126) constant verified against):
82
+ https://github.com/creatis-ULTIM/PyMUST/blob/df02b42/src/pymust/utils.py#L141
83
+ """
84
+ # Validate bandwidth
85
+ if not (0.0 < bandwidth < 2.0):
86
+ msg = f"bandwidth must be in (0, 2.0), got {bandwidth!r}"
87
+ raise ValueError(msg)
88
+
89
+ angular_freq_center = 2.0 * pi * freq_center
90
+ # Convert fractional bandwidth to angular bandwidth
91
+ angular_bandwidth = bandwidth * angular_freq_center
92
+ # Shape parameter for the generalized normal window
93
+ # The constant 126 comes from the two-way 6 dB bandwidth criterion:
94
+ # For pulse-echo, the total response is the product of TX and RX responses,
95
+ # so the one-way response at -3 dB corresponds to -6 dB two-way.
96
+ # In linear scale: 10^(6/10) ≈ 3.98, but the generalized normal window
97
+ # parameterization uses 126 = 2 * (2^6) to define the bandwidth edges
98
+ # where the two-way response falls to -6 dB.
99
+ shape_param = log(126) / log(2.0 * angular_freq_center / angular_bandwidth)
100
+ # Denominator of the exponent
101
+ sigma = angular_bandwidth / 2.0 / (log(2) ** (1.0 / shape_param))
102
+
103
+ xp = array_namespace(angular_freq)
104
+ # Pulse-echo (squared) response
105
+ spectrum_squared = xp.exp(-((xp.abs(angular_freq - angular_freq_center) / sigma) ** shape_param))
106
+ # One-way response (square root)
107
+ return xp.sqrt(spectrum_squared)