FastSIMUS 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,210 @@
1
+ """Loop drivers for simus frequency sweep (Layer 3).
2
+
3
+ Each driver iterates _simus_freq_step_body() using a different mechanism:
4
+
5
+ - _simus_freq_outer_python: Python for-loop (NumPy/CuPy, constant memory)
6
+ - _simus_freq_outer_scan: JAX lax.scan for O(1) compilation cost
7
+
8
+ The simus step body differs from pfield's: instead of accumulating |P_k|^2
9
+ per grid point, it computes the full TX->scatter->RX chain and accumulates
10
+ complex RF spectrum per element.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ from math import pi
16
+
17
+ import array_api_extra as xpx
18
+ from jaxtyping import Bool, Complex, Float
19
+
20
+ from fast_simus.utils._array_api import Array, _ArrayNamespace
21
+
22
+
23
+ def _simus_freq_step_body(
24
+ phase: Complex[Array, "n_scat n_elem n_sub"],
25
+ phase_step: Complex[Array, "n_scat n_elem n_sub"],
26
+ delay_apod_phase: Complex[Array, " n_elem"],
27
+ delay_apod_step: Complex[Array, " n_elem"],
28
+ rc: Float[Array, " n_scat"],
29
+ pulse_probe_k: complex | Array,
30
+ probe_k: float | Array,
31
+ is_out: Bool[Array, " n_scat"],
32
+ xp: _ArrayNamespace,
33
+ *,
34
+ directivity_k: Float[Array, "n_scat n_elem n_sub"] | None = None,
35
+ ) -> tuple[
36
+ Complex[Array, "n_scat n_elem n_sub"],
37
+ Complex[Array, " n_elem"],
38
+ Complex[Array, " n_elem"],
39
+ ]:
40
+ """One frequency step: TX forward, scatter, RX backprop.
41
+
42
+ Args:
43
+ phase: Geometric progression state (n_scat, n_elem, n_sub).
44
+ phase_step: Per-step multiplier for geometric progression.
45
+ delay_apod_phase: Current delay+apodization phase per element.
46
+ delay_apod_step: Per-step delay+apodization multiplier.
47
+ rc: Reflection coefficients per scatterer.
48
+ pulse_probe_k: Combined pulse*probe spectrum weight for this frequency.
49
+ probe_k: Probe-only spectrum weight for RX.
50
+ is_out: Boolean mask for out-of-field scatterers.
51
+ xp: Array namespace.
52
+ directivity_k: Per-source directivity (optional, for full_frequency_directivity).
53
+
54
+ Returns:
55
+ Tuple of (updated_phase, updated_delay_apod, spect_k) where
56
+ spect_k is the complex RF spectrum contribution for this frequency,
57
+ shape (n_elements,).
58
+ """
59
+ if directivity_k is not None:
60
+ rp_mono = xp.mean(phase * directivity_k, axis=-1)
61
+ else:
62
+ rp_mono = xp.mean(phase, axis=-1)
63
+
64
+ # TX: contract over elements -> pressure at each scatterer
65
+ p_k = pulse_probe_k * (rp_mono @ delay_apod_phase[..., None])[..., 0]
66
+ p_k = xp.where(is_out, xp.asarray(0.0 + 0j), p_k)
67
+
68
+ # RX: contract over scatterers -> spectrum per element
69
+ # (rc * p_k)^T @ rp_mono = sum_i(rc_i * p_k_i * rp_mono[i, e])
70
+ weighted = rc * p_k
71
+ spect_k = weighted @ rp_mono
72
+ spect_k = probe_k * spect_k
73
+
74
+ phase = phase * phase_step
75
+ delay_apod_phase = delay_apod_phase * delay_apod_step
76
+
77
+ return phase, delay_apod_phase, spect_k
78
+
79
+
80
+ def _simus_freq_outer_python(
81
+ phase_init: Complex[Array, "n_scat n_elem n_sub"],
82
+ phase_step: Complex[Array, "n_scat n_elem n_sub"],
83
+ delay_apod_init: Complex[Array, " n_elem"],
84
+ delay_apod_step: Complex[Array, " n_elem"],
85
+ rc: Float[Array, " n_scat"],
86
+ is_out: Bool[Array, " n_scat"],
87
+ wavenumbers: Float[Array, " n_freq"],
88
+ pulse_spect: Complex[Array, " n_freq"],
89
+ probe_spect: Float[Array, " n_freq"],
90
+ seg_length: float,
91
+ sin_theta: Float[Array, "n_scat n_elem n_sub"],
92
+ full_frequency_directivity: bool,
93
+ xp: _ArrayNamespace,
94
+ ) -> Complex[Array, "n_freq n_elem"]:
95
+ """Python for-loop driver: iterates one frequency at a time.
96
+
97
+ Accumulates the complex RF spectrum (n_freq, n_elements).
98
+ Peak memory is O(n_scat * n_elem * n_sub), independent of n_freq.
99
+ """
100
+ spectra = pulse_spect * probe_spect
101
+ n_freq = int(wavenumbers.shape[0])
102
+ n_elem = phase_init.shape[1]
103
+
104
+ spect_accum = xp.zeros((n_freq, n_elem), dtype=phase_init.dtype)
105
+ phase = phase_init
106
+ delay_apod_phase = delay_apod_init
107
+
108
+ if full_frequency_directivity:
109
+ for k in range(n_freq):
110
+ sinc_arg = wavenumbers[k] * seg_length / 2.0 * sin_theta / pi
111
+ directivity_k = xpx.sinc(sinc_arg, xp=xp)
112
+ phase, delay_apod_phase, spect_k = _simus_freq_step_body(
113
+ phase,
114
+ phase_step,
115
+ delay_apod_phase,
116
+ delay_apod_step,
117
+ rc,
118
+ spectra[k],
119
+ probe_spect[k],
120
+ is_out,
121
+ xp,
122
+ directivity_k=directivity_k,
123
+ )
124
+ spect_accum = _set_row(spect_accum, k, spect_k)
125
+ else:
126
+ for k in range(n_freq):
127
+ phase, delay_apod_phase, spect_k = _simus_freq_step_body(
128
+ phase,
129
+ phase_step,
130
+ delay_apod_phase,
131
+ delay_apod_step,
132
+ rc,
133
+ spectra[k],
134
+ probe_spect[k],
135
+ is_out,
136
+ xp,
137
+ )
138
+ spect_accum = _set_row(spect_accum, k, spect_k)
139
+
140
+ return spect_accum
141
+
142
+
143
+ def _set_row(
144
+ arr: Complex[Array, "n_freq n_elem"],
145
+ k: int,
146
+ row: Complex[Array, " n_elem"],
147
+ ) -> Complex[Array, "n_freq n_elem"]:
148
+ """Set row k of arr to row, Array API compatible."""
149
+ return xpx.at(arr)[k, :].set(row) # type: ignore[attr-defined]
150
+
151
+
152
+ def _simus_freq_outer_scan(
153
+ phase_init: Complex[Array, "n_scat n_elem n_sub"],
154
+ phase_step: Complex[Array, "n_scat n_elem n_sub"],
155
+ delay_apod_init: Complex[Array, " n_elem"],
156
+ delay_apod_step: Complex[Array, " n_elem"],
157
+ rc: Float[Array, " n_scat"],
158
+ is_out: Bool[Array, " n_scat"],
159
+ wavenumbers: Float[Array, " n_freq"],
160
+ pulse_spect: Complex[Array, " n_freq"],
161
+ probe_spect: Float[Array, " n_freq"],
162
+ seg_length: float,
163
+ sin_theta: Float[Array, "n_scat n_elem n_sub"],
164
+ full_frequency_directivity: bool,
165
+ xp: _ArrayNamespace,
166
+ ) -> Complex[Array, "n_freq n_elem"]:
167
+ """JAX lax.scan driver: scan over frequencies with full tensor carry.
168
+
169
+ The carry holds (phase, delay_apod_phase) with shapes
170
+ (n_scat, n_elem, n_sub) and (n_elem,). Each step outputs
171
+ spect_k with shape (n_elem,), stacked by scan into (n_freq, n_elem).
172
+ """
173
+ import jax
174
+
175
+ spectra = pulse_spect * probe_spect
176
+
177
+ if full_frequency_directivity:
178
+
179
+ def scan_fn(carry, xs):
180
+ phase, delay_apod = carry
181
+ spectrum_k, probe_k, wavenum_k = xs
182
+ sinc_arg = wavenum_k * seg_length / 2.0 * sin_theta / pi
183
+ directivity_k = xpx.sinc(sinc_arg, xp=xp)
184
+ rp_mono = xp.mean(phase * directivity_k, axis=-1)
185
+ p_k = spectrum_k * (rp_mono @ delay_apod[..., None])[..., 0]
186
+ p_k = xp.where(is_out, xp.asarray(0.0 + 0j), p_k)
187
+ spect_k = probe_k * (rc * p_k) @ rp_mono
188
+ phase = phase * phase_step
189
+ delay_apod = delay_apod * delay_apod_step
190
+ return (phase, delay_apod), spect_k
191
+
192
+ init_carry = (phase_init, delay_apod_init)
193
+ _, spect_all = jax.lax.scan(scan_fn, init_carry, (spectra, probe_spect, wavenumbers))
194
+ else:
195
+
196
+ def scan_fn_no_dir(carry, xs):
197
+ phase, delay_apod = carry
198
+ spectrum_k, probe_k = xs
199
+ rp_mono = xp.mean(phase, axis=-1)
200
+ p_k = spectrum_k * (rp_mono @ delay_apod[..., None])[..., 0]
201
+ p_k = xp.where(is_out, xp.asarray(0.0 + 0j), p_k)
202
+ spect_k = probe_k * (rc * p_k) @ rp_mono
203
+ phase = phase * phase_step
204
+ delay_apod = delay_apod * delay_apod_step
205
+ return (phase, delay_apod), spect_k
206
+
207
+ init_carry = (phase_init, delay_apod_init)
208
+ _, spect_all = jax.lax.scan(scan_fn_no_dir, init_carry, (spectra, probe_spect))
209
+
210
+ return spect_all
@@ -0,0 +1 @@
1
+ """Backend-specific compatibility modules."""
@@ -0,0 +1,101 @@
1
+ """Array API compatibility shim for MLX.
2
+
3
+ Temporary until array_api_compat gains native MLX support.
4
+ Tracking: https://github.com/data-apis/array-api-compat/issues/162
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import contextlib
10
+ from typing import Any
11
+
12
+ import array_api_compat
13
+ import array_api_compat.common._helpers as _helpers
14
+
15
+ _MLX_ARRAY_API_ALIASES: dict[str, str] = {
16
+ "asin": "arcsin",
17
+ "acos": "arccos",
18
+ "atan2": "arctan2",
19
+ "bool": "bool_",
20
+ }
21
+
22
+ _MLX_ISDTYPE_KIND_MAP: dict[str, str] = {
23
+ "bool": "bool_",
24
+ "signed integer": "signedinteger",
25
+ "unsigned integer": "unsignedinteger",
26
+ "integral": "integer",
27
+ "real floating": "floating",
28
+ "complex floating": "complexfloating",
29
+ "numeric": "number",
30
+ }
31
+
32
+
33
+ def _make_isdtype(xp: Any) -> Any:
34
+ def isdtype(dtype: Any, kind: Any) -> bool:
35
+ if isinstance(kind, str):
36
+ category = _MLX_ISDTYPE_KIND_MAP.get(kind)
37
+ if category is None:
38
+ msg = f"Unrecognized dtype kind: {kind!r}"
39
+ raise ValueError(msg)
40
+ return bool(xp.issubdtype(dtype, getattr(xp, category)))
41
+ if isinstance(kind, tuple):
42
+ return any(isdtype(dtype, k) for k in kind)
43
+ return dtype == kind
44
+
45
+ return isdtype
46
+
47
+
48
+ def _patch_namespace(xp: Any) -> None:
49
+ """Add Array API aliases to mlx.core (idempotent)."""
50
+ for standard_name, mlx_name in _MLX_ARRAY_API_ALIASES.items():
51
+ if not hasattr(xp, standard_name) and hasattr(xp, mlx_name):
52
+ setattr(xp, standard_name, getattr(xp, mlx_name))
53
+
54
+ if not hasattr(xp, "isdtype") and hasattr(xp, "issubdtype"):
55
+ xp.isdtype = _make_isdtype(xp)
56
+
57
+ if not hasattr(xp, "astype"):
58
+
59
+ def _astype(x: Any, dtype: Any, /, *, copy: bool = False) -> Any:
60
+ return x.astype(dtype)
61
+
62
+ xp.astype = _astype
63
+
64
+ if not getattr(xp.asarray, "_fastsimus_wrapped", False):
65
+ _original = xp.asarray
66
+
67
+ def _asarray(a: Any, *, dtype: Any = None, **_kwargs: Any) -> Any:
68
+ if dtype is not None:
69
+ return _original(a, dtype=dtype)
70
+ return _original(a)
71
+
72
+ _asarray._fastsimus_wrapped = True # type: ignore[attr-defined]
73
+ xp.asarray = _asarray
74
+
75
+
76
+ def _patch_device(xp: Any) -> None:
77
+ """Patch array_api_compat device() for MLX unified memory."""
78
+ _original = _helpers.device
79
+ if getattr(_original, "_fastsimus_mlx", False):
80
+ return
81
+
82
+ def _device_with_mlx(x: Any, /) -> Any:
83
+ if type(x).__module__.startswith("mlx"):
84
+ return xp.default_device()
85
+ return _original(x)
86
+
87
+ _device_with_mlx._fastsimus_mlx = True # type: ignore[attr-defined]
88
+
89
+ _helpers.device = _device_with_mlx # type: ignore[assignment]
90
+ array_api_compat.device = _device_with_mlx # type: ignore[assignment]
91
+
92
+ with contextlib.suppress(ImportError):
93
+ import array_api_extra._lib._utils._compat as _xpx_compat # type: ignore[import-untyped]
94
+
95
+ _xpx_compat.device = _device_with_mlx
96
+
97
+
98
+ def ensure_compat(xp: Any) -> None:
99
+ """Apply all MLX compatibility patches (idempotent)."""
100
+ _patch_namespace(xp)
101
+ _patch_device(xp)
@@ -0,0 +1,9 @@
1
+ """Backend-specific fused kernels for FastSIMUS.
2
+
3
+ Custom kernels provide maximum performance by fusing the entire computation
4
+ into a single GPU dispatch. Each kernel is a different algorithm from the
5
+ Array API path (e.g., on-the-fly geometry instead of precomputed arrays).
6
+
7
+ Available kernels:
8
+ - metal_pfield: Apple Silicon Metal kernel for pfield (requires MLX)
9
+ """
@@ -0,0 +1,321 @@
1
+ """CuPy CUDA backend for simus.
2
+
3
+ Compiles the v25c register-resident TX kernel via NVRTC at runtime
4
+ (``cupy.RawModule``) -- no nanobind, no setuptools build step. Pinned to
5
+ ``(B_SCAT=10, ELEM_TILE=2)`` for RTX 4090 / sm_89 / P4-2v; performance may
6
+ regress on other probes / GPUs (see exp22 + the FastSIMUS-cuda-tune
7
+ follow-up).
8
+
9
+ Output layout matches ``metal_simus.simus_metal``: complex64
10
+ ``(n_freq, n_elements)``. The shipped kernel does its own per-scatterer
11
+ Phase-1 geometry from a flat input set, so the
12
+ ``(n_scat, n_elem, n_sub)`` phase tensors that ``_simus_freq_outer_python``
13
+ consumes are *not* fed in here.
14
+
15
+ Requires: CuPy on a CUDA host. Use ``cupy-cuda12x`` for CUDA 12/Pascal
16
+ hosts and ``cupy-cuda13x`` for CUDA 13/Turing-or-newer hosts.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ from math import inf, pi
22
+ from pathlib import Path
23
+ from typing import TYPE_CHECKING, Any, cast
24
+
25
+ import cupy as cp
26
+
27
+ from fast_simus._pfield_math import NEPER_TO_DB, _subelement_centroids
28
+ from fast_simus.medium_params import MediumParams
29
+ from fast_simus.transducer_params import TransducerParams
30
+ from fast_simus.utils._array_api import _ArrayNamespace
31
+ from fast_simus.utils.geometry import element_positions
32
+
33
+ if TYPE_CHECKING:
34
+ from fast_simus.simus import SimusPlan
35
+
36
+ _KERNELS_DIR = Path(__file__).parent
37
+ _SOURCE_NAME = "simus_fused.cu"
38
+
39
+ # Pinned tuning -- see docs/progress/experiments/exp22-svshmem-et2.md.
40
+ # These constants are RTX 4090 / sm_89 / P4-2v optimal; not autotuned.
41
+ _B_SCAT = 10
42
+ _ELEM_TILE = 2
43
+ _TG_SIZE = 128
44
+ _TILE_SE = 16
45
+ _GRID_BLOCKS = 256 # 2 * 128 SMs on RTX 4090
46
+
47
+ # CuPy / NVRTC auto-derives ``--gpu-architecture`` from the current device,
48
+ # so we don't pin it here. Tuning constants (B_SCAT, ELEM_TILE, TG_SIZE)
49
+ # are still hardwired for sm_89 and may need adjustment for sm_80 / sm_90.
50
+
51
+ # Default static dynamic-shmem cap (48 KB) is below what some probes
52
+ # need (e.g. L11-5v with n_sub=2 hits ~64 KB). We raise the per-kernel
53
+ # cap via cuFuncSetAttribute when required. Modern GPUs (sm_75+) support
54
+ # up to ~96-100 KB dynamic shared memory per block.
55
+ _DEFAULT_SHMEM_CAP_BYTES = 48 * 1024
56
+ _MAX_DYNAMIC_SHMEM_BYTES = 96 * 1024
57
+
58
+ _source_cache: dict[str, str] = {}
59
+ _kernel_cache: dict[tuple[int, int, int], Any] = {}
60
+
61
+
62
+ def _load_source(filename: str) -> str:
63
+ if filename not in _source_cache:
64
+ _source_cache[filename] = (_KERNELS_DIR / filename).read_text()
65
+ return _source_cache[filename]
66
+
67
+
68
+ def _shmem_bytes(n_elem: int, n_sub: int) -> int:
69
+ """Bytes of dynamic shared memory required by the v25c kernel.
70
+
71
+ Layout (see ``simus_fused.cu``):
72
+ 7 * B_SCAT * N_ES floats of TX/RX geometry + 3 * N_ELEM floats of
73
+ per-element broadcast (da_init_re, da_init_im, dps).
74
+ """
75
+ n_es = n_elem * n_sub
76
+ return (7 * _B_SCAT * n_es + 3 * n_elem) * 4
77
+
78
+
79
+ def _get_kernel(n_elem: int, n_sub: int, n_freq: int) -> Any:
80
+ """Compile + cache simus_fused_kernel for the given problem shape.
81
+
82
+ The cache key is ``(n_elem, n_sub, n_freq)`` -- ``n_scat`` is not in
83
+ the key because the kernel grid-strides over scatterers (one fused
84
+ launch covers the whole sweep, unlike the Metal split-kernel path).
85
+ """
86
+ key = (n_elem, n_sub, n_freq)
87
+ if key in _kernel_cache:
88
+ return _kernel_cache[key]
89
+
90
+ n_es = n_elem * n_sub
91
+ max_fpt = (n_freq + _TG_SIZE - 1) // _TG_SIZE
92
+
93
+ options = (
94
+ "--std=c++17",
95
+ "--use_fast_math",
96
+ "--extra-device-vectorization",
97
+ f"-DN_ELEM={n_elem}",
98
+ f"-DN_SUB={n_sub}",
99
+ f"-DN_FREQ={n_freq}",
100
+ f"-DN_ES={n_es}",
101
+ f"-DTILE_SE={_TILE_SE}",
102
+ f"-DTG_SIZE={_TG_SIZE}",
103
+ f"-DMAX_FPT={max_fpt}",
104
+ f"-DB_SCAT={_B_SCAT}",
105
+ f"-DELEM_TILE={_ELEM_TILE}",
106
+ )
107
+
108
+ module = cp.RawModule(
109
+ code=_load_source(_SOURCE_NAME),
110
+ backend="nvrtc",
111
+ options=options,
112
+ name_expressions=("simus_fused_kernel",),
113
+ )
114
+ kernel = module.get_function("simus_fused_kernel")
115
+ _kernel_cache[key] = kernel
116
+ return kernel
117
+
118
+
119
+ def _prepare_inputs(
120
+ scatterers: Any,
121
+ rc: Any,
122
+ delays_clean: Any,
123
+ tx_apodization: Any,
124
+ plan: SimusPlan,
125
+ params: TransducerParams,
126
+ medium: MediumParams,
127
+ ) -> dict[str, Any]:
128
+ """Pack the 15 input arrays + 12 scalars the v25c kernel expects.
129
+
130
+ Mirrors ``metal_simus._prepare_common`` but without the
131
+ ``(n_scat, n_elem, n_sub)`` expansion: v25c does its own Phase-1
132
+ geometry from the flat per-element / per-sub-element inputs.
133
+ """
134
+ c = medium.speed_of_sound
135
+ alpha = medium.attenuation
136
+ n_elem = params.n_elements
137
+ n_sub = plan.n_sub
138
+ n_freq = int(plan.selected_freqs.shape[0])
139
+
140
+ # element_positions with `xp=cp` returns CuPy arrays directly.
141
+ xp_cp = cast(_ArrayNamespace, cp)
142
+ elem_pos, theta_e, apex_offset = element_positions(n_elem, params.pitch, params.radius, xp_cp)
143
+ if theta_e is None:
144
+ theta_e = cp.zeros(n_elem, dtype=cp.float32)
145
+
146
+ # Sub-element offsets per (elem, sub) flattened to N_ES with se = elem*n_sub + sub
147
+ # (see kernel line `int elem = se / N_SUB;`).
148
+ offsets = cast(cp.ndarray, _subelement_centroids(params.element_width, n_sub, theta_e, xp_cp))
149
+ sub_dx = cp.ascontiguousarray(offsets[..., 0].reshape(-1).astype(cp.float32))
150
+ sub_dz = cp.ascontiguousarray(offsets[..., 1].reshape(-1).astype(cp.float32))
151
+
152
+ cos_te = cp.ascontiguousarray(cp.cos(theta_e).astype(cp.float32))
153
+ sin_neg_te = cp.ascontiguousarray(cp.sin(-theta_e).astype(cp.float32))
154
+
155
+ # Frequency-grid scalars
156
+ freq_start = float(plan.selected_freqs[0])
157
+ freq_step = float(plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freq > 1 else 0.0
158
+
159
+ # Delay+apodization as separate per-element arrays. The kernel folds
160
+ # tx_apodization into the initial value and steps phase by 2*pi*freq_step
161
+ # per outer-frequency iteration.
162
+ da_init_re = cp.ascontiguousarray((cp.cos(2 * pi * freq_start * delays_clean) * tx_apodization).astype(cp.float32))
163
+ da_init_im = cp.ascontiguousarray((cp.sin(2 * pi * freq_start * delays_clean) * tx_apodization).astype(cp.float32))
164
+ dps = cp.ascontiguousarray((2 * pi * freq_step * delays_clean).astype(cp.float32))
165
+
166
+ # Pulse * probe (complex), and probe magnitude separately for the RX leg.
167
+ pulse_probe = cast(cp.ndarray, plan.pulse_spectrum * plan.probe_spectrum).astype(cp.complex64)
168
+ pp_re = cp.ascontiguousarray(cp.real(pulse_probe).astype(cp.float32))
169
+ pp_im = cp.ascontiguousarray(cp.imag(pulse_probe).astype(cp.float32))
170
+ probe_real = cp.ascontiguousarray(cp.asarray(plan.probe_spectrum).astype(cp.float32))
171
+
172
+ # Convex array radius is float('inf') for linear arrays. Replace with
173
+ # 1e31 so the kernel's `radius * radius` stays finite in fp32.
174
+ radius_v = params.radius if params.radius != inf else 1e31
175
+
176
+ # Pad scatterers to a multiple of B_SCAT. The kernel processes
177
+ # B_SCAT scatterers per block and, when actual_b < B_SCAT, leaves
178
+ # shmem GEO_* slots for si in [actual_b, B_SCAT) uninitialized.
179
+ # Phase 3's cmul(cv=0, garbage) then produces NaN if the garbage is
180
+ # NaN. Padding with valid positions and rc=0 makes Phase 1 populate
181
+ # all si slots while contributing zero to the spectrum (rc=0 zeros
182
+ # tk in Phase 2, and the GEO progression stays finite).
183
+ n_scat = int(scatterers.shape[0])
184
+ n_scat_padded = ((n_scat + _B_SCAT - 1) // _B_SCAT) * _B_SCAT
185
+ if n_scat_padded > n_scat:
186
+ pad = n_scat_padded - n_scat
187
+ scat_x = cp.concatenate(
188
+ [scatterers[:, 0].astype(cp.float32), cp.repeat(scatterers[:1, 0].astype(cp.float32), pad)],
189
+ )
190
+ scat_z = cp.concatenate(
191
+ [scatterers[:, 1].astype(cp.float32), cp.repeat(scatterers[:1, 1].astype(cp.float32), pad)],
192
+ )
193
+ rc_padded = cp.concatenate([rc.astype(cp.float32), cp.zeros(pad, dtype=cp.float32)])
194
+ else:
195
+ scat_x = scatterers[:, 0].astype(cp.float32)
196
+ scat_z = scatterers[:, 1].astype(cp.float32)
197
+ rc_padded = rc.astype(cp.float32)
198
+
199
+ return {
200
+ "scat_x": cp.ascontiguousarray(scat_x),
201
+ "scat_z": cp.ascontiguousarray(scat_z),
202
+ "rc": cp.ascontiguousarray(rc_padded),
203
+ "elem_x": cp.ascontiguousarray(elem_pos[:, 0].astype(cp.float32)),
204
+ "elem_z": cp.ascontiguousarray(elem_pos[:, 1].astype(cp.float32)),
205
+ "cos_te": cos_te,
206
+ "sin_neg_te": sin_neg_te,
207
+ "sub_dx": sub_dx,
208
+ "sub_dz": sub_dz,
209
+ "da_init_re": da_init_re,
210
+ "da_init_im": da_init_im,
211
+ "dps": dps,
212
+ "pp_re": pp_re,
213
+ "pp_im": pp_im,
214
+ "probe_real": probe_real,
215
+ "n_scat": n_scat_padded,
216
+ "kw_init": 2 * pi * freq_start / c,
217
+ "alpha_init": alpha / NEPER_TO_DB * freq_start / 1e6 * 1e2,
218
+ "kw_step": 2 * pi * freq_step / c,
219
+ "alpha_step": alpha / NEPER_TO_DB * freq_step / 1e6 * 1e2,
220
+ "min_dist": c / params.freq_center / 2.0,
221
+ "seg_len": plan.seg_length,
222
+ "center_kw": 2 * pi * params.freq_center / c,
223
+ "inv_nsub": 1.0 / n_sub,
224
+ "radius_v": radius_v,
225
+ "apex_offset": apex_offset,
226
+ "n_elem": n_elem,
227
+ "n_sub": n_sub,
228
+ "n_freq": n_freq,
229
+ }
230
+
231
+
232
+ def simus_cuda(
233
+ scatterers: Any,
234
+ rc: Any,
235
+ params: TransducerParams,
236
+ plan: SimusPlan,
237
+ medium: MediumParams,
238
+ delays_clean: Any,
239
+ tx_apodization: Any,
240
+ ) -> Any:
241
+ """Compute simus RF spectrum using the v25c CUDA kernel via CuPy/NVRTC.
242
+
243
+ Single fused TX+RX kernel that grid-strides over scatterers; no
244
+ chunking is needed since per-thread state is in registers.
245
+
246
+ Args:
247
+ scatterers: Scatterer positions (x, z) in meters. Shape ``(n_scat, 2)``.
248
+ rc: Reflection coefficients. Shape ``(n_scat,)``.
249
+ params: Transducer parameters.
250
+ plan: Precomputed frequency plan from ``simus_precompute``.
251
+ medium: Medium parameters.
252
+ delays_clean: NaN-cleaned delays. Shape ``(n_elements,)``.
253
+ tx_apodization: Per-element apodization (NaN-zeroed). Shape ``(n_elements,)``.
254
+
255
+ Returns:
256
+ Complex RF spectrum, shape ``(n_freq, n_elements)``, dtype ``complex64``.
257
+ """
258
+ d = _prepare_inputs(scatterers, rc, delays_clean, tx_apodization, plan, params, medium)
259
+ n_elem, n_sub, n_freq = d["n_elem"], d["n_sub"], d["n_freq"]
260
+
261
+ shmem = _shmem_bytes(n_elem, n_sub)
262
+ if shmem > _MAX_DYNAMIC_SHMEM_BYTES:
263
+ msg = (
264
+ f"v25c shmem {shmem} B exceeds the {_MAX_DYNAMIC_SHMEM_BYTES} B "
265
+ f"per-block cap for (n_elem={n_elem}, n_sub={n_sub}); needs a "
266
+ f"smaller B_SCAT or a different probe."
267
+ )
268
+ raise RuntimeError(msg)
269
+
270
+ kernel = _get_kernel(n_elem, n_sub, n_freq)
271
+ # Raise per-kernel dynamic-shmem cap when we exceed the 48 KB default.
272
+ # No-op when shmem fits under _DEFAULT_SHMEM_CAP_BYTES.
273
+ if shmem > _DEFAULT_SHMEM_CAP_BYTES:
274
+ kernel.max_dynamic_shared_size_bytes = shmem
275
+
276
+ # Output buffers; kernel uses atomicAdd into spect_re[elem*N_FREQ + f].
277
+ spect_re = cp.zeros(n_elem * n_freq, dtype=cp.float32)
278
+ spect_im = cp.zeros(n_elem * n_freq, dtype=cp.float32)
279
+
280
+ args = (
281
+ d["scat_x"],
282
+ d["scat_z"],
283
+ d["rc"],
284
+ d["elem_x"],
285
+ d["elem_z"],
286
+ d["cos_te"],
287
+ d["sin_neg_te"],
288
+ d["sub_dx"],
289
+ d["sub_dz"],
290
+ d["da_init_re"],
291
+ d["da_init_im"],
292
+ d["dps"],
293
+ d["pp_re"],
294
+ d["pp_im"],
295
+ d["probe_real"],
296
+ spect_re,
297
+ spect_im,
298
+ cp.int32(d["n_scat"]),
299
+ cp.float32(d["kw_init"]),
300
+ cp.float32(d["alpha_init"]),
301
+ cp.float32(d["kw_step"]),
302
+ cp.float32(d["alpha_step"]),
303
+ cp.float32(d["min_dist"]),
304
+ cp.float32(d["seg_len"]),
305
+ cp.float32(d["center_kw"]),
306
+ cp.float32(d["inv_nsub"]),
307
+ cp.float32(d["radius_v"]),
308
+ cp.float32(d["apex_offset"]),
309
+ )
310
+
311
+ kernel(
312
+ grid=(_GRID_BLOCKS, 1, 1),
313
+ block=(_TG_SIZE, 1, 1),
314
+ args=args,
315
+ shared_mem=shmem,
316
+ )
317
+
318
+ # Row-major (n_elem, n_freq) -> column-major (n_freq, n_elem) complex64
319
+ # to match metal_simus / _simus_freq_outer_python output convention.
320
+ spect = (spect_re + 1j * spect_im).reshape(n_elem, n_freq).T
321
+ return spect.astype(cp.complex64)