FastSIMUS 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_simus/__init__.py +33 -0
- fast_simus/_pfield_math.py +261 -0
- fast_simus/_pfield_strategies.py +203 -0
- fast_simus/_simus_strategies.py +210 -0
- fast_simus/backends/__init__.py +1 -0
- fast_simus/backends/mlx.py +101 -0
- fast_simus/kernels/__init__.py +9 -0
- fast_simus/kernels/cuda_simus.py +321 -0
- fast_simus/kernels/metal_pfield.py +219 -0
- fast_simus/kernels/metal_simus.py +377 -0
- fast_simus/kernels/pfield.metal +97 -0
- fast_simus/kernels/simus_fused.cu +332 -0
- fast_simus/kernels/simus_rx_simd.metal +128 -0
- fast_simus/kernels/simus_tx_tiled.metal +175 -0
- fast_simus/medium_params.py +22 -0
- fast_simus/pfield.py +475 -0
- fast_simus/py.typed +0 -0
- fast_simus/simus.py +567 -0
- fast_simus/spectrum.py +107 -0
- fast_simus/transducer_params.py +160 -0
- fast_simus/transducer_presets.py +102 -0
- fast_simus/tx_delay.py +276 -0
- fast_simus/utils/__init__.py +5 -0
- fast_simus/utils/_array_api.py +294 -0
- fast_simus/utils/geometry.py +88 -0
- fastsimus-0.0.1.dist-info/METADATA +594 -0
- fastsimus-0.0.1.dist-info/RECORD +28 -0
- fastsimus-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
"""Custom Metal kernel for pfield computation on Apple Silicon.
|
|
2
|
+
|
|
3
|
+
Fuses geometry, phase initialization, and frequency sweep into a single
|
|
4
|
+
GPU kernel. One thread per grid point computes the full pressure contribution
|
|
5
|
+
on-the-fly, avoiding large intermediate arrays.
|
|
6
|
+
|
|
7
|
+
Requires: MLX (mlx package) on Apple Silicon.
|
|
8
|
+
|
|
9
|
+
Limitations:
|
|
10
|
+
- Soft baffle only (BaffleType.SOFT assumed)
|
|
11
|
+
- Center-frequency directivity only (full_frequency_directivity=False)
|
|
12
|
+
- Linear arrays only (convex array support needs testing)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from math import inf, pi
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
20
|
+
|
|
21
|
+
import mlx.core as mx
|
|
22
|
+
|
|
23
|
+
from fast_simus._pfield_math import NEPER_TO_DB, _subelement_centroids
|
|
24
|
+
from fast_simus.medium_params import MediumParams
|
|
25
|
+
from fast_simus.transducer_params import TransducerParams
|
|
26
|
+
from fast_simus.utils._array_api import Array, _ArrayNamespace
|
|
27
|
+
from fast_simus.utils.geometry import element_positions
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from fast_simus.pfield import PfieldPlan
|
|
31
|
+
|
|
32
|
+
_metal_source_cache: str | None = None
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _load_kernel_source() -> str:
|
|
36
|
+
"""Read the Metal kernel body from ``pfield.metal`` (cached)."""
|
|
37
|
+
global _metal_source_cache
|
|
38
|
+
if _metal_source_cache is None:
|
|
39
|
+
_metal_source_cache = (Path(__file__).parent / "pfield.metal").read_text()
|
|
40
|
+
return _metal_source_cache
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_kernel_cache: dict[tuple[int, int, int], Any] = {}
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def build_pfield_kernel(n_elem: int, n_sub: int, n_freq: int) -> Any:
|
|
47
|
+
"""Build (or retrieve cached) Metal kernel for given dimensions.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
n_elem: Number of transducer elements.
|
|
51
|
+
n_sub: Number of sub-elements per element.
|
|
52
|
+
n_freq: Number of frequency samples.
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
Compiled Metal kernel callable.
|
|
56
|
+
"""
|
|
57
|
+
key = (n_elem, n_sub, n_freq)
|
|
58
|
+
if key in _kernel_cache:
|
|
59
|
+
return _kernel_cache[key]
|
|
60
|
+
|
|
61
|
+
n_es = n_elem * n_sub
|
|
62
|
+
header = f"#define N_ELEM {n_elem}\n#define N_SUB {n_sub}\n#define N_FREQ {n_freq}\n#define N_ES {n_es}\n"
|
|
63
|
+
kernel = mx.fast.metal_kernel(
|
|
64
|
+
name=f"pfield_{n_elem}_{n_sub}_{n_freq}",
|
|
65
|
+
input_names=[
|
|
66
|
+
"grid_x",
|
|
67
|
+
"grid_z",
|
|
68
|
+
"elem_x",
|
|
69
|
+
"elem_z",
|
|
70
|
+
"theta_e",
|
|
71
|
+
"sub_dx",
|
|
72
|
+
"sub_dz",
|
|
73
|
+
"da_init_re",
|
|
74
|
+
"da_init_im",
|
|
75
|
+
"da_step_re",
|
|
76
|
+
"da_step_im",
|
|
77
|
+
"pp_mag_sq",
|
|
78
|
+
"is_out",
|
|
79
|
+
"scalars",
|
|
80
|
+
],
|
|
81
|
+
output_names=["pressure"],
|
|
82
|
+
header=header,
|
|
83
|
+
source=_load_kernel_source(),
|
|
84
|
+
)
|
|
85
|
+
_kernel_cache[key] = kernel
|
|
86
|
+
return kernel
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def pfield_metal(
|
|
90
|
+
positions: mx.array,
|
|
91
|
+
params: TransducerParams,
|
|
92
|
+
plan: PfieldPlan,
|
|
93
|
+
medium: MediumParams,
|
|
94
|
+
delays_clean: mx.array,
|
|
95
|
+
tx_apodization: mx.array,
|
|
96
|
+
) -> mx.array:
|
|
97
|
+
"""Compute pressure field using a custom Metal kernel.
|
|
98
|
+
|
|
99
|
+
Computes geometry on-the-fly per grid point, avoiding large intermediate
|
|
100
|
+
arrays (*grid, n_elements, n_sub). Returns raw pressure accumulation
|
|
101
|
+
(sum of |P_k|^2 * correction), NOT the final sqrt -- the caller applies
|
|
102
|
+
sqrt after the dispatch block.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
positions: Grid positions (x, z) in meters. Shape ``(*grid_shape, 2)``.
|
|
106
|
+
params: Transducer parameters.
|
|
107
|
+
plan: Precomputed frequency plan from ``pfield_precompute``.
|
|
108
|
+
medium: Medium parameters.
|
|
109
|
+
delays_clean: NaN-cleaned delays. Shape ``(n_elements,)``.
|
|
110
|
+
tx_apodization: Per-element apodization (NaN-zeroed). Shape ``(n_elements,)``.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
Raw pressure accumulation, shape ``(*grid_shape,)``.
|
|
114
|
+
Caller must apply ``xp.sqrt(result)`` to get RMS pressure.
|
|
115
|
+
"""
|
|
116
|
+
c = medium.speed_of_sound
|
|
117
|
+
alpha = medium.attenuation
|
|
118
|
+
n_elem = params.n_elements
|
|
119
|
+
n_sub = plan.n_sub
|
|
120
|
+
n_freq = int(plan.selected_freqs.shape[0])
|
|
121
|
+
grid_shape = positions.shape[:-1]
|
|
122
|
+
|
|
123
|
+
# Element geometry
|
|
124
|
+
elem_pos, theta_e, apex_offset = element_positions(
|
|
125
|
+
n_elem,
|
|
126
|
+
params.pitch,
|
|
127
|
+
params.radius,
|
|
128
|
+
cast(_ArrayNamespace, mx),
|
|
129
|
+
)
|
|
130
|
+
if theta_e is None:
|
|
131
|
+
theta_e = mx.zeros(n_elem, dtype=mx.float32)
|
|
132
|
+
|
|
133
|
+
# Subelement offsets -- reuse shared geometry, reshape to flat (n_elem*n_sub,)
|
|
134
|
+
xp_mx = cast(_ArrayNamespace, mx)
|
|
135
|
+
offsets = _subelement_centroids(params.element_width, n_sub, cast("Array", theta_e), xp_mx)
|
|
136
|
+
sub_dx = cast(mx.array, offsets[..., 0]).reshape(-1)
|
|
137
|
+
sub_dz = cast(mx.array, offsets[..., 1]).reshape(-1)
|
|
138
|
+
|
|
139
|
+
# is_out mask (float32: 1.0=out, 0.0=in)
|
|
140
|
+
x_flat = positions[..., 0].reshape(-1)
|
|
141
|
+
z_flat = positions[..., 1].reshape(-1)
|
|
142
|
+
is_out = (z_flat < 0).astype(mx.float32)
|
|
143
|
+
if params.radius != inf:
|
|
144
|
+
in_arc = (x_flat**2 + (z_flat + apex_offset) ** 2) <= params.radius**2
|
|
145
|
+
is_out = mx.maximum(is_out, in_arc.astype(mx.float32))
|
|
146
|
+
|
|
147
|
+
# Derive freq_start / freq_step from the canonical selected_freqs array.
|
|
148
|
+
freq_start = float(plan.selected_freqs[0])
|
|
149
|
+
freq_step = float(plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freq > 1 else 0.0
|
|
150
|
+
|
|
151
|
+
# Delay+apodization split into real/imag
|
|
152
|
+
ph_init = mx.array(2.0 * pi * freq_start, dtype=mx.float32) * delays_clean
|
|
153
|
+
da_init_re = (mx.cos(ph_init) * tx_apodization).astype(mx.float32)
|
|
154
|
+
da_init_im = (mx.sin(ph_init) * tx_apodization).astype(mx.float32)
|
|
155
|
+
|
|
156
|
+
ph_step = mx.array(2.0 * pi * freq_step, dtype=mx.float32) * delays_clean
|
|
157
|
+
da_step_re = mx.cos(ph_step).astype(mx.float32)
|
|
158
|
+
da_step_im = mx.sin(ph_step).astype(mx.float32)
|
|
159
|
+
|
|
160
|
+
# |pulse_spectrum * probe_spectrum|^2
|
|
161
|
+
_pulse = cast(mx.array, plan.pulse_spectrum)
|
|
162
|
+
_probe = cast(mx.array, plan.probe_spectrum)
|
|
163
|
+
pp_mag_sq = mx.abs(_pulse).astype(mx.float32) ** 2 * _probe.astype(mx.float32) ** 2
|
|
164
|
+
|
|
165
|
+
# Scalar physics parameters
|
|
166
|
+
wavenumber_init = 2.0 * pi * freq_start / c
|
|
167
|
+
attenuation_init = alpha / NEPER_TO_DB * freq_start / 1e6 * 1e2
|
|
168
|
+
wavenumber_step = 2.0 * pi * freq_step / c
|
|
169
|
+
attenuation_step = alpha / NEPER_TO_DB * freq_step / 1e6 * 1e2
|
|
170
|
+
min_distance = c / params.freq_center / 2.0
|
|
171
|
+
center_wavenumber = 2.0 * pi * params.freq_center / c
|
|
172
|
+
# 1/n_sub^2 because kernel sums (not means) over sub-elements.
|
|
173
|
+
# correction_factor is applied by the caller uniformly across all strategies.
|
|
174
|
+
effective_correction = 1.0 / (n_sub**2)
|
|
175
|
+
|
|
176
|
+
scalars = mx.array(
|
|
177
|
+
[
|
|
178
|
+
wavenumber_init,
|
|
179
|
+
attenuation_init,
|
|
180
|
+
wavenumber_step,
|
|
181
|
+
attenuation_step,
|
|
182
|
+
min_distance,
|
|
183
|
+
plan.seg_length,
|
|
184
|
+
center_wavenumber,
|
|
185
|
+
effective_correction,
|
|
186
|
+
],
|
|
187
|
+
dtype=mx.float32,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Build kernel and dispatch
|
|
191
|
+
n_grid = int(x_flat.shape[0])
|
|
192
|
+
kernel = build_pfield_kernel(n_elem, n_sub, n_freq)
|
|
193
|
+
|
|
194
|
+
outputs = kernel(
|
|
195
|
+
inputs=[
|
|
196
|
+
x_flat.astype(mx.float32),
|
|
197
|
+
z_flat.astype(mx.float32),
|
|
198
|
+
elem_pos[:, 0].astype(mx.float32),
|
|
199
|
+
elem_pos[:, 1].astype(mx.float32),
|
|
200
|
+
theta_e.astype(mx.float32),
|
|
201
|
+
sub_dx.astype(mx.float32),
|
|
202
|
+
sub_dz.astype(mx.float32),
|
|
203
|
+
da_init_re,
|
|
204
|
+
da_init_im,
|
|
205
|
+
da_step_re,
|
|
206
|
+
da_step_im,
|
|
207
|
+
pp_mag_sq,
|
|
208
|
+
is_out.astype(mx.float32),
|
|
209
|
+
scalars,
|
|
210
|
+
],
|
|
211
|
+
output_shapes=[(n_grid,)],
|
|
212
|
+
output_dtypes=[mx.float32],
|
|
213
|
+
grid=(n_grid, 1, 1),
|
|
214
|
+
threadgroup=(256, 1, 1),
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# Return raw accumulation (acc / n_sub^2). The caller applies
|
|
218
|
+
# sqrt(pressure_accum * correction_factor) uniformly for all strategies.
|
|
219
|
+
return outputs[0].reshape(grid_shape)
|
|
@@ -0,0 +1,377 @@
|
|
|
1
|
+
"""Custom Metal kernel for simus RF spectrum on Apple Silicon.
|
|
2
|
+
|
|
3
|
+
Two-kernel architecture for optimal GPU occupancy:
|
|
4
|
+
- Kernel A (TX): Element-tiled progression with shared-memory geometry.
|
|
5
|
+
One threadgroup per scatterer; threads cooperatively compute geometry,
|
|
6
|
+
then each thread processes sub-element tiles with ALU-only geometric
|
|
7
|
+
progression. TILE_SE=16, threadgroup=64.
|
|
8
|
+
- Kernel B (RX): SIMD-reduce RX with SCAT_REDUCE scatterers per
|
|
9
|
+
threadgroup. Adjacent SIMD threads handle the same element from
|
|
10
|
+
different scatterers and use simd_shuffle_xor to sum contributions
|
|
11
|
+
before a single atomic write. Cuts atomic ops by SCAT_REDUCE (2x)
|
|
12
|
+
while preserving coalesced output access.
|
|
13
|
+
Threadgroup size = N_ELEM * SCAT_REDUCE (128 for P4-2v with SR=2).
|
|
14
|
+
|
|
15
|
+
For large scatterer counts, scatterers are processed in chunks that fit
|
|
16
|
+
within ``MAX_TX_INTERMEDIATE_BYTES``, with the split-path spectrum
|
|
17
|
+
accumulated across chunks via simple addition.
|
|
18
|
+
|
|
19
|
+
Requires: MLX (mlx package) on Apple Silicon.
|
|
20
|
+
|
|
21
|
+
Limitations:
|
|
22
|
+
- Soft baffle only (BaffleType.SOFT assumed)
|
|
23
|
+
- Center-frequency directivity only (full_frequency_directivity=False)
|
|
24
|
+
- Linear arrays only (convex array support needs testing)
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from __future__ import annotations
|
|
28
|
+
|
|
29
|
+
from math import inf, pi
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
32
|
+
|
|
33
|
+
import mlx.core as mx
|
|
34
|
+
|
|
35
|
+
from fast_simus._pfield_math import NEPER_TO_DB, _subelement_centroids
|
|
36
|
+
from fast_simus.medium_params import MediumParams
|
|
37
|
+
from fast_simus.transducer_params import TransducerParams
|
|
38
|
+
from fast_simus.utils._array_api import Array, _ArrayNamespace
|
|
39
|
+
from fast_simus.utils.geometry import element_positions
|
|
40
|
+
|
|
41
|
+
if TYPE_CHECKING:
|
|
42
|
+
from fast_simus.simus import SimusPlan
|
|
43
|
+
|
|
44
|
+
_KERNELS_DIR = Path(__file__).parent
|
|
45
|
+
|
|
46
|
+
MAX_TX_INTERMEDIATE_BYTES = 256 * 1024 * 1024 # 256 MB
|
|
47
|
+
|
|
48
|
+
_TX_TILE_SE = 16
|
|
49
|
+
_TX_TILE_TG = 64
|
|
50
|
+
_RX_SCAT_REDUCE = 2
|
|
51
|
+
|
|
52
|
+
# TX tiled kernel: register pressure is only TILE_SE * 2 * 8 bytes per thread
|
|
53
|
+
# (256 bytes for TILE_SE=16), well within Apple Silicon's register budget.
|
|
54
|
+
_TX_OPTIMAL_CHUNK: dict[int, int] = {
|
|
55
|
+
64: 10_000, # P4-2v class (64 elem, 256B registers/thread)
|
|
56
|
+
128: 5_000, # L11-5v class (128 elem, 256B registers/thread)
|
|
57
|
+
}
|
|
58
|
+
_TX_DEFAULT_CHUNK = 10_000
|
|
59
|
+
|
|
60
|
+
# ---------------------------------------------------------------------------
|
|
61
|
+
# Source caching
|
|
62
|
+
# ---------------------------------------------------------------------------
|
|
63
|
+
_source_cache: dict[str, str] = {}
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _load_source(filename: str) -> str:
|
|
67
|
+
if filename not in _source_cache:
|
|
68
|
+
_source_cache[filename] = (_KERNELS_DIR / filename).read_text()
|
|
69
|
+
return _source_cache[filename]
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
# ---------------------------------------------------------------------------
|
|
73
|
+
# Kernel builders (cached by dimension tuple)
|
|
74
|
+
# ---------------------------------------------------------------------------
|
|
75
|
+
_kernel_cache: dict[tuple, Any] = {}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _make_header(n_elem: int, n_sub: int, n_freq: int, n_scat: int) -> str:
|
|
79
|
+
return (
|
|
80
|
+
f"#define N_ELEM {n_elem}\n"
|
|
81
|
+
f"#define N_SUB {n_sub}\n"
|
|
82
|
+
f"#define N_FREQ {n_freq}\n"
|
|
83
|
+
f"#define N_ES {n_elem * n_sub}\n"
|
|
84
|
+
f"#define N_SCAT {n_scat}\n"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _build_tx(n_elem: int, n_sub: int, n_freq: int, n_scat: int) -> Any:
|
|
89
|
+
"""Build the tiled TX kernel (element-tiled progression with shared geometry)."""
|
|
90
|
+
key = ("tx_tiled", n_elem, n_sub, n_freq, n_scat)
|
|
91
|
+
if key not in _kernel_cache:
|
|
92
|
+
tg = _TX_TILE_TG
|
|
93
|
+
header = (
|
|
94
|
+
_make_header(n_elem, n_sub, n_freq, n_scat)
|
|
95
|
+
+ f"#define TILE_SE {_TX_TILE_SE}\n"
|
|
96
|
+
+ f"#define TG_SIZE {tg}\n"
|
|
97
|
+
+ f"#define MAX_FPT (({n_freq} + {tg} - 1) / {tg})\n"
|
|
98
|
+
)
|
|
99
|
+
_kernel_cache[key] = mx.fast.metal_kernel(
|
|
100
|
+
name=f"simus_tx_tiled_{n_elem}_{n_sub}_{n_freq}_{n_scat}",
|
|
101
|
+
input_names=[
|
|
102
|
+
"scat_x",
|
|
103
|
+
"scat_z",
|
|
104
|
+
"elem_x",
|
|
105
|
+
"elem_z",
|
|
106
|
+
"theta_e",
|
|
107
|
+
"sub_dx",
|
|
108
|
+
"sub_dz",
|
|
109
|
+
"da_init_re",
|
|
110
|
+
"da_init_im",
|
|
111
|
+
"delay_phase_step",
|
|
112
|
+
"pp_re",
|
|
113
|
+
"pp_im",
|
|
114
|
+
"is_out",
|
|
115
|
+
"scalars",
|
|
116
|
+
],
|
|
117
|
+
output_names=["tx_re", "tx_im"],
|
|
118
|
+
header=header,
|
|
119
|
+
source=_load_source("simus_tx_tiled.metal"),
|
|
120
|
+
)
|
|
121
|
+
return _kernel_cache[key]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _build_rx(n_elem: int, n_sub: int, n_freq: int, n_scat: int) -> Any:
|
|
125
|
+
"""Build the SIMD-reduce RX kernel.
|
|
126
|
+
|
|
127
|
+
Groups SCAT_REDUCE scatterers per threadgroup. Adjacent threads handle
|
|
128
|
+
the same element from different scatterers and use simd_shuffle_xor to
|
|
129
|
+
sum contributions before writing a single atomic. Cuts atomic writes by
|
|
130
|
+
SCAT_REDUCE while preserving coalesced output access.
|
|
131
|
+
"""
|
|
132
|
+
sr = _RX_SCAT_REDUCE
|
|
133
|
+
key = ("rx_simd", n_elem, n_sub, n_freq, n_scat, sr)
|
|
134
|
+
if key not in _kernel_cache:
|
|
135
|
+
header = _make_header(n_elem, n_sub, n_freq, n_scat) + f"#define SCAT_REDUCE {sr}\n"
|
|
136
|
+
_kernel_cache[key] = mx.fast.metal_kernel(
|
|
137
|
+
name=f"simus_rx_simd_{n_elem}_{n_sub}_{n_freq}_{n_scat}_{sr}",
|
|
138
|
+
input_names=[
|
|
139
|
+
"scat_x",
|
|
140
|
+
"scat_z",
|
|
141
|
+
"elem_x",
|
|
142
|
+
"elem_z",
|
|
143
|
+
"theta_e",
|
|
144
|
+
"sub_dx",
|
|
145
|
+
"sub_dz",
|
|
146
|
+
"tx_re",
|
|
147
|
+
"tx_im",
|
|
148
|
+
"probe",
|
|
149
|
+
"rc",
|
|
150
|
+
"scalars",
|
|
151
|
+
],
|
|
152
|
+
output_names=["spect_re", "spect_im"],
|
|
153
|
+
header=header,
|
|
154
|
+
source=_load_source("simus_rx_simd.metal"),
|
|
155
|
+
atomic_outputs=True,
|
|
156
|
+
)
|
|
157
|
+
return _kernel_cache[key]
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
# ---------------------------------------------------------------------------
|
|
161
|
+
# Input preparation
|
|
162
|
+
# ---------------------------------------------------------------------------
|
|
163
|
+
def _prepare_common(
|
|
164
|
+
scatterers: mx.array,
|
|
165
|
+
rc: mx.array,
|
|
166
|
+
params: TransducerParams,
|
|
167
|
+
plan: SimusPlan,
|
|
168
|
+
medium: MediumParams,
|
|
169
|
+
delays_clean: mx.array,
|
|
170
|
+
tx_apodization: mx.array,
|
|
171
|
+
) -> dict[str, Any]:
|
|
172
|
+
"""Prepare all GPU-side inputs from plan and params."""
|
|
173
|
+
c = medium.speed_of_sound
|
|
174
|
+
alpha = medium.attenuation
|
|
175
|
+
n_elem = params.n_elements
|
|
176
|
+
n_sub = plan.n_sub
|
|
177
|
+
n_freq = int(plan.selected_freqs.shape[0])
|
|
178
|
+
n_scat = int(scatterers.shape[0])
|
|
179
|
+
|
|
180
|
+
xp_mx = cast(_ArrayNamespace, mx)
|
|
181
|
+
elem_pos, theta_e, apex_offset = element_positions(
|
|
182
|
+
n_elem,
|
|
183
|
+
params.pitch,
|
|
184
|
+
params.radius,
|
|
185
|
+
xp_mx,
|
|
186
|
+
)
|
|
187
|
+
if theta_e is None:
|
|
188
|
+
theta_e = mx.zeros(n_elem, dtype=mx.float32)
|
|
189
|
+
|
|
190
|
+
offsets = _subelement_centroids(params.element_width, n_sub, cast("Array", theta_e), xp_mx)
|
|
191
|
+
sub_dx = cast(mx.array, offsets[..., 0]).reshape(-1)
|
|
192
|
+
sub_dz = cast(mx.array, offsets[..., 1]).reshape(-1)
|
|
193
|
+
|
|
194
|
+
x_flat = scatterers[:, 0]
|
|
195
|
+
z_flat = scatterers[:, 1]
|
|
196
|
+
is_out = (z_flat < 0).astype(mx.float32)
|
|
197
|
+
if params.radius != inf:
|
|
198
|
+
in_arc = (x_flat**2 + (z_flat + apex_offset) ** 2) <= params.radius**2
|
|
199
|
+
is_out = mx.maximum(is_out, in_arc.astype(mx.float32))
|
|
200
|
+
|
|
201
|
+
freq_start = float(plan.selected_freqs[0])
|
|
202
|
+
freq_step = float(plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freq > 1 else 0.0
|
|
203
|
+
|
|
204
|
+
ph_init = mx.array(2.0 * pi * freq_start, dtype=mx.float32) * delays_clean
|
|
205
|
+
da_init_re = (mx.cos(ph_init) * tx_apodization).astype(mx.float32)
|
|
206
|
+
da_init_im = (mx.sin(ph_init) * tx_apodization).astype(mx.float32)
|
|
207
|
+
|
|
208
|
+
ph_step = mx.array(2.0 * pi * freq_step, dtype=mx.float32) * delays_clean
|
|
209
|
+
delay_phase_step = ph_step.astype(mx.float32)
|
|
210
|
+
_pulse = cast(mx.array, plan.pulse_spectrum)
|
|
211
|
+
_probe = cast(mx.array, plan.probe_spectrum)
|
|
212
|
+
pp_complex = _pulse * _probe
|
|
213
|
+
pp_re = mx.real(pp_complex).astype(mx.float32)
|
|
214
|
+
pp_im = mx.imag(pp_complex).astype(mx.float32)
|
|
215
|
+
probe_real = _probe.astype(mx.float32)
|
|
216
|
+
|
|
217
|
+
wavenumber_init = 2.0 * pi * freq_start / c
|
|
218
|
+
attenuation_init = alpha / NEPER_TO_DB * freq_start / 1e6 * 1e2
|
|
219
|
+
wavenumber_step = 2.0 * pi * freq_step / c
|
|
220
|
+
attenuation_step = alpha / NEPER_TO_DB * freq_step / 1e6 * 1e2
|
|
221
|
+
min_distance = c / params.freq_center / 2.0
|
|
222
|
+
center_wavenumber = 2.0 * pi * params.freq_center / c
|
|
223
|
+
inv_n_sub = 1.0 / n_sub
|
|
224
|
+
|
|
225
|
+
scalars = mx.array(
|
|
226
|
+
[
|
|
227
|
+
wavenumber_init,
|
|
228
|
+
attenuation_init,
|
|
229
|
+
wavenumber_step,
|
|
230
|
+
attenuation_step,
|
|
231
|
+
min_distance,
|
|
232
|
+
plan.seg_length,
|
|
233
|
+
center_wavenumber,
|
|
234
|
+
inv_n_sub,
|
|
235
|
+
],
|
|
236
|
+
dtype=mx.float32,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
return {
|
|
240
|
+
"x_flat": x_flat.astype(mx.float32),
|
|
241
|
+
"z_flat": z_flat.astype(mx.float32),
|
|
242
|
+
"elem_x": elem_pos[:, 0].astype(mx.float32),
|
|
243
|
+
"elem_z": elem_pos[:, 1].astype(mx.float32),
|
|
244
|
+
"theta_e": theta_e.astype(mx.float32),
|
|
245
|
+
"sub_dx": sub_dx.astype(mx.float32),
|
|
246
|
+
"sub_dz": sub_dz.astype(mx.float32),
|
|
247
|
+
"da_init_re": da_init_re,
|
|
248
|
+
"da_init_im": da_init_im,
|
|
249
|
+
"delay_phase_step": delay_phase_step,
|
|
250
|
+
"pp_re": pp_re,
|
|
251
|
+
"pp_im": pp_im,
|
|
252
|
+
"probe_real": probe_real,
|
|
253
|
+
"rc": rc.astype(mx.float32),
|
|
254
|
+
"is_out": is_out,
|
|
255
|
+
"scalars": scalars,
|
|
256
|
+
"n_elem": n_elem,
|
|
257
|
+
"n_sub": n_sub,
|
|
258
|
+
"n_freq": n_freq,
|
|
259
|
+
"n_scat": n_scat,
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
# ---------------------------------------------------------------------------
|
|
264
|
+
# Dispatch
|
|
265
|
+
# ---------------------------------------------------------------------------
|
|
266
|
+
def _dispatch_split(d: dict[str, Any]) -> mx.array:
|
|
267
|
+
"""Two-kernel path with automatic chunking for large scatterer counts.
|
|
268
|
+
|
|
269
|
+
Scatterers are processed in chunks that fit the TX intermediate buffer
|
|
270
|
+
within ``MAX_TX_INTERMEDIATE_BYTES``. Each chunk runs TX then RX, and
|
|
271
|
+
the per-chunk spectra are summed on the host.
|
|
272
|
+
"""
|
|
273
|
+
n_elem, n_sub, n_freq, n_scat = d["n_elem"], d["n_sub"], d["n_freq"], d["n_scat"]
|
|
274
|
+
spect_size = n_freq * n_elem
|
|
275
|
+
|
|
276
|
+
# Use TX-throughput-optimal chunk sizes, capped by memory budget
|
|
277
|
+
bytes_per_scat = n_freq * 4 * 2 # float32 re + im
|
|
278
|
+
mem_chunk = max(1, MAX_TX_INTERMEDIATE_BYTES // bytes_per_scat)
|
|
279
|
+
perf_chunk = _TX_OPTIMAL_CHUNK.get(n_elem, _TX_DEFAULT_CHUNK)
|
|
280
|
+
chunk_size = min(mem_chunk, perf_chunk)
|
|
281
|
+
|
|
282
|
+
# Geometry arrays shared across all chunks
|
|
283
|
+
geom_tx = [
|
|
284
|
+
d["elem_x"],
|
|
285
|
+
d["elem_z"],
|
|
286
|
+
d["theta_e"],
|
|
287
|
+
d["sub_dx"],
|
|
288
|
+
d["sub_dz"],
|
|
289
|
+
d["da_init_re"],
|
|
290
|
+
d["da_init_im"],
|
|
291
|
+
d["delay_phase_step"],
|
|
292
|
+
d["pp_re"],
|
|
293
|
+
d["pp_im"],
|
|
294
|
+
]
|
|
295
|
+
geom_rx = [d["elem_x"], d["elem_z"], d["theta_e"], d["sub_dx"], d["sub_dz"]]
|
|
296
|
+
probe = d["probe_real"]
|
|
297
|
+
scalars = d["scalars"]
|
|
298
|
+
|
|
299
|
+
# Build kernels for the standard chunk size (cached, compiled once per probe)
|
|
300
|
+
k_tx = _build_tx(n_elem, n_sub, n_freq, chunk_size)
|
|
301
|
+
k_rx = _build_rx(n_elem, n_sub, n_freq, chunk_size)
|
|
302
|
+
|
|
303
|
+
total_re = mx.zeros(spect_size, dtype=mx.float32)
|
|
304
|
+
total_im = mx.zeros(spect_size, dtype=mx.float32)
|
|
305
|
+
|
|
306
|
+
for start in range(0, n_scat, chunk_size):
|
|
307
|
+
end = min(start + chunk_size, n_scat)
|
|
308
|
+
cn = end - start
|
|
309
|
+
|
|
310
|
+
cx = d["x_flat"][start:end]
|
|
311
|
+
cz = d["z_flat"][start:end]
|
|
312
|
+
crc = d["rc"][start:end]
|
|
313
|
+
c_out = d["is_out"][start:end]
|
|
314
|
+
|
|
315
|
+
# TX kernel: one threadgroup per scatterer (tiled progression)
|
|
316
|
+
tg = _TX_TILE_TG
|
|
317
|
+
tx_out = k_tx(
|
|
318
|
+
inputs=[cx, cz, *geom_tx, c_out, scalars],
|
|
319
|
+
output_shapes=[(cn * n_freq,), (cn * n_freq,)],
|
|
320
|
+
output_dtypes=[mx.float32, mx.float32],
|
|
321
|
+
grid=(cn * tg, 1, 1),
|
|
322
|
+
threadgroup=(tg, 1, 1),
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# RX kernel: SCAT_REDUCE scatterers per threadgroup, SIMD reduction
|
|
326
|
+
sr = _RX_SCAT_REDUCE
|
|
327
|
+
rx_tg = n_elem * sr
|
|
328
|
+
n_tgs = (cn + sr - 1) // sr
|
|
329
|
+
rx_out = k_rx(
|
|
330
|
+
inputs=[cx, cz, *geom_rx, tx_out[0], tx_out[1], probe, crc, scalars],
|
|
331
|
+
output_shapes=[(spect_size,), (spect_size,)],
|
|
332
|
+
output_dtypes=[mx.float32, mx.float32],
|
|
333
|
+
grid=(n_tgs * rx_tg, 1, 1),
|
|
334
|
+
threadgroup=(rx_tg, 1, 1),
|
|
335
|
+
init_value=0.0,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
total_re = total_re + rx_out[0]
|
|
339
|
+
total_im = total_im + rx_out[1]
|
|
340
|
+
|
|
341
|
+
spect_re = total_re.reshape(n_freq, n_elem)
|
|
342
|
+
spect_im = total_im.reshape(n_freq, n_elem)
|
|
343
|
+
return (spect_re + 1j * spect_im).astype(mx.complex64)
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
# ---------------------------------------------------------------------------
|
|
347
|
+
# Public API
|
|
348
|
+
# ---------------------------------------------------------------------------
|
|
349
|
+
def simus_metal(
|
|
350
|
+
scatterers: mx.array,
|
|
351
|
+
rc: mx.array,
|
|
352
|
+
params: TransducerParams,
|
|
353
|
+
plan: SimusPlan,
|
|
354
|
+
medium: MediumParams,
|
|
355
|
+
delays_clean: mx.array,
|
|
356
|
+
tx_apodization: mx.array,
|
|
357
|
+
) -> mx.array:
|
|
358
|
+
"""Compute simus RF spectrum using custom Metal kernels.
|
|
359
|
+
|
|
360
|
+
Uses a two-kernel TX/RX split with automatic chunking for large
|
|
361
|
+
scatterer counts. Each chunk fits within the TX intermediate memory
|
|
362
|
+
budget, and chunk spectra are accumulated via simple addition.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
scatterers: Scatterer positions (x, z) in meters. Shape ``(n_scat, 2)``.
|
|
366
|
+
rc: Reflection coefficients. Shape ``(n_scat,)``.
|
|
367
|
+
params: Transducer parameters.
|
|
368
|
+
plan: Precomputed frequency plan from ``simus_precompute``.
|
|
369
|
+
medium: Medium parameters.
|
|
370
|
+
delays_clean: NaN-cleaned delays. Shape ``(n_elements,)``.
|
|
371
|
+
tx_apodization: Per-element apodization (NaN-zeroed). Shape ``(n_elements,)``.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
Complex RF spectrum, shape ``(n_freq, n_elements)``.
|
|
375
|
+
"""
|
|
376
|
+
d = _prepare_common(scatterers, rc, params, plan, medium, delays_clean, tx_apodization)
|
|
377
|
+
return _dispatch_split(d)
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
// Kernel body for pfield pressure-field computation via mx.fast.metal_kernel().
|
|
2
|
+
//
|
|
3
|
+
// This file contains ONLY the kernel body -- the code that runs inside the
|
|
4
|
+
// auto-generated [[kernel]] void ...() { ... } wrapper. mx.fast.metal_kernel()
|
|
5
|
+
// injects input/output buffer parameters automatically based on input_names
|
|
6
|
+
// and output_names.
|
|
7
|
+
//
|
|
8
|
+
// Compile-time constants (injected via header=):
|
|
9
|
+
// N_ELEM -- number of transducer elements
|
|
10
|
+
// N_SUB -- number of sub-elements per element
|
|
11
|
+
// N_FREQ -- number of frequency samples
|
|
12
|
+
// N_ES -- N_ELEM * N_SUB (total element-subelement pairs)
|
|
13
|
+
|
|
14
|
+
uint g = thread_position_in_grid.x;
|
|
15
|
+
|
|
16
|
+
float gx = grid_x[g];
|
|
17
|
+
float gz = grid_z[g];
|
|
18
|
+
|
|
19
|
+
float kw_init = scalars[0];
|
|
20
|
+
float alpha_init = scalars[1];
|
|
21
|
+
float kw_step = scalars[2];
|
|
22
|
+
float alpha_step = scalars[3];
|
|
23
|
+
float min_dist = scalars[4];
|
|
24
|
+
float seg_len = scalars[5];
|
|
25
|
+
float center_kw = scalars[6];
|
|
26
|
+
float eff_corr = scalars[7];
|
|
27
|
+
|
|
28
|
+
float2 cur[N_ES];
|
|
29
|
+
float2 stp[N_ES];
|
|
30
|
+
|
|
31
|
+
for (int e = 0; e < N_ELEM; e++) {
|
|
32
|
+
float ex = elem_x[e];
|
|
33
|
+
float ez = elem_z[e];
|
|
34
|
+
float te = theta_e[e];
|
|
35
|
+
float di_re = da_init_re[e], di_im = da_init_im[e];
|
|
36
|
+
float ds_re = da_step_re[e], ds_im = da_step_im[e];
|
|
37
|
+
|
|
38
|
+
for (int s = 0; s < N_SUB; s++) {
|
|
39
|
+
int idx = e * N_SUB + s;
|
|
40
|
+
|
|
41
|
+
float dx = gx - ex - sub_dx[idx];
|
|
42
|
+
float dz = gz - ez - sub_dz[idx];
|
|
43
|
+
float r = metal::precise::sqrt(dx * dx + dz * dz);
|
|
44
|
+
float rc = max(r, min_dist);
|
|
45
|
+
|
|
46
|
+
// Angle relative to element normal (unclipped distance for angle)
|
|
47
|
+
float th = metal::precise::asin((dx + 1e-16f) / (r + 1e-16f)) - te;
|
|
48
|
+
|
|
49
|
+
// Soft baffle obliquity
|
|
50
|
+
float obliq = (fabs(th) >= M_PI_2_F) ? 1e-16f : metal::precise::cos(th);
|
|
51
|
+
|
|
52
|
+
// Phase init: obliq/sqrt(r) * exp(-alpha*r + j*wrap(k*r, 2pi))
|
|
53
|
+
float kwr = kw_init * rc;
|
|
54
|
+
float TWO_PI = 2.0f * M_PI_F;
|
|
55
|
+
float ph_wrap = kwr - TWO_PI * metal::precise::floor(kwr / TWO_PI);
|
|
56
|
+
float ai = obliq / metal::precise::sqrt(rc) * metal::precise::exp(-alpha_init * rc);
|
|
57
|
+
float2 pi_ = float2(ai * metal::precise::cos(ph_wrap),
|
|
58
|
+
ai * metal::precise::sin(ph_wrap));
|
|
59
|
+
|
|
60
|
+
// Phase step: exp((-alpha_step + j*k_step) * r)
|
|
61
|
+
float as_ = metal::precise::exp(-alpha_step * rc);
|
|
62
|
+
float phs = kw_step * rc;
|
|
63
|
+
float2 ps_ = float2(as_ * metal::precise::cos(phs),
|
|
64
|
+
as_ * metal::precise::sin(phs));
|
|
65
|
+
|
|
66
|
+
// Center-frequency sinc directivity
|
|
67
|
+
float sa = center_kw * seg_len * 0.5f * metal::precise::sin(th);
|
|
68
|
+
float sv = (fabs(sa) < 1e-8f) ? 1.0f : metal::precise::sin(sa) / sa;
|
|
69
|
+
pi_ *= sv;
|
|
70
|
+
|
|
71
|
+
// Absorb delay+apodization (complex multiply)
|
|
72
|
+
cur[idx] = float2(
|
|
73
|
+
pi_.x * di_re - pi_.y * di_im,
|
|
74
|
+
pi_.x * di_im + pi_.y * di_re
|
|
75
|
+
);
|
|
76
|
+
stp[idx] = float2(
|
|
77
|
+
ps_.x * ds_re - ps_.y * ds_im,
|
|
78
|
+
ps_.x * ds_im + ps_.y * ds_re
|
|
79
|
+
);
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
// Frequency sweep: accumulate sum_f |pulse_probe_f|^2 * |sum_es phase_es_f|^2
|
|
84
|
+
float acc = 0.0f;
|
|
85
|
+
for (int f = 0; f < N_FREQ; f++) {
|
|
86
|
+
float sr = 0.0f, si = 0.0f;
|
|
87
|
+
for (int j = 0; j < N_ES; j++) {
|
|
88
|
+
sr += cur[j].x;
|
|
89
|
+
si += cur[j].y;
|
|
90
|
+
float cr = cur[j].x, ci = cur[j].y;
|
|
91
|
+
float tr = stp[j].x, ti = stp[j].y;
|
|
92
|
+
cur[j] = float2(cr * tr - ci * ti, cr * ti + ci * tr);
|
|
93
|
+
}
|
|
94
|
+
acc += pp_mag_sq[f] * (sr * sr + si * si);
|
|
95
|
+
}
|
|
96
|
+
|
|
97
|
+
pressure[g] = (is_out[g] > 0.5f) ? 0.0f : acc * eff_corr;
|