FastSIMUS 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,332 @@
1
+ /*
2
+ * Fused TX+RX SIMUS kernel -- v25c: register-resident TX, sv_arr in
3
+ * shmem (correct, fp32).
4
+ *
5
+ * v25b cached sv_arr[B*ELEM_TILE] in registers (56 floats at B=7 ET=4).
6
+ * That competes with tk_re/tk_im for the 255-reg cap and forces 400 B
7
+ * spill of tk into local memory -- which NCU showed saturates L2 at
8
+ * 76.5 % throughput.
9
+ *
10
+ * v25c drops sv_arr from registers and reads GEO_STP_RX_RE/IM directly
11
+ * from shmem inside the cmul, freeing those 56 regs for tk. Cost: an
12
+ * extra shmem read per (si, et, fi) cmul advance -- bank-conflict free
13
+ * since each thread reads its own row.
14
+ *
15
+ * v25 with one structural fix: every loop that indexes tk_re/tk_im is
16
+ * now `for fi in 0..MAX_FPT` with `#pragma unroll` and predicated
17
+ * validity, so fi is statically known. v25 spilled tk entirely
18
+ * (576 B local mem at B=9 ET=4) because dynamic fi forced
19
+ * tk_re[si*MAX_FPT + fi] off-register. Static fi unrolling makes tk
20
+ * actually register-resident, eliminating the local-memory traffic.
21
+ *
22
+ * Why this is safe: in v11 each thread `lid` writes to sh_tx[si*N_FREQ + f]
23
+ * for f in {lid, lid+TG_SIZE, lid+2*TG_SIZE, ...} during Phase 2, and reads
24
+ * the same slots during Phase 3. There is no cross-thread sharing of TX --
25
+ * the shmem allocation was a temporary, not a broadcast surface.
26
+ *
27
+ * Storing TX in per-thread register arrays tk_re[B_SCAT*MAX_FPT],
28
+ * tk_im[B_SCAT*MAX_FPT] eliminates the dominant shmem cost
29
+ * (2*B_SCAT*N_FREQ floats; 60 KB at B=9 N_FREQ=854) without changing
30
+ * precision or arithmetic. Also lets us drop the pre-Phase-3 sync that
31
+ * was only needed to publish sh_tx writes.
32
+ *
33
+ * Per-thread TX register cost: 2*B_SCAT*MAX_FPT floats. For
34
+ * B=9 N_FREQ=854 TG=128 -> MAX_FPT=7 -> 126 floats. May force 1->something
35
+ * trade vs spill; expected to come out well ahead given shmem savings
36
+ * (76.5 KB -> 16.5 KB at B=9 ET=4 unlocks ~5 blk/SM vs v11's 1).
37
+ *
38
+ * Compile-time: N_ELEM, N_SUB, N_FREQ, N_ES, TILE_SE, TG_SIZE, MAX_FPT,
39
+ * B_SCAT, ELEM_TILE
40
+ *
41
+ * Shared memory: (7*B_SCAT*N_ES + 3*N_ELEM) * 4 bytes
42
+ */
43
+
44
+ #ifndef M_PI_F
45
+ #define M_PI_F 3.14159265358979323846f
46
+ #endif
47
+
48
+ struct f2 { float x, y; };
49
+
50
+ __device__ __forceinline__ f2 cmul(f2 a, f2 b) {
51
+ return {a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x};
52
+ }
53
+
54
+ #define GEO_AMP(s) (shmem + ((0*B_SCAT + (s)) * N_ES))
55
+ #define GEO_KW_R(s) (shmem + ((1*B_SCAT + (s)) * N_ES))
56
+ #define GEO_KR_STEP(s) (shmem + ((2*B_SCAT + (s)) * N_ES))
57
+ #define GEO_ALPHA_R(s) (shmem + ((3*B_SCAT + (s)) * N_ES))
58
+ #define GEO_AR_STEP(s) (shmem + ((4*B_SCAT + (s)) * N_ES))
59
+ #define GEO_STP_RX_RE(s) (shmem + ((5*B_SCAT + (s)) * N_ES))
60
+ #define GEO_STP_RX_IM(s) (shmem + ((6*B_SCAT + (s)) * N_ES))
61
+
62
+ extern "C" __global__
63
+ void simus_fused_kernel(
64
+ const float* __restrict__ scat_x,
65
+ const float* __restrict__ scat_z,
66
+ const float* __restrict__ rc_arr,
67
+ const float* __restrict__ elem_x,
68
+ const float* __restrict__ elem_z,
69
+ const float* __restrict__ cos_te,
70
+ const float* __restrict__ sin_neg_te,
71
+ const float* __restrict__ sub_dx,
72
+ const float* __restrict__ sub_dz,
73
+ const float* __restrict__ da_init_re,
74
+ const float* __restrict__ da_init_im,
75
+ const float* __restrict__ dps,
76
+ const float* __restrict__ pp_re,
77
+ const float* __restrict__ pp_im,
78
+ const float* __restrict__ probe,
79
+ float* __restrict__ spect_re,
80
+ float* __restrict__ spect_im,
81
+ int n_scat,
82
+ float kw_init, float alpha_init,
83
+ float kw_step, float alpha_step,
84
+ float min_dist, float seg_len,
85
+ float center_kw, float inv_nsub,
86
+ float radius, float apex_offset
87
+ ) {
88
+ int lid = threadIdx.x;
89
+ float lid_f = (float)lid;
90
+ float stride_f = (float)TG_SIZE;
91
+
92
+ extern __shared__ float shmem[];
93
+
94
+ /* TX moved to per-thread registers; shmem only holds geometry + per-elem broadcast. */
95
+ float tk_re[B_SCAT * MAX_FPT];
96
+ float tk_im[B_SCAT * MAX_FPT];
97
+
98
+ float* sh_da_init_re_l = shmem + 7 * B_SCAT * N_ES;
99
+ float* sh_da_init_im_l = sh_da_init_re_l + N_ELEM;
100
+ float* sh_dps_l = sh_da_init_im_l + N_ELEM;
101
+
102
+ for (int e = lid; e < N_ELEM; e += TG_SIZE) {
103
+ sh_da_init_re_l[e] = da_init_re[e];
104
+ sh_da_init_im_l[e] = da_init_im[e];
105
+ sh_dps_l[e] = dps[e];
106
+ }
107
+ __syncthreads();
108
+
109
+ int my_n_freq = 0;
110
+ for (int f = lid; f < N_FREQ; f += TG_SIZE) my_n_freq++;
111
+
112
+ const int N_TILES = (N_ES + TILE_SE - 1) / TILE_SE;
113
+ const int N_ELEM_GROUPS = (N_ES + ELEM_TILE - 1) / ELEM_TILE;
114
+
115
+ bool out_flag[B_SCAT];
116
+
117
+ for (int scat_base = blockIdx.x * B_SCAT;
118
+ scat_base < n_scat;
119
+ scat_base += gridDim.x * B_SCAT)
120
+ {
121
+ int actual_b = B_SCAT;
122
+ if (scat_base + B_SCAT > n_scat)
123
+ actual_b = n_scat - scat_base;
124
+
125
+ /* Zero-init register TX so si >= actual_b reads finite zeros in
126
+ * Phase 3 (cv is also forced to 0 there, but 0*NaN would propagate). */
127
+ #pragma unroll
128
+ for (int si = 0; si < B_SCAT; si++) {
129
+ #pragma unroll
130
+ for (int fi = 0; fi < MAX_FPT; fi++) {
131
+ tk_re[si * MAX_FPT + fi] = 0.0f;
132
+ tk_im[si * MAX_FPT + fi] = 0.0f;
133
+ }
134
+ }
135
+
136
+ /* ---- Phase 1+2: geometry + TX for each scatterer in batch ---- */
137
+ for (int si = 0; si < actual_b; si++) {
138
+ int scat_idx = scat_base + si;
139
+ float sx = scat_x[scat_idx];
140
+ float sz = scat_z[scat_idx];
141
+ float rc = rc_arr[scat_idx];
142
+
143
+ bool is_out = (sz < 0.0f);
144
+ if (radius < 1e30f) {
145
+ float da = sx, db = sz + apex_offset;
146
+ is_out = is_out || ((da*da + db*db) <= radius*radius);
147
+ }
148
+ out_flag[si] = is_out;
149
+
150
+ for (int se = lid; se < N_ES; se += TG_SIZE) {
151
+ int elem = se / N_SUB;
152
+ float ex_ = elem_x[elem], ez_ = elem_z[elem];
153
+ float ct = cos_te[elem], snt = sin_neg_te[elem];
154
+ float dx = sx - ex_ - sub_dx[se];
155
+ float dz = sz - ez_ - sub_dz[se];
156
+ float r2 = dx*dx + dz*dz;
157
+ float inv_r = rsqrtf(r2 + 1e-30f);
158
+ float r = r2 * inv_r;
159
+ float rc_ = fmaxf(r, min_dist);
160
+
161
+ float sin_th = (dx*ct + dz*snt) * inv_r;
162
+ float cos_th = (dz*ct - dx*snt) * inv_r;
163
+ float obliq = (cos_th <= 0.0f) ? 1e-16f : cos_th;
164
+ float sa = center_kw * seg_len * 0.5f * sin_th;
165
+ float sv = (fabsf(sa) < 1e-8f) ? 1.0f : __fdividef(__sinf(sa), sa);
166
+
167
+ GEO_AMP(si)[se] = obliq * sv * rsqrtf(rc_);
168
+ GEO_KW_R(si)[se] = kw_init * rc_;
169
+ GEO_KR_STEP(si)[se] = kw_step * rc_;
170
+ GEO_ALPHA_R(si)[se] = alpha_init * rc_;
171
+ GEO_AR_STEP(si)[se] = alpha_step * rc_;
172
+
173
+ float stp_phase = stride_f * kw_step * rc_;
174
+ float stp_alpha = stride_f * alpha_step * rc_;
175
+ float sm = expf(-stp_alpha);
176
+ float sp_re, sp_im;
177
+ __sincosf(stp_phase, &sp_im, &sp_re);
178
+ sp_re *= sm; sp_im *= sm;
179
+ GEO_STP_RX_RE(si)[se] = sp_re;
180
+ GEO_STP_RX_IM(si)[se] = sp_im;
181
+ }
182
+ __syncthreads();
183
+
184
+ if (is_out) {
185
+ #pragma unroll
186
+ for (int fi = 0; fi < MAX_FPT; fi++) {
187
+ tk_re[si * MAX_FPT + fi] = 0.0f;
188
+ tk_im[si * MAX_FPT + fi] = 0.0f;
189
+ }
190
+ continue;
191
+ }
192
+
193
+ /* Phase 2: TX sweep */
194
+ float sum_re[MAX_FPT], sum_im[MAX_FPT];
195
+ for (int i = 0; i < MAX_FPT; i++) { sum_re[i] = 0.0f; sum_im[i] = 0.0f; }
196
+
197
+ for (int tile = 0; tile < N_TILES; tile++) {
198
+ int ts = tile * TILE_SE;
199
+ int te = ts + TILE_SE;
200
+ if (te > N_ES) te = N_ES;
201
+ int tl = te - ts;
202
+
203
+ f2 cv[TILE_SE], sv[TILE_SE];
204
+ #pragma unroll
205
+ for (int j = 0; j < TILE_SE; j++) {
206
+ if (j >= tl) { cv[j] = {0.0f, 0.0f}; sv[j] = {1.0f, 0.0f}; continue; }
207
+ int se = ts + j, em = se / N_SUB;
208
+ float ph = GEO_KW_R(si)[se] + lid_f * GEO_KR_STEP(si)[se];
209
+ float av = GEO_ALPHA_R(si)[se] + lid_f * GEO_AR_STEP(si)[se];
210
+ float ai = GEO_AMP(si)[se] * expf(-av);
211
+ float vr, vi;
212
+ __sincosf(ph, &vi, &vr);
213
+ vr *= ai; vi *= ai;
214
+ float dp = lid_f * sh_dps_l[em];
215
+ float dr, di;
216
+ __sincosf(dp, &di, &dr);
217
+ float dvr = sh_da_init_re_l[em]*dr - sh_da_init_im_l[em]*di;
218
+ float dvi = sh_da_init_re_l[em]*di + sh_da_init_im_l[em]*dr;
219
+ cv[j] = {vr*dvr - vi*dvi, vr*dvi + vi*dvr};
220
+
221
+ float sp_re = GEO_STP_RX_RE(si)[se];
222
+ float sp_im = GEO_STP_RX_IM(si)[se];
223
+ float das_phase = stride_f * sh_dps_l[em];
224
+ float das_re, das_im;
225
+ __sincosf(das_phase, &das_im, &das_re);
226
+ sv[j] = {sp_re*das_re - sp_im*das_im, sp_re*das_im + sp_im*das_re};
227
+ }
228
+
229
+ #pragma unroll
230
+ for (int fi = 0; fi < MAX_FPT; fi++) {
231
+ int f_chk = lid + fi * TG_SIZE;
232
+ if (f_chk >= N_FREQ) break;
233
+ #pragma unroll
234
+ for (int j = 0; j < TILE_SE; j++) {
235
+ sum_re[fi] += cv[j].x; sum_im[fi] += cv[j].y;
236
+ cv[j] = cmul(cv[j], sv[j]);
237
+ }
238
+ }
239
+ }
240
+
241
+ #pragma unroll
242
+ for (int fi = 0; fi < MAX_FPT; fi++) {
243
+ int f = lid + fi * TG_SIZE;
244
+ bool valid = (f < N_FREQ);
245
+ float tr = sum_re[fi] * inv_nsub;
246
+ float ti = sum_im[fi] * inv_nsub;
247
+ float ppr = valid ? pp_re[f] : 0.0f;
248
+ float ppi = valid ? pp_im[f] : 0.0f;
249
+ tk_re[si * MAX_FPT + fi] = valid ? (ppr*tr - ppi*ti) * rc : 0.0f;
250
+ tk_im[si * MAX_FPT + fi] = valid ? (ppr*ti + ppi*tr) * rc : 0.0f;
251
+ }
252
+ /* No __syncthreads here -- TX is private to this thread. */
253
+ }
254
+
255
+ /* ---- Phase 3: element-tiled RX with B_SCAT accumulation ---- */
256
+ for (int eg = 0; eg < N_ELEM_GROUPS; eg++) {
257
+ int se_base = eg * ELEM_TILE;
258
+ int etl = ELEM_TILE;
259
+ if (se_base + etl > N_ES) etl = N_ES - se_base;
260
+
261
+ /* Initialize B_SCAT * ELEM_TILE RX states. sv_arr stays in shmem
262
+ * (re-read per cmul advance) to free registers for tk. */
263
+ f2 cv[B_SCAT * ELEM_TILE];
264
+
265
+ #pragma unroll
266
+ for (int si = 0; si < B_SCAT; si++) {
267
+ #pragma unroll
268
+ for (int et = 0; et < ELEM_TILE; et++) {
269
+ int idx = si * ELEM_TILE + et;
270
+ if (si >= actual_b || out_flag[si] || et >= etl) {
271
+ cv[idx] = {0.0f, 0.0f};
272
+ continue;
273
+ }
274
+ int se = se_base + et;
275
+ float ph = GEO_KW_R(si)[se] + lid_f * GEO_KR_STEP(si)[se];
276
+ float av = GEO_ALPHA_R(si)[se] + lid_f * GEO_AR_STEP(si)[se];
277
+ float ai = GEO_AMP(si)[se] * expf(-av);
278
+ float vr, vi;
279
+ __sincosf(ph, &vi, &vr);
280
+ cv[idx] = {vr * ai, vi * ai};
281
+ }
282
+ }
283
+
284
+ /* Sweep frequencies with B_SCAT * ELEM_TILE independent chains.
285
+ * fi is statically unrolled so tk_re[si*MAX_FPT + fi] uses a
286
+ * compile-time index, keeping tk truly register-resident. */
287
+ #pragma unroll
288
+ for (int fi = 0; fi < MAX_FPT; fi++) {
289
+ int f = lid + fi * TG_SIZE;
290
+ bool valid = (f < N_FREQ);
291
+ float pf = valid ? probe[f] : 0.0f;
292
+
293
+ float acc_re[ELEM_TILE];
294
+ float acc_im[ELEM_TILE];
295
+ #pragma unroll
296
+ for (int et = 0; et < ELEM_TILE; et++) {
297
+ acc_re[et] = 0.0f;
298
+ acc_im[et] = 0.0f;
299
+ }
300
+
301
+ #pragma unroll
302
+ for (int si = 0; si < B_SCAT; si++) {
303
+ float tkr = tk_re[si * MAX_FPT + fi];
304
+ float tki = tk_im[si * MAX_FPT + fi];
305
+
306
+ #pragma unroll
307
+ for (int et = 0; et < ELEM_TILE; et++) {
308
+ int idx = si * ELEM_TILE + et;
309
+ int se = se_base + et;
310
+ float rr = cv[idx].x * inv_nsub;
311
+ float ri = cv[idx].y * inv_nsub;
312
+ acc_re[et] += (tkr*rr - tki*ri) * pf;
313
+ acc_im[et] += (tkr*ri + tki*rr) * pf;
314
+ f2 sv_local = {GEO_STP_RX_RE(si)[se], GEO_STP_RX_IM(si)[se]};
315
+ cv[idx] = cmul(cv[idx], sv_local);
316
+ }
317
+ }
318
+
319
+ if (!valid) continue;
320
+ #pragma unroll
321
+ for (int et = 0; et < ELEM_TILE; et++) {
322
+ if (et >= etl) break;
323
+ int elem = (se_base + et) / N_SUB;
324
+ atomicAdd(&spect_re[elem * N_FREQ + f], acc_re[et]);
325
+ atomicAdd(&spect_im[elem * N_FREQ + f], acc_im[et]);
326
+ }
327
+ }
328
+ }
329
+
330
+ __syncthreads();
331
+ }
332
+ }
@@ -0,0 +1,128 @@
1
+ // Kernel B: SIMD-reduce RX -- multiple scatterers per threadgroup with
2
+ // cross-scatterer SIMD reduction to cut atomic writes by SCAT_REDUCE.
3
+ //
4
+ // Thread layout: tid = elem_idx * SCAT_REDUCE + scat_batch
5
+ // - Adjacent threads handle the SAME element from DIFFERENT scatterers
6
+ // - Within a SIMD group (32 threads): 32/SR elements * SR scatterers
7
+ // - simd_shuffle_xor reduces groups of SR threads (same element, different scat)
8
+ // - Only scat_batch==0 threads write atomics -> SR fewer atomics
9
+ //
10
+ // Coalescing: writing threads (scat_batch==0) are at stride SR in the SIMD group.
11
+ // They write to consecutive element addresses -> coalesced atomics.
12
+ //
13
+ // TG = N_ELEM * SCAT_REDUCE (e.g., 64*2 = 128 for P4-2v with SR=2)
14
+ //
15
+ // Compile-time constants:
16
+ // N_ELEM, N_SUB, N_FREQ, N_SCAT, SCAT_REDUCE
17
+
18
+ uint tg_scat_base = threadgroup_position_in_grid.x * SCAT_REDUCE;
19
+ uint lid = thread_position_in_threadgroup.x;
20
+ uint scat_batch = lid % SCAT_REDUCE;
21
+ uint elem_idx = lid / SCAT_REDUCE;
22
+ uint scat_idx = tg_scat_base + scat_batch;
23
+
24
+ bool valid = (scat_idx < (uint)N_SCAT && elem_idx < (uint)N_ELEM);
25
+
26
+ float sx, sz, rc_i, ex, ez, te;
27
+ float kw_init, alpha_init, kw_step, alpha_step, min_dist, seg_len, center_kw, inv_nsub;
28
+
29
+ kw_init = scalars[0];
30
+ alpha_init = scalars[1];
31
+ kw_step = scalars[2];
32
+ alpha_step = scalars[3];
33
+ min_dist = scalars[4];
34
+ seg_len = scalars[5];
35
+ center_kw = scalars[6];
36
+ inv_nsub = scalars[7];
37
+
38
+ float2 cur[N_SUB];
39
+ float2 stp_arr[N_SUB];
40
+
41
+ if (valid) {
42
+ sx = scat_x[scat_idx];
43
+ sz = scat_z[scat_idx];
44
+ rc_i = rc[scat_idx];
45
+ ex = elem_x[elem_idx];
46
+ ez = elem_z[elem_idx];
47
+ te = theta_e[elem_idx];
48
+
49
+ for (int s = 0; s < N_SUB; s++) {
50
+ int sub_idx = elem_idx * N_SUB + s;
51
+ float dx = sx - ex - sub_dx[sub_idx];
52
+ float dz = sz - ez - sub_dz[sub_idx];
53
+ float r = metal::precise::sqrt(dx * dx + dz * dz);
54
+ float rc_ = max(r, min_dist);
55
+
56
+ float th = metal::precise::asin((dx + 1e-16f) / (r + 1e-16f)) - te;
57
+ float obliq = (fabs(th) >= M_PI_2_F) ? 1e-16f : metal::precise::cos(th);
58
+
59
+ float kwr = kw_init * rc_;
60
+ float TWO_PI = 2.0f * M_PI_F;
61
+ float ph_wrap = kwr - TWO_PI * metal::precise::floor(kwr / TWO_PI);
62
+ float ai = obliq / metal::precise::sqrt(rc_) * metal::precise::exp(-alpha_init * rc_);
63
+ float2 pi_ = float2(ai * metal::precise::cos(ph_wrap),
64
+ ai * metal::precise::sin(ph_wrap));
65
+
66
+ float as_ = metal::precise::exp(-alpha_step * rc_);
67
+ float phs = kw_step * rc_;
68
+ float2 ps_ = float2(as_ * metal::precise::cos(phs),
69
+ as_ * metal::precise::sin(phs));
70
+
71
+ float sa = center_kw * seg_len * 0.5f * metal::precise::sin(th);
72
+ float sv = (fabs(sa) < 1e-8f) ? 1.0f : metal::precise::sin(sa) / sa;
73
+ pi_ *= sv;
74
+
75
+ cur[s] = pi_;
76
+ stp_arr[s] = ps_;
77
+ }
78
+ }
79
+
80
+ for (int f = 0; f < N_FREQ; f++) {
81
+ float c_re = 0.0f, c_im = 0.0f;
82
+
83
+ if (valid) {
84
+ float sr = 0.0f, si = 0.0f;
85
+ for (int s = 0; s < N_SUB; s++) {
86
+ sr += cur[s].x;
87
+ si += cur[s].y;
88
+ float cr = cur[s].x, ci = cur[s].y;
89
+ float tr = stp_arr[s].x, ti = stp_arr[s].y;
90
+ cur[s] = float2(cr * tr - ci * ti, cr * ti + ci * tr);
91
+ }
92
+ float rp_re = sr * inv_nsub;
93
+ float rp_im = si * inv_nsub;
94
+
95
+ int tx_idx = scat_idx * N_FREQ + f;
96
+ float pk_re = tx_re[tx_idx];
97
+ float pk_im = tx_im[tx_idx];
98
+
99
+ float probe_f = probe[f];
100
+ c_re = rc_i * (pk_re * rp_re - pk_im * rp_im) * probe_f;
101
+ c_im = rc_i * (pk_re * rp_im + pk_im * rp_re) * probe_f;
102
+ }
103
+
104
+ // SIMD reduce across SCAT_REDUCE scatterers for the same element.
105
+ // All threads participate (invalid threads contribute 0).
106
+ #if SCAT_REDUCE >= 2
107
+ c_re += simd_shuffle_xor(c_re, 1);
108
+ c_im += simd_shuffle_xor(c_im, 1);
109
+ #endif
110
+ #if SCAT_REDUCE >= 4
111
+ c_re += simd_shuffle_xor(c_re, 2);
112
+ c_im += simd_shuffle_xor(c_im, 2);
113
+ #endif
114
+ #if SCAT_REDUCE >= 8
115
+ c_re += simd_shuffle_xor(c_re, 4);
116
+ c_im += simd_shuffle_xor(c_im, 4);
117
+ #endif
118
+ #if SCAT_REDUCE >= 16
119
+ c_re += simd_shuffle_xor(c_re, 8);
120
+ c_im += simd_shuffle_xor(c_im, 8);
121
+ #endif
122
+
123
+ if (scat_batch == 0 && valid) {
124
+ int offset = f * N_ELEM + elem_idx;
125
+ atomic_fetch_add_explicit(&spect_re[offset], c_re, memory_order_relaxed);
126
+ atomic_fetch_add_explicit(&spect_im[offset], c_im, memory_order_relaxed);
127
+ }
128
+ }
@@ -0,0 +1,175 @@
1
+ // Kernel: Element-tiled progression with shared-memory geometry.
2
+ // One threadgroup per scatterer; threads cooperatively compute geometry
3
+ // AND da-absorbed stride steps into shared memory, then each thread
4
+ // processes sub-element tiles with geometric progression (ALU-only inner loop).
5
+ //
6
+ // Low register pressure: only TILE_SE*2 float2 per thread (256 bytes for
7
+ // TILE_SE=16). ALU-only inner loop (0 SFU calls in the frequency sweep).
8
+ //
9
+ // Shared memory layout:
10
+ // amp[N_ES] frequency-independent amplitude
11
+ // kw_r[N_ES] kw_init * r (base phase)
12
+ // kr_step[N_ES] kw_step * r (phase increment per freq index)
13
+ // alpha_r[N_ES] alpha_init * r (base attenuation)
14
+ // ar_step[N_ES] alpha_step * r (attenuation increment per freq)
15
+ // stp[N_ES] float2 stride step, da-absorbed (same for all threads)
16
+ // da_init_re[N_ELEM] delay+apod init real part
17
+ // da_init_im[N_ELEM] delay+apod init imag part
18
+ // dps[N_ELEM] delay_phase_step per element
19
+ //
20
+ // Total: N_ES*(5*4 + 8) + N_ELEM*3*4 bytes
21
+ // = 64*(20+8) + 64*12 = 1792 + 768 = 2560 bytes (N_ES=64)
22
+ //
23
+ // Output: tx_re[N_SCAT * N_FREQ], tx_im[N_SCAT * N_FREQ]
24
+ //
25
+ // Compile-time constants:
26
+ // N_ELEM, N_SUB, N_FREQ, N_ES, N_SCAT, TILE_SE, TG_SIZE, MAX_FPT
27
+
28
+ threadgroup float sh_amp[N_ES];
29
+ threadgroup float sh_kw_r[N_ES];
30
+ threadgroup float sh_kr_step[N_ES];
31
+ threadgroup float sh_alpha_r[N_ES];
32
+ threadgroup float sh_ar_step[N_ES];
33
+ threadgroup float2 sh_stp[N_ES];
34
+ threadgroup float sh_da_init_re[N_ELEM];
35
+ threadgroup float sh_da_init_im[N_ELEM];
36
+ threadgroup float sh_dps[N_ELEM];
37
+
38
+ uint scat_idx = threadgroup_position_in_grid.x;
39
+ uint lid = thread_position_in_threadgroup.x;
40
+ uint tpg = threads_per_threadgroup.x;
41
+
42
+ if (scat_idx >= N_SCAT) return;
43
+
44
+ float sx = scat_x[scat_idx];
45
+ float sz = scat_z[scat_idx];
46
+ float is_out_i = is_out[scat_idx];
47
+
48
+ float kw_init_v = scalars[0];
49
+ float alpha_init_v = scalars[1];
50
+ float kw_step_v = scalars[2];
51
+ float alpha_step_v = scalars[3];
52
+ float min_dist = scalars[4];
53
+ float seg_len = scalars[5];
54
+ float center_kw = scalars[6];
55
+ float inv_nsub = scalars[7];
56
+
57
+ float lid_f = float(lid);
58
+ float stride_f = float(TG_SIZE);
59
+
60
+ // ---- Phase 1A: Cooperatively compute per-sub-element geometry ----
61
+ for (uint se = lid; se < (uint)N_ES; se += tpg) {
62
+ int elem = se / N_SUB;
63
+ int sub_global = elem * N_SUB + (se % N_SUB);
64
+
65
+ float ex = elem_x[elem];
66
+ float ez = elem_z[elem];
67
+ float te = theta_e[elem];
68
+
69
+ float dx = sx - ex - sub_dx[sub_global];
70
+ float dz = sz - ez - sub_dz[sub_global];
71
+ float r = metal::precise::sqrt(dx * dx + dz * dz);
72
+ float rc_ = max(r, min_dist);
73
+
74
+ float th = metal::precise::asin((dx + 1e-16f) / (r + 1e-16f)) - te;
75
+ float obliq = (fabs(th) >= M_PI_2_F) ? 1e-16f : metal::fast::cos(th);
76
+
77
+ float sa = center_kw * seg_len * 0.5f * metal::fast::sin(th);
78
+ float sv = (fabs(sa) < 1e-8f) ? 1.0f : metal::fast::sin(sa) / sa;
79
+
80
+ sh_amp[se] = obliq * sv / metal::precise::sqrt(rc_);
81
+ sh_kw_r[se] = kw_init_v * rc_;
82
+ sh_kr_step[se] = kw_step_v * rc_;
83
+ sh_alpha_r[se] = alpha_init_v * rc_;
84
+ sh_ar_step[se] = alpha_step_v * rc_;
85
+
86
+ // Precompute da-absorbed stride step (same for ALL threads).
87
+ // stp = exp((-alpha_step*stride + j*kw_step*stride) * r) * da_step^stride
88
+ float stp_phase = stride_f * kw_step_v * rc_;
89
+ float stp_alpha = stride_f * alpha_step_v * rc_;
90
+ float sm = metal::fast::exp(-stp_alpha);
91
+ float sp_re = sm * metal::fast::cos(stp_phase);
92
+ float sp_im = sm * metal::fast::sin(stp_phase);
93
+
94
+ float das_phase = stride_f * delay_phase_step[elem];
95
+ float das_re = metal::fast::cos(das_phase);
96
+ float das_im = metal::fast::sin(das_phase);
97
+ sh_stp[se] = float2(sp_re * das_re - sp_im * das_im,
98
+ sp_re * das_im + sp_im * das_re);
99
+ }
100
+
101
+ // ---- Phase 1B: Cooperatively load per-element da info ----
102
+ for (uint e = lid; e < (uint)N_ELEM; e += tpg) {
103
+ sh_da_init_re[e] = da_init_re[e];
104
+ sh_da_init_im[e] = da_init_im[e];
105
+ sh_dps[e] = delay_phase_step[e];
106
+ }
107
+
108
+ threadgroup_barrier(mem_flags::mem_threadgroup);
109
+
110
+ // ---- Phase 2: Tiled progression sweep ----
111
+ constexpr int N_TILES = (N_ES + TILE_SE - 1) / TILE_SE;
112
+
113
+ float sum_re[MAX_FPT];
114
+ float sum_im[MAX_FPT];
115
+ int my_n_freq = 0;
116
+ for (uint f = lid; f < (uint)N_FREQ; f += tpg) my_n_freq++;
117
+ for (int i = 0; i < MAX_FPT; i++) { sum_re[i] = 0.0f; sum_im[i] = 0.0f; }
118
+
119
+ for (int tile = 0; tile < N_TILES; tile++) {
120
+ int tile_start = tile * TILE_SE;
121
+ int tile_end = min(tile_start + TILE_SE, N_ES);
122
+ int tile_len = tile_end - tile_start;
123
+
124
+ float2 cur_t[TILE_SE];
125
+ float2 stp_t[TILE_SE];
126
+
127
+ // Init cur at this thread's starting frequency, read stp from shared
128
+ for (int te = 0; te < tile_len; te++) {
129
+ int se = tile_start + te;
130
+ int elem = se / N_SUB;
131
+
132
+ float phase = sh_kw_r[se] + lid_f * sh_kr_step[se];
133
+ float alpha_val = sh_alpha_r[se] + lid_f * sh_ar_step[se];
134
+ float ai = sh_amp[se] * metal::fast::exp(-alpha_val);
135
+ float pi_re = ai * metal::fast::cos(phase);
136
+ float pi_im = ai * metal::fast::sin(phase);
137
+
138
+ float da_ph = lid_f * sh_dps[elem];
139
+ float da_cs_re = metal::fast::cos(da_ph);
140
+ float da_cs_im = metal::fast::sin(da_ph);
141
+ float da_re = sh_da_init_re[elem] * da_cs_re - sh_da_init_im[elem] * da_cs_im;
142
+ float da_im = sh_da_init_re[elem] * da_cs_im + sh_da_init_im[elem] * da_cs_re;
143
+
144
+ cur_t[te] = float2(pi_re * da_re - pi_im * da_im,
145
+ pi_re * da_im + pi_im * da_re);
146
+ stp_t[te] = sh_stp[se];
147
+ }
148
+
149
+ // Sweep: ALU-only inner loop
150
+ for (int fi = 0; fi < my_n_freq; fi++) {
151
+ for (int te = 0; te < tile_len; te++) {
152
+ sum_re[fi] += cur_t[te].x;
153
+ sum_im[fi] += cur_t[te].y;
154
+ float cr = cur_t[te].x, ci = cur_t[te].y;
155
+ float tr = stp_t[te].x, ti = stp_t[te].y;
156
+ cur_t[te] = float2(cr * tr - ci * ti, cr * ti + ci * tr);
157
+ }
158
+ }
159
+ }
160
+
161
+ // ---- Phase 3: Apply inv_nsub, pulse*probe spectrum, write output ----
162
+ int fi = 0;
163
+ for (uint f = lid; f < (uint)N_FREQ; f += tpg, fi++) {
164
+ float tx_re_v = sum_re[fi] * inv_nsub;
165
+ float tx_im_v = sum_im[fi] * inv_nsub;
166
+
167
+ float pp_re_f = pp_re[f], pp_im_f = pp_im[f];
168
+ float pk_re = pp_re_f * tx_re_v - pp_im_f * tx_im_v;
169
+ float pk_im = pp_re_f * tx_im_v + pp_im_f * tx_re_v;
170
+ if (is_out_i > 0.5f) { pk_re = 0.0f; pk_im = 0.0f; }
171
+
172
+ int out_idx = scat_idx * N_FREQ + f;
173
+ tx_re[out_idx] = pk_re;
174
+ tx_im[out_idx] = pk_im;
175
+ }
@@ -0,0 +1,22 @@
1
+ """Medium parameter definitions for ultrasound propagation."""
2
+
3
+ from pydantic import BaseModel, ConfigDict, Field
4
+
5
+
6
+ class MediumParams(BaseModel):
7
+ """Medium parameters for ultrasound propagation.
8
+
9
+ This class encapsulates physical properties of the propagation medium
10
+ (e.g., soft tissue, water) that affect ultrasound wave propagation.
11
+ """
12
+
13
+ model_config = ConfigDict(
14
+ use_attribute_docstrings=True,
15
+ frozen=True,
16
+ )
17
+
18
+ speed_of_sound: float = Field(default=1540.0, gt=0)
19
+ """Speed of sound in m/s."""
20
+
21
+ attenuation: float = Field(default=0.0, ge=0)
22
+ """Attenuation coefficient in dB/cm/MHz."""