FastSIMUS 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fast_simus/__init__.py +33 -0
- fast_simus/_pfield_math.py +261 -0
- fast_simus/_pfield_strategies.py +203 -0
- fast_simus/_simus_strategies.py +210 -0
- fast_simus/backends/__init__.py +1 -0
- fast_simus/backends/mlx.py +101 -0
- fast_simus/kernels/__init__.py +9 -0
- fast_simus/kernels/cuda_simus.py +321 -0
- fast_simus/kernels/metal_pfield.py +219 -0
- fast_simus/kernels/metal_simus.py +377 -0
- fast_simus/kernels/pfield.metal +97 -0
- fast_simus/kernels/simus_fused.cu +332 -0
- fast_simus/kernels/simus_rx_simd.metal +128 -0
- fast_simus/kernels/simus_tx_tiled.metal +175 -0
- fast_simus/medium_params.py +22 -0
- fast_simus/pfield.py +475 -0
- fast_simus/py.typed +0 -0
- fast_simus/simus.py +567 -0
- fast_simus/spectrum.py +107 -0
- fast_simus/transducer_params.py +160 -0
- fast_simus/transducer_presets.py +102 -0
- fast_simus/tx_delay.py +276 -0
- fast_simus/utils/__init__.py +5 -0
- fast_simus/utils/_array_api.py +294 -0
- fast_simus/utils/geometry.py +88 -0
- fastsimus-0.0.1.dist-info/METADATA +594 -0
- fastsimus-0.0.1.dist-info/RECORD +28 -0
- fastsimus-0.0.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Fused TX+RX SIMUS kernel -- v25c: register-resident TX, sv_arr in
|
|
3
|
+
* shmem (correct, fp32).
|
|
4
|
+
*
|
|
5
|
+
* v25b cached sv_arr[B*ELEM_TILE] in registers (56 floats at B=7 ET=4).
|
|
6
|
+
* That competes with tk_re/tk_im for the 255-reg cap and forces 400 B
|
|
7
|
+
* spill of tk into local memory -- which NCU showed saturates L2 at
|
|
8
|
+
* 76.5 % throughput.
|
|
9
|
+
*
|
|
10
|
+
* v25c drops sv_arr from registers and reads GEO_STP_RX_RE/IM directly
|
|
11
|
+
* from shmem inside the cmul, freeing those 56 regs for tk. Cost: an
|
|
12
|
+
* extra shmem read per (si, et, fi) cmul advance -- bank-conflict free
|
|
13
|
+
* since each thread reads its own row.
|
|
14
|
+
*
|
|
15
|
+
* v25 with one structural fix: every loop that indexes tk_re/tk_im is
|
|
16
|
+
* now `for fi in 0..MAX_FPT` with `#pragma unroll` and predicated
|
|
17
|
+
* validity, so fi is statically known. v25 spilled tk entirely
|
|
18
|
+
* (576 B local mem at B=9 ET=4) because dynamic fi forced
|
|
19
|
+
* tk_re[si*MAX_FPT + fi] off-register. Static fi unrolling makes tk
|
|
20
|
+
* actually register-resident, eliminating the local-memory traffic.
|
|
21
|
+
*
|
|
22
|
+
* Why this is safe: in v11 each thread `lid` writes to sh_tx[si*N_FREQ + f]
|
|
23
|
+
* for f in {lid, lid+TG_SIZE, lid+2*TG_SIZE, ...} during Phase 2, and reads
|
|
24
|
+
* the same slots during Phase 3. There is no cross-thread sharing of TX --
|
|
25
|
+
* the shmem allocation was a temporary, not a broadcast surface.
|
|
26
|
+
*
|
|
27
|
+
* Storing TX in per-thread register arrays tk_re[B_SCAT*MAX_FPT],
|
|
28
|
+
* tk_im[B_SCAT*MAX_FPT] eliminates the dominant shmem cost
|
|
29
|
+
* (2*B_SCAT*N_FREQ floats; 60 KB at B=9 N_FREQ=854) without changing
|
|
30
|
+
* precision or arithmetic. Also lets us drop the pre-Phase-3 sync that
|
|
31
|
+
* was only needed to publish sh_tx writes.
|
|
32
|
+
*
|
|
33
|
+
* Per-thread TX register cost: 2*B_SCAT*MAX_FPT floats. For
|
|
34
|
+
* B=9 N_FREQ=854 TG=128 -> MAX_FPT=7 -> 126 floats. May force 1->something
|
|
35
|
+
* trade vs spill; expected to come out well ahead given shmem savings
|
|
36
|
+
* (76.5 KB -> 16.5 KB at B=9 ET=4 unlocks ~5 blk/SM vs v11's 1).
|
|
37
|
+
*
|
|
38
|
+
* Compile-time: N_ELEM, N_SUB, N_FREQ, N_ES, TILE_SE, TG_SIZE, MAX_FPT,
|
|
39
|
+
* B_SCAT, ELEM_TILE
|
|
40
|
+
*
|
|
41
|
+
* Shared memory: (7*B_SCAT*N_ES + 3*N_ELEM) * 4 bytes
|
|
42
|
+
*/
|
|
43
|
+
|
|
44
|
+
#ifndef M_PI_F
|
|
45
|
+
#define M_PI_F 3.14159265358979323846f
|
|
46
|
+
#endif
|
|
47
|
+
|
|
48
|
+
struct f2 { float x, y; };
|
|
49
|
+
|
|
50
|
+
__device__ __forceinline__ f2 cmul(f2 a, f2 b) {
|
|
51
|
+
return {a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
#define GEO_AMP(s) (shmem + ((0*B_SCAT + (s)) * N_ES))
|
|
55
|
+
#define GEO_KW_R(s) (shmem + ((1*B_SCAT + (s)) * N_ES))
|
|
56
|
+
#define GEO_KR_STEP(s) (shmem + ((2*B_SCAT + (s)) * N_ES))
|
|
57
|
+
#define GEO_ALPHA_R(s) (shmem + ((3*B_SCAT + (s)) * N_ES))
|
|
58
|
+
#define GEO_AR_STEP(s) (shmem + ((4*B_SCAT + (s)) * N_ES))
|
|
59
|
+
#define GEO_STP_RX_RE(s) (shmem + ((5*B_SCAT + (s)) * N_ES))
|
|
60
|
+
#define GEO_STP_RX_IM(s) (shmem + ((6*B_SCAT + (s)) * N_ES))
|
|
61
|
+
|
|
62
|
+
extern "C" __global__
|
|
63
|
+
void simus_fused_kernel(
|
|
64
|
+
const float* __restrict__ scat_x,
|
|
65
|
+
const float* __restrict__ scat_z,
|
|
66
|
+
const float* __restrict__ rc_arr,
|
|
67
|
+
const float* __restrict__ elem_x,
|
|
68
|
+
const float* __restrict__ elem_z,
|
|
69
|
+
const float* __restrict__ cos_te,
|
|
70
|
+
const float* __restrict__ sin_neg_te,
|
|
71
|
+
const float* __restrict__ sub_dx,
|
|
72
|
+
const float* __restrict__ sub_dz,
|
|
73
|
+
const float* __restrict__ da_init_re,
|
|
74
|
+
const float* __restrict__ da_init_im,
|
|
75
|
+
const float* __restrict__ dps,
|
|
76
|
+
const float* __restrict__ pp_re,
|
|
77
|
+
const float* __restrict__ pp_im,
|
|
78
|
+
const float* __restrict__ probe,
|
|
79
|
+
float* __restrict__ spect_re,
|
|
80
|
+
float* __restrict__ spect_im,
|
|
81
|
+
int n_scat,
|
|
82
|
+
float kw_init, float alpha_init,
|
|
83
|
+
float kw_step, float alpha_step,
|
|
84
|
+
float min_dist, float seg_len,
|
|
85
|
+
float center_kw, float inv_nsub,
|
|
86
|
+
float radius, float apex_offset
|
|
87
|
+
) {
|
|
88
|
+
int lid = threadIdx.x;
|
|
89
|
+
float lid_f = (float)lid;
|
|
90
|
+
float stride_f = (float)TG_SIZE;
|
|
91
|
+
|
|
92
|
+
extern __shared__ float shmem[];
|
|
93
|
+
|
|
94
|
+
/* TX moved to per-thread registers; shmem only holds geometry + per-elem broadcast. */
|
|
95
|
+
float tk_re[B_SCAT * MAX_FPT];
|
|
96
|
+
float tk_im[B_SCAT * MAX_FPT];
|
|
97
|
+
|
|
98
|
+
float* sh_da_init_re_l = shmem + 7 * B_SCAT * N_ES;
|
|
99
|
+
float* sh_da_init_im_l = sh_da_init_re_l + N_ELEM;
|
|
100
|
+
float* sh_dps_l = sh_da_init_im_l + N_ELEM;
|
|
101
|
+
|
|
102
|
+
for (int e = lid; e < N_ELEM; e += TG_SIZE) {
|
|
103
|
+
sh_da_init_re_l[e] = da_init_re[e];
|
|
104
|
+
sh_da_init_im_l[e] = da_init_im[e];
|
|
105
|
+
sh_dps_l[e] = dps[e];
|
|
106
|
+
}
|
|
107
|
+
__syncthreads();
|
|
108
|
+
|
|
109
|
+
int my_n_freq = 0;
|
|
110
|
+
for (int f = lid; f < N_FREQ; f += TG_SIZE) my_n_freq++;
|
|
111
|
+
|
|
112
|
+
const int N_TILES = (N_ES + TILE_SE - 1) / TILE_SE;
|
|
113
|
+
const int N_ELEM_GROUPS = (N_ES + ELEM_TILE - 1) / ELEM_TILE;
|
|
114
|
+
|
|
115
|
+
bool out_flag[B_SCAT];
|
|
116
|
+
|
|
117
|
+
for (int scat_base = blockIdx.x * B_SCAT;
|
|
118
|
+
scat_base < n_scat;
|
|
119
|
+
scat_base += gridDim.x * B_SCAT)
|
|
120
|
+
{
|
|
121
|
+
int actual_b = B_SCAT;
|
|
122
|
+
if (scat_base + B_SCAT > n_scat)
|
|
123
|
+
actual_b = n_scat - scat_base;
|
|
124
|
+
|
|
125
|
+
/* Zero-init register TX so si >= actual_b reads finite zeros in
|
|
126
|
+
* Phase 3 (cv is also forced to 0 there, but 0*NaN would propagate). */
|
|
127
|
+
#pragma unroll
|
|
128
|
+
for (int si = 0; si < B_SCAT; si++) {
|
|
129
|
+
#pragma unroll
|
|
130
|
+
for (int fi = 0; fi < MAX_FPT; fi++) {
|
|
131
|
+
tk_re[si * MAX_FPT + fi] = 0.0f;
|
|
132
|
+
tk_im[si * MAX_FPT + fi] = 0.0f;
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
/* ---- Phase 1+2: geometry + TX for each scatterer in batch ---- */
|
|
137
|
+
for (int si = 0; si < actual_b; si++) {
|
|
138
|
+
int scat_idx = scat_base + si;
|
|
139
|
+
float sx = scat_x[scat_idx];
|
|
140
|
+
float sz = scat_z[scat_idx];
|
|
141
|
+
float rc = rc_arr[scat_idx];
|
|
142
|
+
|
|
143
|
+
bool is_out = (sz < 0.0f);
|
|
144
|
+
if (radius < 1e30f) {
|
|
145
|
+
float da = sx, db = sz + apex_offset;
|
|
146
|
+
is_out = is_out || ((da*da + db*db) <= radius*radius);
|
|
147
|
+
}
|
|
148
|
+
out_flag[si] = is_out;
|
|
149
|
+
|
|
150
|
+
for (int se = lid; se < N_ES; se += TG_SIZE) {
|
|
151
|
+
int elem = se / N_SUB;
|
|
152
|
+
float ex_ = elem_x[elem], ez_ = elem_z[elem];
|
|
153
|
+
float ct = cos_te[elem], snt = sin_neg_te[elem];
|
|
154
|
+
float dx = sx - ex_ - sub_dx[se];
|
|
155
|
+
float dz = sz - ez_ - sub_dz[se];
|
|
156
|
+
float r2 = dx*dx + dz*dz;
|
|
157
|
+
float inv_r = rsqrtf(r2 + 1e-30f);
|
|
158
|
+
float r = r2 * inv_r;
|
|
159
|
+
float rc_ = fmaxf(r, min_dist);
|
|
160
|
+
|
|
161
|
+
float sin_th = (dx*ct + dz*snt) * inv_r;
|
|
162
|
+
float cos_th = (dz*ct - dx*snt) * inv_r;
|
|
163
|
+
float obliq = (cos_th <= 0.0f) ? 1e-16f : cos_th;
|
|
164
|
+
float sa = center_kw * seg_len * 0.5f * sin_th;
|
|
165
|
+
float sv = (fabsf(sa) < 1e-8f) ? 1.0f : __fdividef(__sinf(sa), sa);
|
|
166
|
+
|
|
167
|
+
GEO_AMP(si)[se] = obliq * sv * rsqrtf(rc_);
|
|
168
|
+
GEO_KW_R(si)[se] = kw_init * rc_;
|
|
169
|
+
GEO_KR_STEP(si)[se] = kw_step * rc_;
|
|
170
|
+
GEO_ALPHA_R(si)[se] = alpha_init * rc_;
|
|
171
|
+
GEO_AR_STEP(si)[se] = alpha_step * rc_;
|
|
172
|
+
|
|
173
|
+
float stp_phase = stride_f * kw_step * rc_;
|
|
174
|
+
float stp_alpha = stride_f * alpha_step * rc_;
|
|
175
|
+
float sm = expf(-stp_alpha);
|
|
176
|
+
float sp_re, sp_im;
|
|
177
|
+
__sincosf(stp_phase, &sp_im, &sp_re);
|
|
178
|
+
sp_re *= sm; sp_im *= sm;
|
|
179
|
+
GEO_STP_RX_RE(si)[se] = sp_re;
|
|
180
|
+
GEO_STP_RX_IM(si)[se] = sp_im;
|
|
181
|
+
}
|
|
182
|
+
__syncthreads();
|
|
183
|
+
|
|
184
|
+
if (is_out) {
|
|
185
|
+
#pragma unroll
|
|
186
|
+
for (int fi = 0; fi < MAX_FPT; fi++) {
|
|
187
|
+
tk_re[si * MAX_FPT + fi] = 0.0f;
|
|
188
|
+
tk_im[si * MAX_FPT + fi] = 0.0f;
|
|
189
|
+
}
|
|
190
|
+
continue;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/* Phase 2: TX sweep */
|
|
194
|
+
float sum_re[MAX_FPT], sum_im[MAX_FPT];
|
|
195
|
+
for (int i = 0; i < MAX_FPT; i++) { sum_re[i] = 0.0f; sum_im[i] = 0.0f; }
|
|
196
|
+
|
|
197
|
+
for (int tile = 0; tile < N_TILES; tile++) {
|
|
198
|
+
int ts = tile * TILE_SE;
|
|
199
|
+
int te = ts + TILE_SE;
|
|
200
|
+
if (te > N_ES) te = N_ES;
|
|
201
|
+
int tl = te - ts;
|
|
202
|
+
|
|
203
|
+
f2 cv[TILE_SE], sv[TILE_SE];
|
|
204
|
+
#pragma unroll
|
|
205
|
+
for (int j = 0; j < TILE_SE; j++) {
|
|
206
|
+
if (j >= tl) { cv[j] = {0.0f, 0.0f}; sv[j] = {1.0f, 0.0f}; continue; }
|
|
207
|
+
int se = ts + j, em = se / N_SUB;
|
|
208
|
+
float ph = GEO_KW_R(si)[se] + lid_f * GEO_KR_STEP(si)[se];
|
|
209
|
+
float av = GEO_ALPHA_R(si)[se] + lid_f * GEO_AR_STEP(si)[se];
|
|
210
|
+
float ai = GEO_AMP(si)[se] * expf(-av);
|
|
211
|
+
float vr, vi;
|
|
212
|
+
__sincosf(ph, &vi, &vr);
|
|
213
|
+
vr *= ai; vi *= ai;
|
|
214
|
+
float dp = lid_f * sh_dps_l[em];
|
|
215
|
+
float dr, di;
|
|
216
|
+
__sincosf(dp, &di, &dr);
|
|
217
|
+
float dvr = sh_da_init_re_l[em]*dr - sh_da_init_im_l[em]*di;
|
|
218
|
+
float dvi = sh_da_init_re_l[em]*di + sh_da_init_im_l[em]*dr;
|
|
219
|
+
cv[j] = {vr*dvr - vi*dvi, vr*dvi + vi*dvr};
|
|
220
|
+
|
|
221
|
+
float sp_re = GEO_STP_RX_RE(si)[se];
|
|
222
|
+
float sp_im = GEO_STP_RX_IM(si)[se];
|
|
223
|
+
float das_phase = stride_f * sh_dps_l[em];
|
|
224
|
+
float das_re, das_im;
|
|
225
|
+
__sincosf(das_phase, &das_im, &das_re);
|
|
226
|
+
sv[j] = {sp_re*das_re - sp_im*das_im, sp_re*das_im + sp_im*das_re};
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
#pragma unroll
|
|
230
|
+
for (int fi = 0; fi < MAX_FPT; fi++) {
|
|
231
|
+
int f_chk = lid + fi * TG_SIZE;
|
|
232
|
+
if (f_chk >= N_FREQ) break;
|
|
233
|
+
#pragma unroll
|
|
234
|
+
for (int j = 0; j < TILE_SE; j++) {
|
|
235
|
+
sum_re[fi] += cv[j].x; sum_im[fi] += cv[j].y;
|
|
236
|
+
cv[j] = cmul(cv[j], sv[j]);
|
|
237
|
+
}
|
|
238
|
+
}
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
#pragma unroll
|
|
242
|
+
for (int fi = 0; fi < MAX_FPT; fi++) {
|
|
243
|
+
int f = lid + fi * TG_SIZE;
|
|
244
|
+
bool valid = (f < N_FREQ);
|
|
245
|
+
float tr = sum_re[fi] * inv_nsub;
|
|
246
|
+
float ti = sum_im[fi] * inv_nsub;
|
|
247
|
+
float ppr = valid ? pp_re[f] : 0.0f;
|
|
248
|
+
float ppi = valid ? pp_im[f] : 0.0f;
|
|
249
|
+
tk_re[si * MAX_FPT + fi] = valid ? (ppr*tr - ppi*ti) * rc : 0.0f;
|
|
250
|
+
tk_im[si * MAX_FPT + fi] = valid ? (ppr*ti + ppi*tr) * rc : 0.0f;
|
|
251
|
+
}
|
|
252
|
+
/* No __syncthreads here -- TX is private to this thread. */
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
/* ---- Phase 3: element-tiled RX with B_SCAT accumulation ---- */
|
|
256
|
+
for (int eg = 0; eg < N_ELEM_GROUPS; eg++) {
|
|
257
|
+
int se_base = eg * ELEM_TILE;
|
|
258
|
+
int etl = ELEM_TILE;
|
|
259
|
+
if (se_base + etl > N_ES) etl = N_ES - se_base;
|
|
260
|
+
|
|
261
|
+
/* Initialize B_SCAT * ELEM_TILE RX states. sv_arr stays in shmem
|
|
262
|
+
* (re-read per cmul advance) to free registers for tk. */
|
|
263
|
+
f2 cv[B_SCAT * ELEM_TILE];
|
|
264
|
+
|
|
265
|
+
#pragma unroll
|
|
266
|
+
for (int si = 0; si < B_SCAT; si++) {
|
|
267
|
+
#pragma unroll
|
|
268
|
+
for (int et = 0; et < ELEM_TILE; et++) {
|
|
269
|
+
int idx = si * ELEM_TILE + et;
|
|
270
|
+
if (si >= actual_b || out_flag[si] || et >= etl) {
|
|
271
|
+
cv[idx] = {0.0f, 0.0f};
|
|
272
|
+
continue;
|
|
273
|
+
}
|
|
274
|
+
int se = se_base + et;
|
|
275
|
+
float ph = GEO_KW_R(si)[se] + lid_f * GEO_KR_STEP(si)[se];
|
|
276
|
+
float av = GEO_ALPHA_R(si)[se] + lid_f * GEO_AR_STEP(si)[se];
|
|
277
|
+
float ai = GEO_AMP(si)[se] * expf(-av);
|
|
278
|
+
float vr, vi;
|
|
279
|
+
__sincosf(ph, &vi, &vr);
|
|
280
|
+
cv[idx] = {vr * ai, vi * ai};
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
/* Sweep frequencies with B_SCAT * ELEM_TILE independent chains.
|
|
285
|
+
* fi is statically unrolled so tk_re[si*MAX_FPT + fi] uses a
|
|
286
|
+
* compile-time index, keeping tk truly register-resident. */
|
|
287
|
+
#pragma unroll
|
|
288
|
+
for (int fi = 0; fi < MAX_FPT; fi++) {
|
|
289
|
+
int f = lid + fi * TG_SIZE;
|
|
290
|
+
bool valid = (f < N_FREQ);
|
|
291
|
+
float pf = valid ? probe[f] : 0.0f;
|
|
292
|
+
|
|
293
|
+
float acc_re[ELEM_TILE];
|
|
294
|
+
float acc_im[ELEM_TILE];
|
|
295
|
+
#pragma unroll
|
|
296
|
+
for (int et = 0; et < ELEM_TILE; et++) {
|
|
297
|
+
acc_re[et] = 0.0f;
|
|
298
|
+
acc_im[et] = 0.0f;
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
#pragma unroll
|
|
302
|
+
for (int si = 0; si < B_SCAT; si++) {
|
|
303
|
+
float tkr = tk_re[si * MAX_FPT + fi];
|
|
304
|
+
float tki = tk_im[si * MAX_FPT + fi];
|
|
305
|
+
|
|
306
|
+
#pragma unroll
|
|
307
|
+
for (int et = 0; et < ELEM_TILE; et++) {
|
|
308
|
+
int idx = si * ELEM_TILE + et;
|
|
309
|
+
int se = se_base + et;
|
|
310
|
+
float rr = cv[idx].x * inv_nsub;
|
|
311
|
+
float ri = cv[idx].y * inv_nsub;
|
|
312
|
+
acc_re[et] += (tkr*rr - tki*ri) * pf;
|
|
313
|
+
acc_im[et] += (tkr*ri + tki*rr) * pf;
|
|
314
|
+
f2 sv_local = {GEO_STP_RX_RE(si)[se], GEO_STP_RX_IM(si)[se]};
|
|
315
|
+
cv[idx] = cmul(cv[idx], sv_local);
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
if (!valid) continue;
|
|
320
|
+
#pragma unroll
|
|
321
|
+
for (int et = 0; et < ELEM_TILE; et++) {
|
|
322
|
+
if (et >= etl) break;
|
|
323
|
+
int elem = (se_base + et) / N_SUB;
|
|
324
|
+
atomicAdd(&spect_re[elem * N_FREQ + f], acc_re[et]);
|
|
325
|
+
atomicAdd(&spect_im[elem * N_FREQ + f], acc_im[et]);
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
__syncthreads();
|
|
331
|
+
}
|
|
332
|
+
}
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
// Kernel B: SIMD-reduce RX -- multiple scatterers per threadgroup with
|
|
2
|
+
// cross-scatterer SIMD reduction to cut atomic writes by SCAT_REDUCE.
|
|
3
|
+
//
|
|
4
|
+
// Thread layout: tid = elem_idx * SCAT_REDUCE + scat_batch
|
|
5
|
+
// - Adjacent threads handle the SAME element from DIFFERENT scatterers
|
|
6
|
+
// - Within a SIMD group (32 threads): 32/SR elements * SR scatterers
|
|
7
|
+
// - simd_shuffle_xor reduces groups of SR threads (same element, different scat)
|
|
8
|
+
// - Only scat_batch==0 threads write atomics -> SR fewer atomics
|
|
9
|
+
//
|
|
10
|
+
// Coalescing: writing threads (scat_batch==0) are at stride SR in the SIMD group.
|
|
11
|
+
// They write to consecutive element addresses -> coalesced atomics.
|
|
12
|
+
//
|
|
13
|
+
// TG = N_ELEM * SCAT_REDUCE (e.g., 64*2 = 128 for P4-2v with SR=2)
|
|
14
|
+
//
|
|
15
|
+
// Compile-time constants:
|
|
16
|
+
// N_ELEM, N_SUB, N_FREQ, N_SCAT, SCAT_REDUCE
|
|
17
|
+
|
|
18
|
+
uint tg_scat_base = threadgroup_position_in_grid.x * SCAT_REDUCE;
|
|
19
|
+
uint lid = thread_position_in_threadgroup.x;
|
|
20
|
+
uint scat_batch = lid % SCAT_REDUCE;
|
|
21
|
+
uint elem_idx = lid / SCAT_REDUCE;
|
|
22
|
+
uint scat_idx = tg_scat_base + scat_batch;
|
|
23
|
+
|
|
24
|
+
bool valid = (scat_idx < (uint)N_SCAT && elem_idx < (uint)N_ELEM);
|
|
25
|
+
|
|
26
|
+
float sx, sz, rc_i, ex, ez, te;
|
|
27
|
+
float kw_init, alpha_init, kw_step, alpha_step, min_dist, seg_len, center_kw, inv_nsub;
|
|
28
|
+
|
|
29
|
+
kw_init = scalars[0];
|
|
30
|
+
alpha_init = scalars[1];
|
|
31
|
+
kw_step = scalars[2];
|
|
32
|
+
alpha_step = scalars[3];
|
|
33
|
+
min_dist = scalars[4];
|
|
34
|
+
seg_len = scalars[5];
|
|
35
|
+
center_kw = scalars[6];
|
|
36
|
+
inv_nsub = scalars[7];
|
|
37
|
+
|
|
38
|
+
float2 cur[N_SUB];
|
|
39
|
+
float2 stp_arr[N_SUB];
|
|
40
|
+
|
|
41
|
+
if (valid) {
|
|
42
|
+
sx = scat_x[scat_idx];
|
|
43
|
+
sz = scat_z[scat_idx];
|
|
44
|
+
rc_i = rc[scat_idx];
|
|
45
|
+
ex = elem_x[elem_idx];
|
|
46
|
+
ez = elem_z[elem_idx];
|
|
47
|
+
te = theta_e[elem_idx];
|
|
48
|
+
|
|
49
|
+
for (int s = 0; s < N_SUB; s++) {
|
|
50
|
+
int sub_idx = elem_idx * N_SUB + s;
|
|
51
|
+
float dx = sx - ex - sub_dx[sub_idx];
|
|
52
|
+
float dz = sz - ez - sub_dz[sub_idx];
|
|
53
|
+
float r = metal::precise::sqrt(dx * dx + dz * dz);
|
|
54
|
+
float rc_ = max(r, min_dist);
|
|
55
|
+
|
|
56
|
+
float th = metal::precise::asin((dx + 1e-16f) / (r + 1e-16f)) - te;
|
|
57
|
+
float obliq = (fabs(th) >= M_PI_2_F) ? 1e-16f : metal::precise::cos(th);
|
|
58
|
+
|
|
59
|
+
float kwr = kw_init * rc_;
|
|
60
|
+
float TWO_PI = 2.0f * M_PI_F;
|
|
61
|
+
float ph_wrap = kwr - TWO_PI * metal::precise::floor(kwr / TWO_PI);
|
|
62
|
+
float ai = obliq / metal::precise::sqrt(rc_) * metal::precise::exp(-alpha_init * rc_);
|
|
63
|
+
float2 pi_ = float2(ai * metal::precise::cos(ph_wrap),
|
|
64
|
+
ai * metal::precise::sin(ph_wrap));
|
|
65
|
+
|
|
66
|
+
float as_ = metal::precise::exp(-alpha_step * rc_);
|
|
67
|
+
float phs = kw_step * rc_;
|
|
68
|
+
float2 ps_ = float2(as_ * metal::precise::cos(phs),
|
|
69
|
+
as_ * metal::precise::sin(phs));
|
|
70
|
+
|
|
71
|
+
float sa = center_kw * seg_len * 0.5f * metal::precise::sin(th);
|
|
72
|
+
float sv = (fabs(sa) < 1e-8f) ? 1.0f : metal::precise::sin(sa) / sa;
|
|
73
|
+
pi_ *= sv;
|
|
74
|
+
|
|
75
|
+
cur[s] = pi_;
|
|
76
|
+
stp_arr[s] = ps_;
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
for (int f = 0; f < N_FREQ; f++) {
|
|
81
|
+
float c_re = 0.0f, c_im = 0.0f;
|
|
82
|
+
|
|
83
|
+
if (valid) {
|
|
84
|
+
float sr = 0.0f, si = 0.0f;
|
|
85
|
+
for (int s = 0; s < N_SUB; s++) {
|
|
86
|
+
sr += cur[s].x;
|
|
87
|
+
si += cur[s].y;
|
|
88
|
+
float cr = cur[s].x, ci = cur[s].y;
|
|
89
|
+
float tr = stp_arr[s].x, ti = stp_arr[s].y;
|
|
90
|
+
cur[s] = float2(cr * tr - ci * ti, cr * ti + ci * tr);
|
|
91
|
+
}
|
|
92
|
+
float rp_re = sr * inv_nsub;
|
|
93
|
+
float rp_im = si * inv_nsub;
|
|
94
|
+
|
|
95
|
+
int tx_idx = scat_idx * N_FREQ + f;
|
|
96
|
+
float pk_re = tx_re[tx_idx];
|
|
97
|
+
float pk_im = tx_im[tx_idx];
|
|
98
|
+
|
|
99
|
+
float probe_f = probe[f];
|
|
100
|
+
c_re = rc_i * (pk_re * rp_re - pk_im * rp_im) * probe_f;
|
|
101
|
+
c_im = rc_i * (pk_re * rp_im + pk_im * rp_re) * probe_f;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
// SIMD reduce across SCAT_REDUCE scatterers for the same element.
|
|
105
|
+
// All threads participate (invalid threads contribute 0).
|
|
106
|
+
#if SCAT_REDUCE >= 2
|
|
107
|
+
c_re += simd_shuffle_xor(c_re, 1);
|
|
108
|
+
c_im += simd_shuffle_xor(c_im, 1);
|
|
109
|
+
#endif
|
|
110
|
+
#if SCAT_REDUCE >= 4
|
|
111
|
+
c_re += simd_shuffle_xor(c_re, 2);
|
|
112
|
+
c_im += simd_shuffle_xor(c_im, 2);
|
|
113
|
+
#endif
|
|
114
|
+
#if SCAT_REDUCE >= 8
|
|
115
|
+
c_re += simd_shuffle_xor(c_re, 4);
|
|
116
|
+
c_im += simd_shuffle_xor(c_im, 4);
|
|
117
|
+
#endif
|
|
118
|
+
#if SCAT_REDUCE >= 16
|
|
119
|
+
c_re += simd_shuffle_xor(c_re, 8);
|
|
120
|
+
c_im += simd_shuffle_xor(c_im, 8);
|
|
121
|
+
#endif
|
|
122
|
+
|
|
123
|
+
if (scat_batch == 0 && valid) {
|
|
124
|
+
int offset = f * N_ELEM + elem_idx;
|
|
125
|
+
atomic_fetch_add_explicit(&spect_re[offset], c_re, memory_order_relaxed);
|
|
126
|
+
atomic_fetch_add_explicit(&spect_im[offset], c_im, memory_order_relaxed);
|
|
127
|
+
}
|
|
128
|
+
}
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
// Kernel: Element-tiled progression with shared-memory geometry.
|
|
2
|
+
// One threadgroup per scatterer; threads cooperatively compute geometry
|
|
3
|
+
// AND da-absorbed stride steps into shared memory, then each thread
|
|
4
|
+
// processes sub-element tiles with geometric progression (ALU-only inner loop).
|
|
5
|
+
//
|
|
6
|
+
// Low register pressure: only TILE_SE*2 float2 per thread (256 bytes for
|
|
7
|
+
// TILE_SE=16). ALU-only inner loop (0 SFU calls in the frequency sweep).
|
|
8
|
+
//
|
|
9
|
+
// Shared memory layout:
|
|
10
|
+
// amp[N_ES] frequency-independent amplitude
|
|
11
|
+
// kw_r[N_ES] kw_init * r (base phase)
|
|
12
|
+
// kr_step[N_ES] kw_step * r (phase increment per freq index)
|
|
13
|
+
// alpha_r[N_ES] alpha_init * r (base attenuation)
|
|
14
|
+
// ar_step[N_ES] alpha_step * r (attenuation increment per freq)
|
|
15
|
+
// stp[N_ES] float2 stride step, da-absorbed (same for all threads)
|
|
16
|
+
// da_init_re[N_ELEM] delay+apod init real part
|
|
17
|
+
// da_init_im[N_ELEM] delay+apod init imag part
|
|
18
|
+
// dps[N_ELEM] delay_phase_step per element
|
|
19
|
+
//
|
|
20
|
+
// Total: N_ES*(5*4 + 8) + N_ELEM*3*4 bytes
|
|
21
|
+
// = 64*(20+8) + 64*12 = 1792 + 768 = 2560 bytes (N_ES=64)
|
|
22
|
+
//
|
|
23
|
+
// Output: tx_re[N_SCAT * N_FREQ], tx_im[N_SCAT * N_FREQ]
|
|
24
|
+
//
|
|
25
|
+
// Compile-time constants:
|
|
26
|
+
// N_ELEM, N_SUB, N_FREQ, N_ES, N_SCAT, TILE_SE, TG_SIZE, MAX_FPT
|
|
27
|
+
|
|
28
|
+
threadgroup float sh_amp[N_ES];
|
|
29
|
+
threadgroup float sh_kw_r[N_ES];
|
|
30
|
+
threadgroup float sh_kr_step[N_ES];
|
|
31
|
+
threadgroup float sh_alpha_r[N_ES];
|
|
32
|
+
threadgroup float sh_ar_step[N_ES];
|
|
33
|
+
threadgroup float2 sh_stp[N_ES];
|
|
34
|
+
threadgroup float sh_da_init_re[N_ELEM];
|
|
35
|
+
threadgroup float sh_da_init_im[N_ELEM];
|
|
36
|
+
threadgroup float sh_dps[N_ELEM];
|
|
37
|
+
|
|
38
|
+
uint scat_idx = threadgroup_position_in_grid.x;
|
|
39
|
+
uint lid = thread_position_in_threadgroup.x;
|
|
40
|
+
uint tpg = threads_per_threadgroup.x;
|
|
41
|
+
|
|
42
|
+
if (scat_idx >= N_SCAT) return;
|
|
43
|
+
|
|
44
|
+
float sx = scat_x[scat_idx];
|
|
45
|
+
float sz = scat_z[scat_idx];
|
|
46
|
+
float is_out_i = is_out[scat_idx];
|
|
47
|
+
|
|
48
|
+
float kw_init_v = scalars[0];
|
|
49
|
+
float alpha_init_v = scalars[1];
|
|
50
|
+
float kw_step_v = scalars[2];
|
|
51
|
+
float alpha_step_v = scalars[3];
|
|
52
|
+
float min_dist = scalars[4];
|
|
53
|
+
float seg_len = scalars[5];
|
|
54
|
+
float center_kw = scalars[6];
|
|
55
|
+
float inv_nsub = scalars[7];
|
|
56
|
+
|
|
57
|
+
float lid_f = float(lid);
|
|
58
|
+
float stride_f = float(TG_SIZE);
|
|
59
|
+
|
|
60
|
+
// ---- Phase 1A: Cooperatively compute per-sub-element geometry ----
|
|
61
|
+
for (uint se = lid; se < (uint)N_ES; se += tpg) {
|
|
62
|
+
int elem = se / N_SUB;
|
|
63
|
+
int sub_global = elem * N_SUB + (se % N_SUB);
|
|
64
|
+
|
|
65
|
+
float ex = elem_x[elem];
|
|
66
|
+
float ez = elem_z[elem];
|
|
67
|
+
float te = theta_e[elem];
|
|
68
|
+
|
|
69
|
+
float dx = sx - ex - sub_dx[sub_global];
|
|
70
|
+
float dz = sz - ez - sub_dz[sub_global];
|
|
71
|
+
float r = metal::precise::sqrt(dx * dx + dz * dz);
|
|
72
|
+
float rc_ = max(r, min_dist);
|
|
73
|
+
|
|
74
|
+
float th = metal::precise::asin((dx + 1e-16f) / (r + 1e-16f)) - te;
|
|
75
|
+
float obliq = (fabs(th) >= M_PI_2_F) ? 1e-16f : metal::fast::cos(th);
|
|
76
|
+
|
|
77
|
+
float sa = center_kw * seg_len * 0.5f * metal::fast::sin(th);
|
|
78
|
+
float sv = (fabs(sa) < 1e-8f) ? 1.0f : metal::fast::sin(sa) / sa;
|
|
79
|
+
|
|
80
|
+
sh_amp[se] = obliq * sv / metal::precise::sqrt(rc_);
|
|
81
|
+
sh_kw_r[se] = kw_init_v * rc_;
|
|
82
|
+
sh_kr_step[se] = kw_step_v * rc_;
|
|
83
|
+
sh_alpha_r[se] = alpha_init_v * rc_;
|
|
84
|
+
sh_ar_step[se] = alpha_step_v * rc_;
|
|
85
|
+
|
|
86
|
+
// Precompute da-absorbed stride step (same for ALL threads).
|
|
87
|
+
// stp = exp((-alpha_step*stride + j*kw_step*stride) * r) * da_step^stride
|
|
88
|
+
float stp_phase = stride_f * kw_step_v * rc_;
|
|
89
|
+
float stp_alpha = stride_f * alpha_step_v * rc_;
|
|
90
|
+
float sm = metal::fast::exp(-stp_alpha);
|
|
91
|
+
float sp_re = sm * metal::fast::cos(stp_phase);
|
|
92
|
+
float sp_im = sm * metal::fast::sin(stp_phase);
|
|
93
|
+
|
|
94
|
+
float das_phase = stride_f * delay_phase_step[elem];
|
|
95
|
+
float das_re = metal::fast::cos(das_phase);
|
|
96
|
+
float das_im = metal::fast::sin(das_phase);
|
|
97
|
+
sh_stp[se] = float2(sp_re * das_re - sp_im * das_im,
|
|
98
|
+
sp_re * das_im + sp_im * das_re);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
// ---- Phase 1B: Cooperatively load per-element da info ----
|
|
102
|
+
for (uint e = lid; e < (uint)N_ELEM; e += tpg) {
|
|
103
|
+
sh_da_init_re[e] = da_init_re[e];
|
|
104
|
+
sh_da_init_im[e] = da_init_im[e];
|
|
105
|
+
sh_dps[e] = delay_phase_step[e];
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
109
|
+
|
|
110
|
+
// ---- Phase 2: Tiled progression sweep ----
|
|
111
|
+
constexpr int N_TILES = (N_ES + TILE_SE - 1) / TILE_SE;
|
|
112
|
+
|
|
113
|
+
float sum_re[MAX_FPT];
|
|
114
|
+
float sum_im[MAX_FPT];
|
|
115
|
+
int my_n_freq = 0;
|
|
116
|
+
for (uint f = lid; f < (uint)N_FREQ; f += tpg) my_n_freq++;
|
|
117
|
+
for (int i = 0; i < MAX_FPT; i++) { sum_re[i] = 0.0f; sum_im[i] = 0.0f; }
|
|
118
|
+
|
|
119
|
+
for (int tile = 0; tile < N_TILES; tile++) {
|
|
120
|
+
int tile_start = tile * TILE_SE;
|
|
121
|
+
int tile_end = min(tile_start + TILE_SE, N_ES);
|
|
122
|
+
int tile_len = tile_end - tile_start;
|
|
123
|
+
|
|
124
|
+
float2 cur_t[TILE_SE];
|
|
125
|
+
float2 stp_t[TILE_SE];
|
|
126
|
+
|
|
127
|
+
// Init cur at this thread's starting frequency, read stp from shared
|
|
128
|
+
for (int te = 0; te < tile_len; te++) {
|
|
129
|
+
int se = tile_start + te;
|
|
130
|
+
int elem = se / N_SUB;
|
|
131
|
+
|
|
132
|
+
float phase = sh_kw_r[se] + lid_f * sh_kr_step[se];
|
|
133
|
+
float alpha_val = sh_alpha_r[se] + lid_f * sh_ar_step[se];
|
|
134
|
+
float ai = sh_amp[se] * metal::fast::exp(-alpha_val);
|
|
135
|
+
float pi_re = ai * metal::fast::cos(phase);
|
|
136
|
+
float pi_im = ai * metal::fast::sin(phase);
|
|
137
|
+
|
|
138
|
+
float da_ph = lid_f * sh_dps[elem];
|
|
139
|
+
float da_cs_re = metal::fast::cos(da_ph);
|
|
140
|
+
float da_cs_im = metal::fast::sin(da_ph);
|
|
141
|
+
float da_re = sh_da_init_re[elem] * da_cs_re - sh_da_init_im[elem] * da_cs_im;
|
|
142
|
+
float da_im = sh_da_init_re[elem] * da_cs_im + sh_da_init_im[elem] * da_cs_re;
|
|
143
|
+
|
|
144
|
+
cur_t[te] = float2(pi_re * da_re - pi_im * da_im,
|
|
145
|
+
pi_re * da_im + pi_im * da_re);
|
|
146
|
+
stp_t[te] = sh_stp[se];
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
// Sweep: ALU-only inner loop
|
|
150
|
+
for (int fi = 0; fi < my_n_freq; fi++) {
|
|
151
|
+
for (int te = 0; te < tile_len; te++) {
|
|
152
|
+
sum_re[fi] += cur_t[te].x;
|
|
153
|
+
sum_im[fi] += cur_t[te].y;
|
|
154
|
+
float cr = cur_t[te].x, ci = cur_t[te].y;
|
|
155
|
+
float tr = stp_t[te].x, ti = stp_t[te].y;
|
|
156
|
+
cur_t[te] = float2(cr * tr - ci * ti, cr * ti + ci * tr);
|
|
157
|
+
}
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
// ---- Phase 3: Apply inv_nsub, pulse*probe spectrum, write output ----
|
|
162
|
+
int fi = 0;
|
|
163
|
+
for (uint f = lid; f < (uint)N_FREQ; f += tpg, fi++) {
|
|
164
|
+
float tx_re_v = sum_re[fi] * inv_nsub;
|
|
165
|
+
float tx_im_v = sum_im[fi] * inv_nsub;
|
|
166
|
+
|
|
167
|
+
float pp_re_f = pp_re[f], pp_im_f = pp_im[f];
|
|
168
|
+
float pk_re = pp_re_f * tx_re_v - pp_im_f * tx_im_v;
|
|
169
|
+
float pk_im = pp_re_f * tx_im_v + pp_im_f * tx_re_v;
|
|
170
|
+
if (is_out_i > 0.5f) { pk_re = 0.0f; pk_im = 0.0f; }
|
|
171
|
+
|
|
172
|
+
int out_idx = scat_idx * N_FREQ + f;
|
|
173
|
+
tx_re[out_idx] = pk_re;
|
|
174
|
+
tx_im[out_idx] = pk_im;
|
|
175
|
+
}
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
"""Medium parameter definitions for ultrasound propagation."""
|
|
2
|
+
|
|
3
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class MediumParams(BaseModel):
|
|
7
|
+
"""Medium parameters for ultrasound propagation.
|
|
8
|
+
|
|
9
|
+
This class encapsulates physical properties of the propagation medium
|
|
10
|
+
(e.g., soft tissue, water) that affect ultrasound wave propagation.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
model_config = ConfigDict(
|
|
14
|
+
use_attribute_docstrings=True,
|
|
15
|
+
frozen=True,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
speed_of_sound: float = Field(default=1540.0, gt=0)
|
|
19
|
+
"""Speed of sound in m/s."""
|
|
20
|
+
|
|
21
|
+
attenuation: float = Field(default=0.0, ge=0)
|
|
22
|
+
"""Attenuation coefficient in dB/cm/MHz."""
|