tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tide/__init__.py +65 -0
- tide/autograd_utils.py +26 -0
- tide/backend_utils.py +536 -0
- tide/callbacks.py +348 -0
- tide/cfl.py +64 -0
- tide/csrc/CMakeLists.txt +263 -0
- tide/csrc/common_cpu.h +31 -0
- tide/csrc/common_gpu.h +56 -0
- tide/csrc/maxwell.c +2133 -0
- tide/csrc/maxwell.cu +2297 -0
- tide/csrc/maxwell_born.cu +0 -0
- tide/csrc/staggered_grid.h +175 -0
- tide/csrc/staggered_grid_3d.h +124 -0
- tide/csrc/storage_utils.c +78 -0
- tide/csrc/storage_utils.cu +135 -0
- tide/csrc/storage_utils.h +36 -0
- tide/grid_utils.py +31 -0
- tide/maxwell.py +2651 -0
- tide/padding.py +139 -0
- tide/resampling.py +246 -0
- tide/staggered.py +567 -0
- tide/storage.py +131 -0
- tide/tide/libtide_C.so +0 -0
- tide/utils.py +274 -0
- tide/validation.py +71 -0
- tide/wavelets.py +72 -0
- tide_gpr-0.0.9.dist-info/METADATA +256 -0
- tide_gpr-0.0.9.dist-info/RECORD +31 -0
- tide_gpr-0.0.9.dist-info/WHEEL +5 -0
- tide_gpr-0.0.9.dist-info/licenses/LICENSE +46 -0
- tide_gpr.libs/libgomp-24e2ab19.so.1.0.0 +0 -0
tide/csrc/maxwell.cu
ADDED
|
@@ -0,0 +1,2297 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Maxwell wave equation propagator (CUDA implementation)
|
|
3
|
+
*
|
|
4
|
+
* This file contains the CUDA implementation of the 2D TM Maxwell equations
|
|
5
|
+
* propagator with complete Adjoint State Method (ASM) support for gradient computation.
|
|
6
|
+
*
|
|
7
|
+
* TM mode fields: Ey (electric), Hx, Hz (magnetic)
|
|
8
|
+
*
|
|
9
|
+
* EXACT DISCRETE Adjoint State Method for Maxwell TM equations:
|
|
10
|
+
* =============================================================
|
|
11
|
+
* Forward equations (discrete):
|
|
12
|
+
* E_y^{n+1} = C_a * E_y^n + C_b * (D_x[H_z] - D_z[H_x])
|
|
13
|
+
* H_x^{n+1/2} = H_x^{n-1/2} - C_q * D_z^h[E_y]
|
|
14
|
+
* H_z^{n+1/2} = H_z^{n-1/2} + C_q * D_x^h[E_y]
|
|
15
|
+
*
|
|
16
|
+
* Exact discrete adjoint equations (time-reversed with transposed operators):
|
|
17
|
+
* λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * (D_x^{hT}[λ_Hz] - D_z^{hT}[λ_Hx])
|
|
18
|
+
* λ_Hx^{n-1/2} = λ_Hx^{n+1/2} - C_b * D_z^T[λ_Ey]
|
|
19
|
+
* λ_Hz^{n-1/2} = λ_Hz^{n+1/2} + C_b * D_x^T[λ_Ey]
|
|
20
|
+
*
|
|
21
|
+
* Model gradients:
|
|
22
|
+
* ∂J/∂C_a = Σ_n E_y^n * λ_Ey^{n+1}
|
|
23
|
+
* ∂J/∂C_b = Σ_n curl_H^n * λ_Ey^{n+1}
|
|
24
|
+
*
|
|
25
|
+
* Gradient accumulation strategy:
|
|
26
|
+
* - Use per-shot gradient arrays (grad_ca_shot, grad_cb_shot)
|
|
27
|
+
* - Each shot writes to its own memory region (no race condition)
|
|
28
|
+
* - Use combine_grad kernel to sum across shots at the end
|
|
29
|
+
*/
|
|
30
|
+
|
|
31
|
+
#include <stdio.h>
|
|
32
|
+
#include <cstdint>
|
|
33
|
+
#include <cstdlib>
|
|
34
|
+
#include <climits>
|
|
35
|
+
#include <math.h>
|
|
36
|
+
#include <cuda_bf16.h>
|
|
37
|
+
#if defined(__has_include)
|
|
38
|
+
#if __has_include(<cuda_fp8.h>)
|
|
39
|
+
#include <cuda_fp8.h>
|
|
40
|
+
#define TIDE_HAVE_CUDA_FP8 1
|
|
41
|
+
#else
|
|
42
|
+
#define TIDE_HAVE_CUDA_FP8 0
|
|
43
|
+
#endif
|
|
44
|
+
#else
|
|
45
|
+
#define TIDE_HAVE_CUDA_FP8 0
|
|
46
|
+
#endif
|
|
47
|
+
#include "common_gpu.h"
|
|
48
|
+
#include "staggered_grid.h"
|
|
49
|
+
#include "storage_utils.h"
|
|
50
|
+
|
|
51
|
+
#ifndef TIDE_DEVICE
|
|
52
|
+
#define TIDE_DEVICE cuda
|
|
53
|
+
#endif
|
|
54
|
+
|
|
55
|
+
// CPU storage pipelining: Number of ping-pong buffers for async D2H/H2D copies
|
|
56
|
+
// Increasing this reduces synchronization stalls between compute and copy
|
|
57
|
+
#ifndef NUM_BUFFERS
|
|
58
|
+
#define NUM_BUFFERS 3
|
|
59
|
+
#endif
|
|
60
|
+
|
|
61
|
+
// Profiling support: enable with -DTIDE_PROFILING during compilation
|
|
62
|
+
#ifdef TIDE_PROFILING
|
|
63
|
+
#define PROF_EVENT_CREATE(e) cudaEventCreate(&(e))
|
|
64
|
+
#define PROF_RECORD(e, s) cudaEventRecord((e), (s))
|
|
65
|
+
#define PROF_ELAPSED(start, end, ms) cudaEventElapsedTime(&(ms), (start), (end))
|
|
66
|
+
#define PROF_PRINT(name, ms) fprintf(stderr, "[TIDE PROF] %s: %.3f ms\n", (name), (ms))
|
|
67
|
+
#else
|
|
68
|
+
#define PROF_EVENT_CREATE(e) ((void)0)
|
|
69
|
+
#define PROF_RECORD(e, s) ((void)0)
|
|
70
|
+
#define PROF_ELAPSED(start, end, ms) ((void)0)
|
|
71
|
+
#define PROF_PRINT(name, ms) ((void)0)
|
|
72
|
+
#endif
|
|
73
|
+
|
|
74
|
+
#define CAT_I(name, accuracy, dtype, device) \
|
|
75
|
+
maxwell_tm_##accuracy##_##dtype##_##name##_##device
|
|
76
|
+
#define CAT(name, accuracy, dtype, device) \
|
|
77
|
+
CAT_I(name, accuracy, dtype, device)
|
|
78
|
+
#define FUNC(name) CAT(name, TIDE_STENCIL, TIDE_DTYPE, TIDE_DEVICE)
|
|
79
|
+
|
|
80
|
+
// 2D indexing macros
|
|
81
|
+
#define ND_INDEX(i, dy, dx) (i + (dy)*nx + (dx))
|
|
82
|
+
#define ND_INDEX_J(j, dy, dx) (j + (dy)*nx + (dx))
|
|
83
|
+
|
|
84
|
+
#define gpuErrchk(ans) \
|
|
85
|
+
{ gpuAssert((ans), __FILE__, __LINE__); }
|
|
86
|
+
// Field access macros
|
|
87
|
+
#define EY(dy, dx) ey[ND_INDEX(i, dy, dx)]
|
|
88
|
+
#define HX(dy, dx) hx[ND_INDEX(i, dy, dx)]
|
|
89
|
+
#define HZ(dy, dx) hz[ND_INDEX(i, dy, dx)]
|
|
90
|
+
|
|
91
|
+
// Adjoint field access macros
|
|
92
|
+
#define LAMBDA_EY(dy, dx) lambda_ey[ND_INDEX(i, dy, dx)]
|
|
93
|
+
#define LAMBDA_HX(dy, dx) lambda_hx[ND_INDEX(i, dy, dx)]
|
|
94
|
+
#define LAMBDA_HZ(dy, dx) lambda_hz[ND_INDEX(i, dy, dx)]
|
|
95
|
+
|
|
96
|
+
// Material parameter access macros
|
|
97
|
+
#define CA(dy, dx) ca_shot[ND_INDEX_J(j, dy, dx)]
|
|
98
|
+
#define CB(dy, dx) cb_shot[ND_INDEX_J(j, dy, dx)]
|
|
99
|
+
#define CQ(dy, dx) cq_shot[ND_INDEX_J(j, dy, dx)]
|
|
100
|
+
|
|
101
|
+
// PML memory variable macros
|
|
102
|
+
#define M_HX_Z(dy, dx) m_hx_z[ND_INDEX(i, dy, dx)]
|
|
103
|
+
#define M_HZ_X(dy, dx) m_hz_x[ND_INDEX(i, dy, dx)]
|
|
104
|
+
#define M_EY_X(dy, dx) m_ey_x[ND_INDEX(i, dy, dx)]
|
|
105
|
+
#define M_EY_Z(dy, dx) m_ey_z[ND_INDEX(i, dy, dx)]
|
|
106
|
+
|
|
107
|
+
// Adjoint PML memory variable macros
|
|
108
|
+
#define M_LAMBDA_EY_X(dy, dx) m_lambda_ey_x[ND_INDEX(i, dy, dx)]
|
|
109
|
+
#define M_LAMBDA_EY_Z(dy, dx) m_lambda_ey_z[ND_INDEX(i, dy, dx)]
|
|
110
|
+
#define M_LAMBDA_HX_Z(dy, dx) m_lambda_hx_z[ND_INDEX(i, dy, dx)]
|
|
111
|
+
#define M_LAMBDA_HZ_X(dy, dx) m_lambda_hz_x[ND_INDEX(i, dy, dx)]
|
|
112
|
+
|
|
113
|
+
#define MAX(a, b) (a > b ? a : b)
|
|
114
|
+
|
|
115
|
+
// Vacuum permittivity (F/m) to convert dL/d(epsilon_abs) -> dL/d(epsilon_r)
|
|
116
|
+
#define EP0 ((TIDE_DTYPE)8.8541878128e-12)
|
|
117
|
+
|
|
118
|
+
namespace {
|
|
119
|
+
|
|
120
|
+
// Device constants
|
|
121
|
+
__constant__ TIDE_DTYPE rdy;
|
|
122
|
+
__constant__ TIDE_DTYPE rdx;
|
|
123
|
+
__constant__ int64_t n_shots;
|
|
124
|
+
__constant__ int64_t ny;
|
|
125
|
+
__constant__ int64_t nx;
|
|
126
|
+
__constant__ int64_t shot_numel;
|
|
127
|
+
__constant__ int64_t n_sources_per_shot;
|
|
128
|
+
__constant__ int64_t n_receivers_per_shot;
|
|
129
|
+
__constant__ int64_t pml_y0;
|
|
130
|
+
__constant__ int64_t pml_y1;
|
|
131
|
+
__constant__ int64_t pml_x0;
|
|
132
|
+
__constant__ int64_t pml_x1;
|
|
133
|
+
__constant__ bool ca_batched;
|
|
134
|
+
__constant__ bool cb_batched;
|
|
135
|
+
__constant__ bool cq_batched;
|
|
136
|
+
|
|
137
|
+
// Add source to field
|
|
138
|
+
__global__ void add_sources_ey(TIDE_DTYPE *__restrict const ey,
|
|
139
|
+
TIDE_DTYPE const *__restrict const f,
|
|
140
|
+
int64_t const *__restrict const sources_i) {
|
|
141
|
+
int64_t source_idx =
|
|
142
|
+
(int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
|
143
|
+
int64_t shot_idx =
|
|
144
|
+
(int64_t)blockIdx.y * (int64_t)blockDim.y + (int64_t)threadIdx.y;
|
|
145
|
+
if (source_idx < n_sources_per_shot && shot_idx < n_shots) {
|
|
146
|
+
int64_t k = shot_idx * n_sources_per_shot + source_idx;
|
|
147
|
+
int64_t const src = sources_i[k];
|
|
148
|
+
if (0 <= src) {
|
|
149
|
+
ey[shot_idx * shot_numel + src] += f[k];
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
// Add adjoint source at receiver locations (for backward pass)
|
|
155
|
+
__global__ void add_adjoint_sources_ey(TIDE_DTYPE *__restrict const ey,
|
|
156
|
+
TIDE_DTYPE const *__restrict const f,
|
|
157
|
+
int64_t const *__restrict const receivers_i) {
|
|
158
|
+
int64_t receiver_idx =
|
|
159
|
+
(int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
|
160
|
+
int64_t shot_idx =
|
|
161
|
+
(int64_t)blockIdx.y * (int64_t)blockDim.y + (int64_t)threadIdx.y;
|
|
162
|
+
if (receiver_idx < n_receivers_per_shot && shot_idx < n_shots) {
|
|
163
|
+
int64_t k = shot_idx * n_receivers_per_shot + receiver_idx;
|
|
164
|
+
int64_t const rec = receivers_i[k];
|
|
165
|
+
if (0 <= rec) {
|
|
166
|
+
ey[shot_idx * shot_numel + rec] += f[k];
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// Record field at receiver locations
|
|
172
|
+
__global__ void record_receivers_ey(TIDE_DTYPE *__restrict const r,
|
|
173
|
+
TIDE_DTYPE const *__restrict const ey,
|
|
174
|
+
int64_t const *__restrict receivers_i) {
|
|
175
|
+
int64_t receiver_idx =
|
|
176
|
+
(int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
|
177
|
+
int64_t shot_idx =
|
|
178
|
+
(int64_t)blockIdx.y * (int64_t)blockDim.y + (int64_t)threadIdx.y;
|
|
179
|
+
if (receiver_idx < n_receivers_per_shot && shot_idx < n_shots) {
|
|
180
|
+
int64_t k = shot_idx * n_receivers_per_shot + receiver_idx;
|
|
181
|
+
int64_t const rec = receivers_i[k];
|
|
182
|
+
if (0 <= rec) {
|
|
183
|
+
r[k] = ey[shot_idx * shot_numel + rec];
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
// Record adjoint field at source locations (for backward pass - source gradient)
|
|
189
|
+
__global__ void record_adjoint_at_sources(TIDE_DTYPE *__restrict const grad_f,
|
|
190
|
+
TIDE_DTYPE const *__restrict const lambda_ey,
|
|
191
|
+
int64_t const *__restrict sources_i) {
|
|
192
|
+
int64_t source_idx =
|
|
193
|
+
(int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
|
|
194
|
+
int64_t shot_idx =
|
|
195
|
+
(int64_t)blockIdx.y * (int64_t)blockDim.y + (int64_t)threadIdx.y;
|
|
196
|
+
if (source_idx < n_sources_per_shot && shot_idx < n_shots) {
|
|
197
|
+
int64_t k = shot_idx * n_sources_per_shot + source_idx;
|
|
198
|
+
int64_t const src = sources_i[k];
|
|
199
|
+
if (0 <= src) {
|
|
200
|
+
grad_f[k] = lambda_ey[shot_idx * shot_numel + src];
|
|
201
|
+
}
|
|
202
|
+
}
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
// FP8 E4M3 (1 sign, 4 exponent, 3 mantissa) encode/decode.
|
|
207
|
+
__device__ __forceinline__ uint8_t fp8_e4m3_from_float(float x) {
|
|
208
|
+
#if TIDE_HAVE_CUDA_FP8
|
|
209
|
+
return (uint8_t)__nv_cvt_float_to_fp8(x, __NV_SATFINITE, __NV_E4M3);
|
|
210
|
+
#else
|
|
211
|
+
if (x == 0.0f) {
|
|
212
|
+
return 0;
|
|
213
|
+
}
|
|
214
|
+
uint8_t sign = (x < 0.0f) ? 0x80 : 0;
|
|
215
|
+
float ax = fabsf(x);
|
|
216
|
+
if (!isfinite(ax)) {
|
|
217
|
+
return (uint8_t)(sign | 0x7F);
|
|
218
|
+
}
|
|
219
|
+
int exp;
|
|
220
|
+
float m = frexpf(ax, &exp); // ax = m * 2^exp, m in [0.5, 1)
|
|
221
|
+
int e = exp - 1;
|
|
222
|
+
int exp_field = e + 7;
|
|
223
|
+
int mant = 0;
|
|
224
|
+
|
|
225
|
+
if (exp_field <= 0) {
|
|
226
|
+
mant = __float2int_rn(ax * 512.0f);
|
|
227
|
+
if (mant <= 0) {
|
|
228
|
+
return sign;
|
|
229
|
+
}
|
|
230
|
+
if (mant > 7) {
|
|
231
|
+
mant = 7;
|
|
232
|
+
}
|
|
233
|
+
exp_field = 0;
|
|
234
|
+
} else if (exp_field >= 0xF) {
|
|
235
|
+
exp_field = 0xE;
|
|
236
|
+
mant = 7;
|
|
237
|
+
} else {
|
|
238
|
+
float frac = m * 2.0f - 1.0f;
|
|
239
|
+
mant = __float2int_rn(frac * 8.0f);
|
|
240
|
+
if (mant == 8) {
|
|
241
|
+
mant = 0;
|
|
242
|
+
exp_field += 1;
|
|
243
|
+
if (exp_field >= 0xF) {
|
|
244
|
+
exp_field = 0xE;
|
|
245
|
+
mant = 7;
|
|
246
|
+
}
|
|
247
|
+
}
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
return (uint8_t)(sign | ((uint8_t)exp_field << 3) | (uint8_t)(mant & 0x7));
|
|
251
|
+
#endif
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
__device__ __forceinline__ float fp8_e4m3_to_float(uint8_t v) {
|
|
255
|
+
#if TIDE_HAVE_CUDA_FP8
|
|
256
|
+
__half h = __nv_cvt_fp8_to_halfraw((__nv_fp8_storage_t)v, __NV_E4M3);
|
|
257
|
+
return __half2float(h);
|
|
258
|
+
#else
|
|
259
|
+
if (v == 0) {
|
|
260
|
+
return 0.0f;
|
|
261
|
+
}
|
|
262
|
+
int sign = v & 0x80;
|
|
263
|
+
int exp_field = (v >> 3) & 0xF;
|
|
264
|
+
int mant = v & 0x7;
|
|
265
|
+
float val;
|
|
266
|
+
if (exp_field == 0) {
|
|
267
|
+
float frac = (float)mant / 8.0f;
|
|
268
|
+
val = ldexpf(frac, -6);
|
|
269
|
+
} else {
|
|
270
|
+
float frac = 1.0f + (float)mant / 8.0f;
|
|
271
|
+
val = ldexpf(frac, exp_field - 7);
|
|
272
|
+
}
|
|
273
|
+
return sign ? -val : val;
|
|
274
|
+
#endif
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
// Forward kernel: Update H fields (Hx and Hz)
|
|
279
|
+
__global__ __launch_bounds__(256) void forward_kernel_h(
|
|
280
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
281
|
+
TIDE_DTYPE const *__restrict const ey,
|
|
282
|
+
TIDE_DTYPE *__restrict const hx,
|
|
283
|
+
TIDE_DTYPE *__restrict const hz,
|
|
284
|
+
TIDE_DTYPE *__restrict const m_ey_x,
|
|
285
|
+
TIDE_DTYPE *__restrict const m_ey_z,
|
|
286
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
287
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
288
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
289
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
290
|
+
TIDE_DTYPE const *__restrict const by,
|
|
291
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
292
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
293
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
294
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
295
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
296
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
297
|
+
TIDE_DTYPE const *__restrict const kxh) {
|
|
298
|
+
|
|
299
|
+
#if FD_PAD > 1
|
|
300
|
+
// Shared-memory tiling for Ey stencil loads.
|
|
301
|
+
// Assumes blockDim.z == 1 (one shot per block).
|
|
302
|
+
extern __shared__ TIDE_DTYPE shmem[];
|
|
303
|
+
TIDE_DTYPE *__restrict const tile_ey = shmem;
|
|
304
|
+
#endif
|
|
305
|
+
|
|
306
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
307
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
308
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
309
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
310
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
311
|
+
(int64_t)threadIdx.z;
|
|
312
|
+
|
|
313
|
+
if (shot_idx >= n_shots) return;
|
|
314
|
+
|
|
315
|
+
#if FD_PAD > 1
|
|
316
|
+
int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
|
|
317
|
+
int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
|
|
318
|
+
int64_t const tile_pitch = tile_w;
|
|
319
|
+
int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
|
|
320
|
+
int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
|
|
321
|
+
int64_t const base = shot_idx * shot_numel;
|
|
322
|
+
|
|
323
|
+
int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
|
|
324
|
+
(int64_t)threadIdx.x;
|
|
325
|
+
int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
|
|
326
|
+
int64_t const tile_numel = tile_w * tile_h;
|
|
327
|
+
// Original scalar loading (optimization 2.1: vectorized loading disabled due to overhead)
|
|
328
|
+
for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
|
|
329
|
+
int64_t const ly = idx / tile_w;
|
|
330
|
+
int64_t const lx = idx - ly * tile_w;
|
|
331
|
+
int64_t const gx = x0 - FD_PAD + lx;
|
|
332
|
+
int64_t const gy = y0 - FD_PAD + ly;
|
|
333
|
+
if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
|
|
334
|
+
tile_ey[ly * tile_pitch + lx] = __ldg(&ey[base + gy * nx + gx]);
|
|
335
|
+
} else {
|
|
336
|
+
tile_ey[ly * tile_pitch + lx] = (TIDE_DTYPE)0;
|
|
337
|
+
}
|
|
338
|
+
}
|
|
339
|
+
__syncthreads();
|
|
340
|
+
|
|
341
|
+
#define EY_L(dy, dx) tile_ey[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
342
|
+
#else
|
|
343
|
+
#define EY_L(dy, dx) EY(dy, dx)
|
|
344
|
+
#endif
|
|
345
|
+
|
|
346
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
347
|
+
int64_t const pml_y0h = pml_y0;
|
|
348
|
+
int64_t const pml_y1h = MAX(pml_y0, pml_y1 - 1);
|
|
349
|
+
int64_t const pml_x0h = pml_x0;
|
|
350
|
+
int64_t const pml_x1h = MAX(pml_x0, pml_x1 - 1);
|
|
351
|
+
|
|
352
|
+
int64_t j = y * nx + x;
|
|
353
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
354
|
+
|
|
355
|
+
TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
|
|
356
|
+
|
|
357
|
+
// Pre-load PML coefficients into registers (optimization 1.2)
|
|
358
|
+
TIDE_DTYPE byh_val = __ldg(&byh[y]);
|
|
359
|
+
TIDE_DTYPE ayh_val = __ldg(&ayh[y]);
|
|
360
|
+
TIDE_DTYPE kyh_val = __ldg(&kyh[y]);
|
|
361
|
+
TIDE_DTYPE bxh_val = __ldg(&bxh[x]);
|
|
362
|
+
TIDE_DTYPE axh_val = __ldg(&axh[x]);
|
|
363
|
+
TIDE_DTYPE kxh_val = __ldg(&kxh[x]);
|
|
364
|
+
|
|
365
|
+
// Update Hx: Hx = Hx - cq * dEy/dz
|
|
366
|
+
if (y < ny - FD_PAD) {
|
|
367
|
+
bool pml_y = y < pml_y0h || y >= pml_y1h;
|
|
368
|
+
|
|
369
|
+
TIDE_DTYPE dey_dz = DIFFYH1(EY_L);
|
|
370
|
+
|
|
371
|
+
if (pml_y) {
|
|
372
|
+
m_ey_z[i] = byh_val * m_ey_z[i] + ayh_val * dey_dz;
|
|
373
|
+
dey_dz = dey_dz / kyh_val + m_ey_z[i];
|
|
374
|
+
}
|
|
375
|
+
|
|
376
|
+
hx[i] -= cq_shot_i * dey_dz;
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
// Update Hz: Hz = Hz + cq * dEy/dx
|
|
380
|
+
if (x < nx - FD_PAD) {
|
|
381
|
+
bool pml_x = x < pml_x0h || x >= pml_x1h;
|
|
382
|
+
|
|
383
|
+
TIDE_DTYPE dey_dx = DIFFXH1(EY_L);
|
|
384
|
+
|
|
385
|
+
if (pml_x) {
|
|
386
|
+
m_ey_x[i] = bxh_val * m_ey_x[i] + axh_val * dey_dx;
|
|
387
|
+
dey_dx = dey_dx / kxh_val + m_ey_x[i];
|
|
388
|
+
}
|
|
389
|
+
|
|
390
|
+
hz[i] += cq_shot_i * dey_dx;
|
|
391
|
+
}
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
#undef EY_L
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
// Forward kernel: Update E field (Ey) - standard version
|
|
398
|
+
__global__ __launch_bounds__(256) void forward_kernel_e(
|
|
399
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
400
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
401
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
402
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
403
|
+
TIDE_DTYPE *__restrict const ey,
|
|
404
|
+
TIDE_DTYPE *__restrict const m_hx_z,
|
|
405
|
+
TIDE_DTYPE *__restrict const m_hz_x,
|
|
406
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
407
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
408
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
409
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
410
|
+
TIDE_DTYPE const *__restrict const by,
|
|
411
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
412
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
413
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
414
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
415
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
416
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
417
|
+
TIDE_DTYPE const *__restrict const kxh) {
|
|
418
|
+
|
|
419
|
+
#if FD_PAD > 1
|
|
420
|
+
// Shared-memory tiling for Hx/Hz stencil loads.
|
|
421
|
+
// Assumes blockDim.z == 1 (one shot per block).
|
|
422
|
+
extern __shared__ TIDE_DTYPE shmem[];
|
|
423
|
+
int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
|
|
424
|
+
int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
|
|
425
|
+
int64_t const tile_pitch = tile_w;
|
|
426
|
+
int64_t const tile_numel = tile_w * tile_h;
|
|
427
|
+
TIDE_DTYPE *__restrict const tile_hx = shmem;
|
|
428
|
+
TIDE_DTYPE *__restrict const tile_hz = shmem + tile_numel;
|
|
429
|
+
#endif
|
|
430
|
+
|
|
431
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
432
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
433
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
434
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
435
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
436
|
+
(int64_t)threadIdx.z;
|
|
437
|
+
|
|
438
|
+
if (shot_idx >= n_shots) return;
|
|
439
|
+
|
|
440
|
+
#if FD_PAD > 1
|
|
441
|
+
int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
|
|
442
|
+
int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
|
|
443
|
+
int64_t const base = shot_idx * shot_numel;
|
|
444
|
+
int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
|
|
445
|
+
(int64_t)threadIdx.x;
|
|
446
|
+
int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
|
|
447
|
+
// Original scalar loading (optimization 2.1: vectorized loading disabled due to overhead)
|
|
448
|
+
for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
|
|
449
|
+
int64_t const ly = idx / tile_w;
|
|
450
|
+
int64_t const lx = idx - ly * tile_w;
|
|
451
|
+
int64_t const gx = x0 - FD_PAD + lx;
|
|
452
|
+
int64_t const gy = y0 - FD_PAD + ly;
|
|
453
|
+
if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
|
|
454
|
+
int64_t const g = base + gy * nx + gx;
|
|
455
|
+
int64_t const offset = ly * tile_pitch + lx;
|
|
456
|
+
tile_hx[offset] = __ldg(&hx[g]);
|
|
457
|
+
tile_hz[offset] = __ldg(&hz[g]);
|
|
458
|
+
} else {
|
|
459
|
+
int64_t const offset = ly * tile_pitch + lx;
|
|
460
|
+
tile_hx[offset] = (TIDE_DTYPE)0;
|
|
461
|
+
tile_hz[offset] = (TIDE_DTYPE)0;
|
|
462
|
+
}
|
|
463
|
+
}
|
|
464
|
+
__syncthreads();
|
|
465
|
+
|
|
466
|
+
#define HX_L(dy, dx) tile_hx[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
467
|
+
#define HZ_L(dy, dx) tile_hz[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
468
|
+
#else
|
|
469
|
+
#define HX_L(dy, dx) HX(dy, dx)
|
|
470
|
+
#define HZ_L(dy, dx) HZ(dy, dx)
|
|
471
|
+
#endif
|
|
472
|
+
|
|
473
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
474
|
+
int64_t j = y * nx + x;
|
|
475
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
476
|
+
|
|
477
|
+
TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
|
|
478
|
+
TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
|
|
479
|
+
|
|
480
|
+
bool pml_y = y < pml_y0 || y >= pml_y1;
|
|
481
|
+
bool pml_x = x < pml_x0 || x >= pml_x1;
|
|
482
|
+
|
|
483
|
+
TIDE_DTYPE dhz_dx = DIFFX1(HZ_L);
|
|
484
|
+
TIDE_DTYPE dhx_dz = DIFFY1(HX_L);
|
|
485
|
+
|
|
486
|
+
// Pre-load PML coefficients into registers (optimization 1.2)
|
|
487
|
+
TIDE_DTYPE bx_val = __ldg(&bx[x]);
|
|
488
|
+
TIDE_DTYPE ax_val = __ldg(&ax[x]);
|
|
489
|
+
TIDE_DTYPE kx_val = __ldg(&kx[x]);
|
|
490
|
+
TIDE_DTYPE by_val = __ldg(&by[y]);
|
|
491
|
+
TIDE_DTYPE ay_val = __ldg(&ay[y]);
|
|
492
|
+
TIDE_DTYPE ky_val = __ldg(&ky[y]);
|
|
493
|
+
|
|
494
|
+
if (pml_x) {
|
|
495
|
+
m_hz_x[i] = bx_val * m_hz_x[i] + ax_val * dhz_dx;
|
|
496
|
+
dhz_dx = dhz_dx / kx_val + m_hz_x[i];
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
if (pml_y) {
|
|
500
|
+
m_hx_z[i] = by_val * m_hx_z[i] + ay_val * dhx_dz;
|
|
501
|
+
dhx_dz = dhx_dz / ky_val + m_hx_z[i];
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
ey[i] = ca_shot_i * ey[i] + cb_shot_i * (dhz_dx - dhx_dz);
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
#undef HX_L
|
|
508
|
+
#undef HZ_L
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
// Forward kernel: Update E field (Ey) with storage for gradient computation
|
|
512
|
+
__global__ void forward_kernel_e_with_storage(
|
|
513
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
514
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
515
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
516
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
517
|
+
TIDE_DTYPE *__restrict const ey,
|
|
518
|
+
TIDE_DTYPE *__restrict const m_hx_z,
|
|
519
|
+
TIDE_DTYPE *__restrict const m_hz_x,
|
|
520
|
+
TIDE_DTYPE *__restrict const ey_store, // Can be NULL
|
|
521
|
+
TIDE_DTYPE *__restrict const curl_h_store, // Can be NULL
|
|
522
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
523
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
524
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
525
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
526
|
+
TIDE_DTYPE const *__restrict const by,
|
|
527
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
528
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
529
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
530
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
531
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
532
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
533
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
534
|
+
bool const ca_requires_grad,
|
|
535
|
+
bool const cb_requires_grad) {
|
|
536
|
+
|
|
537
|
+
#if FD_PAD > 1
|
|
538
|
+
// Shared-memory tiling for Hx/Hz stencil loads.
|
|
539
|
+
// Assumes blockDim.z == 1 (one shot per block).
|
|
540
|
+
extern __shared__ TIDE_DTYPE shmem[];
|
|
541
|
+
int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
|
|
542
|
+
int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
|
|
543
|
+
int64_t const tile_pitch = tile_w;
|
|
544
|
+
int64_t const tile_numel = tile_w * tile_h;
|
|
545
|
+
TIDE_DTYPE *__restrict const tile_hx = shmem;
|
|
546
|
+
TIDE_DTYPE *__restrict const tile_hz = shmem + tile_numel;
|
|
547
|
+
#endif
|
|
548
|
+
|
|
549
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
550
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
551
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
552
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
553
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
554
|
+
(int64_t)threadIdx.z;
|
|
555
|
+
|
|
556
|
+
if (shot_idx >= n_shots) return;
|
|
557
|
+
|
|
558
|
+
#if FD_PAD > 1
|
|
559
|
+
int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
|
|
560
|
+
int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
|
|
561
|
+
int64_t const base = shot_idx * shot_numel;
|
|
562
|
+
int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
|
|
563
|
+
(int64_t)threadIdx.x;
|
|
564
|
+
int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
|
|
565
|
+
// Original scalar loading (optimization 2.1: vectorized loading disabled due to overhead)
|
|
566
|
+
for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
|
|
567
|
+
int64_t const ly = idx / tile_w;
|
|
568
|
+
int64_t const lx = idx - ly * tile_w;
|
|
569
|
+
int64_t const gx = x0 - FD_PAD + lx;
|
|
570
|
+
int64_t const gy = y0 - FD_PAD + ly;
|
|
571
|
+
if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
|
|
572
|
+
int64_t const g = base + gy * nx + gx;
|
|
573
|
+
int64_t const offset = ly * tile_pitch + lx;
|
|
574
|
+
tile_hx[offset] = __ldg(&hx[g]);
|
|
575
|
+
tile_hz[offset] = __ldg(&hz[g]);
|
|
576
|
+
} else {
|
|
577
|
+
int64_t const offset = ly * tile_pitch + lx;
|
|
578
|
+
tile_hx[offset] = (TIDE_DTYPE)0;
|
|
579
|
+
tile_hz[offset] = (TIDE_DTYPE)0;
|
|
580
|
+
}
|
|
581
|
+
}
|
|
582
|
+
__syncthreads();
|
|
583
|
+
|
|
584
|
+
#define HX_L(dy, dx) tile_hx[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
585
|
+
#define HZ_L(dy, dx) tile_hz[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
586
|
+
#else
|
|
587
|
+
#define HX_L(dy, dx) HX(dy, dx)
|
|
588
|
+
#define HZ_L(dy, dx) HZ(dy, dx)
|
|
589
|
+
#endif
|
|
590
|
+
|
|
591
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots){
|
|
592
|
+
int64_t j = y * nx + x;
|
|
593
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
594
|
+
|
|
595
|
+
TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
|
|
596
|
+
TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
|
|
597
|
+
|
|
598
|
+
bool pml_y = y < pml_y0 || y >= pml_y1;
|
|
599
|
+
bool pml_x = x < pml_x0 || x >= pml_x1;
|
|
600
|
+
|
|
601
|
+
TIDE_DTYPE dhz_dx = DIFFX1(HZ_L);
|
|
602
|
+
TIDE_DTYPE dhx_dz = DIFFY1(HX_L);
|
|
603
|
+
|
|
604
|
+
// Pre-load PML coefficients into registers (optimization 1.2)
|
|
605
|
+
TIDE_DTYPE bx_val = __ldg(&bx[x]);
|
|
606
|
+
TIDE_DTYPE ax_val = __ldg(&ax[x]);
|
|
607
|
+
TIDE_DTYPE kx_val = __ldg(&kx[x]);
|
|
608
|
+
TIDE_DTYPE by_val = __ldg(&by[y]);
|
|
609
|
+
TIDE_DTYPE ay_val = __ldg(&ay[y]);
|
|
610
|
+
TIDE_DTYPE ky_val = __ldg(&ky[y]);
|
|
611
|
+
|
|
612
|
+
if (pml_x) {
|
|
613
|
+
m_hz_x[i] = bx_val * m_hz_x[i] + ax_val * dhz_dx;
|
|
614
|
+
dhz_dx = dhz_dx / kx_val + m_hz_x[i];
|
|
615
|
+
}
|
|
616
|
+
|
|
617
|
+
if (pml_y) {
|
|
618
|
+
m_hx_z[i] = by_val * m_hx_z[i] + ay_val * dhx_dz;
|
|
619
|
+
dhx_dz = dhx_dz / ky_val + m_hx_z[i];
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
|
|
623
|
+
|
|
624
|
+
// Store values for gradient computation (before E update)
|
|
625
|
+
if (ca_requires_grad && ey_store != nullptr) {
|
|
626
|
+
ey_store[i] = ey[i];
|
|
627
|
+
}
|
|
628
|
+
if (cb_requires_grad && curl_h_store != nullptr) {
|
|
629
|
+
curl_h_store[i] = curl_h;
|
|
630
|
+
}
|
|
631
|
+
|
|
632
|
+
ey[i] = ca_shot_i * ey[i] + cb_shot_i * curl_h;
|
|
633
|
+
}
|
|
634
|
+
|
|
635
|
+
#undef HX_L
|
|
636
|
+
#undef HZ_L
|
|
637
|
+
}
|
|
638
|
+
|
|
639
|
+
// Forward kernel: Update E field (Ey) with BF16 storage for gradient computation
|
|
640
|
+
// Stores Ey and curl_H in __nv_bfloat16 to reduce snapshot bandwidth/size.
|
|
641
|
+
__global__ void forward_kernel_e_with_storage_bf16(
|
|
642
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
643
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
644
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
645
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
646
|
+
TIDE_DTYPE *__restrict const ey,
|
|
647
|
+
TIDE_DTYPE *__restrict const m_hx_z,
|
|
648
|
+
TIDE_DTYPE *__restrict const m_hz_x,
|
|
649
|
+
__nv_bfloat16 *__restrict const ey_store, // Can be NULL
|
|
650
|
+
__nv_bfloat16 *__restrict const curl_h_store, // Can be NULL
|
|
651
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
652
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
653
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
654
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
655
|
+
TIDE_DTYPE const *__restrict const by,
|
|
656
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
657
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
658
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
659
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
660
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
661
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
662
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
663
|
+
bool const ca_requires_grad,
|
|
664
|
+
bool const cb_requires_grad) {
|
|
665
|
+
|
|
666
|
+
#if FD_PAD > 1
|
|
667
|
+
// Shared-memory tiling for Hx/Hz stencil loads.
|
|
668
|
+
// Assumes blockDim.z == 1 (one shot per block).
|
|
669
|
+
extern __shared__ TIDE_DTYPE shmem[];
|
|
670
|
+
int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
|
|
671
|
+
int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
|
|
672
|
+
int64_t const tile_pitch = tile_w;
|
|
673
|
+
int64_t const tile_numel = tile_w * tile_h;
|
|
674
|
+
TIDE_DTYPE *__restrict const tile_hx = shmem;
|
|
675
|
+
TIDE_DTYPE *__restrict const tile_hz = shmem + tile_numel;
|
|
676
|
+
#endif
|
|
677
|
+
|
|
678
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
679
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
680
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
681
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
682
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
683
|
+
(int64_t)threadIdx.z;
|
|
684
|
+
|
|
685
|
+
if (shot_idx >= n_shots) return;
|
|
686
|
+
|
|
687
|
+
#if FD_PAD > 1
|
|
688
|
+
int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
|
|
689
|
+
int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
|
|
690
|
+
int64_t const base = shot_idx * shot_numel;
|
|
691
|
+
int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
|
|
692
|
+
(int64_t)threadIdx.x;
|
|
693
|
+
int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
|
|
694
|
+
// Original scalar loading (optimization 2.1: vectorized loading disabled due to overhead)
|
|
695
|
+
for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
|
|
696
|
+
int64_t const ly = idx / tile_w;
|
|
697
|
+
int64_t const lx = idx - ly * tile_w;
|
|
698
|
+
int64_t const gx = x0 - FD_PAD + lx;
|
|
699
|
+
int64_t const gy = y0 - FD_PAD + ly;
|
|
700
|
+
if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
|
|
701
|
+
int64_t const g = base + gy * nx + gx;
|
|
702
|
+
int64_t const offset = ly * tile_pitch + lx;
|
|
703
|
+
tile_hx[offset] = __ldg(&hx[g]);
|
|
704
|
+
tile_hz[offset] = __ldg(&hz[g]);
|
|
705
|
+
} else {
|
|
706
|
+
int64_t const offset = ly * tile_pitch + lx;
|
|
707
|
+
tile_hx[offset] = (TIDE_DTYPE)0;
|
|
708
|
+
tile_hz[offset] = (TIDE_DTYPE)0;
|
|
709
|
+
}
|
|
710
|
+
}
|
|
711
|
+
__syncthreads();
|
|
712
|
+
|
|
713
|
+
#define HX_L(dy, dx) tile_hx[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
714
|
+
#define HZ_L(dy, dx) tile_hz[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
715
|
+
#else
|
|
716
|
+
#define HX_L(dy, dx) HX(dy, dx)
|
|
717
|
+
#define HZ_L(dy, dx) HZ(dy, dx)
|
|
718
|
+
#endif
|
|
719
|
+
|
|
720
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
721
|
+
int64_t j = y * nx + x;
|
|
722
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
723
|
+
|
|
724
|
+
TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
|
|
725
|
+
TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
|
|
726
|
+
|
|
727
|
+
bool pml_y = y < pml_y0 || y >= pml_y1;
|
|
728
|
+
bool pml_x = x < pml_x0 || x >= pml_x1;
|
|
729
|
+
|
|
730
|
+
TIDE_DTYPE dhz_dx = DIFFX1(HZ_L);
|
|
731
|
+
TIDE_DTYPE dhx_dz = DIFFY1(HX_L);
|
|
732
|
+
|
|
733
|
+
// Pre-load PML coefficients into registers (optimization 1.2)
|
|
734
|
+
TIDE_DTYPE bx_val = __ldg(&bx[x]);
|
|
735
|
+
TIDE_DTYPE ax_val = __ldg(&ax[x]);
|
|
736
|
+
TIDE_DTYPE kx_val = __ldg(&kx[x]);
|
|
737
|
+
TIDE_DTYPE by_val = __ldg(&by[y]);
|
|
738
|
+
TIDE_DTYPE ay_val = __ldg(&ay[y]);
|
|
739
|
+
TIDE_DTYPE ky_val = __ldg(&ky[y]);
|
|
740
|
+
|
|
741
|
+
if (pml_x) {
|
|
742
|
+
m_hz_x[i] = bx_val * m_hz_x[i] + ax_val * dhz_dx;
|
|
743
|
+
dhz_dx = dhz_dx / kx_val + m_hz_x[i];
|
|
744
|
+
}
|
|
745
|
+
|
|
746
|
+
if (pml_y) {
|
|
747
|
+
m_hx_z[i] = by_val * m_hx_z[i] + ay_val * dhx_dz;
|
|
748
|
+
dhx_dz = dhx_dz / ky_val + m_hx_z[i];
|
|
749
|
+
}
|
|
750
|
+
|
|
751
|
+
TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
|
|
752
|
+
|
|
753
|
+
if (ca_requires_grad && ey_store != nullptr) {
|
|
754
|
+
ey_store[i] = __float2bfloat16((float)ey[i]);
|
|
755
|
+
}
|
|
756
|
+
if (cb_requires_grad && curl_h_store != nullptr) {
|
|
757
|
+
curl_h_store[i] = __float2bfloat16((float)curl_h);
|
|
758
|
+
}
|
|
759
|
+
|
|
760
|
+
ey[i] = ca_shot_i * ey[i] + cb_shot_i * curl_h;
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
#undef HX_L
|
|
764
|
+
#undef HZ_L
|
|
765
|
+
}
|
|
766
|
+
|
|
767
|
+
// Forward kernel: Update E field (Ey) with FP8 storage for gradient computation.
|
|
768
|
+
__global__ void forward_kernel_e_with_storage_fp8(
|
|
769
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
770
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
771
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
772
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
773
|
+
TIDE_DTYPE *__restrict const ey,
|
|
774
|
+
TIDE_DTYPE *__restrict const m_hx_z,
|
|
775
|
+
TIDE_DTYPE *__restrict const m_hz_x,
|
|
776
|
+
uint8_t *__restrict const ey_store, // Can be NULL
|
|
777
|
+
uint8_t *__restrict const curl_h_store, // Can be NULL
|
|
778
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
779
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
780
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
781
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
782
|
+
TIDE_DTYPE const *__restrict const by,
|
|
783
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
784
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
785
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
786
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
787
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
788
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
789
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
790
|
+
bool const ca_requires_grad,
|
|
791
|
+
bool const cb_requires_grad) {
|
|
792
|
+
|
|
793
|
+
#if FD_PAD > 1
|
|
794
|
+
// Shared-memory tiling for Hx/Hz stencil loads.
|
|
795
|
+
// Assumes blockDim.z == 1 (one shot per block).
|
|
796
|
+
extern __shared__ TIDE_DTYPE shmem[];
|
|
797
|
+
int64_t const tile_w = (int64_t)blockDim.x + 2 * (int64_t)FD_PAD;
|
|
798
|
+
int64_t const tile_h = (int64_t)blockDim.y + 2 * (int64_t)FD_PAD;
|
|
799
|
+
int64_t const tile_pitch = tile_w;
|
|
800
|
+
int64_t const tile_numel = tile_w * tile_h;
|
|
801
|
+
TIDE_DTYPE *__restrict const tile_hx = shmem;
|
|
802
|
+
TIDE_DTYPE *__restrict const tile_hz = shmem + tile_numel;
|
|
803
|
+
#endif
|
|
804
|
+
|
|
805
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
806
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
807
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
808
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
809
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
810
|
+
(int64_t)threadIdx.z;
|
|
811
|
+
|
|
812
|
+
if (shot_idx >= n_shots) return;
|
|
813
|
+
|
|
814
|
+
#if FD_PAD > 1
|
|
815
|
+
int64_t const x0 = (int64_t)blockIdx.x * (int64_t)blockDim.x + FD_PAD;
|
|
816
|
+
int64_t const y0 = (int64_t)blockIdx.y * (int64_t)blockDim.y + FD_PAD;
|
|
817
|
+
int64_t const base = shot_idx * shot_numel;
|
|
818
|
+
int64_t const t = (int64_t)threadIdx.y * (int64_t)blockDim.x +
|
|
819
|
+
(int64_t)threadIdx.x;
|
|
820
|
+
int64_t const nthreads = (int64_t)blockDim.x * (int64_t)blockDim.y;
|
|
821
|
+
for (int64_t idx = t; idx < tile_numel; idx += nthreads) {
|
|
822
|
+
int64_t const ly = idx / tile_w;
|
|
823
|
+
int64_t const lx = idx - ly * tile_w;
|
|
824
|
+
int64_t const gx = x0 - FD_PAD + lx;
|
|
825
|
+
int64_t const gy = y0 - FD_PAD + ly;
|
|
826
|
+
if (0 <= gx && gx < nx && 0 <= gy && gy < ny) {
|
|
827
|
+
int64_t const g = base + gy * nx + gx;
|
|
828
|
+
int64_t const offset = ly * tile_pitch + lx;
|
|
829
|
+
tile_hx[offset] = __ldg(&hx[g]);
|
|
830
|
+
tile_hz[offset] = __ldg(&hz[g]);
|
|
831
|
+
} else {
|
|
832
|
+
int64_t const offset = ly * tile_pitch + lx;
|
|
833
|
+
tile_hx[offset] = (TIDE_DTYPE)0;
|
|
834
|
+
tile_hz[offset] = (TIDE_DTYPE)0;
|
|
835
|
+
}
|
|
836
|
+
}
|
|
837
|
+
__syncthreads();
|
|
838
|
+
|
|
839
|
+
#define HX_L(dy, dx) tile_hx[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
840
|
+
#define HZ_L(dy, dx) tile_hz[((int64_t)threadIdx.y + (int64_t)FD_PAD + (dy)) * tile_pitch + ((int64_t)threadIdx.x + (int64_t)FD_PAD + (dx))]
|
|
841
|
+
#else
|
|
842
|
+
#define HX_L(dy, dx) HX(dy, dx)
|
|
843
|
+
#define HZ_L(dy, dx) HZ(dy, dx)
|
|
844
|
+
#endif
|
|
845
|
+
|
|
846
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
847
|
+
int64_t j = y * nx + x;
|
|
848
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
849
|
+
|
|
850
|
+
TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
|
|
851
|
+
TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
|
|
852
|
+
|
|
853
|
+
bool pml_y = y < pml_y0 || y >= pml_y1;
|
|
854
|
+
bool pml_x = x < pml_x0 || x >= pml_x1;
|
|
855
|
+
|
|
856
|
+
TIDE_DTYPE dhz_dx = DIFFX1(HZ_L);
|
|
857
|
+
TIDE_DTYPE dhx_dz = DIFFY1(HX_L);
|
|
858
|
+
|
|
859
|
+
// Pre-load PML coefficients into registers (optimization 1.2)
|
|
860
|
+
TIDE_DTYPE bx_val = __ldg(&bx[x]);
|
|
861
|
+
TIDE_DTYPE ax_val = __ldg(&ax[x]);
|
|
862
|
+
TIDE_DTYPE kx_val = __ldg(&kx[x]);
|
|
863
|
+
TIDE_DTYPE by_val = __ldg(&by[y]);
|
|
864
|
+
TIDE_DTYPE ay_val = __ldg(&ay[y]);
|
|
865
|
+
TIDE_DTYPE ky_val = __ldg(&ky[y]);
|
|
866
|
+
|
|
867
|
+
if (pml_x) {
|
|
868
|
+
m_hz_x[i] = bx_val * m_hz_x[i] + ax_val * dhz_dx;
|
|
869
|
+
dhz_dx = dhz_dx / kx_val + m_hz_x[i];
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
if (pml_y) {
|
|
873
|
+
m_hx_z[i] = by_val * m_hx_z[i] + ay_val * dhx_dz;
|
|
874
|
+
dhx_dz = dhx_dz / ky_val + m_hx_z[i];
|
|
875
|
+
}
|
|
876
|
+
|
|
877
|
+
TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
|
|
878
|
+
|
|
879
|
+
if (ca_requires_grad && ey_store != nullptr) {
|
|
880
|
+
ey_store[i] = fp8_e4m3_from_float((float)ey[i]);
|
|
881
|
+
}
|
|
882
|
+
if (cb_requires_grad && curl_h_store != nullptr) {
|
|
883
|
+
curl_h_store[i] = fp8_e4m3_from_float((float)curl_h);
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
ey[i] = ca_shot_i * ey[i] + cb_shot_i * curl_h;
|
|
887
|
+
}
|
|
888
|
+
|
|
889
|
+
#undef HX_L
|
|
890
|
+
#undef HZ_L
|
|
891
|
+
}
|
|
892
|
+
|
|
893
|
+
// Backward kernel: Update adjoint λ_H fields
|
|
894
|
+
__global__ void backward_kernel_lambda_h(
|
|
895
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
896
|
+
TIDE_DTYPE const *__restrict const lambda_ey,
|
|
897
|
+
TIDE_DTYPE *__restrict const lambda_hx,
|
|
898
|
+
TIDE_DTYPE *__restrict const lambda_hz,
|
|
899
|
+
TIDE_DTYPE *__restrict const m_lambda_ey_x,
|
|
900
|
+
TIDE_DTYPE *__restrict const m_lambda_ey_z,
|
|
901
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
902
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
903
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
904
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
905
|
+
TIDE_DTYPE const *__restrict const by,
|
|
906
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
907
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
908
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
909
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
910
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
911
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
912
|
+
TIDE_DTYPE const *__restrict const kxh) {
|
|
913
|
+
|
|
914
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
915
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
916
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
917
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
918
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
919
|
+
(int64_t)threadIdx.z;
|
|
920
|
+
|
|
921
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
922
|
+
int64_t const pml_y0h = pml_y0;
|
|
923
|
+
int64_t const pml_y1h = MAX(pml_y0, pml_y1 - 1);
|
|
924
|
+
int64_t const pml_x0h = pml_x0;
|
|
925
|
+
int64_t const pml_x1h = MAX(pml_x0, pml_x1 - 1);
|
|
926
|
+
|
|
927
|
+
int64_t j = y * nx + x;
|
|
928
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
929
|
+
|
|
930
|
+
TIDE_DTYPE const cb_shot_i = cb_batched ? cb[i] : cb[j];
|
|
931
|
+
|
|
932
|
+
// Update λ_Hx: λ_Hx = λ_Hx - cb * D_z^T[λ_Ey]
|
|
933
|
+
// EXACT ADJOINT: use transpose of DIFFYH1 -> which is DIFFY1
|
|
934
|
+
if (y < ny - FD_PAD) {
|
|
935
|
+
bool pml_y = y < pml_y0h || y >= pml_y1h;
|
|
936
|
+
|
|
937
|
+
TIDE_DTYPE d_lambda_ey_dz = DIFFY1(LAMBDA_EY);
|
|
938
|
+
|
|
939
|
+
if (pml_y) {
|
|
940
|
+
m_lambda_ey_z[i] = __ldg(&byh[y]) * m_lambda_ey_z[i] + __ldg(&ayh[y]) * d_lambda_ey_dz;
|
|
941
|
+
d_lambda_ey_dz = d_lambda_ey_dz / __ldg(&kyh[y]) + m_lambda_ey_z[i];
|
|
942
|
+
}
|
|
943
|
+
|
|
944
|
+
lambda_hx[i] -= cb_shot_i * d_lambda_ey_dz;
|
|
945
|
+
}
|
|
946
|
+
|
|
947
|
+
// Update λ_Hz: λ_Hz = λ_Hz + cb * D_x^T[λ_Ey]
|
|
948
|
+
// EXACT ADJOINT: use transpose of DIFFXH1 -> which is DIFFX1
|
|
949
|
+
if (x < nx - FD_PAD) {
|
|
950
|
+
bool pml_x = x < pml_x0h || x >= pml_x1h;
|
|
951
|
+
|
|
952
|
+
TIDE_DTYPE d_lambda_ey_dx = DIFFX1(LAMBDA_EY);
|
|
953
|
+
|
|
954
|
+
if (pml_x) {
|
|
955
|
+
m_lambda_ey_x[i] = __ldg(&bxh[x]) * m_lambda_ey_x[i] + __ldg(&axh[x]) * d_lambda_ey_dx;
|
|
956
|
+
d_lambda_ey_dx = d_lambda_ey_dx / __ldg(&kxh[x]) + m_lambda_ey_x[i];
|
|
957
|
+
}
|
|
958
|
+
|
|
959
|
+
lambda_hz[i] += cb_shot_i * d_lambda_ey_dx;
|
|
960
|
+
}
|
|
961
|
+
}
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
// Backward kernel: Update adjoint λ_Ey field with per-shot gradient accumulation
|
|
965
|
+
// Uses pml_y0/pml_y1/pml_x0/pml_x1 for both adjoint propagation and gradient masking
|
|
966
|
+
// NO atomicAdd - each shot writes to its own memory region
|
|
967
|
+
__global__ void backward_kernel_lambda_e_with_grad(
|
|
968
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
969
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
970
|
+
TIDE_DTYPE const *__restrict const lambda_hx,
|
|
971
|
+
TIDE_DTYPE const *__restrict const lambda_hz,
|
|
972
|
+
TIDE_DTYPE *__restrict const lambda_ey,
|
|
973
|
+
TIDE_DTYPE *__restrict const m_lambda_hx_z,
|
|
974
|
+
TIDE_DTYPE *__restrict const m_lambda_hz_x,
|
|
975
|
+
TIDE_DTYPE const *__restrict const ey_store,
|
|
976
|
+
TIDE_DTYPE const *__restrict const curl_h_store,
|
|
977
|
+
TIDE_DTYPE *__restrict const grad_ca_shot, // [n_shots, ny, nx] - per-shot gradient
|
|
978
|
+
TIDE_DTYPE *__restrict const grad_cb_shot, // [n_shots, ny, nx] - per-shot gradient
|
|
979
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
980
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
981
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
982
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
983
|
+
TIDE_DTYPE const *__restrict const by,
|
|
984
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
985
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
986
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
987
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
988
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
989
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
990
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
991
|
+
bool const ca_requires_grad,
|
|
992
|
+
bool const cb_requires_grad,
|
|
993
|
+
int64_t const step_ratio_val) {
|
|
994
|
+
|
|
995
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
996
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
997
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
998
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
999
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
1000
|
+
(int64_t)threadIdx.z;
|
|
1001
|
+
|
|
1002
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
1003
|
+
int64_t j = y * nx + x;
|
|
1004
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
1005
|
+
|
|
1006
|
+
TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
|
|
1007
|
+
TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
|
|
1008
|
+
|
|
1009
|
+
// Determine PML region (pml_y/pml_x = true means in PML region)
|
|
1010
|
+
bool pml_y = y < pml_y0 || y >= pml_y1;
|
|
1011
|
+
bool pml_x = x < pml_x0 || x >= pml_x1;
|
|
1012
|
+
|
|
1013
|
+
// Compute D_x^{hT}[λ_Hz] at integer grid points
|
|
1014
|
+
// EXACT ADJOINT: use transpose of DIFFX1 -> which is DIFFXH1
|
|
1015
|
+
TIDE_DTYPE d_lambda_hz_dx = DIFFXH1(LAMBDA_HZ);
|
|
1016
|
+
// Compute D_z^{hT}[λ_Hx] at integer grid points
|
|
1017
|
+
// EXACT ADJOINT: use transpose of DIFFY1 -> which is DIFFYH1
|
|
1018
|
+
TIDE_DTYPE d_lambda_hx_dz = DIFFYH1(LAMBDA_HX);
|
|
1019
|
+
|
|
1020
|
+
// Pre-load PML coefficients into registers (optimization 1.2)
|
|
1021
|
+
TIDE_DTYPE bx_val = __ldg(&bx[x]);
|
|
1022
|
+
TIDE_DTYPE ax_val = __ldg(&ax[x]);
|
|
1023
|
+
TIDE_DTYPE kx_val = __ldg(&kx[x]);
|
|
1024
|
+
TIDE_DTYPE by_val = __ldg(&by[y]);
|
|
1025
|
+
TIDE_DTYPE ay_val = __ldg(&ay[y]);
|
|
1026
|
+
TIDE_DTYPE ky_val = __ldg(&ky[y]);
|
|
1027
|
+
|
|
1028
|
+
// Apply adjoint CPML for d(λ_Hz)/dx (only in PML region)
|
|
1029
|
+
if (pml_x) {
|
|
1030
|
+
m_lambda_hz_x[i] = bx_val * m_lambda_hz_x[i] + ax_val * d_lambda_hz_dx;
|
|
1031
|
+
d_lambda_hz_dx = d_lambda_hz_dx / kx_val + m_lambda_hz_x[i];
|
|
1032
|
+
}
|
|
1033
|
+
|
|
1034
|
+
// Apply adjoint CPML for d(λ_Hx)/dz (only in PML region)
|
|
1035
|
+
if (pml_y) {
|
|
1036
|
+
m_lambda_hx_z[i] = by_val * m_lambda_hx_z[i] + ay_val * d_lambda_hx_dz;
|
|
1037
|
+
d_lambda_hx_dz = d_lambda_hx_dz / ky_val + m_lambda_hx_z[i];
|
|
1038
|
+
}
|
|
1039
|
+
|
|
1040
|
+
// curl_λH = d(λ_Hz)/dx - d(λ_Hx)/dz
|
|
1041
|
+
TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
|
|
1042
|
+
|
|
1043
|
+
// Store current λ_Ey before update (this is λ_Ey^{n+1})
|
|
1044
|
+
TIDE_DTYPE lambda_ey_curr = lambda_ey[i];
|
|
1045
|
+
|
|
1046
|
+
// Update λ_Ey: λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * curl_λH
|
|
1047
|
+
lambda_ey[i] = ca_shot_i * lambda_ey_curr + cq_shot_i * curl_lambda_h;
|
|
1048
|
+
|
|
1049
|
+
// Accumulate per-shot gradients only in interior region (!pml_y && !pml_x)
|
|
1050
|
+
if (!pml_y && !pml_x) {
|
|
1051
|
+
// grad_ca_shot[shot_idx, y, x] += λ_Ey^{n+1} * E_y^n
|
|
1052
|
+
// Convert from BF16 back to FP32 for computation
|
|
1053
|
+
if (ca_requires_grad && ey_store != nullptr) {
|
|
1054
|
+
TIDE_DTYPE ey_n = ey_store[i];
|
|
1055
|
+
grad_ca_shot[i] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio_val;
|
|
1056
|
+
}
|
|
1057
|
+
|
|
1058
|
+
// grad_cb_shot[shot_idx, y, x] += λ_Ey^{n+1} * curl_H^n
|
|
1059
|
+
if (cb_requires_grad && curl_h_store != nullptr) {
|
|
1060
|
+
TIDE_DTYPE curl_h_n = curl_h_store[i];
|
|
1061
|
+
grad_cb_shot[i] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio_val;
|
|
1062
|
+
}
|
|
1063
|
+
}
|
|
1064
|
+
}
|
|
1065
|
+
}
|
|
1066
|
+
|
|
1067
|
+
// Backward kernel: Update adjoint λ_Ey field with BF16 snapshot loads.
|
|
1068
|
+
__global__ void backward_kernel_lambda_e_with_grad_bf16(
|
|
1069
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
1070
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
1071
|
+
TIDE_DTYPE const *__restrict const lambda_hx,
|
|
1072
|
+
TIDE_DTYPE const *__restrict const lambda_hz,
|
|
1073
|
+
TIDE_DTYPE *__restrict const lambda_ey,
|
|
1074
|
+
TIDE_DTYPE *__restrict const m_lambda_hx_z,
|
|
1075
|
+
TIDE_DTYPE *__restrict const m_lambda_hz_x,
|
|
1076
|
+
__nv_bfloat16 const *__restrict const ey_store,
|
|
1077
|
+
__nv_bfloat16 const *__restrict const curl_h_store,
|
|
1078
|
+
TIDE_DTYPE *__restrict const grad_ca_shot,
|
|
1079
|
+
TIDE_DTYPE *__restrict const grad_cb_shot,
|
|
1080
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
1081
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
1082
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
1083
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
1084
|
+
TIDE_DTYPE const *__restrict const by,
|
|
1085
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
1086
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
1087
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
1088
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
1089
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
1090
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
1091
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
1092
|
+
bool const ca_requires_grad,
|
|
1093
|
+
bool const cb_requires_grad,
|
|
1094
|
+
int64_t const step_ratio_val) {
|
|
1095
|
+
|
|
1096
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
1097
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
1098
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
1099
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
1100
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
1101
|
+
(int64_t)threadIdx.z;
|
|
1102
|
+
|
|
1103
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
1104
|
+
int64_t j = y * nx + x;
|
|
1105
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
1106
|
+
|
|
1107
|
+
TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
|
|
1108
|
+
TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
|
|
1109
|
+
|
|
1110
|
+
bool pml_y = y < pml_y0 || y >= pml_y1;
|
|
1111
|
+
bool pml_x = x < pml_x0 || x >= pml_x1;
|
|
1112
|
+
|
|
1113
|
+
// EXACT ADJOINT: use transposed difference operators
|
|
1114
|
+
TIDE_DTYPE d_lambda_hz_dx = DIFFXH1(LAMBDA_HZ);
|
|
1115
|
+
TIDE_DTYPE d_lambda_hx_dz = DIFFYH1(LAMBDA_HX);
|
|
1116
|
+
|
|
1117
|
+
// Pre-load PML coefficients into registers (optimization 1.2)
|
|
1118
|
+
TIDE_DTYPE bx_val = __ldg(&bx[x]);
|
|
1119
|
+
TIDE_DTYPE ax_val = __ldg(&ax[x]);
|
|
1120
|
+
TIDE_DTYPE kx_val = __ldg(&kx[x]);
|
|
1121
|
+
TIDE_DTYPE by_val = __ldg(&by[y]);
|
|
1122
|
+
TIDE_DTYPE ay_val = __ldg(&ay[y]);
|
|
1123
|
+
TIDE_DTYPE ky_val = __ldg(&ky[y]);
|
|
1124
|
+
|
|
1125
|
+
if (pml_x) {
|
|
1126
|
+
m_lambda_hz_x[i] = bx_val * m_lambda_hz_x[i] + ax_val * d_lambda_hz_dx;
|
|
1127
|
+
d_lambda_hz_dx = d_lambda_hz_dx / kx_val + m_lambda_hz_x[i];
|
|
1128
|
+
}
|
|
1129
|
+
|
|
1130
|
+
if (pml_y) {
|
|
1131
|
+
m_lambda_hx_z[i] = by_val * m_lambda_hx_z[i] + ay_val * d_lambda_hx_dz;
|
|
1132
|
+
d_lambda_hx_dz = d_lambda_hx_dz / ky_val + m_lambda_hx_z[i];
|
|
1133
|
+
}
|
|
1134
|
+
|
|
1135
|
+
TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
|
|
1136
|
+
|
|
1137
|
+
TIDE_DTYPE lambda_ey_curr = lambda_ey[i];
|
|
1138
|
+
lambda_ey[i] = ca_shot_i * lambda_ey_curr + cq_shot_i * curl_lambda_h;
|
|
1139
|
+
|
|
1140
|
+
if (!pml_y && !pml_x) {
|
|
1141
|
+
if (ca_requires_grad && ey_store != nullptr) {
|
|
1142
|
+
TIDE_DTYPE ey_n = (TIDE_DTYPE)__bfloat162float(ey_store[i]);
|
|
1143
|
+
grad_ca_shot[i] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio_val;
|
|
1144
|
+
}
|
|
1145
|
+
if (cb_requires_grad && curl_h_store != nullptr) {
|
|
1146
|
+
TIDE_DTYPE curl_h_n = (TIDE_DTYPE)__bfloat162float(curl_h_store[i]);
|
|
1147
|
+
grad_cb_shot[i] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio_val;
|
|
1148
|
+
}
|
|
1149
|
+
}
|
|
1150
|
+
}
|
|
1151
|
+
}
|
|
1152
|
+
|
|
1153
|
+
// Backward kernel: Update adjoint λ_Ey field with FP8 snapshot loads.
|
|
1154
|
+
__global__ void backward_kernel_lambda_e_with_grad_fp8(
|
|
1155
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
1156
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
1157
|
+
TIDE_DTYPE const *__restrict const lambda_hx,
|
|
1158
|
+
TIDE_DTYPE const *__restrict const lambda_hz,
|
|
1159
|
+
TIDE_DTYPE *__restrict const lambda_ey,
|
|
1160
|
+
TIDE_DTYPE *__restrict const m_lambda_hx_z,
|
|
1161
|
+
TIDE_DTYPE *__restrict const m_lambda_hz_x,
|
|
1162
|
+
uint8_t const *__restrict const ey_store,
|
|
1163
|
+
uint8_t const *__restrict const curl_h_store,
|
|
1164
|
+
TIDE_DTYPE *__restrict const grad_ca_shot,
|
|
1165
|
+
TIDE_DTYPE *__restrict const grad_cb_shot,
|
|
1166
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
1167
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
1168
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
1169
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
1170
|
+
TIDE_DTYPE const *__restrict const by,
|
|
1171
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
1172
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
1173
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
1174
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
1175
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
1176
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
1177
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
1178
|
+
bool const ca_requires_grad,
|
|
1179
|
+
bool const cb_requires_grad,
|
|
1180
|
+
int64_t const step_ratio_val) {
|
|
1181
|
+
|
|
1182
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
1183
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
1184
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
1185
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
1186
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
1187
|
+
(int64_t)threadIdx.z;
|
|
1188
|
+
|
|
1189
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
1190
|
+
int64_t j = y * nx + x;
|
|
1191
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
1192
|
+
|
|
1193
|
+
TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
|
|
1194
|
+
TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
|
|
1195
|
+
|
|
1196
|
+
bool pml_y = y < pml_y0 || y >= pml_y1;
|
|
1197
|
+
bool pml_x = x < pml_x0 || x >= pml_x1;
|
|
1198
|
+
|
|
1199
|
+
TIDE_DTYPE d_lambda_hz_dx = DIFFXH1(LAMBDA_HZ);
|
|
1200
|
+
TIDE_DTYPE d_lambda_hx_dz = DIFFYH1(LAMBDA_HX);
|
|
1201
|
+
|
|
1202
|
+
TIDE_DTYPE bx_val = __ldg(&bx[x]);
|
|
1203
|
+
TIDE_DTYPE ax_val = __ldg(&ax[x]);
|
|
1204
|
+
TIDE_DTYPE kx_val = __ldg(&kx[x]);
|
|
1205
|
+
TIDE_DTYPE by_val = __ldg(&by[y]);
|
|
1206
|
+
TIDE_DTYPE ay_val = __ldg(&ay[y]);
|
|
1207
|
+
TIDE_DTYPE ky_val = __ldg(&ky[y]);
|
|
1208
|
+
|
|
1209
|
+
if (pml_x) {
|
|
1210
|
+
m_lambda_hz_x[i] = bx_val * m_lambda_hz_x[i] + ax_val * d_lambda_hz_dx;
|
|
1211
|
+
d_lambda_hz_dx = d_lambda_hz_dx / kx_val + m_lambda_hz_x[i];
|
|
1212
|
+
}
|
|
1213
|
+
|
|
1214
|
+
if (pml_y) {
|
|
1215
|
+
m_lambda_hx_z[i] = by_val * m_lambda_hx_z[i] + ay_val * d_lambda_hx_dz;
|
|
1216
|
+
d_lambda_hx_dz = d_lambda_hx_dz / ky_val + m_lambda_hx_z[i];
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
|
|
1220
|
+
|
|
1221
|
+
TIDE_DTYPE lambda_ey_curr = lambda_ey[i];
|
|
1222
|
+
lambda_ey[i] = ca_shot_i * lambda_ey_curr + cq_shot_i * curl_lambda_h;
|
|
1223
|
+
|
|
1224
|
+
if (!pml_y && !pml_x) {
|
|
1225
|
+
if (ca_requires_grad && ey_store != nullptr) {
|
|
1226
|
+
TIDE_DTYPE ey_n = (TIDE_DTYPE)fp8_e4m3_to_float(ey_store[i]);
|
|
1227
|
+
grad_ca_shot[i] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio_val;
|
|
1228
|
+
}
|
|
1229
|
+
if (cb_requires_grad && curl_h_store != nullptr) {
|
|
1230
|
+
TIDE_DTYPE curl_h_n = (TIDE_DTYPE)fp8_e4m3_to_float(curl_h_store[i]);
|
|
1231
|
+
grad_cb_shot[i] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio_val;
|
|
1232
|
+
}
|
|
1233
|
+
}
|
|
1234
|
+
}
|
|
1235
|
+
}
|
|
1236
|
+
|
|
1237
|
+
// Backward kernel: Update adjoint λ_Ey field (no gradient accumulation).
|
|
1238
|
+
__global__ void backward_kernel_lambda_e(
|
|
1239
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
1240
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
1241
|
+
TIDE_DTYPE const *__restrict const lambda_hx,
|
|
1242
|
+
TIDE_DTYPE const *__restrict const lambda_hz,
|
|
1243
|
+
TIDE_DTYPE *__restrict const lambda_ey,
|
|
1244
|
+
TIDE_DTYPE *__restrict const m_lambda_hx_z,
|
|
1245
|
+
TIDE_DTYPE *__restrict const m_lambda_hz_x,
|
|
1246
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
1247
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
1248
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
1249
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
1250
|
+
TIDE_DTYPE const *__restrict const by,
|
|
1251
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
1252
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
1253
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
1254
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
1255
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
1256
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
1257
|
+
TIDE_DTYPE const *__restrict const kxh) {
|
|
1258
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
1259
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
1260
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
1261
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
1262
|
+
int64_t shot_idx = (int64_t)blockIdx.z * (int64_t)blockDim.z +
|
|
1263
|
+
(int64_t)threadIdx.z;
|
|
1264
|
+
|
|
1265
|
+
if (y < ny - FD_PAD + 1 && x < nx - FD_PAD + 1 && shot_idx < n_shots) {
|
|
1266
|
+
int64_t j = y * nx + x;
|
|
1267
|
+
int64_t i = shot_idx * shot_numel + j;
|
|
1268
|
+
|
|
1269
|
+
(void)ayh;
|
|
1270
|
+
(void)axh;
|
|
1271
|
+
(void)byh;
|
|
1272
|
+
(void)bxh;
|
|
1273
|
+
(void)kyh;
|
|
1274
|
+
(void)kxh;
|
|
1275
|
+
|
|
1276
|
+
TIDE_DTYPE const ca_shot_i = ca_batched ? ca[i] : ca[j];
|
|
1277
|
+
TIDE_DTYPE const cq_shot_i = cq_batched ? cq[i] : cq[j];
|
|
1278
|
+
|
|
1279
|
+
bool pml_y = y < pml_y0 || y >= pml_y1;
|
|
1280
|
+
bool pml_x = x < pml_x0 || x >= pml_x1;
|
|
1281
|
+
|
|
1282
|
+
// EXACT ADJOINT: use transposed difference operators
|
|
1283
|
+
TIDE_DTYPE d_lambda_hz_dx = DIFFXH1(LAMBDA_HZ);
|
|
1284
|
+
TIDE_DTYPE d_lambda_hx_dz = DIFFYH1(LAMBDA_HX);
|
|
1285
|
+
|
|
1286
|
+
// Pre-load PML coefficients into registers (optimization 1.2)
|
|
1287
|
+
TIDE_DTYPE bx_val = __ldg(&bx[x]);
|
|
1288
|
+
TIDE_DTYPE ax_val = __ldg(&ax[x]);
|
|
1289
|
+
TIDE_DTYPE kx_val = __ldg(&kx[x]);
|
|
1290
|
+
TIDE_DTYPE by_val = __ldg(&by[y]);
|
|
1291
|
+
TIDE_DTYPE ay_val = __ldg(&ay[y]);
|
|
1292
|
+
TIDE_DTYPE ky_val = __ldg(&ky[y]);
|
|
1293
|
+
|
|
1294
|
+
if (pml_x) {
|
|
1295
|
+
m_lambda_hz_x[i] = bx_val * m_lambda_hz_x[i] + ax_val * d_lambda_hz_dx;
|
|
1296
|
+
d_lambda_hz_dx = d_lambda_hz_dx / kx_val + m_lambda_hz_x[i];
|
|
1297
|
+
}
|
|
1298
|
+
|
|
1299
|
+
if (pml_y) {
|
|
1300
|
+
m_lambda_hx_z[i] = by_val * m_lambda_hx_z[i] + ay_val * d_lambda_hx_dz;
|
|
1301
|
+
d_lambda_hx_dz = d_lambda_hx_dz / ky_val + m_lambda_hx_z[i];
|
|
1302
|
+
}
|
|
1303
|
+
|
|
1304
|
+
TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
|
|
1305
|
+
|
|
1306
|
+
TIDE_DTYPE lambda_ey_curr = lambda_ey[i];
|
|
1307
|
+
lambda_ey[i] = ca_shot_i * lambda_ey_curr + cq_shot_i * curl_lambda_h;
|
|
1308
|
+
}
|
|
1309
|
+
}
|
|
1310
|
+
|
|
1311
|
+
// Combine per-shot gradients into final gradient (sum across shots)
|
|
1312
|
+
__global__ void combine_grad(TIDE_DTYPE *__restrict const grad,
|
|
1313
|
+
TIDE_DTYPE const *__restrict const grad_shot) {
|
|
1314
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
1315
|
+
(int64_t)threadIdx.x + FD_PAD;
|
|
1316
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
1317
|
+
(int64_t)threadIdx.y + FD_PAD;
|
|
1318
|
+
if (y < ny - FD_PAD && x < nx - FD_PAD) {
|
|
1319
|
+
int64_t j = y * nx + x;
|
|
1320
|
+
int64_t const stride = shot_numel;
|
|
1321
|
+
TIDE_DTYPE sum = 0;
|
|
1322
|
+
#pragma unroll 4
|
|
1323
|
+
for (int64_t shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
1324
|
+
sum += grad_shot[shot_idx * stride + j];
|
|
1325
|
+
}
|
|
1326
|
+
grad[j] += sum;
|
|
1327
|
+
}
|
|
1328
|
+
}
|
|
1329
|
+
|
|
1330
|
+
__global__ void convert_grad_ca_cb_to_eps_sigma(
|
|
1331
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
1332
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
1333
|
+
TIDE_DTYPE const *__restrict const grad_ca,
|
|
1334
|
+
TIDE_DTYPE const *__restrict const grad_cb,
|
|
1335
|
+
TIDE_DTYPE const *__restrict const grad_ca_shot,
|
|
1336
|
+
TIDE_DTYPE const *__restrict const grad_cb_shot,
|
|
1337
|
+
TIDE_DTYPE *__restrict const grad_eps,
|
|
1338
|
+
TIDE_DTYPE *__restrict const grad_sigma,
|
|
1339
|
+
TIDE_DTYPE const dt,
|
|
1340
|
+
bool const ca_requires_grad,
|
|
1341
|
+
bool const cb_requires_grad,
|
|
1342
|
+
bool const ca_batched_h,
|
|
1343
|
+
bool const cb_batched_h) {
|
|
1344
|
+
int64_t x = (int64_t)blockIdx.x * (int64_t)blockDim.x +
|
|
1345
|
+
(int64_t)threadIdx.x;
|
|
1346
|
+
int64_t y = (int64_t)blockIdx.y * (int64_t)blockDim.y +
|
|
1347
|
+
(int64_t)threadIdx.y;
|
|
1348
|
+
if (x >= nx || y >= ny) {
|
|
1349
|
+
return;
|
|
1350
|
+
}
|
|
1351
|
+
|
|
1352
|
+
int64_t shot_idx = (int64_t)blockIdx.z;
|
|
1353
|
+
if (!ca_batched_h) {
|
|
1354
|
+
shot_idx = 0;
|
|
1355
|
+
}
|
|
1356
|
+
|
|
1357
|
+
int64_t const j = y * nx + x;
|
|
1358
|
+
int64_t const idx_shot = shot_idx * shot_numel + j;
|
|
1359
|
+
int64_t const out_idx = ca_batched_h ? idx_shot : j;
|
|
1360
|
+
int64_t const ca_idx = ca_batched_h ? idx_shot : j;
|
|
1361
|
+
int64_t const cb_idx = cb_batched_h ? idx_shot : j;
|
|
1362
|
+
|
|
1363
|
+
TIDE_DTYPE const ca_val = ca[ca_idx];
|
|
1364
|
+
TIDE_DTYPE const cb_val = cb[cb_idx];
|
|
1365
|
+
TIDE_DTYPE const cb_sq = cb_val * cb_val;
|
|
1366
|
+
TIDE_DTYPE const inv_dt = (TIDE_DTYPE)1 / dt;
|
|
1367
|
+
|
|
1368
|
+
TIDE_DTYPE grad_ca_val = 0;
|
|
1369
|
+
if (ca_requires_grad) {
|
|
1370
|
+
grad_ca_val = ca_batched_h ? grad_ca_shot[idx_shot] : grad_ca[j];
|
|
1371
|
+
}
|
|
1372
|
+
|
|
1373
|
+
TIDE_DTYPE grad_cb_val = 0;
|
|
1374
|
+
if (cb_requires_grad) {
|
|
1375
|
+
grad_cb_val = cb_batched_h ? grad_cb_shot[idx_shot] : grad_cb[j];
|
|
1376
|
+
}
|
|
1377
|
+
|
|
1378
|
+
TIDE_DTYPE const dca_de = ((TIDE_DTYPE)1 - ca_val) * cb_val * inv_dt;
|
|
1379
|
+
TIDE_DTYPE const dcb_de = -cb_sq * inv_dt;
|
|
1380
|
+
TIDE_DTYPE const dca_ds = -((TIDE_DTYPE)0.5) * ((TIDE_DTYPE)1 + ca_val) * cb_val;
|
|
1381
|
+
TIDE_DTYPE const dcb_ds = -((TIDE_DTYPE)0.5) * cb_sq;
|
|
1382
|
+
|
|
1383
|
+
if (grad_eps != nullptr) {
|
|
1384
|
+
TIDE_DTYPE const grad_e = grad_ca_val * dca_de + grad_cb_val * dcb_de;
|
|
1385
|
+
grad_eps[out_idx] = grad_e * EP0;
|
|
1386
|
+
}
|
|
1387
|
+
if (grad_sigma != nullptr) {
|
|
1388
|
+
grad_sigma[out_idx] = grad_ca_val * dca_ds + grad_cb_val * dcb_ds;
|
|
1389
|
+
}
|
|
1390
|
+
}
|
|
1391
|
+
|
|
1392
|
+
} // namespace
|
|
1393
|
+
|
|
1394
|
+
// Forward propagation function
|
|
1395
|
+
extern "C" void FUNC(forward)(
|
|
1396
|
+
TIDE_DTYPE const *const ca,
|
|
1397
|
+
TIDE_DTYPE const *const cb,
|
|
1398
|
+
TIDE_DTYPE const *const cq,
|
|
1399
|
+
TIDE_DTYPE const *const f,
|
|
1400
|
+
TIDE_DTYPE *const ey,
|
|
1401
|
+
TIDE_DTYPE *const hx,
|
|
1402
|
+
TIDE_DTYPE *const hz,
|
|
1403
|
+
TIDE_DTYPE *const m_ey_x,
|
|
1404
|
+
TIDE_DTYPE *const m_ey_z,
|
|
1405
|
+
TIDE_DTYPE *const m_hx_z,
|
|
1406
|
+
TIDE_DTYPE *const m_hz_x,
|
|
1407
|
+
TIDE_DTYPE *const r,
|
|
1408
|
+
TIDE_DTYPE const *const ay,
|
|
1409
|
+
TIDE_DTYPE const *const by,
|
|
1410
|
+
TIDE_DTYPE const *const ayh,
|
|
1411
|
+
TIDE_DTYPE const *const byh,
|
|
1412
|
+
TIDE_DTYPE const *const ax,
|
|
1413
|
+
TIDE_DTYPE const *const bx,
|
|
1414
|
+
TIDE_DTYPE const *const axh,
|
|
1415
|
+
TIDE_DTYPE const *const bxh,
|
|
1416
|
+
TIDE_DTYPE const *const ky,
|
|
1417
|
+
TIDE_DTYPE const *const kyh,
|
|
1418
|
+
TIDE_DTYPE const *const kx,
|
|
1419
|
+
TIDE_DTYPE const *const kxh,
|
|
1420
|
+
int64_t const *const sources_i,
|
|
1421
|
+
int64_t const *const receivers_i,
|
|
1422
|
+
TIDE_DTYPE const rdy_h,
|
|
1423
|
+
TIDE_DTYPE const rdx_h,
|
|
1424
|
+
TIDE_DTYPE const dt_h,
|
|
1425
|
+
int64_t const nt,
|
|
1426
|
+
int64_t const n_shots_h,
|
|
1427
|
+
int64_t const ny_h,
|
|
1428
|
+
int64_t const nx_h,
|
|
1429
|
+
int64_t const n_sources_per_shot_h,
|
|
1430
|
+
int64_t const n_receivers_per_shot_h,
|
|
1431
|
+
int64_t const step_ratio_h,
|
|
1432
|
+
bool const ca_batched_h,
|
|
1433
|
+
bool const cb_batched_h,
|
|
1434
|
+
bool const cq_batched_h,
|
|
1435
|
+
int64_t const start_t,
|
|
1436
|
+
int64_t const pml_y0_h,
|
|
1437
|
+
int64_t const pml_x0_h,
|
|
1438
|
+
int64_t const pml_y1_h,
|
|
1439
|
+
int64_t const pml_x1_h,
|
|
1440
|
+
int64_t const n_threads,
|
|
1441
|
+
int64_t const device) {
|
|
1442
|
+
|
|
1443
|
+
cudaSetDevice(device);
|
|
1444
|
+
(void)dt_h;
|
|
1445
|
+
(void)step_ratio_h;
|
|
1446
|
+
(void)n_threads;
|
|
1447
|
+
|
|
1448
|
+
int64_t const shot_numel_h = ny_h * nx_h;
|
|
1449
|
+
|
|
1450
|
+
// Copy constants to device with caching to avoid redundant copies
|
|
1451
|
+
static TIDE_DTYPE cached_rdy = 0, cached_rdx = 0;
|
|
1452
|
+
static int64_t cached_n_shots = -1, cached_ny = -1, cached_nx = -1;
|
|
1453
|
+
static int64_t cached_shot_numel = -1, cached_n_sources_per_shot = -1, cached_n_receivers_per_shot = -1;
|
|
1454
|
+
static int64_t cached_pml_y0 = -1, cached_pml_y1 = -1;
|
|
1455
|
+
static int64_t cached_pml_x0 = -1, cached_pml_x1 = -1;
|
|
1456
|
+
static bool cached_ca_batched = false, cached_cb_batched = false, cached_cq_batched = false;
|
|
1457
|
+
static int64_t cached_device = -1;
|
|
1458
|
+
static bool first_call = true;
|
|
1459
|
+
|
|
1460
|
+
if (first_call || cached_device != device || cached_rdy != rdy_h || cached_rdx != rdx_h ||
|
|
1461
|
+
cached_n_shots != n_shots_h || cached_ny != ny_h || cached_nx != nx_h ||
|
|
1462
|
+
cached_shot_numel != shot_numel_h || cached_n_sources_per_shot != n_sources_per_shot_h ||
|
|
1463
|
+
cached_n_receivers_per_shot != n_receivers_per_shot_h ||
|
|
1464
|
+
cached_pml_y0 != pml_y0_h || cached_pml_y1 != pml_y1_h ||
|
|
1465
|
+
cached_pml_x0 != pml_x0_h || cached_pml_x1 != pml_x1_h ||
|
|
1466
|
+
cached_ca_batched != ca_batched_h || cached_cb_batched != cb_batched_h ||
|
|
1467
|
+
cached_cq_batched != cq_batched_h) {
|
|
1468
|
+
|
|
1469
|
+
cudaMemcpyToSymbol(rdy, &rdy_h, sizeof(TIDE_DTYPE));
|
|
1470
|
+
cudaMemcpyToSymbol(rdx, &rdx_h, sizeof(TIDE_DTYPE));
|
|
1471
|
+
cudaMemcpyToSymbol(n_shots, &n_shots_h, sizeof(int64_t));
|
|
1472
|
+
cudaMemcpyToSymbol(ny, &ny_h, sizeof(int64_t));
|
|
1473
|
+
cudaMemcpyToSymbol(nx, &nx_h, sizeof(int64_t));
|
|
1474
|
+
cudaMemcpyToSymbol(shot_numel, &shot_numel_h, sizeof(int64_t));
|
|
1475
|
+
cudaMemcpyToSymbol(n_sources_per_shot, &n_sources_per_shot_h, sizeof(int64_t));
|
|
1476
|
+
cudaMemcpyToSymbol(n_receivers_per_shot, &n_receivers_per_shot_h, sizeof(int64_t));
|
|
1477
|
+
cudaMemcpyToSymbol(pml_y0, &pml_y0_h, sizeof(int64_t));
|
|
1478
|
+
cudaMemcpyToSymbol(pml_y1, &pml_y1_h, sizeof(int64_t));
|
|
1479
|
+
cudaMemcpyToSymbol(pml_x0, &pml_x0_h, sizeof(int64_t));
|
|
1480
|
+
cudaMemcpyToSymbol(pml_x1, &pml_x1_h, sizeof(int64_t));
|
|
1481
|
+
cudaMemcpyToSymbol(ca_batched, &ca_batched_h, sizeof(bool));
|
|
1482
|
+
cudaMemcpyToSymbol(cb_batched, &cb_batched_h, sizeof(bool));
|
|
1483
|
+
cudaMemcpyToSymbol(cq_batched, &cq_batched_h, sizeof(bool));
|
|
1484
|
+
|
|
1485
|
+
cached_rdy = rdy_h; cached_rdx = rdx_h;
|
|
1486
|
+
cached_n_shots = n_shots_h; cached_ny = ny_h; cached_nx = nx_h;
|
|
1487
|
+
cached_shot_numel = shot_numel_h; cached_n_sources_per_shot = n_sources_per_shot_h;
|
|
1488
|
+
cached_n_receivers_per_shot = n_receivers_per_shot_h;
|
|
1489
|
+
cached_pml_y0 = pml_y0_h; cached_pml_y1 = pml_y1_h;
|
|
1490
|
+
cached_pml_x0 = pml_x0_h; cached_pml_x1 = pml_x1_h;
|
|
1491
|
+
cached_ca_batched = ca_batched_h; cached_cb_batched = cb_batched_h;
|
|
1492
|
+
cached_cq_batched = cq_batched_h;
|
|
1493
|
+
cached_device = device;
|
|
1494
|
+
first_call = false;
|
|
1495
|
+
}
|
|
1496
|
+
|
|
1497
|
+
dim3 dimBlock(32, 8, 1);
|
|
1498
|
+
int64_t gridx = (nx_h - 2 * FD_PAD + 2 + dimBlock.x - 1) / dimBlock.x;
|
|
1499
|
+
int64_t gridy = (ny_h - 2 * FD_PAD + 2 + dimBlock.y - 1) / dimBlock.y;
|
|
1500
|
+
int64_t gridz = n_shots_h;
|
|
1501
|
+
dim3 dimGrid(gridx, gridy, gridz);
|
|
1502
|
+
#if FD_PAD > 1
|
|
1503
|
+
size_t const shmem_h_bytes =
|
|
1504
|
+
(size_t)(dimBlock.x + 2 * FD_PAD) * (size_t)(dimBlock.y + 2 * FD_PAD) *
|
|
1505
|
+
sizeof(TIDE_DTYPE);
|
|
1506
|
+
size_t const shmem_e_bytes = 2 * shmem_h_bytes;
|
|
1507
|
+
#else
|
|
1508
|
+
size_t const shmem_h_bytes = 0;
|
|
1509
|
+
size_t const shmem_e_bytes = 0;
|
|
1510
|
+
#endif
|
|
1511
|
+
|
|
1512
|
+
dim3 dimBlock_sources(32, 1, 1);
|
|
1513
|
+
dim3 dimGrid_sources(
|
|
1514
|
+
(n_sources_per_shot_h + dimBlock_sources.x - 1) / dimBlock_sources.x,
|
|
1515
|
+
n_shots_h, 1);
|
|
1516
|
+
|
|
1517
|
+
dim3 dimBlock_receivers(32, 1, 1);
|
|
1518
|
+
dim3 dimGrid_receivers(
|
|
1519
|
+
(n_receivers_per_shot_h + dimBlock_receivers.x - 1) / dimBlock_receivers.x,
|
|
1520
|
+
n_shots_h, 1);
|
|
1521
|
+
|
|
1522
|
+
auto run_step = [&](int64_t t) {
|
|
1523
|
+
forward_kernel_h<<<dimGrid, dimBlock, shmem_h_bytes>>>(
|
|
1524
|
+
cq, ey, hx, hz, m_ey_x, m_ey_z,
|
|
1525
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
1526
|
+
ky, kyh, kx, kxh);
|
|
1527
|
+
forward_kernel_e<<<dimGrid, dimBlock, shmem_e_bytes>>>(
|
|
1528
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
|
|
1529
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
1530
|
+
ky, kyh, kx, kxh);
|
|
1531
|
+
|
|
1532
|
+
if (n_sources_per_shot_h > 0) {
|
|
1533
|
+
add_sources_ey<<<dimGrid_sources, dimBlock_sources>>>(
|
|
1534
|
+
ey, f + t * n_shots_h * n_sources_per_shot_h, sources_i);
|
|
1535
|
+
}
|
|
1536
|
+
|
|
1537
|
+
if (n_receivers_per_shot_h > 0) {
|
|
1538
|
+
record_receivers_ey<<<dimGrid_receivers, dimBlock_receivers>>>(
|
|
1539
|
+
r + t * n_shots_h * n_receivers_per_shot_h, ey, receivers_i);
|
|
1540
|
+
}
|
|
1541
|
+
};
|
|
1542
|
+
|
|
1543
|
+
for (int64_t t = start_t; t < start_t + nt; ++t) {
|
|
1544
|
+
run_step(t);
|
|
1545
|
+
}
|
|
1546
|
+
|
|
1547
|
+
gpuErrchk(cudaPeekAtLastError());
|
|
1548
|
+
}
|
|
1549
|
+
|
|
1550
|
+
extern "C" void FUNC(forward_with_storage)(
|
|
1551
|
+
TIDE_DTYPE const *const ca,
|
|
1552
|
+
TIDE_DTYPE const *const cb,
|
|
1553
|
+
TIDE_DTYPE const *const cq,
|
|
1554
|
+
TIDE_DTYPE const *const f,
|
|
1555
|
+
TIDE_DTYPE *const ey,
|
|
1556
|
+
TIDE_DTYPE *const hx,
|
|
1557
|
+
TIDE_DTYPE *const hz,
|
|
1558
|
+
TIDE_DTYPE *const m_ey_x,
|
|
1559
|
+
TIDE_DTYPE *const m_ey_z,
|
|
1560
|
+
TIDE_DTYPE *const m_hx_z,
|
|
1561
|
+
TIDE_DTYPE *const m_hz_x,
|
|
1562
|
+
TIDE_DTYPE *const r,
|
|
1563
|
+
void *const ey_store_1,
|
|
1564
|
+
void *const ey_store_3,
|
|
1565
|
+
char const *const *const ey_filenames,
|
|
1566
|
+
void *const curl_store_1,
|
|
1567
|
+
void *const curl_store_3,
|
|
1568
|
+
char const *const *const curl_filenames,
|
|
1569
|
+
TIDE_DTYPE const *const ay,
|
|
1570
|
+
TIDE_DTYPE const *const by,
|
|
1571
|
+
TIDE_DTYPE const *const ayh,
|
|
1572
|
+
TIDE_DTYPE const *const byh,
|
|
1573
|
+
TIDE_DTYPE const *const ax,
|
|
1574
|
+
TIDE_DTYPE const *const bx,
|
|
1575
|
+
TIDE_DTYPE const *const axh,
|
|
1576
|
+
TIDE_DTYPE const *const bxh,
|
|
1577
|
+
TIDE_DTYPE const *const ky,
|
|
1578
|
+
TIDE_DTYPE const *const kyh,
|
|
1579
|
+
TIDE_DTYPE const *const kx,
|
|
1580
|
+
TIDE_DTYPE const *const kxh,
|
|
1581
|
+
int64_t const *const sources_i,
|
|
1582
|
+
int64_t const *const receivers_i,
|
|
1583
|
+
TIDE_DTYPE const rdy_h,
|
|
1584
|
+
TIDE_DTYPE const rdx_h,
|
|
1585
|
+
TIDE_DTYPE const dt_h,
|
|
1586
|
+
int64_t const nt,
|
|
1587
|
+
int64_t const n_shots_h,
|
|
1588
|
+
int64_t const ny_h,
|
|
1589
|
+
int64_t const nx_h,
|
|
1590
|
+
int64_t const n_sources_per_shot_h,
|
|
1591
|
+
int64_t const n_receivers_per_shot_h,
|
|
1592
|
+
int64_t const step_ratio_h,
|
|
1593
|
+
int64_t const storage_mode_h,
|
|
1594
|
+
int64_t const shot_bytes_uncomp_h,
|
|
1595
|
+
bool const ca_requires_grad,
|
|
1596
|
+
bool const cb_requires_grad,
|
|
1597
|
+
bool const ca_batched_h,
|
|
1598
|
+
bool const cb_batched_h,
|
|
1599
|
+
bool const cq_batched_h,
|
|
1600
|
+
int64_t const start_t,
|
|
1601
|
+
int64_t const pml_y0_h,
|
|
1602
|
+
int64_t const pml_x0_h,
|
|
1603
|
+
int64_t const pml_y1_h,
|
|
1604
|
+
int64_t const pml_x1_h,
|
|
1605
|
+
int64_t const n_threads,
|
|
1606
|
+
int64_t const device) {
|
|
1607
|
+
|
|
1608
|
+
cudaSetDevice(device);
|
|
1609
|
+
(void)n_threads;
|
|
1610
|
+
|
|
1611
|
+
int64_t const shot_numel_h = ny_h * nx_h;
|
|
1612
|
+
size_t const bytes_per_step_store =
|
|
1613
|
+
(size_t)shot_bytes_uncomp_h * (size_t)n_shots_h;
|
|
1614
|
+
bool const storage_bf16_h = (shot_bytes_uncomp_h == shot_numel_h * 2);
|
|
1615
|
+
bool const storage_fp8_h = (shot_bytes_uncomp_h == shot_numel_h);
|
|
1616
|
+
cudaStream_t copy_stream = nullptr;
|
|
1617
|
+
cudaEvent_t store_ready;
|
|
1618
|
+
cudaEvent_t copy_done[NUM_BUFFERS];
|
|
1619
|
+
bool copy_in_flight[NUM_BUFFERS];
|
|
1620
|
+
for (int i = 0; i < NUM_BUFFERS; i++) copy_in_flight[i] = false;
|
|
1621
|
+
|
|
1622
|
+
#ifdef TIDE_PROFILING
|
|
1623
|
+
cudaEvent_t prof_wait_start, prof_wait_end, prof_copy_start, prof_copy_end;
|
|
1624
|
+
float total_wait_ms = 0.0f, total_copy_ms = 0.0f;
|
|
1625
|
+
int n_waits = 0, n_copies = 0;
|
|
1626
|
+
#endif
|
|
1627
|
+
|
|
1628
|
+
if (storage_mode_h == STORAGE_CPU) {
|
|
1629
|
+
gpuErrchk(cudaStreamCreateWithFlags(©_stream, cudaStreamNonBlocking));
|
|
1630
|
+
#ifdef TIDE_PROFILING
|
|
1631
|
+
PROF_EVENT_CREATE(store_ready);
|
|
1632
|
+
PROF_EVENT_CREATE(prof_wait_start);
|
|
1633
|
+
PROF_EVENT_CREATE(prof_wait_end);
|
|
1634
|
+
PROF_EVENT_CREATE(prof_copy_start);
|
|
1635
|
+
PROF_EVENT_CREATE(prof_copy_end);
|
|
1636
|
+
for (int i = 0; i < NUM_BUFFERS; i++) {
|
|
1637
|
+
PROF_EVENT_CREATE(copy_done[i]);
|
|
1638
|
+
}
|
|
1639
|
+
#else
|
|
1640
|
+
gpuErrchk(cudaEventCreateWithFlags(&store_ready, cudaEventDisableTiming));
|
|
1641
|
+
for (int i = 0; i < NUM_BUFFERS; i++) {
|
|
1642
|
+
gpuErrchk(cudaEventCreateWithFlags(©_done[i], cudaEventDisableTiming));
|
|
1643
|
+
}
|
|
1644
|
+
#endif
|
|
1645
|
+
}
|
|
1646
|
+
|
|
1647
|
+
// Copy constants to device with caching to avoid redundant copies
|
|
1648
|
+
static TIDE_DTYPE cached_rdy2 = 0, cached_rdx2 = 0;
|
|
1649
|
+
static int64_t cached_n_shots2 = -1, cached_ny2 = -1, cached_nx2 = -1;
|
|
1650
|
+
static int64_t cached_shot_numel2 = -1, cached_n_sources_per_shot2 = -1, cached_n_receivers_per_shot2 = -1;
|
|
1651
|
+
static int64_t cached_pml_y02 = -1, cached_pml_y12 = -1;
|
|
1652
|
+
static int64_t cached_pml_x02 = -1, cached_pml_x12 = -1;
|
|
1653
|
+
static bool cached_ca_batched2 = false, cached_cb_batched2 = false, cached_cq_batched2 = false;
|
|
1654
|
+
static int64_t cached_device2 = -1;
|
|
1655
|
+
static bool first_call2 = true;
|
|
1656
|
+
|
|
1657
|
+
if (first_call2 || cached_device2 != device || cached_rdy2 != rdy_h || cached_rdx2 != rdx_h ||
|
|
1658
|
+
cached_n_shots2 != n_shots_h || cached_ny2 != ny_h || cached_nx2 != nx_h ||
|
|
1659
|
+
cached_shot_numel2 != shot_numel_h || cached_n_sources_per_shot2 != n_sources_per_shot_h ||
|
|
1660
|
+
cached_n_receivers_per_shot2 != n_receivers_per_shot_h ||
|
|
1661
|
+
cached_pml_y02 != pml_y0_h || cached_pml_y12 != pml_y1_h ||
|
|
1662
|
+
cached_pml_x02 != pml_x0_h || cached_pml_x12 != pml_x1_h ||
|
|
1663
|
+
cached_ca_batched2 != ca_batched_h || cached_cb_batched2 != cb_batched_h ||
|
|
1664
|
+
cached_cq_batched2 != cq_batched_h) {
|
|
1665
|
+
|
|
1666
|
+
cudaMemcpyToSymbol(rdy, &rdy_h, sizeof(TIDE_DTYPE));
|
|
1667
|
+
cudaMemcpyToSymbol(rdx, &rdx_h, sizeof(TIDE_DTYPE));
|
|
1668
|
+
cudaMemcpyToSymbol(n_shots, &n_shots_h, sizeof(int64_t));
|
|
1669
|
+
cudaMemcpyToSymbol(ny, &ny_h, sizeof(int64_t));
|
|
1670
|
+
cudaMemcpyToSymbol(nx, &nx_h, sizeof(int64_t));
|
|
1671
|
+
cudaMemcpyToSymbol(shot_numel, &shot_numel_h, sizeof(int64_t));
|
|
1672
|
+
cudaMemcpyToSymbol(n_sources_per_shot, &n_sources_per_shot_h, sizeof(int64_t));
|
|
1673
|
+
cudaMemcpyToSymbol(n_receivers_per_shot, &n_receivers_per_shot_h, sizeof(int64_t));
|
|
1674
|
+
cudaMemcpyToSymbol(pml_y0, &pml_y0_h, sizeof(int64_t));
|
|
1675
|
+
cudaMemcpyToSymbol(pml_y1, &pml_y1_h, sizeof(int64_t));
|
|
1676
|
+
cudaMemcpyToSymbol(pml_x0, &pml_x0_h, sizeof(int64_t));
|
|
1677
|
+
cudaMemcpyToSymbol(pml_x1, &pml_x1_h, sizeof(int64_t));
|
|
1678
|
+
cudaMemcpyToSymbol(ca_batched, &ca_batched_h, sizeof(bool));
|
|
1679
|
+
cudaMemcpyToSymbol(cb_batched, &cb_batched_h, sizeof(bool));
|
|
1680
|
+
cudaMemcpyToSymbol(cq_batched, &cq_batched_h, sizeof(bool));
|
|
1681
|
+
|
|
1682
|
+
cached_rdy2 = rdy_h; cached_rdx2 = rdx_h;
|
|
1683
|
+
cached_n_shots2 = n_shots_h; cached_ny2 = ny_h; cached_nx2 = nx_h;
|
|
1684
|
+
cached_shot_numel2 = shot_numel_h; cached_n_sources_per_shot2 = n_sources_per_shot_h;
|
|
1685
|
+
cached_n_receivers_per_shot2 = n_receivers_per_shot_h;
|
|
1686
|
+
cached_pml_y02 = pml_y0_h; cached_pml_y12 = pml_y1_h;
|
|
1687
|
+
cached_pml_x02 = pml_x0_h; cached_pml_x12 = pml_x1_h;
|
|
1688
|
+
cached_ca_batched2 = ca_batched_h; cached_cb_batched2 = cb_batched_h;
|
|
1689
|
+
cached_cq_batched2 = cq_batched_h;
|
|
1690
|
+
cached_device2 = device;
|
|
1691
|
+
first_call2 = false;
|
|
1692
|
+
}
|
|
1693
|
+
|
|
1694
|
+
dim3 dimBlock(32, 8, 1);
|
|
1695
|
+
int64_t gridx = (nx_h - 2 * FD_PAD + 2 + dimBlock.x - 1) / dimBlock.x;
|
|
1696
|
+
int64_t gridy = (ny_h - 2 * FD_PAD + 2 + dimBlock.y - 1) / dimBlock.y;
|
|
1697
|
+
int64_t gridz = n_shots_h;
|
|
1698
|
+
dim3 dimGrid(gridx, gridy, gridz);
|
|
1699
|
+
#if FD_PAD > 1
|
|
1700
|
+
size_t const shmem_h_bytes =
|
|
1701
|
+
(size_t)(dimBlock.x + 2 * FD_PAD) * (size_t)(dimBlock.y + 2 * FD_PAD) *
|
|
1702
|
+
sizeof(TIDE_DTYPE);
|
|
1703
|
+
size_t const shmem_e_bytes = 2 * shmem_h_bytes;
|
|
1704
|
+
#else
|
|
1705
|
+
size_t const shmem_h_bytes = 0;
|
|
1706
|
+
size_t const shmem_e_bytes = 0;
|
|
1707
|
+
#endif
|
|
1708
|
+
|
|
1709
|
+
dim3 dimBlock_sources(32, 1, 1);
|
|
1710
|
+
dim3 dimGrid_sources(
|
|
1711
|
+
(n_sources_per_shot_h + dimBlock_sources.x - 1) / dimBlock_sources.x,
|
|
1712
|
+
n_shots_h, 1);
|
|
1713
|
+
|
|
1714
|
+
dim3 dimBlock_receivers(32, 1, 1);
|
|
1715
|
+
dim3 dimGrid_receivers(
|
|
1716
|
+
(n_receivers_per_shot_h + dimBlock_receivers.x - 1) / dimBlock_receivers.x,
|
|
1717
|
+
n_shots_h, 1);
|
|
1718
|
+
|
|
1719
|
+
FILE *fp_ey = nullptr;
|
|
1720
|
+
FILE *fp_curl = nullptr;
|
|
1721
|
+
if (storage_mode_h == STORAGE_DISK) {
|
|
1722
|
+
if (ca_requires_grad) fp_ey = fopen(ey_filenames[0], "wb");
|
|
1723
|
+
if (cb_requires_grad) fp_curl = fopen(curl_filenames[0], "wb");
|
|
1724
|
+
}
|
|
1725
|
+
|
|
1726
|
+
auto store1_offset_bytes = [&](int64_t step_idx) -> size_t {
|
|
1727
|
+
if (storage_mode_h == STORAGE_DEVICE) {
|
|
1728
|
+
return (size_t)step_idx * bytes_per_step_store;
|
|
1729
|
+
}
|
|
1730
|
+
if (storage_mode_h == STORAGE_CPU) {
|
|
1731
|
+
return (size_t)(step_idx % NUM_BUFFERS) * bytes_per_step_store;
|
|
1732
|
+
}
|
|
1733
|
+
return 0;
|
|
1734
|
+
};
|
|
1735
|
+
|
|
1736
|
+
auto run_step = [&](int64_t t) {
|
|
1737
|
+
forward_kernel_h<<<dimGrid, dimBlock, shmem_h_bytes>>>(
|
|
1738
|
+
cq, ey, hx, hz, m_ey_x, m_ey_z,
|
|
1739
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
1740
|
+
ky, kyh, kx, kxh);
|
|
1741
|
+
|
|
1742
|
+
bool const store_step = ((t % step_ratio_h) == 0);
|
|
1743
|
+
bool const store_ey = store_step && ca_requires_grad;
|
|
1744
|
+
bool const store_curl = store_step && cb_requires_grad;
|
|
1745
|
+
bool const want_store = store_ey || store_curl;
|
|
1746
|
+
if (want_store) {
|
|
1747
|
+
int64_t const step_idx = t / step_ratio_h;
|
|
1748
|
+
int const store_buf = (storage_mode_h == STORAGE_CPU) ? (int)(step_idx % NUM_BUFFERS) : 0;
|
|
1749
|
+
if (storage_mode_h == STORAGE_CPU && copy_in_flight[store_buf]) {
|
|
1750
|
+
#ifdef TIDE_PROFILING
|
|
1751
|
+
PROF_RECORD(prof_wait_start, 0);
|
|
1752
|
+
#endif
|
|
1753
|
+
gpuErrchk(cudaStreamWaitEvent(0, copy_done[store_buf], 0));
|
|
1754
|
+
#ifdef TIDE_PROFILING
|
|
1755
|
+
PROF_RECORD(prof_wait_end, 0);
|
|
1756
|
+
gpuErrchk(cudaDeviceSynchronize());
|
|
1757
|
+
float wait_ms;
|
|
1758
|
+
PROF_ELAPSED(prof_wait_start, prof_wait_end, wait_ms);
|
|
1759
|
+
total_wait_ms += wait_ms;
|
|
1760
|
+
n_waits++;
|
|
1761
|
+
#endif
|
|
1762
|
+
copy_in_flight[store_buf] = false;
|
|
1763
|
+
}
|
|
1764
|
+
size_t const store1_offset = store1_offset_bytes(step_idx);
|
|
1765
|
+
|
|
1766
|
+
void *__restrict const ey_store_1_t =
|
|
1767
|
+
(uint8_t *)ey_store_1 + store1_offset;
|
|
1768
|
+
void *__restrict const ey_store_3_t =
|
|
1769
|
+
(uint8_t *)ey_store_3 +
|
|
1770
|
+
(storage_mode_h == STORAGE_CPU
|
|
1771
|
+
? (size_t)step_idx * bytes_per_step_store
|
|
1772
|
+
: 0);
|
|
1773
|
+
|
|
1774
|
+
void *__restrict const curl_store_1_t =
|
|
1775
|
+
(uint8_t *)curl_store_1 + store1_offset;
|
|
1776
|
+
void *__restrict const curl_store_3_t =
|
|
1777
|
+
(uint8_t *)curl_store_3 +
|
|
1778
|
+
(storage_mode_h == STORAGE_CPU
|
|
1779
|
+
? (size_t)step_idx * bytes_per_step_store
|
|
1780
|
+
: 0);
|
|
1781
|
+
|
|
1782
|
+
if (storage_fp8_h) {
|
|
1783
|
+
forward_kernel_e_with_storage_fp8<<<dimGrid, dimBlock, shmem_e_bytes>>>(
|
|
1784
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
|
|
1785
|
+
store_ey ? (uint8_t *)ey_store_1_t : nullptr,
|
|
1786
|
+
store_curl ? (uint8_t *)curl_store_1_t : nullptr, ay, ayh, ax, axh,
|
|
1787
|
+
by, byh, bx, bxh, ky, kyh, kx, kxh, store_ey, store_curl);
|
|
1788
|
+
} else if (storage_bf16_h) {
|
|
1789
|
+
forward_kernel_e_with_storage_bf16<<<dimGrid, dimBlock, shmem_e_bytes>>>(
|
|
1790
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
|
|
1791
|
+
store_ey ? (__nv_bfloat16 *)ey_store_1_t : nullptr,
|
|
1792
|
+
store_curl ? (__nv_bfloat16 *)curl_store_1_t : nullptr, ay, ayh, ax,
|
|
1793
|
+
axh, by, byh, bx, bxh, ky, kyh, kx, kxh, store_ey, store_curl);
|
|
1794
|
+
} else {
|
|
1795
|
+
forward_kernel_e_with_storage<<<dimGrid, dimBlock, shmem_e_bytes>>>(
|
|
1796
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
|
|
1797
|
+
store_ey ? (TIDE_DTYPE *)ey_store_1_t : nullptr,
|
|
1798
|
+
store_curl ? (TIDE_DTYPE *)curl_store_1_t : nullptr, ay, ayh, ax,
|
|
1799
|
+
axh, by, byh, bx, bxh, ky, kyh, kx, kxh, store_ey, store_curl);
|
|
1800
|
+
}
|
|
1801
|
+
|
|
1802
|
+
if (storage_mode_h == STORAGE_CPU) {
|
|
1803
|
+
gpuErrchk(cudaEventRecord(store_ready, 0));
|
|
1804
|
+
gpuErrchk(cudaStreamWaitEvent(copy_stream, store_ready, 0));
|
|
1805
|
+
#ifdef TIDE_PROFILING
|
|
1806
|
+
PROF_RECORD(prof_copy_start, copy_stream);
|
|
1807
|
+
#endif
|
|
1808
|
+
if (store_ey) {
|
|
1809
|
+
gpuErrchk(cudaMemcpyAsync(
|
|
1810
|
+
ey_store_3_t, ey_store_1_t, bytes_per_step_store,
|
|
1811
|
+
cudaMemcpyDeviceToHost, copy_stream));
|
|
1812
|
+
}
|
|
1813
|
+
if (store_curl) {
|
|
1814
|
+
gpuErrchk(cudaMemcpyAsync(
|
|
1815
|
+
curl_store_3_t, curl_store_1_t, bytes_per_step_store,
|
|
1816
|
+
cudaMemcpyDeviceToHost, copy_stream));
|
|
1817
|
+
}
|
|
1818
|
+
#ifdef TIDE_PROFILING
|
|
1819
|
+
PROF_RECORD(prof_copy_end, copy_stream);
|
|
1820
|
+
#endif
|
|
1821
|
+
gpuErrchk(cudaEventRecord(copy_done[store_buf], copy_stream));
|
|
1822
|
+
copy_in_flight[store_buf] = true;
|
|
1823
|
+
#ifdef TIDE_PROFILING
|
|
1824
|
+
n_copies++;
|
|
1825
|
+
#endif
|
|
1826
|
+
} else {
|
|
1827
|
+
if (store_ey) {
|
|
1828
|
+
storage_save_snapshot_gpu(
|
|
1829
|
+
ey_store_1_t, ey_store_3_t, fp_ey, storage_mode_h, step_idx,
|
|
1830
|
+
(size_t)shot_bytes_uncomp_h, (size_t)n_shots_h);
|
|
1831
|
+
}
|
|
1832
|
+
if (store_curl) {
|
|
1833
|
+
storage_save_snapshot_gpu(
|
|
1834
|
+
curl_store_1_t, curl_store_3_t, fp_curl, storage_mode_h, step_idx,
|
|
1835
|
+
(size_t)shot_bytes_uncomp_h, (size_t)n_shots_h);
|
|
1836
|
+
}
|
|
1837
|
+
}
|
|
1838
|
+
} else {
|
|
1839
|
+
forward_kernel_e<<<dimGrid, dimBlock, shmem_e_bytes>>>(
|
|
1840
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x, ay, ayh, ax, axh, by, byh, bx,
|
|
1841
|
+
bxh, ky, kyh, kx, kxh);
|
|
1842
|
+
}
|
|
1843
|
+
|
|
1844
|
+
if (n_sources_per_shot_h > 0) {
|
|
1845
|
+
add_sources_ey<<<dimGrid_sources, dimBlock_sources>>>(
|
|
1846
|
+
ey, f + t * n_shots_h * n_sources_per_shot_h, sources_i);
|
|
1847
|
+
}
|
|
1848
|
+
|
|
1849
|
+
if (n_receivers_per_shot_h > 0) {
|
|
1850
|
+
record_receivers_ey<<<dimGrid_receivers, dimBlock_receivers>>>(
|
|
1851
|
+
r + t * n_shots_h * n_receivers_per_shot_h, ey, receivers_i);
|
|
1852
|
+
}
|
|
1853
|
+
};
|
|
1854
|
+
|
|
1855
|
+
for (int64_t t = start_t; t < start_t + nt; ++t) {
|
|
1856
|
+
run_step(t);
|
|
1857
|
+
}
|
|
1858
|
+
|
|
1859
|
+
if (storage_mode_h == STORAGE_CPU) {
|
|
1860
|
+
gpuErrchk(cudaStreamSynchronize(copy_stream));
|
|
1861
|
+
for (int i = 0; i < NUM_BUFFERS; i++) {
|
|
1862
|
+
gpuErrchk(cudaEventDestroy(copy_done[i]));
|
|
1863
|
+
}
|
|
1864
|
+
gpuErrchk(cudaEventDestroy(store_ready));
|
|
1865
|
+
gpuErrchk(cudaStreamDestroy(copy_stream));
|
|
1866
|
+
}
|
|
1867
|
+
|
|
1868
|
+
if (fp_ey != nullptr) fclose(fp_ey);
|
|
1869
|
+
if (fp_curl != nullptr) fclose(fp_curl);
|
|
1870
|
+
|
|
1871
|
+
gpuErrchk(cudaPeekAtLastError());
|
|
1872
|
+
}
|
|
1873
|
+
|
|
1874
|
+
|
|
1875
|
+
|
|
1876
|
+
extern "C" void FUNC(backward)(
|
|
1877
|
+
TIDE_DTYPE const *const ca,
|
|
1878
|
+
TIDE_DTYPE const *const cb,
|
|
1879
|
+
TIDE_DTYPE const *const cq,
|
|
1880
|
+
TIDE_DTYPE const *const grad_r,
|
|
1881
|
+
TIDE_DTYPE *const lambda_ey,
|
|
1882
|
+
TIDE_DTYPE *const lambda_hx,
|
|
1883
|
+
TIDE_DTYPE *const lambda_hz,
|
|
1884
|
+
TIDE_DTYPE *const m_lambda_ey_x,
|
|
1885
|
+
TIDE_DTYPE *const m_lambda_ey_z,
|
|
1886
|
+
TIDE_DTYPE *const m_lambda_hx_z,
|
|
1887
|
+
TIDE_DTYPE *const m_lambda_hz_x,
|
|
1888
|
+
void *const ey_store_1,
|
|
1889
|
+
void *const ey_store_3,
|
|
1890
|
+
char const *const *const ey_filenames,
|
|
1891
|
+
void *const curl_store_1,
|
|
1892
|
+
void *const curl_store_3,
|
|
1893
|
+
char const *const *const curl_filenames,
|
|
1894
|
+
TIDE_DTYPE *const grad_f,
|
|
1895
|
+
TIDE_DTYPE *const grad_ca,
|
|
1896
|
+
TIDE_DTYPE *const grad_cb,
|
|
1897
|
+
TIDE_DTYPE *const grad_eps,
|
|
1898
|
+
TIDE_DTYPE *const grad_sigma,
|
|
1899
|
+
TIDE_DTYPE *const grad_ca_shot, // [n_shots, ny, nx] - per-shot gradient workspace
|
|
1900
|
+
TIDE_DTYPE *const grad_cb_shot, // [n_shots, ny, nx] - per-shot gradient workspace
|
|
1901
|
+
TIDE_DTYPE const *const ay,
|
|
1902
|
+
TIDE_DTYPE const *const by,
|
|
1903
|
+
TIDE_DTYPE const *const ayh,
|
|
1904
|
+
TIDE_DTYPE const *const byh,
|
|
1905
|
+
TIDE_DTYPE const *const ax,
|
|
1906
|
+
TIDE_DTYPE const *const bx,
|
|
1907
|
+
TIDE_DTYPE const *const axh,
|
|
1908
|
+
TIDE_DTYPE const *const bxh,
|
|
1909
|
+
TIDE_DTYPE const *const ky,
|
|
1910
|
+
TIDE_DTYPE const *const kyh,
|
|
1911
|
+
TIDE_DTYPE const *const kx,
|
|
1912
|
+
TIDE_DTYPE const *const kxh,
|
|
1913
|
+
int64_t const *const sources_i,
|
|
1914
|
+
int64_t const *const receivers_i,
|
|
1915
|
+
TIDE_DTYPE const rdy_h,
|
|
1916
|
+
TIDE_DTYPE const rdx_h,
|
|
1917
|
+
TIDE_DTYPE const dt_h,
|
|
1918
|
+
int64_t const nt,
|
|
1919
|
+
int64_t const n_shots_h,
|
|
1920
|
+
int64_t const ny_h,
|
|
1921
|
+
int64_t const nx_h,
|
|
1922
|
+
int64_t const n_sources_per_shot_h,
|
|
1923
|
+
int64_t const n_receivers_per_shot_h,
|
|
1924
|
+
int64_t const step_ratio_h,
|
|
1925
|
+
int64_t const storage_mode_h,
|
|
1926
|
+
int64_t const shot_bytes_uncomp_h,
|
|
1927
|
+
bool const ca_requires_grad,
|
|
1928
|
+
bool const cb_requires_grad,
|
|
1929
|
+
bool const ca_batched_h,
|
|
1930
|
+
bool const cb_batched_h,
|
|
1931
|
+
bool const cq_batched_h,
|
|
1932
|
+
int64_t const start_t,
|
|
1933
|
+
int64_t const pml_y0_h,
|
|
1934
|
+
int64_t const pml_x0_h,
|
|
1935
|
+
int64_t const pml_y1_h,
|
|
1936
|
+
int64_t const pml_x1_h,
|
|
1937
|
+
int64_t const n_threads,
|
|
1938
|
+
int64_t const device) {
|
|
1939
|
+
|
|
1940
|
+
cudaSetDevice(device);
|
|
1941
|
+
(void)dt_h;
|
|
1942
|
+
(void)n_threads;
|
|
1943
|
+
|
|
1944
|
+
int64_t const shot_numel_h = ny_h * nx_h;
|
|
1945
|
+
size_t const bytes_per_step_store =
|
|
1946
|
+
(size_t)shot_bytes_uncomp_h * (size_t)n_shots_h;
|
|
1947
|
+
bool const storage_bf16_h = (shot_bytes_uncomp_h == shot_numel_h * 2);
|
|
1948
|
+
bool const storage_fp8_h = (shot_bytes_uncomp_h == shot_numel_h);
|
|
1949
|
+
cudaStream_t copy_stream = nullptr;
|
|
1950
|
+
cudaEvent_t copy_done[NUM_BUFFERS];
|
|
1951
|
+
bool copy_in_flight[NUM_BUFFERS];
|
|
1952
|
+
for (int i = 0; i < NUM_BUFFERS; i++) copy_in_flight[i] = false;
|
|
1953
|
+
|
|
1954
|
+
#ifdef TIDE_PROFILING
|
|
1955
|
+
cudaEvent_t prof_prefetch_start, prof_prefetch_end, prof_wait_start, prof_wait_end;
|
|
1956
|
+
float total_prefetch_ms = 0.0f, total_wait_ms = 0.0f;
|
|
1957
|
+
int n_prefetches = 0, n_waits = 0;
|
|
1958
|
+
#endif
|
|
1959
|
+
|
|
1960
|
+
if (storage_mode_h == STORAGE_CPU) {
|
|
1961
|
+
gpuErrchk(cudaStreamCreateWithFlags(©_stream, cudaStreamNonBlocking));
|
|
1962
|
+
#ifdef TIDE_PROFILING
|
|
1963
|
+
PROF_EVENT_CREATE(prof_prefetch_start);
|
|
1964
|
+
PROF_EVENT_CREATE(prof_prefetch_end);
|
|
1965
|
+
PROF_EVENT_CREATE(prof_wait_start);
|
|
1966
|
+
PROF_EVENT_CREATE(prof_wait_end);
|
|
1967
|
+
for (int i = 0; i < NUM_BUFFERS; i++) {
|
|
1968
|
+
PROF_EVENT_CREATE(copy_done[i]);
|
|
1969
|
+
}
|
|
1970
|
+
#else
|
|
1971
|
+
for (int i = 0; i < NUM_BUFFERS; i++) {
|
|
1972
|
+
gpuErrchk(cudaEventCreateWithFlags(©_done[i], cudaEventDisableTiming));
|
|
1973
|
+
}
|
|
1974
|
+
#endif
|
|
1975
|
+
}
|
|
1976
|
+
|
|
1977
|
+
// Copy constants to device with caching to avoid redundant copies
|
|
1978
|
+
static TIDE_DTYPE cached_rdy3 = 0, cached_rdx3 = 0;
|
|
1979
|
+
static int64_t cached_n_shots3 = -1, cached_ny3 = -1, cached_nx3 = -1;
|
|
1980
|
+
static int64_t cached_shot_numel3 = -1, cached_n_sources_per_shot3 = -1, cached_n_receivers_per_shot3 = -1;
|
|
1981
|
+
static int64_t cached_pml_y03 = -1, cached_pml_y13 = -1;
|
|
1982
|
+
static int64_t cached_pml_x03 = -1, cached_pml_x13 = -1;
|
|
1983
|
+
static bool cached_ca_batched3 = false, cached_cb_batched3 = false, cached_cq_batched3 = false;
|
|
1984
|
+
static int64_t cached_device3 = -1;
|
|
1985
|
+
static bool first_call3 = true;
|
|
1986
|
+
|
|
1987
|
+
if (first_call3 || cached_device3 != device || cached_rdy3 != rdy_h || cached_rdx3 != rdx_h ||
|
|
1988
|
+
cached_n_shots3 != n_shots_h || cached_ny3 != ny_h || cached_nx3 != nx_h ||
|
|
1989
|
+
cached_shot_numel3 != shot_numel_h || cached_n_sources_per_shot3 != n_sources_per_shot_h ||
|
|
1990
|
+
cached_n_receivers_per_shot3 != n_receivers_per_shot_h ||
|
|
1991
|
+
cached_pml_y03 != pml_y0_h || cached_pml_y13 != pml_y1_h ||
|
|
1992
|
+
cached_pml_x03 != pml_x0_h || cached_pml_x13 != pml_x1_h ||
|
|
1993
|
+
cached_ca_batched3 != ca_batched_h || cached_cb_batched3 != cb_batched_h ||
|
|
1994
|
+
cached_cq_batched3 != cq_batched_h) {
|
|
1995
|
+
|
|
1996
|
+
cudaMemcpyToSymbol(rdy, &rdy_h, sizeof(TIDE_DTYPE));
|
|
1997
|
+
cudaMemcpyToSymbol(rdx, &rdx_h, sizeof(TIDE_DTYPE));
|
|
1998
|
+
cudaMemcpyToSymbol(n_shots, &n_shots_h, sizeof(int64_t));
|
|
1999
|
+
cudaMemcpyToSymbol(ny, &ny_h, sizeof(int64_t));
|
|
2000
|
+
cudaMemcpyToSymbol(nx, &nx_h, sizeof(int64_t));
|
|
2001
|
+
cudaMemcpyToSymbol(shot_numel, &shot_numel_h, sizeof(int64_t));
|
|
2002
|
+
cudaMemcpyToSymbol(n_sources_per_shot, &n_sources_per_shot_h, sizeof(int64_t));
|
|
2003
|
+
cudaMemcpyToSymbol(n_receivers_per_shot, &n_receivers_per_shot_h, sizeof(int64_t));
|
|
2004
|
+
cudaMemcpyToSymbol(pml_y0, &pml_y0_h, sizeof(int64_t));
|
|
2005
|
+
cudaMemcpyToSymbol(pml_y1, &pml_y1_h, sizeof(int64_t));
|
|
2006
|
+
cudaMemcpyToSymbol(pml_x0, &pml_x0_h, sizeof(int64_t));
|
|
2007
|
+
cudaMemcpyToSymbol(pml_x1, &pml_x1_h, sizeof(int64_t));
|
|
2008
|
+
cudaMemcpyToSymbol(ca_batched, &ca_batched_h, sizeof(bool));
|
|
2009
|
+
cudaMemcpyToSymbol(cb_batched, &cb_batched_h, sizeof(bool));
|
|
2010
|
+
cudaMemcpyToSymbol(cq_batched, &cq_batched_h, sizeof(bool));
|
|
2011
|
+
|
|
2012
|
+
cached_rdy3 = rdy_h; cached_rdx3 = rdx_h;
|
|
2013
|
+
cached_n_shots3 = n_shots_h; cached_ny3 = ny_h; cached_nx3 = nx_h;
|
|
2014
|
+
cached_shot_numel3 = shot_numel_h; cached_n_sources_per_shot3 = n_sources_per_shot_h;
|
|
2015
|
+
cached_n_receivers_per_shot3 = n_receivers_per_shot_h;
|
|
2016
|
+
cached_pml_y03 = pml_y0_h; cached_pml_y13 = pml_y1_h;
|
|
2017
|
+
cached_pml_x03 = pml_x0_h; cached_pml_x13 = pml_x1_h;
|
|
2018
|
+
cached_ca_batched3 = ca_batched_h; cached_cb_batched3 = cb_batched_h;
|
|
2019
|
+
cached_cq_batched3 = cq_batched_h;
|
|
2020
|
+
cached_device3 = device;
|
|
2021
|
+
first_call3 = false;
|
|
2022
|
+
}
|
|
2023
|
+
|
|
2024
|
+
dim3 dimBlock(32, 8, 1);
|
|
2025
|
+
int64_t gridx = (nx_h - 2 * FD_PAD + 2 + dimBlock.x - 1) / dimBlock.x;
|
|
2026
|
+
int64_t gridy = (ny_h - 2 * FD_PAD + 2 + dimBlock.y - 1) / dimBlock.y;
|
|
2027
|
+
int64_t gridz = n_shots_h;
|
|
2028
|
+
dim3 dimGrid(gridx, gridy, gridz);
|
|
2029
|
+
|
|
2030
|
+
dim3 dimBlock_sources(32, 1, 1);
|
|
2031
|
+
dim3 dimGrid_sources(
|
|
2032
|
+
(n_sources_per_shot_h + dimBlock_sources.x - 1) / dimBlock_sources.x,
|
|
2033
|
+
n_shots_h, 1);
|
|
2034
|
+
|
|
2035
|
+
dim3 dimBlock_receivers(32, 1, 1);
|
|
2036
|
+
dim3 dimGrid_receivers(
|
|
2037
|
+
(n_receivers_per_shot_h + dimBlock_receivers.x - 1) / dimBlock_receivers.x,
|
|
2038
|
+
n_shots_h, 1);
|
|
2039
|
+
|
|
2040
|
+
FILE *fp_ey = nullptr;
|
|
2041
|
+
FILE *fp_curl = nullptr;
|
|
2042
|
+
if (storage_mode_h == STORAGE_DISK) {
|
|
2043
|
+
if (ca_requires_grad) fp_ey = fopen(ey_filenames[0], "rb");
|
|
2044
|
+
if (cb_requires_grad) fp_curl = fopen(curl_filenames[0], "rb");
|
|
2045
|
+
}
|
|
2046
|
+
|
|
2047
|
+
auto store1_offset_bytes = [&](int64_t store_idx) -> size_t {
|
|
2048
|
+
if (storage_mode_h == STORAGE_DEVICE) {
|
|
2049
|
+
return (size_t)store_idx * bytes_per_step_store;
|
|
2050
|
+
}
|
|
2051
|
+
if (storage_mode_h == STORAGE_CPU) {
|
|
2052
|
+
return (size_t)(store_idx % NUM_BUFFERS) * bytes_per_step_store;
|
|
2053
|
+
}
|
|
2054
|
+
return 0;
|
|
2055
|
+
};
|
|
2056
|
+
|
|
2057
|
+
auto store3_offset_bytes = [&](int64_t store_idx) -> size_t {
|
|
2058
|
+
return (storage_mode_h == STORAGE_CPU)
|
|
2059
|
+
? (size_t)store_idx * bytes_per_step_store
|
|
2060
|
+
: 0;
|
|
2061
|
+
};
|
|
2062
|
+
|
|
2063
|
+
auto prefetch_snapshots = [&](int64_t store_idx, bool want_ey, bool want_curl) {
|
|
2064
|
+
if (storage_mode_h != STORAGE_CPU || (!want_ey && !want_curl)) {
|
|
2065
|
+
return;
|
|
2066
|
+
}
|
|
2067
|
+
int const store_buf = (int)(store_idx % NUM_BUFFERS);
|
|
2068
|
+
if (copy_in_flight[store_buf]) {
|
|
2069
|
+
gpuErrchk(cudaStreamWaitEvent(copy_stream, copy_done[store_buf], 0));
|
|
2070
|
+
}
|
|
2071
|
+
#ifdef TIDE_PROFILING
|
|
2072
|
+
PROF_RECORD(prof_prefetch_start, copy_stream);
|
|
2073
|
+
#endif
|
|
2074
|
+
size_t const store1_offset = store1_offset_bytes(store_idx);
|
|
2075
|
+
size_t const store3_offset = store3_offset_bytes(store_idx);
|
|
2076
|
+
void *ey_store_1_t = (uint8_t *)ey_store_1 + store1_offset;
|
|
2077
|
+
void *curl_store_1_t = (uint8_t *)curl_store_1 + store1_offset;
|
|
2078
|
+
void *ey_store_3_t = (uint8_t *)ey_store_3 + store3_offset;
|
|
2079
|
+
void *curl_store_3_t = (uint8_t *)curl_store_3 + store3_offset;
|
|
2080
|
+
if (want_ey) {
|
|
2081
|
+
gpuErrchk(cudaMemcpyAsync(
|
|
2082
|
+
ey_store_1_t, ey_store_3_t, bytes_per_step_store,
|
|
2083
|
+
cudaMemcpyHostToDevice, copy_stream));
|
|
2084
|
+
}
|
|
2085
|
+
if (want_curl) {
|
|
2086
|
+
gpuErrchk(cudaMemcpyAsync(
|
|
2087
|
+
curl_store_1_t, curl_store_3_t, bytes_per_step_store,
|
|
2088
|
+
cudaMemcpyHostToDevice, copy_stream));
|
|
2089
|
+
}
|
|
2090
|
+
#ifdef TIDE_PROFILING
|
|
2091
|
+
PROF_RECORD(prof_prefetch_end, copy_stream);
|
|
2092
|
+
#endif
|
|
2093
|
+
gpuErrchk(cudaEventRecord(copy_done[store_buf], copy_stream));
|
|
2094
|
+
copy_in_flight[store_buf] = true;
|
|
2095
|
+
#ifdef TIDE_PROFILING
|
|
2096
|
+
n_prefetches++;
|
|
2097
|
+
#endif
|
|
2098
|
+
};
|
|
2099
|
+
|
|
2100
|
+
int64_t const t_min = start_t - nt;
|
|
2101
|
+
if (storage_mode_h == STORAGE_CPU && (ca_requires_grad || cb_requires_grad)) {
|
|
2102
|
+
int64_t t_prefetch = start_t - 1;
|
|
2103
|
+
int64_t const mod = t_prefetch % step_ratio_h;
|
|
2104
|
+
if (mod != 0) t_prefetch -= mod;
|
|
2105
|
+
if (t_prefetch >= t_min) {
|
|
2106
|
+
prefetch_snapshots(
|
|
2107
|
+
t_prefetch / step_ratio_h, ca_requires_grad, cb_requires_grad);
|
|
2108
|
+
}
|
|
2109
|
+
}
|
|
2110
|
+
|
|
2111
|
+
// Time reversed loop
|
|
2112
|
+
for (int64_t t = start_t - 1; t >= start_t - nt; --t) {
|
|
2113
|
+
// Inject adjoint source (receiver residual) at receiver locations
|
|
2114
|
+
// Use add_adjoint_sources_ey which checks n_receivers_per_shot
|
|
2115
|
+
if (n_receivers_per_shot_h > 0) {
|
|
2116
|
+
add_adjoint_sources_ey<<<dimGrid_receivers, dimBlock_receivers>>>(
|
|
2117
|
+
lambda_ey, grad_r + t * n_shots_h * n_receivers_per_shot_h, receivers_i);
|
|
2118
|
+
}
|
|
2119
|
+
|
|
2120
|
+
// Record adjoint field at source locations for source gradient
|
|
2121
|
+
// Use record_adjoint_at_sources which checks n_sources_per_shot
|
|
2122
|
+
if (n_sources_per_shot_h > 0) {
|
|
2123
|
+
record_adjoint_at_sources<<<dimGrid_sources, dimBlock_sources>>>(
|
|
2124
|
+
grad_f + t * n_shots_h * n_sources_per_shot_h,
|
|
2125
|
+
lambda_ey, sources_i);
|
|
2126
|
+
}
|
|
2127
|
+
|
|
2128
|
+
int64_t const store_idx = t / step_ratio_h;
|
|
2129
|
+
bool const do_grad = (t % step_ratio_h) == 0;
|
|
2130
|
+
bool const grad_ey = do_grad && ca_requires_grad;
|
|
2131
|
+
bool const grad_curl = do_grad && cb_requires_grad;
|
|
2132
|
+
|
|
2133
|
+
size_t const store1_offset = store1_offset_bytes(store_idx);
|
|
2134
|
+
size_t const store3_offset = store3_offset_bytes(store_idx);
|
|
2135
|
+
|
|
2136
|
+
void *__restrict const ey_store_1_t =
|
|
2137
|
+
(uint8_t *)ey_store_1 + store1_offset;
|
|
2138
|
+
void *__restrict const ey_store_3_t =
|
|
2139
|
+
(uint8_t *)ey_store_3 + store3_offset;
|
|
2140
|
+
|
|
2141
|
+
void *__restrict const curl_store_1_t =
|
|
2142
|
+
(uint8_t *)curl_store_1 + store1_offset;
|
|
2143
|
+
void *__restrict const curl_store_3_t =
|
|
2144
|
+
(uint8_t *)curl_store_3 + store3_offset;
|
|
2145
|
+
|
|
2146
|
+
if (storage_mode_h == STORAGE_CPU && (grad_ey || grad_curl)) {
|
|
2147
|
+
int const store_buf = (int)(store_idx % NUM_BUFFERS);
|
|
2148
|
+
if (!copy_in_flight[store_buf]) {
|
|
2149
|
+
prefetch_snapshots(store_idx, grad_ey, grad_curl);
|
|
2150
|
+
}
|
|
2151
|
+
#ifdef TIDE_PROFILING
|
|
2152
|
+
PROF_RECORD(prof_wait_start, 0);
|
|
2153
|
+
#endif
|
|
2154
|
+
gpuErrchk(cudaStreamWaitEvent(0, copy_done[store_buf], 0));
|
|
2155
|
+
#ifdef TIDE_PROFILING
|
|
2156
|
+
PROF_RECORD(prof_wait_end, 0);
|
|
2157
|
+
gpuErrchk(cudaDeviceSynchronize());
|
|
2158
|
+
float wait_ms;
|
|
2159
|
+
PROF_ELAPSED(prof_wait_start, prof_wait_end, wait_ms);
|
|
2160
|
+
total_wait_ms += wait_ms;
|
|
2161
|
+
n_waits++;
|
|
2162
|
+
#endif
|
|
2163
|
+
copy_in_flight[store_buf] = false;
|
|
2164
|
+
} else if (storage_mode_h == STORAGE_DISK) {
|
|
2165
|
+
if (grad_ey) {
|
|
2166
|
+
storage_load_snapshot_gpu(
|
|
2167
|
+
(void *)ey_store_1_t, (void *)ey_store_3_t, fp_ey, storage_mode_h,
|
|
2168
|
+
store_idx, (size_t)shot_bytes_uncomp_h, (size_t)n_shots_h);
|
|
2169
|
+
}
|
|
2170
|
+
if (grad_curl) {
|
|
2171
|
+
storage_load_snapshot_gpu(
|
|
2172
|
+
(void *)curl_store_1_t, (void *)curl_store_3_t, fp_curl, storage_mode_h,
|
|
2173
|
+
store_idx, (size_t)shot_bytes_uncomp_h, (size_t)n_shots_h);
|
|
2174
|
+
}
|
|
2175
|
+
}
|
|
2176
|
+
|
|
2177
|
+
// Backward λ_H fields update
|
|
2178
|
+
backward_kernel_lambda_h<<<dimGrid, dimBlock>>>(
|
|
2179
|
+
cb, lambda_ey, lambda_hx, lambda_hz,
|
|
2180
|
+
m_lambda_ey_x, m_lambda_ey_z,
|
|
2181
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2182
|
+
ky, kyh, kx, kxh);
|
|
2183
|
+
|
|
2184
|
+
// Backward λ_Ey update (specialized kernel when no gradient is needed).
|
|
2185
|
+
if (grad_ey || grad_curl) {
|
|
2186
|
+
if (storage_fp8_h) {
|
|
2187
|
+
backward_kernel_lambda_e_with_grad_fp8<<<dimGrid, dimBlock>>>(
|
|
2188
|
+
ca, cq, lambda_hx, lambda_hz, lambda_ey,
|
|
2189
|
+
m_lambda_hx_z, m_lambda_hz_x,
|
|
2190
|
+
grad_ey ? (uint8_t const *)ey_store_1_t : nullptr,
|
|
2191
|
+
grad_curl ? (uint8_t const *)curl_store_1_t : nullptr,
|
|
2192
|
+
grad_ca_shot, grad_cb_shot,
|
|
2193
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2194
|
+
ky, kyh, kx, kxh,
|
|
2195
|
+
grad_ey, grad_curl,
|
|
2196
|
+
step_ratio_h);
|
|
2197
|
+
} else if (storage_bf16_h) {
|
|
2198
|
+
backward_kernel_lambda_e_with_grad_bf16<<<dimGrid, dimBlock>>>(
|
|
2199
|
+
ca, cq, lambda_hx, lambda_hz, lambda_ey,
|
|
2200
|
+
m_lambda_hx_z, m_lambda_hz_x,
|
|
2201
|
+
grad_ey ? (__nv_bfloat16 const *)ey_store_1_t : nullptr,
|
|
2202
|
+
grad_curl ? (__nv_bfloat16 const *)curl_store_1_t : nullptr,
|
|
2203
|
+
grad_ca_shot, grad_cb_shot,
|
|
2204
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2205
|
+
ky, kyh, kx, kxh,
|
|
2206
|
+
grad_ey, grad_curl,
|
|
2207
|
+
step_ratio_h);
|
|
2208
|
+
} else {
|
|
2209
|
+
backward_kernel_lambda_e_with_grad<<<dimGrid, dimBlock>>>(
|
|
2210
|
+
ca, cq, lambda_hx, lambda_hz, lambda_ey,
|
|
2211
|
+
m_lambda_hx_z, m_lambda_hz_x,
|
|
2212
|
+
grad_ey ? (TIDE_DTYPE const *)ey_store_1_t : nullptr,
|
|
2213
|
+
grad_curl ? (TIDE_DTYPE const *)curl_store_1_t : nullptr,
|
|
2214
|
+
grad_ca_shot, grad_cb_shot,
|
|
2215
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2216
|
+
ky, kyh, kx, kxh,
|
|
2217
|
+
grad_ey, grad_curl,
|
|
2218
|
+
step_ratio_h);
|
|
2219
|
+
}
|
|
2220
|
+
} else {
|
|
2221
|
+
backward_kernel_lambda_e<<<dimGrid, dimBlock>>>(
|
|
2222
|
+
ca, cq, lambda_hx, lambda_hz, lambda_ey,
|
|
2223
|
+
m_lambda_hx_z, m_lambda_hz_x,
|
|
2224
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2225
|
+
ky, kyh, kx, kxh);
|
|
2226
|
+
}
|
|
2227
|
+
|
|
2228
|
+
if (storage_mode_h == STORAGE_CPU && do_grad &&
|
|
2229
|
+
(ca_requires_grad || cb_requires_grad)) {
|
|
2230
|
+
int64_t const next_t = t - step_ratio_h;
|
|
2231
|
+
if (next_t >= t_min) {
|
|
2232
|
+
prefetch_snapshots(store_idx - 1, ca_requires_grad, cb_requires_grad);
|
|
2233
|
+
}
|
|
2234
|
+
}
|
|
2235
|
+
}
|
|
2236
|
+
|
|
2237
|
+
if (storage_mode_h == STORAGE_CPU) {
|
|
2238
|
+
gpuErrchk(cudaStreamSynchronize(copy_stream));
|
|
2239
|
+
#ifdef TIDE_PROFILING
|
|
2240
|
+
// Compute and print profiling statistics
|
|
2241
|
+
if (n_prefetches > 0) {
|
|
2242
|
+
gpuErrchk(cudaDeviceSynchronize());
|
|
2243
|
+
float avg_prefetch_ms = 0.0f;
|
|
2244
|
+
for (int i = 0; i < NUM_BUFFERS; i++) {
|
|
2245
|
+
float ms;
|
|
2246
|
+
// Note: per-copy timing would require more events
|
|
2247
|
+
}
|
|
2248
|
+
PROF_PRINT("Backward H2D prefetch count", (float)n_prefetches);
|
|
2249
|
+
}
|
|
2250
|
+
if (n_waits > 0) {
|
|
2251
|
+
float avg_wait_ms = total_wait_ms / n_waits;
|
|
2252
|
+
PROF_PRINT("Backward avg wait time", avg_wait_ms);
|
|
2253
|
+
PROF_PRINT("Backward total wait time", total_wait_ms);
|
|
2254
|
+
}
|
|
2255
|
+
PROF_EVENT_CREATE(prof_prefetch_start); // Dummy to avoid unused warning
|
|
2256
|
+
cudaEventDestroy(prof_prefetch_start);
|
|
2257
|
+
cudaEventDestroy(prof_prefetch_end);
|
|
2258
|
+
cudaEventDestroy(prof_wait_start);
|
|
2259
|
+
cudaEventDestroy(prof_wait_end);
|
|
2260
|
+
#endif
|
|
2261
|
+
for (int i = 0; i < NUM_BUFFERS; i++) {
|
|
2262
|
+
gpuErrchk(cudaEventDestroy(copy_done[i]));
|
|
2263
|
+
}
|
|
2264
|
+
gpuErrchk(cudaStreamDestroy(copy_stream));
|
|
2265
|
+
}
|
|
2266
|
+
|
|
2267
|
+
if (fp_ey != nullptr) fclose(fp_ey);
|
|
2268
|
+
if (fp_curl != nullptr) fclose(fp_curl);
|
|
2269
|
+
|
|
2270
|
+
// Combine per-shot gradients (only if not batched - batched case keeps per-shot grads)
|
|
2271
|
+
dim3 dimBlock_combine(32, 32, 1);
|
|
2272
|
+
dim3 dimGrid_combine(
|
|
2273
|
+
(nx_h - 2 * FD_PAD + dimBlock_combine.x - 1) / dimBlock_combine.x,
|
|
2274
|
+
(ny_h - 2 * FD_PAD + dimBlock_combine.y - 1) / dimBlock_combine.y, 1);
|
|
2275
|
+
|
|
2276
|
+
if (ca_requires_grad && !ca_batched_h) {
|
|
2277
|
+
combine_grad<<<dimGrid_combine, dimBlock_combine>>>(grad_ca, grad_ca_shot);
|
|
2278
|
+
}
|
|
2279
|
+
if (cb_requires_grad && !cb_batched_h) {
|
|
2280
|
+
combine_grad<<<dimGrid_combine, dimBlock_combine>>>(grad_cb, grad_cb_shot);
|
|
2281
|
+
}
|
|
2282
|
+
|
|
2283
|
+
if ((grad_eps != nullptr || grad_sigma != nullptr) && (ca_requires_grad || cb_requires_grad)) {
|
|
2284
|
+
dim3 dimBlock_conv(32, 8, 1);
|
|
2285
|
+
dim3 dimGrid_conv(
|
|
2286
|
+
(nx_h + dimBlock_conv.x - 1) / dimBlock_conv.x,
|
|
2287
|
+
(ny_h + dimBlock_conv.y - 1) / dimBlock_conv.y,
|
|
2288
|
+
ca_batched_h ? n_shots_h : 1);
|
|
2289
|
+
convert_grad_ca_cb_to_eps_sigma<<<dimGrid_conv, dimBlock_conv>>>(
|
|
2290
|
+
ca, cb, grad_ca, grad_cb, grad_ca_shot, grad_cb_shot,
|
|
2291
|
+
grad_eps, grad_sigma, dt_h,
|
|
2292
|
+
ca_requires_grad, cb_requires_grad,
|
|
2293
|
+
ca_batched_h, cb_batched_h);
|
|
2294
|
+
}
|
|
2295
|
+
|
|
2296
|
+
gpuErrchk(cudaPeekAtLastError());
|
|
2297
|
+
}
|