tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tide/__init__.py +65 -0
- tide/autograd_utils.py +26 -0
- tide/backend_utils.py +536 -0
- tide/callbacks.py +348 -0
- tide/cfl.py +64 -0
- tide/csrc/CMakeLists.txt +263 -0
- tide/csrc/common_cpu.h +31 -0
- tide/csrc/common_gpu.h +56 -0
- tide/csrc/maxwell.c +2133 -0
- tide/csrc/maxwell.cu +2297 -0
- tide/csrc/maxwell_born.cu +0 -0
- tide/csrc/staggered_grid.h +175 -0
- tide/csrc/staggered_grid_3d.h +124 -0
- tide/csrc/storage_utils.c +78 -0
- tide/csrc/storage_utils.cu +135 -0
- tide/csrc/storage_utils.h +36 -0
- tide/grid_utils.py +31 -0
- tide/maxwell.py +2651 -0
- tide/padding.py +139 -0
- tide/resampling.py +246 -0
- tide/staggered.py +567 -0
- tide/storage.py +131 -0
- tide/tide/libtide_C.so +0 -0
- tide/utils.py +274 -0
- tide/validation.py +71 -0
- tide/wavelets.py +72 -0
- tide_gpr-0.0.9.dist-info/METADATA +256 -0
- tide_gpr-0.0.9.dist-info/RECORD +31 -0
- tide_gpr-0.0.9.dist-info/WHEEL +5 -0
- tide_gpr-0.0.9.dist-info/licenses/LICENSE +46 -0
- tide_gpr.libs/libgomp-24e2ab19.so.1.0.0 +0 -0
tide/csrc/maxwell.c
ADDED
|
@@ -0,0 +1,2133 @@
|
|
|
1
|
+
/*
|
|
2
|
+
* Maxwell wave equation propagator (CPU implementation)
|
|
3
|
+
*
|
|
4
|
+
* This file contains the CPU implementation of the 2D TM Maxwell equations
|
|
5
|
+
* propagator with complete Adjoint State Method (ASM) support for gradient
|
|
6
|
+
* computation.
|
|
7
|
+
*
|
|
8
|
+
* TM mode fields: Ey (electric), Hx, Hz (magnetic)
|
|
9
|
+
*
|
|
10
|
+
* Adjoint State Method for Maxwell TM equations:
|
|
11
|
+
* ================================================
|
|
12
|
+
* Forward equations (discrete):
|
|
13
|
+
* E_y^{n+1} = C_a * E_y^n + C_b * (∂H_z/∂x - ∂H_x/∂z)
|
|
14
|
+
* H_x^{n+1/2} = H_x^{n-1/2} - C_q * ∂E_y/∂z
|
|
15
|
+
* H_z^{n+1/2} = H_z^{n-1/2} + C_q * ∂E_y/∂x
|
|
16
|
+
*
|
|
17
|
+
* Adjoint equations (time-reversed):
|
|
18
|
+
* λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * (∂λ_Hz/∂x - ∂λ_Hx/∂z) + residual_injection
|
|
19
|
+
* λ_Hx^{n-1/2} = λ_Hx^{n+1/2} - C_b * ∂λ_Ey/∂z
|
|
20
|
+
* λ_Hz^{n-1/2} = λ_Hz^{n+1/2} + C_b * ∂λ_Ey/∂x
|
|
21
|
+
*
|
|
22
|
+
* Model gradients:
|
|
23
|
+
* ∂J/∂C_a = Σ_n E_y^n * λ_Ey^{n+1}
|
|
24
|
+
* ∂J/∂C_b = Σ_n curl_H^n * λ_Ey^{n+1}
|
|
25
|
+
*
|
|
26
|
+
* Storage requirements:
|
|
27
|
+
* - ey_store: E_y field at each step_ratio time step [nt/step_ratio, n_shots, ny, nx]
|
|
28
|
+
* - curl_h_store: (∂H_z/∂x - ∂H_x/∂z) at each step_ratio time step [nt/step_ratio, n_shots, ny, nx]
|
|
29
|
+
*/
|
|
30
|
+
|
|
31
|
+
#include <stdio.h>
|
|
32
|
+
#include <stdint.h>
|
|
33
|
+
#include <stdbool.h>
|
|
34
|
+
#include <stdlib.h>
|
|
35
|
+
#include <string.h>
|
|
36
|
+
#include <math.h>
|
|
37
|
+
|
|
38
|
+
#ifdef _OPENMP
|
|
39
|
+
#include <omp.h>
|
|
40
|
+
#endif
|
|
41
|
+
|
|
42
|
+
#include "common_cpu.h"
|
|
43
|
+
#include "staggered_grid.h"
|
|
44
|
+
#include "storage_utils.h"
|
|
45
|
+
|
|
46
|
+
#define CAT_I(name, accuracy, dtype, device) \
|
|
47
|
+
maxwell_tm_##accuracy##_##dtype##_##name##_##device
|
|
48
|
+
#define CAT(name, accuracy, dtype, device) \
|
|
49
|
+
CAT_I(name, accuracy, dtype, device)
|
|
50
|
+
#define FUNC(name) CAT(name, TIDE_STENCIL, TIDE_DTYPE, cpu)
|
|
51
|
+
|
|
52
|
+
// 2D indexing macros
|
|
53
|
+
#define IDX(y, x) ((y) * nx + (x))
|
|
54
|
+
#define IDX_SHOT(shot, y, x) ((shot) * shot_numel + (y) * nx + (x))
|
|
55
|
+
|
|
56
|
+
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
57
|
+
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
58
|
+
|
|
59
|
+
// Vacuum permittivity (F/m) to convert dL/d(epsilon_abs) -> dL/d(epsilon_r)
|
|
60
|
+
#define EP0 ((TIDE_DTYPE)8.8541878128e-12)
|
|
61
|
+
|
|
62
|
+
typedef uint16_t tide_bfloat16;
|
|
63
|
+
|
|
64
|
+
static inline tide_bfloat16 tide_float_to_bf16(float value) {
|
|
65
|
+
union {
|
|
66
|
+
float f;
|
|
67
|
+
uint32_t u;
|
|
68
|
+
} tmp;
|
|
69
|
+
tmp.f = value;
|
|
70
|
+
uint32_t lsb = (tmp.u >> 16) & 1u;
|
|
71
|
+
tmp.u += 0x7FFFu + lsb;
|
|
72
|
+
return (tide_bfloat16)(tmp.u >> 16);
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
static inline float tide_bf16_to_float(tide_bfloat16 value) {
|
|
76
|
+
union {
|
|
77
|
+
uint32_t u;
|
|
78
|
+
float f;
|
|
79
|
+
} tmp;
|
|
80
|
+
tmp.u = ((uint32_t)value) << 16;
|
|
81
|
+
return tmp.f;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
// FP8 E4M3 format (1 sign bit, 4 exponent bits, 3 mantissa bits)
|
|
85
|
+
// Dynamic range: ~10^-5 to ~448
|
|
86
|
+
typedef uint8_t tide_fp8_e4m3;
|
|
87
|
+
|
|
88
|
+
static inline tide_fp8_e4m3 tide_float_to_fp8_e4m3(float value) {
|
|
89
|
+
if (value == 0.0f) {
|
|
90
|
+
return 0;
|
|
91
|
+
}
|
|
92
|
+
|
|
93
|
+
union {
|
|
94
|
+
float f;
|
|
95
|
+
uint32_t u;
|
|
96
|
+
} tmp;
|
|
97
|
+
tmp.f = value;
|
|
98
|
+
|
|
99
|
+
uint8_t sign = (tmp.u >> 31) ? 0x80 : 0;
|
|
100
|
+
float ax = fabsf(value);
|
|
101
|
+
|
|
102
|
+
// Handle infinity/NaN - saturate to max representable
|
|
103
|
+
if (!isfinite(ax)) {
|
|
104
|
+
return (uint8_t)(sign | 0x7F);
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
// Use frexpf to extract mantissa and exponent
|
|
108
|
+
int exp;
|
|
109
|
+
float m = frexpf(ax, &exp); // ax = m * 2^exp, m in [0.5, 1)
|
|
110
|
+
int e = exp - 1;
|
|
111
|
+
int exp_field = e + 7; // Bias of 7 for E4M3
|
|
112
|
+
int mant = 0;
|
|
113
|
+
|
|
114
|
+
if (exp_field <= 0) {
|
|
115
|
+
// Subnormal - compute mantissa directly
|
|
116
|
+
// Smallest normal: 2^-6, subnormal range: 2^-9 to 2^-6
|
|
117
|
+
mant = (int)roundf(ax * 512.0f);
|
|
118
|
+
if (mant <= 0) {
|
|
119
|
+
return sign; // Underflow to zero
|
|
120
|
+
}
|
|
121
|
+
if (mant > 7) {
|
|
122
|
+
mant = 7; // Clamp to max subnormal
|
|
123
|
+
}
|
|
124
|
+
exp_field = 0;
|
|
125
|
+
} else if (exp_field >= 0xF) {
|
|
126
|
+
// Overflow - saturate to max normal
|
|
127
|
+
exp_field = 0xE;
|
|
128
|
+
mant = 7;
|
|
129
|
+
} else {
|
|
130
|
+
// Normal number
|
|
131
|
+
float frac = m * 2.0f - 1.0f; // Extract fractional part
|
|
132
|
+
mant = (int)roundf(frac * 8.0f);
|
|
133
|
+
if (mant == 8) {
|
|
134
|
+
// Rounding overflow - increment exponent
|
|
135
|
+
mant = 0;
|
|
136
|
+
exp_field += 1;
|
|
137
|
+
if (exp_field >= 0xF) {
|
|
138
|
+
exp_field = 0xE;
|
|
139
|
+
mant = 7;
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
|
|
144
|
+
return (uint8_t)(sign | ((uint8_t)exp_field << 3) | (uint8_t)(mant & 0x7));
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
static inline float tide_fp8_e4m3_to_float(tide_fp8_e4m3 value) {
|
|
148
|
+
if (value == 0) {
|
|
149
|
+
return 0.0f;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
int sign = value & 0x80;
|
|
153
|
+
int exp_field = (value >> 3) & 0xF;
|
|
154
|
+
int mant = value & 0x7;
|
|
155
|
+
float val;
|
|
156
|
+
|
|
157
|
+
if (exp_field == 0) {
|
|
158
|
+
// Subnormal
|
|
159
|
+
float frac = (float)mant / 8.0f;
|
|
160
|
+
val = ldexpf(frac, -6); // 2^-6 * (mant/8)
|
|
161
|
+
} else {
|
|
162
|
+
// Normal
|
|
163
|
+
float frac = 1.0f + (float)mant / 8.0f;
|
|
164
|
+
val = ldexpf(frac, exp_field - 7);
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
return sign ? -val : val;
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
// Field access macros for stencil operations
|
|
171
|
+
#define EY(dy, dx) ey[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
172
|
+
#define HX(dy, dx) hx[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
173
|
+
#define HZ(dy, dx) hz[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
174
|
+
|
|
175
|
+
// Adjoint field access macros
|
|
176
|
+
#define LAMBDA_EY(dy, dx) lambda_ey[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
177
|
+
#define LAMBDA_HX(dy, dx) lambda_hx[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
178
|
+
#define LAMBDA_HZ(dy, dx) lambda_hz[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
179
|
+
|
|
180
|
+
// Material parameter access macros
|
|
181
|
+
#define CA(dy, dx) (ca_batched ? ca[IDX_SHOT(shot_idx, y + (dy), x + (dx))] : ca[IDX(y + (dy), x + (dx))])
|
|
182
|
+
#define CB(dy, dx) (cb_batched ? cb[IDX_SHOT(shot_idx, y + (dy), x + (dx))] : cb[IDX(y + (dy), x + (dx))])
|
|
183
|
+
#define CQ(dy, dx) (cq_batched ? cq[IDX_SHOT(shot_idx, y + (dy), x + (dx))] : cq[IDX(y + (dy), x + (dx))])
|
|
184
|
+
|
|
185
|
+
// PML memory variable macros
|
|
186
|
+
#define M_EY_X(dy, dx) m_ey_x[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
187
|
+
#define M_EY_Z(dy, dx) m_ey_z[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
188
|
+
#define M_HX_Z(dy, dx) m_hx_z[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
189
|
+
#define M_HZ_X(dy, dx) m_hz_x[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
190
|
+
|
|
191
|
+
// Adjoint PML memory variable macros
|
|
192
|
+
#define M_LAMBDA_EY_X(dy, dx) m_lambda_ey_x[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
193
|
+
#define M_LAMBDA_EY_Z(dy, dx) m_lambda_ey_z[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
194
|
+
#define M_LAMBDA_HX_Z(dy, dx) m_lambda_hx_z[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
195
|
+
#define M_LAMBDA_HZ_X(dy, dx) m_lambda_hz_x[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
|
|
196
|
+
|
|
197
|
+
static void convert_grad_ca_cb_to_eps_sigma(
|
|
198
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
199
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
200
|
+
TIDE_DTYPE const *__restrict const grad_ca,
|
|
201
|
+
TIDE_DTYPE const *__restrict const grad_cb,
|
|
202
|
+
TIDE_DTYPE *__restrict const grad_eps,
|
|
203
|
+
TIDE_DTYPE *__restrict const grad_sigma,
|
|
204
|
+
TIDE_DTYPE const dt,
|
|
205
|
+
int64_t const n_shots,
|
|
206
|
+
int64_t const ny,
|
|
207
|
+
int64_t const nx,
|
|
208
|
+
bool const ca_batched,
|
|
209
|
+
bool const cb_batched,
|
|
210
|
+
bool const ca_requires_grad,
|
|
211
|
+
bool const cb_requires_grad) {
|
|
212
|
+
if ((grad_eps == NULL && grad_sigma == NULL) || (!ca_requires_grad && !cb_requires_grad)) {
|
|
213
|
+
return;
|
|
214
|
+
}
|
|
215
|
+
|
|
216
|
+
int64_t const shot_numel = ny * nx;
|
|
217
|
+
int64_t const out_shots = ca_batched ? n_shots : 1;
|
|
218
|
+
TIDE_DTYPE const inv_dt = (TIDE_DTYPE)1 / dt;
|
|
219
|
+
|
|
220
|
+
TIDE_OMP_INDEX shot_idx;
|
|
221
|
+
TIDE_OMP_PARALLEL_FOR
|
|
222
|
+
for (shot_idx = 0; shot_idx < out_shots; ++shot_idx) {
|
|
223
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
224
|
+
int64_t const out_offset = ca_batched ? shot_offset : 0;
|
|
225
|
+
int64_t const ca_offset = ca_batched ? shot_offset : 0;
|
|
226
|
+
int64_t const cb_offset = cb_batched ? shot_offset : 0;
|
|
227
|
+
TIDE_DTYPE const *__restrict const ca_ptr = ca + ca_offset;
|
|
228
|
+
TIDE_DTYPE const *__restrict const cb_ptr = cb + cb_offset;
|
|
229
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
230
|
+
for (int64_t y = 0; y < ny; ++y) {
|
|
231
|
+
for (int64_t x = 0; x < nx; ++x) {
|
|
232
|
+
int64_t const idx = IDX(y, x);
|
|
233
|
+
int64_t const out_idx = out_offset + idx;
|
|
234
|
+
|
|
235
|
+
TIDE_DTYPE const ca_val = ca_ptr[idx];
|
|
236
|
+
TIDE_DTYPE const cb_val = cb_ptr[idx];
|
|
237
|
+
|
|
238
|
+
TIDE_DTYPE const grad_ca_val =
|
|
239
|
+
(ca_requires_grad && grad_ca != NULL) ? grad_ca[out_idx] : (TIDE_DTYPE)0;
|
|
240
|
+
TIDE_DTYPE const grad_cb_val =
|
|
241
|
+
(cb_requires_grad && grad_cb != NULL) ? grad_cb[out_idx] : (TIDE_DTYPE)0;
|
|
242
|
+
|
|
243
|
+
TIDE_DTYPE const cb_sq = cb_val * cb_val;
|
|
244
|
+
TIDE_DTYPE const dca_de = ((TIDE_DTYPE)1 - ca_val) * cb_val * inv_dt;
|
|
245
|
+
TIDE_DTYPE const dcb_de = -cb_sq * inv_dt;
|
|
246
|
+
TIDE_DTYPE const dca_ds = -((TIDE_DTYPE)0.5) * ((TIDE_DTYPE)1 + ca_val) * cb_val;
|
|
247
|
+
TIDE_DTYPE const dcb_ds = -((TIDE_DTYPE)0.5) * cb_sq;
|
|
248
|
+
|
|
249
|
+
if (grad_eps != NULL) {
|
|
250
|
+
TIDE_DTYPE const grad_e = grad_ca_val * dca_de + grad_cb_val * dcb_de;
|
|
251
|
+
grad_eps[out_idx] = grad_e * EP0;
|
|
252
|
+
}
|
|
253
|
+
if (grad_sigma != NULL) {
|
|
254
|
+
grad_sigma[out_idx] = grad_ca_val * dca_ds + grad_cb_val * dcb_ds;
|
|
255
|
+
}
|
|
256
|
+
}
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
static void add_sources_ey(
|
|
263
|
+
TIDE_DTYPE *__restrict const ey,
|
|
264
|
+
TIDE_DTYPE const *__restrict const f,
|
|
265
|
+
int64_t const *__restrict const sources_i,
|
|
266
|
+
int64_t const n_shots,
|
|
267
|
+
int64_t const shot_numel,
|
|
268
|
+
int64_t const n_sources_per_shot) {
|
|
269
|
+
|
|
270
|
+
TIDE_OMP_INDEX shot_idx;
|
|
271
|
+
TIDE_OMP_PARALLEL_FOR
|
|
272
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
273
|
+
TIDE_OMP_SIMD
|
|
274
|
+
for (int64_t source_idx = 0; source_idx < n_sources_per_shot; ++source_idx) {
|
|
275
|
+
int64_t k = shot_idx * n_sources_per_shot + source_idx;
|
|
276
|
+
if (sources_i[k] >= 0) {
|
|
277
|
+
ey[shot_idx * shot_numel + sources_i[k]] += f[k];
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
}
|
|
282
|
+
|
|
283
|
+
static void subtract_sources_ey(
|
|
284
|
+
TIDE_DTYPE *__restrict const ey,
|
|
285
|
+
TIDE_DTYPE const *__restrict const f,
|
|
286
|
+
int64_t const *__restrict const sources_i,
|
|
287
|
+
int64_t const n_shots,
|
|
288
|
+
int64_t const shot_numel,
|
|
289
|
+
int64_t const n_sources_per_shot) {
|
|
290
|
+
|
|
291
|
+
TIDE_OMP_INDEX shot_idx;
|
|
292
|
+
TIDE_OMP_PARALLEL_FOR
|
|
293
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
294
|
+
TIDE_OMP_SIMD
|
|
295
|
+
for (int64_t source_idx = 0; source_idx < n_sources_per_shot; ++source_idx) {
|
|
296
|
+
int64_t k = shot_idx * n_sources_per_shot + source_idx;
|
|
297
|
+
if (sources_i[k] >= 0) {
|
|
298
|
+
ey[shot_idx * shot_numel + sources_i[k]] -= f[k];
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
static void record_receivers_ey(
|
|
306
|
+
TIDE_DTYPE *__restrict const r,
|
|
307
|
+
TIDE_DTYPE const *__restrict const ey,
|
|
308
|
+
int64_t const *__restrict const receivers_i,
|
|
309
|
+
int64_t const n_shots,
|
|
310
|
+
int64_t const shot_numel,
|
|
311
|
+
int64_t const n_receivers_per_shot) {
|
|
312
|
+
|
|
313
|
+
TIDE_OMP_INDEX shot_idx;
|
|
314
|
+
TIDE_OMP_PARALLEL_FOR
|
|
315
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
316
|
+
TIDE_OMP_SIMD
|
|
317
|
+
for (int64_t receiver_idx = 0; receiver_idx < n_receivers_per_shot; ++receiver_idx) {
|
|
318
|
+
int64_t k = shot_idx * n_receivers_per_shot + receiver_idx;
|
|
319
|
+
if (receivers_i[k] >= 0) {
|
|
320
|
+
r[k] = ey[shot_idx * shot_numel + receivers_i[k]];
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
}
|
|
325
|
+
|
|
326
|
+
static void gather_boundary_3_cpu(
|
|
327
|
+
TIDE_DTYPE const *__restrict const ey,
|
|
328
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
329
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
330
|
+
TIDE_DTYPE *__restrict const bey,
|
|
331
|
+
TIDE_DTYPE *__restrict const bhx,
|
|
332
|
+
TIDE_DTYPE *__restrict const bhz,
|
|
333
|
+
int64_t const *__restrict const boundary_indices,
|
|
334
|
+
int64_t const boundary_numel,
|
|
335
|
+
int64_t const n_shots,
|
|
336
|
+
int64_t const shot_numel) {
|
|
337
|
+
|
|
338
|
+
TIDE_OMP_INDEX shot_idx;
|
|
339
|
+
TIDE_OMP_PARALLEL_FOR
|
|
340
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
341
|
+
TIDE_OMP_SIMD
|
|
342
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
343
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
344
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
345
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
346
|
+
bey[store_offset] = ey[field_offset];
|
|
347
|
+
bhx[store_offset] = hx[field_offset];
|
|
348
|
+
bhz[store_offset] = hz[field_offset];
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
|
|
353
|
+
static void scatter_boundary_cpu(
|
|
354
|
+
TIDE_DTYPE *__restrict const field,
|
|
355
|
+
TIDE_DTYPE const *__restrict const store,
|
|
356
|
+
int64_t const *__restrict const boundary_indices,
|
|
357
|
+
int64_t const boundary_numel,
|
|
358
|
+
int64_t const n_shots,
|
|
359
|
+
int64_t const shot_numel) {
|
|
360
|
+
|
|
361
|
+
TIDE_OMP_INDEX shot_idx;
|
|
362
|
+
TIDE_OMP_PARALLEL_FOR
|
|
363
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
364
|
+
TIDE_OMP_SIMD
|
|
365
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
366
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
367
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
368
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
369
|
+
field[field_offset] = store[store_offset];
|
|
370
|
+
}
|
|
371
|
+
}
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
static void scatter_boundary_2_cpu(
|
|
375
|
+
TIDE_DTYPE *__restrict const hx,
|
|
376
|
+
TIDE_DTYPE *__restrict const hz,
|
|
377
|
+
TIDE_DTYPE const *__restrict const bhx,
|
|
378
|
+
TIDE_DTYPE const *__restrict const bhz,
|
|
379
|
+
int64_t const *__restrict const boundary_indices,
|
|
380
|
+
int64_t const boundary_numel,
|
|
381
|
+
int64_t const n_shots,
|
|
382
|
+
int64_t const shot_numel) {
|
|
383
|
+
|
|
384
|
+
TIDE_OMP_INDEX shot_idx;
|
|
385
|
+
TIDE_OMP_PARALLEL_FOR
|
|
386
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
387
|
+
TIDE_OMP_SIMD
|
|
388
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
389
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
390
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
391
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
392
|
+
hx[field_offset] = bhx[store_offset];
|
|
393
|
+
hz[field_offset] = bhz[store_offset];
|
|
394
|
+
}
|
|
395
|
+
}
|
|
396
|
+
}
|
|
397
|
+
|
|
398
|
+
static void gather_boundary_3_cpu_bf16(
|
|
399
|
+
TIDE_DTYPE const *__restrict const ey,
|
|
400
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
401
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
402
|
+
tide_bfloat16 *__restrict const bey,
|
|
403
|
+
tide_bfloat16 *__restrict const bhx,
|
|
404
|
+
tide_bfloat16 *__restrict const bhz,
|
|
405
|
+
int64_t const *__restrict const boundary_indices,
|
|
406
|
+
int64_t const boundary_numel,
|
|
407
|
+
int64_t const n_shots,
|
|
408
|
+
int64_t const shot_numel) {
|
|
409
|
+
|
|
410
|
+
TIDE_OMP_INDEX shot_idx;
|
|
411
|
+
TIDE_OMP_PARALLEL_FOR
|
|
412
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
413
|
+
TIDE_OMP_SIMD
|
|
414
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
415
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
416
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
417
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
418
|
+
bey[store_offset] = tide_float_to_bf16((float)ey[field_offset]);
|
|
419
|
+
bhx[store_offset] = tide_float_to_bf16((float)hx[field_offset]);
|
|
420
|
+
bhz[store_offset] = tide_float_to_bf16((float)hz[field_offset]);
|
|
421
|
+
}
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
|
|
425
|
+
static void scatter_boundary_cpu_bf16(
|
|
426
|
+
TIDE_DTYPE *__restrict const field,
|
|
427
|
+
tide_bfloat16 const *__restrict const store,
|
|
428
|
+
int64_t const *__restrict const boundary_indices,
|
|
429
|
+
int64_t const boundary_numel,
|
|
430
|
+
int64_t const n_shots,
|
|
431
|
+
int64_t const shot_numel) {
|
|
432
|
+
|
|
433
|
+
TIDE_OMP_INDEX shot_idx;
|
|
434
|
+
TIDE_OMP_PARALLEL_FOR
|
|
435
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
436
|
+
TIDE_OMP_SIMD
|
|
437
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
438
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
439
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
440
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
441
|
+
field[field_offset] = (TIDE_DTYPE)tide_bf16_to_float(store[store_offset]);
|
|
442
|
+
}
|
|
443
|
+
}
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
static void scatter_boundary_2_cpu_bf16(
|
|
447
|
+
TIDE_DTYPE *__restrict const hx,
|
|
448
|
+
TIDE_DTYPE *__restrict const hz,
|
|
449
|
+
tide_bfloat16 const *__restrict const bhx,
|
|
450
|
+
tide_bfloat16 const *__restrict const bhz,
|
|
451
|
+
int64_t const *__restrict const boundary_indices,
|
|
452
|
+
int64_t const boundary_numel,
|
|
453
|
+
int64_t const n_shots,
|
|
454
|
+
int64_t const shot_numel) {
|
|
455
|
+
|
|
456
|
+
TIDE_OMP_INDEX shot_idx;
|
|
457
|
+
TIDE_OMP_PARALLEL_FOR
|
|
458
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
459
|
+
TIDE_OMP_SIMD
|
|
460
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
461
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
462
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
463
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
464
|
+
hx[field_offset] = (TIDE_DTYPE)tide_bf16_to_float(bhx[store_offset]);
|
|
465
|
+
hz[field_offset] = (TIDE_DTYPE)tide_bf16_to_float(bhz[store_offset]);
|
|
466
|
+
}
|
|
467
|
+
}
|
|
468
|
+
}
|
|
469
|
+
|
|
470
|
+
// FP8 boundary storage functions
|
|
471
|
+
static void gather_boundary_3_cpu_fp8(
|
|
472
|
+
TIDE_DTYPE const *__restrict const ey,
|
|
473
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
474
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
475
|
+
tide_fp8_e4m3 *__restrict const bey,
|
|
476
|
+
tide_fp8_e4m3 *__restrict const bhx,
|
|
477
|
+
tide_fp8_e4m3 *__restrict const bhz,
|
|
478
|
+
int64_t const *__restrict const boundary_indices,
|
|
479
|
+
int64_t const boundary_numel,
|
|
480
|
+
int64_t const n_shots,
|
|
481
|
+
int64_t const shot_numel) {
|
|
482
|
+
|
|
483
|
+
TIDE_OMP_INDEX shot_idx;
|
|
484
|
+
TIDE_OMP_PARALLEL_FOR
|
|
485
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
486
|
+
TIDE_OMP_SIMD
|
|
487
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
488
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
489
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
490
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
491
|
+
bey[store_offset] = tide_float_to_fp8_e4m3((float)ey[field_offset]);
|
|
492
|
+
bhx[store_offset] = tide_float_to_fp8_e4m3((float)hx[field_offset]);
|
|
493
|
+
bhz[store_offset] = tide_float_to_fp8_e4m3((float)hz[field_offset]);
|
|
494
|
+
}
|
|
495
|
+
}
|
|
496
|
+
}
|
|
497
|
+
|
|
498
|
+
static void scatter_boundary_cpu_fp8(
|
|
499
|
+
TIDE_DTYPE *__restrict const field,
|
|
500
|
+
tide_fp8_e4m3 const *__restrict const store,
|
|
501
|
+
int64_t const *__restrict const boundary_indices,
|
|
502
|
+
int64_t const boundary_numel,
|
|
503
|
+
int64_t const n_shots,
|
|
504
|
+
int64_t const shot_numel) {
|
|
505
|
+
|
|
506
|
+
TIDE_OMP_INDEX shot_idx;
|
|
507
|
+
TIDE_OMP_PARALLEL_FOR
|
|
508
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
509
|
+
TIDE_OMP_SIMD
|
|
510
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
511
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
512
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
513
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
514
|
+
field[field_offset] = (TIDE_DTYPE)tide_fp8_e4m3_to_float(store[store_offset]);
|
|
515
|
+
}
|
|
516
|
+
}
|
|
517
|
+
}
|
|
518
|
+
|
|
519
|
+
static void scatter_boundary_2_cpu_fp8(
|
|
520
|
+
TIDE_DTYPE *__restrict const hx,
|
|
521
|
+
TIDE_DTYPE *__restrict const hz,
|
|
522
|
+
tide_fp8_e4m3 const *__restrict const bhx,
|
|
523
|
+
tide_fp8_e4m3 const *__restrict const bhz,
|
|
524
|
+
int64_t const *__restrict const boundary_indices,
|
|
525
|
+
int64_t const boundary_numel,
|
|
526
|
+
int64_t const n_shots,
|
|
527
|
+
int64_t const shot_numel) {
|
|
528
|
+
|
|
529
|
+
TIDE_OMP_INDEX shot_idx;
|
|
530
|
+
TIDE_OMP_PARALLEL_FOR
|
|
531
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
532
|
+
TIDE_OMP_SIMD
|
|
533
|
+
for (int64_t bi = 0; bi < boundary_numel; ++bi) {
|
|
534
|
+
int64_t const grid_idx = boundary_indices[bi];
|
|
535
|
+
int64_t const field_offset = shot_idx * shot_numel + grid_idx;
|
|
536
|
+
int64_t const store_offset = shot_idx * boundary_numel + bi;
|
|
537
|
+
hx[field_offset] = (TIDE_DTYPE)tide_fp8_e4m3_to_float(bhx[store_offset]);
|
|
538
|
+
hz[field_offset] = (TIDE_DTYPE)tide_fp8_e4m3_to_float(bhz[store_offset]);
|
|
539
|
+
}
|
|
540
|
+
}
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
static inline void *boundary_store_ptr(
|
|
544
|
+
void *store_1,
|
|
545
|
+
void *store_3,
|
|
546
|
+
int64_t storage_mode,
|
|
547
|
+
int64_t step_idx,
|
|
548
|
+
int64_t step_elems,
|
|
549
|
+
size_t elem_size) {
|
|
550
|
+
size_t const offset_bytes =
|
|
551
|
+
(size_t)step_idx * (size_t)step_elems * elem_size;
|
|
552
|
+
if (storage_mode == STORAGE_DEVICE) {
|
|
553
|
+
return (uint8_t *)store_1 + offset_bytes;
|
|
554
|
+
}
|
|
555
|
+
if (storage_mode == STORAGE_CPU && store_3 != NULL) {
|
|
556
|
+
return (uint8_t *)store_3 + offset_bytes;
|
|
557
|
+
}
|
|
558
|
+
return (uint8_t *)store_1;
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
static void forward_kernel_h(
|
|
563
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
564
|
+
TIDE_DTYPE const *__restrict const ey,
|
|
565
|
+
TIDE_DTYPE *__restrict const hx,
|
|
566
|
+
TIDE_DTYPE *__restrict const hz,
|
|
567
|
+
TIDE_DTYPE *__restrict const m_ey_x,
|
|
568
|
+
TIDE_DTYPE *__restrict const m_ey_z,
|
|
569
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
570
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
571
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
572
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
573
|
+
TIDE_DTYPE const *__restrict const by,
|
|
574
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
575
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
576
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
577
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
578
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
579
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
580
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
581
|
+
TIDE_DTYPE const rdy,
|
|
582
|
+
TIDE_DTYPE const rdx,
|
|
583
|
+
int64_t const n_shots,
|
|
584
|
+
int64_t const ny,
|
|
585
|
+
int64_t const nx,
|
|
586
|
+
int64_t const shot_numel,
|
|
587
|
+
int64_t const pml_y0,
|
|
588
|
+
int64_t const pml_y1,
|
|
589
|
+
int64_t const pml_x0,
|
|
590
|
+
int64_t const pml_x1,
|
|
591
|
+
bool const cq_batched) {
|
|
592
|
+
|
|
593
|
+
int64_t const pml_y0h = pml_y0;
|
|
594
|
+
int64_t const pml_y1h = MAX(pml_y0, pml_y1 - 1);
|
|
595
|
+
int64_t const pml_x0h = pml_x0;
|
|
596
|
+
int64_t const pml_x1h = MAX(pml_x0, pml_x1 - 1);
|
|
597
|
+
int64_t const pml_bounds_yh[] = {FD_PAD, pml_y0h, pml_y1h, ny - FD_PAD + 1};
|
|
598
|
+
int64_t const pml_bounds_xh[] = {FD_PAD, pml_x0h, pml_x1h, nx - FD_PAD + 1};
|
|
599
|
+
|
|
600
|
+
TIDE_OMP_INDEX shot_idx;
|
|
601
|
+
TIDE_OMP_PARALLEL_FOR
|
|
602
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
603
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
604
|
+
TIDE_DTYPE const *__restrict const cq_ptr =
|
|
605
|
+
cq_batched ? (cq + shot_offset) : cq;
|
|
606
|
+
for (int pml_y = 0; pml_y < 3; ++pml_y) {
|
|
607
|
+
for (int pml_x = 0; pml_x < 3; ++pml_x) {
|
|
608
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
609
|
+
for (int64_t y = pml_bounds_yh[pml_y]; y < pml_bounds_yh[pml_y + 1]; ++y) {
|
|
610
|
+
for (int64_t x = pml_bounds_xh[pml_x]; x < pml_bounds_xh[pml_x + 1]; ++x) {
|
|
611
|
+
int64_t const idx = IDX(y, x);
|
|
612
|
+
TIDE_DTYPE const cq_val = cq_ptr[idx];
|
|
613
|
+
|
|
614
|
+
if (y < ny - FD_PAD) {
|
|
615
|
+
TIDE_DTYPE dey_dz = DIFFYH1(EY);
|
|
616
|
+
|
|
617
|
+
if (pml_y != 1) {
|
|
618
|
+
M_EY_Z(0, 0) = byh[y] * M_EY_Z(0, 0) + ayh[y] * dey_dz;
|
|
619
|
+
dey_dz = dey_dz / kyh[y] + M_EY_Z(0, 0);
|
|
620
|
+
}
|
|
621
|
+
|
|
622
|
+
HX(0, 0) -= cq_val * dey_dz;
|
|
623
|
+
}
|
|
624
|
+
|
|
625
|
+
if (x < nx - FD_PAD) {
|
|
626
|
+
TIDE_DTYPE dey_dx = DIFFXH1(EY);
|
|
627
|
+
|
|
628
|
+
if (pml_x != 1) {
|
|
629
|
+
M_EY_X(0, 0) = bxh[x] * M_EY_X(0, 0) + axh[x] * dey_dx;
|
|
630
|
+
dey_dx = dey_dx / kxh[x] + M_EY_X(0, 0);
|
|
631
|
+
}
|
|
632
|
+
|
|
633
|
+
HZ(0, 0) += cq_val * dey_dx;
|
|
634
|
+
}
|
|
635
|
+
}
|
|
636
|
+
}
|
|
637
|
+
}
|
|
638
|
+
}
|
|
639
|
+
}
|
|
640
|
+
}
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
/*
|
|
644
|
+
* Forward E kernel with optional storage for gradient computation
|
|
645
|
+
*
|
|
646
|
+
* When ca_requires_grad or cb_requires_grad is true, stores:
|
|
647
|
+
* - ey_store: E_y field before update (needed for grad_ca)
|
|
648
|
+
* - curl_h_store: (dHz/dx - dHx/dz) (needed for grad_cb)
|
|
649
|
+
*/
|
|
650
|
+
static void forward_kernel_e_with_storage(
|
|
651
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
652
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
653
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
654
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
655
|
+
TIDE_DTYPE *__restrict const ey,
|
|
656
|
+
TIDE_DTYPE *__restrict const m_hx_z,
|
|
657
|
+
TIDE_DTYPE *__restrict const m_hz_x,
|
|
658
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
659
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
660
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
661
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
662
|
+
TIDE_DTYPE const *__restrict const by,
|
|
663
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
664
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
665
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
666
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
667
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
668
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
669
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
670
|
+
TIDE_DTYPE const rdy,
|
|
671
|
+
TIDE_DTYPE const rdx,
|
|
672
|
+
int64_t const n_shots,
|
|
673
|
+
int64_t const ny,
|
|
674
|
+
int64_t const nx,
|
|
675
|
+
int64_t const shot_numel,
|
|
676
|
+
int64_t const pml_y0,
|
|
677
|
+
int64_t const pml_y1,
|
|
678
|
+
int64_t const pml_x0,
|
|
679
|
+
int64_t const pml_x1,
|
|
680
|
+
bool const ca_batched,
|
|
681
|
+
bool const cb_batched,
|
|
682
|
+
bool const ca_requires_grad,
|
|
683
|
+
bool const cb_requires_grad,
|
|
684
|
+
TIDE_DTYPE *__restrict const ey_store,
|
|
685
|
+
TIDE_DTYPE *__restrict const curl_h_store) {
|
|
686
|
+
|
|
687
|
+
int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
|
|
688
|
+
int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
|
|
689
|
+
|
|
690
|
+
TIDE_OMP_INDEX shot_idx;
|
|
691
|
+
TIDE_OMP_PARALLEL_FOR
|
|
692
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
693
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
694
|
+
TIDE_DTYPE const *__restrict const ca_ptr =
|
|
695
|
+
ca_batched ? (ca + shot_offset) : ca;
|
|
696
|
+
TIDE_DTYPE const *__restrict const cb_ptr =
|
|
697
|
+
cb_batched ? (cb + shot_offset) : cb;
|
|
698
|
+
for (int pml_y = 0; pml_y < 3; ++pml_y) {
|
|
699
|
+
for (int pml_x = 0; pml_x < 3; ++pml_x) {
|
|
700
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
701
|
+
for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
|
|
702
|
+
for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
|
|
703
|
+
int64_t const idx = IDX(y, x);
|
|
704
|
+
int64_t const store_idx = shot_offset + idx;
|
|
705
|
+
TIDE_DTYPE const ca_val = ca_ptr[idx];
|
|
706
|
+
TIDE_DTYPE const cb_val = cb_ptr[idx];
|
|
707
|
+
|
|
708
|
+
TIDE_DTYPE dhz_dx = DIFFX1(HZ);
|
|
709
|
+
TIDE_DTYPE dhx_dz = DIFFY1(HX);
|
|
710
|
+
|
|
711
|
+
if (pml_x != 1) {
|
|
712
|
+
M_HZ_X(0, 0) = bx[x] * M_HZ_X(0, 0) + ax[x] * dhz_dx;
|
|
713
|
+
dhz_dx = dhz_dx / kx[x] + M_HZ_X(0, 0);
|
|
714
|
+
}
|
|
715
|
+
|
|
716
|
+
if (pml_y != 1) {
|
|
717
|
+
M_HX_Z(0, 0) = by[y] * M_HX_Z(0, 0) + ay[y] * dhx_dz;
|
|
718
|
+
dhx_dz = dhx_dz / ky[y] + M_HX_Z(0, 0);
|
|
719
|
+
}
|
|
720
|
+
|
|
721
|
+
TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
|
|
722
|
+
|
|
723
|
+
if (ca_requires_grad && ey_store != NULL) {
|
|
724
|
+
ey_store[store_idx] = EY(0, 0);
|
|
725
|
+
}
|
|
726
|
+
if (cb_requires_grad && curl_h_store != NULL) {
|
|
727
|
+
curl_h_store[store_idx] = curl_h;
|
|
728
|
+
}
|
|
729
|
+
|
|
730
|
+
EY(0, 0) = ca_val * EY(0, 0) + cb_val * curl_h;
|
|
731
|
+
}
|
|
732
|
+
}
|
|
733
|
+
}
|
|
734
|
+
}
|
|
735
|
+
}
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
static void forward_kernel_e_with_storage_bf16(
|
|
739
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
740
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
741
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
742
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
743
|
+
TIDE_DTYPE *__restrict const ey,
|
|
744
|
+
TIDE_DTYPE *__restrict const m_hx_z,
|
|
745
|
+
TIDE_DTYPE *__restrict const m_hz_x,
|
|
746
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
747
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
748
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
749
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
750
|
+
TIDE_DTYPE const *__restrict const by,
|
|
751
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
752
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
753
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
754
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
755
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
756
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
757
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
758
|
+
TIDE_DTYPE const rdy,
|
|
759
|
+
TIDE_DTYPE const rdx,
|
|
760
|
+
int64_t const n_shots,
|
|
761
|
+
int64_t const ny,
|
|
762
|
+
int64_t const nx,
|
|
763
|
+
int64_t const shot_numel,
|
|
764
|
+
int64_t const pml_y0,
|
|
765
|
+
int64_t const pml_y1,
|
|
766
|
+
int64_t const pml_x0,
|
|
767
|
+
int64_t const pml_x1,
|
|
768
|
+
bool const ca_batched,
|
|
769
|
+
bool const cb_batched,
|
|
770
|
+
bool const ca_requires_grad,
|
|
771
|
+
bool const cb_requires_grad,
|
|
772
|
+
tide_bfloat16 *__restrict const ey_store,
|
|
773
|
+
tide_bfloat16 *__restrict const curl_h_store) {
|
|
774
|
+
|
|
775
|
+
int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
|
|
776
|
+
int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
|
|
777
|
+
|
|
778
|
+
TIDE_OMP_INDEX shot_idx;
|
|
779
|
+
TIDE_OMP_PARALLEL_FOR
|
|
780
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
781
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
782
|
+
TIDE_DTYPE const *__restrict const ca_ptr =
|
|
783
|
+
ca_batched ? (ca + shot_offset) : ca;
|
|
784
|
+
TIDE_DTYPE const *__restrict const cb_ptr =
|
|
785
|
+
cb_batched ? (cb + shot_offset) : cb;
|
|
786
|
+
for (int pml_y = 0; pml_y < 3; ++pml_y) {
|
|
787
|
+
for (int pml_x = 0; pml_x < 3; ++pml_x) {
|
|
788
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
789
|
+
for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
|
|
790
|
+
for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
|
|
791
|
+
int64_t const idx = IDX(y, x);
|
|
792
|
+
int64_t const store_idx = shot_offset + idx;
|
|
793
|
+
TIDE_DTYPE const ca_val = ca_ptr[idx];
|
|
794
|
+
TIDE_DTYPE const cb_val = cb_ptr[idx];
|
|
795
|
+
|
|
796
|
+
TIDE_DTYPE dhz_dx = DIFFX1(HZ);
|
|
797
|
+
TIDE_DTYPE dhx_dz = DIFFY1(HX);
|
|
798
|
+
|
|
799
|
+
if (pml_x != 1) {
|
|
800
|
+
M_HZ_X(0, 0) = bx[x] * M_HZ_X(0, 0) + ax[x] * dhz_dx;
|
|
801
|
+
dhz_dx = dhz_dx / kx[x] + M_HZ_X(0, 0);
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
if (pml_y != 1) {
|
|
805
|
+
M_HX_Z(0, 0) = by[y] * M_HX_Z(0, 0) + ay[y] * dhx_dz;
|
|
806
|
+
dhx_dz = dhx_dz / ky[y] + M_HX_Z(0, 0);
|
|
807
|
+
}
|
|
808
|
+
|
|
809
|
+
TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
|
|
810
|
+
if (ca_requires_grad && ey_store != NULL) {
|
|
811
|
+
ey_store[store_idx] = tide_float_to_bf16((float)EY(0, 0));
|
|
812
|
+
}
|
|
813
|
+
if (cb_requires_grad && curl_h_store != NULL) {
|
|
814
|
+
curl_h_store[store_idx] = tide_float_to_bf16((float)curl_h);
|
|
815
|
+
}
|
|
816
|
+
|
|
817
|
+
EY(0, 0) = ca_val * EY(0, 0) + cb_val * curl_h;
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
}
|
|
821
|
+
}
|
|
822
|
+
}
|
|
823
|
+
}
|
|
824
|
+
|
|
825
|
+
static void forward_kernel_e_with_storage_fp8(
|
|
826
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
827
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
828
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
829
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
830
|
+
TIDE_DTYPE *__restrict const ey,
|
|
831
|
+
TIDE_DTYPE *__restrict const m_hx_z,
|
|
832
|
+
TIDE_DTYPE *__restrict const m_hz_x,
|
|
833
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
834
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
835
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
836
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
837
|
+
TIDE_DTYPE const *__restrict const by,
|
|
838
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
839
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
840
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
841
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
842
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
843
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
844
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
845
|
+
TIDE_DTYPE const rdy,
|
|
846
|
+
TIDE_DTYPE const rdx,
|
|
847
|
+
int64_t const n_shots,
|
|
848
|
+
int64_t const ny,
|
|
849
|
+
int64_t const nx,
|
|
850
|
+
int64_t const shot_numel,
|
|
851
|
+
int64_t const pml_y0,
|
|
852
|
+
int64_t const pml_y1,
|
|
853
|
+
int64_t const pml_x0,
|
|
854
|
+
int64_t const pml_x1,
|
|
855
|
+
bool const ca_batched,
|
|
856
|
+
bool const cb_batched,
|
|
857
|
+
bool const ca_requires_grad,
|
|
858
|
+
bool const cb_requires_grad,
|
|
859
|
+
tide_fp8_e4m3 *__restrict const ey_store,
|
|
860
|
+
tide_fp8_e4m3 *__restrict const curl_h_store) {
|
|
861
|
+
|
|
862
|
+
int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
|
|
863
|
+
int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
|
|
864
|
+
|
|
865
|
+
TIDE_OMP_INDEX shot_idx;
|
|
866
|
+
TIDE_OMP_PARALLEL_FOR
|
|
867
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
868
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
869
|
+
TIDE_DTYPE const *__restrict const ca_ptr =
|
|
870
|
+
ca_batched ? (ca + shot_offset) : ca;
|
|
871
|
+
TIDE_DTYPE const *__restrict const cb_ptr =
|
|
872
|
+
cb_batched ? (cb + shot_offset) : cb;
|
|
873
|
+
for (int pml_y = 0; pml_y < 3; ++pml_y) {
|
|
874
|
+
for (int pml_x = 0; pml_x < 3; ++pml_x) {
|
|
875
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
876
|
+
for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
|
|
877
|
+
for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
|
|
878
|
+
int64_t const idx = IDX(y, x);
|
|
879
|
+
int64_t const store_idx = shot_offset + idx;
|
|
880
|
+
TIDE_DTYPE const ca_val = ca_ptr[idx];
|
|
881
|
+
TIDE_DTYPE const cb_val = cb_ptr[idx];
|
|
882
|
+
|
|
883
|
+
TIDE_DTYPE dhz_dx = DIFFX1(HZ);
|
|
884
|
+
TIDE_DTYPE dhx_dz = DIFFY1(HX);
|
|
885
|
+
|
|
886
|
+
if (pml_x != 1) {
|
|
887
|
+
M_HZ_X(0, 0) = bx[x] * M_HZ_X(0, 0) + ax[x] * dhz_dx;
|
|
888
|
+
dhz_dx = dhz_dx / kx[x] + M_HZ_X(0, 0);
|
|
889
|
+
}
|
|
890
|
+
|
|
891
|
+
if (pml_y != 1) {
|
|
892
|
+
M_HX_Z(0, 0) = by[y] * M_HX_Z(0, 0) + ay[y] * dhx_dz;
|
|
893
|
+
dhx_dz = dhx_dz / ky[y] + M_HX_Z(0, 0);
|
|
894
|
+
}
|
|
895
|
+
|
|
896
|
+
TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
|
|
897
|
+
if (ca_requires_grad && ey_store != NULL) {
|
|
898
|
+
ey_store[store_idx] = tide_float_to_fp8_e4m3((float)EY(0, 0));
|
|
899
|
+
}
|
|
900
|
+
if (cb_requires_grad && curl_h_store != NULL) {
|
|
901
|
+
curl_h_store[store_idx] = tide_float_to_fp8_e4m3((float)curl_h);
|
|
902
|
+
}
|
|
903
|
+
|
|
904
|
+
EY(0, 0) = ca_val * EY(0, 0) + cb_val * curl_h;
|
|
905
|
+
}
|
|
906
|
+
}
|
|
907
|
+
}
|
|
908
|
+
}
|
|
909
|
+
}
|
|
910
|
+
}
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
#ifdef __cplusplus
|
|
914
|
+
extern "C"
|
|
915
|
+
#endif
|
|
916
|
+
#ifdef _WIN32
|
|
917
|
+
__declspec(dllexport)
|
|
918
|
+
#endif
|
|
919
|
+
void FUNC(forward)(
|
|
920
|
+
TIDE_DTYPE const *const ca,
|
|
921
|
+
TIDE_DTYPE const *const cb,
|
|
922
|
+
TIDE_DTYPE const *const cq,
|
|
923
|
+
TIDE_DTYPE const *const f,
|
|
924
|
+
TIDE_DTYPE *const ey,
|
|
925
|
+
TIDE_DTYPE *const hx,
|
|
926
|
+
TIDE_DTYPE *const hz,
|
|
927
|
+
TIDE_DTYPE *const m_ey_x,
|
|
928
|
+
TIDE_DTYPE *const m_ey_z,
|
|
929
|
+
TIDE_DTYPE *const m_hx_z,
|
|
930
|
+
TIDE_DTYPE *const m_hz_x,
|
|
931
|
+
TIDE_DTYPE *const r,
|
|
932
|
+
TIDE_DTYPE const *const ay,
|
|
933
|
+
TIDE_DTYPE const *const by,
|
|
934
|
+
TIDE_DTYPE const *const ayh,
|
|
935
|
+
TIDE_DTYPE const *const byh,
|
|
936
|
+
TIDE_DTYPE const *const ax,
|
|
937
|
+
TIDE_DTYPE const *const bx,
|
|
938
|
+
TIDE_DTYPE const *const axh,
|
|
939
|
+
TIDE_DTYPE const *const bxh,
|
|
940
|
+
TIDE_DTYPE const *const ky,
|
|
941
|
+
TIDE_DTYPE const *const kyh,
|
|
942
|
+
TIDE_DTYPE const *const kx,
|
|
943
|
+
TIDE_DTYPE const *const kxh,
|
|
944
|
+
int64_t const *const sources_i,
|
|
945
|
+
int64_t const *const receivers_i,
|
|
946
|
+
TIDE_DTYPE const rdy,
|
|
947
|
+
TIDE_DTYPE const rdx,
|
|
948
|
+
TIDE_DTYPE const dt,
|
|
949
|
+
int64_t const nt,
|
|
950
|
+
int64_t const n_shots,
|
|
951
|
+
int64_t const ny,
|
|
952
|
+
int64_t const nx,
|
|
953
|
+
int64_t const n_sources_per_shot,
|
|
954
|
+
int64_t const n_receivers_per_shot,
|
|
955
|
+
int64_t const step_ratio,
|
|
956
|
+
bool const ca_batched,
|
|
957
|
+
bool const cb_batched,
|
|
958
|
+
bool const cq_batched,
|
|
959
|
+
int64_t const start_t,
|
|
960
|
+
int64_t const pml_y0,
|
|
961
|
+
int64_t const pml_x0,
|
|
962
|
+
int64_t const pml_y1,
|
|
963
|
+
int64_t const pml_x1,
|
|
964
|
+
int64_t const n_threads,
|
|
965
|
+
int64_t const device /* unused for CPU */) {
|
|
966
|
+
|
|
967
|
+
(void)device;
|
|
968
|
+
(void)dt;
|
|
969
|
+
(void)step_ratio;
|
|
970
|
+
#ifdef _OPENMP
|
|
971
|
+
int const prev_threads = omp_get_max_threads();
|
|
972
|
+
if (n_threads > 0) {
|
|
973
|
+
omp_set_num_threads((int)n_threads);
|
|
974
|
+
}
|
|
975
|
+
#else
|
|
976
|
+
(void)n_threads;
|
|
977
|
+
#endif
|
|
978
|
+
|
|
979
|
+
int64_t const shot_numel = ny * nx;
|
|
980
|
+
|
|
981
|
+
for (int64_t t = start_t; t < start_t + nt; ++t) {
|
|
982
|
+
forward_kernel_h(
|
|
983
|
+
cq, ey, hx, hz, m_ey_x, m_ey_z,
|
|
984
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
985
|
+
ky, kyh, kx, kxh,
|
|
986
|
+
rdy, rdx,
|
|
987
|
+
n_shots, ny, nx, shot_numel,
|
|
988
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
989
|
+
cq_batched);
|
|
990
|
+
|
|
991
|
+
forward_kernel_e_with_storage(
|
|
992
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
|
|
993
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
994
|
+
ky, kyh, kx, kxh,
|
|
995
|
+
rdy, rdx,
|
|
996
|
+
n_shots, ny, nx, shot_numel,
|
|
997
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
998
|
+
ca_batched, cb_batched,
|
|
999
|
+
false, false, // No storage for standard forward
|
|
1000
|
+
NULL, NULL);
|
|
1001
|
+
|
|
1002
|
+
if (n_sources_per_shot > 0) {
|
|
1003
|
+
add_sources_ey(
|
|
1004
|
+
ey, f + t * n_shots * n_sources_per_shot, sources_i,
|
|
1005
|
+
n_shots, shot_numel, n_sources_per_shot);
|
|
1006
|
+
}
|
|
1007
|
+
|
|
1008
|
+
if (n_receivers_per_shot > 0) {
|
|
1009
|
+
record_receivers_ey(
|
|
1010
|
+
r + t * n_shots * n_receivers_per_shot,
|
|
1011
|
+
ey, receivers_i,
|
|
1012
|
+
n_shots, shot_numel, n_receivers_per_shot);
|
|
1013
|
+
}
|
|
1014
|
+
}
|
|
1015
|
+
#ifdef _OPENMP
|
|
1016
|
+
if (n_threads > 0) {
|
|
1017
|
+
omp_set_num_threads(prev_threads);
|
|
1018
|
+
}
|
|
1019
|
+
#endif
|
|
1020
|
+
}
|
|
1021
|
+
|
|
1022
|
+
|
|
1023
|
+
/*
|
|
1024
|
+
* Forward with storage for backward pass
|
|
1025
|
+
*
|
|
1026
|
+
* This function performs forward propagation while storing the values
|
|
1027
|
+
* needed for gradient computation in the backward pass.
|
|
1028
|
+
*/
|
|
1029
|
+
#ifdef __cplusplus
|
|
1030
|
+
extern "C"
|
|
1031
|
+
#endif
|
|
1032
|
+
#ifdef _WIN32
|
|
1033
|
+
__declspec(dllexport)
|
|
1034
|
+
#endif
|
|
1035
|
+
void FUNC(forward_with_storage)(
|
|
1036
|
+
TIDE_DTYPE const *const ca,
|
|
1037
|
+
TIDE_DTYPE const *const cb,
|
|
1038
|
+
TIDE_DTYPE const *const cq,
|
|
1039
|
+
TIDE_DTYPE const *const f,
|
|
1040
|
+
TIDE_DTYPE *const ey,
|
|
1041
|
+
TIDE_DTYPE *const hx,
|
|
1042
|
+
TIDE_DTYPE *const hz,
|
|
1043
|
+
TIDE_DTYPE *const m_ey_x,
|
|
1044
|
+
TIDE_DTYPE *const m_ey_z,
|
|
1045
|
+
TIDE_DTYPE *const m_hx_z,
|
|
1046
|
+
TIDE_DTYPE *const m_hz_x,
|
|
1047
|
+
TIDE_DTYPE *const r,
|
|
1048
|
+
TIDE_DTYPE *const ey_store_1,
|
|
1049
|
+
void *const ey_store_3,
|
|
1050
|
+
char const *const *const ey_filenames,
|
|
1051
|
+
TIDE_DTYPE *const curl_store_1,
|
|
1052
|
+
void *const curl_store_3,
|
|
1053
|
+
char const *const *const curl_filenames,
|
|
1054
|
+
TIDE_DTYPE const *const ay,
|
|
1055
|
+
TIDE_DTYPE const *const by,
|
|
1056
|
+
TIDE_DTYPE const *const ayh,
|
|
1057
|
+
TIDE_DTYPE const *const byh,
|
|
1058
|
+
TIDE_DTYPE const *const ax,
|
|
1059
|
+
TIDE_DTYPE const *const bx,
|
|
1060
|
+
TIDE_DTYPE const *const axh,
|
|
1061
|
+
TIDE_DTYPE const *const bxh,
|
|
1062
|
+
TIDE_DTYPE const *const ky,
|
|
1063
|
+
TIDE_DTYPE const *const kyh,
|
|
1064
|
+
TIDE_DTYPE const *const kx,
|
|
1065
|
+
TIDE_DTYPE const *const kxh,
|
|
1066
|
+
int64_t const *const sources_i,
|
|
1067
|
+
int64_t const *const receivers_i,
|
|
1068
|
+
TIDE_DTYPE const rdy,
|
|
1069
|
+
TIDE_DTYPE const rdx,
|
|
1070
|
+
TIDE_DTYPE const dt,
|
|
1071
|
+
int64_t const nt,
|
|
1072
|
+
int64_t const n_shots,
|
|
1073
|
+
int64_t const ny,
|
|
1074
|
+
int64_t const nx,
|
|
1075
|
+
int64_t const n_sources_per_shot,
|
|
1076
|
+
int64_t const n_receivers_per_shot,
|
|
1077
|
+
int64_t const step_ratio,
|
|
1078
|
+
int64_t const storage_mode,
|
|
1079
|
+
int64_t const shot_bytes_uncomp,
|
|
1080
|
+
bool const ca_requires_grad,
|
|
1081
|
+
bool const cb_requires_grad,
|
|
1082
|
+
bool const ca_batched,
|
|
1083
|
+
bool const cb_batched,
|
|
1084
|
+
bool const cq_batched,
|
|
1085
|
+
int64_t const start_t,
|
|
1086
|
+
int64_t const pml_y0,
|
|
1087
|
+
int64_t const pml_x0,
|
|
1088
|
+
int64_t const pml_y1,
|
|
1089
|
+
int64_t const pml_x1,
|
|
1090
|
+
int64_t const n_threads,
|
|
1091
|
+
int64_t const device /* unused for CPU */) {
|
|
1092
|
+
|
|
1093
|
+
(void)device;
|
|
1094
|
+
(void)dt;
|
|
1095
|
+
#ifdef _OPENMP
|
|
1096
|
+
int const prev_threads = omp_get_max_threads();
|
|
1097
|
+
if (n_threads > 0) {
|
|
1098
|
+
omp_set_num_threads((int)n_threads);
|
|
1099
|
+
}
|
|
1100
|
+
#else
|
|
1101
|
+
(void)n_threads;
|
|
1102
|
+
#endif
|
|
1103
|
+
|
|
1104
|
+
int64_t const shot_numel = ny * nx;
|
|
1105
|
+
int64_t const store_size = n_shots * shot_numel;
|
|
1106
|
+
bool const storage_fp8 = (shot_bytes_uncomp == shot_numel * 1);
|
|
1107
|
+
bool const storage_bf16 = (shot_bytes_uncomp == shot_numel * 2);
|
|
1108
|
+
|
|
1109
|
+
FILE **fp_ey = NULL;
|
|
1110
|
+
FILE **fp_curl = NULL;
|
|
1111
|
+
if (storage_mode == STORAGE_DISK) {
|
|
1112
|
+
if (ca_requires_grad) {
|
|
1113
|
+
fp_ey = (FILE **)malloc((size_t)n_shots * sizeof(FILE *));
|
|
1114
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1115
|
+
fp_ey[shot] = fopen(ey_filenames[shot], "wb");
|
|
1116
|
+
}
|
|
1117
|
+
}
|
|
1118
|
+
if (cb_requires_grad) {
|
|
1119
|
+
fp_curl = (FILE **)malloc((size_t)n_shots * sizeof(FILE *));
|
|
1120
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1121
|
+
fp_curl[shot] = fopen(curl_filenames[shot], "wb");
|
|
1122
|
+
}
|
|
1123
|
+
}
|
|
1124
|
+
}
|
|
1125
|
+
|
|
1126
|
+
for (int64_t t = start_t; t < start_t + nt; ++t) {
|
|
1127
|
+
forward_kernel_h(
|
|
1128
|
+
cq, ey, hx, hz, m_ey_x, m_ey_z,
|
|
1129
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
1130
|
+
ky, kyh, kx, kxh,
|
|
1131
|
+
rdy, rdx,
|
|
1132
|
+
n_shots, ny, nx, shot_numel,
|
|
1133
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
1134
|
+
cq_batched);
|
|
1135
|
+
|
|
1136
|
+
bool const store_step = ((t % step_ratio) == 0);
|
|
1137
|
+
bool const store_ey = store_step && ca_requires_grad;
|
|
1138
|
+
bool const store_curl = store_step && cb_requires_grad;
|
|
1139
|
+
int64_t const step_idx = t / step_ratio;
|
|
1140
|
+
|
|
1141
|
+
int64_t const store_offset =
|
|
1142
|
+
(storage_mode == STORAGE_DEVICE ? step_idx * store_size : 0);
|
|
1143
|
+
|
|
1144
|
+
if (storage_fp8) {
|
|
1145
|
+
tide_fp8_e4m3 *const ey_store_1_t =
|
|
1146
|
+
(tide_fp8_e4m3 *)ey_store_1 + store_offset;
|
|
1147
|
+
tide_fp8_e4m3 *const curl_store_1_t =
|
|
1148
|
+
(tide_fp8_e4m3 *)curl_store_1 + store_offset;
|
|
1149
|
+
|
|
1150
|
+
forward_kernel_e_with_storage_fp8(
|
|
1151
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
|
|
1152
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
1153
|
+
ky, kyh, kx, kxh,
|
|
1154
|
+
rdy, rdx,
|
|
1155
|
+
n_shots, ny, nx, shot_numel,
|
|
1156
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
1157
|
+
ca_batched, cb_batched,
|
|
1158
|
+
store_ey,
|
|
1159
|
+
store_curl,
|
|
1160
|
+
store_ey ? ey_store_1_t : NULL,
|
|
1161
|
+
store_curl ? curl_store_1_t : NULL);
|
|
1162
|
+
|
|
1163
|
+
if (store_ey && storage_mode == STORAGE_DISK) {
|
|
1164
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1165
|
+
storage_save_snapshot_cpu(
|
|
1166
|
+
(void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
|
|
1167
|
+
storage_mode, step_idx, (size_t)shot_bytes_uncomp);
|
|
1168
|
+
}
|
|
1169
|
+
}
|
|
1170
|
+
if (store_curl && storage_mode == STORAGE_DISK) {
|
|
1171
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1172
|
+
storage_save_snapshot_cpu(
|
|
1173
|
+
(void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
|
|
1174
|
+
storage_mode, step_idx, (size_t)shot_bytes_uncomp);
|
|
1175
|
+
}
|
|
1176
|
+
}
|
|
1177
|
+
} else if (storage_bf16) {
|
|
1178
|
+
tide_bfloat16 *const ey_store_1_t =
|
|
1179
|
+
(tide_bfloat16 *)ey_store_1 + store_offset;
|
|
1180
|
+
tide_bfloat16 *const curl_store_1_t =
|
|
1181
|
+
(tide_bfloat16 *)curl_store_1 + store_offset;
|
|
1182
|
+
|
|
1183
|
+
forward_kernel_e_with_storage_bf16(
|
|
1184
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
|
|
1185
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
1186
|
+
ky, kyh, kx, kxh,
|
|
1187
|
+
rdy, rdx,
|
|
1188
|
+
n_shots, ny, nx, shot_numel,
|
|
1189
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
1190
|
+
ca_batched, cb_batched,
|
|
1191
|
+
store_ey,
|
|
1192
|
+
store_curl,
|
|
1193
|
+
store_ey ? ey_store_1_t : NULL,
|
|
1194
|
+
store_curl ? curl_store_1_t : NULL);
|
|
1195
|
+
|
|
1196
|
+
if (store_ey && storage_mode == STORAGE_DISK) {
|
|
1197
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1198
|
+
storage_save_snapshot_cpu(
|
|
1199
|
+
(void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
|
|
1200
|
+
storage_mode, step_idx, (size_t)shot_bytes_uncomp);
|
|
1201
|
+
}
|
|
1202
|
+
}
|
|
1203
|
+
if (store_curl && storage_mode == STORAGE_DISK) {
|
|
1204
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1205
|
+
storage_save_snapshot_cpu(
|
|
1206
|
+
(void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
|
|
1207
|
+
storage_mode, step_idx, (size_t)shot_bytes_uncomp);
|
|
1208
|
+
}
|
|
1209
|
+
}
|
|
1210
|
+
} else {
|
|
1211
|
+
TIDE_DTYPE *const ey_store_1_t =
|
|
1212
|
+
ey_store_1 + store_offset;
|
|
1213
|
+
TIDE_DTYPE *const curl_store_1_t =
|
|
1214
|
+
curl_store_1 + store_offset;
|
|
1215
|
+
|
|
1216
|
+
forward_kernel_e_with_storage(
|
|
1217
|
+
ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
|
|
1218
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
1219
|
+
ky, kyh, kx, kxh,
|
|
1220
|
+
rdy, rdx,
|
|
1221
|
+
n_shots, ny, nx, shot_numel,
|
|
1222
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
1223
|
+
ca_batched, cb_batched,
|
|
1224
|
+
store_ey,
|
|
1225
|
+
store_curl,
|
|
1226
|
+
store_ey ? ey_store_1_t : NULL,
|
|
1227
|
+
store_curl ? curl_store_1_t : NULL);
|
|
1228
|
+
|
|
1229
|
+
if (store_ey && storage_mode == STORAGE_DISK) {
|
|
1230
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1231
|
+
storage_save_snapshot_cpu(
|
|
1232
|
+
(void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
|
|
1233
|
+
storage_mode, step_idx, (size_t)shot_bytes_uncomp);
|
|
1234
|
+
}
|
|
1235
|
+
}
|
|
1236
|
+
if (store_curl && storage_mode == STORAGE_DISK) {
|
|
1237
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1238
|
+
storage_save_snapshot_cpu(
|
|
1239
|
+
(void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
|
|
1240
|
+
storage_mode, step_idx, (size_t)shot_bytes_uncomp);
|
|
1241
|
+
}
|
|
1242
|
+
}
|
|
1243
|
+
}
|
|
1244
|
+
|
|
1245
|
+
if (n_sources_per_shot > 0) {
|
|
1246
|
+
add_sources_ey(
|
|
1247
|
+
ey, f + t * n_shots * n_sources_per_shot, sources_i,
|
|
1248
|
+
n_shots, shot_numel, n_sources_per_shot);
|
|
1249
|
+
}
|
|
1250
|
+
|
|
1251
|
+
if (n_receivers_per_shot > 0) {
|
|
1252
|
+
record_receivers_ey(
|
|
1253
|
+
r + t * n_shots * n_receivers_per_shot,
|
|
1254
|
+
ey, receivers_i,
|
|
1255
|
+
n_shots, shot_numel, n_receivers_per_shot);
|
|
1256
|
+
}
|
|
1257
|
+
}
|
|
1258
|
+
|
|
1259
|
+
if (fp_ey != NULL) {
|
|
1260
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) fclose(fp_ey[shot]);
|
|
1261
|
+
free(fp_ey);
|
|
1262
|
+
}
|
|
1263
|
+
if (fp_curl != NULL) {
|
|
1264
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) fclose(fp_curl[shot]);
|
|
1265
|
+
free(fp_curl);
|
|
1266
|
+
}
|
|
1267
|
+
#ifdef _OPENMP
|
|
1268
|
+
if (n_threads > 0) {
|
|
1269
|
+
omp_set_num_threads(prev_threads);
|
|
1270
|
+
}
|
|
1271
|
+
#endif
|
|
1272
|
+
}
|
|
1273
|
+
|
|
1274
|
+
/*
|
|
1275
|
+
* Backward kernel for adjoint λ_H fields update
|
|
1276
|
+
*
|
|
1277
|
+
* Adjoint equations for H fields (time reversed, swap Cb and Cq roles):
|
|
1278
|
+
* λ_Hx^{n-1/2} = λ_Hx^{n+1/2} - C_b * ∂λ_Ey/∂z
|
|
1279
|
+
* λ_Hz^{n-1/2} = λ_Hz^{n+1/2} + C_b * ∂λ_Ey/∂x
|
|
1280
|
+
*/
|
|
1281
|
+
static void backward_kernel_lambda_h(
|
|
1282
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
1283
|
+
TIDE_DTYPE const *__restrict const lambda_ey,
|
|
1284
|
+
TIDE_DTYPE *__restrict const lambda_hx,
|
|
1285
|
+
TIDE_DTYPE *__restrict const lambda_hz,
|
|
1286
|
+
TIDE_DTYPE *__restrict const m_lambda_ey_x,
|
|
1287
|
+
TIDE_DTYPE *__restrict const m_lambda_ey_z,
|
|
1288
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
1289
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
1290
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
1291
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
1292
|
+
TIDE_DTYPE const *__restrict const by,
|
|
1293
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
1294
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
1295
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
1296
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
1297
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
1298
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
1299
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
1300
|
+
TIDE_DTYPE const rdy,
|
|
1301
|
+
TIDE_DTYPE const rdx,
|
|
1302
|
+
int64_t const n_shots,
|
|
1303
|
+
int64_t const ny,
|
|
1304
|
+
int64_t const nx,
|
|
1305
|
+
int64_t const shot_numel,
|
|
1306
|
+
int64_t const pml_y0,
|
|
1307
|
+
int64_t const pml_y1,
|
|
1308
|
+
int64_t const pml_x0,
|
|
1309
|
+
int64_t const pml_x1,
|
|
1310
|
+
bool const cb_batched) {
|
|
1311
|
+
|
|
1312
|
+
int64_t const pml_y0h = pml_y0;
|
|
1313
|
+
int64_t const pml_y1h = MAX(pml_y0, pml_y1 - 1);
|
|
1314
|
+
int64_t const pml_x0h = pml_x0;
|
|
1315
|
+
int64_t const pml_x1h = MAX(pml_x0, pml_x1 - 1);
|
|
1316
|
+
int64_t const pml_bounds_yh[] = {FD_PAD, pml_y0h, pml_y1h, ny - FD_PAD + 1};
|
|
1317
|
+
int64_t const pml_bounds_xh[] = {FD_PAD, pml_x0h, pml_x1h, nx - FD_PAD + 1};
|
|
1318
|
+
|
|
1319
|
+
TIDE_OMP_INDEX shot_idx;
|
|
1320
|
+
TIDE_OMP_PARALLEL_FOR
|
|
1321
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
1322
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
1323
|
+
TIDE_DTYPE const *__restrict const cb_ptr =
|
|
1324
|
+
cb_batched ? (cb + shot_offset) : cb;
|
|
1325
|
+
for (int pml_y = 0; pml_y < 3; ++pml_y) {
|
|
1326
|
+
for (int pml_x = 0; pml_x < 3; ++pml_x) {
|
|
1327
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
1328
|
+
for (int64_t y = pml_bounds_yh[pml_y]; y < pml_bounds_yh[pml_y + 1]; ++y) {
|
|
1329
|
+
for (int64_t x = pml_bounds_xh[pml_x]; x < pml_bounds_xh[pml_x + 1]; ++x) {
|
|
1330
|
+
int64_t const idx = IDX(y, x);
|
|
1331
|
+
TIDE_DTYPE const cb_val = cb_ptr[idx];
|
|
1332
|
+
|
|
1333
|
+
if (y < ny - FD_PAD) {
|
|
1334
|
+
TIDE_DTYPE d_lambda_ey_dz = DIFFYH1(LAMBDA_EY);
|
|
1335
|
+
|
|
1336
|
+
if (pml_y != 1) {
|
|
1337
|
+
M_LAMBDA_EY_Z(0, 0) = byh[y] * M_LAMBDA_EY_Z(0, 0) + ayh[y] * d_lambda_ey_dz;
|
|
1338
|
+
d_lambda_ey_dz = d_lambda_ey_dz / kyh[y] + M_LAMBDA_EY_Z(0, 0);
|
|
1339
|
+
}
|
|
1340
|
+
|
|
1341
|
+
LAMBDA_HX(0, 0) -= cb_val * d_lambda_ey_dz;
|
|
1342
|
+
}
|
|
1343
|
+
|
|
1344
|
+
if (x < nx - FD_PAD) {
|
|
1345
|
+
TIDE_DTYPE d_lambda_ey_dx = DIFFXH1(LAMBDA_EY);
|
|
1346
|
+
|
|
1347
|
+
if (pml_x != 1) {
|
|
1348
|
+
M_LAMBDA_EY_X(0, 0) = bxh[x] * M_LAMBDA_EY_X(0, 0) + axh[x] * d_lambda_ey_dx;
|
|
1349
|
+
d_lambda_ey_dx = d_lambda_ey_dx / kxh[x] + M_LAMBDA_EY_X(0, 0);
|
|
1350
|
+
}
|
|
1351
|
+
|
|
1352
|
+
LAMBDA_HZ(0, 0) += cb_val * d_lambda_ey_dx;
|
|
1353
|
+
}
|
|
1354
|
+
}
|
|
1355
|
+
}
|
|
1356
|
+
}
|
|
1357
|
+
}
|
|
1358
|
+
}
|
|
1359
|
+
}
|
|
1360
|
+
|
|
1361
|
+
|
|
1362
|
+
/*
|
|
1363
|
+
* Backward kernel for adjoint λ_Ey field update with gradient accumulation
|
|
1364
|
+
*
|
|
1365
|
+
* Adjoint equation for E field (time reversed, swap Cb and Cq roles):
|
|
1366
|
+
* λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * (∂λ_Hz/∂x - ∂λ_Hx/∂z)
|
|
1367
|
+
*
|
|
1368
|
+
* Gradient accumulation:
|
|
1369
|
+
* grad_ca += λ_Ey^{n+1} * E_y^n
|
|
1370
|
+
* grad_cb += λ_Ey^{n+1} * curl_H^n
|
|
1371
|
+
*
|
|
1372
|
+
* Uses pml_bounds arrays to divide domain into 9 regions (3x3 grid):
|
|
1373
|
+
* pml_y/pml_x == 0: Left/Top PML region
|
|
1374
|
+
* pml_y/pml_x == 1: Interior region (where gradients are accumulated)
|
|
1375
|
+
* pml_y/pml_x == 2: Right/Bottom PML region
|
|
1376
|
+
*/
|
|
1377
|
+
static void backward_kernel_lambda_e_with_grad(
|
|
1378
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
1379
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
1380
|
+
TIDE_DTYPE const *__restrict const lambda_hx,
|
|
1381
|
+
TIDE_DTYPE const *__restrict const lambda_hz,
|
|
1382
|
+
TIDE_DTYPE *__restrict const lambda_ey,
|
|
1383
|
+
TIDE_DTYPE *__restrict const m_lambda_hx_z,
|
|
1384
|
+
TIDE_DTYPE *__restrict const m_lambda_hz_x,
|
|
1385
|
+
TIDE_DTYPE const *__restrict const ey_store,
|
|
1386
|
+
TIDE_DTYPE const *__restrict const curl_h_store,
|
|
1387
|
+
TIDE_DTYPE *__restrict const grad_ca,
|
|
1388
|
+
TIDE_DTYPE *__restrict const grad_cb,
|
|
1389
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
1390
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
1391
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
1392
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
1393
|
+
TIDE_DTYPE const *__restrict const by,
|
|
1394
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
1395
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
1396
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
1397
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
1398
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
1399
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
1400
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
1401
|
+
TIDE_DTYPE const rdy,
|
|
1402
|
+
TIDE_DTYPE const rdx,
|
|
1403
|
+
int64_t const n_shots,
|
|
1404
|
+
int64_t const ny,
|
|
1405
|
+
int64_t const nx,
|
|
1406
|
+
int64_t const shot_numel,
|
|
1407
|
+
int64_t const pml_y0,
|
|
1408
|
+
int64_t const pml_y1,
|
|
1409
|
+
int64_t const pml_x0,
|
|
1410
|
+
int64_t const pml_x1,
|
|
1411
|
+
bool const ca_batched,
|
|
1412
|
+
bool const cq_batched,
|
|
1413
|
+
bool const ca_requires_grad,
|
|
1414
|
+
bool const cb_requires_grad,
|
|
1415
|
+
int64_t const step_ratio) {
|
|
1416
|
+
|
|
1417
|
+
// PML region bounds arrays
|
|
1418
|
+
// pml_bounds[0] = FD_PAD (start of computational domain)
|
|
1419
|
+
// pml_bounds[1] = pml_y0 (start of interior region)
|
|
1420
|
+
// pml_bounds[2] = pml_y1 (end of interior region)
|
|
1421
|
+
// pml_bounds[3] = ny - FD_PAD + 1 (end of computational domain)
|
|
1422
|
+
int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
|
|
1423
|
+
int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
|
|
1424
|
+
|
|
1425
|
+
TIDE_OMP_INDEX shot_idx;
|
|
1426
|
+
TIDE_OMP_PARALLEL_FOR
|
|
1427
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
1428
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
1429
|
+
TIDE_DTYPE const *__restrict const ca_ptr =
|
|
1430
|
+
ca_batched ? (ca + shot_offset) : ca;
|
|
1431
|
+
TIDE_DTYPE const *__restrict const cq_ptr =
|
|
1432
|
+
cq_batched ? (cq + shot_offset) : cq;
|
|
1433
|
+
// Loop over 3x3 grid of regions
|
|
1434
|
+
for (int pml_y = 0; pml_y < 3; ++pml_y) {
|
|
1435
|
+
for (int pml_x = 0; pml_x < 3; ++pml_x) {
|
|
1436
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
1437
|
+
for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
|
|
1438
|
+
for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
|
|
1439
|
+
int64_t const idx = IDX(y, x);
|
|
1440
|
+
int64_t const store_idx = shot_offset + idx;
|
|
1441
|
+
TIDE_DTYPE const ca_val = ca_ptr[idx];
|
|
1442
|
+
TIDE_DTYPE const cq_val = cq_ptr[idx];
|
|
1443
|
+
|
|
1444
|
+
// Compute d(λ_Hz)/dx at integer grid points
|
|
1445
|
+
TIDE_DTYPE d_lambda_hz_dx = DIFFX1(LAMBDA_HZ);
|
|
1446
|
+
// Compute d(λ_Hx)/dz at integer grid points
|
|
1447
|
+
TIDE_DTYPE d_lambda_hx_dz = DIFFY1(LAMBDA_HX);
|
|
1448
|
+
|
|
1449
|
+
// Apply adjoint CPML for d(λ_Hz)/dx (only in PML regions)
|
|
1450
|
+
if (pml_x != 1) {
|
|
1451
|
+
M_LAMBDA_HZ_X(0, 0) = bx[x] * M_LAMBDA_HZ_X(0, 0) + ax[x] * d_lambda_hz_dx;
|
|
1452
|
+
d_lambda_hz_dx = d_lambda_hz_dx / kx[x] + M_LAMBDA_HZ_X(0, 0);
|
|
1453
|
+
}
|
|
1454
|
+
|
|
1455
|
+
// Apply adjoint CPML for d(λ_Hx)/dz (only in PML regions)
|
|
1456
|
+
if (pml_y != 1) {
|
|
1457
|
+
M_LAMBDA_HX_Z(0, 0) = by[y] * M_LAMBDA_HX_Z(0, 0) + ay[y] * d_lambda_hx_dz;
|
|
1458
|
+
d_lambda_hx_dz = d_lambda_hx_dz / ky[y] + M_LAMBDA_HX_Z(0, 0);
|
|
1459
|
+
}
|
|
1460
|
+
|
|
1461
|
+
// curl_λH = d(λ_Hz)/dx - d(λ_Hx)/dz
|
|
1462
|
+
TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
|
|
1463
|
+
|
|
1464
|
+
// Store current λ_Ey before update (this is λ_Ey^{n+1})
|
|
1465
|
+
TIDE_DTYPE lambda_ey_curr = LAMBDA_EY(0, 0);
|
|
1466
|
+
|
|
1467
|
+
// Update λ_Ey: λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * curl_λH
|
|
1468
|
+
LAMBDA_EY(0, 0) = ca_val * lambda_ey_curr + cq_val * curl_lambda_h;
|
|
1469
|
+
|
|
1470
|
+
// Accumulate gradients only in interior region (pml_y == 1 && pml_x == 1)
|
|
1471
|
+
if (pml_y == 1 && pml_x == 1) {
|
|
1472
|
+
// grad_ca += λ_Ey^{n+1} * E_y^n
|
|
1473
|
+
if (ca_requires_grad && ey_store != NULL) {
|
|
1474
|
+
TIDE_DTYPE ey_n = ey_store[store_idx];
|
|
1475
|
+
if (ca_batched) {
|
|
1476
|
+
grad_ca[store_idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
|
|
1477
|
+
} else {
|
|
1478
|
+
#ifdef _OPENMP
|
|
1479
|
+
#pragma omp atomic
|
|
1480
|
+
#endif
|
|
1481
|
+
grad_ca[idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
|
|
1482
|
+
}
|
|
1483
|
+
}
|
|
1484
|
+
|
|
1485
|
+
// grad_cb += λ_Ey^{n+1} * curl_H^n
|
|
1486
|
+
if (cb_requires_grad && curl_h_store != NULL) {
|
|
1487
|
+
TIDE_DTYPE curl_h_n = curl_h_store[store_idx];
|
|
1488
|
+
if (ca_batched) {
|
|
1489
|
+
grad_cb[store_idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
|
|
1490
|
+
} else {
|
|
1491
|
+
#ifdef _OPENMP
|
|
1492
|
+
#pragma omp atomic
|
|
1493
|
+
#endif
|
|
1494
|
+
grad_cb[idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
|
|
1495
|
+
}
|
|
1496
|
+
}
|
|
1497
|
+
}
|
|
1498
|
+
}
|
|
1499
|
+
}
|
|
1500
|
+
}
|
|
1501
|
+
}
|
|
1502
|
+
}
|
|
1503
|
+
}
|
|
1504
|
+
|
|
1505
|
+
static void backward_kernel_lambda_e_with_grad_bf16(
|
|
1506
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
1507
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
1508
|
+
TIDE_DTYPE const *__restrict const lambda_hx,
|
|
1509
|
+
TIDE_DTYPE const *__restrict const lambda_hz,
|
|
1510
|
+
TIDE_DTYPE *__restrict const lambda_ey,
|
|
1511
|
+
TIDE_DTYPE *__restrict const m_lambda_hx_z,
|
|
1512
|
+
TIDE_DTYPE *__restrict const m_lambda_hz_x,
|
|
1513
|
+
tide_bfloat16 const *__restrict const ey_store,
|
|
1514
|
+
tide_bfloat16 const *__restrict const curl_h_store,
|
|
1515
|
+
TIDE_DTYPE *__restrict const grad_ca,
|
|
1516
|
+
TIDE_DTYPE *__restrict const grad_cb,
|
|
1517
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
1518
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
1519
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
1520
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
1521
|
+
TIDE_DTYPE const *__restrict const by,
|
|
1522
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
1523
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
1524
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
1525
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
1526
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
1527
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
1528
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
1529
|
+
TIDE_DTYPE const rdy,
|
|
1530
|
+
TIDE_DTYPE const rdx,
|
|
1531
|
+
int64_t const n_shots,
|
|
1532
|
+
int64_t const ny,
|
|
1533
|
+
int64_t const nx,
|
|
1534
|
+
int64_t const shot_numel,
|
|
1535
|
+
int64_t const pml_y0,
|
|
1536
|
+
int64_t const pml_y1,
|
|
1537
|
+
int64_t const pml_x0,
|
|
1538
|
+
int64_t const pml_x1,
|
|
1539
|
+
bool const ca_batched,
|
|
1540
|
+
bool const cq_batched,
|
|
1541
|
+
bool const ca_requires_grad,
|
|
1542
|
+
bool const cb_requires_grad,
|
|
1543
|
+
int64_t const step_ratio) {
|
|
1544
|
+
|
|
1545
|
+
int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
|
|
1546
|
+
int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
|
|
1547
|
+
|
|
1548
|
+
TIDE_OMP_INDEX shot_idx;
|
|
1549
|
+
TIDE_OMP_PARALLEL_FOR
|
|
1550
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
1551
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
1552
|
+
TIDE_DTYPE const *__restrict const ca_ptr =
|
|
1553
|
+
ca_batched ? (ca + shot_offset) : ca;
|
|
1554
|
+
TIDE_DTYPE const *__restrict const cq_ptr =
|
|
1555
|
+
cq_batched ? (cq + shot_offset) : cq;
|
|
1556
|
+
for (int pml_y = 0; pml_y < 3; ++pml_y) {
|
|
1557
|
+
for (int pml_x = 0; pml_x < 3; ++pml_x) {
|
|
1558
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
1559
|
+
for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
|
|
1560
|
+
for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
|
|
1561
|
+
int64_t const idx = IDX(y, x);
|
|
1562
|
+
int64_t const store_idx = shot_offset + idx;
|
|
1563
|
+
TIDE_DTYPE const ca_val = ca_ptr[idx];
|
|
1564
|
+
TIDE_DTYPE const cq_val = cq_ptr[idx];
|
|
1565
|
+
|
|
1566
|
+
TIDE_DTYPE d_lambda_hz_dx = DIFFX1(LAMBDA_HZ);
|
|
1567
|
+
TIDE_DTYPE d_lambda_hx_dz = DIFFY1(LAMBDA_HX);
|
|
1568
|
+
|
|
1569
|
+
if (pml_x != 1) {
|
|
1570
|
+
M_LAMBDA_HZ_X(0, 0) = bx[x] * M_LAMBDA_HZ_X(0, 0) + ax[x] * d_lambda_hz_dx;
|
|
1571
|
+
d_lambda_hz_dx = d_lambda_hz_dx / kx[x] + M_LAMBDA_HZ_X(0, 0);
|
|
1572
|
+
}
|
|
1573
|
+
if (pml_y != 1) {
|
|
1574
|
+
M_LAMBDA_HX_Z(0, 0) = by[y] * M_LAMBDA_HX_Z(0, 0) + ay[y] * d_lambda_hx_dz;
|
|
1575
|
+
d_lambda_hx_dz = d_lambda_hx_dz / ky[y] + M_LAMBDA_HX_Z(0, 0);
|
|
1576
|
+
}
|
|
1577
|
+
|
|
1578
|
+
TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
|
|
1579
|
+
TIDE_DTYPE lambda_ey_curr = LAMBDA_EY(0, 0);
|
|
1580
|
+
LAMBDA_EY(0, 0) = ca_val * lambda_ey_curr + cq_val * curl_lambda_h;
|
|
1581
|
+
|
|
1582
|
+
if (pml_y == 1 && pml_x == 1) {
|
|
1583
|
+
if (ca_requires_grad && ey_store != NULL) {
|
|
1584
|
+
TIDE_DTYPE ey_n =
|
|
1585
|
+
(TIDE_DTYPE)tide_bf16_to_float(ey_store[store_idx]);
|
|
1586
|
+
if (ca_batched) {
|
|
1587
|
+
grad_ca[store_idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
|
|
1588
|
+
} else {
|
|
1589
|
+
#ifdef _OPENMP
|
|
1590
|
+
#pragma omp atomic
|
|
1591
|
+
#endif
|
|
1592
|
+
grad_ca[idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
|
|
1593
|
+
}
|
|
1594
|
+
}
|
|
1595
|
+
|
|
1596
|
+
if (cb_requires_grad && curl_h_store != NULL) {
|
|
1597
|
+
TIDE_DTYPE curl_h_n =
|
|
1598
|
+
(TIDE_DTYPE)tide_bf16_to_float(curl_h_store[store_idx]);
|
|
1599
|
+
if (ca_batched) {
|
|
1600
|
+
grad_cb[store_idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
|
|
1601
|
+
} else {
|
|
1602
|
+
#ifdef _OPENMP
|
|
1603
|
+
#pragma omp atomic
|
|
1604
|
+
#endif
|
|
1605
|
+
grad_cb[idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
|
|
1606
|
+
}
|
|
1607
|
+
}
|
|
1608
|
+
}
|
|
1609
|
+
}
|
|
1610
|
+
}
|
|
1611
|
+
}
|
|
1612
|
+
}
|
|
1613
|
+
}
|
|
1614
|
+
}
|
|
1615
|
+
|
|
1616
|
+
static void backward_kernel_lambda_e_with_grad_fp8(
|
|
1617
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
1618
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
1619
|
+
TIDE_DTYPE const *__restrict const lambda_hx,
|
|
1620
|
+
TIDE_DTYPE const *__restrict const lambda_hz,
|
|
1621
|
+
TIDE_DTYPE *__restrict const lambda_ey,
|
|
1622
|
+
TIDE_DTYPE *__restrict const m_lambda_hx_z,
|
|
1623
|
+
TIDE_DTYPE *__restrict const m_lambda_hz_x,
|
|
1624
|
+
tide_fp8_e4m3 const *__restrict const ey_store,
|
|
1625
|
+
tide_fp8_e4m3 const *__restrict const curl_h_store,
|
|
1626
|
+
TIDE_DTYPE *__restrict const grad_ca,
|
|
1627
|
+
TIDE_DTYPE *__restrict const grad_cb,
|
|
1628
|
+
TIDE_DTYPE const *__restrict const ay,
|
|
1629
|
+
TIDE_DTYPE const *__restrict const ayh,
|
|
1630
|
+
TIDE_DTYPE const *__restrict const ax,
|
|
1631
|
+
TIDE_DTYPE const *__restrict const axh,
|
|
1632
|
+
TIDE_DTYPE const *__restrict const by,
|
|
1633
|
+
TIDE_DTYPE const *__restrict const byh,
|
|
1634
|
+
TIDE_DTYPE const *__restrict const bx,
|
|
1635
|
+
TIDE_DTYPE const *__restrict const bxh,
|
|
1636
|
+
TIDE_DTYPE const *__restrict const ky,
|
|
1637
|
+
TIDE_DTYPE const *__restrict const kyh,
|
|
1638
|
+
TIDE_DTYPE const *__restrict const kx,
|
|
1639
|
+
TIDE_DTYPE const *__restrict const kxh,
|
|
1640
|
+
TIDE_DTYPE const rdy,
|
|
1641
|
+
TIDE_DTYPE const rdx,
|
|
1642
|
+
int64_t const n_shots,
|
|
1643
|
+
int64_t const ny,
|
|
1644
|
+
int64_t const nx,
|
|
1645
|
+
int64_t const shot_numel,
|
|
1646
|
+
int64_t const pml_y0,
|
|
1647
|
+
int64_t const pml_y1,
|
|
1648
|
+
int64_t const pml_x0,
|
|
1649
|
+
int64_t const pml_x1,
|
|
1650
|
+
bool const ca_batched,
|
|
1651
|
+
bool const cq_batched,
|
|
1652
|
+
bool const ca_requires_grad,
|
|
1653
|
+
bool const cb_requires_grad,
|
|
1654
|
+
int64_t const step_ratio) {
|
|
1655
|
+
|
|
1656
|
+
int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
|
|
1657
|
+
int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
|
|
1658
|
+
|
|
1659
|
+
TIDE_OMP_INDEX shot_idx;
|
|
1660
|
+
TIDE_OMP_PARALLEL_FOR
|
|
1661
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
1662
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
1663
|
+
TIDE_DTYPE const *__restrict const ca_ptr =
|
|
1664
|
+
ca_batched ? (ca + shot_offset) : ca;
|
|
1665
|
+
TIDE_DTYPE const *__restrict const cq_ptr =
|
|
1666
|
+
cq_batched ? (cq + shot_offset) : cq;
|
|
1667
|
+
for (int pml_y = 0; pml_y < 3; ++pml_y) {
|
|
1668
|
+
for (int pml_x = 0; pml_x < 3; ++pml_x) {
|
|
1669
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
1670
|
+
for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
|
|
1671
|
+
for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
|
|
1672
|
+
int64_t const idx = IDX(y, x);
|
|
1673
|
+
int64_t const store_idx = shot_offset + idx;
|
|
1674
|
+
TIDE_DTYPE const ca_val = ca_ptr[idx];
|
|
1675
|
+
TIDE_DTYPE const cq_val = cq_ptr[idx];
|
|
1676
|
+
|
|
1677
|
+
TIDE_DTYPE d_lambda_hz_dx = DIFFX1(LAMBDA_HZ);
|
|
1678
|
+
TIDE_DTYPE d_lambda_hx_dz = DIFFY1(LAMBDA_HX);
|
|
1679
|
+
|
|
1680
|
+
if (pml_x != 1) {
|
|
1681
|
+
M_LAMBDA_HZ_X(0, 0) = bx[x] * M_LAMBDA_HZ_X(0, 0) + ax[x] * d_lambda_hz_dx;
|
|
1682
|
+
d_lambda_hz_dx = d_lambda_hz_dx / kx[x] + M_LAMBDA_HZ_X(0, 0);
|
|
1683
|
+
}
|
|
1684
|
+
if (pml_y != 1) {
|
|
1685
|
+
M_LAMBDA_HX_Z(0, 0) = by[y] * M_LAMBDA_HX_Z(0, 0) + ay[y] * d_lambda_hx_dz;
|
|
1686
|
+
d_lambda_hx_dz = d_lambda_hx_dz / ky[y] + M_LAMBDA_HX_Z(0, 0);
|
|
1687
|
+
}
|
|
1688
|
+
|
|
1689
|
+
TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
|
|
1690
|
+
TIDE_DTYPE lambda_ey_curr = LAMBDA_EY(0, 0);
|
|
1691
|
+
LAMBDA_EY(0, 0) = ca_val * lambda_ey_curr + cq_val * curl_lambda_h;
|
|
1692
|
+
|
|
1693
|
+
if (pml_y == 1 && pml_x == 1) {
|
|
1694
|
+
if (ca_requires_grad && ey_store != NULL) {
|
|
1695
|
+
TIDE_DTYPE ey_n =
|
|
1696
|
+
(TIDE_DTYPE)tide_fp8_e4m3_to_float(ey_store[store_idx]);
|
|
1697
|
+
if (ca_batched) {
|
|
1698
|
+
grad_ca[store_idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
|
|
1699
|
+
} else {
|
|
1700
|
+
#ifdef _OPENMP
|
|
1701
|
+
#pragma omp atomic
|
|
1702
|
+
#endif
|
|
1703
|
+
grad_ca[idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
|
|
1704
|
+
}
|
|
1705
|
+
}
|
|
1706
|
+
|
|
1707
|
+
if (cb_requires_grad && curl_h_store != NULL) {
|
|
1708
|
+
TIDE_DTYPE curl_h_n =
|
|
1709
|
+
(TIDE_DTYPE)tide_fp8_e4m3_to_float(curl_h_store[store_idx]);
|
|
1710
|
+
if (ca_batched) {
|
|
1711
|
+
grad_cb[store_idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
|
|
1712
|
+
} else {
|
|
1713
|
+
#ifdef _OPENMP
|
|
1714
|
+
#pragma omp atomic
|
|
1715
|
+
#endif
|
|
1716
|
+
grad_cb[idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
|
|
1717
|
+
}
|
|
1718
|
+
}
|
|
1719
|
+
}
|
|
1720
|
+
}
|
|
1721
|
+
}
|
|
1722
|
+
}
|
|
1723
|
+
}
|
|
1724
|
+
}
|
|
1725
|
+
}
|
|
1726
|
+
|
|
1727
|
+
static void inverse_kernel_e_and_curl(
|
|
1728
|
+
TIDE_DTYPE const *__restrict const ca,
|
|
1729
|
+
TIDE_DTYPE const *__restrict const cb,
|
|
1730
|
+
TIDE_DTYPE const *__restrict const hx,
|
|
1731
|
+
TIDE_DTYPE const *__restrict const hz,
|
|
1732
|
+
TIDE_DTYPE *__restrict const ey,
|
|
1733
|
+
TIDE_DTYPE *__restrict const curl_h_out,
|
|
1734
|
+
TIDE_DTYPE const rdy,
|
|
1735
|
+
TIDE_DTYPE const rdx,
|
|
1736
|
+
int64_t const n_shots,
|
|
1737
|
+
int64_t const ny,
|
|
1738
|
+
int64_t const nx,
|
|
1739
|
+
int64_t const shot_numel,
|
|
1740
|
+
int64_t const pml_y0,
|
|
1741
|
+
int64_t const pml_y1,
|
|
1742
|
+
int64_t const pml_x0,
|
|
1743
|
+
int64_t const pml_x1,
|
|
1744
|
+
bool const ca_batched,
|
|
1745
|
+
bool const cb_batched) {
|
|
1746
|
+
|
|
1747
|
+
int64_t const y0 = MAX(FD_PAD, pml_y0);
|
|
1748
|
+
int64_t const y1 = MIN(ny - FD_PAD + 1, pml_y1);
|
|
1749
|
+
int64_t const x0 = MAX(FD_PAD, pml_x0);
|
|
1750
|
+
int64_t const x1 = MIN(nx - FD_PAD + 1, pml_x1);
|
|
1751
|
+
|
|
1752
|
+
TIDE_OMP_INDEX shot_idx;
|
|
1753
|
+
TIDE_OMP_PARALLEL_FOR
|
|
1754
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
1755
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
1756
|
+
TIDE_DTYPE const *__restrict const ca_ptr =
|
|
1757
|
+
ca_batched ? (ca + shot_offset) : ca;
|
|
1758
|
+
TIDE_DTYPE const *__restrict const cb_ptr =
|
|
1759
|
+
cb_batched ? (cb + shot_offset) : cb;
|
|
1760
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
1761
|
+
for (int64_t y = y0; y < y1; ++y) {
|
|
1762
|
+
for (int64_t x = x0; x < x1; ++x) {
|
|
1763
|
+
int64_t const idx = IDX(y, x);
|
|
1764
|
+
int64_t const store_idx = shot_offset + idx;
|
|
1765
|
+
TIDE_DTYPE const ca_val = ca_ptr[idx];
|
|
1766
|
+
TIDE_DTYPE const cb_val = cb_ptr[idx];
|
|
1767
|
+
|
|
1768
|
+
TIDE_DTYPE const dhz_dx = DIFFX1(HZ);
|
|
1769
|
+
TIDE_DTYPE const dhx_dz = DIFFY1(HX);
|
|
1770
|
+
TIDE_DTYPE const curl_h = dhz_dx - dhx_dz;
|
|
1771
|
+
|
|
1772
|
+
curl_h_out[store_idx] = curl_h;
|
|
1773
|
+
ey[store_idx] = (ey[store_idx] - cb_val * curl_h) / ca_val;
|
|
1774
|
+
}
|
|
1775
|
+
}
|
|
1776
|
+
}
|
|
1777
|
+
}
|
|
1778
|
+
|
|
1779
|
+
static void inverse_kernel_h(
|
|
1780
|
+
TIDE_DTYPE const *__restrict const cq,
|
|
1781
|
+
TIDE_DTYPE const *__restrict const ey,
|
|
1782
|
+
TIDE_DTYPE *__restrict const hx,
|
|
1783
|
+
TIDE_DTYPE *__restrict const hz,
|
|
1784
|
+
TIDE_DTYPE const rdy,
|
|
1785
|
+
TIDE_DTYPE const rdx,
|
|
1786
|
+
int64_t const n_shots,
|
|
1787
|
+
int64_t const ny,
|
|
1788
|
+
int64_t const nx,
|
|
1789
|
+
int64_t const shot_numel,
|
|
1790
|
+
int64_t const pml_y0,
|
|
1791
|
+
int64_t const pml_y1,
|
|
1792
|
+
int64_t const pml_x0,
|
|
1793
|
+
int64_t const pml_x1,
|
|
1794
|
+
bool const cq_batched) {
|
|
1795
|
+
|
|
1796
|
+
int64_t const y0 = MAX(FD_PAD, pml_y0);
|
|
1797
|
+
int64_t const y1 = MIN(ny - FD_PAD + 1, pml_y1);
|
|
1798
|
+
int64_t const x0 = MAX(FD_PAD, pml_x0);
|
|
1799
|
+
int64_t const x1 = MIN(nx - FD_PAD + 1, pml_x1);
|
|
1800
|
+
|
|
1801
|
+
TIDE_OMP_INDEX shot_idx;
|
|
1802
|
+
TIDE_OMP_PARALLEL_FOR
|
|
1803
|
+
for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
|
|
1804
|
+
int64_t const shot_offset = shot_idx * shot_numel;
|
|
1805
|
+
TIDE_DTYPE const *__restrict const cq_ptr =
|
|
1806
|
+
cq_batched ? (cq + shot_offset) : cq;
|
|
1807
|
+
TIDE_OMP_SIMD_COLLAPSE2
|
|
1808
|
+
for (int64_t y = y0; y < y1; ++y) {
|
|
1809
|
+
for (int64_t x = x0; x < x1; ++x) {
|
|
1810
|
+
int64_t const idx = IDX(y, x);
|
|
1811
|
+
TIDE_DTYPE const cq_val = cq_ptr[idx];
|
|
1812
|
+
|
|
1813
|
+
if (y < ny - FD_PAD) {
|
|
1814
|
+
TIDE_DTYPE const dey_dz = DIFFYH1(EY);
|
|
1815
|
+
HX(0, 0) += cq_val * dey_dz;
|
|
1816
|
+
}
|
|
1817
|
+
if (x < nx - FD_PAD) {
|
|
1818
|
+
TIDE_DTYPE const dey_dx = DIFFXH1(EY);
|
|
1819
|
+
HZ(0, 0) -= cq_val * dey_dx;
|
|
1820
|
+
}
|
|
1821
|
+
}
|
|
1822
|
+
}
|
|
1823
|
+
}
|
|
1824
|
+
}
|
|
1825
|
+
|
|
1826
|
+
/*
|
|
1827
|
+
* Full backward pass for Maxwell TM equations
|
|
1828
|
+
*
|
|
1829
|
+
* Implements the Adjoint State Method to compute:
|
|
1830
|
+
* - grad_ca: gradient w.r.t. C_a coefficient
|
|
1831
|
+
* - grad_cb: gradient w.r.t. C_b coefficient
|
|
1832
|
+
* - grad_eps: gradient w.r.t. epsilon_r
|
|
1833
|
+
* - grad_sigma: gradient w.r.t. conductivity
|
|
1834
|
+
* - grad_f: gradient w.r.t. source amplitudes
|
|
1835
|
+
*/
|
|
1836
|
+
#ifdef __cplusplus
|
|
1837
|
+
extern "C"
|
|
1838
|
+
#endif
|
|
1839
|
+
#ifdef _WIN32
|
|
1840
|
+
__declspec(dllexport)
|
|
1841
|
+
#endif
|
|
1842
|
+
void FUNC(backward)(
|
|
1843
|
+
TIDE_DTYPE const *const ca,
|
|
1844
|
+
TIDE_DTYPE const *const cb,
|
|
1845
|
+
TIDE_DTYPE const *const cq,
|
|
1846
|
+
TIDE_DTYPE const *const grad_r,
|
|
1847
|
+
TIDE_DTYPE *const lambda_ey,
|
|
1848
|
+
TIDE_DTYPE *const lambda_hx,
|
|
1849
|
+
TIDE_DTYPE *const lambda_hz,
|
|
1850
|
+
TIDE_DTYPE *const m_lambda_ey_x,
|
|
1851
|
+
TIDE_DTYPE *const m_lambda_ey_z,
|
|
1852
|
+
TIDE_DTYPE *const m_lambda_hx_z,
|
|
1853
|
+
TIDE_DTYPE *const m_lambda_hz_x,
|
|
1854
|
+
TIDE_DTYPE *const ey_store_1,
|
|
1855
|
+
void *const ey_store_3,
|
|
1856
|
+
char const *const *const ey_filenames,
|
|
1857
|
+
TIDE_DTYPE *const curl_store_1,
|
|
1858
|
+
void *const curl_store_3,
|
|
1859
|
+
char const *const *const curl_filenames,
|
|
1860
|
+
TIDE_DTYPE *const grad_f,
|
|
1861
|
+
TIDE_DTYPE *const grad_ca,
|
|
1862
|
+
TIDE_DTYPE *const grad_cb,
|
|
1863
|
+
TIDE_DTYPE *const grad_eps,
|
|
1864
|
+
TIDE_DTYPE *const grad_sigma,
|
|
1865
|
+
TIDE_DTYPE *const grad_ca_shot, /* unused in CPU - for API compatibility */
|
|
1866
|
+
TIDE_DTYPE *const grad_cb_shot, /* unused in CPU - for API compatibility */
|
|
1867
|
+
TIDE_DTYPE const *const ay,
|
|
1868
|
+
TIDE_DTYPE const *const by,
|
|
1869
|
+
TIDE_DTYPE const *const ayh,
|
|
1870
|
+
TIDE_DTYPE const *const byh,
|
|
1871
|
+
TIDE_DTYPE const *const ax,
|
|
1872
|
+
TIDE_DTYPE const *const bx,
|
|
1873
|
+
TIDE_DTYPE const *const axh,
|
|
1874
|
+
TIDE_DTYPE const *const bxh,
|
|
1875
|
+
TIDE_DTYPE const *const ky,
|
|
1876
|
+
TIDE_DTYPE const *const kyh,
|
|
1877
|
+
TIDE_DTYPE const *const kx,
|
|
1878
|
+
TIDE_DTYPE const *const kxh,
|
|
1879
|
+
int64_t const *const sources_i,
|
|
1880
|
+
int64_t const *const receivers_i,
|
|
1881
|
+
TIDE_DTYPE const rdy,
|
|
1882
|
+
TIDE_DTYPE const rdx,
|
|
1883
|
+
TIDE_DTYPE const dt,
|
|
1884
|
+
int64_t const nt,
|
|
1885
|
+
int64_t const n_shots,
|
|
1886
|
+
int64_t const ny,
|
|
1887
|
+
int64_t const nx,
|
|
1888
|
+
int64_t const n_sources_per_shot,
|
|
1889
|
+
int64_t const n_receivers_per_shot,
|
|
1890
|
+
int64_t const step_ratio,
|
|
1891
|
+
int64_t const storage_mode,
|
|
1892
|
+
int64_t const shot_bytes_uncomp,
|
|
1893
|
+
bool const ca_requires_grad,
|
|
1894
|
+
bool const cb_requires_grad,
|
|
1895
|
+
bool const ca_batched,
|
|
1896
|
+
bool const cb_batched,
|
|
1897
|
+
bool const cq_batched,
|
|
1898
|
+
int64_t const start_t,
|
|
1899
|
+
int64_t const pml_y0,
|
|
1900
|
+
int64_t const pml_x0,
|
|
1901
|
+
int64_t const pml_y1,
|
|
1902
|
+
int64_t const pml_x1,
|
|
1903
|
+
int64_t const n_threads,
|
|
1904
|
+
int64_t const device /* unused for CPU */) {
|
|
1905
|
+
|
|
1906
|
+
(void)device;
|
|
1907
|
+
(void)grad_ca_shot; // Not needed in CPU version
|
|
1908
|
+
(void)grad_cb_shot; // Not needed in CPU version
|
|
1909
|
+
(void)ey_store_3;
|
|
1910
|
+
(void)curl_store_3;
|
|
1911
|
+
#ifdef _OPENMP
|
|
1912
|
+
int const prev_threads = omp_get_max_threads();
|
|
1913
|
+
if (n_threads > 0) {
|
|
1914
|
+
omp_set_num_threads((int)n_threads);
|
|
1915
|
+
}
|
|
1916
|
+
#else
|
|
1917
|
+
(void)n_threads;
|
|
1918
|
+
#endif
|
|
1919
|
+
|
|
1920
|
+
int64_t const shot_numel = ny * nx;
|
|
1921
|
+
int64_t const store_size = n_shots * shot_numel;
|
|
1922
|
+
bool const storage_fp8 = (shot_bytes_uncomp == shot_numel * 1);
|
|
1923
|
+
bool const storage_bf16 = (shot_bytes_uncomp == shot_numel * 2);
|
|
1924
|
+
|
|
1925
|
+
FILE **fp_ey = NULL;
|
|
1926
|
+
FILE **fp_curl = NULL;
|
|
1927
|
+
if (storage_mode == STORAGE_DISK) {
|
|
1928
|
+
if (ca_requires_grad) {
|
|
1929
|
+
fp_ey = (FILE **)malloc((size_t)n_shots * sizeof(FILE *));
|
|
1930
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1931
|
+
fp_ey[shot] = fopen(ey_filenames[shot], "rb");
|
|
1932
|
+
}
|
|
1933
|
+
}
|
|
1934
|
+
if (cb_requires_grad) {
|
|
1935
|
+
fp_curl = (FILE **)malloc((size_t)n_shots * sizeof(FILE *));
|
|
1936
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1937
|
+
fp_curl[shot] = fopen(curl_filenames[shot], "rb");
|
|
1938
|
+
}
|
|
1939
|
+
}
|
|
1940
|
+
}
|
|
1941
|
+
|
|
1942
|
+
// Time reversed loop: from t = start_t - 1 down to start_t - nt
|
|
1943
|
+
//
|
|
1944
|
+
// Forward order was: H_update -> E_update(store) -> source_inject -> record
|
|
1945
|
+
// Backward order is: record(adjoint) -> source_inject(adjoint) -> E_update(adjoint) -> H_update(adjoint)
|
|
1946
|
+
// Which translates to: grad_r_inject -> grad_f_record -> λ_E_update(grad_accum) -> λ_H_update
|
|
1947
|
+
|
|
1948
|
+
for (int64_t t = start_t - 1; t >= start_t - nt; --t) {
|
|
1949
|
+
// Determine storage index for this time step
|
|
1950
|
+
int64_t const store_idx = t / step_ratio;
|
|
1951
|
+
bool const do_grad = (t % step_ratio) == 0;
|
|
1952
|
+
bool const grad_ey = do_grad && ca_requires_grad;
|
|
1953
|
+
bool const grad_curl = do_grad && cb_requires_grad;
|
|
1954
|
+
|
|
1955
|
+
int64_t const store_offset =
|
|
1956
|
+
(storage_mode == STORAGE_DEVICE ? store_idx * store_size : 0);
|
|
1957
|
+
|
|
1958
|
+
if (storage_fp8) {
|
|
1959
|
+
tide_fp8_e4m3 *const ey_store_1_t =
|
|
1960
|
+
(tide_fp8_e4m3 *)ey_store_1 + store_offset;
|
|
1961
|
+
tide_fp8_e4m3 *const curl_store_1_t =
|
|
1962
|
+
(tide_fp8_e4m3 *)curl_store_1 + store_offset;
|
|
1963
|
+
|
|
1964
|
+
if (storage_mode == STORAGE_DISK) {
|
|
1965
|
+
if (grad_ey) {
|
|
1966
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1967
|
+
storage_load_snapshot_cpu(
|
|
1968
|
+
(void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
|
|
1969
|
+
storage_mode, store_idx, (size_t)shot_bytes_uncomp);
|
|
1970
|
+
}
|
|
1971
|
+
}
|
|
1972
|
+
if (grad_curl) {
|
|
1973
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1974
|
+
storage_load_snapshot_cpu(
|
|
1975
|
+
(void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
|
|
1976
|
+
storage_mode, store_idx, (size_t)shot_bytes_uncomp);
|
|
1977
|
+
}
|
|
1978
|
+
}
|
|
1979
|
+
}
|
|
1980
|
+
} else if (storage_bf16) {
|
|
1981
|
+
tide_bfloat16 *const ey_store_1_t =
|
|
1982
|
+
(tide_bfloat16 *)ey_store_1 + store_offset;
|
|
1983
|
+
tide_bfloat16 *const curl_store_1_t =
|
|
1984
|
+
(tide_bfloat16 *)curl_store_1 + store_offset;
|
|
1985
|
+
|
|
1986
|
+
if (storage_mode == STORAGE_DISK) {
|
|
1987
|
+
if (grad_ey) {
|
|
1988
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1989
|
+
storage_load_snapshot_cpu(
|
|
1990
|
+
(void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
|
|
1991
|
+
storage_mode, store_idx, (size_t)shot_bytes_uncomp);
|
|
1992
|
+
}
|
|
1993
|
+
}
|
|
1994
|
+
if (grad_curl) {
|
|
1995
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
1996
|
+
storage_load_snapshot_cpu(
|
|
1997
|
+
(void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
|
|
1998
|
+
storage_mode, store_idx, (size_t)shot_bytes_uncomp);
|
|
1999
|
+
}
|
|
2000
|
+
}
|
|
2001
|
+
}
|
|
2002
|
+
} else {
|
|
2003
|
+
TIDE_DTYPE *const ey_store_1_t =
|
|
2004
|
+
ey_store_1 + store_offset;
|
|
2005
|
+
TIDE_DTYPE *const curl_store_1_t =
|
|
2006
|
+
curl_store_1 + store_offset;
|
|
2007
|
+
|
|
2008
|
+
if (storage_mode == STORAGE_DISK) {
|
|
2009
|
+
if (grad_ey) {
|
|
2010
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
2011
|
+
storage_load_snapshot_cpu(
|
|
2012
|
+
(void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
|
|
2013
|
+
storage_mode, store_idx, (size_t)shot_bytes_uncomp);
|
|
2014
|
+
}
|
|
2015
|
+
}
|
|
2016
|
+
if (grad_curl) {
|
|
2017
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) {
|
|
2018
|
+
storage_load_snapshot_cpu(
|
|
2019
|
+
(void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
|
|
2020
|
+
storage_mode, store_idx, (size_t)shot_bytes_uncomp);
|
|
2021
|
+
}
|
|
2022
|
+
}
|
|
2023
|
+
}
|
|
2024
|
+
}
|
|
2025
|
+
|
|
2026
|
+
// Inject adjoint residuals into λ_Ey^{t+1} (adjoint of receiver recording)
|
|
2027
|
+
if (n_receivers_per_shot > 0) {
|
|
2028
|
+
add_sources_ey(
|
|
2029
|
+
lambda_ey, grad_r + t * n_shots * n_receivers_per_shot, receivers_i,
|
|
2030
|
+
n_shots, shot_numel, n_receivers_per_shot);
|
|
2031
|
+
}
|
|
2032
|
+
|
|
2033
|
+
// Record adjoint source gradient using λ_Ey^{t+1} (adjoint of source injection)
|
|
2034
|
+
if (n_sources_per_shot > 0) {
|
|
2035
|
+
record_receivers_ey(
|
|
2036
|
+
grad_f + t * n_shots * n_sources_per_shot,
|
|
2037
|
+
lambda_ey, sources_i,
|
|
2038
|
+
n_shots, shot_numel, n_sources_per_shot);
|
|
2039
|
+
}
|
|
2040
|
+
|
|
2041
|
+
// Backward λ_Ey update with gradient accumulation
|
|
2042
|
+
// This computes: λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * curl_λH
|
|
2043
|
+
// And accumulates: grad_ca += λ_Ey^{n+1} * E_y^n, grad_cb += λ_Ey^{n+1} * curl_H^n
|
|
2044
|
+
if (storage_fp8 && (grad_ey || grad_curl)) {
|
|
2045
|
+
tide_fp8_e4m3 *const ey_store_1_t =
|
|
2046
|
+
(tide_fp8_e4m3 *)ey_store_1 + store_offset;
|
|
2047
|
+
tide_fp8_e4m3 *const curl_store_1_t =
|
|
2048
|
+
(tide_fp8_e4m3 *)curl_store_1 + store_offset;
|
|
2049
|
+
backward_kernel_lambda_e_with_grad_fp8(
|
|
2050
|
+
ca, cq, lambda_hx, lambda_hz, lambda_ey,
|
|
2051
|
+
m_lambda_hx_z, m_lambda_hz_x,
|
|
2052
|
+
grad_ey ? ey_store_1_t : NULL,
|
|
2053
|
+
grad_curl ? curl_store_1_t : NULL,
|
|
2054
|
+
grad_ca, grad_cb,
|
|
2055
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2056
|
+
ky, kyh, kx, kxh,
|
|
2057
|
+
rdy, rdx,
|
|
2058
|
+
n_shots, ny, nx, shot_numel,
|
|
2059
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
2060
|
+
ca_batched, cq_batched,
|
|
2061
|
+
grad_ey, grad_curl,
|
|
2062
|
+
step_ratio);
|
|
2063
|
+
} else if (storage_bf16 && (grad_ey || grad_curl)) {
|
|
2064
|
+
tide_bfloat16 *const ey_store_1_t =
|
|
2065
|
+
(tide_bfloat16 *)ey_store_1 + store_offset;
|
|
2066
|
+
tide_bfloat16 *const curl_store_1_t =
|
|
2067
|
+
(tide_bfloat16 *)curl_store_1 + store_offset;
|
|
2068
|
+
backward_kernel_lambda_e_with_grad_bf16(
|
|
2069
|
+
ca, cq, lambda_hx, lambda_hz, lambda_ey,
|
|
2070
|
+
m_lambda_hx_z, m_lambda_hz_x,
|
|
2071
|
+
grad_ey ? ey_store_1_t : NULL,
|
|
2072
|
+
grad_curl ? curl_store_1_t : NULL,
|
|
2073
|
+
grad_ca, grad_cb,
|
|
2074
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2075
|
+
ky, kyh, kx, kxh,
|
|
2076
|
+
rdy, rdx,
|
|
2077
|
+
n_shots, ny, nx, shot_numel,
|
|
2078
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
2079
|
+
ca_batched, cq_batched,
|
|
2080
|
+
grad_ey, grad_curl,
|
|
2081
|
+
step_ratio);
|
|
2082
|
+
} else {
|
|
2083
|
+
TIDE_DTYPE *const ey_store_1_t =
|
|
2084
|
+
(storage_fp8 || storage_bf16) ? NULL : (ey_store_1 + store_offset);
|
|
2085
|
+
TIDE_DTYPE *const curl_store_1_t =
|
|
2086
|
+
(storage_fp8 || storage_bf16) ? NULL : (curl_store_1 + store_offset);
|
|
2087
|
+
backward_kernel_lambda_e_with_grad(
|
|
2088
|
+
ca, cq, lambda_hx, lambda_hz, lambda_ey,
|
|
2089
|
+
m_lambda_hx_z, m_lambda_hz_x,
|
|
2090
|
+
grad_ey ? ey_store_1_t : NULL,
|
|
2091
|
+
grad_curl ? curl_store_1_t : NULL,
|
|
2092
|
+
grad_ca, grad_cb,
|
|
2093
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2094
|
+
ky, kyh, kx, kxh,
|
|
2095
|
+
rdy, rdx,
|
|
2096
|
+
n_shots, ny, nx, shot_numel,
|
|
2097
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
2098
|
+
ca_batched, cq_batched,
|
|
2099
|
+
grad_ey, grad_curl,
|
|
2100
|
+
step_ratio);
|
|
2101
|
+
}
|
|
2102
|
+
|
|
2103
|
+
// Backward λ_H fields update
|
|
2104
|
+
backward_kernel_lambda_h(
|
|
2105
|
+
cb, lambda_ey, lambda_hx, lambda_hz,
|
|
2106
|
+
m_lambda_ey_x, m_lambda_ey_z,
|
|
2107
|
+
ay, ayh, ax, axh, by, byh, bx, bxh,
|
|
2108
|
+
ky, kyh, kx, kxh,
|
|
2109
|
+
rdy, rdx,
|
|
2110
|
+
n_shots, ny, nx, shot_numel,
|
|
2111
|
+
pml_y0, pml_y1, pml_x0, pml_x1,
|
|
2112
|
+
cb_batched);
|
|
2113
|
+
}
|
|
2114
|
+
|
|
2115
|
+
if (fp_ey != NULL) {
|
|
2116
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) fclose(fp_ey[shot]);
|
|
2117
|
+
free(fp_ey);
|
|
2118
|
+
}
|
|
2119
|
+
if (fp_curl != NULL) {
|
|
2120
|
+
for (int64_t shot = 0; shot < n_shots; ++shot) fclose(fp_curl[shot]);
|
|
2121
|
+
free(fp_curl);
|
|
2122
|
+
}
|
|
2123
|
+
|
|
2124
|
+
convert_grad_ca_cb_to_eps_sigma(
|
|
2125
|
+
ca, cb, grad_ca, grad_cb, grad_eps, grad_sigma,
|
|
2126
|
+
dt, n_shots, ny, nx, ca_batched, cb_batched,
|
|
2127
|
+
ca_requires_grad, cb_requires_grad);
|
|
2128
|
+
#ifdef _OPENMP
|
|
2129
|
+
if (n_threads > 0) {
|
|
2130
|
+
omp_set_num_threads(prev_threads);
|
|
2131
|
+
}
|
|
2132
|
+
#endif
|
|
2133
|
+
}
|