FastSIMUS 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
fast_simus/pfield.py ADDED
@@ -0,0 +1,475 @@
1
+ """Pressure field computation for ultrasound transducer arrays.
2
+
3
+ Implements PFIELD algorithm for simulating ultrasound beam patterns from
4
+ phased/linear/convex arrays using Fraunhofer (far-field) approximation in
5
+ the azimuthal plane and Fresnel (paraxial) approximation in elevation.
6
+
7
+ All functions are Array API compliant and work with NumPy, JAX, CuPy backends.
8
+
9
+ References:
10
+ Garcia D. SIMUS: an open-source simulator for medical ultrasound imaging.
11
+ Part I: theory & examples. CMPB, 2022;218:106726.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from enum import StrEnum
17
+ from math import ceil, inf, pi, prod
18
+ from types import ModuleType
19
+ from typing import TYPE_CHECKING, NamedTuple, cast
20
+
21
+ import array_api_extra as xpx
22
+ from array_api_compat import is_jax_namespace
23
+ from beartype import beartype as typechecker
24
+ from jaxtyping import Bool, Complex, Float, jaxtyped
25
+
26
+ from fast_simus._pfield_math import (
27
+ _distances_and_angles,
28
+ _init_exponentials,
29
+ _obliquity_factor,
30
+ _select_frequencies,
31
+ _subelement_centroids,
32
+ )
33
+ from fast_simus.medium_params import MediumParams
34
+ from fast_simus.transducer_params import BaffleType, TransducerParams
35
+ from fast_simus.utils._array_api import Array, _ArrayNamespace, array_namespace, is_mlx_namespace
36
+ from fast_simus.utils.geometry import element_positions
37
+
38
+ _DEFAULT_MEDIUM = MediumParams()
39
+
40
+
41
+ class PfieldStrategy(StrEnum):
42
+ """Backend strategy for the pfield frequency sweep.
43
+
44
+ The three-layer pfield architecture separates:
45
+ - Layer 1 (setup): geometry, phase init -- pure Array API, shared by all
46
+ - Layer 2 (step body): per-frequency math -- pure Array API function
47
+ - Layer 3 (loop driver): iteration mechanism -- backend-specific
48
+
49
+ This enum selects the Layer 3 loop driver. When None is passed to
50
+ pfield_compute, the strategy is auto-selected based on the detected backend.
51
+ """
52
+
53
+ VECTORIZED = "vectorized"
54
+ SCAN = "scan"
55
+ METAL = "metal"
56
+
57
+
58
+ class PfieldPlan(NamedTuple):
59
+ """Precomputed plan for pfield computation.
60
+
61
+ Contains all data-dependent quantities so
62
+ that ``pfield_compute`` has static array shapes and can be JIT-compiled.
63
+
64
+ Use ``pfield_precompute`` to construct this; do not build manually.
65
+
66
+ Attributes:
67
+ selected_freqs: Significant frequency samples in Hz (uniformly spaced).
68
+ pulse_spectrum: Pulse spectrum at selected frequencies (complex).
69
+ probe_spectrum: Probe response at selected frequencies (real).
70
+ n_sub: Number of sub-elements per transducer element.
71
+ seg_length: Sub-element length in meters (element_width / n_sub).
72
+ correction_factor: Scaling factor for the RMS integration
73
+ (df * element_width, or element_width when tx_n_wavelengths=inf).
74
+ """
75
+
76
+ selected_freqs: Float[Array, " n_frequencies"]
77
+ pulse_spectrum: Complex[Array, " n_frequencies"]
78
+ probe_spectrum: Float[Array, " n_frequencies"]
79
+ n_sub: int
80
+ seg_length: float
81
+ correction_factor: float
82
+
83
+
84
+ class _SweepInputs(NamedTuple):
85
+ """Precomputed inputs for the Array API frequency-sweep strategies.
86
+
87
+ Source points are flattened: n_sources = n_elements * n_sub.
88
+ The 1/n_sub normalization is absorbed into phase_decay_init.
89
+ """
90
+
91
+ phase_decay_init: Complex[Array, " *grid n_sources"]
92
+ phase_decay_step: Complex[Array, " *grid n_sources"]
93
+ is_out: Bool[Array, " *grid"]
94
+ wavenumbers: Float[Array, " n_freq"]
95
+ pulse_spect: Complex[Array, " n_freq"]
96
+ probe_spect: Float[Array, " n_freq"]
97
+ seg_length: float
98
+ sin_theta: Float[Array, " *grid n_sources"]
99
+ full_frequency_directivity: bool
100
+
101
+
102
+ def _prepare_frequency_sweep(
103
+ positions: Float[Array, "*grid_shape 2"],
104
+ delays_clean: Float[Array, " n_elements"],
105
+ tx_apodization: Float[Array, " n_elements"],
106
+ plan: PfieldPlan,
107
+ params: TransducerParams,
108
+ medium: MediumParams,
109
+ *,
110
+ full_frequency_directivity: bool,
111
+ xp: _ArrayNamespace,
112
+ ) -> _SweepInputs:
113
+ """Compute geometry, phases, and obliquity for Array API loop drivers.
114
+
115
+ Shared setup for VECTORIZED and SCAN strategies. The Metal kernel
116
+ computes geometry on-the-fly and does not use this function.
117
+ """
118
+ element_pos, theta_elements, apex_offset = element_positions(params.n_elements, params.pitch, params.radius, xp)
119
+ if theta_elements is None:
120
+ theta_elements = xp.zeros(params.n_elements)
121
+
122
+ speed_of_sound = medium.speed_of_sound
123
+ attenuation = medium.attenuation
124
+
125
+ subelement_offsets = _subelement_centroids(params.element_width, plan.n_sub, theta_elements, xp)
126
+
127
+ x = positions[..., 0]
128
+ z = positions[..., 1]
129
+ is_out = z < 0
130
+ if params.radius != inf:
131
+ is_out = is_out | ((x**2 + (z + apex_offset) ** 2) <= params.radius**2)
132
+
133
+ distances, sin_theta, theta_arr = _distances_and_angles(
134
+ positions, subelement_offsets, element_pos, theta_elements, speed_of_sound, params.freq_center, xp
135
+ )
136
+
137
+ obliquity_factor = _obliquity_factor(theta_arr, params.baffle, xp)
138
+
139
+ freq_start = plan.selected_freqs[0]
140
+ n_freqs = plan.selected_freqs.shape[0]
141
+ freq_step = (plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freqs > 1 else xp.asarray(0.0)
142
+
143
+ phase_decay_init, phase_decay_step = _init_exponentials(
144
+ freq_start, speed_of_sound, attenuation, distances, obliquity_factor, freq_step, xp
145
+ )
146
+
147
+ if not full_frequency_directivity:
148
+ center_wavenumber = 2.0 * pi * params.freq_center / speed_of_sound
149
+ sinc_arg = xp.asarray(center_wavenumber * plan.seg_length / 2.0) * sin_theta / pi
150
+ phase_decay_init = phase_decay_init * xpx.sinc(sinc_arg, xp=xp)
151
+
152
+ # Absorb delay+apodization into the geometric progression so loop
153
+ # drivers don't need a per-frequency multiply for delays.
154
+ delay_apod_init = xp.exp(xp.asarray(1j * 2.0 * pi) * freq_start * delays_clean) * tx_apodization
155
+ delay_apod_step = xp.exp(xp.asarray(1j * 2.0 * pi) * freq_step * delays_clean)
156
+ phase_decay_init = phase_decay_init * delay_apod_init[:, None]
157
+ phase_decay_step = phase_decay_step * delay_apod_step[:, None]
158
+
159
+ # Absorb 1/n_sub normalization and flatten (n_elements, n_sub) -> (n_sources,).
160
+ # After this, sub-elements and elements are equivalent source points
161
+ # and all loop drivers use a single sum(axis=-1).
162
+ n_sub = plan.n_sub
163
+ phase_decay_init = phase_decay_init / n_sub
164
+
165
+ def _flatten_sources(arr: Array) -> Array:
166
+ return xp.reshape(arr, (*arr.shape[:-2], arr.shape[-2] * arr.shape[-1]))
167
+
168
+ phase_decay_init = _flatten_sources(phase_decay_init)
169
+ phase_decay_step = _flatten_sources(phase_decay_step)
170
+ sin_theta = _flatten_sources(sin_theta)
171
+
172
+ wavenumbers = xp.asarray(2.0 * pi) * plan.selected_freqs / speed_of_sound
173
+
174
+ return _SweepInputs(
175
+ phase_decay_init=phase_decay_init,
176
+ phase_decay_step=phase_decay_step,
177
+ is_out=is_out,
178
+ wavenumbers=wavenumbers,
179
+ pulse_spect=plan.pulse_spectrum,
180
+ probe_spect=plan.probe_spectrum,
181
+ seg_length=plan.seg_length,
182
+ sin_theta=sin_theta,
183
+ full_frequency_directivity=full_frequency_directivity,
184
+ )
185
+
186
+
187
+ def _metal_supported(params: TransducerParams, full_frequency_directivity: bool) -> bool:
188
+ """Check whether the Metal kernel supports the given configuration."""
189
+ if full_frequency_directivity:
190
+ return False
191
+ if not isinstance(params.baffle, str | BaffleType):
192
+ return False
193
+ return params.baffle == BaffleType.SOFT
194
+
195
+
196
+ def _select_strategy(
197
+ xp: _ArrayNamespace,
198
+ grid_size: int,
199
+ params: TransducerParams,
200
+ full_frequency_directivity: bool,
201
+ *,
202
+ strategy: PfieldStrategy | None = None,
203
+ ) -> PfieldStrategy:
204
+ """Auto-select the best pfield strategy for the detected backend."""
205
+ if strategy is not None:
206
+ if strategy == PfieldStrategy.METAL and not _metal_supported(params, full_frequency_directivity):
207
+ unsupported = []
208
+ if full_frequency_directivity:
209
+ unsupported.append("full_frequency_directivity=True")
210
+ if params.baffle != BaffleType.SOFT:
211
+ unsupported.append(f"baffle={params.baffle!r} (only SOFT supported)")
212
+ raise NotImplementedError(
213
+ f"Metal kernel does not support: {', '.join(unsupported)}. Use strategy=None for auto-selection."
214
+ )
215
+ return strategy
216
+ if is_jax_namespace(cast(ModuleType, xp)):
217
+ return PfieldStrategy.SCAN
218
+ if is_mlx_namespace(xp) and _metal_supported(params, full_frequency_directivity):
219
+ return PfieldStrategy.METAL
220
+ return PfieldStrategy.VECTORIZED
221
+
222
+
223
+ def pfield_precompute(
224
+ positions: Float[Array, "*grid_shape 2"],
225
+ delays: Float[Array, " n_elements"],
226
+ params: TransducerParams,
227
+ medium: MediumParams = _DEFAULT_MEDIUM,
228
+ *,
229
+ tx_n_wavelengths: float | int = 1.0,
230
+ db_thresh: float | int = -60.0,
231
+ element_splitting: int | None = None,
232
+ frequency_step: float | int = 1.0,
233
+ ) -> PfieldPlan:
234
+ """Precompute static quantities for pfield computation.
235
+
236
+ Extracts all data-dependent scalars and
237
+ dynamically-shaped arrays so that ``pfield_compute`` has static shapes
238
+ suitable for JAX JIT compilation.
239
+
240
+ Args:
241
+ positions: Grid positions in meters. Shape ``(*grid_shape, 2)``.
242
+ delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
243
+ params: Transducer parameters.
244
+ medium: Medium parameters.
245
+ tx_n_wavelengths: Number of wavelengths in the TX pulse.
246
+ db_thresh: Threshold in dB for frequency component selection.
247
+ element_splitting: Number of sub-elements per element (None = auto).
248
+ frequency_step: Scaling factor for the frequency step.
249
+
250
+ Returns:
251
+ PfieldPlan with static-shaped arrays and precomputed scalars.
252
+ """
253
+ xp = array_namespace(positions, delays)
254
+ speed_of_sound = medium.speed_of_sound
255
+
256
+ if positions.size == 0:
257
+ raise ValueError("Grid has no points")
258
+
259
+ # NaN-clean delays (for max-delay calculation)
260
+ delays_clean = xp.where(xp.isnan(delays), xp.asarray(0.0), delays)
261
+
262
+ # Element splitting: requires Python ceil on computed float
263
+ if element_splitting is not None:
264
+ n_sub = element_splitting
265
+ else:
266
+ lambda_min = speed_of_sound / (params.freq_center * (1.0 + params.bandwidth / 2.0))
267
+ n_sub = ceil(params.element_width / lambda_min)
268
+
269
+ seg_length = params.element_width / n_sub
270
+
271
+ # Geometry for max-distance calculation
272
+ element_pos, theta_elements, _ = element_positions(params.n_elements, params.pitch, params.radius, xp)
273
+ if theta_elements is None:
274
+ theta_elements = xp.zeros(params.n_elements)
275
+ subelement_offsets = _subelement_centroids(params.element_width, n_sub, theta_elements, xp)
276
+ distances, _, _ = _distances_and_angles(
277
+ positions, subelement_offsets, element_pos, theta_elements, speed_of_sound, params.freq_center, xp
278
+ )
279
+
280
+ # Frequency step: requires float() extraction from array
281
+ df = 1.0 / (float(xp.max(distances)) / speed_of_sound + float(xp.max(delays_clean)))
282
+ df = float(frequency_step) * df
283
+
284
+ # Frequency selection: uses boolean masking -> dynamic n_frequencies
285
+ freq_plan = _select_frequencies(params.freq_center, params.bandwidth, tx_n_wavelengths, db_thresh, df, xp)
286
+ df = freq_plan.freq_step
287
+
288
+ correction_factor = 1.0 if tx_n_wavelengths == float("inf") else df
289
+ correction_factor = correction_factor * params.element_width
290
+
291
+ return PfieldPlan(
292
+ selected_freqs=freq_plan.selected_freqs,
293
+ pulse_spectrum=freq_plan.pulse_spectrum,
294
+ probe_spectrum=freq_plan.probe_spectrum,
295
+ n_sub=n_sub,
296
+ seg_length=seg_length,
297
+ correction_factor=correction_factor,
298
+ )
299
+
300
+
301
+ def pfield_compute(
302
+ positions: Float[Array, "*grid_shape 2"],
303
+ delays: Float[Array, " n_elements"],
304
+ plan: PfieldPlan,
305
+ params: TransducerParams,
306
+ medium: MediumParams = _DEFAULT_MEDIUM,
307
+ *,
308
+ tx_apodization: Float[Array, " n_elements"] | None = None,
309
+ full_frequency_directivity: bool = False,
310
+ strategy: PfieldStrategy | None = None,
311
+ ) -> Float[Array, " *grid_shape"]:
312
+ """Compute the RMS pressure field given a precomputed plan.
313
+
314
+ Contains only static-shape operations and is suitable for JAX JIT
315
+ compilation when ``plan`` and ``params`` are treated as static arguments.
316
+
317
+ Args:
318
+ positions: Grid positions in meters. Shape ``(*grid_shape, 2)``.
319
+ delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
320
+ plan: Precomputed plan from ``pfield_precompute``.
321
+ params: Transducer parameters.
322
+ medium: Medium parameters.
323
+ tx_apodization: Transmit apodization weights. Shape ``(n_elements,)``.
324
+ Elements with NaN delays are automatically zeroed.
325
+ full_frequency_directivity: If True, compute element directivity at
326
+ every frequency. If False, use center-frequency-only directivity.
327
+ strategy: Backend strategy for the frequency sweep. If None,
328
+ auto-selects based on the detected array backend.
329
+
330
+ Returns:
331
+ RMS pressure field with shape ``(*grid_shape,)``.
332
+ """
333
+ xp = array_namespace(positions, delays, tx_apodization)
334
+
335
+ if tx_apodization is None:
336
+ tx_apodization = xp.ones(params.n_elements)
337
+
338
+ nan_mask = xp.isnan(delays)
339
+ tx_apodization = xp.where(nan_mask, xp.asarray(0.0), tx_apodization)
340
+ delays_clean = xp.where(nan_mask, xp.asarray(0.0), delays)
341
+
342
+ grid_size = prod(positions.shape[:-1])
343
+ selected = _select_strategy(xp, grid_size, params, full_frequency_directivity, strategy=strategy)
344
+
345
+ if selected == PfieldStrategy.METAL:
346
+ from fast_simus.kernels.metal_pfield import pfield_metal
347
+
348
+ if TYPE_CHECKING:
349
+ import mlx.core as mx
350
+
351
+ pressure_accum = cast(
352
+ Array,
353
+ pfield_metal(
354
+ positions=cast("mx.array", positions),
355
+ params=params,
356
+ plan=plan,
357
+ medium=medium,
358
+ delays_clean=cast("mx.array", delays_clean),
359
+ tx_apodization=cast("mx.array", tx_apodization),
360
+ ),
361
+ )
362
+ else:
363
+ from fast_simus._pfield_strategies import _freq_outer_python, _freq_outer_scan
364
+
365
+ sweep = _prepare_frequency_sweep(
366
+ positions,
367
+ delays_clean,
368
+ tx_apodization,
369
+ plan,
370
+ params,
371
+ medium,
372
+ full_frequency_directivity=full_frequency_directivity,
373
+ xp=xp,
374
+ )
375
+ driver = _freq_outer_scan if selected == PfieldStrategy.SCAN else _freq_outer_python
376
+ pressure_accum = driver(**sweep._asdict(), xp=xp)
377
+
378
+ return xp.sqrt(pressure_accum * plan.correction_factor)
379
+
380
+
381
+ @jaxtyped(typechecker=typechecker)
382
+ def pfield(
383
+ positions: Float[Array, "*grid_shape 2"],
384
+ delays: Float[Array, " n_elements"],
385
+ params: TransducerParams,
386
+ medium: MediumParams = _DEFAULT_MEDIUM,
387
+ *,
388
+ tx_apodization: Float[Array, " n_elements"] | None = None,
389
+ tx_n_wavelengths: float | int = 1.0,
390
+ db_thresh: float | int = -60.0,
391
+ full_frequency_directivity: bool = False,
392
+ element_splitting: int | None = None,
393
+ frequency_step: float | int = 1.0,
394
+ strategy: PfieldStrategy | None = None,
395
+ ) -> Float[Array, " *grid_shape"]:
396
+ """Compute the RMS acoustic pressure field of a transducer array.
397
+
398
+ Calculates the radiation pattern (root-mean-square of acoustic pressure)
399
+ for a uniform linear or convex array whose elements are excited at
400
+ different time delays. 2-D computation only (no elevation focusing).
401
+
402
+ Algorithm
403
+ ---------
404
+ Implements Garcia 2022 Eq. 22, computing acoustic pressure by superposing
405
+ contributions from all array elements:
406
+
407
+ P(X,w,t) ~ P_TX(w) exp(-iwt) Sum_n W_n [exp(ikr_n)/r_n] D(theta_n,k) exp(iw*tau_n)
408
+
409
+ Where:
410
+ - P_TX(w): Transmit pulse spectrum (windowed sinusoid x transducer response)
411
+ - r_n: Distance from sub-element n to field point
412
+ - D(theta_n,k): Element directivity = sinc(kb*sin(theta)) x obliquity_factor
413
+ - W_n: Transmit apodization weights
414
+ - tau_n: Transmit time delays for focusing/steering
415
+
416
+ Wide elements are split into nu sub-elements where nu = ceil(width/lambda_min)
417
+ to satisfy far-field conditions. The RMS field is computed by integrating
418
+ |P(X,w)|^2 over the frequency band (Garcia 2022 Eq. 41-42):
419
+
420
+ P_RMS(X) = sqrt[Integral |P(X,w)|^2 dw] ~ sqrt[Delta_w Sum |P(X,w_j)|^2]
421
+
422
+ Frequency sampling uses adaptive step Delta_w to avoid phase aliasing, ensuring
423
+ (Delta_w/c)*r_max + Delta_w*tau_max < 2*pi everywhere in the region of interest.
424
+
425
+ Implementation Notes
426
+ --------------------
427
+ - **2D mode**: Uses 1/sqrt(r) geometric spreading (no elevation focusing)
428
+ - **Attenuation**: Frequency-linear absorption exp(-alpha*f*r) with alpha in dB/cm/MHz
429
+ - **Baffle**: Obliquity factor depends on boundary condition (rigid/soft/custom)
430
+ - **Directivity**: Can be frequency-dependent (slower) or center-frequency only
431
+
432
+ Args:
433
+ positions: Grid positions in meters. Shape ``(*grid_shape, 2)`` where
434
+ ``positions[..., 0]`` is lateral (x) and ``positions[..., 1]`` is
435
+ axial (z, into tissue).
436
+ delays: Transmit time delays in seconds. Shape ``(n_elements,)``.
437
+ params: Transducer parameters (geometry, frequency, bandwidth, baffle).
438
+ medium: Medium parameters (speed of sound, attenuation).
439
+ tx_apodization: Transmit apodization weights. Shape ``(n_elements,)``.
440
+ Elements with NaN delays are automatically zeroed.
441
+ tx_n_wavelengths: Number of wavelengths in the TX pulse.
442
+ db_thresh: Threshold in dB for frequency component selection.
443
+ Only components above this threshold (relative to peak) are used.
444
+ full_frequency_directivity: If True, compute element directivity at
445
+ every frequency. If False, use center-frequency-only directivity.
446
+ element_splitting: Number of sub-elements per transducer element.
447
+ If None, computed automatically as ceil(element_width / smallest_wavelength).
448
+ frequency_step: Scaling factor for the frequency step.
449
+ Values > 1 speed up computation; values < 1 give smoother results.
450
+ strategy: Backend strategy for the frequency sweep. If None,
451
+ auto-selects based on the detected array backend.
452
+
453
+ Returns:
454
+ RMS pressure field with shape ``(*grid_shape,)``.
455
+ """
456
+ plan = pfield_precompute(
457
+ positions,
458
+ delays,
459
+ params,
460
+ medium,
461
+ tx_n_wavelengths=tx_n_wavelengths,
462
+ db_thresh=db_thresh,
463
+ element_splitting=element_splitting,
464
+ frequency_step=frequency_step,
465
+ )
466
+ return pfield_compute(
467
+ positions,
468
+ delays,
469
+ plan,
470
+ params,
471
+ medium,
472
+ tx_apodization=tx_apodization,
473
+ full_frequency_directivity=full_frequency_directivity,
474
+ strategy=strategy,
475
+ )
fast_simus/py.typed ADDED
File without changes