FastSIMUS 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_simus/__init__.py +33 -0
- fast_simus/_pfield_math.py +261 -0
- fast_simus/_pfield_strategies.py +203 -0
- fast_simus/_simus_strategies.py +210 -0
- fast_simus/backends/__init__.py +1 -0
- fast_simus/backends/mlx.py +101 -0
- fast_simus/kernels/__init__.py +9 -0
- fast_simus/kernels/cuda_simus.py +321 -0
- fast_simus/kernels/metal_pfield.py +219 -0
- fast_simus/kernels/metal_simus.py +377 -0
- fast_simus/kernels/pfield.metal +97 -0
- fast_simus/kernels/simus_fused.cu +332 -0
- fast_simus/kernels/simus_rx_simd.metal +128 -0
- fast_simus/kernels/simus_tx_tiled.metal +175 -0
- fast_simus/medium_params.py +22 -0
- fast_simus/pfield.py +475 -0
- fast_simus/py.typed +0 -0
- fast_simus/simus.py +567 -0
- fast_simus/spectrum.py +107 -0
- fast_simus/transducer_params.py +160 -0
- fast_simus/transducer_presets.py +102 -0
- fast_simus/tx_delay.py +276 -0
- fast_simus/utils/__init__.py +5 -0
- fast_simus/utils/_array_api.py +294 -0
- fast_simus/utils/geometry.py +88 -0
- fastsimus-0.0.1.dist-info/METADATA +594 -0
- fastsimus-0.0.1.dist-info/RECORD +28 -0
- fastsimus-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
"""Loop drivers for simus frequency sweep (Layer 3).
|
|
2
|
+
|
|
3
|
+
Each driver iterates _simus_freq_step_body() using a different mechanism:
|
|
4
|
+
|
|
5
|
+
- _simus_freq_outer_python: Python for-loop (NumPy/CuPy, constant memory)
|
|
6
|
+
- _simus_freq_outer_scan: JAX lax.scan for O(1) compilation cost
|
|
7
|
+
|
|
8
|
+
The simus step body differs from pfield's: instead of accumulating |P_k|^2
|
|
9
|
+
per grid point, it computes the full TX->scatter->RX chain and accumulates
|
|
10
|
+
complex RF spectrum per element.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
from math import pi
|
|
16
|
+
|
|
17
|
+
import array_api_extra as xpx
|
|
18
|
+
from jaxtyping import Bool, Complex, Float
|
|
19
|
+
|
|
20
|
+
from fast_simus.utils._array_api import Array, _ArrayNamespace
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _simus_freq_step_body(
|
|
24
|
+
phase: Complex[Array, "n_scat n_elem n_sub"],
|
|
25
|
+
phase_step: Complex[Array, "n_scat n_elem n_sub"],
|
|
26
|
+
delay_apod_phase: Complex[Array, " n_elem"],
|
|
27
|
+
delay_apod_step: Complex[Array, " n_elem"],
|
|
28
|
+
rc: Float[Array, " n_scat"],
|
|
29
|
+
pulse_probe_k: complex | Array,
|
|
30
|
+
probe_k: float | Array,
|
|
31
|
+
is_out: Bool[Array, " n_scat"],
|
|
32
|
+
xp: _ArrayNamespace,
|
|
33
|
+
*,
|
|
34
|
+
directivity_k: Float[Array, "n_scat n_elem n_sub"] | None = None,
|
|
35
|
+
) -> tuple[
|
|
36
|
+
Complex[Array, "n_scat n_elem n_sub"],
|
|
37
|
+
Complex[Array, " n_elem"],
|
|
38
|
+
Complex[Array, " n_elem"],
|
|
39
|
+
]:
|
|
40
|
+
"""One frequency step: TX forward, scatter, RX backprop.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
phase: Geometric progression state (n_scat, n_elem, n_sub).
|
|
44
|
+
phase_step: Per-step multiplier for geometric progression.
|
|
45
|
+
delay_apod_phase: Current delay+apodization phase per element.
|
|
46
|
+
delay_apod_step: Per-step delay+apodization multiplier.
|
|
47
|
+
rc: Reflection coefficients per scatterer.
|
|
48
|
+
pulse_probe_k: Combined pulse*probe spectrum weight for this frequency.
|
|
49
|
+
probe_k: Probe-only spectrum weight for RX.
|
|
50
|
+
is_out: Boolean mask for out-of-field scatterers.
|
|
51
|
+
xp: Array namespace.
|
|
52
|
+
directivity_k: Per-source directivity (optional, for full_frequency_directivity).
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Tuple of (updated_phase, updated_delay_apod, spect_k) where
|
|
56
|
+
spect_k is the complex RF spectrum contribution for this frequency,
|
|
57
|
+
shape (n_elements,).
|
|
58
|
+
"""
|
|
59
|
+
if directivity_k is not None:
|
|
60
|
+
rp_mono = xp.mean(phase * directivity_k, axis=-1)
|
|
61
|
+
else:
|
|
62
|
+
rp_mono = xp.mean(phase, axis=-1)
|
|
63
|
+
|
|
64
|
+
# TX: contract over elements -> pressure at each scatterer
|
|
65
|
+
p_k = pulse_probe_k * (rp_mono @ delay_apod_phase[..., None])[..., 0]
|
|
66
|
+
p_k = xp.where(is_out, xp.asarray(0.0 + 0j), p_k)
|
|
67
|
+
|
|
68
|
+
# RX: contract over scatterers -> spectrum per element
|
|
69
|
+
# (rc * p_k)^T @ rp_mono = sum_i(rc_i * p_k_i * rp_mono[i, e])
|
|
70
|
+
weighted = rc * p_k
|
|
71
|
+
spect_k = weighted @ rp_mono
|
|
72
|
+
spect_k = probe_k * spect_k
|
|
73
|
+
|
|
74
|
+
phase = phase * phase_step
|
|
75
|
+
delay_apod_phase = delay_apod_phase * delay_apod_step
|
|
76
|
+
|
|
77
|
+
return phase, delay_apod_phase, spect_k
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _simus_freq_outer_python(
|
|
81
|
+
phase_init: Complex[Array, "n_scat n_elem n_sub"],
|
|
82
|
+
phase_step: Complex[Array, "n_scat n_elem n_sub"],
|
|
83
|
+
delay_apod_init: Complex[Array, " n_elem"],
|
|
84
|
+
delay_apod_step: Complex[Array, " n_elem"],
|
|
85
|
+
rc: Float[Array, " n_scat"],
|
|
86
|
+
is_out: Bool[Array, " n_scat"],
|
|
87
|
+
wavenumbers: Float[Array, " n_freq"],
|
|
88
|
+
pulse_spect: Complex[Array, " n_freq"],
|
|
89
|
+
probe_spect: Float[Array, " n_freq"],
|
|
90
|
+
seg_length: float,
|
|
91
|
+
sin_theta: Float[Array, "n_scat n_elem n_sub"],
|
|
92
|
+
full_frequency_directivity: bool,
|
|
93
|
+
xp: _ArrayNamespace,
|
|
94
|
+
) -> Complex[Array, "n_freq n_elem"]:
|
|
95
|
+
"""Python for-loop driver: iterates one frequency at a time.
|
|
96
|
+
|
|
97
|
+
Accumulates the complex RF spectrum (n_freq, n_elements).
|
|
98
|
+
Peak memory is O(n_scat * n_elem * n_sub), independent of n_freq.
|
|
99
|
+
"""
|
|
100
|
+
spectra = pulse_spect * probe_spect
|
|
101
|
+
n_freq = int(wavenumbers.shape[0])
|
|
102
|
+
n_elem = phase_init.shape[1]
|
|
103
|
+
|
|
104
|
+
spect_accum = xp.zeros((n_freq, n_elem), dtype=phase_init.dtype)
|
|
105
|
+
phase = phase_init
|
|
106
|
+
delay_apod_phase = delay_apod_init
|
|
107
|
+
|
|
108
|
+
if full_frequency_directivity:
|
|
109
|
+
for k in range(n_freq):
|
|
110
|
+
sinc_arg = wavenumbers[k] * seg_length / 2.0 * sin_theta / pi
|
|
111
|
+
directivity_k = xpx.sinc(sinc_arg, xp=xp)
|
|
112
|
+
phase, delay_apod_phase, spect_k = _simus_freq_step_body(
|
|
113
|
+
phase,
|
|
114
|
+
phase_step,
|
|
115
|
+
delay_apod_phase,
|
|
116
|
+
delay_apod_step,
|
|
117
|
+
rc,
|
|
118
|
+
spectra[k],
|
|
119
|
+
probe_spect[k],
|
|
120
|
+
is_out,
|
|
121
|
+
xp,
|
|
122
|
+
directivity_k=directivity_k,
|
|
123
|
+
)
|
|
124
|
+
spect_accum = _set_row(spect_accum, k, spect_k)
|
|
125
|
+
else:
|
|
126
|
+
for k in range(n_freq):
|
|
127
|
+
phase, delay_apod_phase, spect_k = _simus_freq_step_body(
|
|
128
|
+
phase,
|
|
129
|
+
phase_step,
|
|
130
|
+
delay_apod_phase,
|
|
131
|
+
delay_apod_step,
|
|
132
|
+
rc,
|
|
133
|
+
spectra[k],
|
|
134
|
+
probe_spect[k],
|
|
135
|
+
is_out,
|
|
136
|
+
xp,
|
|
137
|
+
)
|
|
138
|
+
spect_accum = _set_row(spect_accum, k, spect_k)
|
|
139
|
+
|
|
140
|
+
return spect_accum
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _set_row(
|
|
144
|
+
arr: Complex[Array, "n_freq n_elem"],
|
|
145
|
+
k: int,
|
|
146
|
+
row: Complex[Array, " n_elem"],
|
|
147
|
+
) -> Complex[Array, "n_freq n_elem"]:
|
|
148
|
+
"""Set row k of arr to row, Array API compatible."""
|
|
149
|
+
return xpx.at(arr)[k, :].set(row) # type: ignore[attr-defined]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def _simus_freq_outer_scan(
|
|
153
|
+
phase_init: Complex[Array, "n_scat n_elem n_sub"],
|
|
154
|
+
phase_step: Complex[Array, "n_scat n_elem n_sub"],
|
|
155
|
+
delay_apod_init: Complex[Array, " n_elem"],
|
|
156
|
+
delay_apod_step: Complex[Array, " n_elem"],
|
|
157
|
+
rc: Float[Array, " n_scat"],
|
|
158
|
+
is_out: Bool[Array, " n_scat"],
|
|
159
|
+
wavenumbers: Float[Array, " n_freq"],
|
|
160
|
+
pulse_spect: Complex[Array, " n_freq"],
|
|
161
|
+
probe_spect: Float[Array, " n_freq"],
|
|
162
|
+
seg_length: float,
|
|
163
|
+
sin_theta: Float[Array, "n_scat n_elem n_sub"],
|
|
164
|
+
full_frequency_directivity: bool,
|
|
165
|
+
xp: _ArrayNamespace,
|
|
166
|
+
) -> Complex[Array, "n_freq n_elem"]:
|
|
167
|
+
"""JAX lax.scan driver: scan over frequencies with full tensor carry.
|
|
168
|
+
|
|
169
|
+
The carry holds (phase, delay_apod_phase) with shapes
|
|
170
|
+
(n_scat, n_elem, n_sub) and (n_elem,). Each step outputs
|
|
171
|
+
spect_k with shape (n_elem,), stacked by scan into (n_freq, n_elem).
|
|
172
|
+
"""
|
|
173
|
+
import jax
|
|
174
|
+
|
|
175
|
+
spectra = pulse_spect * probe_spect
|
|
176
|
+
|
|
177
|
+
if full_frequency_directivity:
|
|
178
|
+
|
|
179
|
+
def scan_fn(carry, xs):
|
|
180
|
+
phase, delay_apod = carry
|
|
181
|
+
spectrum_k, probe_k, wavenum_k = xs
|
|
182
|
+
sinc_arg = wavenum_k * seg_length / 2.0 * sin_theta / pi
|
|
183
|
+
directivity_k = xpx.sinc(sinc_arg, xp=xp)
|
|
184
|
+
rp_mono = xp.mean(phase * directivity_k, axis=-1)
|
|
185
|
+
p_k = spectrum_k * (rp_mono @ delay_apod[..., None])[..., 0]
|
|
186
|
+
p_k = xp.where(is_out, xp.asarray(0.0 + 0j), p_k)
|
|
187
|
+
spect_k = probe_k * (rc * p_k) @ rp_mono
|
|
188
|
+
phase = phase * phase_step
|
|
189
|
+
delay_apod = delay_apod * delay_apod_step
|
|
190
|
+
return (phase, delay_apod), spect_k
|
|
191
|
+
|
|
192
|
+
init_carry = (phase_init, delay_apod_init)
|
|
193
|
+
_, spect_all = jax.lax.scan(scan_fn, init_carry, (spectra, probe_spect, wavenumbers))
|
|
194
|
+
else:
|
|
195
|
+
|
|
196
|
+
def scan_fn_no_dir(carry, xs):
|
|
197
|
+
phase, delay_apod = carry
|
|
198
|
+
spectrum_k, probe_k = xs
|
|
199
|
+
rp_mono = xp.mean(phase, axis=-1)
|
|
200
|
+
p_k = spectrum_k * (rp_mono @ delay_apod[..., None])[..., 0]
|
|
201
|
+
p_k = xp.where(is_out, xp.asarray(0.0 + 0j), p_k)
|
|
202
|
+
spect_k = probe_k * (rc * p_k) @ rp_mono
|
|
203
|
+
phase = phase * phase_step
|
|
204
|
+
delay_apod = delay_apod * delay_apod_step
|
|
205
|
+
return (phase, delay_apod), spect_k
|
|
206
|
+
|
|
207
|
+
init_carry = (phase_init, delay_apod_init)
|
|
208
|
+
_, spect_all = jax.lax.scan(scan_fn_no_dir, init_carry, (spectra, probe_spect))
|
|
209
|
+
|
|
210
|
+
return spect_all
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Backend-specific compatibility modules."""
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""Array API compatibility shim for MLX.
|
|
2
|
+
|
|
3
|
+
Temporary until array_api_compat gains native MLX support.
|
|
4
|
+
Tracking: https://github.com/data-apis/array-api-compat/issues/162
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import contextlib
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import array_api_compat
|
|
13
|
+
import array_api_compat.common._helpers as _helpers
|
|
14
|
+
|
|
15
|
+
_MLX_ARRAY_API_ALIASES: dict[str, str] = {
|
|
16
|
+
"asin": "arcsin",
|
|
17
|
+
"acos": "arccos",
|
|
18
|
+
"atan2": "arctan2",
|
|
19
|
+
"bool": "bool_",
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
_MLX_ISDTYPE_KIND_MAP: dict[str, str] = {
|
|
23
|
+
"bool": "bool_",
|
|
24
|
+
"signed integer": "signedinteger",
|
|
25
|
+
"unsigned integer": "unsignedinteger",
|
|
26
|
+
"integral": "integer",
|
|
27
|
+
"real floating": "floating",
|
|
28
|
+
"complex floating": "complexfloating",
|
|
29
|
+
"numeric": "number",
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _make_isdtype(xp: Any) -> Any:
|
|
34
|
+
def isdtype(dtype: Any, kind: Any) -> bool:
|
|
35
|
+
if isinstance(kind, str):
|
|
36
|
+
category = _MLX_ISDTYPE_KIND_MAP.get(kind)
|
|
37
|
+
if category is None:
|
|
38
|
+
msg = f"Unrecognized dtype kind: {kind!r}"
|
|
39
|
+
raise ValueError(msg)
|
|
40
|
+
return bool(xp.issubdtype(dtype, getattr(xp, category)))
|
|
41
|
+
if isinstance(kind, tuple):
|
|
42
|
+
return any(isdtype(dtype, k) for k in kind)
|
|
43
|
+
return dtype == kind
|
|
44
|
+
|
|
45
|
+
return isdtype
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _patch_namespace(xp: Any) -> None:
|
|
49
|
+
"""Add Array API aliases to mlx.core (idempotent)."""
|
|
50
|
+
for standard_name, mlx_name in _MLX_ARRAY_API_ALIASES.items():
|
|
51
|
+
if not hasattr(xp, standard_name) and hasattr(xp, mlx_name):
|
|
52
|
+
setattr(xp, standard_name, getattr(xp, mlx_name))
|
|
53
|
+
|
|
54
|
+
if not hasattr(xp, "isdtype") and hasattr(xp, "issubdtype"):
|
|
55
|
+
xp.isdtype = _make_isdtype(xp)
|
|
56
|
+
|
|
57
|
+
if not hasattr(xp, "astype"):
|
|
58
|
+
|
|
59
|
+
def _astype(x: Any, dtype: Any, /, *, copy: bool = False) -> Any:
|
|
60
|
+
return x.astype(dtype)
|
|
61
|
+
|
|
62
|
+
xp.astype = _astype
|
|
63
|
+
|
|
64
|
+
if not getattr(xp.asarray, "_fastsimus_wrapped", False):
|
|
65
|
+
_original = xp.asarray
|
|
66
|
+
|
|
67
|
+
def _asarray(a: Any, *, dtype: Any = None, **_kwargs: Any) -> Any:
|
|
68
|
+
if dtype is not None:
|
|
69
|
+
return _original(a, dtype=dtype)
|
|
70
|
+
return _original(a)
|
|
71
|
+
|
|
72
|
+
_asarray._fastsimus_wrapped = True # type: ignore[attr-defined]
|
|
73
|
+
xp.asarray = _asarray
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def _patch_device(xp: Any) -> None:
|
|
77
|
+
"""Patch array_api_compat device() for MLX unified memory."""
|
|
78
|
+
_original = _helpers.device
|
|
79
|
+
if getattr(_original, "_fastsimus_mlx", False):
|
|
80
|
+
return
|
|
81
|
+
|
|
82
|
+
def _device_with_mlx(x: Any, /) -> Any:
|
|
83
|
+
if type(x).__module__.startswith("mlx"):
|
|
84
|
+
return xp.default_device()
|
|
85
|
+
return _original(x)
|
|
86
|
+
|
|
87
|
+
_device_with_mlx._fastsimus_mlx = True # type: ignore[attr-defined]
|
|
88
|
+
|
|
89
|
+
_helpers.device = _device_with_mlx # type: ignore[assignment]
|
|
90
|
+
array_api_compat.device = _device_with_mlx # type: ignore[assignment]
|
|
91
|
+
|
|
92
|
+
with contextlib.suppress(ImportError):
|
|
93
|
+
import array_api_extra._lib._utils._compat as _xpx_compat # type: ignore[import-untyped]
|
|
94
|
+
|
|
95
|
+
_xpx_compat.device = _device_with_mlx
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def ensure_compat(xp: Any) -> None:
|
|
99
|
+
"""Apply all MLX compatibility patches (idempotent)."""
|
|
100
|
+
_patch_namespace(xp)
|
|
101
|
+
_patch_device(xp)
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Backend-specific fused kernels for FastSIMUS.
|
|
2
|
+
|
|
3
|
+
Custom kernels provide maximum performance by fusing the entire computation
|
|
4
|
+
into a single GPU dispatch. Each kernel is a different algorithm from the
|
|
5
|
+
Array API path (e.g., on-the-fly geometry instead of precomputed arrays).
|
|
6
|
+
|
|
7
|
+
Available kernels:
|
|
8
|
+
- metal_pfield: Apple Silicon Metal kernel for pfield (requires MLX)
|
|
9
|
+
"""
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
"""CuPy CUDA backend for simus.
|
|
2
|
+
|
|
3
|
+
Compiles the v25c register-resident TX kernel via NVRTC at runtime
|
|
4
|
+
(``cupy.RawModule``) -- no nanobind, no setuptools build step. Pinned to
|
|
5
|
+
``(B_SCAT=10, ELEM_TILE=2)`` for RTX 4090 / sm_89 / P4-2v; performance may
|
|
6
|
+
regress on other probes / GPUs (see exp22 + the FastSIMUS-cuda-tune
|
|
7
|
+
follow-up).
|
|
8
|
+
|
|
9
|
+
Output layout matches ``metal_simus.simus_metal``: complex64
|
|
10
|
+
``(n_freq, n_elements)``. The shipped kernel does its own per-scatterer
|
|
11
|
+
Phase-1 geometry from a flat input set, so the
|
|
12
|
+
``(n_scat, n_elem, n_sub)`` phase tensors that ``_simus_freq_outer_python``
|
|
13
|
+
consumes are *not* fed in here.
|
|
14
|
+
|
|
15
|
+
Requires: CuPy on a CUDA host. Use ``cupy-cuda12x`` for CUDA 12/Pascal
|
|
16
|
+
hosts and ``cupy-cuda13x`` for CUDA 13/Turing-or-newer hosts.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from math import inf, pi
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
24
|
+
|
|
25
|
+
import cupy as cp
|
|
26
|
+
|
|
27
|
+
from fast_simus._pfield_math import NEPER_TO_DB, _subelement_centroids
|
|
28
|
+
from fast_simus.medium_params import MediumParams
|
|
29
|
+
from fast_simus.transducer_params import TransducerParams
|
|
30
|
+
from fast_simus.utils._array_api import _ArrayNamespace
|
|
31
|
+
from fast_simus.utils.geometry import element_positions
|
|
32
|
+
|
|
33
|
+
if TYPE_CHECKING:
|
|
34
|
+
from fast_simus.simus import SimusPlan
|
|
35
|
+
|
|
36
|
+
_KERNELS_DIR = Path(__file__).parent
|
|
37
|
+
_SOURCE_NAME = "simus_fused.cu"
|
|
38
|
+
|
|
39
|
+
# Pinned tuning -- see docs/progress/experiments/exp22-svshmem-et2.md.
|
|
40
|
+
# These constants are RTX 4090 / sm_89 / P4-2v optimal; not autotuned.
|
|
41
|
+
_B_SCAT = 10
|
|
42
|
+
_ELEM_TILE = 2
|
|
43
|
+
_TG_SIZE = 128
|
|
44
|
+
_TILE_SE = 16
|
|
45
|
+
_GRID_BLOCKS = 256 # 2 * 128 SMs on RTX 4090
|
|
46
|
+
|
|
47
|
+
# CuPy / NVRTC auto-derives ``--gpu-architecture`` from the current device,
|
|
48
|
+
# so we don't pin it here. Tuning constants (B_SCAT, ELEM_TILE, TG_SIZE)
|
|
49
|
+
# are still hardwired for sm_89 and may need adjustment for sm_80 / sm_90.
|
|
50
|
+
|
|
51
|
+
# Default static dynamic-shmem cap (48 KB) is below what some probes
|
|
52
|
+
# need (e.g. L11-5v with n_sub=2 hits ~64 KB). We raise the per-kernel
|
|
53
|
+
# cap via cuFuncSetAttribute when required. Modern GPUs (sm_75+) support
|
|
54
|
+
# up to ~96-100 KB dynamic shared memory per block.
|
|
55
|
+
_DEFAULT_SHMEM_CAP_BYTES = 48 * 1024
|
|
56
|
+
_MAX_DYNAMIC_SHMEM_BYTES = 96 * 1024
|
|
57
|
+
|
|
58
|
+
_source_cache: dict[str, str] = {}
|
|
59
|
+
_kernel_cache: dict[tuple[int, int, int], Any] = {}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _load_source(filename: str) -> str:
|
|
63
|
+
if filename not in _source_cache:
|
|
64
|
+
_source_cache[filename] = (_KERNELS_DIR / filename).read_text()
|
|
65
|
+
return _source_cache[filename]
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _shmem_bytes(n_elem: int, n_sub: int) -> int:
|
|
69
|
+
"""Bytes of dynamic shared memory required by the v25c kernel.
|
|
70
|
+
|
|
71
|
+
Layout (see ``simus_fused.cu``):
|
|
72
|
+
7 * B_SCAT * N_ES floats of TX/RX geometry + 3 * N_ELEM floats of
|
|
73
|
+
per-element broadcast (da_init_re, da_init_im, dps).
|
|
74
|
+
"""
|
|
75
|
+
n_es = n_elem * n_sub
|
|
76
|
+
return (7 * _B_SCAT * n_es + 3 * n_elem) * 4
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _get_kernel(n_elem: int, n_sub: int, n_freq: int) -> Any:
|
|
80
|
+
"""Compile + cache simus_fused_kernel for the given problem shape.
|
|
81
|
+
|
|
82
|
+
The cache key is ``(n_elem, n_sub, n_freq)`` -- ``n_scat`` is not in
|
|
83
|
+
the key because the kernel grid-strides over scatterers (one fused
|
|
84
|
+
launch covers the whole sweep, unlike the Metal split-kernel path).
|
|
85
|
+
"""
|
|
86
|
+
key = (n_elem, n_sub, n_freq)
|
|
87
|
+
if key in _kernel_cache:
|
|
88
|
+
return _kernel_cache[key]
|
|
89
|
+
|
|
90
|
+
n_es = n_elem * n_sub
|
|
91
|
+
max_fpt = (n_freq + _TG_SIZE - 1) // _TG_SIZE
|
|
92
|
+
|
|
93
|
+
options = (
|
|
94
|
+
"--std=c++17",
|
|
95
|
+
"--use_fast_math",
|
|
96
|
+
"--extra-device-vectorization",
|
|
97
|
+
f"-DN_ELEM={n_elem}",
|
|
98
|
+
f"-DN_SUB={n_sub}",
|
|
99
|
+
f"-DN_FREQ={n_freq}",
|
|
100
|
+
f"-DN_ES={n_es}",
|
|
101
|
+
f"-DTILE_SE={_TILE_SE}",
|
|
102
|
+
f"-DTG_SIZE={_TG_SIZE}",
|
|
103
|
+
f"-DMAX_FPT={max_fpt}",
|
|
104
|
+
f"-DB_SCAT={_B_SCAT}",
|
|
105
|
+
f"-DELEM_TILE={_ELEM_TILE}",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
module = cp.RawModule(
|
|
109
|
+
code=_load_source(_SOURCE_NAME),
|
|
110
|
+
backend="nvrtc",
|
|
111
|
+
options=options,
|
|
112
|
+
name_expressions=("simus_fused_kernel",),
|
|
113
|
+
)
|
|
114
|
+
kernel = module.get_function("simus_fused_kernel")
|
|
115
|
+
_kernel_cache[key] = kernel
|
|
116
|
+
return kernel
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def _prepare_inputs(
|
|
120
|
+
scatterers: Any,
|
|
121
|
+
rc: Any,
|
|
122
|
+
delays_clean: Any,
|
|
123
|
+
tx_apodization: Any,
|
|
124
|
+
plan: SimusPlan,
|
|
125
|
+
params: TransducerParams,
|
|
126
|
+
medium: MediumParams,
|
|
127
|
+
) -> dict[str, Any]:
|
|
128
|
+
"""Pack the 15 input arrays + 12 scalars the v25c kernel expects.
|
|
129
|
+
|
|
130
|
+
Mirrors ``metal_simus._prepare_common`` but without the
|
|
131
|
+
``(n_scat, n_elem, n_sub)`` expansion: v25c does its own Phase-1
|
|
132
|
+
geometry from the flat per-element / per-sub-element inputs.
|
|
133
|
+
"""
|
|
134
|
+
c = medium.speed_of_sound
|
|
135
|
+
alpha = medium.attenuation
|
|
136
|
+
n_elem = params.n_elements
|
|
137
|
+
n_sub = plan.n_sub
|
|
138
|
+
n_freq = int(plan.selected_freqs.shape[0])
|
|
139
|
+
|
|
140
|
+
# element_positions with `xp=cp` returns CuPy arrays directly.
|
|
141
|
+
xp_cp = cast(_ArrayNamespace, cp)
|
|
142
|
+
elem_pos, theta_e, apex_offset = element_positions(n_elem, params.pitch, params.radius, xp_cp)
|
|
143
|
+
if theta_e is None:
|
|
144
|
+
theta_e = cp.zeros(n_elem, dtype=cp.float32)
|
|
145
|
+
|
|
146
|
+
# Sub-element offsets per (elem, sub) flattened to N_ES with se = elem*n_sub + sub
|
|
147
|
+
# (see kernel line `int elem = se / N_SUB;`).
|
|
148
|
+
offsets = cast(cp.ndarray, _subelement_centroids(params.element_width, n_sub, theta_e, xp_cp))
|
|
149
|
+
sub_dx = cp.ascontiguousarray(offsets[..., 0].reshape(-1).astype(cp.float32))
|
|
150
|
+
sub_dz = cp.ascontiguousarray(offsets[..., 1].reshape(-1).astype(cp.float32))
|
|
151
|
+
|
|
152
|
+
cos_te = cp.ascontiguousarray(cp.cos(theta_e).astype(cp.float32))
|
|
153
|
+
sin_neg_te = cp.ascontiguousarray(cp.sin(-theta_e).astype(cp.float32))
|
|
154
|
+
|
|
155
|
+
# Frequency-grid scalars
|
|
156
|
+
freq_start = float(plan.selected_freqs[0])
|
|
157
|
+
freq_step = float(plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freq > 1 else 0.0
|
|
158
|
+
|
|
159
|
+
# Delay+apodization as separate per-element arrays. The kernel folds
|
|
160
|
+
# tx_apodization into the initial value and steps phase by 2*pi*freq_step
|
|
161
|
+
# per outer-frequency iteration.
|
|
162
|
+
da_init_re = cp.ascontiguousarray((cp.cos(2 * pi * freq_start * delays_clean) * tx_apodization).astype(cp.float32))
|
|
163
|
+
da_init_im = cp.ascontiguousarray((cp.sin(2 * pi * freq_start * delays_clean) * tx_apodization).astype(cp.float32))
|
|
164
|
+
dps = cp.ascontiguousarray((2 * pi * freq_step * delays_clean).astype(cp.float32))
|
|
165
|
+
|
|
166
|
+
# Pulse * probe (complex), and probe magnitude separately for the RX leg.
|
|
167
|
+
pulse_probe = cast(cp.ndarray, plan.pulse_spectrum * plan.probe_spectrum).astype(cp.complex64)
|
|
168
|
+
pp_re = cp.ascontiguousarray(cp.real(pulse_probe).astype(cp.float32))
|
|
169
|
+
pp_im = cp.ascontiguousarray(cp.imag(pulse_probe).astype(cp.float32))
|
|
170
|
+
probe_real = cp.ascontiguousarray(cp.asarray(plan.probe_spectrum).astype(cp.float32))
|
|
171
|
+
|
|
172
|
+
# Convex array radius is float('inf') for linear arrays. Replace with
|
|
173
|
+
# 1e31 so the kernel's `radius * radius` stays finite in fp32.
|
|
174
|
+
radius_v = params.radius if params.radius != inf else 1e31
|
|
175
|
+
|
|
176
|
+
# Pad scatterers to a multiple of B_SCAT. The kernel processes
|
|
177
|
+
# B_SCAT scatterers per block and, when actual_b < B_SCAT, leaves
|
|
178
|
+
# shmem GEO_* slots for si in [actual_b, B_SCAT) uninitialized.
|
|
179
|
+
# Phase 3's cmul(cv=0, garbage) then produces NaN if the garbage is
|
|
180
|
+
# NaN. Padding with valid positions and rc=0 makes Phase 1 populate
|
|
181
|
+
# all si slots while contributing zero to the spectrum (rc=0 zeros
|
|
182
|
+
# tk in Phase 2, and the GEO progression stays finite).
|
|
183
|
+
n_scat = int(scatterers.shape[0])
|
|
184
|
+
n_scat_padded = ((n_scat + _B_SCAT - 1) // _B_SCAT) * _B_SCAT
|
|
185
|
+
if n_scat_padded > n_scat:
|
|
186
|
+
pad = n_scat_padded - n_scat
|
|
187
|
+
scat_x = cp.concatenate(
|
|
188
|
+
[scatterers[:, 0].astype(cp.float32), cp.repeat(scatterers[:1, 0].astype(cp.float32), pad)],
|
|
189
|
+
)
|
|
190
|
+
scat_z = cp.concatenate(
|
|
191
|
+
[scatterers[:, 1].astype(cp.float32), cp.repeat(scatterers[:1, 1].astype(cp.float32), pad)],
|
|
192
|
+
)
|
|
193
|
+
rc_padded = cp.concatenate([rc.astype(cp.float32), cp.zeros(pad, dtype=cp.float32)])
|
|
194
|
+
else:
|
|
195
|
+
scat_x = scatterers[:, 0].astype(cp.float32)
|
|
196
|
+
scat_z = scatterers[:, 1].astype(cp.float32)
|
|
197
|
+
rc_padded = rc.astype(cp.float32)
|
|
198
|
+
|
|
199
|
+
return {
|
|
200
|
+
"scat_x": cp.ascontiguousarray(scat_x),
|
|
201
|
+
"scat_z": cp.ascontiguousarray(scat_z),
|
|
202
|
+
"rc": cp.ascontiguousarray(rc_padded),
|
|
203
|
+
"elem_x": cp.ascontiguousarray(elem_pos[:, 0].astype(cp.float32)),
|
|
204
|
+
"elem_z": cp.ascontiguousarray(elem_pos[:, 1].astype(cp.float32)),
|
|
205
|
+
"cos_te": cos_te,
|
|
206
|
+
"sin_neg_te": sin_neg_te,
|
|
207
|
+
"sub_dx": sub_dx,
|
|
208
|
+
"sub_dz": sub_dz,
|
|
209
|
+
"da_init_re": da_init_re,
|
|
210
|
+
"da_init_im": da_init_im,
|
|
211
|
+
"dps": dps,
|
|
212
|
+
"pp_re": pp_re,
|
|
213
|
+
"pp_im": pp_im,
|
|
214
|
+
"probe_real": probe_real,
|
|
215
|
+
"n_scat": n_scat_padded,
|
|
216
|
+
"kw_init": 2 * pi * freq_start / c,
|
|
217
|
+
"alpha_init": alpha / NEPER_TO_DB * freq_start / 1e6 * 1e2,
|
|
218
|
+
"kw_step": 2 * pi * freq_step / c,
|
|
219
|
+
"alpha_step": alpha / NEPER_TO_DB * freq_step / 1e6 * 1e2,
|
|
220
|
+
"min_dist": c / params.freq_center / 2.0,
|
|
221
|
+
"seg_len": plan.seg_length,
|
|
222
|
+
"center_kw": 2 * pi * params.freq_center / c,
|
|
223
|
+
"inv_nsub": 1.0 / n_sub,
|
|
224
|
+
"radius_v": radius_v,
|
|
225
|
+
"apex_offset": apex_offset,
|
|
226
|
+
"n_elem": n_elem,
|
|
227
|
+
"n_sub": n_sub,
|
|
228
|
+
"n_freq": n_freq,
|
|
229
|
+
}
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def simus_cuda(
|
|
233
|
+
scatterers: Any,
|
|
234
|
+
rc: Any,
|
|
235
|
+
params: TransducerParams,
|
|
236
|
+
plan: SimusPlan,
|
|
237
|
+
medium: MediumParams,
|
|
238
|
+
delays_clean: Any,
|
|
239
|
+
tx_apodization: Any,
|
|
240
|
+
) -> Any:
|
|
241
|
+
"""Compute simus RF spectrum using the v25c CUDA kernel via CuPy/NVRTC.
|
|
242
|
+
|
|
243
|
+
Single fused TX+RX kernel that grid-strides over scatterers; no
|
|
244
|
+
chunking is needed since per-thread state is in registers.
|
|
245
|
+
|
|
246
|
+
Args:
|
|
247
|
+
scatterers: Scatterer positions (x, z) in meters. Shape ``(n_scat, 2)``.
|
|
248
|
+
rc: Reflection coefficients. Shape ``(n_scat,)``.
|
|
249
|
+
params: Transducer parameters.
|
|
250
|
+
plan: Precomputed frequency plan from ``simus_precompute``.
|
|
251
|
+
medium: Medium parameters.
|
|
252
|
+
delays_clean: NaN-cleaned delays. Shape ``(n_elements,)``.
|
|
253
|
+
tx_apodization: Per-element apodization (NaN-zeroed). Shape ``(n_elements,)``.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Complex RF spectrum, shape ``(n_freq, n_elements)``, dtype ``complex64``.
|
|
257
|
+
"""
|
|
258
|
+
d = _prepare_inputs(scatterers, rc, delays_clean, tx_apodization, plan, params, medium)
|
|
259
|
+
n_elem, n_sub, n_freq = d["n_elem"], d["n_sub"], d["n_freq"]
|
|
260
|
+
|
|
261
|
+
shmem = _shmem_bytes(n_elem, n_sub)
|
|
262
|
+
if shmem > _MAX_DYNAMIC_SHMEM_BYTES:
|
|
263
|
+
msg = (
|
|
264
|
+
f"v25c shmem {shmem} B exceeds the {_MAX_DYNAMIC_SHMEM_BYTES} B "
|
|
265
|
+
f"per-block cap for (n_elem={n_elem}, n_sub={n_sub}); needs a "
|
|
266
|
+
f"smaller B_SCAT or a different probe."
|
|
267
|
+
)
|
|
268
|
+
raise RuntimeError(msg)
|
|
269
|
+
|
|
270
|
+
kernel = _get_kernel(n_elem, n_sub, n_freq)
|
|
271
|
+
# Raise per-kernel dynamic-shmem cap when we exceed the 48 KB default.
|
|
272
|
+
# No-op when shmem fits under _DEFAULT_SHMEM_CAP_BYTES.
|
|
273
|
+
if shmem > _DEFAULT_SHMEM_CAP_BYTES:
|
|
274
|
+
kernel.max_dynamic_shared_size_bytes = shmem
|
|
275
|
+
|
|
276
|
+
# Output buffers; kernel uses atomicAdd into spect_re[elem*N_FREQ + f].
|
|
277
|
+
spect_re = cp.zeros(n_elem * n_freq, dtype=cp.float32)
|
|
278
|
+
spect_im = cp.zeros(n_elem * n_freq, dtype=cp.float32)
|
|
279
|
+
|
|
280
|
+
args = (
|
|
281
|
+
d["scat_x"],
|
|
282
|
+
d["scat_z"],
|
|
283
|
+
d["rc"],
|
|
284
|
+
d["elem_x"],
|
|
285
|
+
d["elem_z"],
|
|
286
|
+
d["cos_te"],
|
|
287
|
+
d["sin_neg_te"],
|
|
288
|
+
d["sub_dx"],
|
|
289
|
+
d["sub_dz"],
|
|
290
|
+
d["da_init_re"],
|
|
291
|
+
d["da_init_im"],
|
|
292
|
+
d["dps"],
|
|
293
|
+
d["pp_re"],
|
|
294
|
+
d["pp_im"],
|
|
295
|
+
d["probe_real"],
|
|
296
|
+
spect_re,
|
|
297
|
+
spect_im,
|
|
298
|
+
cp.int32(d["n_scat"]),
|
|
299
|
+
cp.float32(d["kw_init"]),
|
|
300
|
+
cp.float32(d["alpha_init"]),
|
|
301
|
+
cp.float32(d["kw_step"]),
|
|
302
|
+
cp.float32(d["alpha_step"]),
|
|
303
|
+
cp.float32(d["min_dist"]),
|
|
304
|
+
cp.float32(d["seg_len"]),
|
|
305
|
+
cp.float32(d["center_kw"]),
|
|
306
|
+
cp.float32(d["inv_nsub"]),
|
|
307
|
+
cp.float32(d["radius_v"]),
|
|
308
|
+
cp.float32(d["apex_offset"]),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
kernel(
|
|
312
|
+
grid=(_GRID_BLOCKS, 1, 1),
|
|
313
|
+
block=(_TG_SIZE, 1, 1),
|
|
314
|
+
args=args,
|
|
315
|
+
shared_mem=shmem,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Row-major (n_elem, n_freq) -> column-major (n_freq, n_elem) complex64
|
|
319
|
+
# to match metal_simus / _simus_freq_outer_python output convention.
|
|
320
|
+
spect = (spect_re + 1j * spect_im).reshape(n_elem, n_freq).T
|
|
321
|
+
return spect.astype(cp.complex64)
|