FastSIMUS 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_simus/__init__.py +33 -0
- fast_simus/_pfield_math.py +261 -0
- fast_simus/_pfield_strategies.py +203 -0
- fast_simus/_simus_strategies.py +210 -0
- fast_simus/backends/__init__.py +1 -0
- fast_simus/backends/mlx.py +101 -0
- fast_simus/kernels/__init__.py +9 -0
- fast_simus/kernels/cuda_simus.py +321 -0
- fast_simus/kernels/metal_pfield.py +219 -0
- fast_simus/kernels/metal_simus.py +377 -0
- fast_simus/kernels/pfield.metal +97 -0
- fast_simus/kernels/simus_fused.cu +332 -0
- fast_simus/kernels/simus_rx_simd.metal +128 -0
- fast_simus/kernels/simus_tx_tiled.metal +175 -0
- fast_simus/medium_params.py +22 -0
- fast_simus/pfield.py +475 -0
- fast_simus/py.typed +0 -0
- fast_simus/simus.py +567 -0
- fast_simus/spectrum.py +107 -0
- fast_simus/transducer_params.py +160 -0
- fast_simus/transducer_presets.py +102 -0
- fast_simus/tx_delay.py +276 -0
- fast_simus/utils/__init__.py +5 -0
- fast_simus/utils/_array_api.py +294 -0
- fast_simus/utils/geometry.py +88 -0
- fastsimus-0.0.1.dist-info/METADATA +594 -0
- fastsimus-0.0.1.dist-info/RECORD +28 -0
- fastsimus-0.0.1.dist-info/WHEEL +4 -0
fast_simus/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""FastSIMUS - Fast Simulator for Medical Ultrasound based on SIMUS/MUST."""
|
|
2
|
+
|
|
3
|
+
from fast_simus.medium_params import MediumParams
|
|
4
|
+
from fast_simus.pfield import PfieldPlan, PfieldStrategy, pfield, pfield_compute, pfield_precompute
|
|
5
|
+
from fast_simus.simus import SimusPlan, SimusResult, SimusStrategy, simus, simus_compute, simus_precompute
|
|
6
|
+
from fast_simus.transducer_params import BaffleType, TransducerParams
|
|
7
|
+
from fast_simus.tx_delay import (
|
|
8
|
+
diverging_wave,
|
|
9
|
+
focused,
|
|
10
|
+
plane_wave,
|
|
11
|
+
)
|
|
12
|
+
from fast_simus.utils.geometry import element_positions
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"BaffleType",
|
|
16
|
+
"MediumParams",
|
|
17
|
+
"PfieldPlan",
|
|
18
|
+
"PfieldStrategy",
|
|
19
|
+
"SimusPlan",
|
|
20
|
+
"SimusResult",
|
|
21
|
+
"SimusStrategy",
|
|
22
|
+
"TransducerParams",
|
|
23
|
+
"diverging_wave",
|
|
24
|
+
"element_positions",
|
|
25
|
+
"focused",
|
|
26
|
+
"pfield",
|
|
27
|
+
"pfield_compute",
|
|
28
|
+
"pfield_precompute",
|
|
29
|
+
"plane_wave",
|
|
30
|
+
"simus",
|
|
31
|
+
"simus_compute",
|
|
32
|
+
"simus_precompute",
|
|
33
|
+
]
|
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""Physics helpers for pfield computation.
|
|
2
|
+
|
|
3
|
+
Pure Array API functions for geometry, phase initialization, frequency
|
|
4
|
+
selection, and obliquity. No loop structure or backend-specific code.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from math import ceil, log, pi
|
|
10
|
+
from typing import NamedTuple
|
|
11
|
+
|
|
12
|
+
from beartype import beartype as typechecker
|
|
13
|
+
from jaxtyping import Complex, Float, jaxtyped
|
|
14
|
+
|
|
15
|
+
from fast_simus.spectrum import probe_spectrum, pulse_spectrum
|
|
16
|
+
from fast_simus.transducer_params import BaffleType
|
|
17
|
+
from fast_simus.utils._array_api import Array, _ArrayNamespace
|
|
18
|
+
|
|
19
|
+
# Conversion factor: Nepers to dB -- 20/log(10) ≈ 8.6859
|
|
20
|
+
# Shared by Array API path and Metal kernel wrapper.
|
|
21
|
+
NEPER_TO_DB = 20.0 / log(10.0)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _FrequencyPlan(NamedTuple):
|
|
25
|
+
"""Frequency sampling plan for pfield computation.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
selected_freqs: Selected frequencies, shape
|
|
29
|
+
pulse_spectrum: Pulse spectrum at selected frequencies, shape
|
|
30
|
+
probe_spectrum: Probe spectrum at selected frequencies, shape
|
|
31
|
+
freq_step: Frequency step in Hz
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
selected_freqs: Float[Array, " n_frequencies"]
|
|
35
|
+
pulse_spectrum: Complex[Array, " n_frequencies"]
|
|
36
|
+
probe_spectrum: Float[Array, " n_frequencies"]
|
|
37
|
+
freq_step: float
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@jaxtyped(typechecker=typechecker)
|
|
41
|
+
def _subelement_centroids(
|
|
42
|
+
element_width: float,
|
|
43
|
+
n_sub: int,
|
|
44
|
+
theta_e: Float[Array, " n_elements"],
|
|
45
|
+
xp: _ArrayNamespace,
|
|
46
|
+
) -> Float[Array, "n_elements n_sub 2"]:
|
|
47
|
+
"""Compute sub-element centroid positions relative to element centers.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
element_width: Element width in meters.
|
|
51
|
+
n_sub: Number of sub-elements per element.
|
|
52
|
+
theta_e: Element angular positions in radians.
|
|
53
|
+
xp: Array namespace.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Sub-element offsets with shape (n_elements, n_sub, 2) where [..., 0]
|
|
57
|
+
is lateral (x) and [..., 1] is axial (z).
|
|
58
|
+
"""
|
|
59
|
+
seg_length = element_width / n_sub
|
|
60
|
+
seg_offsets = xp.asarray([-element_width / 2.0 + seg_length / 2.0 + i * seg_length for i in range(n_sub)])
|
|
61
|
+
# Broadcasting: (n_sub,) -> (1, n_sub), (n_elements,) -> (n_elements, 1)
|
|
62
|
+
seg_offsets_2d = xp.reshape(seg_offsets, (1, n_sub))
|
|
63
|
+
cos_theta = xp.cos(theta_e)[:, None]
|
|
64
|
+
sin_neg_theta = xp.sin(-theta_e)[:, None]
|
|
65
|
+
subelement_dx = seg_offsets_2d * cos_theta
|
|
66
|
+
subelement_dz = seg_offsets_2d * sin_neg_theta
|
|
67
|
+
return xp.stack([subelement_dx, subelement_dz], axis=-1)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@jaxtyped(typechecker=typechecker)
|
|
71
|
+
def _distances_and_angles(
|
|
72
|
+
points: Float[Array, "*batch 2"],
|
|
73
|
+
subelement_offsets: Float[Array, "n_elements n_sub 2"],
|
|
74
|
+
element_pos: Float[Array, "n_elements 2"],
|
|
75
|
+
theta_e: Float[Array, " n_elements"],
|
|
76
|
+
speed_of_sound: float,
|
|
77
|
+
freq_center: float,
|
|
78
|
+
xp: _ArrayNamespace,
|
|
79
|
+
) -> tuple[
|
|
80
|
+
Float[Array, "*batch n_elements n_sub"],
|
|
81
|
+
Float[Array, "*batch n_elements n_sub"],
|
|
82
|
+
Float[Array, "*batch n_elements n_sub"],
|
|
83
|
+
]:
|
|
84
|
+
"""Compute distances and angles from grid points to sub-elements.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
points: Grid point positions.
|
|
88
|
+
subelement_offsets: Sub-element offsets.
|
|
89
|
+
element_pos: Element positions.
|
|
90
|
+
theta_e: Element angular positions.
|
|
91
|
+
speed_of_sound: Speed of sound in m/s.
|
|
92
|
+
freq_center: Center frequency in Hz.
|
|
93
|
+
xp: Array namespace.
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Tuple of (distances, sin_theta, theta_arr):
|
|
97
|
+
- distances: Distances from grid points to sub-elements.
|
|
98
|
+
- sin_theta: Sine of angles relative to element normal.
|
|
99
|
+
- theta_arr: Angles relative to element normal.
|
|
100
|
+
"""
|
|
101
|
+
delta: Float[Array, "*batch n_elements n_sub xz=2"] = (
|
|
102
|
+
points[..., None, None, :] - subelement_offsets - element_pos[:, None, :]
|
|
103
|
+
)
|
|
104
|
+
delta_x = delta[..., 0]
|
|
105
|
+
delta_z = delta[..., 1]
|
|
106
|
+
dist_squared = delta_x**2 + delta_z**2
|
|
107
|
+
distances = xp.sqrt(dist_squared)
|
|
108
|
+
|
|
109
|
+
# Distances with clipping (use unclipped sqrt for angle computation)
|
|
110
|
+
min_distance = xp.asarray(speed_of_sound / freq_center / 2.0)
|
|
111
|
+
distances_clipped = xp.where(distances < min_distance, min_distance, distances)
|
|
112
|
+
|
|
113
|
+
# Angle relative to element normal
|
|
114
|
+
_div_eps = xp.asarray(1e-16) # Numerical stability for division
|
|
115
|
+
theta_arr = xp.asin((delta_x + _div_eps) / (distances + _div_eps)) - theta_e[:, None]
|
|
116
|
+
sin_theta = xp.sin(theta_arr)
|
|
117
|
+
|
|
118
|
+
return distances_clipped, sin_theta, theta_arr
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _select_frequencies(
|
|
122
|
+
fc: float,
|
|
123
|
+
bandwidth: float,
|
|
124
|
+
tx_n_wavelengths: float,
|
|
125
|
+
db_thresh: float,
|
|
126
|
+
max_freq_step: float,
|
|
127
|
+
xp: _ArrayNamespace,
|
|
128
|
+
) -> _FrequencyPlan:
|
|
129
|
+
"""Select frequency samples for pfield computation.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
fc: Center frequency in Hz.
|
|
133
|
+
bandwidth: Fractional bandwidth.
|
|
134
|
+
tx_n_wavelengths: Number of wavelengths in TX pulse.
|
|
135
|
+
db_thresh: Threshold in dB for frequency component selection.
|
|
136
|
+
max_freq_step: Upper bound for frequency step.
|
|
137
|
+
xp: Array namespace.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
FrequencyPlan with selected frequencies and spectra.
|
|
141
|
+
"""
|
|
142
|
+
# Frequency samples
|
|
143
|
+
n_freq = int(2 * ceil(fc / max_freq_step) + 1)
|
|
144
|
+
frequencies = xp.linspace(0, 2 * fc, n_freq)
|
|
145
|
+
freq_step = float(frequencies[1])
|
|
146
|
+
|
|
147
|
+
# Keep only significant components (dB threshold)
|
|
148
|
+
angular_freqs_all = xp.asarray(2.0 * pi) * frequencies
|
|
149
|
+
spectrum_magnitude = xp.abs(
|
|
150
|
+
pulse_spectrum(angular_freqs_all, fc, tx_n_wavelengths) * probe_spectrum(angular_freqs_all, fc, bandwidth)
|
|
151
|
+
)
|
|
152
|
+
gain_db = 20.0 * xp.log10(xp.asarray(1e-200) + spectrum_magnitude / xp.max(spectrum_magnitude))
|
|
153
|
+
above_threshold = gain_db > db_thresh
|
|
154
|
+
idx_first, idx_last = _first_last_true(xp, above_threshold)
|
|
155
|
+
|
|
156
|
+
selected_freqs = frequencies[idx_first : idx_last + 1]
|
|
157
|
+
|
|
158
|
+
angular_freqs_sel = xp.asarray(2.0 * pi) * selected_freqs
|
|
159
|
+
pulse_spect = pulse_spectrum(angular_freqs_sel, fc, tx_n_wavelengths)
|
|
160
|
+
probe_spect = probe_spectrum(angular_freqs_sel, fc, bandwidth)
|
|
161
|
+
|
|
162
|
+
return _FrequencyPlan(selected_freqs, pulse_spect, probe_spect, freq_step)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
@jaxtyped(typechecker=typechecker)
|
|
166
|
+
def _obliquity_factor(
|
|
167
|
+
theta_arr: Float[Array, "*batch n_elements n_sub"],
|
|
168
|
+
baffle: BaffleType | float,
|
|
169
|
+
xp: _ArrayNamespace,
|
|
170
|
+
) -> Float[Array, "*batch n_elements n_sub"]:
|
|
171
|
+
"""Compute obliquity factor based on baffle type.
|
|
172
|
+
|
|
173
|
+
Args:
|
|
174
|
+
theta_arr: Angles relative to element normal.
|
|
175
|
+
baffle: Baffle type or impedance ratio.
|
|
176
|
+
xp: Array namespace.
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
Obliquity factor.
|
|
180
|
+
"""
|
|
181
|
+
non_rigid_baffle = baffle != BaffleType.RIGID
|
|
182
|
+
_horizon_floor = xp.asarray(1e-16) # Near-zero for beyond-hemisphere angles
|
|
183
|
+
|
|
184
|
+
if non_rigid_baffle:
|
|
185
|
+
if baffle == BaffleType.SOFT:
|
|
186
|
+
obliquity_factor = xp.cos(theta_arr)
|
|
187
|
+
else:
|
|
188
|
+
cos_th = xp.cos(theta_arr)
|
|
189
|
+
obliquity_factor = cos_th / (cos_th + float(baffle))
|
|
190
|
+
else:
|
|
191
|
+
obliquity_factor = xp.ones(theta_arr.shape)
|
|
192
|
+
|
|
193
|
+
obliquity_factor = xp.where(
|
|
194
|
+
xp.abs(theta_arr) >= xp.asarray(pi / 2),
|
|
195
|
+
_horizon_floor,
|
|
196
|
+
obliquity_factor,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
return obliquity_factor
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@jaxtyped(typechecker=typechecker)
|
|
203
|
+
def _init_exponentials(
|
|
204
|
+
freq_start: float | Float[Array, ""],
|
|
205
|
+
speed_of_sound: float,
|
|
206
|
+
attenuation: float,
|
|
207
|
+
distances: Float[Array, "*batch n_elements n_sub"],
|
|
208
|
+
obliquity_factor: Float[Array, "*batch n_elements n_sub"],
|
|
209
|
+
freq_step: float | Float[Array, ""],
|
|
210
|
+
xp: _ArrayNamespace,
|
|
211
|
+
) -> tuple[
|
|
212
|
+
Complex[Array, "*batch n_elements n_sub"],
|
|
213
|
+
Complex[Array, "*batch n_elements n_sub"],
|
|
214
|
+
]:
|
|
215
|
+
"""Initialize exponential arrays for frequency loop.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
freq_start: Initial frequency in Hz (scalar or 0-d array).
|
|
219
|
+
speed_of_sound: Speed of sound in m/s.
|
|
220
|
+
attenuation: Attenuation coefficient in dB/cm/MHz.
|
|
221
|
+
distances: Distances.
|
|
222
|
+
obliquity_factor: Obliquity factor.
|
|
223
|
+
freq_step: Frequency step in Hz (scalar or 0-d array).
|
|
224
|
+
xp: Array namespace.
|
|
225
|
+
|
|
226
|
+
Returns:
|
|
227
|
+
Tuple of (phase_decay, phase_decay_step):
|
|
228
|
+
- phase_decay: Initial complex exponential array.
|
|
229
|
+
- phase_decay_step: Complex exponential increment per frequency step.
|
|
230
|
+
"""
|
|
231
|
+
wavenumber_init = 2.0 * pi * freq_start / speed_of_sound
|
|
232
|
+
attenuation_wavenum = attenuation / NEPER_TO_DB * freq_start / 1e6 * 1e2
|
|
233
|
+
|
|
234
|
+
# exp(-kwa*distances + 1j*mod(kw*distances, 2pi))
|
|
235
|
+
kw0_r = xp.asarray(wavenumber_init) * distances
|
|
236
|
+
two_pi = xp.asarray(2.0 * pi)
|
|
237
|
+
phase_mod = kw0_r - two_pi * xp.floor(kw0_r / two_pi)
|
|
238
|
+
phase_decay = xp.exp(xp.asarray(-attenuation_wavenum) * distances + xp.asarray(1j) * phase_mod)
|
|
239
|
+
|
|
240
|
+
wavenumber_step = 2.0 * pi * freq_step / speed_of_sound
|
|
241
|
+
attenuation_step = attenuation / NEPER_TO_DB * freq_step / 1e6 * 1e2
|
|
242
|
+
phase_decay_step = xp.exp(xp.asarray(-attenuation_step + 1j * wavenumber_step) * distances)
|
|
243
|
+
|
|
244
|
+
# Incorporate obliquity / sqrt(distances) (2D, no elevation)
|
|
245
|
+
phase_decay = phase_decay * obliquity_factor / xp.sqrt(distances)
|
|
246
|
+
|
|
247
|
+
return phase_decay, phase_decay_step
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _first_last_true(xp: _ArrayNamespace, mask: Array) -> tuple[int, int]:
|
|
251
|
+
"""Find first and last True index in 1D boolean array. JAX-compatible (no nonzero)."""
|
|
252
|
+
n = mask.shape[0]
|
|
253
|
+
if n == 0:
|
|
254
|
+
return 0, 0
|
|
255
|
+
# Cast to int: argmax on bool not allowed by array_api_strict
|
|
256
|
+
mask_int = xp.asarray(mask, dtype=xp.int32)
|
|
257
|
+
first = int(xp.argmax(mask_int))
|
|
258
|
+
if int(xp.max(mask_int)) == 0:
|
|
259
|
+
return 0, 0
|
|
260
|
+
last = n - 1 - int(xp.argmax(mask_int[::-1]))
|
|
261
|
+
return first, last
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""Loop drivers for pfield frequency sweep (Layer 3).
|
|
2
|
+
|
|
3
|
+
Each driver iterates the same _freq_step_body() using a different mechanism:
|
|
4
|
+
|
|
5
|
+
- _freq_outer_python: Python for-loop (NumPy/CuPy, constant memory)
|
|
6
|
+
- _pfield_freq_vectorized: tensor broadcast (small grids only)
|
|
7
|
+
- _freq_outer_scan: JAX lax.scan for O(1) compilation cost
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from math import pi
|
|
13
|
+
|
|
14
|
+
import array_api_extra as xpx
|
|
15
|
+
from beartype import beartype as typechecker
|
|
16
|
+
from jaxtyping import Bool, Complex, Float, jaxtyped
|
|
17
|
+
|
|
18
|
+
from fast_simus.utils._array_api import Array, _ArrayNamespace
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _freq_step_body(
|
|
22
|
+
phase: Complex[Array, " *grid n_sources"],
|
|
23
|
+
phase_step: Complex[Array, " *grid n_sources"],
|
|
24
|
+
spectrum_k: complex | Array,
|
|
25
|
+
xp: _ArrayNamespace,
|
|
26
|
+
*,
|
|
27
|
+
directivity_k: Float[Array, " *grid n_sources"] | None = None,
|
|
28
|
+
) -> tuple[Complex[Array, " *grid n_sources"], Float[Array, " *grid"]]:
|
|
29
|
+
"""One frequency step: geometric update, source contraction, spectrum weight.
|
|
30
|
+
|
|
31
|
+
Single source of truth for per-frequency math. Source points are
|
|
32
|
+
already flattened (n_elements * n_sub) with 1/n_sub absorbed.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
phase: Current phase state (geometric progression).
|
|
36
|
+
phase_step: Per-step multiplier for geometric progression.
|
|
37
|
+
spectrum_k: Combined pulse*probe spectrum weight for this frequency.
|
|
38
|
+
xp: Array namespace.
|
|
39
|
+
directivity_k: Per-source directivity for this frequency (optional).
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Tuple of (updated_phase, rp_k) where rp_k = |P_k|^2 at this frequency.
|
|
43
|
+
"""
|
|
44
|
+
if directivity_k is not None:
|
|
45
|
+
phase_weighted = phase * directivity_k
|
|
46
|
+
else:
|
|
47
|
+
phase_weighted = phase
|
|
48
|
+
pressure_k = spectrum_k * xp.sum(phase_weighted, axis=-1)
|
|
49
|
+
phase = phase * phase_step
|
|
50
|
+
return phase, xp.real(pressure_k * xp.conj(pressure_k))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _freq_outer_python(
|
|
54
|
+
phase_decay_init: Complex[Array, " *grid n_sources"],
|
|
55
|
+
phase_decay_step: Complex[Array, " *grid n_sources"],
|
|
56
|
+
is_out: Bool[Array, " *grid"],
|
|
57
|
+
wavenumbers: Float[Array, " n_freq"],
|
|
58
|
+
pulse_spect: Complex[Array, " n_freq"],
|
|
59
|
+
probe_spect: Float[Array, " n_freq"],
|
|
60
|
+
seg_length: float,
|
|
61
|
+
sin_theta: Float[Array, " *grid n_sources"],
|
|
62
|
+
full_frequency_directivity: bool,
|
|
63
|
+
xp: _ArrayNamespace,
|
|
64
|
+
) -> Float[Array, " *grid"]:
|
|
65
|
+
"""Python for-loop driver for NumPy/CuPy: constant O(grid * sources) memory.
|
|
66
|
+
|
|
67
|
+
Iterates one frequency at a time using _freq_step_body, accumulating
|
|
68
|
+
|P_k|^2 into the result. Peak memory is independent of n_freq.
|
|
69
|
+
"""
|
|
70
|
+
spectra = pulse_spect * probe_spect
|
|
71
|
+
n_freq = int(wavenumbers.shape[0])
|
|
72
|
+
zero = xp.asarray(0.0)
|
|
73
|
+
|
|
74
|
+
phase = phase_decay_init
|
|
75
|
+
rp = xp.zeros(phase.shape[:-1])
|
|
76
|
+
|
|
77
|
+
if full_frequency_directivity:
|
|
78
|
+
for k in range(n_freq):
|
|
79
|
+
sinc_arg = wavenumbers[k] * seg_length / 2.0 * sin_theta / pi
|
|
80
|
+
directivity_k = xpx.sinc(sinc_arg, xp=xp)
|
|
81
|
+
phase, rp_k = _freq_step_body(phase, phase_decay_step, spectra[k], xp, directivity_k=directivity_k)
|
|
82
|
+
rp = rp + xp.where(is_out, zero, rp_k)
|
|
83
|
+
else:
|
|
84
|
+
for k in range(n_freq):
|
|
85
|
+
phase, rp_k = _freq_step_body(phase, phase_decay_step, spectra[k], xp)
|
|
86
|
+
rp = rp + xp.where(is_out, zero, rp_k)
|
|
87
|
+
|
|
88
|
+
return rp
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@jaxtyped(typechecker=typechecker)
|
|
92
|
+
def _pfield_freq_vectorized(
|
|
93
|
+
phase_decay_init: Complex[Array, " *grid n_sources"],
|
|
94
|
+
phase_decay_step: Complex[Array, " *grid n_sources"],
|
|
95
|
+
is_out: Bool[Array, " *grid"],
|
|
96
|
+
wavenumbers: Float[Array, " n_freq"],
|
|
97
|
+
pulse_spect: Complex[Array, " n_freq"],
|
|
98
|
+
probe_spect: Float[Array, " n_freq"],
|
|
99
|
+
seg_length: float,
|
|
100
|
+
sin_theta: Float[Array, " *grid n_sources"],
|
|
101
|
+
full_frequency_directivity: bool,
|
|
102
|
+
xp: _ArrayNamespace,
|
|
103
|
+
) -> Float[Array, " *grid"]:
|
|
104
|
+
"""Vectorized frequency sweep: broadcast all frequencies at once.
|
|
105
|
+
|
|
106
|
+
Reference implementation kept for testing and small-grid use cases.
|
|
107
|
+
Production code uses _freq_outer_python (constant memory) instead.
|
|
108
|
+
|
|
109
|
+
Uses the geometric progression: phase_decay[k] = init * step^k.
|
|
110
|
+
Source points are pre-flattened with 1/n_sub absorbed.
|
|
111
|
+
|
|
112
|
+
Best for small grids where the (*grid, n_sources, n_freq) tensor fits
|
|
113
|
+
in memory. For large grids, use an iterative driver instead.
|
|
114
|
+
"""
|
|
115
|
+
n_freq = wavenumbers.shape[0]
|
|
116
|
+
exponents = xp.arange(n_freq, dtype=wavenumbers.dtype)
|
|
117
|
+
|
|
118
|
+
# (*grid, n_sources, n_freq) via geometric progression
|
|
119
|
+
phase_k = phase_decay_init[..., None] * phase_decay_step[..., None] ** exponents
|
|
120
|
+
|
|
121
|
+
if full_frequency_directivity:
|
|
122
|
+
sinc_arg = wavenumbers * seg_length / 2.0 * sin_theta[..., None] / pi
|
|
123
|
+
phase_k = xpx.sinc(sinc_arg, xp=xp) * phase_k
|
|
124
|
+
|
|
125
|
+
# Contract over sources: (*grid, n_sources, n_freq) -> (*grid, n_freq)
|
|
126
|
+
pressure_all = xp.sum(phase_k, axis=-2)
|
|
127
|
+
|
|
128
|
+
pressure_all = pulse_spect * probe_spect * pressure_all
|
|
129
|
+
|
|
130
|
+
pressure_all = xp.where(is_out[..., None], xp.asarray(0.0 + 0j), pressure_all)
|
|
131
|
+
|
|
132
|
+
return xp.sum(xp.abs(pressure_all) ** 2, axis=-1)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _freq_outer_scan(
|
|
136
|
+
phase_decay_init: Complex[Array, " *grid n_sources"],
|
|
137
|
+
phase_decay_step: Complex[Array, " *grid n_sources"],
|
|
138
|
+
is_out: Bool[Array, " *grid"],
|
|
139
|
+
wavenumbers: Float[Array, " n_freq"],
|
|
140
|
+
pulse_spect: Complex[Array, " n_freq"],
|
|
141
|
+
probe_spect: Float[Array, " n_freq"],
|
|
142
|
+
seg_length: float,
|
|
143
|
+
sin_theta: Float[Array, " *grid n_sources"],
|
|
144
|
+
full_frequency_directivity: bool,
|
|
145
|
+
xp: _ArrayNamespace,
|
|
146
|
+
) -> Float[Array, " *grid"]:
|
|
147
|
+
"""JAX vmap+scan driver: per-grid-point kernel vmapped over the grid.
|
|
148
|
+
|
|
149
|
+
The scan carry is only (n_sources,) per grid point, matching the Metal
|
|
150
|
+
kernel's per-thread model. vmap handles parallelism across grid points.
|
|
151
|
+
"""
|
|
152
|
+
import jax
|
|
153
|
+
import jax.numpy as jnp
|
|
154
|
+
|
|
155
|
+
grid_shape = phase_decay_init.shape[:-1]
|
|
156
|
+
n_sources = phase_decay_init.shape[-1]
|
|
157
|
+
|
|
158
|
+
# Flatten grid to (n_grid, n_sources) and (n_grid,)
|
|
159
|
+
init_flat = xp.reshape(phase_decay_init, (-1, n_sources))
|
|
160
|
+
step_flat = xp.reshape(phase_decay_step, (-1, n_sources))
|
|
161
|
+
is_out_flat = xp.reshape(is_out, (-1,))
|
|
162
|
+
sin_theta_flat = xp.reshape(sin_theta, (-1, n_sources))
|
|
163
|
+
|
|
164
|
+
spectra = pulse_spect * probe_spect
|
|
165
|
+
|
|
166
|
+
if full_frequency_directivity:
|
|
167
|
+
|
|
168
|
+
def _single_point_scan(
|
|
169
|
+
phase_init_g: jax.Array, phase_step_g: jax.Array, is_out_g: jax.Array, sin_theta_g: jax.Array
|
|
170
|
+
) -> jax.Array:
|
|
171
|
+
def scan_fn(carry, k):
|
|
172
|
+
phase, rp = carry
|
|
173
|
+
sinc_arg = wavenumbers[k] * seg_length / 2.0 * sin_theta_g / pi
|
|
174
|
+
directivity_k = xpx.sinc(sinc_arg, xp=xp)
|
|
175
|
+
p_k = spectra[k] * xp.sum(phase * directivity_k)
|
|
176
|
+
rp_k = xp.real(p_k * xp.conj(p_k))
|
|
177
|
+
rp = rp + xp.where(is_out_g, xp.asarray(0.0), rp_k)
|
|
178
|
+
phase = phase * phase_step_g
|
|
179
|
+
return (phase, rp), None
|
|
180
|
+
|
|
181
|
+
(_, rp), _ = jax.lax.scan(scan_fn, (phase_init_g, jnp.float32(0.0)), jnp.arange(spectra.shape[0]))
|
|
182
|
+
return rp
|
|
183
|
+
|
|
184
|
+
rp_flat = jax.vmap(_single_point_scan)(init_flat, step_flat, is_out_flat, sin_theta_flat)
|
|
185
|
+
else:
|
|
186
|
+
|
|
187
|
+
def _single_point_scan_no_dir(
|
|
188
|
+
phase_init_g: jax.Array, phase_step_g: jax.Array, is_out_g: jax.Array
|
|
189
|
+
) -> jax.Array:
|
|
190
|
+
def scan_fn(carry, spectrum_k):
|
|
191
|
+
phase, rp = carry
|
|
192
|
+
p_k = spectrum_k * xp.sum(phase)
|
|
193
|
+
rp_k = xp.real(p_k * xp.conj(p_k))
|
|
194
|
+
rp = rp + xp.where(is_out_g, xp.asarray(0.0), rp_k)
|
|
195
|
+
phase = phase * phase_step_g
|
|
196
|
+
return (phase, rp), None
|
|
197
|
+
|
|
198
|
+
(_, rp), _ = jax.lax.scan(scan_fn, (phase_init_g, jnp.float32(0.0)), spectra)
|
|
199
|
+
return rp
|
|
200
|
+
|
|
201
|
+
rp_flat = jax.vmap(_single_point_scan_no_dir)(init_flat, step_flat, is_out_flat)
|
|
202
|
+
|
|
203
|
+
return xp.reshape(rp_flat, grid_shape)
|