tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tide/csrc/maxwell.cu ADDED
@@ -0,0 +1,2297 @@
1
+ /*
2
+ * Maxwell wave equation propagator (CUDA implementation)
3
+ *
4
+ * This file contains the CUDA implementation of the 2D TM Maxwell equations
5
+ * propagator with complete Adjoint State Method (ASM) support for gradient computation.
6
+ *
7
+ * TM mode fields: Ey (electric), Hx, Hz (magnetic)
8
+ *
9
+ * EXACT DISCRETE Adjoint State Method for Maxwell TM equations:
10
+ * =============================================================
11
+ * Forward equations (discrete):
12
+ * E_y^{n+1} = C_a * E_y^n + C_b * (D_x[H_z] - D_z[H_x])
13
+ * H_x^{n+1/2} = H_x^{n-1/2} - C_q * D_z^h[E_y]
14
+ * H_z^{n+1/2} = H_z^{n-1/2} + C_q * D_x^h[E_y]
15
+ *
16
+ * Exact discrete adjoint equations (time-reversed with transposed operators):
17
+ * λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * (D_x^{hT}[λ_Hz] - D_z^{hT}[λ_Hx])
18
+ * λ_Hx^{n-1/2} = λ_Hx^{n+1/2} - C_b * D_z^T[λ_Ey]
19
+ * λ_Hz^{n-1/2} = λ_Hz^{n+1/2} + C_b * D_x^T[λ_Ey]
20
+ *
21
+ * Model gradients:
22
+ * ∂J/∂C_a = Σ_n E_y^n * λ_Ey^{n+1}
23
+ * ∂J/∂C_b = Σ_n curl_H^n * λ_Ey^{n+1}
24
+ *
25
+ * Gradient accumulation strategy:
26
+ * - Use per-shot gradient arrays (grad_ca_shot, grad_cb_shot)
27
+ * - Each shot writes to its own memory region (no race condition)
28
+ * - Use combine_grad kernel to sum across shots at the end
29
+ */
30
+
31
+ #include <stdio.h>
32
+ #include <cstdint>
33
+ #include <cstdlib>
34
+ #include <climits>
35
+ #include <math.h>
36
+ #include <cuda_bf16.h>
37
+ #if defined(__has_include)
38
+ #if __has_include(<cuda_fp8.h>)
39
+ #include <cuda_fp8.h>
40
+ #define TIDE_HAVE_CUDA_FP8 1
41
+ #else
42
+ #define TIDE_HAVE_CUDA_FP8 0
43
+ #endif
44
+ #else
45
+ #define TIDE_HAVE_CUDA_FP8 0
46
+ #endif
47
+ #include "common_gpu.h"
48
+ #include "staggered_grid.h"
49
+ #include "storage_utils.h"
50
+
51
+ #ifndef TIDE_DEVICE
52
+ #define TIDE_DEVICE cuda
53
+ #endif
54
+
55
+ // CPU storage pipelining: Number of ping-pong buffers for async D2H/H2D copies
56
+ // Increasing this reduces synchronization stalls between compute and copy
57
+ #ifndef NUM_BUFFERS
58
+ #define NUM_BUFFERS 3
59
+ #endif
60
+
61
+ // Profiling support: enable with -DTIDE_PROFILING during compilation
62
+ #ifdef TIDE_PROFILING
63
+ #define PROF_EVENT_CREATE(e) cudaEventCreate(&(e))
64
+ #define PROF_RECORD(e, s) cudaEventRecord((e), (s))
65
+ #define PROF_ELAPSED(start, end, ms) cudaEventElapsedTime(&(ms), (start), (end))
66
+ #define PROF_PRINT(name, ms) fprintf(stderr, "[TIDE PROF] %s: %.3f ms\n", (name), (ms))
67
+ #else
68
+ #define PROF_EVENT_CREATE(e) ((void)0)
69
+ #define PROF_RECORD(e, s) ((void)0)
70
+ #define PROF_ELAPSED(start, end, ms) ((void)0)
71
+ #define PROF_PRINT(name, ms) ((void)0)
72
+ #endif
73
+
74
+ #define CAT_I(name, accuracy, dtype, device) \
75
+ maxwell_tm_##accuracy##_##dtype##_##name##_##device
76
+ #define CAT(name, accuracy, dtype, device) \
77
+ CAT_I(name, accuracy, dtype, device)
78
+ #define FUNC(name) CAT(name, TIDE_STENCIL, TIDE_DTYPE, TIDE_DEVICE)
79
+
80
+ // 2D indexing macros
81
+ #define ND_INDEX(i, dy, dx) (i + (dy)*nx + (dx))
82
+ #define ND_INDEX_J(j, dy, dx) (j + (dy)*nx + (dx))
83
+
84
+ #define gpuErrchk(ans) \
85
+ { gpuAssert((ans), __FILE__, __LINE__); }
86
+ // Field access macros
87
+ #define EY(dy, dx) ey[ND_INDEX(i, dy, dx)]
88
+ #define HX(dy, dx) hx[ND_INDEX(i, dy, dx)]
89
+ #define HZ(dy, dx) hz[ND_INDEX(i, dy, dx)]
90
+
91
+ // Adjoint field access macros
92
+ #define LAMBDA_EY(dy, dx) lambda_ey[ND_INDEX(i, dy, dx)]
93
+ #define LAMBDA_HX(dy, dx) lambda_hx[ND_INDEX(i, dy, dx)]
94
+ #define LAMBDA_HZ(dy, dx) lambda_hz[ND_INDEX(i, dy, dx)]
95
+
96
+ // Material parameter access macros
97
+ #define CA(dy, dx) ca_shot[ND_INDEX_J(j, dy, dx)]
98
+ #define CB(dy, dx) cb_shot[ND_INDEX_J(j, dy, dx)]
99
+ #define CQ(dy, dx) cq_shot[ND_INDEX_J(j, dy, dx)]
100
+
101
+ // PML memory variable macros
102
+ #define M_HX_Z(dy, dx) m_hx_z[ND_INDEX(i, dy, dx)]
103
+ #define M_HZ_X(dy, dx) m_hz_x[ND_INDEX(i, dy, dx)]
104
+ #define M_EY_X(dy, dx) m_ey_x[ND_INDEX(i, dy, dx)]
105
+ #define M_EY_Z(dy, dx) m_ey_z[ND_INDEX(i, dy, dx)]
106
+
107
+ // Adjoint PML memory variable macros
108
+ #define M_LAMBDA_EY_X(dy, dx) m_lambda_ey_x[ND_INDEX(i, dy, dx)]
109
+ #define M_LAMBDA_EY_Z(dy, dx) m_lambda_ey_z[ND_INDEX(i, dy, dx)]
110
+ #define M_LAMBDA_HX_Z(dy, dx) m_lambda_hx_z[ND_INDEX(i, dy, dx)]
111
+ #define M_LAMBDA_HZ_X(dy, dx) m_lambda_hz_x[ND_INDEX(i, dy, dx)]
112
+
113
+ #define MAX(a, b) (a > b ? a : b)
114
+
115
+ // Vacuum permittivity (F/m) to convert dL/d(epsilon_abs) -> dL/d(epsilon_r)
116
+ #define EP0 ((TIDE_DTYPE)8.8541878128e-12)
117
+
118
+ namespace {
119
+
120
+ // Device constants
121
+ __constant__ TIDE_DTYPE rdy;
122
+ __constant__ TIDE_DTYPE rdx;
123
+ __constant__ int64_t n_shots;
124
+ __constant__ int64_t ny;
125
+ __constant__ int64_t nx;
126
+ __constant__ int64_t shot_numel;
127
+ __constant__ int64_t n_sources_per_shot;
128
+ __constant__ int64_t n_receivers_per_shot;
129
+ __constant__ int64_t pml_y0;
130
+ __constant__ int64_t pml_y1;
131
+ __constant__ int64_t pml_x0;
132
+ __constant__ int64_t pml_x1;
133
+ __constant__ bool ca_batched;
134
+ __constant__ bool cb_batched;
135
+ __constant__ bool cq_batched;
136
+
137
+ // Add source to field
138
+ __global__ void add_sources_ey(TIDE_DTYPE *__restrict const ey,
139
+ TIDE_DTYPE const *__restrict const f,
140
+ int64_t const *__restrict const sources_i) {
141
+ int64_t source_idx =
142
+ (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
143
+ int64_t shot_idx =
144
+ (int64_t)blockIdx.y * (int64_t)blockDim.y + (int64_t)threadIdx.y;
145
+ if (source_idx < n_sources_per_shot && shot_idx < n_shots) {
146
+ int64_t k = shot_idx * n_sources_per_shot + source_idx;
147
+ int64_t const src = sources_i[k];
148
+ if (0 <= src) {
149
+ ey[shot_idx * shot_numel + src] += f[k];
150
+ }
151
+ }
152
+ }
153
+
154
+ // Add adjoint source at receiver locations (for backward pass)
155
+ __global__ void add_adjoint_sources_ey(TIDE_DTYPE *__restrict const ey,
156
+ TIDE_DTYPE const *__restrict const f,
157
+ int64_t const *__restrict const receivers_i) {
158
+ int64_t receiver_idx =
159
+ (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
160
+ int64_t shot_idx =
161
+ (int64_t)blockIdx.y * (int64_t)blockDim.y + (int64_t)threadIdx.y;
162
+ if (receiver_idx < n_receivers_per_shot && shot_idx < n_shots) {
163
+ int64_t k = shot_idx * n_receivers_per_shot + receiver_idx;
164
+ int64_t const rec = receivers_i[k];
165
+ if (0 <= rec) {
166
+ ey[shot_idx * shot_numel + rec] += f[k];
167
+ }
168
+ }
169
+ }
170
+
171
+ // Record field at receiver locations
172
+ __global__ void record_receivers_ey(TIDE_DTYPE *__restrict const r,
173
+ TIDE_DTYPE const *__restrict const ey,
174
+ int64_t const *__restrict receivers_i) {
175
+ int64_t receiver_idx =
176
+ (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
177
+ int64_t shot_idx =
178
+ (int64_t)blockIdx.y * (int64_t)blockDim.y + (int64_t)threadIdx.y;
179
+ if (receiver_idx < n_receivers_per_shot && shot_idx < n_shots) {
180
+ int64_t k = shot_idx * n_receivers_per_shot + receiver_idx;
181
+ int64_t const rec = receivers_i[k];
182
+ if (0 <= rec) {
183
+ r[k] = ey[shot_idx * shot_numel + rec];
184
+ }
185
+ }
186
+ }
187
+
188
+ // Record adjoint field at source locations (for backward pass - source gradient)
189
+ __global__ void record_adjoint_at_sources(TIDE_DTYPE *__restrict const grad_f,
190
+ TIDE_DTYPE const *__restrict const lambda_ey,
191
+ int64_t const *__restrict sources_i) {
192
+ int64_t source_idx =
193
+ (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
194
+ int64_t shot_idx =
195
+ (int64_t)blockIdx.y * (int64_t)blockDim.y + (int64_t)threadIdx.y;
196
+ if (source_idx < n_sources_per_shot && shot_idx < n_shots) {
197
+ int64_t k = shot_idx * n_sources_per_shot + source_idx;
198
+ int64_t const src = sources_i[k];
199
+ if (0 <= src) {
200
+ grad_f[k] = lambda_ey[shot_idx * shot_numel + src];
201
+ }
202
+ }
203
+ }
204
+
205
+
206
+ // FP8 E4M3 (1 sign, 4 exponent, 3 mantissa) encode/decode.
207
+ __device__ __forceinline__ uint8_t fp8_e4m3_from_float(float x) {
208
+ #if TIDE_HAVE_CUDA_FP8
209
+ return (uint8_t)__nv_cvt_float_to_fp8(x, __NV_SATFINITE, __NV_E4M3);
210
+ #else
211
+ if (x == 0.0f) {
212
+ return 0;
213
+ }
214
+ uint8_t sign = (x < 0.0f) ? 0x80 : 0;
215
+ float ax = fabsf(x);
216
+ if (!isfinite(ax)) {
217
+ return (uint8_t)(sign | 0x7F);
218
+ }
219
+ int exp;
220
+ float m = frexpf(ax, &exp); // ax = m * 2^exp, m in [0.5, 1)
221
+ int e = exp - 1;
222
+ int exp_field = e + 7;
223
+ int mant = 0;
224
+
225
+ if (exp_field <= 0) {
226
+ mant = __float2int_rn(ax * 512.0f);
227
+ if (mant <= 0) {
228
+ return sign;
229
+ }
230
+ if (mant > 7) {
231
+ mant = 7;
232
+ }
233
+ exp_field = 0;
234
+ } else if (exp_field >= 0xF) {
235
+ exp_field = 0xE;
236
+ mant = 7;
237
+ } else {
238
+ float frac = m * 2.0f - 1.0f;
239
+ mant = __float2int_rn(frac * 8.0f);
240
+ if (mant == 8) {
241
+ mant = 0;
242
+ exp_field += 1;
243
+ if (exp_field >= 0xF) {
244
+ exp_field = 0xE;
245
+ mant = 7;
246
+ }
247
+ }
248
+ }
249
+
250
+ return (uint8_t)(sign | ((uint8_t)exp_field << 3) | (uint8_t)(mant & 0x7));
251
+ #endif
252
+ }
253
+
254
+ __device__ __forceinline__ float fp8_e4m3_to_float(uint8_t v) {
255
+ #if TIDE_HAVE_CUDA_FP8
256
+ __half h = __nv_cvt_fp8_to_halfraw((__nv_fp8_storage_t)v, __NV_E4M3);
257
+ return __half2float(h);
258
+ #else
259
+ if (v == 0) {
260
+ return 0.0f;
261
+ }
262
+ int sign = v & 0x80;
263
+ int exp_field = (v >> 3) & 0xF;
264
+ int mant = v & 0x7;
265
+ float val;
266
+ if (exp_field == 0) {
267
+ float frac = (float)mant / 8.0f;
268
+ val = ldexpf(frac, -6);
269
+ } else {
270
+ float frac = 1.0f + (float)mant / 8.0f;
271
+ val = ldexpf(frac, exp_field - 7);
272
+ }
273
+ return sign ? -val : val;
274
+ #endif
275
+ }
276
+
277
+
278
+ // Forward kernel: Update H fields (Hx and Hz)
279
+ __global__ __launch_bounds__(256) void forward_kernel_h(
280
+ TIDE_DTYPE const *__restrict const cq,
281
+ TIDE_DTYPE const *__restrict const ey,
282
+ TIDE_DTYPE *__restrict const hx,
283
+ TIDE_DTYPE *__restrict const hz,
284
+ TIDE_DTYPE *__restrict const m_ey_x,
285
+ TIDE_DTYPE *__restrict const m_ey_z,
286
+ TIDE_DTYPE const *__restrict const ay,
287
+ TIDE_DTYPE const *__restrict const ayh,
288
+ TIDE_DTYPE const *__restrict const ax,
289
+ TIDE_DTYPE const *__restrict const axh,
290
+ TIDE_DTYPE const *__restrict const by,
291
+ TIDE_DTYPE const *__restrict const byh,
292
+ TIDE_DTYPE const *__restrict const bx,
293
+ TIDE_DTYPE const *__restrict const bxh,
294
+ TIDE_DTYPE const *__restrict const ky,
295
+ TIDE_DTYPE const *__restrict const kyh,
296
+ TIDE_DTYPE const *__restrict const kx,
297
+ TIDE_DTYPE const *__restrict const kxh) {
298
+
299
+ #if FD_PAD > 1
300
+ // Shared-memory tiling for Ey stencil loads.
301
+ // Assumes blockDim.z == 1 (one shot per block).
302
+ extern __shared__ TIDE_DTYPE shmem[];
303
+ TIDE_DTYPE *__restrict const tile_ey = shmem;
304
+ #endif
305
+
306
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
307
+ (int64_t)threadIdx.x + FD_PAD;
308
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
309
+ (int64_t)threadIdx.y + FD_PAD;
310
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
311
+ (int64_t)threadIdx.z;
312
+
313
+ if (shot_idx >= n_shots) return;
314
+
315
+ #if FD_PAD > 1
316
+ int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
317
+ int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
318
+ int64_t const tile_pitch = tile_w;
319
+ int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
320
+ int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
321
+ int64_t const base = shot_idx * shot_numel;
322
+
323
+ int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
324
+ (int64_t)threadIdx.x;
325
+ int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
326
+ int64_t const tile_numel = tile_w * tile_h;
327
+ // Original scalar loading (optimization 2.1: vectorized loading disabled due to overhead)
328
+ for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
329
+ int64_t const ly = idx / tile_w;
330
+ int64_t const lx = idx - ly * tile_w;
331
+ int64_t const gx = x0 - FD_PAD + lx;
332
+ int64_t const gy = y0 - FD_PAD + ly;
333
+ if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
334
+ tile_ey[ly * tile_pitch + lx] = __ldg(&ey[base + gy * nx + gx]);
335
+ } else {
336
+ tile_ey[ly * tile_pitch + lx] = (TIDE_DTYPE)0;
337
+ }
338
+ }
339
+ __syncthreads();
340
+
341
+ #define EY_L(dy, dx) tile_ey[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
342
+ #else
343
+ #define EY_L(dy, dx) EY(dy, dx)
344
+ #endif
345
+
346
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
347
+ int64_t const pml_y0h = pml_y0;
348
+ int64_t const pml_y1h = MAX(pml_y0, pml_y1 - 1);
349
+ int64_t const pml_x0h = pml_x0;
350
+ int64_t const pml_x1h = MAX(pml_x0, pml_x1 - 1);
351
+
352
+ int64_t j = y * nx + x;
353
+ int64_t i = shot_idx * shot_numel + j;
354
+
355
+ TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
356
+
357
+ // Pre-load PML coefficients into registers (optimization 1.2)
358
+ TIDE_DTYPE byh_val = __ldg(&byh[y]);
359
+ TIDE_DTYPE ayh_val = __ldg(&ayh[y]);
360
+ TIDE_DTYPE kyh_val = __ldg(&kyh[y]);
361
+ TIDE_DTYPE bxh_val = __ldg(&bxh[x]);
362
+ TIDE_DTYPE axh_val = __ldg(&axh[x]);
363
+ TIDE_DTYPE kxh_val = __ldg(&kxh[x]);
364
+
365
+ // Update Hx: Hx = Hx - cq * dEy/dz
366
+ if (y < ny - FD_PAD) {
367
+ bool pml_y = y < pml_y0h || y >= pml_y1h;
368
+
369
+ TIDE_DTYPE dey_dz = DIFFYH1(EY_L);
370
+
371
+ if (pml_y) {
372
+ m_ey_z[i] = byh_val * m_ey_z[i] + ayh_val * dey_dz;
373
+ dey_dz = dey_dz / kyh_val + m_ey_z[i];
374
+ }
375
+
376
+ hx[i] -= cq_shot_i * dey_dz;
377
+ }
378
+
379
+ // Update Hz: Hz = Hz + cq * dEy/dx
380
+ if (x < nx - FD_PAD) {
381
+ bool pml_x = x < pml_x0h || x >= pml_x1h;
382
+
383
+ TIDE_DTYPE dey_dx = DIFFXH1(EY_L);
384
+
385
+ if (pml_x) {
386
+ m_ey_x[i] = bxh_val * m_ey_x[i] + axh_val * dey_dx;
387
+ dey_dx = dey_dx / kxh_val + m_ey_x[i];
388
+ }
389
+
390
+ hz[i] += cq_shot_i * dey_dx;
391
+ }
392
+ }
393
+
394
+ #undef EY_L
395
+ }
396
+
397
+ // Forward kernel: Update E field (Ey) - standard version
398
+ __global__ __launch_bounds__(256) void forward_kernel_e(
399
+ TIDE_DTYPE const *__restrict const ca,
400
+ TIDE_DTYPE const *__restrict const cb,
401
+ TIDE_DTYPE const *__restrict const hx,
402
+ TIDE_DTYPE const *__restrict const hz,
403
+ TIDE_DTYPE *__restrict const ey,
404
+ TIDE_DTYPE *__restrict const m_hx_z,
405
+ TIDE_DTYPE *__restrict const m_hz_x,
406
+ TIDE_DTYPE const *__restrict const ay,
407
+ TIDE_DTYPE const *__restrict const ayh,
408
+ TIDE_DTYPE const *__restrict const ax,
409
+ TIDE_DTYPE const *__restrict const axh,
410
+ TIDE_DTYPE const *__restrict const by,
411
+ TIDE_DTYPE const *__restrict const byh,
412
+ TIDE_DTYPE const *__restrict const bx,
413
+ TIDE_DTYPE const *__restrict const bxh,
414
+ TIDE_DTYPE const *__restrict const ky,
415
+ TIDE_DTYPE const *__restrict const kyh,
416
+ TIDE_DTYPE const *__restrict const kx,
417
+ TIDE_DTYPE const *__restrict const kxh) {
418
+
419
+ #if FD_PAD > 1
420
+ // Shared-memory tiling for Hx/Hz stencil loads.
421
+ // Assumes blockDim.z == 1 (one shot per block).
422
+ extern __shared__ TIDE_DTYPE shmem[];
423
+ int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
424
+ int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
425
+ int64_t const tile_pitch = tile_w;
426
+ int64_t const tile_numel = tile_w * tile_h;
427
+ TIDE_DTYPE *__restrict const tile_hx = shmem;
428
+ TIDE_DTYPE *__restrict const tile_hz = shmem + tile_numel;
429
+ #endif
430
+
431
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
432
+ (int64_t)threadIdx.x + FD_PAD;
433
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
434
+ (int64_t)threadIdx.y + FD_PAD;
435
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
436
+ (int64_t)threadIdx.z;
437
+
438
+ if (shot_idx >= n_shots) return;
439
+
440
+ #if FD_PAD > 1
441
+ int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
442
+ int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
443
+ int64_t const base = shot_idx * shot_numel;
444
+ int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
445
+ (int64_t)threadIdx.x;
446
+ int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
447
+ // Original scalar loading (optimization 2.1: vectorized loading disabled due to overhead)
448
+ for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
449
+ int64_t const ly = idx / tile_w;
450
+ int64_t const lx = idx - ly * tile_w;
451
+ int64_t const gx = x0 - FD_PAD + lx;
452
+ int64_t const gy = y0 - FD_PAD + ly;
453
+ if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
454
+ int64_t const g = base + gy * nx + gx;
455
+ int64_t const offset = ly * tile_pitch + lx;
456
+ tile_hx[offset] = __ldg(&hx[g]);
457
+ tile_hz[offset] = __ldg(&hz[g]);
458
+ } else {
459
+ int64_t const offset = ly * tile_pitch + lx;
460
+ tile_hx[offset] = (TIDE_DTYPE)0;
461
+ tile_hz[offset] = (TIDE_DTYPE)0;
462
+ }
463
+ }
464
+ __syncthreads();
465
+
466
+ #define HX_L(dy, dx) tile_hx[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
467
+ #define HZ_L(dy, dx) tile_hz[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
468
+ #else
469
+ #define HX_L(dy, dx) HX(dy, dx)
470
+ #define HZ_L(dy, dx) HZ(dy, dx)
471
+ #endif
472
+
473
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
474
+ int64_t j = y * nx + x;
475
+ int64_t i = shot_idx * shot_numel + j;
476
+
477
+ TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
478
+ TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
479
+
480
+ bool pml_y = y < pml_y0 || y >= pml_y1;
481
+ bool pml_x = x < pml_x0 || x >= pml_x1;
482
+
483
+ TIDE_DTYPE dhz_dx = DIFFX1(HZ_L);
484
+ TIDE_DTYPE dhx_dz = DIFFY1(HX_L);
485
+
486
+ // Pre-load PML coefficients into registers (optimization 1.2)
487
+ TIDE_DTYPE bx_val = __ldg(&bx[x]);
488
+ TIDE_DTYPE ax_val = __ldg(&ax[x]);
489
+ TIDE_DTYPE kx_val = __ldg(&kx[x]);
490
+ TIDE_DTYPE by_val = __ldg(&by[y]);
491
+ TIDE_DTYPE ay_val = __ldg(&ay[y]);
492
+ TIDE_DTYPE ky_val = __ldg(&ky[y]);
493
+
494
+ if (pml_x) {
495
+ m_hz_x[i] = bx_val * m_hz_x[i] + ax_val * dhz_dx;
496
+ dhz_dx = dhz_dx / kx_val + m_hz_x[i];
497
+ }
498
+
499
+ if (pml_y) {
500
+ m_hx_z[i] = by_val * m_hx_z[i] + ay_val * dhx_dz;
501
+ dhx_dz = dhx_dz / ky_val + m_hx_z[i];
502
+ }
503
+
504
+ ey[i] = ca_shot_i * ey[i] + cb_shot_i * (dhz_dx - dhx_dz);
505
+ }
506
+
507
+ #undef HX_L
508
+ #undef HZ_L
509
+ }
510
+
511
+ // Forward kernel: Update E field (Ey) with storage for gradient computation
512
+ __global__ void forward_kernel_e_with_storage(
513
+ TIDE_DTYPE const *__restrict const ca,
514
+ TIDE_DTYPE const *__restrict const cb,
515
+ TIDE_DTYPE const *__restrict const hx,
516
+ TIDE_DTYPE const *__restrict const hz,
517
+ TIDE_DTYPE *__restrict const ey,
518
+ TIDE_DTYPE *__restrict const m_hx_z,
519
+ TIDE_DTYPE *__restrict const m_hz_x,
520
+ TIDE_DTYPE *__restrict const ey_store, // Can be NULL
521
+ TIDE_DTYPE *__restrict const curl_h_store, // Can be NULL
522
+ TIDE_DTYPE const *__restrict const ay,
523
+ TIDE_DTYPE const *__restrict const ayh,
524
+ TIDE_DTYPE const *__restrict const ax,
525
+ TIDE_DTYPE const *__restrict const axh,
526
+ TIDE_DTYPE const *__restrict const by,
527
+ TIDE_DTYPE const *__restrict const byh,
528
+ TIDE_DTYPE const *__restrict const bx,
529
+ TIDE_DTYPE const *__restrict const bxh,
530
+ TIDE_DTYPE const *__restrict const ky,
531
+ TIDE_DTYPE const *__restrict const kyh,
532
+ TIDE_DTYPE const *__restrict const kx,
533
+ TIDE_DTYPE const *__restrict const kxh,
534
+ bool const ca_requires_grad,
535
+ bool const cb_requires_grad) {
536
+
537
+ #if FD_PAD > 1
538
+ // Shared-memory tiling for Hx/Hz stencil loads.
539
+ // Assumes blockDim.z == 1 (one shot per block).
540
+ extern __shared__ TIDE_DTYPE shmem[];
541
+ int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
542
+ int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
543
+ int64_t const tile_pitch = tile_w;
544
+ int64_t const tile_numel = tile_w * tile_h;
545
+ TIDE_DTYPE *__restrict const tile_hx = shmem;
546
+ TIDE_DTYPE *__restrict const tile_hz = shmem + tile_numel;
547
+ #endif
548
+
549
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
550
+ (int64_t)threadIdx.x + FD_PAD;
551
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
552
+ (int64_t)threadIdx.y + FD_PAD;
553
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
554
+ (int64_t)threadIdx.z;
555
+
556
+ if (shot_idx >= n_shots) return;
557
+
558
+ #if FD_PAD > 1
559
+ int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
560
+ int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
561
+ int64_t const base = shot_idx * shot_numel;
562
+ int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
563
+ (int64_t)threadIdx.x;
564
+ int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
565
+ // Original scalar loading (optimization 2.1: vectorized loading disabled due to overhead)
566
+ for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
567
+ int64_t const ly = idx / tile_w;
568
+ int64_t const lx = idx - ly * tile_w;
569
+ int64_t const gx = x0 - FD_PAD + lx;
570
+ int64_t const gy = y0 - FD_PAD + ly;
571
+ if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
572
+ int64_t const g = base + gy * nx + gx;
573
+ int64_t const offset = ly * tile_pitch + lx;
574
+ tile_hx[offset] = __ldg(&hx[g]);
575
+ tile_hz[offset] = __ldg(&hz[g]);
576
+ } else {
577
+ int64_t const offset = ly * tile_pitch + lx;
578
+ tile_hx[offset] = (TIDE_DTYPE)0;
579
+ tile_hz[offset] = (TIDE_DTYPE)0;
580
+ }
581
+ }
582
+ __syncthreads();
583
+
584
+ #define HX_L(dy, dx) tile_hx[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
585
+ #define HZ_L(dy, dx) tile_hz[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
586
+ #else
587
+ #define HX_L(dy, dx) HX(dy, dx)
588
+ #define HZ_L(dy, dx) HZ(dy, dx)
589
+ #endif
590
+
591
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots){
592
+ int64_t j = y * nx + x;
593
+ int64_t i = shot_idx * shot_numel + j;
594
+
595
+ TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
596
+ TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
597
+
598
+ bool pml_y = y < pml_y0 || y >= pml_y1;
599
+ bool pml_x = x < pml_x0 || x >= pml_x1;
600
+
601
+ TIDE_DTYPE dhz_dx = DIFFX1(HZ_L);
602
+ TIDE_DTYPE dhx_dz = DIFFY1(HX_L);
603
+
604
+ // Pre-load PML coefficients into registers (optimization 1.2)
605
+ TIDE_DTYPE bx_val = __ldg(&bx[x]);
606
+ TIDE_DTYPE ax_val = __ldg(&ax[x]);
607
+ TIDE_DTYPE kx_val = __ldg(&kx[x]);
608
+ TIDE_DTYPE by_val = __ldg(&by[y]);
609
+ TIDE_DTYPE ay_val = __ldg(&ay[y]);
610
+ TIDE_DTYPE ky_val = __ldg(&ky[y]);
611
+
612
+ if (pml_x) {
613
+ m_hz_x[i] = bx_val * m_hz_x[i] + ax_val * dhz_dx;
614
+ dhz_dx = dhz_dx / kx_val + m_hz_x[i];
615
+ }
616
+
617
+ if (pml_y) {
618
+ m_hx_z[i] = by_val * m_hx_z[i] + ay_val * dhx_dz;
619
+ dhx_dz = dhx_dz / ky_val + m_hx_z[i];
620
+ }
621
+
622
+ TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
623
+
624
+ // Store values for gradient computation (before E update)
625
+ if (ca_requires_grad && ey_store != nullptr) {
626
+ ey_store[i] = ey[i];
627
+ }
628
+ if (cb_requires_grad && curl_h_store != nullptr) {
629
+ curl_h_store[i] = curl_h;
630
+ }
631
+
632
+ ey[i] = ca_shot_i * ey[i] + cb_shot_i * curl_h;
633
+ }
634
+
635
+ #undef HX_L
636
+ #undef HZ_L
637
+ }
638
+
639
+ // Forward kernel: Update E field (Ey) with BF16 storage for gradient computation
640
+ // Stores Ey and curl_H in __nv_bfloat16 to reduce snapshot bandwidth/size.
641
+ __global__ void forward_kernel_e_with_storage_bf16(
642
+ TIDE_DTYPE const *__restrict const ca,
643
+ TIDE_DTYPE const *__restrict const cb,
644
+ TIDE_DTYPE const *__restrict const hx,
645
+ TIDE_DTYPE const *__restrict const hz,
646
+ TIDE_DTYPE *__restrict const ey,
647
+ TIDE_DTYPE *__restrict const m_hx_z,
648
+ TIDE_DTYPE *__restrict const m_hz_x,
649
+ __nv_bfloat16 *__restrict const ey_store, // Can be NULL
650
+ __nv_bfloat16 *__restrict const curl_h_store, // Can be NULL
651
+ TIDE_DTYPE const *__restrict const ay,
652
+ TIDE_DTYPE const *__restrict const ayh,
653
+ TIDE_DTYPE const *__restrict const ax,
654
+ TIDE_DTYPE const *__restrict const axh,
655
+ TIDE_DTYPE const *__restrict const by,
656
+ TIDE_DTYPE const *__restrict const byh,
657
+ TIDE_DTYPE const *__restrict const bx,
658
+ TIDE_DTYPE const *__restrict const bxh,
659
+ TIDE_DTYPE const *__restrict const ky,
660
+ TIDE_DTYPE const *__restrict const kyh,
661
+ TIDE_DTYPE const *__restrict const kx,
662
+ TIDE_DTYPE const *__restrict const kxh,
663
+ bool const ca_requires_grad,
664
+ bool const cb_requires_grad) {
665
+
666
+ #if FD_PAD > 1
667
+ // Shared-memory tiling for Hx/Hz stencil loads.
668
+ // Assumes blockDim.z == 1 (one shot per block).
669
+ extern __shared__ TIDE_DTYPE shmem[];
670
+ int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
671
+ int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
672
+ int64_t const tile_pitch = tile_w;
673
+ int64_t const tile_numel = tile_w * tile_h;
674
+ TIDE_DTYPE *__restrict const tile_hx = shmem;
675
+ TIDE_DTYPE *__restrict const tile_hz = shmem + tile_numel;
676
+ #endif
677
+
678
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
679
+ (int64_t)threadIdx.x + FD_PAD;
680
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
681
+ (int64_t)threadIdx.y + FD_PAD;
682
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
683
+ (int64_t)threadIdx.z;
684
+
685
+ if (shot_idx >= n_shots) return;
686
+
687
+ #if FD_PAD > 1
688
+ int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
689
+ int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
690
+ int64_t const base = shot_idx * shot_numel;
691
+ int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
692
+ (int64_t)threadIdx.x;
693
+ int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
694
+ // Original scalar loading (optimization 2.1: vectorized loading disabled due to overhead)
695
+ for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
696
+ int64_t const ly = idx / tile_w;
697
+ int64_t const lx = idx - ly * tile_w;
698
+ int64_t const gx = x0 - FD_PAD + lx;
699
+ int64_t const gy = y0 - FD_PAD + ly;
700
+ if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
701
+ int64_t const g = base + gy * nx + gx;
702
+ int64_t const offset = ly * tile_pitch + lx;
703
+ tile_hx[offset] = __ldg(&hx[g]);
704
+ tile_hz[offset] = __ldg(&hz[g]);
705
+ } else {
706
+ int64_t const offset = ly * tile_pitch + lx;
707
+ tile_hx[offset] = (TIDE_DTYPE)0;
708
+ tile_hz[offset] = (TIDE_DTYPE)0;
709
+ }
710
+ }
711
+ __syncthreads();
712
+
713
+ #define HX_L(dy, dx) tile_hx[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
714
+ #define HZ_L(dy, dx) tile_hz[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
715
+ #else
716
+ #define HX_L(dy, dx) HX(dy, dx)
717
+ #define HZ_L(dy, dx) HZ(dy, dx)
718
+ #endif
719
+
720
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
721
+ int64_t j = y * nx + x;
722
+ int64_t i = shot_idx * shot_numel + j;
723
+
724
+ TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
725
+ TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
726
+
727
+ bool pml_y = y < pml_y0 || y >= pml_y1;
728
+ bool pml_x = x < pml_x0 || x >= pml_x1;
729
+
730
+ TIDE_DTYPE dhz_dx = DIFFX1(HZ_L);
731
+ TIDE_DTYPE dhx_dz = DIFFY1(HX_L);
732
+
733
+ // Pre-load PML coefficients into registers (optimization 1.2)
734
+ TIDE_DTYPE bx_val = __ldg(&bx[x]);
735
+ TIDE_DTYPE ax_val = __ldg(&ax[x]);
736
+ TIDE_DTYPE kx_val = __ldg(&kx[x]);
737
+ TIDE_DTYPE by_val = __ldg(&by[y]);
738
+ TIDE_DTYPE ay_val = __ldg(&ay[y]);
739
+ TIDE_DTYPE ky_val = __ldg(&ky[y]);
740
+
741
+ if (pml_x) {
742
+ m_hz_x[i] = bx_val * m_hz_x[i] + ax_val * dhz_dx;
743
+ dhz_dx = dhz_dx / kx_val + m_hz_x[i];
744
+ }
745
+
746
+ if (pml_y) {
747
+ m_hx_z[i] = by_val * m_hx_z[i] + ay_val * dhx_dz;
748
+ dhx_dz = dhx_dz / ky_val + m_hx_z[i];
749
+ }
750
+
751
+ TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
752
+
753
+ if (ca_requires_grad && ey_store != nullptr) {
754
+ ey_store[i] = __float2bfloat16((float)ey[i]);
755
+ }
756
+ if (cb_requires_grad && curl_h_store != nullptr) {
757
+ curl_h_store[i] = __float2bfloat16((float)curl_h);
758
+ }
759
+
760
+ ey[i] = ca_shot_i * ey[i] + cb_shot_i * curl_h;
761
+ }
762
+
763
+ #undef HX_L
764
+ #undef HZ_L
765
+ }
766
+
767
+ // Forward kernel: Update E field (Ey) with FP8 storage for gradient computation.
768
+ __global__ void forward_kernel_e_with_storage_fp8(
769
+ TIDE_DTYPE const *__restrict const ca,
770
+ TIDE_DTYPE const *__restrict const cb,
771
+ TIDE_DTYPE const *__restrict const hx,
772
+ TIDE_DTYPE const *__restrict const hz,
773
+ TIDE_DTYPE *__restrict const ey,
774
+ TIDE_DTYPE *__restrict const m_hx_z,
775
+ TIDE_DTYPE *__restrict const m_hz_x,
776
+ uint8_t *__restrict const ey_store, // Can be NULL
777
+ uint8_t *__restrict const curl_h_store, // Can be NULL
778
+ TIDE_DTYPE const *__restrict const ay,
779
+ TIDE_DTYPE const *__restrict const ayh,
780
+ TIDE_DTYPE const *__restrict const ax,
781
+ TIDE_DTYPE const *__restrict const axh,
782
+ TIDE_DTYPE const *__restrict const by,
783
+ TIDE_DTYPE const *__restrict const byh,
784
+ TIDE_DTYPE const *__restrict const bx,
785
+ TIDE_DTYPE const *__restrict const bxh,
786
+ TIDE_DTYPE const *__restrict const ky,
787
+ TIDE_DTYPE const *__restrict const kyh,
788
+ TIDE_DTYPE const *__restrict const kx,
789
+ TIDE_DTYPE const *__restrict const kxh,
790
+ bool const ca_requires_grad,
791
+ bool const cb_requires_grad) {
792
+
793
+ #if FD_PAD > 1
794
+ // Shared-memory tiling for Hx/Hz stencil loads.
795
+ // Assumes blockDim.z == 1 (one shot per block).
796
+ extern __shared__ TIDE_DTYPE shmem[];
797
+ int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
798
+ int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
799
+ int64_t const tile_pitch = tile_w;
800
+ int64_t const tile_numel = tile_w * tile_h;
801
+ TIDE_DTYPE *__restrict const tile_hx = shmem;
802
+ TIDE_DTYPE *__restrict const tile_hz = shmem + tile_numel;
803
+ #endif
804
+
805
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
806
+ (int64_t)threadIdx.x + FD_PAD;
807
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
808
+ (int64_t)threadIdx.y + FD_PAD;
809
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
810
+ (int64_t)threadIdx.z;
811
+
812
+ if (shot_idx >= n_shots) return;
813
+
814
+ #if FD_PAD > 1
815
+ int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
816
+ int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
817
+ int64_t const base = shot_idx * shot_numel;
818
+ int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
819
+ (int64_t)threadIdx.x;
820
+ int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
821
+ for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
822
+ int64_t const ly = idx / tile_w;
823
+ int64_t const lx = idx - ly * tile_w;
824
+ int64_t const gx = x0 - FD_PAD + lx;
825
+ int64_t const gy = y0 - FD_PAD + ly;
826
+ if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
827
+ int64_t const g = base + gy * nx + gx;
828
+ int64_t const offset = ly * tile_pitch + lx;
829
+ tile_hx[offset] = __ldg(&hx[g]);
830
+ tile_hz[offset] = __ldg(&hz[g]);
831
+ } else {
832
+ int64_t const offset = ly * tile_pitch + lx;
833
+ tile_hx[offset] = (TIDE_DTYPE)0;
834
+ tile_hz[offset] = (TIDE_DTYPE)0;
835
+ }
836
+ }
837
+ __syncthreads();
838
+
839
+ #define HX_L(dy, dx) tile_hx[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
840
+ #define HZ_L(dy, dx) tile_hz[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
841
+ #else
842
+ #define HX_L(dy, dx) HX(dy, dx)
843
+ #define HZ_L(dy, dx) HZ(dy, dx)
844
+ #endif
845
+
846
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
847
+ int64_t j = y * nx + x;
848
+ int64_t i = shot_idx * shot_numel + j;
849
+
850
+ TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
851
+ TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
852
+
853
+ bool pml_y = y < pml_y0 || y >= pml_y1;
854
+ bool pml_x = x < pml_x0 || x >= pml_x1;
855
+
856
+ TIDE_DTYPE dhz_dx = DIFFX1(HZ_L);
857
+ TIDE_DTYPE dhx_dz = DIFFY1(HX_L);
858
+
859
+ // Pre-load PML coefficients into registers (optimization 1.2)
860
+ TIDE_DTYPE bx_val = __ldg(&bx[x]);
861
+ TIDE_DTYPE ax_val = __ldg(&ax[x]);
862
+ TIDE_DTYPE kx_val = __ldg(&kx[x]);
863
+ TIDE_DTYPE by_val = __ldg(&by[y]);
864
+ TIDE_DTYPE ay_val = __ldg(&ay[y]);
865
+ TIDE_DTYPE ky_val = __ldg(&ky[y]);
866
+
867
+ if (pml_x) {
868
+ m_hz_x[i] = bx_val * m_hz_x[i] + ax_val * dhz_dx;
869
+ dhz_dx = dhz_dx / kx_val + m_hz_x[i];
870
+ }
871
+
872
+ if (pml_y) {
873
+ m_hx_z[i] = by_val * m_hx_z[i] + ay_val * dhx_dz;
874
+ dhx_dz = dhx_dz / ky_val + m_hx_z[i];
875
+ }
876
+
877
+ TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
878
+
879
+ if (ca_requires_grad && ey_store != nullptr) {
880
+ ey_store[i] = fp8_e4m3_from_float((float)ey[i]);
881
+ }
882
+ if (cb_requires_grad && curl_h_store != nullptr) {
883
+ curl_h_store[i] = fp8_e4m3_from_float((float)curl_h);
884
+ }
885
+
886
+ ey[i] = ca_shot_i * ey[i] + cb_shot_i * curl_h;
887
+ }
888
+
889
+ #undef HX_L
890
+ #undef HZ_L
891
+ }
892
+
893
+ // Backward kernel: Update adjoint λ_H fields
894
+ __global__ void backward_kernel_lambda_h(
895
+ TIDE_DTYPE const *__restrict const cb,
896
+ TIDE_DTYPE const *__restrict const lambda_ey,
897
+ TIDE_DTYPE *__restrict const lambda_hx,
898
+ TIDE_DTYPE *__restrict const lambda_hz,
899
+ TIDE_DTYPE *__restrict const m_lambda_ey_x,
900
+ TIDE_DTYPE *__restrict const m_lambda_ey_z,
901
+ TIDE_DTYPE const *__restrict const ay,
902
+ TIDE_DTYPE const *__restrict const ayh,
903
+ TIDE_DTYPE const *__restrict const ax,
904
+ TIDE_DTYPE const *__restrict const axh,
905
+ TIDE_DTYPE const *__restrict const by,
906
+ TIDE_DTYPE const *__restrict const byh,
907
+ TIDE_DTYPE const *__restrict const bx,
908
+ TIDE_DTYPE const *__restrict const bxh,
909
+ TIDE_DTYPE const *__restrict const ky,
910
+ TIDE_DTYPE const *__restrict const kyh,
911
+ TIDE_DTYPE const *__restrict const kx,
912
+ TIDE_DTYPE const *__restrict const kxh) {
913
+
914
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
915
+ (int64_t)threadIdx.x + FD_PAD;
916
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
917
+ (int64_t)threadIdx.y + FD_PAD;
918
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
919
+ (int64_t)threadIdx.z;
920
+
921
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
922
+ int64_t const pml_y0h = pml_y0;
923
+ int64_t const pml_y1h = MAX(pml_y0, pml_y1 - 1);
924
+ int64_t const pml_x0h = pml_x0;
925
+ int64_t const pml_x1h = MAX(pml_x0, pml_x1 - 1);
926
+
927
+ int64_t j = y * nx + x;
928
+ int64_t i = shot_idx * shot_numel + j;
929
+
930
+ TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
931
+
932
+ // Update λ_Hx: λ_Hx = λ_Hx - cb * D_z^T[λ_Ey]
933
+ // EXACT ADJOINT: use transpose of DIFFYH1 -> which is DIFFY1
934
+ if (y < ny - FD_PAD) {
935
+ bool pml_y = y < pml_y0h || y >= pml_y1h;
936
+
937
+ TIDE_DTYPE d_lambda_ey_dz = DIFFY1(LAMBDA_EY);
938
+
939
+ if (pml_y) {
940
+ m_lambda_ey_z[i] = __ldg(&byh[y]) * m_lambda_ey_z[i] + __ldg(&ayh[y]) * d_lambda_ey_dz;
941
+ d_lambda_ey_dz = d_lambda_ey_dz / __ldg(&kyh[y]) + m_lambda_ey_z[i];
942
+ }
943
+
944
+ lambda_hx[i] -= cb_shot_i * d_lambda_ey_dz;
945
+ }
946
+
947
+ // Update λ_Hz: λ_Hz = λ_Hz + cb * D_x^T[λ_Ey]
948
+ // EXACT ADJOINT: use transpose of DIFFXH1 -> which is DIFFX1
949
+ if (x < nx - FD_PAD) {
950
+ bool pml_x = x < pml_x0h || x >= pml_x1h;
951
+
952
+ TIDE_DTYPE d_lambda_ey_dx = DIFFX1(LAMBDA_EY);
953
+
954
+ if (pml_x) {
955
+ m_lambda_ey_x[i] = __ldg(&bxh[x]) * m_lambda_ey_x[i] + __ldg(&axh[x]) * d_lambda_ey_dx;
956
+ d_lambda_ey_dx = d_lambda_ey_dx / __ldg(&kxh[x]) + m_lambda_ey_x[i];
957
+ }
958
+
959
+ lambda_hz[i] += cb_shot_i * d_lambda_ey_dx;
960
+ }
961
+ }
962
+ }
963
+
964
+ // Backward kernel: Update adjoint λ_Ey field with per-shot gradient accumulation
965
+ // Uses pml_y0/pml_y1/pml_x0/pml_x1 for both adjoint propagation and gradient masking
966
+ // NO atomicAdd - each shot writes to its own memory region
967
+ __global__ void backward_kernel_lambda_e_with_grad(
968
+ TIDE_DTYPE const *__restrict const ca,
969
+ TIDE_DTYPE const *__restrict const cq,
970
+ TIDE_DTYPE const *__restrict const lambda_hx,
971
+ TIDE_DTYPE const *__restrict const lambda_hz,
972
+ TIDE_DTYPE *__restrict const lambda_ey,
973
+ TIDE_DTYPE *__restrict const m_lambda_hx_z,
974
+ TIDE_DTYPE *__restrict const m_lambda_hz_x,
975
+ TIDE_DTYPE const *__restrict const ey_store,
976
+ TIDE_DTYPE const *__restrict const curl_h_store,
977
+ TIDE_DTYPE *__restrict const grad_ca_shot, // [n_shots, ny, nx] - per-shot gradient
978
+ TIDE_DTYPE *__restrict const grad_cb_shot, // [n_shots, ny, nx] - per-shot gradient
979
+ TIDE_DTYPE const *__restrict const ay,
980
+ TIDE_DTYPE const *__restrict const ayh,
981
+ TIDE_DTYPE const *__restrict const ax,
982
+ TIDE_DTYPE const *__restrict const axh,
983
+ TIDE_DTYPE const *__restrict const by,
984
+ TIDE_DTYPE const *__restrict const byh,
985
+ TIDE_DTYPE const *__restrict const bx,
986
+ TIDE_DTYPE const *__restrict const bxh,
987
+ TIDE_DTYPE const *__restrict const ky,
988
+ TIDE_DTYPE const *__restrict const kyh,
989
+ TIDE_DTYPE const *__restrict const kx,
990
+ TIDE_DTYPE const *__restrict const kxh,
991
+ bool const ca_requires_grad,
992
+ bool const cb_requires_grad,
993
+ int64_t const step_ratio_val) {
994
+
995
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
996
+ (int64_t)threadIdx.x + FD_PAD;
997
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
998
+ (int64_t)threadIdx.y + FD_PAD;
999
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
1000
+ (int64_t)threadIdx.z;
1001
+
1002
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
1003
+ int64_t j = y * nx + x;
1004
+ int64_t i = shot_idx * shot_numel + j;
1005
+
1006
+ TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
1007
+ TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
1008
+
1009
+ // Determine PML region (pml_y/pml_x = true means in PML region)
1010
+ bool pml_y = y < pml_y0 || y >= pml_y1;
1011
+ bool pml_x = x < pml_x0 || x >= pml_x1;
1012
+
1013
+ // Compute D_x^{hT}[λ_Hz] at integer grid points
1014
+ // EXACT ADJOINT: use transpose of DIFFX1 -> which is DIFFXH1
1015
+ TIDE_DTYPE d_lambda_hz_dx = DIFFXH1(LAMBDA_HZ);
1016
+ // Compute D_z^{hT}[λ_Hx] at integer grid points
1017
+ // EXACT ADJOINT: use transpose of DIFFY1 -> which is DIFFYH1
1018
+ TIDE_DTYPE d_lambda_hx_dz = DIFFYH1(LAMBDA_HX);
1019
+
1020
+ // Pre-load PML coefficients into registers (optimization 1.2)
1021
+ TIDE_DTYPE bx_val = __ldg(&bx[x]);
1022
+ TIDE_DTYPE ax_val = __ldg(&ax[x]);
1023
+ TIDE_DTYPE kx_val = __ldg(&kx[x]);
1024
+ TIDE_DTYPE by_val = __ldg(&by[y]);
1025
+ TIDE_DTYPE ay_val = __ldg(&ay[y]);
1026
+ TIDE_DTYPE ky_val = __ldg(&ky[y]);
1027
+
1028
+ // Apply adjoint CPML for d(λ_Hz)/dx (only in PML region)
1029
+ if (pml_x) {
1030
+ m_lambda_hz_x[i] = bx_val * m_lambda_hz_x[i] + ax_val * d_lambda_hz_dx;
1031
+ d_lambda_hz_dx = d_lambda_hz_dx / kx_val + m_lambda_hz_x[i];
1032
+ }
1033
+
1034
+ // Apply adjoint CPML for d(λ_Hx)/dz (only in PML region)
1035
+ if (pml_y) {
1036
+ m_lambda_hx_z[i] = by_val * m_lambda_hx_z[i] + ay_val * d_lambda_hx_dz;
1037
+ d_lambda_hx_dz = d_lambda_hx_dz / ky_val + m_lambda_hx_z[i];
1038
+ }
1039
+
1040
+ // curl_λH = d(λ_Hz)/dx - d(λ_Hx)/dz
1041
+ TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
1042
+
1043
+ // Store current λ_Ey before update (this is λ_Ey^{n+1})
1044
+ TIDE_DTYPE lambda_ey_curr = lambda_ey[i];
1045
+
1046
+ // Update λ_Ey: λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * curl_λH
1047
+ lambda_ey[i] = ca_shot_i * lambda_ey_curr + cq_shot_i * curl_lambda_h;
1048
+
1049
+ // Accumulate per-shot gradients only in interior region (!pml_y && !pml_x)
1050
+ if (!pml_y && !pml_x) {
1051
+ // grad_ca_shot[shot_idx, y, x] += λ_Ey^{n+1} * E_y^n
1052
+ // Convert from BF16 back to FP32 for computation
1053
+ if (ca_requires_grad && ey_store != nullptr) {
1054
+ TIDE_DTYPE ey_n = ey_store[i];
1055
+ grad_ca_shot[i] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio_val;
1056
+ }
1057
+
1058
+ // grad_cb_shot[shot_idx, y, x] += λ_Ey^{n+1} * curl_H^n
1059
+ if (cb_requires_grad && curl_h_store != nullptr) {
1060
+ TIDE_DTYPE curl_h_n = curl_h_store[i];
1061
+ grad_cb_shot[i] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio_val;
1062
+ }
1063
+ }
1064
+ }
1065
+ }
1066
+
1067
+ // Backward kernel: Update adjoint λ_Ey field with BF16 snapshot loads.
1068
+ __global__ void backward_kernel_lambda_e_with_grad_bf16(
1069
+ TIDE_DTYPE const *__restrict const ca,
1070
+ TIDE_DTYPE const *__restrict const cq,
1071
+ TIDE_DTYPE const *__restrict const lambda_hx,
1072
+ TIDE_DTYPE const *__restrict const lambda_hz,
1073
+ TIDE_DTYPE *__restrict const lambda_ey,
1074
+ TIDE_DTYPE *__restrict const m_lambda_hx_z,
1075
+ TIDE_DTYPE *__restrict const m_lambda_hz_x,
1076
+ __nv_bfloat16 const *__restrict const ey_store,
1077
+ __nv_bfloat16 const *__restrict const curl_h_store,
1078
+ TIDE_DTYPE *__restrict const grad_ca_shot,
1079
+ TIDE_DTYPE *__restrict const grad_cb_shot,
1080
+ TIDE_DTYPE const *__restrict const ay,
1081
+ TIDE_DTYPE const *__restrict const ayh,
1082
+ TIDE_DTYPE const *__restrict const ax,
1083
+ TIDE_DTYPE const *__restrict const axh,
1084
+ TIDE_DTYPE const *__restrict const by,
1085
+ TIDE_DTYPE const *__restrict const byh,
1086
+ TIDE_DTYPE const *__restrict const bx,
1087
+ TIDE_DTYPE const *__restrict const bxh,
1088
+ TIDE_DTYPE const *__restrict const ky,
1089
+ TIDE_DTYPE const *__restrict const kyh,
1090
+ TIDE_DTYPE const *__restrict const kx,
1091
+ TIDE_DTYPE const *__restrict const kxh,
1092
+ bool const ca_requires_grad,
1093
+ bool const cb_requires_grad,
1094
+ int64_t const step_ratio_val) {
1095
+
1096
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
1097
+ (int64_t)threadIdx.x + FD_PAD;
1098
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
1099
+ (int64_t)threadIdx.y + FD_PAD;
1100
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
1101
+ (int64_t)threadIdx.z;
1102
+
1103
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
1104
+ int64_t j = y * nx + x;
1105
+ int64_t i = shot_idx * shot_numel + j;
1106
+
1107
+ TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
1108
+ TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
1109
+
1110
+ bool pml_y = y < pml_y0 || y >= pml_y1;
1111
+ bool pml_x = x < pml_x0 || x >= pml_x1;
1112
+
1113
+ // EXACT ADJOINT: use transposed difference operators
1114
+ TIDE_DTYPE d_lambda_hz_dx = DIFFXH1(LAMBDA_HZ);
1115
+ TIDE_DTYPE d_lambda_hx_dz = DIFFYH1(LAMBDA_HX);
1116
+
1117
+ // Pre-load PML coefficients into registers (optimization 1.2)
1118
+ TIDE_DTYPE bx_val = __ldg(&bx[x]);
1119
+ TIDE_DTYPE ax_val = __ldg(&ax[x]);
1120
+ TIDE_DTYPE kx_val = __ldg(&kx[x]);
1121
+ TIDE_DTYPE by_val = __ldg(&by[y]);
1122
+ TIDE_DTYPE ay_val = __ldg(&ay[y]);
1123
+ TIDE_DTYPE ky_val = __ldg(&ky[y]);
1124
+
1125
+ if (pml_x) {
1126
+ m_lambda_hz_x[i] = bx_val * m_lambda_hz_x[i] + ax_val * d_lambda_hz_dx;
1127
+ d_lambda_hz_dx = d_lambda_hz_dx / kx_val + m_lambda_hz_x[i];
1128
+ }
1129
+
1130
+ if (pml_y) {
1131
+ m_lambda_hx_z[i] = by_val * m_lambda_hx_z[i] + ay_val * d_lambda_hx_dz;
1132
+ d_lambda_hx_dz = d_lambda_hx_dz / ky_val + m_lambda_hx_z[i];
1133
+ }
1134
+
1135
+ TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
1136
+
1137
+ TIDE_DTYPE lambda_ey_curr = lambda_ey[i];
1138
+ lambda_ey[i] = ca_shot_i * lambda_ey_curr + cq_shot_i * curl_lambda_h;
1139
+
1140
+ if (!pml_y && !pml_x) {
1141
+ if (ca_requires_grad && ey_store != nullptr) {
1142
+ TIDE_DTYPE ey_n = (TIDE_DTYPE)__bfloat162float(ey_store[i]);
1143
+ grad_ca_shot[i] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio_val;
1144
+ }
1145
+ if (cb_requires_grad && curl_h_store != nullptr) {
1146
+ TIDE_DTYPE curl_h_n = (TIDE_DTYPE)__bfloat162float(curl_h_store[i]);
1147
+ grad_cb_shot[i] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio_val;
1148
+ }
1149
+ }
1150
+ }
1151
+ }
1152
+
1153
+ // Backward kernel: Update adjoint λ_Ey field with FP8 snapshot loads.
1154
+ __global__ void backward_kernel_lambda_e_with_grad_fp8(
1155
+ TIDE_DTYPE const *__restrict const ca,
1156
+ TIDE_DTYPE const *__restrict const cq,
1157
+ TIDE_DTYPE const *__restrict const lambda_hx,
1158
+ TIDE_DTYPE const *__restrict const lambda_hz,
1159
+ TIDE_DTYPE *__restrict const lambda_ey,
1160
+ TIDE_DTYPE *__restrict const m_lambda_hx_z,
1161
+ TIDE_DTYPE *__restrict const m_lambda_hz_x,
1162
+ uint8_t const *__restrict const ey_store,
1163
+ uint8_t const *__restrict const curl_h_store,
1164
+ TIDE_DTYPE *__restrict const grad_ca_shot,
1165
+ TIDE_DTYPE *__restrict const grad_cb_shot,
1166
+ TIDE_DTYPE const *__restrict const ay,
1167
+ TIDE_DTYPE const *__restrict const ayh,
1168
+ TIDE_DTYPE const *__restrict const ax,
1169
+ TIDE_DTYPE const *__restrict const axh,
1170
+ TIDE_DTYPE const *__restrict const by,
1171
+ TIDE_DTYPE const *__restrict const byh,
1172
+ TIDE_DTYPE const *__restrict const bx,
1173
+ TIDE_DTYPE const *__restrict const bxh,
1174
+ TIDE_DTYPE const *__restrict const ky,
1175
+ TIDE_DTYPE const *__restrict const kyh,
1176
+ TIDE_DTYPE const *__restrict const kx,
1177
+ TIDE_DTYPE const *__restrict const kxh,
1178
+ bool const ca_requires_grad,
1179
+ bool const cb_requires_grad,
1180
+ int64_t const step_ratio_val) {
1181
+
1182
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
1183
+ (int64_t)threadIdx.x + FD_PAD;
1184
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
1185
+ (int64_t)threadIdx.y + FD_PAD;
1186
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
1187
+ (int64_t)threadIdx.z;
1188
+
1189
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
1190
+ int64_t j = y * nx + x;
1191
+ int64_t i = shot_idx * shot_numel + j;
1192
+
1193
+ TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
1194
+ TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
1195
+
1196
+ bool pml_y = y < pml_y0 || y >= pml_y1;
1197
+ bool pml_x = x < pml_x0 || x >= pml_x1;
1198
+
1199
+ TIDE_DTYPE d_lambda_hz_dx = DIFFXH1(LAMBDA_HZ);
1200
+ TIDE_DTYPE d_lambda_hx_dz = DIFFYH1(LAMBDA_HX);
1201
+
1202
+ TIDE_DTYPE bx_val = __ldg(&bx[x]);
1203
+ TIDE_DTYPE ax_val = __ldg(&ax[x]);
1204
+ TIDE_DTYPE kx_val = __ldg(&kx[x]);
1205
+ TIDE_DTYPE by_val = __ldg(&by[y]);
1206
+ TIDE_DTYPE ay_val = __ldg(&ay[y]);
1207
+ TIDE_DTYPE ky_val = __ldg(&ky[y]);
1208
+
1209
+ if (pml_x) {
1210
+ m_lambda_hz_x[i] = bx_val * m_lambda_hz_x[i] + ax_val * d_lambda_hz_dx;
1211
+ d_lambda_hz_dx = d_lambda_hz_dx / kx_val + m_lambda_hz_x[i];
1212
+ }
1213
+
1214
+ if (pml_y) {
1215
+ m_lambda_hx_z[i] = by_val * m_lambda_hx_z[i] + ay_val * d_lambda_hx_dz;
1216
+ d_lambda_hx_dz = d_lambda_hx_dz / ky_val + m_lambda_hx_z[i];
1217
+ }
1218
+
1219
+ TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
1220
+
1221
+ TIDE_DTYPE lambda_ey_curr = lambda_ey[i];
1222
+ lambda_ey[i] = ca_shot_i * lambda_ey_curr + cq_shot_i * curl_lambda_h;
1223
+
1224
+ if (!pml_y && !pml_x) {
1225
+ if (ca_requires_grad && ey_store != nullptr) {
1226
+ TIDE_DTYPE ey_n = (TIDE_DTYPE)fp8_e4m3_to_float(ey_store[i]);
1227
+ grad_ca_shot[i] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio_val;
1228
+ }
1229
+ if (cb_requires_grad && curl_h_store != nullptr) {
1230
+ TIDE_DTYPE curl_h_n = (TIDE_DTYPE)fp8_e4m3_to_float(curl_h_store[i]);
1231
+ grad_cb_shot[i] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio_val;
1232
+ }
1233
+ }
1234
+ }
1235
+ }
1236
+
1237
+ // Backward kernel: Update adjoint λ_Ey field (no gradient accumulation).
1238
+ __global__ void backward_kernel_lambda_e(
1239
+ TIDE_DTYPE const *__restrict const ca,
1240
+ TIDE_DTYPE const *__restrict const cq,
1241
+ TIDE_DTYPE const *__restrict const lambda_hx,
1242
+ TIDE_DTYPE const *__restrict const lambda_hz,
1243
+ TIDE_DTYPE *__restrict const lambda_ey,
1244
+ TIDE_DTYPE *__restrict const m_lambda_hx_z,
1245
+ TIDE_DTYPE *__restrict const m_lambda_hz_x,
1246
+ TIDE_DTYPE const *__restrict const ay,
1247
+ TIDE_DTYPE const *__restrict const ayh,
1248
+ TIDE_DTYPE const *__restrict const ax,
1249
+ TIDE_DTYPE const *__restrict const axh,
1250
+ TIDE_DTYPE const *__restrict const by,
1251
+ TIDE_DTYPE const *__restrict const byh,
1252
+ TIDE_DTYPE const *__restrict const bx,
1253
+ TIDE_DTYPE const *__restrict const bxh,
1254
+ TIDE_DTYPE const *__restrict const ky,
1255
+ TIDE_DTYPE const *__restrict const kyh,
1256
+ TIDE_DTYPE const *__restrict const kx,
1257
+ TIDE_DTYPE const *__restrict const kxh) {
1258
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
1259
+ (int64_t)threadIdx.x + FD_PAD;
1260
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
1261
+ (int64_t)threadIdx.y + FD_PAD;
1262
+ int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
1263
+ (int64_t)threadIdx.z;
1264
+
1265
+ if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
1266
+ int64_t j = y * nx + x;
1267
+ int64_t i = shot_idx * shot_numel + j;
1268
+
1269
+ (void)ayh;
1270
+ (void)axh;
1271
+ (void)byh;
1272
+ (void)bxh;
1273
+ (void)kyh;
1274
+ (void)kxh;
1275
+
1276
+ TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
1277
+ TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
1278
+
1279
+ bool pml_y = y < pml_y0 || y >= pml_y1;
1280
+ bool pml_x = x < pml_x0 || x >= pml_x1;
1281
+
1282
+ // EXACT ADJOINT: use transposed difference operators
1283
+ TIDE_DTYPE d_lambda_hz_dx = DIFFXH1(LAMBDA_HZ);
1284
+ TIDE_DTYPE d_lambda_hx_dz = DIFFYH1(LAMBDA_HX);
1285
+
1286
+ // Pre-load PML coefficients into registers (optimization 1.2)
1287
+ TIDE_DTYPE bx_val = __ldg(&bx[x]);
1288
+ TIDE_DTYPE ax_val = __ldg(&ax[x]);
1289
+ TIDE_DTYPE kx_val = __ldg(&kx[x]);
1290
+ TIDE_DTYPE by_val = __ldg(&by[y]);
1291
+ TIDE_DTYPE ay_val = __ldg(&ay[y]);
1292
+ TIDE_DTYPE ky_val = __ldg(&ky[y]);
1293
+
1294
+ if (pml_x) {
1295
+ m_lambda_hz_x[i] = bx_val * m_lambda_hz_x[i] + ax_val * d_lambda_hz_dx;
1296
+ d_lambda_hz_dx = d_lambda_hz_dx / kx_val + m_lambda_hz_x[i];
1297
+ }
1298
+
1299
+ if (pml_y) {
1300
+ m_lambda_hx_z[i] = by_val * m_lambda_hx_z[i] + ay_val * d_lambda_hx_dz;
1301
+ d_lambda_hx_dz = d_lambda_hx_dz / ky_val + m_lambda_hx_z[i];
1302
+ }
1303
+
1304
+ TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
1305
+
1306
+ TIDE_DTYPE lambda_ey_curr = lambda_ey[i];
1307
+ lambda_ey[i] = ca_shot_i * lambda_ey_curr + cq_shot_i * curl_lambda_h;
1308
+ }
1309
+ }
1310
+
1311
+ // Combine per-shot gradients into final gradient (sum across shots)
1312
+ __global__ void combine_grad(TIDE_DTYPE *__restrict const grad,
1313
+ TIDE_DTYPE const *__restrict const grad_shot) {
1314
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
1315
+ (int64_t)threadIdx.x + FD_PAD;
1316
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
1317
+ (int64_t)threadIdx.y + FD_PAD;
1318
+ if (y < ny - FD_PAD && x < nx - FD_PAD) {
1319
+ int64_t j = y * nx + x;
1320
+ int64_t const stride = shot_numel;
1321
+ TIDE_DTYPE sum = 0;
1322
+ #pragma unroll 4
1323
+ for (int64_t shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
1324
+ sum += grad_shot[shot_idx * stride + j];
1325
+ }
1326
+ grad[j] += sum;
1327
+ }
1328
+ }
1329
+
1330
+ __global__ void convert_grad_ca_cb_to_eps_sigma(
1331
+ TIDE_DTYPE const *__restrict const ca,
1332
+ TIDE_DTYPE const *__restrict const cb,
1333
+ TIDE_DTYPE const *__restrict const grad_ca,
1334
+ TIDE_DTYPE const *__restrict const grad_cb,
1335
+ TIDE_DTYPE const *__restrict const grad_ca_shot,
1336
+ TIDE_DTYPE const *__restrict const grad_cb_shot,
1337
+ TIDE_DTYPE *__restrict const grad_eps,
1338
+ TIDE_DTYPE *__restrict const grad_sigma,
1339
+ TIDE_DTYPE const dt,
1340
+ bool const ca_requires_grad,
1341
+ bool const cb_requires_grad,
1342
+ bool const ca_batched_h,
1343
+ bool const cb_batched_h) {
1344
+ int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
1345
+ (int64_t)threadIdx.x;
1346
+ int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
1347
+ (int64_t)threadIdx.y;
1348
+ if (x >= nx || y >= ny) {
1349
+ return;
1350
+ }
1351
+
1352
+ int64_t shot_idx = (int64_t)blockIdx.z;
1353
+ if (!ca_batched_h) {
1354
+ shot_idx = 0;
1355
+ }
1356
+
1357
+ int64_t const j = y * nx + x;
1358
+ int64_t const idx_shot = shot_idx * shot_numel + j;
1359
+ int64_t const out_idx = ca_batched_h ? idx_shot : j;
1360
+ int64_t const ca_idx = ca_batched_h ? idx_shot : j;
1361
+ int64_t const cb_idx = cb_batched_h ? idx_shot : j;
1362
+
1363
+ TIDE_DTYPE const ca_val = ca[ca_idx];
1364
+ TIDE_DTYPE const cb_val = cb[cb_idx];
1365
+ TIDE_DTYPE const cb_sq = cb_val * cb_val;
1366
+ TIDE_DTYPE const inv_dt = (TIDE_DTYPE)1 / dt;
1367
+
1368
+ TIDE_DTYPE grad_ca_val = 0;
1369
+ if (ca_requires_grad) {
1370
+ grad_ca_val = ca_batched_h ? grad_ca_shot[idx_shot] : grad_ca[j];
1371
+ }
1372
+
1373
+ TIDE_DTYPE grad_cb_val = 0;
1374
+ if (cb_requires_grad) {
1375
+ grad_cb_val = cb_batched_h ? grad_cb_shot[idx_shot] : grad_cb[j];
1376
+ }
1377
+
1378
+ TIDE_DTYPE const dca_de = ((TIDE_DTYPE)1 - ca_val) * cb_val * inv_dt;
1379
+ TIDE_DTYPE const dcb_de = -cb_sq * inv_dt;
1380
+ TIDE_DTYPE const dca_ds = -((TIDE_DTYPE)0.5) * ((TIDE_DTYPE)1 + ca_val) * cb_val;
1381
+ TIDE_DTYPE const dcb_ds = -((TIDE_DTYPE)0.5) * cb_sq;
1382
+
1383
+ if (grad_eps != nullptr) {
1384
+ TIDE_DTYPE const grad_e = grad_ca_val * dca_de + grad_cb_val * dcb_de;
1385
+ grad_eps[out_idx] = grad_e * EP0;
1386
+ }
1387
+ if (grad_sigma != nullptr) {
1388
+ grad_sigma[out_idx] = grad_ca_val * dca_ds + grad_cb_val * dcb_ds;
1389
+ }
1390
+ }
1391
+
1392
+ } // namespace
1393
+
1394
+ // Forward propagation function
1395
+ extern "C" void FUNC(forward)(
1396
+ TIDE_DTYPE const *const ca,
1397
+ TIDE_DTYPE const *const cb,
1398
+ TIDE_DTYPE const *const cq,
1399
+ TIDE_DTYPE const *const f,
1400
+ TIDE_DTYPE *const ey,
1401
+ TIDE_DTYPE *const hx,
1402
+ TIDE_DTYPE *const hz,
1403
+ TIDE_DTYPE *const m_ey_x,
1404
+ TIDE_DTYPE *const m_ey_z,
1405
+ TIDE_DTYPE *const m_hx_z,
1406
+ TIDE_DTYPE *const m_hz_x,
1407
+ TIDE_DTYPE *const r,
1408
+ TIDE_DTYPE const *const ay,
1409
+ TIDE_DTYPE const *const by,
1410
+ TIDE_DTYPE const *const ayh,
1411
+ TIDE_DTYPE const *const byh,
1412
+ TIDE_DTYPE const *const ax,
1413
+ TIDE_DTYPE const *const bx,
1414
+ TIDE_DTYPE const *const axh,
1415
+ TIDE_DTYPE const *const bxh,
1416
+ TIDE_DTYPE const *const ky,
1417
+ TIDE_DTYPE const *const kyh,
1418
+ TIDE_DTYPE const *const kx,
1419
+ TIDE_DTYPE const *const kxh,
1420
+ int64_t const *const sources_i,
1421
+ int64_t const *const receivers_i,
1422
+ TIDE_DTYPE const rdy_h,
1423
+ TIDE_DTYPE const rdx_h,
1424
+ TIDE_DTYPE const dt_h,
1425
+ int64_t const nt,
1426
+ int64_t const n_shots_h,
1427
+ int64_t const ny_h,
1428
+ int64_t const nx_h,
1429
+ int64_t const n_sources_per_shot_h,
1430
+ int64_t const n_receivers_per_shot_h,
1431
+ int64_t const step_ratio_h,
1432
+ bool const ca_batched_h,
1433
+ bool const cb_batched_h,
1434
+ bool const cq_batched_h,
1435
+ int64_t const start_t,
1436
+ int64_t const pml_y0_h,
1437
+ int64_t const pml_x0_h,
1438
+ int64_t const pml_y1_h,
1439
+ int64_t const pml_x1_h,
1440
+ int64_t const n_threads,
1441
+ int64_t const device) {
1442
+
1443
+ cudaSetDevice(device);
1444
+ (void)dt_h;
1445
+ (void)step_ratio_h;
1446
+ (void)n_threads;
1447
+
1448
+ int64_t const shot_numel_h = ny_h * nx_h;
1449
+
1450
+ // Copy constants to device with caching to avoid redundant copies
1451
+ static TIDE_DTYPE cached_rdy = 0, cached_rdx = 0;
1452
+ static int64_t cached_n_shots = -1, cached_ny = -1, cached_nx = -1;
1453
+ static int64_t cached_shot_numel = -1, cached_n_sources_per_shot = -1, cached_n_receivers_per_shot = -1;
1454
+ static int64_t cached_pml_y0 = -1, cached_pml_y1 = -1;
1455
+ static int64_t cached_pml_x0 = -1, cached_pml_x1 = -1;
1456
+ static bool cached_ca_batched = false, cached_cb_batched = false, cached_cq_batched = false;
1457
+ static int64_t cached_device = -1;
1458
+ static bool first_call = true;
1459
+
1460
+ if (first_call || cached_device != device || cached_rdy != rdy_h || cached_rdx != rdx_h ||
1461
+ cached_n_shots != n_shots_h || cached_ny != ny_h || cached_nx != nx_h ||
1462
+ cached_shot_numel != shot_numel_h || cached_n_sources_per_shot != n_sources_per_shot_h ||
1463
+ cached_n_receivers_per_shot != n_receivers_per_shot_h ||
1464
+ cached_pml_y0 != pml_y0_h || cached_pml_y1 != pml_y1_h ||
1465
+ cached_pml_x0 != pml_x0_h || cached_pml_x1 != pml_x1_h ||
1466
+ cached_ca_batched != ca_batched_h || cached_cb_batched != cb_batched_h ||
1467
+ cached_cq_batched != cq_batched_h) {
1468
+
1469
+ cudaMemcpyToSymbol(rdy, &rdy_h, sizeof(TIDE_DTYPE));
1470
+ cudaMemcpyToSymbol(rdx, &rdx_h, sizeof(TIDE_DTYPE));
1471
+ cudaMemcpyToSymbol(n_shots, &n_shots_h, sizeof(int64_t));
1472
+ cudaMemcpyToSymbol(ny, &ny_h, sizeof(int64_t));
1473
+ cudaMemcpyToSymbol(nx, &nx_h, sizeof(int64_t));
1474
+ cudaMemcpyToSymbol(shot_numel, &shot_numel_h, sizeof(int64_t));
1475
+ cudaMemcpyToSymbol(n_sources_per_shot, &n_sources_per_shot_h, sizeof(int64_t));
1476
+ cudaMemcpyToSymbol(n_receivers_per_shot, &n_receivers_per_shot_h, sizeof(int64_t));
1477
+ cudaMemcpyToSymbol(pml_y0, &pml_y0_h, sizeof(int64_t));
1478
+ cudaMemcpyToSymbol(pml_y1, &pml_y1_h, sizeof(int64_t));
1479
+ cudaMemcpyToSymbol(pml_x0, &pml_x0_h, sizeof(int64_t));
1480
+ cudaMemcpyToSymbol(pml_x1, &pml_x1_h, sizeof(int64_t));
1481
+ cudaMemcpyToSymbol(ca_batched, &ca_batched_h, sizeof(bool));
1482
+ cudaMemcpyToSymbol(cb_batched, &cb_batched_h, sizeof(bool));
1483
+ cudaMemcpyToSymbol(cq_batched, &cq_batched_h, sizeof(bool));
1484
+
1485
+ cached_rdy = rdy_h; cached_rdx = rdx_h;
1486
+ cached_n_shots = n_shots_h; cached_ny = ny_h; cached_nx = nx_h;
1487
+ cached_shot_numel = shot_numel_h; cached_n_sources_per_shot = n_sources_per_shot_h;
1488
+ cached_n_receivers_per_shot = n_receivers_per_shot_h;
1489
+ cached_pml_y0 = pml_y0_h; cached_pml_y1 = pml_y1_h;
1490
+ cached_pml_x0 = pml_x0_h; cached_pml_x1 = pml_x1_h;
1491
+ cached_ca_batched = ca_batched_h; cached_cb_batched = cb_batched_h;
1492
+ cached_cq_batched = cq_batched_h;
1493
+ cached_device = device;
1494
+ first_call = false;
1495
+ }
1496
+
1497
+ dim3 dimBlock(32, 8, 1);
1498
+ int64_t gridx = (nx_h - 2 * FD_PAD + 2 + dimBlock.x - 1) / dimBlock.x;
1499
+ int64_t gridy = (ny_h - 2 * FD_PAD + 2 + dimBlock.y - 1) / dimBlock.y;
1500
+ int64_t gridz = n_shots_h;
1501
+ dim3 dimGrid(gridx, gridy, gridz);
1502
+ #if FD_PAD > 1
1503
+ size_t const shmem_h_bytes =
1504
+ (size_t)(dimBlock.x + 2 * FD_PAD) * (size_t)(dimBlock.y + 2 * FD_PAD) *
1505
+ sizeof(TIDE_DTYPE);
1506
+ size_t const shmem_e_bytes = 2 * shmem_h_bytes;
1507
+ #else
1508
+ size_t const shmem_h_bytes = 0;
1509
+ size_t const shmem_e_bytes = 0;
1510
+ #endif
1511
+
1512
+ dim3 dimBlock_sources(32, 1, 1);
1513
+ dim3 dimGrid_sources(
1514
+ (n_sources_per_shot_h + dimBlock_sources.x - 1) / dimBlock_sources.x,
1515
+ n_shots_h, 1);
1516
+
1517
+ dim3 dimBlock_receivers(32, 1, 1);
1518
+ dim3 dimGrid_receivers(
1519
+ (n_receivers_per_shot_h + dimBlock_receivers.x - 1) / dimBlock_receivers.x,
1520
+ n_shots_h, 1);
1521
+
1522
+ auto run_step = [&](int64_t t) {
1523
+ forward_kernel_h<<<dimGrid, dimBlock, shmem_h_bytes>>>(
1524
+ cq, ey, hx, hz, m_ey_x, m_ey_z,
1525
+ ay, ayh, ax, axh, by, byh, bx, bxh,
1526
+ ky, kyh, kx, kxh);
1527
+ forward_kernel_e<<<dimGrid, dimBlock, shmem_e_bytes>>>(
1528
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
1529
+ ay, ayh, ax, axh, by, byh, bx, bxh,
1530
+ ky, kyh, kx, kxh);
1531
+
1532
+ if (n_sources_per_shot_h > 0) {
1533
+ add_sources_ey<<<dimGrid_sources, dimBlock_sources>>>(
1534
+ ey, f + t * n_shots_h * n_sources_per_shot_h, sources_i);
1535
+ }
1536
+
1537
+ if (n_receivers_per_shot_h > 0) {
1538
+ record_receivers_ey<<<dimGrid_receivers, dimBlock_receivers>>>(
1539
+ r + t * n_shots_h * n_receivers_per_shot_h, ey, receivers_i);
1540
+ }
1541
+ };
1542
+
1543
+ for (int64_t t = start_t; t < start_t + nt; ++t) {
1544
+ run_step(t);
1545
+ }
1546
+
1547
+ gpuErrchk(cudaPeekAtLastError());
1548
+ }
1549
+
1550
+ extern "C" void FUNC(forward_with_storage)(
1551
+ TIDE_DTYPE const *const ca,
1552
+ TIDE_DTYPE const *const cb,
1553
+ TIDE_DTYPE const *const cq,
1554
+ TIDE_DTYPE const *const f,
1555
+ TIDE_DTYPE *const ey,
1556
+ TIDE_DTYPE *const hx,
1557
+ TIDE_DTYPE *const hz,
1558
+ TIDE_DTYPE *const m_ey_x,
1559
+ TIDE_DTYPE *const m_ey_z,
1560
+ TIDE_DTYPE *const m_hx_z,
1561
+ TIDE_DTYPE *const m_hz_x,
1562
+ TIDE_DTYPE *const r,
1563
+ void *const ey_store_1,
1564
+ void *const ey_store_3,
1565
+ char const *const *const ey_filenames,
1566
+ void *const curl_store_1,
1567
+ void *const curl_store_3,
1568
+ char const *const *const curl_filenames,
1569
+ TIDE_DTYPE const *const ay,
1570
+ TIDE_DTYPE const *const by,
1571
+ TIDE_DTYPE const *const ayh,
1572
+ TIDE_DTYPE const *const byh,
1573
+ TIDE_DTYPE const *const ax,
1574
+ TIDE_DTYPE const *const bx,
1575
+ TIDE_DTYPE const *const axh,
1576
+ TIDE_DTYPE const *const bxh,
1577
+ TIDE_DTYPE const *const ky,
1578
+ TIDE_DTYPE const *const kyh,
1579
+ TIDE_DTYPE const *const kx,
1580
+ TIDE_DTYPE const *const kxh,
1581
+ int64_t const *const sources_i,
1582
+ int64_t const *const receivers_i,
1583
+ TIDE_DTYPE const rdy_h,
1584
+ TIDE_DTYPE const rdx_h,
1585
+ TIDE_DTYPE const dt_h,
1586
+ int64_t const nt,
1587
+ int64_t const n_shots_h,
1588
+ int64_t const ny_h,
1589
+ int64_t const nx_h,
1590
+ int64_t const n_sources_per_shot_h,
1591
+ int64_t const n_receivers_per_shot_h,
1592
+ int64_t const step_ratio_h,
1593
+ int64_t const storage_mode_h,
1594
+ int64_t const shot_bytes_uncomp_h,
1595
+ bool const ca_requires_grad,
1596
+ bool const cb_requires_grad,
1597
+ bool const ca_batched_h,
1598
+ bool const cb_batched_h,
1599
+ bool const cq_batched_h,
1600
+ int64_t const start_t,
1601
+ int64_t const pml_y0_h,
1602
+ int64_t const pml_x0_h,
1603
+ int64_t const pml_y1_h,
1604
+ int64_t const pml_x1_h,
1605
+ int64_t const n_threads,
1606
+ int64_t const device) {
1607
+
1608
+ cudaSetDevice(device);
1609
+ (void)n_threads;
1610
+
1611
+ int64_t const shot_numel_h = ny_h * nx_h;
1612
+ size_t const bytes_per_step_store =
1613
+ (size_t)shot_bytes_uncomp_h * (size_t)n_shots_h;
1614
+ bool const storage_bf16_h = (shot_bytes_uncomp_h == shot_numel_h * 2);
1615
+ bool const storage_fp8_h = (shot_bytes_uncomp_h == shot_numel_h);
1616
+ cudaStream_t copy_stream = nullptr;
1617
+ cudaEvent_t store_ready;
1618
+ cudaEvent_t copy_done[NUM_BUFFERS];
1619
+ bool copy_in_flight[NUM_BUFFERS];
1620
+ for (int i = 0; i < NUM_BUFFERS; i++) copy_in_flight[i] = false;
1621
+
1622
+ #ifdef TIDE_PROFILING
1623
+ cudaEvent_t prof_wait_start, prof_wait_end, prof_copy_start, prof_copy_end;
1624
+ float total_wait_ms = 0.0f, total_copy_ms = 0.0f;
1625
+ int n_waits = 0, n_copies = 0;
1626
+ #endif
1627
+
1628
+ if (storage_mode_h == STORAGE_CPU) {
1629
+ gpuErrchk(cudaStreamCreateWithFlags(&copy_stream, cudaStreamNonBlocking));
1630
+ #ifdef TIDE_PROFILING
1631
+ PROF_EVENT_CREATE(store_ready);
1632
+ PROF_EVENT_CREATE(prof_wait_start);
1633
+ PROF_EVENT_CREATE(prof_wait_end);
1634
+ PROF_EVENT_CREATE(prof_copy_start);
1635
+ PROF_EVENT_CREATE(prof_copy_end);
1636
+ for (int i = 0; i < NUM_BUFFERS; i++) {
1637
+ PROF_EVENT_CREATE(copy_done[i]);
1638
+ }
1639
+ #else
1640
+ gpuErrchk(cudaEventCreateWithFlags(&store_ready, cudaEventDisableTiming));
1641
+ for (int i = 0; i < NUM_BUFFERS; i++) {
1642
+ gpuErrchk(cudaEventCreateWithFlags(&copy_done[i], cudaEventDisableTiming));
1643
+ }
1644
+ #endif
1645
+ }
1646
+
1647
+ // Copy constants to device with caching to avoid redundant copies
1648
+ static TIDE_DTYPE cached_rdy2 = 0, cached_rdx2 = 0;
1649
+ static int64_t cached_n_shots2 = -1, cached_ny2 = -1, cached_nx2 = -1;
1650
+ static int64_t cached_shot_numel2 = -1, cached_n_sources_per_shot2 = -1, cached_n_receivers_per_shot2 = -1;
1651
+ static int64_t cached_pml_y02 = -1, cached_pml_y12 = -1;
1652
+ static int64_t cached_pml_x02 = -1, cached_pml_x12 = -1;
1653
+ static bool cached_ca_batched2 = false, cached_cb_batched2 = false, cached_cq_batched2 = false;
1654
+ static int64_t cached_device2 = -1;
1655
+ static bool first_call2 = true;
1656
+
1657
+ if (first_call2 || cached_device2 != device || cached_rdy2 != rdy_h || cached_rdx2 != rdx_h ||
1658
+ cached_n_shots2 != n_shots_h || cached_ny2 != ny_h || cached_nx2 != nx_h ||
1659
+ cached_shot_numel2 != shot_numel_h || cached_n_sources_per_shot2 != n_sources_per_shot_h ||
1660
+ cached_n_receivers_per_shot2 != n_receivers_per_shot_h ||
1661
+ cached_pml_y02 != pml_y0_h || cached_pml_y12 != pml_y1_h ||
1662
+ cached_pml_x02 != pml_x0_h || cached_pml_x12 != pml_x1_h ||
1663
+ cached_ca_batched2 != ca_batched_h || cached_cb_batched2 != cb_batched_h ||
1664
+ cached_cq_batched2 != cq_batched_h) {
1665
+
1666
+ cudaMemcpyToSymbol(rdy, &rdy_h, sizeof(TIDE_DTYPE));
1667
+ cudaMemcpyToSymbol(rdx, &rdx_h, sizeof(TIDE_DTYPE));
1668
+ cudaMemcpyToSymbol(n_shots, &n_shots_h, sizeof(int64_t));
1669
+ cudaMemcpyToSymbol(ny, &ny_h, sizeof(int64_t));
1670
+ cudaMemcpyToSymbol(nx, &nx_h, sizeof(int64_t));
1671
+ cudaMemcpyToSymbol(shot_numel, &shot_numel_h, sizeof(int64_t));
1672
+ cudaMemcpyToSymbol(n_sources_per_shot, &n_sources_per_shot_h, sizeof(int64_t));
1673
+ cudaMemcpyToSymbol(n_receivers_per_shot, &n_receivers_per_shot_h, sizeof(int64_t));
1674
+ cudaMemcpyToSymbol(pml_y0, &pml_y0_h, sizeof(int64_t));
1675
+ cudaMemcpyToSymbol(pml_y1, &pml_y1_h, sizeof(int64_t));
1676
+ cudaMemcpyToSymbol(pml_x0, &pml_x0_h, sizeof(int64_t));
1677
+ cudaMemcpyToSymbol(pml_x1, &pml_x1_h, sizeof(int64_t));
1678
+ cudaMemcpyToSymbol(ca_batched, &ca_batched_h, sizeof(bool));
1679
+ cudaMemcpyToSymbol(cb_batched, &cb_batched_h, sizeof(bool));
1680
+ cudaMemcpyToSymbol(cq_batched, &cq_batched_h, sizeof(bool));
1681
+
1682
+ cached_rdy2 = rdy_h; cached_rdx2 = rdx_h;
1683
+ cached_n_shots2 = n_shots_h; cached_ny2 = ny_h; cached_nx2 = nx_h;
1684
+ cached_shot_numel2 = shot_numel_h; cached_n_sources_per_shot2 = n_sources_per_shot_h;
1685
+ cached_n_receivers_per_shot2 = n_receivers_per_shot_h;
1686
+ cached_pml_y02 = pml_y0_h; cached_pml_y12 = pml_y1_h;
1687
+ cached_pml_x02 = pml_x0_h; cached_pml_x12 = pml_x1_h;
1688
+ cached_ca_batched2 = ca_batched_h; cached_cb_batched2 = cb_batched_h;
1689
+ cached_cq_batched2 = cq_batched_h;
1690
+ cached_device2 = device;
1691
+ first_call2 = false;
1692
+ }
1693
+
1694
+ dim3 dimBlock(32, 8, 1);
1695
+ int64_t gridx = (nx_h - 2 * FD_PAD + 2 + dimBlock.x - 1) / dimBlock.x;
1696
+ int64_t gridy = (ny_h - 2 * FD_PAD + 2 + dimBlock.y - 1) / dimBlock.y;
1697
+ int64_t gridz = n_shots_h;
1698
+ dim3 dimGrid(gridx, gridy, gridz);
1699
+ #if FD_PAD > 1
1700
+ size_t const shmem_h_bytes =
1701
+ (size_t)(dimBlock.x + 2 * FD_PAD) * (size_t)(dimBlock.y + 2 * FD_PAD) *
1702
+ sizeof(TIDE_DTYPE);
1703
+ size_t const shmem_e_bytes = 2 * shmem_h_bytes;
1704
+ #else
1705
+ size_t const shmem_h_bytes = 0;
1706
+ size_t const shmem_e_bytes = 0;
1707
+ #endif
1708
+
1709
+ dim3 dimBlock_sources(32, 1, 1);
1710
+ dim3 dimGrid_sources(
1711
+ (n_sources_per_shot_h + dimBlock_sources.x - 1) / dimBlock_sources.x,
1712
+ n_shots_h, 1);
1713
+
1714
+ dim3 dimBlock_receivers(32, 1, 1);
1715
+ dim3 dimGrid_receivers(
1716
+ (n_receivers_per_shot_h + dimBlock_receivers.x - 1) / dimBlock_receivers.x,
1717
+ n_shots_h, 1);
1718
+
1719
+ FILE *fp_ey = nullptr;
1720
+ FILE *fp_curl = nullptr;
1721
+ if (storage_mode_h == STORAGE_DISK) {
1722
+ if (ca_requires_grad) fp_ey = fopen(ey_filenames[0], "wb");
1723
+ if (cb_requires_grad) fp_curl = fopen(curl_filenames[0], "wb");
1724
+ }
1725
+
1726
+ auto store1_offset_bytes = [&](int64_t step_idx) -> size_t {
1727
+ if (storage_mode_h == STORAGE_DEVICE) {
1728
+ return (size_t)step_idx * bytes_per_step_store;
1729
+ }
1730
+ if (storage_mode_h == STORAGE_CPU) {
1731
+ return (size_t)(step_idx % NUM_BUFFERS) * bytes_per_step_store;
1732
+ }
1733
+ return 0;
1734
+ };
1735
+
1736
+ auto run_step = [&](int64_t t) {
1737
+ forward_kernel_h<<<dimGrid, dimBlock, shmem_h_bytes>>>(
1738
+ cq, ey, hx, hz, m_ey_x, m_ey_z,
1739
+ ay, ayh, ax, axh, by, byh, bx, bxh,
1740
+ ky, kyh, kx, kxh);
1741
+
1742
+ bool const store_step = ((t % step_ratio_h) == 0);
1743
+ bool const store_ey = store_step && ca_requires_grad;
1744
+ bool const store_curl = store_step && cb_requires_grad;
1745
+ bool const want_store = store_ey || store_curl;
1746
+ if (want_store) {
1747
+ int64_t const step_idx = t / step_ratio_h;
1748
+ int const store_buf = (storage_mode_h == STORAGE_CPU) ? (int)(step_idx % NUM_BUFFERS) : 0;
1749
+ if (storage_mode_h == STORAGE_CPU && copy_in_flight[store_buf]) {
1750
+ #ifdef TIDE_PROFILING
1751
+ PROF_RECORD(prof_wait_start, 0);
1752
+ #endif
1753
+ gpuErrchk(cudaStreamWaitEvent(0, copy_done[store_buf], 0));
1754
+ #ifdef TIDE_PROFILING
1755
+ PROF_RECORD(prof_wait_end, 0);
1756
+ gpuErrchk(cudaDeviceSynchronize());
1757
+ float wait_ms;
1758
+ PROF_ELAPSED(prof_wait_start, prof_wait_end, wait_ms);
1759
+ total_wait_ms += wait_ms;
1760
+ n_waits++;
1761
+ #endif
1762
+ copy_in_flight[store_buf] = false;
1763
+ }
1764
+ size_t const store1_offset = store1_offset_bytes(step_idx);
1765
+
1766
+ void *__restrict const ey_store_1_t =
1767
+ (uint8_t *)ey_store_1 + store1_offset;
1768
+ void *__restrict const ey_store_3_t =
1769
+ (uint8_t *)ey_store_3 +
1770
+ (storage_mode_h == STORAGE_CPU
1771
+ ? (size_t)step_idx * bytes_per_step_store
1772
+ : 0);
1773
+
1774
+ void *__restrict const curl_store_1_t =
1775
+ (uint8_t *)curl_store_1 + store1_offset;
1776
+ void *__restrict const curl_store_3_t =
1777
+ (uint8_t *)curl_store_3 +
1778
+ (storage_mode_h == STORAGE_CPU
1779
+ ? (size_t)step_idx * bytes_per_step_store
1780
+ : 0);
1781
+
1782
+ if (storage_fp8_h) {
1783
+ forward_kernel_e_with_storage_fp8<<<dimGrid, dimBlock, shmem_e_bytes>>>(
1784
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
1785
+ store_ey ? (uint8_t *)ey_store_1_t : nullptr,
1786
+ store_curl ? (uint8_t *)curl_store_1_t : nullptr, ay, ayh, ax, axh,
1787
+ by, byh, bx, bxh, ky, kyh, kx, kxh, store_ey, store_curl);
1788
+ } else if (storage_bf16_h) {
1789
+ forward_kernel_e_with_storage_bf16<<<dimGrid, dimBlock, shmem_e_bytes>>>(
1790
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
1791
+ store_ey ? (__nv_bfloat16 *)ey_store_1_t : nullptr,
1792
+ store_curl ? (__nv_bfloat16 *)curl_store_1_t : nullptr, ay, ayh, ax,
1793
+ axh, by, byh, bx, bxh, ky, kyh, kx, kxh, store_ey, store_curl);
1794
+ } else {
1795
+ forward_kernel_e_with_storage<<<dimGrid, dimBlock, shmem_e_bytes>>>(
1796
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
1797
+ store_ey ? (TIDE_DTYPE *)ey_store_1_t : nullptr,
1798
+ store_curl ? (TIDE_DTYPE *)curl_store_1_t : nullptr, ay, ayh, ax,
1799
+ axh, by, byh, bx, bxh, ky, kyh, kx, kxh, store_ey, store_curl);
1800
+ }
1801
+
1802
+ if (storage_mode_h == STORAGE_CPU) {
1803
+ gpuErrchk(cudaEventRecord(store_ready, 0));
1804
+ gpuErrchk(cudaStreamWaitEvent(copy_stream, store_ready, 0));
1805
+ #ifdef TIDE_PROFILING
1806
+ PROF_RECORD(prof_copy_start, copy_stream);
1807
+ #endif
1808
+ if (store_ey) {
1809
+ gpuErrchk(cudaMemcpyAsync(
1810
+ ey_store_3_t, ey_store_1_t, bytes_per_step_store,
1811
+ cudaMemcpyDeviceToHost, copy_stream));
1812
+ }
1813
+ if (store_curl) {
1814
+ gpuErrchk(cudaMemcpyAsync(
1815
+ curl_store_3_t, curl_store_1_t, bytes_per_step_store,
1816
+ cudaMemcpyDeviceToHost, copy_stream));
1817
+ }
1818
+ #ifdef TIDE_PROFILING
1819
+ PROF_RECORD(prof_copy_end, copy_stream);
1820
+ #endif
1821
+ gpuErrchk(cudaEventRecord(copy_done[store_buf], copy_stream));
1822
+ copy_in_flight[store_buf] = true;
1823
+ #ifdef TIDE_PROFILING
1824
+ n_copies++;
1825
+ #endif
1826
+ } else {
1827
+ if (store_ey) {
1828
+ storage_save_snapshot_gpu(
1829
+ ey_store_1_t, ey_store_3_t, fp_ey, storage_mode_h, step_idx,
1830
+ (size_t)shot_bytes_uncomp_h, (size_t)n_shots_h);
1831
+ }
1832
+ if (store_curl) {
1833
+ storage_save_snapshot_gpu(
1834
+ curl_store_1_t, curl_store_3_t, fp_curl, storage_mode_h, step_idx,
1835
+ (size_t)shot_bytes_uncomp_h, (size_t)n_shots_h);
1836
+ }
1837
+ }
1838
+ } else {
1839
+ forward_kernel_e<<<dimGrid, dimBlock, shmem_e_bytes>>>(
1840
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x, ay, ayh, ax, axh, by, byh, bx,
1841
+ bxh, ky, kyh, kx, kxh);
1842
+ }
1843
+
1844
+ if (n_sources_per_shot_h > 0) {
1845
+ add_sources_ey<<<dimGrid_sources, dimBlock_sources>>>(
1846
+ ey, f + t * n_shots_h * n_sources_per_shot_h, sources_i);
1847
+ }
1848
+
1849
+ if (n_receivers_per_shot_h > 0) {
1850
+ record_receivers_ey<<<dimGrid_receivers, dimBlock_receivers>>>(
1851
+ r + t * n_shots_h * n_receivers_per_shot_h, ey, receivers_i);
1852
+ }
1853
+ };
1854
+
1855
+ for (int64_t t = start_t; t < start_t + nt; ++t) {
1856
+ run_step(t);
1857
+ }
1858
+
1859
+ if (storage_mode_h == STORAGE_CPU) {
1860
+ gpuErrchk(cudaStreamSynchronize(copy_stream));
1861
+ for (int i = 0; i < NUM_BUFFERS; i++) {
1862
+ gpuErrchk(cudaEventDestroy(copy_done[i]));
1863
+ }
1864
+ gpuErrchk(cudaEventDestroy(store_ready));
1865
+ gpuErrchk(cudaStreamDestroy(copy_stream));
1866
+ }
1867
+
1868
+ if (fp_ey != nullptr) fclose(fp_ey);
1869
+ if (fp_curl != nullptr) fclose(fp_curl);
1870
+
1871
+ gpuErrchk(cudaPeekAtLastError());
1872
+ }
1873
+
1874
+
1875
+
1876
+ extern "C" void FUNC(backward)(
1877
+ TIDE_DTYPE const *const ca,
1878
+ TIDE_DTYPE const *const cb,
1879
+ TIDE_DTYPE const *const cq,
1880
+ TIDE_DTYPE const *const grad_r,
1881
+ TIDE_DTYPE *const lambda_ey,
1882
+ TIDE_DTYPE *const lambda_hx,
1883
+ TIDE_DTYPE *const lambda_hz,
1884
+ TIDE_DTYPE *const m_lambda_ey_x,
1885
+ TIDE_DTYPE *const m_lambda_ey_z,
1886
+ TIDE_DTYPE *const m_lambda_hx_z,
1887
+ TIDE_DTYPE *const m_lambda_hz_x,
1888
+ void *const ey_store_1,
1889
+ void *const ey_store_3,
1890
+ char const *const *const ey_filenames,
1891
+ void *const curl_store_1,
1892
+ void *const curl_store_3,
1893
+ char const *const *const curl_filenames,
1894
+ TIDE_DTYPE *const grad_f,
1895
+ TIDE_DTYPE *const grad_ca,
1896
+ TIDE_DTYPE *const grad_cb,
1897
+ TIDE_DTYPE *const grad_eps,
1898
+ TIDE_DTYPE *const grad_sigma,
1899
+ TIDE_DTYPE *const grad_ca_shot, // [n_shots, ny, nx] - per-shot gradient workspace
1900
+ TIDE_DTYPE *const grad_cb_shot, // [n_shots, ny, nx] - per-shot gradient workspace
1901
+ TIDE_DTYPE const *const ay,
1902
+ TIDE_DTYPE const *const by,
1903
+ TIDE_DTYPE const *const ayh,
1904
+ TIDE_DTYPE const *const byh,
1905
+ TIDE_DTYPE const *const ax,
1906
+ TIDE_DTYPE const *const bx,
1907
+ TIDE_DTYPE const *const axh,
1908
+ TIDE_DTYPE const *const bxh,
1909
+ TIDE_DTYPE const *const ky,
1910
+ TIDE_DTYPE const *const kyh,
1911
+ TIDE_DTYPE const *const kx,
1912
+ TIDE_DTYPE const *const kxh,
1913
+ int64_t const *const sources_i,
1914
+ int64_t const *const receivers_i,
1915
+ TIDE_DTYPE const rdy_h,
1916
+ TIDE_DTYPE const rdx_h,
1917
+ TIDE_DTYPE const dt_h,
1918
+ int64_t const nt,
1919
+ int64_t const n_shots_h,
1920
+ int64_t const ny_h,
1921
+ int64_t const nx_h,
1922
+ int64_t const n_sources_per_shot_h,
1923
+ int64_t const n_receivers_per_shot_h,
1924
+ int64_t const step_ratio_h,
1925
+ int64_t const storage_mode_h,
1926
+ int64_t const shot_bytes_uncomp_h,
1927
+ bool const ca_requires_grad,
1928
+ bool const cb_requires_grad,
1929
+ bool const ca_batched_h,
1930
+ bool const cb_batched_h,
1931
+ bool const cq_batched_h,
1932
+ int64_t const start_t,
1933
+ int64_t const pml_y0_h,
1934
+ int64_t const pml_x0_h,
1935
+ int64_t const pml_y1_h,
1936
+ int64_t const pml_x1_h,
1937
+ int64_t const n_threads,
1938
+ int64_t const device) {
1939
+
1940
+ cudaSetDevice(device);
1941
+ (void)dt_h;
1942
+ (void)n_threads;
1943
+
1944
+ int64_t const shot_numel_h = ny_h * nx_h;
1945
+ size_t const bytes_per_step_store =
1946
+ (size_t)shot_bytes_uncomp_h * (size_t)n_shots_h;
1947
+ bool const storage_bf16_h = (shot_bytes_uncomp_h == shot_numel_h * 2);
1948
+ bool const storage_fp8_h = (shot_bytes_uncomp_h == shot_numel_h);
1949
+ cudaStream_t copy_stream = nullptr;
1950
+ cudaEvent_t copy_done[NUM_BUFFERS];
1951
+ bool copy_in_flight[NUM_BUFFERS];
1952
+ for (int i = 0; i < NUM_BUFFERS; i++) copy_in_flight[i] = false;
1953
+
1954
+ #ifdef TIDE_PROFILING
1955
+ cudaEvent_t prof_prefetch_start, prof_prefetch_end, prof_wait_start, prof_wait_end;
1956
+ float total_prefetch_ms = 0.0f, total_wait_ms = 0.0f;
1957
+ int n_prefetches = 0, n_waits = 0;
1958
+ #endif
1959
+
1960
+ if (storage_mode_h == STORAGE_CPU) {
1961
+ gpuErrchk(cudaStreamCreateWithFlags(&copy_stream, cudaStreamNonBlocking));
1962
+ #ifdef TIDE_PROFILING
1963
+ PROF_EVENT_CREATE(prof_prefetch_start);
1964
+ PROF_EVENT_CREATE(prof_prefetch_end);
1965
+ PROF_EVENT_CREATE(prof_wait_start);
1966
+ PROF_EVENT_CREATE(prof_wait_end);
1967
+ for (int i = 0; i < NUM_BUFFERS; i++) {
1968
+ PROF_EVENT_CREATE(copy_done[i]);
1969
+ }
1970
+ #else
1971
+ for (int i = 0; i < NUM_BUFFERS; i++) {
1972
+ gpuErrchk(cudaEventCreateWithFlags(&copy_done[i], cudaEventDisableTiming));
1973
+ }
1974
+ #endif
1975
+ }
1976
+
1977
+ // Copy constants to device with caching to avoid redundant copies
1978
+ static TIDE_DTYPE cached_rdy3 = 0, cached_rdx3 = 0;
1979
+ static int64_t cached_n_shots3 = -1, cached_ny3 = -1, cached_nx3 = -1;
1980
+ static int64_t cached_shot_numel3 = -1, cached_n_sources_per_shot3 = -1, cached_n_receivers_per_shot3 = -1;
1981
+ static int64_t cached_pml_y03 = -1, cached_pml_y13 = -1;
1982
+ static int64_t cached_pml_x03 = -1, cached_pml_x13 = -1;
1983
+ static bool cached_ca_batched3 = false, cached_cb_batched3 = false, cached_cq_batched3 = false;
1984
+ static int64_t cached_device3 = -1;
1985
+ static bool first_call3 = true;
1986
+
1987
+ if (first_call3 || cached_device3 != device || cached_rdy3 != rdy_h || cached_rdx3 != rdx_h ||
1988
+ cached_n_shots3 != n_shots_h || cached_ny3 != ny_h || cached_nx3 != nx_h ||
1989
+ cached_shot_numel3 != shot_numel_h || cached_n_sources_per_shot3 != n_sources_per_shot_h ||
1990
+ cached_n_receivers_per_shot3 != n_receivers_per_shot_h ||
1991
+ cached_pml_y03 != pml_y0_h || cached_pml_y13 != pml_y1_h ||
1992
+ cached_pml_x03 != pml_x0_h || cached_pml_x13 != pml_x1_h ||
1993
+ cached_ca_batched3 != ca_batched_h || cached_cb_batched3 != cb_batched_h ||
1994
+ cached_cq_batched3 != cq_batched_h) {
1995
+
1996
+ cudaMemcpyToSymbol(rdy, &rdy_h, sizeof(TIDE_DTYPE));
1997
+ cudaMemcpyToSymbol(rdx, &rdx_h, sizeof(TIDE_DTYPE));
1998
+ cudaMemcpyToSymbol(n_shots, &n_shots_h, sizeof(int64_t));
1999
+ cudaMemcpyToSymbol(ny, &ny_h, sizeof(int64_t));
2000
+ cudaMemcpyToSymbol(nx, &nx_h, sizeof(int64_t));
2001
+ cudaMemcpyToSymbol(shot_numel, &shot_numel_h, sizeof(int64_t));
2002
+ cudaMemcpyToSymbol(n_sources_per_shot, &n_sources_per_shot_h, sizeof(int64_t));
2003
+ cudaMemcpyToSymbol(n_receivers_per_shot, &n_receivers_per_shot_h, sizeof(int64_t));
2004
+ cudaMemcpyToSymbol(pml_y0, &pml_y0_h, sizeof(int64_t));
2005
+ cudaMemcpyToSymbol(pml_y1, &pml_y1_h, sizeof(int64_t));
2006
+ cudaMemcpyToSymbol(pml_x0, &pml_x0_h, sizeof(int64_t));
2007
+ cudaMemcpyToSymbol(pml_x1, &pml_x1_h, sizeof(int64_t));
2008
+ cudaMemcpyToSymbol(ca_batched, &ca_batched_h, sizeof(bool));
2009
+ cudaMemcpyToSymbol(cb_batched, &cb_batched_h, sizeof(bool));
2010
+ cudaMemcpyToSymbol(cq_batched, &cq_batched_h, sizeof(bool));
2011
+
2012
+ cached_rdy3 = rdy_h; cached_rdx3 = rdx_h;
2013
+ cached_n_shots3 = n_shots_h; cached_ny3 = ny_h; cached_nx3 = nx_h;
2014
+ cached_shot_numel3 = shot_numel_h; cached_n_sources_per_shot3 = n_sources_per_shot_h;
2015
+ cached_n_receivers_per_shot3 = n_receivers_per_shot_h;
2016
+ cached_pml_y03 = pml_y0_h; cached_pml_y13 = pml_y1_h;
2017
+ cached_pml_x03 = pml_x0_h; cached_pml_x13 = pml_x1_h;
2018
+ cached_ca_batched3 = ca_batched_h; cached_cb_batched3 = cb_batched_h;
2019
+ cached_cq_batched3 = cq_batched_h;
2020
+ cached_device3 = device;
2021
+ first_call3 = false;
2022
+ }
2023
+
2024
+ dim3 dimBlock(32, 8, 1);
2025
+ int64_t gridx = (nx_h - 2 * FD_PAD + 2 + dimBlock.x - 1) / dimBlock.x;
2026
+ int64_t gridy = (ny_h - 2 * FD_PAD + 2 + dimBlock.y - 1) / dimBlock.y;
2027
+ int64_t gridz = n_shots_h;
2028
+ dim3 dimGrid(gridx, gridy, gridz);
2029
+
2030
+ dim3 dimBlock_sources(32, 1, 1);
2031
+ dim3 dimGrid_sources(
2032
+ (n_sources_per_shot_h + dimBlock_sources.x - 1) / dimBlock_sources.x,
2033
+ n_shots_h, 1);
2034
+
2035
+ dim3 dimBlock_receivers(32, 1, 1);
2036
+ dim3 dimGrid_receivers(
2037
+ (n_receivers_per_shot_h + dimBlock_receivers.x - 1) / dimBlock_receivers.x,
2038
+ n_shots_h, 1);
2039
+
2040
+ FILE *fp_ey = nullptr;
2041
+ FILE *fp_curl = nullptr;
2042
+ if (storage_mode_h == STORAGE_DISK) {
2043
+ if (ca_requires_grad) fp_ey = fopen(ey_filenames[0], "rb");
2044
+ if (cb_requires_grad) fp_curl = fopen(curl_filenames[0], "rb");
2045
+ }
2046
+
2047
+ auto store1_offset_bytes = [&](int64_t store_idx) -> size_t {
2048
+ if (storage_mode_h == STORAGE_DEVICE) {
2049
+ return (size_t)store_idx * bytes_per_step_store;
2050
+ }
2051
+ if (storage_mode_h == STORAGE_CPU) {
2052
+ return (size_t)(store_idx % NUM_BUFFERS) * bytes_per_step_store;
2053
+ }
2054
+ return 0;
2055
+ };
2056
+
2057
+ auto store3_offset_bytes = [&](int64_t store_idx) -> size_t {
2058
+ return (storage_mode_h == STORAGE_CPU)
2059
+ ? (size_t)store_idx * bytes_per_step_store
2060
+ : 0;
2061
+ };
2062
+
2063
+ auto prefetch_snapshots = [&](int64_t store_idx, bool want_ey, bool want_curl) {
2064
+ if (storage_mode_h != STORAGE_CPU || (!want_ey && !want_curl)) {
2065
+ return;
2066
+ }
2067
+ int const store_buf = (int)(store_idx % NUM_BUFFERS);
2068
+ if (copy_in_flight[store_buf]) {
2069
+ gpuErrchk(cudaStreamWaitEvent(copy_stream, copy_done[store_buf], 0));
2070
+ }
2071
+ #ifdef TIDE_PROFILING
2072
+ PROF_RECORD(prof_prefetch_start, copy_stream);
2073
+ #endif
2074
+ size_t const store1_offset = store1_offset_bytes(store_idx);
2075
+ size_t const store3_offset = store3_offset_bytes(store_idx);
2076
+ void *ey_store_1_t = (uint8_t *)ey_store_1 + store1_offset;
2077
+ void *curl_store_1_t = (uint8_t *)curl_store_1 + store1_offset;
2078
+ void *ey_store_3_t = (uint8_t *)ey_store_3 + store3_offset;
2079
+ void *curl_store_3_t = (uint8_t *)curl_store_3 + store3_offset;
2080
+ if (want_ey) {
2081
+ gpuErrchk(cudaMemcpyAsync(
2082
+ ey_store_1_t, ey_store_3_t, bytes_per_step_store,
2083
+ cudaMemcpyHostToDevice, copy_stream));
2084
+ }
2085
+ if (want_curl) {
2086
+ gpuErrchk(cudaMemcpyAsync(
2087
+ curl_store_1_t, curl_store_3_t, bytes_per_step_store,
2088
+ cudaMemcpyHostToDevice, copy_stream));
2089
+ }
2090
+ #ifdef TIDE_PROFILING
2091
+ PROF_RECORD(prof_prefetch_end, copy_stream);
2092
+ #endif
2093
+ gpuErrchk(cudaEventRecord(copy_done[store_buf], copy_stream));
2094
+ copy_in_flight[store_buf] = true;
2095
+ #ifdef TIDE_PROFILING
2096
+ n_prefetches++;
2097
+ #endif
2098
+ };
2099
+
2100
+ int64_t const t_min = start_t - nt;
2101
+ if (storage_mode_h == STORAGE_CPU && (ca_requires_grad || cb_requires_grad)) {
2102
+ int64_t t_prefetch = start_t - 1;
2103
+ int64_t const mod = t_prefetch % step_ratio_h;
2104
+ if (mod != 0) t_prefetch -= mod;
2105
+ if (t_prefetch >= t_min) {
2106
+ prefetch_snapshots(
2107
+ t_prefetch / step_ratio_h, ca_requires_grad, cb_requires_grad);
2108
+ }
2109
+ }
2110
+
2111
+ // Time reversed loop
2112
+ for (int64_t t = start_t - 1; t >= start_t - nt; --t) {
2113
+ // Inject adjoint source (receiver residual) at receiver locations
2114
+ // Use add_adjoint_sources_ey which checks n_receivers_per_shot
2115
+ if (n_receivers_per_shot_h > 0) {
2116
+ add_adjoint_sources_ey<<<dimGrid_receivers, dimBlock_receivers>>>(
2117
+ lambda_ey, grad_r + t * n_shots_h * n_receivers_per_shot_h, receivers_i);
2118
+ }
2119
+
2120
+ // Record adjoint field at source locations for source gradient
2121
+ // Use record_adjoint_at_sources which checks n_sources_per_shot
2122
+ if (n_sources_per_shot_h > 0) {
2123
+ record_adjoint_at_sources<<<dimGrid_sources, dimBlock_sources>>>(
2124
+ grad_f + t * n_shots_h * n_sources_per_shot_h,
2125
+ lambda_ey, sources_i);
2126
+ }
2127
+
2128
+ int64_t const store_idx = t / step_ratio_h;
2129
+ bool const do_grad = (t % step_ratio_h) == 0;
2130
+ bool const grad_ey = do_grad && ca_requires_grad;
2131
+ bool const grad_curl = do_grad && cb_requires_grad;
2132
+
2133
+ size_t const store1_offset = store1_offset_bytes(store_idx);
2134
+ size_t const store3_offset = store3_offset_bytes(store_idx);
2135
+
2136
+ void *__restrict const ey_store_1_t =
2137
+ (uint8_t *)ey_store_1 + store1_offset;
2138
+ void *__restrict const ey_store_3_t =
2139
+ (uint8_t *)ey_store_3 + store3_offset;
2140
+
2141
+ void *__restrict const curl_store_1_t =
2142
+ (uint8_t *)curl_store_1 + store1_offset;
2143
+ void *__restrict const curl_store_3_t =
2144
+ (uint8_t *)curl_store_3 + store3_offset;
2145
+
2146
+ if (storage_mode_h == STORAGE_CPU && (grad_ey || grad_curl)) {
2147
+ int const store_buf = (int)(store_idx % NUM_BUFFERS);
2148
+ if (!copy_in_flight[store_buf]) {
2149
+ prefetch_snapshots(store_idx, grad_ey, grad_curl);
2150
+ }
2151
+ #ifdef TIDE_PROFILING
2152
+ PROF_RECORD(prof_wait_start, 0);
2153
+ #endif
2154
+ gpuErrchk(cudaStreamWaitEvent(0, copy_done[store_buf], 0));
2155
+ #ifdef TIDE_PROFILING
2156
+ PROF_RECORD(prof_wait_end, 0);
2157
+ gpuErrchk(cudaDeviceSynchronize());
2158
+ float wait_ms;
2159
+ PROF_ELAPSED(prof_wait_start, prof_wait_end, wait_ms);
2160
+ total_wait_ms += wait_ms;
2161
+ n_waits++;
2162
+ #endif
2163
+ copy_in_flight[store_buf] = false;
2164
+ } else if (storage_mode_h == STORAGE_DISK) {
2165
+ if (grad_ey) {
2166
+ storage_load_snapshot_gpu(
2167
+ (void *)ey_store_1_t, (void *)ey_store_3_t, fp_ey, storage_mode_h,
2168
+ store_idx, (size_t)shot_bytes_uncomp_h, (size_t)n_shots_h);
2169
+ }
2170
+ if (grad_curl) {
2171
+ storage_load_snapshot_gpu(
2172
+ (void *)curl_store_1_t, (void *)curl_store_3_t, fp_curl, storage_mode_h,
2173
+ store_idx, (size_t)shot_bytes_uncomp_h, (size_t)n_shots_h);
2174
+ }
2175
+ }
2176
+
2177
+ // Backward λ_H fields update
2178
+ backward_kernel_lambda_h<<<dimGrid, dimBlock>>>(
2179
+ cb, lambda_ey, lambda_hx, lambda_hz,
2180
+ m_lambda_ey_x, m_lambda_ey_z,
2181
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2182
+ ky, kyh, kx, kxh);
2183
+
2184
+ // Backward λ_Ey update (specialized kernel when no gradient is needed).
2185
+ if (grad_ey || grad_curl) {
2186
+ if (storage_fp8_h) {
2187
+ backward_kernel_lambda_e_with_grad_fp8<<<dimGrid, dimBlock>>>(
2188
+ ca, cq, lambda_hx, lambda_hz, lambda_ey,
2189
+ m_lambda_hx_z, m_lambda_hz_x,
2190
+ grad_ey ? (uint8_t const *)ey_store_1_t : nullptr,
2191
+ grad_curl ? (uint8_t const *)curl_store_1_t : nullptr,
2192
+ grad_ca_shot, grad_cb_shot,
2193
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2194
+ ky, kyh, kx, kxh,
2195
+ grad_ey, grad_curl,
2196
+ step_ratio_h);
2197
+ } else if (storage_bf16_h) {
2198
+ backward_kernel_lambda_e_with_grad_bf16<<<dimGrid, dimBlock>>>(
2199
+ ca, cq, lambda_hx, lambda_hz, lambda_ey,
2200
+ m_lambda_hx_z, m_lambda_hz_x,
2201
+ grad_ey ? (__nv_bfloat16 const *)ey_store_1_t : nullptr,
2202
+ grad_curl ? (__nv_bfloat16 const *)curl_store_1_t : nullptr,
2203
+ grad_ca_shot, grad_cb_shot,
2204
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2205
+ ky, kyh, kx, kxh,
2206
+ grad_ey, grad_curl,
2207
+ step_ratio_h);
2208
+ } else {
2209
+ backward_kernel_lambda_e_with_grad<<<dimGrid, dimBlock>>>(
2210
+ ca, cq, lambda_hx, lambda_hz, lambda_ey,
2211
+ m_lambda_hx_z, m_lambda_hz_x,
2212
+ grad_ey ? (TIDE_DTYPE const *)ey_store_1_t : nullptr,
2213
+ grad_curl ? (TIDE_DTYPE const *)curl_store_1_t : nullptr,
2214
+ grad_ca_shot, grad_cb_shot,
2215
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2216
+ ky, kyh, kx, kxh,
2217
+ grad_ey, grad_curl,
2218
+ step_ratio_h);
2219
+ }
2220
+ } else {
2221
+ backward_kernel_lambda_e<<<dimGrid, dimBlock>>>(
2222
+ ca, cq, lambda_hx, lambda_hz, lambda_ey,
2223
+ m_lambda_hx_z, m_lambda_hz_x,
2224
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2225
+ ky, kyh, kx, kxh);
2226
+ }
2227
+
2228
+ if (storage_mode_h == STORAGE_CPU && do_grad &&
2229
+ (ca_requires_grad || cb_requires_grad)) {
2230
+ int64_t const next_t = t - step_ratio_h;
2231
+ if (next_t >= t_min) {
2232
+ prefetch_snapshots(store_idx - 1, ca_requires_grad, cb_requires_grad);
2233
+ }
2234
+ }
2235
+ }
2236
+
2237
+ if (storage_mode_h == STORAGE_CPU) {
2238
+ gpuErrchk(cudaStreamSynchronize(copy_stream));
2239
+ #ifdef TIDE_PROFILING
2240
+ // Compute and print profiling statistics
2241
+ if (n_prefetches > 0) {
2242
+ gpuErrchk(cudaDeviceSynchronize());
2243
+ float avg_prefetch_ms = 0.0f;
2244
+ for (int i = 0; i < NUM_BUFFERS; i++) {
2245
+ float ms;
2246
+ // Note: per-copy timing would require more events
2247
+ }
2248
+ PROF_PRINT("Backward H2D prefetch count", (float)n_prefetches);
2249
+ }
2250
+ if (n_waits > 0) {
2251
+ float avg_wait_ms = total_wait_ms / n_waits;
2252
+ PROF_PRINT("Backward avg wait time", avg_wait_ms);
2253
+ PROF_PRINT("Backward total wait time", total_wait_ms);
2254
+ }
2255
+ PROF_EVENT_CREATE(prof_prefetch_start); // Dummy to avoid unused warning
2256
+ cudaEventDestroy(prof_prefetch_start);
2257
+ cudaEventDestroy(prof_prefetch_end);
2258
+ cudaEventDestroy(prof_wait_start);
2259
+ cudaEventDestroy(prof_wait_end);
2260
+ #endif
2261
+ for (int i = 0; i < NUM_BUFFERS; i++) {
2262
+ gpuErrchk(cudaEventDestroy(copy_done[i]));
2263
+ }
2264
+ gpuErrchk(cudaStreamDestroy(copy_stream));
2265
+ }
2266
+
2267
+ if (fp_ey != nullptr) fclose(fp_ey);
2268
+ if (fp_curl != nullptr) fclose(fp_curl);
2269
+
2270
+ // Combine per-shot gradients (only if not batched - batched case keeps per-shot grads)
2271
+ dim3 dimBlock_combine(32, 32, 1);
2272
+ dim3 dimGrid_combine(
2273
+ (nx_h - 2 * FD_PAD + dimBlock_combine.x - 1) / dimBlock_combine.x,
2274
+ (ny_h - 2 * FD_PAD + dimBlock_combine.y - 1) / dimBlock_combine.y, 1);
2275
+
2276
+ if (ca_requires_grad && !ca_batched_h) {
2277
+ combine_grad<<<dimGrid_combine, dimBlock_combine>>>(grad_ca, grad_ca_shot);
2278
+ }
2279
+ if (cb_requires_grad && !cb_batched_h) {
2280
+ combine_grad<<<dimGrid_combine, dimBlock_combine>>>(grad_cb, grad_cb_shot);
2281
+ }
2282
+
2283
+ if ((grad_eps != nullptr || grad_sigma != nullptr) && (ca_requires_grad || cb_requires_grad)) {
2284
+ dim3 dimBlock_conv(32, 8, 1);
2285
+ dim3 dimGrid_conv(
2286
+ (nx_h + dimBlock_conv.x - 1) / dimBlock_conv.x,
2287
+ (ny_h + dimBlock_conv.y - 1) / dimBlock_conv.y,
2288
+ ca_batched_h ? n_shots_h : 1);
2289
+ convert_grad_ca_cb_to_eps_sigma<<<dimGrid_conv, dimBlock_conv>>>(
2290
+ ca, cb, grad_ca, grad_cb, grad_ca_shot, grad_cb_shot,
2291
+ grad_eps, grad_sigma, dt_h,
2292
+ ca_requires_grad, cb_requires_grad,
2293
+ ca_batched_h, cb_batched_h);
2294
+ }
2295
+
2296
+ gpuErrchk(cudaPeekAtLastError());
2297
+ }