FastSIMUS 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,219 @@
1
+ """Custom Metal kernel for pfield computation on Apple Silicon.
2
+
3
+ Fuses geometry, phase initialization, and frequency sweep into a single
4
+ GPU kernel. One thread per grid point computes the full pressure contribution
5
+ on-the-fly, avoiding large intermediate arrays.
6
+
7
+ Requires: MLX (mlx package) on Apple Silicon.
8
+
9
+ Limitations:
10
+ - Soft baffle only (BaffleType.SOFT assumed)
11
+ - Center-frequency directivity only (full_frequency_directivity=False)
12
+ - Linear arrays only (convex array support needs testing)
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ from math import inf, pi
18
+ from pathlib import Path
19
+ from typing import TYPE_CHECKING, Any, cast
20
+
21
+ import mlx.core as mx
22
+
23
+ from fast_simus._pfield_math import NEPER_TO_DB, _subelement_centroids
24
+ from fast_simus.medium_params import MediumParams
25
+ from fast_simus.transducer_params import TransducerParams
26
+ from fast_simus.utils._array_api import Array, _ArrayNamespace
27
+ from fast_simus.utils.geometry import element_positions
28
+
29
+ if TYPE_CHECKING:
30
+ from fast_simus.pfield import PfieldPlan
31
+
32
+ _metal_source_cache: str | None = None
33
+
34
+
35
+ def _load_kernel_source() -> str:
36
+ """Read the Metal kernel body from ``pfield.metal`` (cached)."""
37
+ global _metal_source_cache
38
+ if _metal_source_cache is None:
39
+ _metal_source_cache = (Path(__file__).parent / "pfield.metal").read_text()
40
+ return _metal_source_cache
41
+
42
+
43
+ _kernel_cache: dict[tuple[int, int, int], Any] = {}
44
+
45
+
46
+ def build_pfield_kernel(n_elem: int, n_sub: int, n_freq: int) -> Any:
47
+ """Build (or retrieve cached) Metal kernel for given dimensions.
48
+
49
+ Args:
50
+ n_elem: Number of transducer elements.
51
+ n_sub: Number of sub-elements per element.
52
+ n_freq: Number of frequency samples.
53
+
54
+ Returns:
55
+ Compiled Metal kernel callable.
56
+ """
57
+ key = (n_elem, n_sub, n_freq)
58
+ if key in _kernel_cache:
59
+ return _kernel_cache[key]
60
+
61
+ n_es = n_elem * n_sub
62
+ header = f"#define N_ELEM {n_elem}\n#define N_SUB {n_sub}\n#define N_FREQ {n_freq}\n#define N_ES {n_es}\n"
63
+ kernel = mx.fast.metal_kernel(
64
+ name=f"pfield_{n_elem}_{n_sub}_{n_freq}",
65
+ input_names=[
66
+ "grid_x",
67
+ "grid_z",
68
+ "elem_x",
69
+ "elem_z",
70
+ "theta_e",
71
+ "sub_dx",
72
+ "sub_dz",
73
+ "da_init_re",
74
+ "da_init_im",
75
+ "da_step_re",
76
+ "da_step_im",
77
+ "pp_mag_sq",
78
+ "is_out",
79
+ "scalars",
80
+ ],
81
+ output_names=["pressure"],
82
+ header=header,
83
+ source=_load_kernel_source(),
84
+ )
85
+ _kernel_cache[key] = kernel
86
+ return kernel
87
+
88
+
89
+ def pfield_metal(
90
+ positions: mx.array,
91
+ params: TransducerParams,
92
+ plan: PfieldPlan,
93
+ medium: MediumParams,
94
+ delays_clean: mx.array,
95
+ tx_apodization: mx.array,
96
+ ) -> mx.array:
97
+ """Compute pressure field using a custom Metal kernel.
98
+
99
+ Computes geometry on-the-fly per grid point, avoiding large intermediate
100
+ arrays (*grid, n_elements, n_sub). Returns raw pressure accumulation
101
+ (sum of |P_k|^2 * correction), NOT the final sqrt -- the caller applies
102
+ sqrt after the dispatch block.
103
+
104
+ Args:
105
+ positions: Grid positions (x, z) in meters. Shape ``(*grid_shape, 2)``.
106
+ params: Transducer parameters.
107
+ plan: Precomputed frequency plan from ``pfield_precompute``.
108
+ medium: Medium parameters.
109
+ delays_clean: NaN-cleaned delays. Shape ``(n_elements,)``.
110
+ tx_apodization: Per-element apodization (NaN-zeroed). Shape ``(n_elements,)``.
111
+
112
+ Returns:
113
+ Raw pressure accumulation, shape ``(*grid_shape,)``.
114
+ Caller must apply ``xp.sqrt(result)`` to get RMS pressure.
115
+ """
116
+ c = medium.speed_of_sound
117
+ alpha = medium.attenuation
118
+ n_elem = params.n_elements
119
+ n_sub = plan.n_sub
120
+ n_freq = int(plan.selected_freqs.shape[0])
121
+ grid_shape = positions.shape[:-1]
122
+
123
+ # Element geometry
124
+ elem_pos, theta_e, apex_offset = element_positions(
125
+ n_elem,
126
+ params.pitch,
127
+ params.radius,
128
+ cast(_ArrayNamespace, mx),
129
+ )
130
+ if theta_e is None:
131
+ theta_e = mx.zeros(n_elem, dtype=mx.float32)
132
+
133
+ # Subelement offsets -- reuse shared geometry, reshape to flat (n_elem*n_sub,)
134
+ xp_mx = cast(_ArrayNamespace, mx)
135
+ offsets = _subelement_centroids(params.element_width, n_sub, cast("Array", theta_e), xp_mx)
136
+ sub_dx = cast(mx.array, offsets[..., 0]).reshape(-1)
137
+ sub_dz = cast(mx.array, offsets[..., 1]).reshape(-1)
138
+
139
+ # is_out mask (float32: 1.0=out, 0.0=in)
140
+ x_flat = positions[..., 0].reshape(-1)
141
+ z_flat = positions[..., 1].reshape(-1)
142
+ is_out = (z_flat < 0).astype(mx.float32)
143
+ if params.radius != inf:
144
+ in_arc = (x_flat**2 + (z_flat + apex_offset) ** 2) <= params.radius**2
145
+ is_out = mx.maximum(is_out, in_arc.astype(mx.float32))
146
+
147
+ # Derive freq_start / freq_step from the canonical selected_freqs array.
148
+ freq_start = float(plan.selected_freqs[0])
149
+ freq_step = float(plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freq > 1 else 0.0
150
+
151
+ # Delay+apodization split into real/imag
152
+ ph_init = mx.array(2.0 * pi * freq_start, dtype=mx.float32) * delays_clean
153
+ da_init_re = (mx.cos(ph_init) * tx_apodization).astype(mx.float32)
154
+ da_init_im = (mx.sin(ph_init) * tx_apodization).astype(mx.float32)
155
+
156
+ ph_step = mx.array(2.0 * pi * freq_step, dtype=mx.float32) * delays_clean
157
+ da_step_re = mx.cos(ph_step).astype(mx.float32)
158
+ da_step_im = mx.sin(ph_step).astype(mx.float32)
159
+
160
+ # |pulse_spectrum * probe_spectrum|^2
161
+ _pulse = cast(mx.array, plan.pulse_spectrum)
162
+ _probe = cast(mx.array, plan.probe_spectrum)
163
+ pp_mag_sq = mx.abs(_pulse).astype(mx.float32) ** 2 * _probe.astype(mx.float32) ** 2
164
+
165
+ # Scalar physics parameters
166
+ wavenumber_init = 2.0 * pi * freq_start / c
167
+ attenuation_init = alpha / NEPER_TO_DB * freq_start / 1e6 * 1e2
168
+ wavenumber_step = 2.0 * pi * freq_step / c
169
+ attenuation_step = alpha / NEPER_TO_DB * freq_step / 1e6 * 1e2
170
+ min_distance = c / params.freq_center / 2.0
171
+ center_wavenumber = 2.0 * pi * params.freq_center / c
172
+ # 1/n_sub^2 because kernel sums (not means) over sub-elements.
173
+ # correction_factor is applied by the caller uniformly across all strategies.
174
+ effective_correction = 1.0 / (n_sub**2)
175
+
176
+ scalars = mx.array(
177
+ [
178
+ wavenumber_init,
179
+ attenuation_init,
180
+ wavenumber_step,
181
+ attenuation_step,
182
+ min_distance,
183
+ plan.seg_length,
184
+ center_wavenumber,
185
+ effective_correction,
186
+ ],
187
+ dtype=mx.float32,
188
+ )
189
+
190
+ # Build kernel and dispatch
191
+ n_grid = int(x_flat.shape[0])
192
+ kernel = build_pfield_kernel(n_elem, n_sub, n_freq)
193
+
194
+ outputs = kernel(
195
+ inputs=[
196
+ x_flat.astype(mx.float32),
197
+ z_flat.astype(mx.float32),
198
+ elem_pos[:, 0].astype(mx.float32),
199
+ elem_pos[:, 1].astype(mx.float32),
200
+ theta_e.astype(mx.float32),
201
+ sub_dx.astype(mx.float32),
202
+ sub_dz.astype(mx.float32),
203
+ da_init_re,
204
+ da_init_im,
205
+ da_step_re,
206
+ da_step_im,
207
+ pp_mag_sq,
208
+ is_out.astype(mx.float32),
209
+ scalars,
210
+ ],
211
+ output_shapes=[(n_grid,)],
212
+ output_dtypes=[mx.float32],
213
+ grid=(n_grid, 1, 1),
214
+ threadgroup=(256, 1, 1),
215
+ )
216
+
217
+ # Return raw accumulation (acc / n_sub^2). The caller applies
218
+ # sqrt(pressure_accum * correction_factor) uniformly for all strategies.
219
+ return outputs[0].reshape(grid_shape)
@@ -0,0 +1,377 @@
1
+ """Custom Metal kernel for simus RF spectrum on Apple Silicon.
2
+
3
+ Two-kernel architecture for optimal GPU occupancy:
4
+ - Kernel A (TX): Element-tiled progression with shared-memory geometry.
5
+ One threadgroup per scatterer; threads cooperatively compute geometry,
6
+ then each thread processes sub-element tiles with ALU-only geometric
7
+ progression. TILE_SE=16, threadgroup=64.
8
+ - Kernel B (RX): SIMD-reduce RX with SCAT_REDUCE scatterers per
9
+ threadgroup. Adjacent SIMD threads handle the same element from
10
+ different scatterers and use simd_shuffle_xor to sum contributions
11
+ before a single atomic write. Cuts atomic ops by SCAT_REDUCE (2x)
12
+ while preserving coalesced output access.
13
+ Threadgroup size = N_ELEM * SCAT_REDUCE (128 for P4-2v with SR=2).
14
+
15
+ For large scatterer counts, scatterers are processed in chunks that fit
16
+ within ``MAX_TX_INTERMEDIATE_BYTES``, with the split-path spectrum
17
+ accumulated across chunks via simple addition.
18
+
19
+ Requires: MLX (mlx package) on Apple Silicon.
20
+
21
+ Limitations:
22
+ - Soft baffle only (BaffleType.SOFT assumed)
23
+ - Center-frequency directivity only (full_frequency_directivity=False)
24
+ - Linear arrays only (convex array support needs testing)
25
+ """
26
+
27
+ from __future__ import annotations
28
+
29
+ from math import inf, pi
30
+ from pathlib import Path
31
+ from typing import TYPE_CHECKING, Any, cast
32
+
33
+ import mlx.core as mx
34
+
35
+ from fast_simus._pfield_math import NEPER_TO_DB, _subelement_centroids
36
+ from fast_simus.medium_params import MediumParams
37
+ from fast_simus.transducer_params import TransducerParams
38
+ from fast_simus.utils._array_api import Array, _ArrayNamespace
39
+ from fast_simus.utils.geometry import element_positions
40
+
41
+ if TYPE_CHECKING:
42
+ from fast_simus.simus import SimusPlan
43
+
44
+ _KERNELS_DIR = Path(__file__).parent
45
+
46
+ MAX_TX_INTERMEDIATE_BYTES = 256 * 1024 * 1024 # 256 MB
47
+
48
+ _TX_TILE_SE = 16
49
+ _TX_TILE_TG = 64
50
+ _RX_SCAT_REDUCE = 2
51
+
52
+ # TX tiled kernel: register pressure is only TILE_SE * 2 * 8 bytes per thread
53
+ # (256 bytes for TILE_SE=16), well within Apple Silicon's register budget.
54
+ _TX_OPTIMAL_CHUNK: dict[int, int] = {
55
+ 64: 10_000, # P4-2v class (64 elem, 256B registers/thread)
56
+ 128: 5_000, # L11-5v class (128 elem, 256B registers/thread)
57
+ }
58
+ _TX_DEFAULT_CHUNK = 10_000
59
+
60
+ # ---------------------------------------------------------------------------
61
+ # Source caching
62
+ # ---------------------------------------------------------------------------
63
+ _source_cache: dict[str, str] = {}
64
+
65
+
66
+ def _load_source(filename: str) -> str:
67
+ if filename not in _source_cache:
68
+ _source_cache[filename] = (_KERNELS_DIR / filename).read_text()
69
+ return _source_cache[filename]
70
+
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Kernel builders (cached by dimension tuple)
74
+ # ---------------------------------------------------------------------------
75
+ _kernel_cache: dict[tuple, Any] = {}
76
+
77
+
78
+ def _make_header(n_elem: int, n_sub: int, n_freq: int, n_scat: int) -> str:
79
+ return (
80
+ f"#define N_ELEM {n_elem}\n"
81
+ f"#define N_SUB {n_sub}\n"
82
+ f"#define N_FREQ {n_freq}\n"
83
+ f"#define N_ES {n_elem * n_sub}\n"
84
+ f"#define N_SCAT {n_scat}\n"
85
+ )
86
+
87
+
88
+ def _build_tx(n_elem: int, n_sub: int, n_freq: int, n_scat: int) -> Any:
89
+ """Build the tiled TX kernel (element-tiled progression with shared geometry)."""
90
+ key = ("tx_tiled", n_elem, n_sub, n_freq, n_scat)
91
+ if key not in _kernel_cache:
92
+ tg = _TX_TILE_TG
93
+ header = (
94
+ _make_header(n_elem, n_sub, n_freq, n_scat)
95
+ + f"#define TILE_SE {_TX_TILE_SE}\n"
96
+ + f"#define TG_SIZE {tg}\n"
97
+ + f"#define MAX_FPT (({n_freq} + {tg} - 1) / {tg})\n"
98
+ )
99
+ _kernel_cache[key] = mx.fast.metal_kernel(
100
+ name=f"simus_tx_tiled_{n_elem}_{n_sub}_{n_freq}_{n_scat}",
101
+ input_names=[
102
+ "scat_x",
103
+ "scat_z",
104
+ "elem_x",
105
+ "elem_z",
106
+ "theta_e",
107
+ "sub_dx",
108
+ "sub_dz",
109
+ "da_init_re",
110
+ "da_init_im",
111
+ "delay_phase_step",
112
+ "pp_re",
113
+ "pp_im",
114
+ "is_out",
115
+ "scalars",
116
+ ],
117
+ output_names=["tx_re", "tx_im"],
118
+ header=header,
119
+ source=_load_source("simus_tx_tiled.metal"),
120
+ )
121
+ return _kernel_cache[key]
122
+
123
+
124
+ def _build_rx(n_elem: int, n_sub: int, n_freq: int, n_scat: int) -> Any:
125
+ """Build the SIMD-reduce RX kernel.
126
+
127
+ Groups SCAT_REDUCE scatterers per threadgroup. Adjacent threads handle
128
+ the same element from different scatterers and use simd_shuffle_xor to
129
+ sum contributions before writing a single atomic. Cuts atomic writes by
130
+ SCAT_REDUCE while preserving coalesced output access.
131
+ """
132
+ sr = _RX_SCAT_REDUCE
133
+ key = ("rx_simd", n_elem, n_sub, n_freq, n_scat, sr)
134
+ if key not in _kernel_cache:
135
+ header = _make_header(n_elem, n_sub, n_freq, n_scat) + f"#define SCAT_REDUCE {sr}\n"
136
+ _kernel_cache[key] = mx.fast.metal_kernel(
137
+ name=f"simus_rx_simd_{n_elem}_{n_sub}_{n_freq}_{n_scat}_{sr}",
138
+ input_names=[
139
+ "scat_x",
140
+ "scat_z",
141
+ "elem_x",
142
+ "elem_z",
143
+ "theta_e",
144
+ "sub_dx",
145
+ "sub_dz",
146
+ "tx_re",
147
+ "tx_im",
148
+ "probe",
149
+ "rc",
150
+ "scalars",
151
+ ],
152
+ output_names=["spect_re", "spect_im"],
153
+ header=header,
154
+ source=_load_source("simus_rx_simd.metal"),
155
+ atomic_outputs=True,
156
+ )
157
+ return _kernel_cache[key]
158
+
159
+
160
+ # ---------------------------------------------------------------------------
161
+ # Input preparation
162
+ # ---------------------------------------------------------------------------
163
+ def _prepare_common(
164
+ scatterers: mx.array,
165
+ rc: mx.array,
166
+ params: TransducerParams,
167
+ plan: SimusPlan,
168
+ medium: MediumParams,
169
+ delays_clean: mx.array,
170
+ tx_apodization: mx.array,
171
+ ) -> dict[str, Any]:
172
+ """Prepare all GPU-side inputs from plan and params."""
173
+ c = medium.speed_of_sound
174
+ alpha = medium.attenuation
175
+ n_elem = params.n_elements
176
+ n_sub = plan.n_sub
177
+ n_freq = int(plan.selected_freqs.shape[0])
178
+ n_scat = int(scatterers.shape[0])
179
+
180
+ xp_mx = cast(_ArrayNamespace, mx)
181
+ elem_pos, theta_e, apex_offset = element_positions(
182
+ n_elem,
183
+ params.pitch,
184
+ params.radius,
185
+ xp_mx,
186
+ )
187
+ if theta_e is None:
188
+ theta_e = mx.zeros(n_elem, dtype=mx.float32)
189
+
190
+ offsets = _subelement_centroids(params.element_width, n_sub, cast("Array", theta_e), xp_mx)
191
+ sub_dx = cast(mx.array, offsets[..., 0]).reshape(-1)
192
+ sub_dz = cast(mx.array, offsets[..., 1]).reshape(-1)
193
+
194
+ x_flat = scatterers[:, 0]
195
+ z_flat = scatterers[:, 1]
196
+ is_out = (z_flat < 0).astype(mx.float32)
197
+ if params.radius != inf:
198
+ in_arc = (x_flat**2 + (z_flat + apex_offset) ** 2) <= params.radius**2
199
+ is_out = mx.maximum(is_out, in_arc.astype(mx.float32))
200
+
201
+ freq_start = float(plan.selected_freqs[0])
202
+ freq_step = float(plan.selected_freqs[1] - plan.selected_freqs[0]) if n_freq > 1 else 0.0
203
+
204
+ ph_init = mx.array(2.0 * pi * freq_start, dtype=mx.float32) * delays_clean
205
+ da_init_re = (mx.cos(ph_init) * tx_apodization).astype(mx.float32)
206
+ da_init_im = (mx.sin(ph_init) * tx_apodization).astype(mx.float32)
207
+
208
+ ph_step = mx.array(2.0 * pi * freq_step, dtype=mx.float32) * delays_clean
209
+ delay_phase_step = ph_step.astype(mx.float32)
210
+ _pulse = cast(mx.array, plan.pulse_spectrum)
211
+ _probe = cast(mx.array, plan.probe_spectrum)
212
+ pp_complex = _pulse * _probe
213
+ pp_re = mx.real(pp_complex).astype(mx.float32)
214
+ pp_im = mx.imag(pp_complex).astype(mx.float32)
215
+ probe_real = _probe.astype(mx.float32)
216
+
217
+ wavenumber_init = 2.0 * pi * freq_start / c
218
+ attenuation_init = alpha / NEPER_TO_DB * freq_start / 1e6 * 1e2
219
+ wavenumber_step = 2.0 * pi * freq_step / c
220
+ attenuation_step = alpha / NEPER_TO_DB * freq_step / 1e6 * 1e2
221
+ min_distance = c / params.freq_center / 2.0
222
+ center_wavenumber = 2.0 * pi * params.freq_center / c
223
+ inv_n_sub = 1.0 / n_sub
224
+
225
+ scalars = mx.array(
226
+ [
227
+ wavenumber_init,
228
+ attenuation_init,
229
+ wavenumber_step,
230
+ attenuation_step,
231
+ min_distance,
232
+ plan.seg_length,
233
+ center_wavenumber,
234
+ inv_n_sub,
235
+ ],
236
+ dtype=mx.float32,
237
+ )
238
+
239
+ return {
240
+ "x_flat": x_flat.astype(mx.float32),
241
+ "z_flat": z_flat.astype(mx.float32),
242
+ "elem_x": elem_pos[:, 0].astype(mx.float32),
243
+ "elem_z": elem_pos[:, 1].astype(mx.float32),
244
+ "theta_e": theta_e.astype(mx.float32),
245
+ "sub_dx": sub_dx.astype(mx.float32),
246
+ "sub_dz": sub_dz.astype(mx.float32),
247
+ "da_init_re": da_init_re,
248
+ "da_init_im": da_init_im,
249
+ "delay_phase_step": delay_phase_step,
250
+ "pp_re": pp_re,
251
+ "pp_im": pp_im,
252
+ "probe_real": probe_real,
253
+ "rc": rc.astype(mx.float32),
254
+ "is_out": is_out,
255
+ "scalars": scalars,
256
+ "n_elem": n_elem,
257
+ "n_sub": n_sub,
258
+ "n_freq": n_freq,
259
+ "n_scat": n_scat,
260
+ }
261
+
262
+
263
+ # ---------------------------------------------------------------------------
264
+ # Dispatch
265
+ # ---------------------------------------------------------------------------
266
+ def _dispatch_split(d: dict[str, Any]) -> mx.array:
267
+ """Two-kernel path with automatic chunking for large scatterer counts.
268
+
269
+ Scatterers are processed in chunks that fit the TX intermediate buffer
270
+ within ``MAX_TX_INTERMEDIATE_BYTES``. Each chunk runs TX then RX, and
271
+ the per-chunk spectra are summed on the host.
272
+ """
273
+ n_elem, n_sub, n_freq, n_scat = d["n_elem"], d["n_sub"], d["n_freq"], d["n_scat"]
274
+ spect_size = n_freq * n_elem
275
+
276
+ # Use TX-throughput-optimal chunk sizes, capped by memory budget
277
+ bytes_per_scat = n_freq * 4 * 2 # float32 re + im
278
+ mem_chunk = max(1, MAX_TX_INTERMEDIATE_BYTES // bytes_per_scat)
279
+ perf_chunk = _TX_OPTIMAL_CHUNK.get(n_elem, _TX_DEFAULT_CHUNK)
280
+ chunk_size = min(mem_chunk, perf_chunk)
281
+
282
+ # Geometry arrays shared across all chunks
283
+ geom_tx = [
284
+ d["elem_x"],
285
+ d["elem_z"],
286
+ d["theta_e"],
287
+ d["sub_dx"],
288
+ d["sub_dz"],
289
+ d["da_init_re"],
290
+ d["da_init_im"],
291
+ d["delay_phase_step"],
292
+ d["pp_re"],
293
+ d["pp_im"],
294
+ ]
295
+ geom_rx = [d["elem_x"], d["elem_z"], d["theta_e"], d["sub_dx"], d["sub_dz"]]
296
+ probe = d["probe_real"]
297
+ scalars = d["scalars"]
298
+
299
+ # Build kernels for the standard chunk size (cached, compiled once per probe)
300
+ k_tx = _build_tx(n_elem, n_sub, n_freq, chunk_size)
301
+ k_rx = _build_rx(n_elem, n_sub, n_freq, chunk_size)
302
+
303
+ total_re = mx.zeros(spect_size, dtype=mx.float32)
304
+ total_im = mx.zeros(spect_size, dtype=mx.float32)
305
+
306
+ for start in range(0, n_scat, chunk_size):
307
+ end = min(start + chunk_size, n_scat)
308
+ cn = end - start
309
+
310
+ cx = d["x_flat"][start:end]
311
+ cz = d["z_flat"][start:end]
312
+ crc = d["rc"][start:end]
313
+ c_out = d["is_out"][start:end]
314
+
315
+ # TX kernel: one threadgroup per scatterer (tiled progression)
316
+ tg = _TX_TILE_TG
317
+ tx_out = k_tx(
318
+ inputs=[cx, cz, *geom_tx, c_out, scalars],
319
+ output_shapes=[(cn * n_freq,), (cn * n_freq,)],
320
+ output_dtypes=[mx.float32, mx.float32],
321
+ grid=(cn * tg, 1, 1),
322
+ threadgroup=(tg, 1, 1),
323
+ )
324
+
325
+ # RX kernel: SCAT_REDUCE scatterers per threadgroup, SIMD reduction
326
+ sr = _RX_SCAT_REDUCE
327
+ rx_tg = n_elem * sr
328
+ n_tgs = (cn + sr - 1) // sr
329
+ rx_out = k_rx(
330
+ inputs=[cx, cz, *geom_rx, tx_out[0], tx_out[1], probe, crc, scalars],
331
+ output_shapes=[(spect_size,), (spect_size,)],
332
+ output_dtypes=[mx.float32, mx.float32],
333
+ grid=(n_tgs * rx_tg, 1, 1),
334
+ threadgroup=(rx_tg, 1, 1),
335
+ init_value=0.0,
336
+ )
337
+
338
+ total_re = total_re + rx_out[0]
339
+ total_im = total_im + rx_out[1]
340
+
341
+ spect_re = total_re.reshape(n_freq, n_elem)
342
+ spect_im = total_im.reshape(n_freq, n_elem)
343
+ return (spect_re + 1j * spect_im).astype(mx.complex64)
344
+
345
+
346
+ # ---------------------------------------------------------------------------
347
+ # Public API
348
+ # ---------------------------------------------------------------------------
349
+ def simus_metal(
350
+ scatterers: mx.array,
351
+ rc: mx.array,
352
+ params: TransducerParams,
353
+ plan: SimusPlan,
354
+ medium: MediumParams,
355
+ delays_clean: mx.array,
356
+ tx_apodization: mx.array,
357
+ ) -> mx.array:
358
+ """Compute simus RF spectrum using custom Metal kernels.
359
+
360
+ Uses a two-kernel TX/RX split with automatic chunking for large
361
+ scatterer counts. Each chunk fits within the TX intermediate memory
362
+ budget, and chunk spectra are accumulated via simple addition.
363
+
364
+ Args:
365
+ scatterers: Scatterer positions (x, z) in meters. Shape ``(n_scat, 2)``.
366
+ rc: Reflection coefficients. Shape ``(n_scat,)``.
367
+ params: Transducer parameters.
368
+ plan: Precomputed frequency plan from ``simus_precompute``.
369
+ medium: Medium parameters.
370
+ delays_clean: NaN-cleaned delays. Shape ``(n_elements,)``.
371
+ tx_apodization: Per-element apodization (NaN-zeroed). Shape ``(n_elements,)``.
372
+
373
+ Returns:
374
+ Complex RF spectrum, shape ``(n_freq, n_elements)``.
375
+ """
376
+ d = _prepare_common(scatterers, rc, params, plan, medium, delays_clean, tx_apodization)
377
+ return _dispatch_split(d)
@@ -0,0 +1,97 @@
1
+ // Kernel body for pfield pressure-field computation via mx.fast.metal_kernel().
2
+ //
3
+ // This file contains ONLY the kernel body -- the code that runs inside the
4
+ // auto-generated [[kernel]] void ...() { ... } wrapper. mx.fast.metal_kernel()
5
+ // injects input/output buffer parameters automatically based on input_names
6
+ // and output_names.
7
+ //
8
+ // Compile-time constants (injected via header=):
9
+ // N_ELEM -- number of transducer elements
10
+ // N_SUB -- number of sub-elements per element
11
+ // N_FREQ -- number of frequency samples
12
+ // N_ES -- N_ELEM * N_SUB (total element-subelement pairs)
13
+
14
+ uint g = thread_position_in_grid.x;
15
+
16
+ float gx = grid_x[g];
17
+ float gz = grid_z[g];
18
+
19
+ float kw_init = scalars[0];
20
+ float alpha_init = scalars[1];
21
+ float kw_step = scalars[2];
22
+ float alpha_step = scalars[3];
23
+ float min_dist = scalars[4];
24
+ float seg_len = scalars[5];
25
+ float center_kw = scalars[6];
26
+ float eff_corr = scalars[7];
27
+
28
+ float2 cur[N_ES];
29
+ float2 stp[N_ES];
30
+
31
+ for (int e = 0; e < N_ELEM; e++) {
32
+ float ex = elem_x[e];
33
+ float ez = elem_z[e];
34
+ float te = theta_e[e];
35
+ float di_re = da_init_re[e], di_im = da_init_im[e];
36
+ float ds_re = da_step_re[e], ds_im = da_step_im[e];
37
+
38
+ for (int s = 0; s < N_SUB; s++) {
39
+ int idx = e * N_SUB + s;
40
+
41
+ float dx = gx - ex - sub_dx[idx];
42
+ float dz = gz - ez - sub_dz[idx];
43
+ float r = metal::precise::sqrt(dx * dx + dz * dz);
44
+ float rc = max(r, min_dist);
45
+
46
+ // Angle relative to element normal (unclipped distance for angle)
47
+ float th = metal::precise::asin((dx + 1e-16f) / (r + 1e-16f)) - te;
48
+
49
+ // Soft baffle obliquity
50
+ float obliq = (fabs(th) >= M_PI_2_F) ? 1e-16f : metal::precise::cos(th);
51
+
52
+ // Phase init: obliq/sqrt(r) * exp(-alpha*r + j*wrap(k*r, 2pi))
53
+ float kwr = kw_init * rc;
54
+ float TWO_PI = 2.0f * M_PI_F;
55
+ float ph_wrap = kwr - TWO_PI * metal::precise::floor(kwr / TWO_PI);
56
+ float ai = obliq / metal::precise::sqrt(rc) * metal::precise::exp(-alpha_init * rc);
57
+ float2 pi_ = float2(ai * metal::precise::cos(ph_wrap),
58
+ ai * metal::precise::sin(ph_wrap));
59
+
60
+ // Phase step: exp((-alpha_step + j*k_step) * r)
61
+ float as_ = metal::precise::exp(-alpha_step * rc);
62
+ float phs = kw_step * rc;
63
+ float2 ps_ = float2(as_ * metal::precise::cos(phs),
64
+ as_ * metal::precise::sin(phs));
65
+
66
+ // Center-frequency sinc directivity
67
+ float sa = center_kw * seg_len * 0.5f * metal::precise::sin(th);
68
+ float sv = (fabs(sa) < 1e-8f) ? 1.0f : metal::precise::sin(sa) / sa;
69
+ pi_ *= sv;
70
+
71
+ // Absorb delay+apodization (complex multiply)
72
+ cur[idx] = float2(
73
+ pi_.x * di_re - pi_.y * di_im,
74
+ pi_.x * di_im + pi_.y * di_re
75
+ );
76
+ stp[idx] = float2(
77
+ ps_.x * ds_re - ps_.y * ds_im,
78
+ ps_.x * ds_im + ps_.y * ds_re
79
+ );
80
+ }
81
+ }
82
+
83
+ // Frequency sweep: accumulate sum_f |pulse_probe_f|^2 * |sum_es phase_es_f|^2
84
+ float acc = 0.0f;
85
+ for (int f = 0; f < N_FREQ; f++) {
86
+ float sr = 0.0f, si = 0.0f;
87
+ for (int j = 0; j < N_ES; j++) {
88
+ sr += cur[j].x;
89
+ si += cur[j].y;
90
+ float cr = cur[j].x, ci = cur[j].y;
91
+ float tr = stp[j].x, ti = stp[j].y;
92
+ cur[j] = float2(cr * tr - ci * ti, cr * ti + ci * tr);
93
+ }
94
+ acc += pp_mag_sq[f] * (sr * sr + si * si);
95
+ }
96
+
97
+ pressure[g] = (is_out[g] > 0.5f) ? 0.0f : acc * eff_corr;