tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tide/csrc/maxwell.c ADDED
@@ -0,0 +1,2133 @@
1
+ /*
2
+ * Maxwell wave equation propagator (CPU implementation)
3
+ *
4
+ * This file contains the CPU implementation of the 2D TM Maxwell equations
5
+ * propagator with complete Adjoint State Method (ASM) support for gradient
6
+ * computation.
7
+ *
8
+ * TM mode fields: Ey (electric), Hx, Hz (magnetic)
9
+ *
10
+ * Adjoint State Method for Maxwell TM equations:
11
+ * ================================================
12
+ * Forward equations (discrete):
13
+ * E_y^{n+1} = C_a * E_y^n + C_b * (∂H_z/∂x - ∂H_x/∂z)
14
+ * H_x^{n+1/2} = H_x^{n-1/2} - C_q * ∂E_y/∂z
15
+ * H_z^{n+1/2} = H_z^{n-1/2} + C_q * ∂E_y/∂x
16
+ *
17
+ * Adjoint equations (time-reversed):
18
+ * λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * (∂λ_Hz/∂x - ∂λ_Hx/∂z) + residual_injection
19
+ * λ_Hx^{n-1/2} = λ_Hx^{n+1/2} - C_b * ∂λ_Ey/∂z
20
+ * λ_Hz^{n-1/2} = λ_Hz^{n+1/2} + C_b * ∂λ_Ey/∂x
21
+ *
22
+ * Model gradients:
23
+ * ∂J/∂C_a = Σ_n E_y^n * λ_Ey^{n+1}
24
+ * ∂J/∂C_b = Σ_n curl_H^n * λ_Ey^{n+1}
25
+ *
26
+ * Storage requirements:
27
+ * - ey_store: E_y field at each step_ratio time step [nt/step_ratio, n_shots, ny, nx]
28
+ * - curl_h_store: (∂H_z/∂x - ∂H_x/∂z) at each step_ratio time step [nt/step_ratio, n_shots, ny, nx]
29
+ */
30
+
31
+ #include <stdio.h>
32
+ #include <stdint.h>
33
+ #include <stdbool.h>
34
+ #include <stdlib.h>
35
+ #include <string.h>
36
+ #include <math.h>
37
+
38
+ #ifdef _OPENMP
39
+ #include <omp.h>
40
+ #endif
41
+
42
+ #include "common_cpu.h"
43
+ #include "staggered_grid.h"
44
+ #include "storage_utils.h"
45
+
46
+ #define CAT_I(name, accuracy, dtype, device) \
47
+ maxwell_tm_##accuracy##_##dtype##_##name##_##device
48
+ #define CAT(name, accuracy, dtype, device) \
49
+ CAT_I(name, accuracy, dtype, device)
50
+ #define FUNC(name) CAT(name, TIDE_STENCIL, TIDE_DTYPE, cpu)
51
+
52
+ // 2D indexing macros
53
+ #define IDX(y, x) ((y) * nx + (x))
54
+ #define IDX_SHOT(shot, y, x) ((shot) * shot_numel + (y) * nx + (x))
55
+
56
+ #define MAX(a, b) ((a) > (b) ? (a) : (b))
57
+ #define MIN(a, b) ((a) < (b) ? (a) : (b))
58
+
59
+ // Vacuum permittivity (F/m) to convert dL/d(epsilon_abs) -> dL/d(epsilon_r)
60
+ #define EP0 ((TIDE_DTYPE)8.8541878128e-12)
61
+
62
+ typedef uint16_t tide_bfloat16;
63
+
64
+ static inline tide_bfloat16 tide_float_to_bf16(float value) {
65
+ union {
66
+ float f;
67
+ uint32_t u;
68
+ } tmp;
69
+ tmp.f = value;
70
+ uint32_t lsb = (tmp.u >> 16) & 1u;
71
+ tmp.u += 0x7FFFu + lsb;
72
+ return (tide_bfloat16)(tmp.u >> 16);
73
+ }
74
+
75
+ static inline float tide_bf16_to_float(tide_bfloat16 value) {
76
+ union {
77
+ uint32_t u;
78
+ float f;
79
+ } tmp;
80
+ tmp.u = ((uint32_t)value) << 16;
81
+ return tmp.f;
82
+ }
83
+
84
+ // FP8 E4M3 format (1 sign bit, 4 exponent bits, 3 mantissa bits)
85
+ // Dynamic range: ~10^-5 to ~448
86
+ typedef uint8_t tide_fp8_e4m3;
87
+
88
+ static inline tide_fp8_e4m3 tide_float_to_fp8_e4m3(float value) {
89
+ if (value == 0.0f) {
90
+ return 0;
91
+ }
92
+
93
+ union {
94
+ float f;
95
+ uint32_t u;
96
+ } tmp;
97
+ tmp.f = value;
98
+
99
+ uint8_t sign = (tmp.u >> 31) ? 0x80 : 0;
100
+ float ax = fabsf(value);
101
+
102
+ // Handle infinity/NaN - saturate to max representable
103
+ if (!isfinite(ax)) {
104
+ return (uint8_t)(sign | 0x7F);
105
+ }
106
+
107
+ // Use frexpf to extract mantissa and exponent
108
+ int exp;
109
+ float m = frexpf(ax, &exp); // ax = m * 2^exp, m in [0.5, 1)
110
+ int e = exp - 1;
111
+ int exp_field = e + 7; // Bias of 7 for E4M3
112
+ int mant = 0;
113
+
114
+ if (exp_field <= 0) {
115
+ // Subnormal - compute mantissa directly
116
+ // Smallest normal: 2^-6, subnormal range: 2^-9 to 2^-6
117
+ mant = (int)roundf(ax * 512.0f);
118
+ if (mant <= 0) {
119
+ return sign; // Underflow to zero
120
+ }
121
+ if (mant > 7) {
122
+ mant = 7; // Clamp to max subnormal
123
+ }
124
+ exp_field = 0;
125
+ } else if (exp_field >= 0xF) {
126
+ // Overflow - saturate to max normal
127
+ exp_field = 0xE;
128
+ mant = 7;
129
+ } else {
130
+ // Normal number
131
+ float frac = m * 2.0f - 1.0f; // Extract fractional part
132
+ mant = (int)roundf(frac * 8.0f);
133
+ if (mant == 8) {
134
+ // Rounding overflow - increment exponent
135
+ mant = 0;
136
+ exp_field += 1;
137
+ if (exp_field >= 0xF) {
138
+ exp_field = 0xE;
139
+ mant = 7;
140
+ }
141
+ }
142
+ }
143
+
144
+ return (uint8_t)(sign | ((uint8_t)exp_field << 3) | (uint8_t)(mant & 0x7));
145
+ }
146
+
147
+ static inline float tide_fp8_e4m3_to_float(tide_fp8_e4m3 value) {
148
+ if (value == 0) {
149
+ return 0.0f;
150
+ }
151
+
152
+ int sign = value & 0x80;
153
+ int exp_field = (value >> 3) & 0xF;
154
+ int mant = value & 0x7;
155
+ float val;
156
+
157
+ if (exp_field == 0) {
158
+ // Subnormal
159
+ float frac = (float)mant / 8.0f;
160
+ val = ldexpf(frac, -6); // 2^-6 * (mant/8)
161
+ } else {
162
+ // Normal
163
+ float frac = 1.0f + (float)mant / 8.0f;
164
+ val = ldexpf(frac, exp_field - 7);
165
+ }
166
+
167
+ return sign ? -val : val;
168
+ }
169
+
170
+ // Field access macros for stencil operations
171
+ #define EY(dy, dx) ey[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
172
+ #define HX(dy, dx) hx[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
173
+ #define HZ(dy, dx) hz[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
174
+
175
+ // Adjoint field access macros
176
+ #define LAMBDA_EY(dy, dx) lambda_ey[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
177
+ #define LAMBDA_HX(dy, dx) lambda_hx[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
178
+ #define LAMBDA_HZ(dy, dx) lambda_hz[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
179
+
180
+ // Material parameter access macros
181
+ #define CA(dy, dx) (ca_batched ? ca[IDX_SHOT(shot_idx, y + (dy), x + (dx))] : ca[IDX(y + (dy), x + (dx))])
182
+ #define CB(dy, dx) (cb_batched ? cb[IDX_SHOT(shot_idx, y + (dy), x + (dx))] : cb[IDX(y + (dy), x + (dx))])
183
+ #define CQ(dy, dx) (cq_batched ? cq[IDX_SHOT(shot_idx, y + (dy), x + (dx))] : cq[IDX(y + (dy), x + (dx))])
184
+
185
+ // PML memory variable macros
186
+ #define M_EY_X(dy, dx) m_ey_x[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
187
+ #define M_EY_Z(dy, dx) m_ey_z[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
188
+ #define M_HX_Z(dy, dx) m_hx_z[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
189
+ #define M_HZ_X(dy, dx) m_hz_x[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
190
+
191
+ // Adjoint PML memory variable macros
192
+ #define M_LAMBDA_EY_X(dy, dx) m_lambda_ey_x[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
193
+ #define M_LAMBDA_EY_Z(dy, dx) m_lambda_ey_z[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
194
+ #define M_LAMBDA_HX_Z(dy, dx) m_lambda_hx_z[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
195
+ #define M_LAMBDA_HZ_X(dy, dx) m_lambda_hz_x[IDX_SHOT(shot_idx, y + (dy), x + (dx))]
196
+
197
+ static void convert_grad_ca_cb_to_eps_sigma(
198
+ TIDE_DTYPE const *__restrict const ca,
199
+ TIDE_DTYPE const *__restrict const cb,
200
+ TIDE_DTYPE const *__restrict const grad_ca,
201
+ TIDE_DTYPE const *__restrict const grad_cb,
202
+ TIDE_DTYPE *__restrict const grad_eps,
203
+ TIDE_DTYPE *__restrict const grad_sigma,
204
+ TIDE_DTYPE const dt,
205
+ int64_t const n_shots,
206
+ int64_t const ny,
207
+ int64_t const nx,
208
+ bool const ca_batched,
209
+ bool const cb_batched,
210
+ bool const ca_requires_grad,
211
+ bool const cb_requires_grad) {
212
+ if ((grad_eps == NULL && grad_sigma == NULL) || (!ca_requires_grad && !cb_requires_grad)) {
213
+ return;
214
+ }
215
+
216
+ int64_t const shot_numel = ny * nx;
217
+ int64_t const out_shots = ca_batched ? n_shots : 1;
218
+ TIDE_DTYPE const inv_dt = (TIDE_DTYPE)1 / dt;
219
+
220
+ TIDE_OMP_INDEX shot_idx;
221
+ TIDE_OMP_PARALLEL_FOR
222
+ for (shot_idx = 0; shot_idx < out_shots; ++shot_idx) {
223
+ int64_t const shot_offset = shot_idx * shot_numel;
224
+ int64_t const out_offset = ca_batched ? shot_offset : 0;
225
+ int64_t const ca_offset = ca_batched ? shot_offset : 0;
226
+ int64_t const cb_offset = cb_batched ? shot_offset : 0;
227
+ TIDE_DTYPE const *__restrict const ca_ptr = ca + ca_offset;
228
+ TIDE_DTYPE const *__restrict const cb_ptr = cb + cb_offset;
229
+ TIDE_OMP_SIMD_COLLAPSE2
230
+ for (int64_t y = 0; y < ny; ++y) {
231
+ for (int64_t x = 0; x < nx; ++x) {
232
+ int64_t const idx = IDX(y, x);
233
+ int64_t const out_idx = out_offset + idx;
234
+
235
+ TIDE_DTYPE const ca_val = ca_ptr[idx];
236
+ TIDE_DTYPE const cb_val = cb_ptr[idx];
237
+
238
+ TIDE_DTYPE const grad_ca_val =
239
+ (ca_requires_grad && grad_ca != NULL) ? grad_ca[out_idx] : (TIDE_DTYPE)0;
240
+ TIDE_DTYPE const grad_cb_val =
241
+ (cb_requires_grad && grad_cb != NULL) ? grad_cb[out_idx] : (TIDE_DTYPE)0;
242
+
243
+ TIDE_DTYPE const cb_sq = cb_val * cb_val;
244
+ TIDE_DTYPE const dca_de = ((TIDE_DTYPE)1 - ca_val) * cb_val * inv_dt;
245
+ TIDE_DTYPE const dcb_de = -cb_sq * inv_dt;
246
+ TIDE_DTYPE const dca_ds = -((TIDE_DTYPE)0.5) * ((TIDE_DTYPE)1 + ca_val) * cb_val;
247
+ TIDE_DTYPE const dcb_ds = -((TIDE_DTYPE)0.5) * cb_sq;
248
+
249
+ if (grad_eps != NULL) {
250
+ TIDE_DTYPE const grad_e = grad_ca_val * dca_de + grad_cb_val * dcb_de;
251
+ grad_eps[out_idx] = grad_e * EP0;
252
+ }
253
+ if (grad_sigma != NULL) {
254
+ grad_sigma[out_idx] = grad_ca_val * dca_ds + grad_cb_val * dcb_ds;
255
+ }
256
+ }
257
+ }
258
+ }
259
+ }
260
+
261
+
262
+ static void add_sources_ey(
263
+ TIDE_DTYPE *__restrict const ey,
264
+ TIDE_DTYPE const *__restrict const f,
265
+ int64_t const *__restrict const sources_i,
266
+ int64_t const n_shots,
267
+ int64_t const shot_numel,
268
+ int64_t const n_sources_per_shot) {
269
+
270
+ TIDE_OMP_INDEX shot_idx;
271
+ TIDE_OMP_PARALLEL_FOR
272
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
273
+ TIDE_OMP_SIMD
274
+ for (int64_t source_idx = 0; source_idx < n_sources_per_shot; ++source_idx) {
275
+ int64_t k = shot_idx * n_sources_per_shot + source_idx;
276
+ if (sources_i[k] >= 0) {
277
+ ey[shot_idx * shot_numel + sources_i[k]] += f[k];
278
+ }
279
+ }
280
+ }
281
+ }
282
+
283
+ static void subtract_sources_ey(
284
+ TIDE_DTYPE *__restrict const ey,
285
+ TIDE_DTYPE const *__restrict const f,
286
+ int64_t const *__restrict const sources_i,
287
+ int64_t const n_shots,
288
+ int64_t const shot_numel,
289
+ int64_t const n_sources_per_shot) {
290
+
291
+ TIDE_OMP_INDEX shot_idx;
292
+ TIDE_OMP_PARALLEL_FOR
293
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
294
+ TIDE_OMP_SIMD
295
+ for (int64_t source_idx = 0; source_idx < n_sources_per_shot; ++source_idx) {
296
+ int64_t k = shot_idx * n_sources_per_shot + source_idx;
297
+ if (sources_i[k] >= 0) {
298
+ ey[shot_idx * shot_numel + sources_i[k]] -= f[k];
299
+ }
300
+ }
301
+ }
302
+ }
303
+
304
+
305
+ static void record_receivers_ey(
306
+ TIDE_DTYPE *__restrict const r,
307
+ TIDE_DTYPE const *__restrict const ey,
308
+ int64_t const *__restrict const receivers_i,
309
+ int64_t const n_shots,
310
+ int64_t const shot_numel,
311
+ int64_t const n_receivers_per_shot) {
312
+
313
+ TIDE_OMP_INDEX shot_idx;
314
+ TIDE_OMP_PARALLEL_FOR
315
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
316
+ TIDE_OMP_SIMD
317
+ for (int64_t receiver_idx = 0; receiver_idx < n_receivers_per_shot; ++receiver_idx) {
318
+ int64_t k = shot_idx * n_receivers_per_shot + receiver_idx;
319
+ if (receivers_i[k] >= 0) {
320
+ r[k] = ey[shot_idx * shot_numel + receivers_i[k]];
321
+ }
322
+ }
323
+ }
324
+ }
325
+
326
+ static void gather_boundary_3_cpu(
327
+ TIDE_DTYPE const *__restrict const ey,
328
+ TIDE_DTYPE const *__restrict const hx,
329
+ TIDE_DTYPE const *__restrict const hz,
330
+ TIDE_DTYPE *__restrict const bey,
331
+ TIDE_DTYPE *__restrict const bhx,
332
+ TIDE_DTYPE *__restrict const bhz,
333
+ int64_t const *__restrict const boundary_indices,
334
+ int64_t const boundary_numel,
335
+ int64_t const n_shots,
336
+ int64_t const shot_numel) {
337
+
338
+ TIDE_OMP_INDEX shot_idx;
339
+ TIDE_OMP_PARALLEL_FOR
340
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
341
+ TIDE_OMP_SIMD
342
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
343
+ int64_t const grid_idx = boundary_indices[bi];
344
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
345
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
346
+ bey[store_offset] = ey[field_offset];
347
+ bhx[store_offset] = hx[field_offset];
348
+ bhz[store_offset] = hz[field_offset];
349
+ }
350
+ }
351
+ }
352
+
353
+ static void scatter_boundary_cpu(
354
+ TIDE_DTYPE *__restrict const field,
355
+ TIDE_DTYPE const *__restrict const store,
356
+ int64_t const *__restrict const boundary_indices,
357
+ int64_t const boundary_numel,
358
+ int64_t const n_shots,
359
+ int64_t const shot_numel) {
360
+
361
+ TIDE_OMP_INDEX shot_idx;
362
+ TIDE_OMP_PARALLEL_FOR
363
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
364
+ TIDE_OMP_SIMD
365
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
366
+ int64_t const grid_idx = boundary_indices[bi];
367
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
368
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
369
+ field[field_offset] = store[store_offset];
370
+ }
371
+ }
372
+ }
373
+
374
+ static void scatter_boundary_2_cpu(
375
+ TIDE_DTYPE *__restrict const hx,
376
+ TIDE_DTYPE *__restrict const hz,
377
+ TIDE_DTYPE const *__restrict const bhx,
378
+ TIDE_DTYPE const *__restrict const bhz,
379
+ int64_t const *__restrict const boundary_indices,
380
+ int64_t const boundary_numel,
381
+ int64_t const n_shots,
382
+ int64_t const shot_numel) {
383
+
384
+ TIDE_OMP_INDEX shot_idx;
385
+ TIDE_OMP_PARALLEL_FOR
386
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
387
+ TIDE_OMP_SIMD
388
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
389
+ int64_t const grid_idx = boundary_indices[bi];
390
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
391
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
392
+ hx[field_offset] = bhx[store_offset];
393
+ hz[field_offset] = bhz[store_offset];
394
+ }
395
+ }
396
+ }
397
+
398
+ static void gather_boundary_3_cpu_bf16(
399
+ TIDE_DTYPE const *__restrict const ey,
400
+ TIDE_DTYPE const *__restrict const hx,
401
+ TIDE_DTYPE const *__restrict const hz,
402
+ tide_bfloat16 *__restrict const bey,
403
+ tide_bfloat16 *__restrict const bhx,
404
+ tide_bfloat16 *__restrict const bhz,
405
+ int64_t const *__restrict const boundary_indices,
406
+ int64_t const boundary_numel,
407
+ int64_t const n_shots,
408
+ int64_t const shot_numel) {
409
+
410
+ TIDE_OMP_INDEX shot_idx;
411
+ TIDE_OMP_PARALLEL_FOR
412
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
413
+ TIDE_OMP_SIMD
414
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
415
+ int64_t const grid_idx = boundary_indices[bi];
416
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
417
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
418
+ bey[store_offset] = tide_float_to_bf16((float)ey[field_offset]);
419
+ bhx[store_offset] = tide_float_to_bf16((float)hx[field_offset]);
420
+ bhz[store_offset] = tide_float_to_bf16((float)hz[field_offset]);
421
+ }
422
+ }
423
+ }
424
+
425
+ static void scatter_boundary_cpu_bf16(
426
+ TIDE_DTYPE *__restrict const field,
427
+ tide_bfloat16 const *__restrict const store,
428
+ int64_t const *__restrict const boundary_indices,
429
+ int64_t const boundary_numel,
430
+ int64_t const n_shots,
431
+ int64_t const shot_numel) {
432
+
433
+ TIDE_OMP_INDEX shot_idx;
434
+ TIDE_OMP_PARALLEL_FOR
435
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
436
+ TIDE_OMP_SIMD
437
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
438
+ int64_t const grid_idx = boundary_indices[bi];
439
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
440
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
441
+ field[field_offset] = (TIDE_DTYPE)tide_bf16_to_float(store[store_offset]);
442
+ }
443
+ }
444
+ }
445
+
446
+ static void scatter_boundary_2_cpu_bf16(
447
+ TIDE_DTYPE *__restrict const hx,
448
+ TIDE_DTYPE *__restrict const hz,
449
+ tide_bfloat16 const *__restrict const bhx,
450
+ tide_bfloat16 const *__restrict const bhz,
451
+ int64_t const *__restrict const boundary_indices,
452
+ int64_t const boundary_numel,
453
+ int64_t const n_shots,
454
+ int64_t const shot_numel) {
455
+
456
+ TIDE_OMP_INDEX shot_idx;
457
+ TIDE_OMP_PARALLEL_FOR
458
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
459
+ TIDE_OMP_SIMD
460
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
461
+ int64_t const grid_idx = boundary_indices[bi];
462
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
463
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
464
+ hx[field_offset] = (TIDE_DTYPE)tide_bf16_to_float(bhx[store_offset]);
465
+ hz[field_offset] = (TIDE_DTYPE)tide_bf16_to_float(bhz[store_offset]);
466
+ }
467
+ }
468
+ }
469
+
470
+ // FP8 boundary storage functions
471
+ static void gather_boundary_3_cpu_fp8(
472
+ TIDE_DTYPE const *__restrict const ey,
473
+ TIDE_DTYPE const *__restrict const hx,
474
+ TIDE_DTYPE const *__restrict const hz,
475
+ tide_fp8_e4m3 *__restrict const bey,
476
+ tide_fp8_e4m3 *__restrict const bhx,
477
+ tide_fp8_e4m3 *__restrict const bhz,
478
+ int64_t const *__restrict const boundary_indices,
479
+ int64_t const boundary_numel,
480
+ int64_t const n_shots,
481
+ int64_t const shot_numel) {
482
+
483
+ TIDE_OMP_INDEX shot_idx;
484
+ TIDE_OMP_PARALLEL_FOR
485
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
486
+ TIDE_OMP_SIMD
487
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
488
+ int64_t const grid_idx = boundary_indices[bi];
489
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
490
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
491
+ bey[store_offset] = tide_float_to_fp8_e4m3((float)ey[field_offset]);
492
+ bhx[store_offset] = tide_float_to_fp8_e4m3((float)hx[field_offset]);
493
+ bhz[store_offset] = tide_float_to_fp8_e4m3((float)hz[field_offset]);
494
+ }
495
+ }
496
+ }
497
+
498
+ static void scatter_boundary_cpu_fp8(
499
+ TIDE_DTYPE *__restrict const field,
500
+ tide_fp8_e4m3 const *__restrict const store,
501
+ int64_t const *__restrict const boundary_indices,
502
+ int64_t const boundary_numel,
503
+ int64_t const n_shots,
504
+ int64_t const shot_numel) {
505
+
506
+ TIDE_OMP_INDEX shot_idx;
507
+ TIDE_OMP_PARALLEL_FOR
508
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
509
+ TIDE_OMP_SIMD
510
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
511
+ int64_t const grid_idx = boundary_indices[bi];
512
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
513
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
514
+ field[field_offset] = (TIDE_DTYPE)tide_fp8_e4m3_to_float(store[store_offset]);
515
+ }
516
+ }
517
+ }
518
+
519
+ static void scatter_boundary_2_cpu_fp8(
520
+ TIDE_DTYPE *__restrict const hx,
521
+ TIDE_DTYPE *__restrict const hz,
522
+ tide_fp8_e4m3 const *__restrict const bhx,
523
+ tide_fp8_e4m3 const *__restrict const bhz,
524
+ int64_t const *__restrict const boundary_indices,
525
+ int64_t const boundary_numel,
526
+ int64_t const n_shots,
527
+ int64_t const shot_numel) {
528
+
529
+ TIDE_OMP_INDEX shot_idx;
530
+ TIDE_OMP_PARALLEL_FOR
531
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
532
+ TIDE_OMP_SIMD
533
+ for (int64_t bi = 0; bi < boundary_numel; ++bi) {
534
+ int64_t const grid_idx = boundary_indices[bi];
535
+ int64_t const field_offset = shot_idx * shot_numel + grid_idx;
536
+ int64_t const store_offset = shot_idx * boundary_numel + bi;
537
+ hx[field_offset] = (TIDE_DTYPE)tide_fp8_e4m3_to_float(bhx[store_offset]);
538
+ hz[field_offset] = (TIDE_DTYPE)tide_fp8_e4m3_to_float(bhz[store_offset]);
539
+ }
540
+ }
541
+ }
542
+
543
+ static inline void *boundary_store_ptr(
544
+ void *store_1,
545
+ void *store_3,
546
+ int64_t storage_mode,
547
+ int64_t step_idx,
548
+ int64_t step_elems,
549
+ size_t elem_size) {
550
+ size_t const offset_bytes =
551
+ (size_t)step_idx * (size_t)step_elems * elem_size;
552
+ if (storage_mode == STORAGE_DEVICE) {
553
+ return (uint8_t *)store_1 + offset_bytes;
554
+ }
555
+ if (storage_mode == STORAGE_CPU && store_3 != NULL) {
556
+ return (uint8_t *)store_3 + offset_bytes;
557
+ }
558
+ return (uint8_t *)store_1;
559
+ }
560
+
561
+
562
+ static void forward_kernel_h(
563
+ TIDE_DTYPE const *__restrict const cq,
564
+ TIDE_DTYPE const *__restrict const ey,
565
+ TIDE_DTYPE *__restrict const hx,
566
+ TIDE_DTYPE *__restrict const hz,
567
+ TIDE_DTYPE *__restrict const m_ey_x,
568
+ TIDE_DTYPE *__restrict const m_ey_z,
569
+ TIDE_DTYPE const *__restrict const ay,
570
+ TIDE_DTYPE const *__restrict const ayh,
571
+ TIDE_DTYPE const *__restrict const ax,
572
+ TIDE_DTYPE const *__restrict const axh,
573
+ TIDE_DTYPE const *__restrict const by,
574
+ TIDE_DTYPE const *__restrict const byh,
575
+ TIDE_DTYPE const *__restrict const bx,
576
+ TIDE_DTYPE const *__restrict const bxh,
577
+ TIDE_DTYPE const *__restrict const ky,
578
+ TIDE_DTYPE const *__restrict const kyh,
579
+ TIDE_DTYPE const *__restrict const kx,
580
+ TIDE_DTYPE const *__restrict const kxh,
581
+ TIDE_DTYPE const rdy,
582
+ TIDE_DTYPE const rdx,
583
+ int64_t const n_shots,
584
+ int64_t const ny,
585
+ int64_t const nx,
586
+ int64_t const shot_numel,
587
+ int64_t const pml_y0,
588
+ int64_t const pml_y1,
589
+ int64_t const pml_x0,
590
+ int64_t const pml_x1,
591
+ bool const cq_batched) {
592
+
593
+ int64_t const pml_y0h = pml_y0;
594
+ int64_t const pml_y1h = MAX(pml_y0, pml_y1 - 1);
595
+ int64_t const pml_x0h = pml_x0;
596
+ int64_t const pml_x1h = MAX(pml_x0, pml_x1 - 1);
597
+ int64_t const pml_bounds_yh[] = {FD_PAD, pml_y0h, pml_y1h, ny - FD_PAD + 1};
598
+ int64_t const pml_bounds_xh[] = {FD_PAD, pml_x0h, pml_x1h, nx - FD_PAD + 1};
599
+
600
+ TIDE_OMP_INDEX shot_idx;
601
+ TIDE_OMP_PARALLEL_FOR
602
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
603
+ int64_t const shot_offset = shot_idx * shot_numel;
604
+ TIDE_DTYPE const *__restrict const cq_ptr =
605
+ cq_batched ? (cq + shot_offset) : cq;
606
+ for (int pml_y = 0; pml_y < 3; ++pml_y) {
607
+ for (int pml_x = 0; pml_x < 3; ++pml_x) {
608
+ TIDE_OMP_SIMD_COLLAPSE2
609
+ for (int64_t y = pml_bounds_yh[pml_y]; y < pml_bounds_yh[pml_y + 1]; ++y) {
610
+ for (int64_t x = pml_bounds_xh[pml_x]; x < pml_bounds_xh[pml_x + 1]; ++x) {
611
+ int64_t const idx = IDX(y, x);
612
+ TIDE_DTYPE const cq_val = cq_ptr[idx];
613
+
614
+ if (y < ny - FD_PAD) {
615
+ TIDE_DTYPE dey_dz = DIFFYH1(EY);
616
+
617
+ if (pml_y != 1) {
618
+ M_EY_Z(0, 0) = byh[y] * M_EY_Z(0, 0) + ayh[y] * dey_dz;
619
+ dey_dz = dey_dz / kyh[y] + M_EY_Z(0, 0);
620
+ }
621
+
622
+ HX(0, 0) -= cq_val * dey_dz;
623
+ }
624
+
625
+ if (x < nx - FD_PAD) {
626
+ TIDE_DTYPE dey_dx = DIFFXH1(EY);
627
+
628
+ if (pml_x != 1) {
629
+ M_EY_X(0, 0) = bxh[x] * M_EY_X(0, 0) + axh[x] * dey_dx;
630
+ dey_dx = dey_dx / kxh[x] + M_EY_X(0, 0);
631
+ }
632
+
633
+ HZ(0, 0) += cq_val * dey_dx;
634
+ }
635
+ }
636
+ }
637
+ }
638
+ }
639
+ }
640
+ }
641
+
642
+
643
+ /*
644
+ * Forward E kernel with optional storage for gradient computation
645
+ *
646
+ * When ca_requires_grad or cb_requires_grad is true, stores:
647
+ * - ey_store: E_y field before update (needed for grad_ca)
648
+ * - curl_h_store: (dHz/dx - dHx/dz) (needed for grad_cb)
649
+ */
650
+ static void forward_kernel_e_with_storage(
651
+ TIDE_DTYPE const *__restrict const ca,
652
+ TIDE_DTYPE const *__restrict const cb,
653
+ TIDE_DTYPE const *__restrict const hx,
654
+ TIDE_DTYPE const *__restrict const hz,
655
+ TIDE_DTYPE *__restrict const ey,
656
+ TIDE_DTYPE *__restrict const m_hx_z,
657
+ TIDE_DTYPE *__restrict const m_hz_x,
658
+ TIDE_DTYPE const *__restrict const ay,
659
+ TIDE_DTYPE const *__restrict const ayh,
660
+ TIDE_DTYPE const *__restrict const ax,
661
+ TIDE_DTYPE const *__restrict const axh,
662
+ TIDE_DTYPE const *__restrict const by,
663
+ TIDE_DTYPE const *__restrict const byh,
664
+ TIDE_DTYPE const *__restrict const bx,
665
+ TIDE_DTYPE const *__restrict const bxh,
666
+ TIDE_DTYPE const *__restrict const ky,
667
+ TIDE_DTYPE const *__restrict const kyh,
668
+ TIDE_DTYPE const *__restrict const kx,
669
+ TIDE_DTYPE const *__restrict const kxh,
670
+ TIDE_DTYPE const rdy,
671
+ TIDE_DTYPE const rdx,
672
+ int64_t const n_shots,
673
+ int64_t const ny,
674
+ int64_t const nx,
675
+ int64_t const shot_numel,
676
+ int64_t const pml_y0,
677
+ int64_t const pml_y1,
678
+ int64_t const pml_x0,
679
+ int64_t const pml_x1,
680
+ bool const ca_batched,
681
+ bool const cb_batched,
682
+ bool const ca_requires_grad,
683
+ bool const cb_requires_grad,
684
+ TIDE_DTYPE *__restrict const ey_store,
685
+ TIDE_DTYPE *__restrict const curl_h_store) {
686
+
687
+ int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
688
+ int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
689
+
690
+ TIDE_OMP_INDEX shot_idx;
691
+ TIDE_OMP_PARALLEL_FOR
692
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
693
+ int64_t const shot_offset = shot_idx * shot_numel;
694
+ TIDE_DTYPE const *__restrict const ca_ptr =
695
+ ca_batched ? (ca + shot_offset) : ca;
696
+ TIDE_DTYPE const *__restrict const cb_ptr =
697
+ cb_batched ? (cb + shot_offset) : cb;
698
+ for (int pml_y = 0; pml_y < 3; ++pml_y) {
699
+ for (int pml_x = 0; pml_x < 3; ++pml_x) {
700
+ TIDE_OMP_SIMD_COLLAPSE2
701
+ for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
702
+ for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
703
+ int64_t const idx = IDX(y, x);
704
+ int64_t const store_idx = shot_offset + idx;
705
+ TIDE_DTYPE const ca_val = ca_ptr[idx];
706
+ TIDE_DTYPE const cb_val = cb_ptr[idx];
707
+
708
+ TIDE_DTYPE dhz_dx = DIFFX1(HZ);
709
+ TIDE_DTYPE dhx_dz = DIFFY1(HX);
710
+
711
+ if (pml_x != 1) {
712
+ M_HZ_X(0, 0) = bx[x] * M_HZ_X(0, 0) + ax[x] * dhz_dx;
713
+ dhz_dx = dhz_dx / kx[x] + M_HZ_X(0, 0);
714
+ }
715
+
716
+ if (pml_y != 1) {
717
+ M_HX_Z(0, 0) = by[y] * M_HX_Z(0, 0) + ay[y] * dhx_dz;
718
+ dhx_dz = dhx_dz / ky[y] + M_HX_Z(0, 0);
719
+ }
720
+
721
+ TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
722
+
723
+ if (ca_requires_grad && ey_store != NULL) {
724
+ ey_store[store_idx] = EY(0, 0);
725
+ }
726
+ if (cb_requires_grad && curl_h_store != NULL) {
727
+ curl_h_store[store_idx] = curl_h;
728
+ }
729
+
730
+ EY(0, 0) = ca_val * EY(0, 0) + cb_val * curl_h;
731
+ }
732
+ }
733
+ }
734
+ }
735
+ }
736
+ }
737
+
738
+ static void forward_kernel_e_with_storage_bf16(
739
+ TIDE_DTYPE const *__restrict const ca,
740
+ TIDE_DTYPE const *__restrict const cb,
741
+ TIDE_DTYPE const *__restrict const hx,
742
+ TIDE_DTYPE const *__restrict const hz,
743
+ TIDE_DTYPE *__restrict const ey,
744
+ TIDE_DTYPE *__restrict const m_hx_z,
745
+ TIDE_DTYPE *__restrict const m_hz_x,
746
+ TIDE_DTYPE const *__restrict const ay,
747
+ TIDE_DTYPE const *__restrict const ayh,
748
+ TIDE_DTYPE const *__restrict const ax,
749
+ TIDE_DTYPE const *__restrict const axh,
750
+ TIDE_DTYPE const *__restrict const by,
751
+ TIDE_DTYPE const *__restrict const byh,
752
+ TIDE_DTYPE const *__restrict const bx,
753
+ TIDE_DTYPE const *__restrict const bxh,
754
+ TIDE_DTYPE const *__restrict const ky,
755
+ TIDE_DTYPE const *__restrict const kyh,
756
+ TIDE_DTYPE const *__restrict const kx,
757
+ TIDE_DTYPE const *__restrict const kxh,
758
+ TIDE_DTYPE const rdy,
759
+ TIDE_DTYPE const rdx,
760
+ int64_t const n_shots,
761
+ int64_t const ny,
762
+ int64_t const nx,
763
+ int64_t const shot_numel,
764
+ int64_t const pml_y0,
765
+ int64_t const pml_y1,
766
+ int64_t const pml_x0,
767
+ int64_t const pml_x1,
768
+ bool const ca_batched,
769
+ bool const cb_batched,
770
+ bool const ca_requires_grad,
771
+ bool const cb_requires_grad,
772
+ tide_bfloat16 *__restrict const ey_store,
773
+ tide_bfloat16 *__restrict const curl_h_store) {
774
+
775
+ int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
776
+ int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
777
+
778
+ TIDE_OMP_INDEX shot_idx;
779
+ TIDE_OMP_PARALLEL_FOR
780
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
781
+ int64_t const shot_offset = shot_idx * shot_numel;
782
+ TIDE_DTYPE const *__restrict const ca_ptr =
783
+ ca_batched ? (ca + shot_offset) : ca;
784
+ TIDE_DTYPE const *__restrict const cb_ptr =
785
+ cb_batched ? (cb + shot_offset) : cb;
786
+ for (int pml_y = 0; pml_y < 3; ++pml_y) {
787
+ for (int pml_x = 0; pml_x < 3; ++pml_x) {
788
+ TIDE_OMP_SIMD_COLLAPSE2
789
+ for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
790
+ for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
791
+ int64_t const idx = IDX(y, x);
792
+ int64_t const store_idx = shot_offset + idx;
793
+ TIDE_DTYPE const ca_val = ca_ptr[idx];
794
+ TIDE_DTYPE const cb_val = cb_ptr[idx];
795
+
796
+ TIDE_DTYPE dhz_dx = DIFFX1(HZ);
797
+ TIDE_DTYPE dhx_dz = DIFFY1(HX);
798
+
799
+ if (pml_x != 1) {
800
+ M_HZ_X(0, 0) = bx[x] * M_HZ_X(0, 0) + ax[x] * dhz_dx;
801
+ dhz_dx = dhz_dx / kx[x] + M_HZ_X(0, 0);
802
+ }
803
+
804
+ if (pml_y != 1) {
805
+ M_HX_Z(0, 0) = by[y] * M_HX_Z(0, 0) + ay[y] * dhx_dz;
806
+ dhx_dz = dhx_dz / ky[y] + M_HX_Z(0, 0);
807
+ }
808
+
809
+ TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
810
+ if (ca_requires_grad && ey_store != NULL) {
811
+ ey_store[store_idx] = tide_float_to_bf16((float)EY(0, 0));
812
+ }
813
+ if (cb_requires_grad && curl_h_store != NULL) {
814
+ curl_h_store[store_idx] = tide_float_to_bf16((float)curl_h);
815
+ }
816
+
817
+ EY(0, 0) = ca_val * EY(0, 0) + cb_val * curl_h;
818
+ }
819
+ }
820
+ }
821
+ }
822
+ }
823
+ }
824
+
825
+ static void forward_kernel_e_with_storage_fp8(
826
+ TIDE_DTYPE const *__restrict const ca,
827
+ TIDE_DTYPE const *__restrict const cb,
828
+ TIDE_DTYPE const *__restrict const hx,
829
+ TIDE_DTYPE const *__restrict const hz,
830
+ TIDE_DTYPE *__restrict const ey,
831
+ TIDE_DTYPE *__restrict const m_hx_z,
832
+ TIDE_DTYPE *__restrict const m_hz_x,
833
+ TIDE_DTYPE const *__restrict const ay,
834
+ TIDE_DTYPE const *__restrict const ayh,
835
+ TIDE_DTYPE const *__restrict const ax,
836
+ TIDE_DTYPE const *__restrict const axh,
837
+ TIDE_DTYPE const *__restrict const by,
838
+ TIDE_DTYPE const *__restrict const byh,
839
+ TIDE_DTYPE const *__restrict const bx,
840
+ TIDE_DTYPE const *__restrict const bxh,
841
+ TIDE_DTYPE const *__restrict const ky,
842
+ TIDE_DTYPE const *__restrict const kyh,
843
+ TIDE_DTYPE const *__restrict const kx,
844
+ TIDE_DTYPE const *__restrict const kxh,
845
+ TIDE_DTYPE const rdy,
846
+ TIDE_DTYPE const rdx,
847
+ int64_t const n_shots,
848
+ int64_t const ny,
849
+ int64_t const nx,
850
+ int64_t const shot_numel,
851
+ int64_t const pml_y0,
852
+ int64_t const pml_y1,
853
+ int64_t const pml_x0,
854
+ int64_t const pml_x1,
855
+ bool const ca_batched,
856
+ bool const cb_batched,
857
+ bool const ca_requires_grad,
858
+ bool const cb_requires_grad,
859
+ tide_fp8_e4m3 *__restrict const ey_store,
860
+ tide_fp8_e4m3 *__restrict const curl_h_store) {
861
+
862
+ int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
863
+ int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
864
+
865
+ TIDE_OMP_INDEX shot_idx;
866
+ TIDE_OMP_PARALLEL_FOR
867
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
868
+ int64_t const shot_offset = shot_idx * shot_numel;
869
+ TIDE_DTYPE const *__restrict const ca_ptr =
870
+ ca_batched ? (ca + shot_offset) : ca;
871
+ TIDE_DTYPE const *__restrict const cb_ptr =
872
+ cb_batched ? (cb + shot_offset) : cb;
873
+ for (int pml_y = 0; pml_y < 3; ++pml_y) {
874
+ for (int pml_x = 0; pml_x < 3; ++pml_x) {
875
+ TIDE_OMP_SIMD_COLLAPSE2
876
+ for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
877
+ for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
878
+ int64_t const idx = IDX(y, x);
879
+ int64_t const store_idx = shot_offset + idx;
880
+ TIDE_DTYPE const ca_val = ca_ptr[idx];
881
+ TIDE_DTYPE const cb_val = cb_ptr[idx];
882
+
883
+ TIDE_DTYPE dhz_dx = DIFFX1(HZ);
884
+ TIDE_DTYPE dhx_dz = DIFFY1(HX);
885
+
886
+ if (pml_x != 1) {
887
+ M_HZ_X(0, 0) = bx[x] * M_HZ_X(0, 0) + ax[x] * dhz_dx;
888
+ dhz_dx = dhz_dx / kx[x] + M_HZ_X(0, 0);
889
+ }
890
+
891
+ if (pml_y != 1) {
892
+ M_HX_Z(0, 0) = by[y] * M_HX_Z(0, 0) + ay[y] * dhx_dz;
893
+ dhx_dz = dhx_dz / ky[y] + M_HX_Z(0, 0);
894
+ }
895
+
896
+ TIDE_DTYPE curl_h = dhz_dx - dhx_dz;
897
+ if (ca_requires_grad && ey_store != NULL) {
898
+ ey_store[store_idx] = tide_float_to_fp8_e4m3((float)EY(0, 0));
899
+ }
900
+ if (cb_requires_grad && curl_h_store != NULL) {
901
+ curl_h_store[store_idx] = tide_float_to_fp8_e4m3((float)curl_h);
902
+ }
903
+
904
+ EY(0, 0) = ca_val * EY(0, 0) + cb_val * curl_h;
905
+ }
906
+ }
907
+ }
908
+ }
909
+ }
910
+ }
911
+
912
+
913
+ #ifdef __cplusplus
914
+ extern "C"
915
+ #endif
916
+ #ifdef _WIN32
917
+ __declspec(dllexport)
918
+ #endif
919
+ void FUNC(forward)(
920
+ TIDE_DTYPE const *const ca,
921
+ TIDE_DTYPE const *const cb,
922
+ TIDE_DTYPE const *const cq,
923
+ TIDE_DTYPE const *const f,
924
+ TIDE_DTYPE *const ey,
925
+ TIDE_DTYPE *const hx,
926
+ TIDE_DTYPE *const hz,
927
+ TIDE_DTYPE *const m_ey_x,
928
+ TIDE_DTYPE *const m_ey_z,
929
+ TIDE_DTYPE *const m_hx_z,
930
+ TIDE_DTYPE *const m_hz_x,
931
+ TIDE_DTYPE *const r,
932
+ TIDE_DTYPE const *const ay,
933
+ TIDE_DTYPE const *const by,
934
+ TIDE_DTYPE const *const ayh,
935
+ TIDE_DTYPE const *const byh,
936
+ TIDE_DTYPE const *const ax,
937
+ TIDE_DTYPE const *const bx,
938
+ TIDE_DTYPE const *const axh,
939
+ TIDE_DTYPE const *const bxh,
940
+ TIDE_DTYPE const *const ky,
941
+ TIDE_DTYPE const *const kyh,
942
+ TIDE_DTYPE const *const kx,
943
+ TIDE_DTYPE const *const kxh,
944
+ int64_t const *const sources_i,
945
+ int64_t const *const receivers_i,
946
+ TIDE_DTYPE const rdy,
947
+ TIDE_DTYPE const rdx,
948
+ TIDE_DTYPE const dt,
949
+ int64_t const nt,
950
+ int64_t const n_shots,
951
+ int64_t const ny,
952
+ int64_t const nx,
953
+ int64_t const n_sources_per_shot,
954
+ int64_t const n_receivers_per_shot,
955
+ int64_t const step_ratio,
956
+ bool const ca_batched,
957
+ bool const cb_batched,
958
+ bool const cq_batched,
959
+ int64_t const start_t,
960
+ int64_t const pml_y0,
961
+ int64_t const pml_x0,
962
+ int64_t const pml_y1,
963
+ int64_t const pml_x1,
964
+ int64_t const n_threads,
965
+ int64_t const device /* unused for CPU */) {
966
+
967
+ (void)device;
968
+ (void)dt;
969
+ (void)step_ratio;
970
+ #ifdef _OPENMP
971
+ int const prev_threads = omp_get_max_threads();
972
+ if (n_threads > 0) {
973
+ omp_set_num_threads((int)n_threads);
974
+ }
975
+ #else
976
+ (void)n_threads;
977
+ #endif
978
+
979
+ int64_t const shot_numel = ny * nx;
980
+
981
+ for (int64_t t = start_t; t < start_t + nt; ++t) {
982
+ forward_kernel_h(
983
+ cq, ey, hx, hz, m_ey_x, m_ey_z,
984
+ ay, ayh, ax, axh, by, byh, bx, bxh,
985
+ ky, kyh, kx, kxh,
986
+ rdy, rdx,
987
+ n_shots, ny, nx, shot_numel,
988
+ pml_y0, pml_y1, pml_x0, pml_x1,
989
+ cq_batched);
990
+
991
+ forward_kernel_e_with_storage(
992
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
993
+ ay, ayh, ax, axh, by, byh, bx, bxh,
994
+ ky, kyh, kx, kxh,
995
+ rdy, rdx,
996
+ n_shots, ny, nx, shot_numel,
997
+ pml_y0, pml_y1, pml_x0, pml_x1,
998
+ ca_batched, cb_batched,
999
+ false, false, // No storage for standard forward
1000
+ NULL, NULL);
1001
+
1002
+ if (n_sources_per_shot > 0) {
1003
+ add_sources_ey(
1004
+ ey, f + t * n_shots * n_sources_per_shot, sources_i,
1005
+ n_shots, shot_numel, n_sources_per_shot);
1006
+ }
1007
+
1008
+ if (n_receivers_per_shot > 0) {
1009
+ record_receivers_ey(
1010
+ r + t * n_shots * n_receivers_per_shot,
1011
+ ey, receivers_i,
1012
+ n_shots, shot_numel, n_receivers_per_shot);
1013
+ }
1014
+ }
1015
+ #ifdef _OPENMP
1016
+ if (n_threads > 0) {
1017
+ omp_set_num_threads(prev_threads);
1018
+ }
1019
+ #endif
1020
+ }
1021
+
1022
+
1023
+ /*
1024
+ * Forward with storage for backward pass
1025
+ *
1026
+ * This function performs forward propagation while storing the values
1027
+ * needed for gradient computation in the backward pass.
1028
+ */
1029
+ #ifdef __cplusplus
1030
+ extern "C"
1031
+ #endif
1032
+ #ifdef _WIN32
1033
+ __declspec(dllexport)
1034
+ #endif
1035
+ void FUNC(forward_with_storage)(
1036
+ TIDE_DTYPE const *const ca,
1037
+ TIDE_DTYPE const *const cb,
1038
+ TIDE_DTYPE const *const cq,
1039
+ TIDE_DTYPE const *const f,
1040
+ TIDE_DTYPE *const ey,
1041
+ TIDE_DTYPE *const hx,
1042
+ TIDE_DTYPE *const hz,
1043
+ TIDE_DTYPE *const m_ey_x,
1044
+ TIDE_DTYPE *const m_ey_z,
1045
+ TIDE_DTYPE *const m_hx_z,
1046
+ TIDE_DTYPE *const m_hz_x,
1047
+ TIDE_DTYPE *const r,
1048
+ TIDE_DTYPE *const ey_store_1,
1049
+ void *const ey_store_3,
1050
+ char const *const *const ey_filenames,
1051
+ TIDE_DTYPE *const curl_store_1,
1052
+ void *const curl_store_3,
1053
+ char const *const *const curl_filenames,
1054
+ TIDE_DTYPE const *const ay,
1055
+ TIDE_DTYPE const *const by,
1056
+ TIDE_DTYPE const *const ayh,
1057
+ TIDE_DTYPE const *const byh,
1058
+ TIDE_DTYPE const *const ax,
1059
+ TIDE_DTYPE const *const bx,
1060
+ TIDE_DTYPE const *const axh,
1061
+ TIDE_DTYPE const *const bxh,
1062
+ TIDE_DTYPE const *const ky,
1063
+ TIDE_DTYPE const *const kyh,
1064
+ TIDE_DTYPE const *const kx,
1065
+ TIDE_DTYPE const *const kxh,
1066
+ int64_t const *const sources_i,
1067
+ int64_t const *const receivers_i,
1068
+ TIDE_DTYPE const rdy,
1069
+ TIDE_DTYPE const rdx,
1070
+ TIDE_DTYPE const dt,
1071
+ int64_t const nt,
1072
+ int64_t const n_shots,
1073
+ int64_t const ny,
1074
+ int64_t const nx,
1075
+ int64_t const n_sources_per_shot,
1076
+ int64_t const n_receivers_per_shot,
1077
+ int64_t const step_ratio,
1078
+ int64_t const storage_mode,
1079
+ int64_t const shot_bytes_uncomp,
1080
+ bool const ca_requires_grad,
1081
+ bool const cb_requires_grad,
1082
+ bool const ca_batched,
1083
+ bool const cb_batched,
1084
+ bool const cq_batched,
1085
+ int64_t const start_t,
1086
+ int64_t const pml_y0,
1087
+ int64_t const pml_x0,
1088
+ int64_t const pml_y1,
1089
+ int64_t const pml_x1,
1090
+ int64_t const n_threads,
1091
+ int64_t const device /* unused for CPU */) {
1092
+
1093
+ (void)device;
1094
+ (void)dt;
1095
+ #ifdef _OPENMP
1096
+ int const prev_threads = omp_get_max_threads();
1097
+ if (n_threads > 0) {
1098
+ omp_set_num_threads((int)n_threads);
1099
+ }
1100
+ #else
1101
+ (void)n_threads;
1102
+ #endif
1103
+
1104
+ int64_t const shot_numel = ny * nx;
1105
+ int64_t const store_size = n_shots * shot_numel;
1106
+ bool const storage_fp8 = (shot_bytes_uncomp == shot_numel * 1);
1107
+ bool const storage_bf16 = (shot_bytes_uncomp == shot_numel * 2);
1108
+
1109
+ FILE **fp_ey = NULL;
1110
+ FILE **fp_curl = NULL;
1111
+ if (storage_mode == STORAGE_DISK) {
1112
+ if (ca_requires_grad) {
1113
+ fp_ey = (FILE **)malloc((size_t)n_shots * sizeof(FILE *));
1114
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1115
+ fp_ey[shot] = fopen(ey_filenames[shot], "wb");
1116
+ }
1117
+ }
1118
+ if (cb_requires_grad) {
1119
+ fp_curl = (FILE **)malloc((size_t)n_shots * sizeof(FILE *));
1120
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1121
+ fp_curl[shot] = fopen(curl_filenames[shot], "wb");
1122
+ }
1123
+ }
1124
+ }
1125
+
1126
+ for (int64_t t = start_t; t < start_t + nt; ++t) {
1127
+ forward_kernel_h(
1128
+ cq, ey, hx, hz, m_ey_x, m_ey_z,
1129
+ ay, ayh, ax, axh, by, byh, bx, bxh,
1130
+ ky, kyh, kx, kxh,
1131
+ rdy, rdx,
1132
+ n_shots, ny, nx, shot_numel,
1133
+ pml_y0, pml_y1, pml_x0, pml_x1,
1134
+ cq_batched);
1135
+
1136
+ bool const store_step = ((t % step_ratio) == 0);
1137
+ bool const store_ey = store_step && ca_requires_grad;
1138
+ bool const store_curl = store_step && cb_requires_grad;
1139
+ int64_t const step_idx = t / step_ratio;
1140
+
1141
+ int64_t const store_offset =
1142
+ (storage_mode == STORAGE_DEVICE ? step_idx * store_size : 0);
1143
+
1144
+ if (storage_fp8) {
1145
+ tide_fp8_e4m3 *const ey_store_1_t =
1146
+ (tide_fp8_e4m3 *)ey_store_1 + store_offset;
1147
+ tide_fp8_e4m3 *const curl_store_1_t =
1148
+ (tide_fp8_e4m3 *)curl_store_1 + store_offset;
1149
+
1150
+ forward_kernel_e_with_storage_fp8(
1151
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
1152
+ ay, ayh, ax, axh, by, byh, bx, bxh,
1153
+ ky, kyh, kx, kxh,
1154
+ rdy, rdx,
1155
+ n_shots, ny, nx, shot_numel,
1156
+ pml_y0, pml_y1, pml_x0, pml_x1,
1157
+ ca_batched, cb_batched,
1158
+ store_ey,
1159
+ store_curl,
1160
+ store_ey ? ey_store_1_t : NULL,
1161
+ store_curl ? curl_store_1_t : NULL);
1162
+
1163
+ if (store_ey && storage_mode == STORAGE_DISK) {
1164
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1165
+ storage_save_snapshot_cpu(
1166
+ (void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
1167
+ storage_mode, step_idx, (size_t)shot_bytes_uncomp);
1168
+ }
1169
+ }
1170
+ if (store_curl && storage_mode == STORAGE_DISK) {
1171
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1172
+ storage_save_snapshot_cpu(
1173
+ (void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
1174
+ storage_mode, step_idx, (size_t)shot_bytes_uncomp);
1175
+ }
1176
+ }
1177
+ } else if (storage_bf16) {
1178
+ tide_bfloat16 *const ey_store_1_t =
1179
+ (tide_bfloat16 *)ey_store_1 + store_offset;
1180
+ tide_bfloat16 *const curl_store_1_t =
1181
+ (tide_bfloat16 *)curl_store_1 + store_offset;
1182
+
1183
+ forward_kernel_e_with_storage_bf16(
1184
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
1185
+ ay, ayh, ax, axh, by, byh, bx, bxh,
1186
+ ky, kyh, kx, kxh,
1187
+ rdy, rdx,
1188
+ n_shots, ny, nx, shot_numel,
1189
+ pml_y0, pml_y1, pml_x0, pml_x1,
1190
+ ca_batched, cb_batched,
1191
+ store_ey,
1192
+ store_curl,
1193
+ store_ey ? ey_store_1_t : NULL,
1194
+ store_curl ? curl_store_1_t : NULL);
1195
+
1196
+ if (store_ey && storage_mode == STORAGE_DISK) {
1197
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1198
+ storage_save_snapshot_cpu(
1199
+ (void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
1200
+ storage_mode, step_idx, (size_t)shot_bytes_uncomp);
1201
+ }
1202
+ }
1203
+ if (store_curl && storage_mode == STORAGE_DISK) {
1204
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1205
+ storage_save_snapshot_cpu(
1206
+ (void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
1207
+ storage_mode, step_idx, (size_t)shot_bytes_uncomp);
1208
+ }
1209
+ }
1210
+ } else {
1211
+ TIDE_DTYPE *const ey_store_1_t =
1212
+ ey_store_1 + store_offset;
1213
+ TIDE_DTYPE *const curl_store_1_t =
1214
+ curl_store_1 + store_offset;
1215
+
1216
+ forward_kernel_e_with_storage(
1217
+ ca, cb, hx, hz, ey, m_hx_z, m_hz_x,
1218
+ ay, ayh, ax, axh, by, byh, bx, bxh,
1219
+ ky, kyh, kx, kxh,
1220
+ rdy, rdx,
1221
+ n_shots, ny, nx, shot_numel,
1222
+ pml_y0, pml_y1, pml_x0, pml_x1,
1223
+ ca_batched, cb_batched,
1224
+ store_ey,
1225
+ store_curl,
1226
+ store_ey ? ey_store_1_t : NULL,
1227
+ store_curl ? curl_store_1_t : NULL);
1228
+
1229
+ if (store_ey && storage_mode == STORAGE_DISK) {
1230
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1231
+ storage_save_snapshot_cpu(
1232
+ (void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
1233
+ storage_mode, step_idx, (size_t)shot_bytes_uncomp);
1234
+ }
1235
+ }
1236
+ if (store_curl && storage_mode == STORAGE_DISK) {
1237
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1238
+ storage_save_snapshot_cpu(
1239
+ (void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
1240
+ storage_mode, step_idx, (size_t)shot_bytes_uncomp);
1241
+ }
1242
+ }
1243
+ }
1244
+
1245
+ if (n_sources_per_shot > 0) {
1246
+ add_sources_ey(
1247
+ ey, f + t * n_shots * n_sources_per_shot, sources_i,
1248
+ n_shots, shot_numel, n_sources_per_shot);
1249
+ }
1250
+
1251
+ if (n_receivers_per_shot > 0) {
1252
+ record_receivers_ey(
1253
+ r + t * n_shots * n_receivers_per_shot,
1254
+ ey, receivers_i,
1255
+ n_shots, shot_numel, n_receivers_per_shot);
1256
+ }
1257
+ }
1258
+
1259
+ if (fp_ey != NULL) {
1260
+ for (int64_t shot = 0; shot < n_shots; ++shot) fclose(fp_ey[shot]);
1261
+ free(fp_ey);
1262
+ }
1263
+ if (fp_curl != NULL) {
1264
+ for (int64_t shot = 0; shot < n_shots; ++shot) fclose(fp_curl[shot]);
1265
+ free(fp_curl);
1266
+ }
1267
+ #ifdef _OPENMP
1268
+ if (n_threads > 0) {
1269
+ omp_set_num_threads(prev_threads);
1270
+ }
1271
+ #endif
1272
+ }
1273
+
1274
+ /*
1275
+ * Backward kernel for adjoint λ_H fields update
1276
+ *
1277
+ * Adjoint equations for H fields (time reversed, swap Cb and Cq roles):
1278
+ * λ_Hx^{n-1/2} = λ_Hx^{n+1/2} - C_b * ∂λ_Ey/∂z
1279
+ * λ_Hz^{n-1/2} = λ_Hz^{n+1/2} + C_b * ∂λ_Ey/∂x
1280
+ */
1281
+ static void backward_kernel_lambda_h(
1282
+ TIDE_DTYPE const *__restrict const cb,
1283
+ TIDE_DTYPE const *__restrict const lambda_ey,
1284
+ TIDE_DTYPE *__restrict const lambda_hx,
1285
+ TIDE_DTYPE *__restrict const lambda_hz,
1286
+ TIDE_DTYPE *__restrict const m_lambda_ey_x,
1287
+ TIDE_DTYPE *__restrict const m_lambda_ey_z,
1288
+ TIDE_DTYPE const *__restrict const ay,
1289
+ TIDE_DTYPE const *__restrict const ayh,
1290
+ TIDE_DTYPE const *__restrict const ax,
1291
+ TIDE_DTYPE const *__restrict const axh,
1292
+ TIDE_DTYPE const *__restrict const by,
1293
+ TIDE_DTYPE const *__restrict const byh,
1294
+ TIDE_DTYPE const *__restrict const bx,
1295
+ TIDE_DTYPE const *__restrict const bxh,
1296
+ TIDE_DTYPE const *__restrict const ky,
1297
+ TIDE_DTYPE const *__restrict const kyh,
1298
+ TIDE_DTYPE const *__restrict const kx,
1299
+ TIDE_DTYPE const *__restrict const kxh,
1300
+ TIDE_DTYPE const rdy,
1301
+ TIDE_DTYPE const rdx,
1302
+ int64_t const n_shots,
1303
+ int64_t const ny,
1304
+ int64_t const nx,
1305
+ int64_t const shot_numel,
1306
+ int64_t const pml_y0,
1307
+ int64_t const pml_y1,
1308
+ int64_t const pml_x0,
1309
+ int64_t const pml_x1,
1310
+ bool const cb_batched) {
1311
+
1312
+ int64_t const pml_y0h = pml_y0;
1313
+ int64_t const pml_y1h = MAX(pml_y0, pml_y1 - 1);
1314
+ int64_t const pml_x0h = pml_x0;
1315
+ int64_t const pml_x1h = MAX(pml_x0, pml_x1 - 1);
1316
+ int64_t const pml_bounds_yh[] = {FD_PAD, pml_y0h, pml_y1h, ny - FD_PAD + 1};
1317
+ int64_t const pml_bounds_xh[] = {FD_PAD, pml_x0h, pml_x1h, nx - FD_PAD + 1};
1318
+
1319
+ TIDE_OMP_INDEX shot_idx;
1320
+ TIDE_OMP_PARALLEL_FOR
1321
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
1322
+ int64_t const shot_offset = shot_idx * shot_numel;
1323
+ TIDE_DTYPE const *__restrict const cb_ptr =
1324
+ cb_batched ? (cb + shot_offset) : cb;
1325
+ for (int pml_y = 0; pml_y < 3; ++pml_y) {
1326
+ for (int pml_x = 0; pml_x < 3; ++pml_x) {
1327
+ TIDE_OMP_SIMD_COLLAPSE2
1328
+ for (int64_t y = pml_bounds_yh[pml_y]; y < pml_bounds_yh[pml_y + 1]; ++y) {
1329
+ for (int64_t x = pml_bounds_xh[pml_x]; x < pml_bounds_xh[pml_x + 1]; ++x) {
1330
+ int64_t const idx = IDX(y, x);
1331
+ TIDE_DTYPE const cb_val = cb_ptr[idx];
1332
+
1333
+ if (y < ny - FD_PAD) {
1334
+ TIDE_DTYPE d_lambda_ey_dz = DIFFYH1(LAMBDA_EY);
1335
+
1336
+ if (pml_y != 1) {
1337
+ M_LAMBDA_EY_Z(0, 0) = byh[y] * M_LAMBDA_EY_Z(0, 0) + ayh[y] * d_lambda_ey_dz;
1338
+ d_lambda_ey_dz = d_lambda_ey_dz / kyh[y] + M_LAMBDA_EY_Z(0, 0);
1339
+ }
1340
+
1341
+ LAMBDA_HX(0, 0) -= cb_val * d_lambda_ey_dz;
1342
+ }
1343
+
1344
+ if (x < nx - FD_PAD) {
1345
+ TIDE_DTYPE d_lambda_ey_dx = DIFFXH1(LAMBDA_EY);
1346
+
1347
+ if (pml_x != 1) {
1348
+ M_LAMBDA_EY_X(0, 0) = bxh[x] * M_LAMBDA_EY_X(0, 0) + axh[x] * d_lambda_ey_dx;
1349
+ d_lambda_ey_dx = d_lambda_ey_dx / kxh[x] + M_LAMBDA_EY_X(0, 0);
1350
+ }
1351
+
1352
+ LAMBDA_HZ(0, 0) += cb_val * d_lambda_ey_dx;
1353
+ }
1354
+ }
1355
+ }
1356
+ }
1357
+ }
1358
+ }
1359
+ }
1360
+
1361
+
1362
+ /*
1363
+ * Backward kernel for adjoint λ_Ey field update with gradient accumulation
1364
+ *
1365
+ * Adjoint equation for E field (time reversed, swap Cb and Cq roles):
1366
+ * λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * (∂λ_Hz/∂x - ∂λ_Hx/∂z)
1367
+ *
1368
+ * Gradient accumulation:
1369
+ * grad_ca += λ_Ey^{n+1} * E_y^n
1370
+ * grad_cb += λ_Ey^{n+1} * curl_H^n
1371
+ *
1372
+ * Uses pml_bounds arrays to divide domain into 9 regions (3x3 grid):
1373
+ * pml_y/pml_x == 0: Left/Top PML region
1374
+ * pml_y/pml_x == 1: Interior region (where gradients are accumulated)
1375
+ * pml_y/pml_x == 2: Right/Bottom PML region
1376
+ */
1377
+ static void backward_kernel_lambda_e_with_grad(
1378
+ TIDE_DTYPE const *__restrict const ca,
1379
+ TIDE_DTYPE const *__restrict const cq,
1380
+ TIDE_DTYPE const *__restrict const lambda_hx,
1381
+ TIDE_DTYPE const *__restrict const lambda_hz,
1382
+ TIDE_DTYPE *__restrict const lambda_ey,
1383
+ TIDE_DTYPE *__restrict const m_lambda_hx_z,
1384
+ TIDE_DTYPE *__restrict const m_lambda_hz_x,
1385
+ TIDE_DTYPE const *__restrict const ey_store,
1386
+ TIDE_DTYPE const *__restrict const curl_h_store,
1387
+ TIDE_DTYPE *__restrict const grad_ca,
1388
+ TIDE_DTYPE *__restrict const grad_cb,
1389
+ TIDE_DTYPE const *__restrict const ay,
1390
+ TIDE_DTYPE const *__restrict const ayh,
1391
+ TIDE_DTYPE const *__restrict const ax,
1392
+ TIDE_DTYPE const *__restrict const axh,
1393
+ TIDE_DTYPE const *__restrict const by,
1394
+ TIDE_DTYPE const *__restrict const byh,
1395
+ TIDE_DTYPE const *__restrict const bx,
1396
+ TIDE_DTYPE const *__restrict const bxh,
1397
+ TIDE_DTYPE const *__restrict const ky,
1398
+ TIDE_DTYPE const *__restrict const kyh,
1399
+ TIDE_DTYPE const *__restrict const kx,
1400
+ TIDE_DTYPE const *__restrict const kxh,
1401
+ TIDE_DTYPE const rdy,
1402
+ TIDE_DTYPE const rdx,
1403
+ int64_t const n_shots,
1404
+ int64_t const ny,
1405
+ int64_t const nx,
1406
+ int64_t const shot_numel,
1407
+ int64_t const pml_y0,
1408
+ int64_t const pml_y1,
1409
+ int64_t const pml_x0,
1410
+ int64_t const pml_x1,
1411
+ bool const ca_batched,
1412
+ bool const cq_batched,
1413
+ bool const ca_requires_grad,
1414
+ bool const cb_requires_grad,
1415
+ int64_t const step_ratio) {
1416
+
1417
+ // PML region bounds arrays
1418
+ // pml_bounds[0] = FD_PAD (start of computational domain)
1419
+ // pml_bounds[1] = pml_y0 (start of interior region)
1420
+ // pml_bounds[2] = pml_y1 (end of interior region)
1421
+ // pml_bounds[3] = ny - FD_PAD + 1 (end of computational domain)
1422
+ int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
1423
+ int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
1424
+
1425
+ TIDE_OMP_INDEX shot_idx;
1426
+ TIDE_OMP_PARALLEL_FOR
1427
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
1428
+ int64_t const shot_offset = shot_idx * shot_numel;
1429
+ TIDE_DTYPE const *__restrict const ca_ptr =
1430
+ ca_batched ? (ca + shot_offset) : ca;
1431
+ TIDE_DTYPE const *__restrict const cq_ptr =
1432
+ cq_batched ? (cq + shot_offset) : cq;
1433
+ // Loop over 3x3 grid of regions
1434
+ for (int pml_y = 0; pml_y < 3; ++pml_y) {
1435
+ for (int pml_x = 0; pml_x < 3; ++pml_x) {
1436
+ TIDE_OMP_SIMD_COLLAPSE2
1437
+ for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
1438
+ for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
1439
+ int64_t const idx = IDX(y, x);
1440
+ int64_t const store_idx = shot_offset + idx;
1441
+ TIDE_DTYPE const ca_val = ca_ptr[idx];
1442
+ TIDE_DTYPE const cq_val = cq_ptr[idx];
1443
+
1444
+ // Compute d(λ_Hz)/dx at integer grid points
1445
+ TIDE_DTYPE d_lambda_hz_dx = DIFFX1(LAMBDA_HZ);
1446
+ // Compute d(λ_Hx)/dz at integer grid points
1447
+ TIDE_DTYPE d_lambda_hx_dz = DIFFY1(LAMBDA_HX);
1448
+
1449
+ // Apply adjoint CPML for d(λ_Hz)/dx (only in PML regions)
1450
+ if (pml_x != 1) {
1451
+ M_LAMBDA_HZ_X(0, 0) = bx[x] * M_LAMBDA_HZ_X(0, 0) + ax[x] * d_lambda_hz_dx;
1452
+ d_lambda_hz_dx = d_lambda_hz_dx / kx[x] + M_LAMBDA_HZ_X(0, 0);
1453
+ }
1454
+
1455
+ // Apply adjoint CPML for d(λ_Hx)/dz (only in PML regions)
1456
+ if (pml_y != 1) {
1457
+ M_LAMBDA_HX_Z(0, 0) = by[y] * M_LAMBDA_HX_Z(0, 0) + ay[y] * d_lambda_hx_dz;
1458
+ d_lambda_hx_dz = d_lambda_hx_dz / ky[y] + M_LAMBDA_HX_Z(0, 0);
1459
+ }
1460
+
1461
+ // curl_λH = d(λ_Hz)/dx - d(λ_Hx)/dz
1462
+ TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
1463
+
1464
+ // Store current λ_Ey before update (this is λ_Ey^{n+1})
1465
+ TIDE_DTYPE lambda_ey_curr = LAMBDA_EY(0, 0);
1466
+
1467
+ // Update λ_Ey: λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * curl_λH
1468
+ LAMBDA_EY(0, 0) = ca_val * lambda_ey_curr + cq_val * curl_lambda_h;
1469
+
1470
+ // Accumulate gradients only in interior region (pml_y == 1 && pml_x == 1)
1471
+ if (pml_y == 1 && pml_x == 1) {
1472
+ // grad_ca += λ_Ey^{n+1} * E_y^n
1473
+ if (ca_requires_grad && ey_store != NULL) {
1474
+ TIDE_DTYPE ey_n = ey_store[store_idx];
1475
+ if (ca_batched) {
1476
+ grad_ca[store_idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
1477
+ } else {
1478
+ #ifdef _OPENMP
1479
+ #pragma omp atomic
1480
+ #endif
1481
+ grad_ca[idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
1482
+ }
1483
+ }
1484
+
1485
+ // grad_cb += λ_Ey^{n+1} * curl_H^n
1486
+ if (cb_requires_grad && curl_h_store != NULL) {
1487
+ TIDE_DTYPE curl_h_n = curl_h_store[store_idx];
1488
+ if (ca_batched) {
1489
+ grad_cb[store_idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
1490
+ } else {
1491
+ #ifdef _OPENMP
1492
+ #pragma omp atomic
1493
+ #endif
1494
+ grad_cb[idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
1495
+ }
1496
+ }
1497
+ }
1498
+ }
1499
+ }
1500
+ }
1501
+ }
1502
+ }
1503
+ }
1504
+
1505
+ static void backward_kernel_lambda_e_with_grad_bf16(
1506
+ TIDE_DTYPE const *__restrict const ca,
1507
+ TIDE_DTYPE const *__restrict const cq,
1508
+ TIDE_DTYPE const *__restrict const lambda_hx,
1509
+ TIDE_DTYPE const *__restrict const lambda_hz,
1510
+ TIDE_DTYPE *__restrict const lambda_ey,
1511
+ TIDE_DTYPE *__restrict const m_lambda_hx_z,
1512
+ TIDE_DTYPE *__restrict const m_lambda_hz_x,
1513
+ tide_bfloat16 const *__restrict const ey_store,
1514
+ tide_bfloat16 const *__restrict const curl_h_store,
1515
+ TIDE_DTYPE *__restrict const grad_ca,
1516
+ TIDE_DTYPE *__restrict const grad_cb,
1517
+ TIDE_DTYPE const *__restrict const ay,
1518
+ TIDE_DTYPE const *__restrict const ayh,
1519
+ TIDE_DTYPE const *__restrict const ax,
1520
+ TIDE_DTYPE const *__restrict const axh,
1521
+ TIDE_DTYPE const *__restrict const by,
1522
+ TIDE_DTYPE const *__restrict const byh,
1523
+ TIDE_DTYPE const *__restrict const bx,
1524
+ TIDE_DTYPE const *__restrict const bxh,
1525
+ TIDE_DTYPE const *__restrict const ky,
1526
+ TIDE_DTYPE const *__restrict const kyh,
1527
+ TIDE_DTYPE const *__restrict const kx,
1528
+ TIDE_DTYPE const *__restrict const kxh,
1529
+ TIDE_DTYPE const rdy,
1530
+ TIDE_DTYPE const rdx,
1531
+ int64_t const n_shots,
1532
+ int64_t const ny,
1533
+ int64_t const nx,
1534
+ int64_t const shot_numel,
1535
+ int64_t const pml_y0,
1536
+ int64_t const pml_y1,
1537
+ int64_t const pml_x0,
1538
+ int64_t const pml_x1,
1539
+ bool const ca_batched,
1540
+ bool const cq_batched,
1541
+ bool const ca_requires_grad,
1542
+ bool const cb_requires_grad,
1543
+ int64_t const step_ratio) {
1544
+
1545
+ int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
1546
+ int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
1547
+
1548
+ TIDE_OMP_INDEX shot_idx;
1549
+ TIDE_OMP_PARALLEL_FOR
1550
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
1551
+ int64_t const shot_offset = shot_idx * shot_numel;
1552
+ TIDE_DTYPE const *__restrict const ca_ptr =
1553
+ ca_batched ? (ca + shot_offset) : ca;
1554
+ TIDE_DTYPE const *__restrict const cq_ptr =
1555
+ cq_batched ? (cq + shot_offset) : cq;
1556
+ for (int pml_y = 0; pml_y < 3; ++pml_y) {
1557
+ for (int pml_x = 0; pml_x < 3; ++pml_x) {
1558
+ TIDE_OMP_SIMD_COLLAPSE2
1559
+ for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
1560
+ for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
1561
+ int64_t const idx = IDX(y, x);
1562
+ int64_t const store_idx = shot_offset + idx;
1563
+ TIDE_DTYPE const ca_val = ca_ptr[idx];
1564
+ TIDE_DTYPE const cq_val = cq_ptr[idx];
1565
+
1566
+ TIDE_DTYPE d_lambda_hz_dx = DIFFX1(LAMBDA_HZ);
1567
+ TIDE_DTYPE d_lambda_hx_dz = DIFFY1(LAMBDA_HX);
1568
+
1569
+ if (pml_x != 1) {
1570
+ M_LAMBDA_HZ_X(0, 0) = bx[x] * M_LAMBDA_HZ_X(0, 0) + ax[x] * d_lambda_hz_dx;
1571
+ d_lambda_hz_dx = d_lambda_hz_dx / kx[x] + M_LAMBDA_HZ_X(0, 0);
1572
+ }
1573
+ if (pml_y != 1) {
1574
+ M_LAMBDA_HX_Z(0, 0) = by[y] * M_LAMBDA_HX_Z(0, 0) + ay[y] * d_lambda_hx_dz;
1575
+ d_lambda_hx_dz = d_lambda_hx_dz / ky[y] + M_LAMBDA_HX_Z(0, 0);
1576
+ }
1577
+
1578
+ TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
1579
+ TIDE_DTYPE lambda_ey_curr = LAMBDA_EY(0, 0);
1580
+ LAMBDA_EY(0, 0) = ca_val * lambda_ey_curr + cq_val * curl_lambda_h;
1581
+
1582
+ if (pml_y == 1 && pml_x == 1) {
1583
+ if (ca_requires_grad && ey_store != NULL) {
1584
+ TIDE_DTYPE ey_n =
1585
+ (TIDE_DTYPE)tide_bf16_to_float(ey_store[store_idx]);
1586
+ if (ca_batched) {
1587
+ grad_ca[store_idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
1588
+ } else {
1589
+ #ifdef _OPENMP
1590
+ #pragma omp atomic
1591
+ #endif
1592
+ grad_ca[idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
1593
+ }
1594
+ }
1595
+
1596
+ if (cb_requires_grad && curl_h_store != NULL) {
1597
+ TIDE_DTYPE curl_h_n =
1598
+ (TIDE_DTYPE)tide_bf16_to_float(curl_h_store[store_idx]);
1599
+ if (ca_batched) {
1600
+ grad_cb[store_idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
1601
+ } else {
1602
+ #ifdef _OPENMP
1603
+ #pragma omp atomic
1604
+ #endif
1605
+ grad_cb[idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
1606
+ }
1607
+ }
1608
+ }
1609
+ }
1610
+ }
1611
+ }
1612
+ }
1613
+ }
1614
+ }
1615
+
1616
+ static void backward_kernel_lambda_e_with_grad_fp8(
1617
+ TIDE_DTYPE const *__restrict const ca,
1618
+ TIDE_DTYPE const *__restrict const cq,
1619
+ TIDE_DTYPE const *__restrict const lambda_hx,
1620
+ TIDE_DTYPE const *__restrict const lambda_hz,
1621
+ TIDE_DTYPE *__restrict const lambda_ey,
1622
+ TIDE_DTYPE *__restrict const m_lambda_hx_z,
1623
+ TIDE_DTYPE *__restrict const m_lambda_hz_x,
1624
+ tide_fp8_e4m3 const *__restrict const ey_store,
1625
+ tide_fp8_e4m3 const *__restrict const curl_h_store,
1626
+ TIDE_DTYPE *__restrict const grad_ca,
1627
+ TIDE_DTYPE *__restrict const grad_cb,
1628
+ TIDE_DTYPE const *__restrict const ay,
1629
+ TIDE_DTYPE const *__restrict const ayh,
1630
+ TIDE_DTYPE const *__restrict const ax,
1631
+ TIDE_DTYPE const *__restrict const axh,
1632
+ TIDE_DTYPE const *__restrict const by,
1633
+ TIDE_DTYPE const *__restrict const byh,
1634
+ TIDE_DTYPE const *__restrict const bx,
1635
+ TIDE_DTYPE const *__restrict const bxh,
1636
+ TIDE_DTYPE const *__restrict const ky,
1637
+ TIDE_DTYPE const *__restrict const kyh,
1638
+ TIDE_DTYPE const *__restrict const kx,
1639
+ TIDE_DTYPE const *__restrict const kxh,
1640
+ TIDE_DTYPE const rdy,
1641
+ TIDE_DTYPE const rdx,
1642
+ int64_t const n_shots,
1643
+ int64_t const ny,
1644
+ int64_t const nx,
1645
+ int64_t const shot_numel,
1646
+ int64_t const pml_y0,
1647
+ int64_t const pml_y1,
1648
+ int64_t const pml_x0,
1649
+ int64_t const pml_x1,
1650
+ bool const ca_batched,
1651
+ bool const cq_batched,
1652
+ bool const ca_requires_grad,
1653
+ bool const cb_requires_grad,
1654
+ int64_t const step_ratio) {
1655
+
1656
+ int64_t const pml_bounds_y[] = {FD_PAD, pml_y0, pml_y1, ny - FD_PAD + 1};
1657
+ int64_t const pml_bounds_x[] = {FD_PAD, pml_x0, pml_x1, nx - FD_PAD + 1};
1658
+
1659
+ TIDE_OMP_INDEX shot_idx;
1660
+ TIDE_OMP_PARALLEL_FOR
1661
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
1662
+ int64_t const shot_offset = shot_idx * shot_numel;
1663
+ TIDE_DTYPE const *__restrict const ca_ptr =
1664
+ ca_batched ? (ca + shot_offset) : ca;
1665
+ TIDE_DTYPE const *__restrict const cq_ptr =
1666
+ cq_batched ? (cq + shot_offset) : cq;
1667
+ for (int pml_y = 0; pml_y < 3; ++pml_y) {
1668
+ for (int pml_x = 0; pml_x < 3; ++pml_x) {
1669
+ TIDE_OMP_SIMD_COLLAPSE2
1670
+ for (int64_t y = pml_bounds_y[pml_y]; y < pml_bounds_y[pml_y + 1]; ++y) {
1671
+ for (int64_t x = pml_bounds_x[pml_x]; x < pml_bounds_x[pml_x + 1]; ++x) {
1672
+ int64_t const idx = IDX(y, x);
1673
+ int64_t const store_idx = shot_offset + idx;
1674
+ TIDE_DTYPE const ca_val = ca_ptr[idx];
1675
+ TIDE_DTYPE const cq_val = cq_ptr[idx];
1676
+
1677
+ TIDE_DTYPE d_lambda_hz_dx = DIFFX1(LAMBDA_HZ);
1678
+ TIDE_DTYPE d_lambda_hx_dz = DIFFY1(LAMBDA_HX);
1679
+
1680
+ if (pml_x != 1) {
1681
+ M_LAMBDA_HZ_X(0, 0) = bx[x] * M_LAMBDA_HZ_X(0, 0) + ax[x] * d_lambda_hz_dx;
1682
+ d_lambda_hz_dx = d_lambda_hz_dx / kx[x] + M_LAMBDA_HZ_X(0, 0);
1683
+ }
1684
+ if (pml_y != 1) {
1685
+ M_LAMBDA_HX_Z(0, 0) = by[y] * M_LAMBDA_HX_Z(0, 0) + ay[y] * d_lambda_hx_dz;
1686
+ d_lambda_hx_dz = d_lambda_hx_dz / ky[y] + M_LAMBDA_HX_Z(0, 0);
1687
+ }
1688
+
1689
+ TIDE_DTYPE curl_lambda_h = d_lambda_hz_dx - d_lambda_hx_dz;
1690
+ TIDE_DTYPE lambda_ey_curr = LAMBDA_EY(0, 0);
1691
+ LAMBDA_EY(0, 0) = ca_val * lambda_ey_curr + cq_val * curl_lambda_h;
1692
+
1693
+ if (pml_y == 1 && pml_x == 1) {
1694
+ if (ca_requires_grad && ey_store != NULL) {
1695
+ TIDE_DTYPE ey_n =
1696
+ (TIDE_DTYPE)tide_fp8_e4m3_to_float(ey_store[store_idx]);
1697
+ if (ca_batched) {
1698
+ grad_ca[store_idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
1699
+ } else {
1700
+ #ifdef _OPENMP
1701
+ #pragma omp atomic
1702
+ #endif
1703
+ grad_ca[idx] += lambda_ey_curr * ey_n * (TIDE_DTYPE)step_ratio;
1704
+ }
1705
+ }
1706
+
1707
+ if (cb_requires_grad && curl_h_store != NULL) {
1708
+ TIDE_DTYPE curl_h_n =
1709
+ (TIDE_DTYPE)tide_fp8_e4m3_to_float(curl_h_store[store_idx]);
1710
+ if (ca_batched) {
1711
+ grad_cb[store_idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
1712
+ } else {
1713
+ #ifdef _OPENMP
1714
+ #pragma omp atomic
1715
+ #endif
1716
+ grad_cb[idx] += lambda_ey_curr * curl_h_n * (TIDE_DTYPE)step_ratio;
1717
+ }
1718
+ }
1719
+ }
1720
+ }
1721
+ }
1722
+ }
1723
+ }
1724
+ }
1725
+ }
1726
+
1727
+ static void inverse_kernel_e_and_curl(
1728
+ TIDE_DTYPE const *__restrict const ca,
1729
+ TIDE_DTYPE const *__restrict const cb,
1730
+ TIDE_DTYPE const *__restrict const hx,
1731
+ TIDE_DTYPE const *__restrict const hz,
1732
+ TIDE_DTYPE *__restrict const ey,
1733
+ TIDE_DTYPE *__restrict const curl_h_out,
1734
+ TIDE_DTYPE const rdy,
1735
+ TIDE_DTYPE const rdx,
1736
+ int64_t const n_shots,
1737
+ int64_t const ny,
1738
+ int64_t const nx,
1739
+ int64_t const shot_numel,
1740
+ int64_t const pml_y0,
1741
+ int64_t const pml_y1,
1742
+ int64_t const pml_x0,
1743
+ int64_t const pml_x1,
1744
+ bool const ca_batched,
1745
+ bool const cb_batched) {
1746
+
1747
+ int64_t const y0 = MAX(FD_PAD, pml_y0);
1748
+ int64_t const y1 = MIN(ny - FD_PAD + 1, pml_y1);
1749
+ int64_t const x0 = MAX(FD_PAD, pml_x0);
1750
+ int64_t const x1 = MIN(nx - FD_PAD + 1, pml_x1);
1751
+
1752
+ TIDE_OMP_INDEX shot_idx;
1753
+ TIDE_OMP_PARALLEL_FOR
1754
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
1755
+ int64_t const shot_offset = shot_idx * shot_numel;
1756
+ TIDE_DTYPE const *__restrict const ca_ptr =
1757
+ ca_batched ? (ca + shot_offset) : ca;
1758
+ TIDE_DTYPE const *__restrict const cb_ptr =
1759
+ cb_batched ? (cb + shot_offset) : cb;
1760
+ TIDE_OMP_SIMD_COLLAPSE2
1761
+ for (int64_t y = y0; y < y1; ++y) {
1762
+ for (int64_t x = x0; x < x1; ++x) {
1763
+ int64_t const idx = IDX(y, x);
1764
+ int64_t const store_idx = shot_offset + idx;
1765
+ TIDE_DTYPE const ca_val = ca_ptr[idx];
1766
+ TIDE_DTYPE const cb_val = cb_ptr[idx];
1767
+
1768
+ TIDE_DTYPE const dhz_dx = DIFFX1(HZ);
1769
+ TIDE_DTYPE const dhx_dz = DIFFY1(HX);
1770
+ TIDE_DTYPE const curl_h = dhz_dx - dhx_dz;
1771
+
1772
+ curl_h_out[store_idx] = curl_h;
1773
+ ey[store_idx] = (ey[store_idx] - cb_val * curl_h) / ca_val;
1774
+ }
1775
+ }
1776
+ }
1777
+ }
1778
+
1779
+ static void inverse_kernel_h(
1780
+ TIDE_DTYPE const *__restrict const cq,
1781
+ TIDE_DTYPE const *__restrict const ey,
1782
+ TIDE_DTYPE *__restrict const hx,
1783
+ TIDE_DTYPE *__restrict const hz,
1784
+ TIDE_DTYPE const rdy,
1785
+ TIDE_DTYPE const rdx,
1786
+ int64_t const n_shots,
1787
+ int64_t const ny,
1788
+ int64_t const nx,
1789
+ int64_t const shot_numel,
1790
+ int64_t const pml_y0,
1791
+ int64_t const pml_y1,
1792
+ int64_t const pml_x0,
1793
+ int64_t const pml_x1,
1794
+ bool const cq_batched) {
1795
+
1796
+ int64_t const y0 = MAX(FD_PAD, pml_y0);
1797
+ int64_t const y1 = MIN(ny - FD_PAD + 1, pml_y1);
1798
+ int64_t const x0 = MAX(FD_PAD, pml_x0);
1799
+ int64_t const x1 = MIN(nx - FD_PAD + 1, pml_x1);
1800
+
1801
+ TIDE_OMP_INDEX shot_idx;
1802
+ TIDE_OMP_PARALLEL_FOR
1803
+ for (shot_idx = 0; shot_idx < n_shots; ++shot_idx) {
1804
+ int64_t const shot_offset = shot_idx * shot_numel;
1805
+ TIDE_DTYPE const *__restrict const cq_ptr =
1806
+ cq_batched ? (cq + shot_offset) : cq;
1807
+ TIDE_OMP_SIMD_COLLAPSE2
1808
+ for (int64_t y = y0; y < y1; ++y) {
1809
+ for (int64_t x = x0; x < x1; ++x) {
1810
+ int64_t const idx = IDX(y, x);
1811
+ TIDE_DTYPE const cq_val = cq_ptr[idx];
1812
+
1813
+ if (y < ny - FD_PAD) {
1814
+ TIDE_DTYPE const dey_dz = DIFFYH1(EY);
1815
+ HX(0, 0) += cq_val * dey_dz;
1816
+ }
1817
+ if (x < nx - FD_PAD) {
1818
+ TIDE_DTYPE const dey_dx = DIFFXH1(EY);
1819
+ HZ(0, 0) -= cq_val * dey_dx;
1820
+ }
1821
+ }
1822
+ }
1823
+ }
1824
+ }
1825
+
1826
+ /*
1827
+ * Full backward pass for Maxwell TM equations
1828
+ *
1829
+ * Implements the Adjoint State Method to compute:
1830
+ * - grad_ca: gradient w.r.t. C_a coefficient
1831
+ * - grad_cb: gradient w.r.t. C_b coefficient
1832
+ * - grad_eps: gradient w.r.t. epsilon_r
1833
+ * - grad_sigma: gradient w.r.t. conductivity
1834
+ * - grad_f: gradient w.r.t. source amplitudes
1835
+ */
1836
+ #ifdef __cplusplus
1837
+ extern "C"
1838
+ #endif
1839
+ #ifdef _WIN32
1840
+ __declspec(dllexport)
1841
+ #endif
1842
+ void FUNC(backward)(
1843
+ TIDE_DTYPE const *const ca,
1844
+ TIDE_DTYPE const *const cb,
1845
+ TIDE_DTYPE const *const cq,
1846
+ TIDE_DTYPE const *const grad_r,
1847
+ TIDE_DTYPE *const lambda_ey,
1848
+ TIDE_DTYPE *const lambda_hx,
1849
+ TIDE_DTYPE *const lambda_hz,
1850
+ TIDE_DTYPE *const m_lambda_ey_x,
1851
+ TIDE_DTYPE *const m_lambda_ey_z,
1852
+ TIDE_DTYPE *const m_lambda_hx_z,
1853
+ TIDE_DTYPE *const m_lambda_hz_x,
1854
+ TIDE_DTYPE *const ey_store_1,
1855
+ void *const ey_store_3,
1856
+ char const *const *const ey_filenames,
1857
+ TIDE_DTYPE *const curl_store_1,
1858
+ void *const curl_store_3,
1859
+ char const *const *const curl_filenames,
1860
+ TIDE_DTYPE *const grad_f,
1861
+ TIDE_DTYPE *const grad_ca,
1862
+ TIDE_DTYPE *const grad_cb,
1863
+ TIDE_DTYPE *const grad_eps,
1864
+ TIDE_DTYPE *const grad_sigma,
1865
+ TIDE_DTYPE *const grad_ca_shot, /* unused in CPU - for API compatibility */
1866
+ TIDE_DTYPE *const grad_cb_shot, /* unused in CPU - for API compatibility */
1867
+ TIDE_DTYPE const *const ay,
1868
+ TIDE_DTYPE const *const by,
1869
+ TIDE_DTYPE const *const ayh,
1870
+ TIDE_DTYPE const *const byh,
1871
+ TIDE_DTYPE const *const ax,
1872
+ TIDE_DTYPE const *const bx,
1873
+ TIDE_DTYPE const *const axh,
1874
+ TIDE_DTYPE const *const bxh,
1875
+ TIDE_DTYPE const *const ky,
1876
+ TIDE_DTYPE const *const kyh,
1877
+ TIDE_DTYPE const *const kx,
1878
+ TIDE_DTYPE const *const kxh,
1879
+ int64_t const *const sources_i,
1880
+ int64_t const *const receivers_i,
1881
+ TIDE_DTYPE const rdy,
1882
+ TIDE_DTYPE const rdx,
1883
+ TIDE_DTYPE const dt,
1884
+ int64_t const nt,
1885
+ int64_t const n_shots,
1886
+ int64_t const ny,
1887
+ int64_t const nx,
1888
+ int64_t const n_sources_per_shot,
1889
+ int64_t const n_receivers_per_shot,
1890
+ int64_t const step_ratio,
1891
+ int64_t const storage_mode,
1892
+ int64_t const shot_bytes_uncomp,
1893
+ bool const ca_requires_grad,
1894
+ bool const cb_requires_grad,
1895
+ bool const ca_batched,
1896
+ bool const cb_batched,
1897
+ bool const cq_batched,
1898
+ int64_t const start_t,
1899
+ int64_t const pml_y0,
1900
+ int64_t const pml_x0,
1901
+ int64_t const pml_y1,
1902
+ int64_t const pml_x1,
1903
+ int64_t const n_threads,
1904
+ int64_t const device /* unused for CPU */) {
1905
+
1906
+ (void)device;
1907
+ (void)grad_ca_shot; // Not needed in CPU version
1908
+ (void)grad_cb_shot; // Not needed in CPU version
1909
+ (void)ey_store_3;
1910
+ (void)curl_store_3;
1911
+ #ifdef _OPENMP
1912
+ int const prev_threads = omp_get_max_threads();
1913
+ if (n_threads > 0) {
1914
+ omp_set_num_threads((int)n_threads);
1915
+ }
1916
+ #else
1917
+ (void)n_threads;
1918
+ #endif
1919
+
1920
+ int64_t const shot_numel = ny * nx;
1921
+ int64_t const store_size = n_shots * shot_numel;
1922
+ bool const storage_fp8 = (shot_bytes_uncomp == shot_numel * 1);
1923
+ bool const storage_bf16 = (shot_bytes_uncomp == shot_numel * 2);
1924
+
1925
+ FILE **fp_ey = NULL;
1926
+ FILE **fp_curl = NULL;
1927
+ if (storage_mode == STORAGE_DISK) {
1928
+ if (ca_requires_grad) {
1929
+ fp_ey = (FILE **)malloc((size_t)n_shots * sizeof(FILE *));
1930
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1931
+ fp_ey[shot] = fopen(ey_filenames[shot], "rb");
1932
+ }
1933
+ }
1934
+ if (cb_requires_grad) {
1935
+ fp_curl = (FILE **)malloc((size_t)n_shots * sizeof(FILE *));
1936
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1937
+ fp_curl[shot] = fopen(curl_filenames[shot], "rb");
1938
+ }
1939
+ }
1940
+ }
1941
+
1942
+ // Time reversed loop: from t = start_t - 1 down to start_t - nt
1943
+ //
1944
+ // Forward order was: H_update -> E_update(store) -> source_inject -> record
1945
+ // Backward order is: record(adjoint) -> source_inject(adjoint) -> E_update(adjoint) -> H_update(adjoint)
1946
+ // Which translates to: grad_r_inject -> grad_f_record -> λ_E_update(grad_accum) -> λ_H_update
1947
+
1948
+ for (int64_t t = start_t - 1; t >= start_t - nt; --t) {
1949
+ // Determine storage index for this time step
1950
+ int64_t const store_idx = t / step_ratio;
1951
+ bool const do_grad = (t % step_ratio) == 0;
1952
+ bool const grad_ey = do_grad && ca_requires_grad;
1953
+ bool const grad_curl = do_grad && cb_requires_grad;
1954
+
1955
+ int64_t const store_offset =
1956
+ (storage_mode == STORAGE_DEVICE ? store_idx * store_size : 0);
1957
+
1958
+ if (storage_fp8) {
1959
+ tide_fp8_e4m3 *const ey_store_1_t =
1960
+ (tide_fp8_e4m3 *)ey_store_1 + store_offset;
1961
+ tide_fp8_e4m3 *const curl_store_1_t =
1962
+ (tide_fp8_e4m3 *)curl_store_1 + store_offset;
1963
+
1964
+ if (storage_mode == STORAGE_DISK) {
1965
+ if (grad_ey) {
1966
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1967
+ storage_load_snapshot_cpu(
1968
+ (void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
1969
+ storage_mode, store_idx, (size_t)shot_bytes_uncomp);
1970
+ }
1971
+ }
1972
+ if (grad_curl) {
1973
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1974
+ storage_load_snapshot_cpu(
1975
+ (void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
1976
+ storage_mode, store_idx, (size_t)shot_bytes_uncomp);
1977
+ }
1978
+ }
1979
+ }
1980
+ } else if (storage_bf16) {
1981
+ tide_bfloat16 *const ey_store_1_t =
1982
+ (tide_bfloat16 *)ey_store_1 + store_offset;
1983
+ tide_bfloat16 *const curl_store_1_t =
1984
+ (tide_bfloat16 *)curl_store_1 + store_offset;
1985
+
1986
+ if (storage_mode == STORAGE_DISK) {
1987
+ if (grad_ey) {
1988
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1989
+ storage_load_snapshot_cpu(
1990
+ (void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
1991
+ storage_mode, store_idx, (size_t)shot_bytes_uncomp);
1992
+ }
1993
+ }
1994
+ if (grad_curl) {
1995
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
1996
+ storage_load_snapshot_cpu(
1997
+ (void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
1998
+ storage_mode, store_idx, (size_t)shot_bytes_uncomp);
1999
+ }
2000
+ }
2001
+ }
2002
+ } else {
2003
+ TIDE_DTYPE *const ey_store_1_t =
2004
+ ey_store_1 + store_offset;
2005
+ TIDE_DTYPE *const curl_store_1_t =
2006
+ curl_store_1 + store_offset;
2007
+
2008
+ if (storage_mode == STORAGE_DISK) {
2009
+ if (grad_ey) {
2010
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
2011
+ storage_load_snapshot_cpu(
2012
+ (void *)(ey_store_1_t + shot * shot_numel), fp_ey[shot],
2013
+ storage_mode, store_idx, (size_t)shot_bytes_uncomp);
2014
+ }
2015
+ }
2016
+ if (grad_curl) {
2017
+ for (int64_t shot = 0; shot < n_shots; ++shot) {
2018
+ storage_load_snapshot_cpu(
2019
+ (void *)(curl_store_1_t + shot * shot_numel), fp_curl[shot],
2020
+ storage_mode, store_idx, (size_t)shot_bytes_uncomp);
2021
+ }
2022
+ }
2023
+ }
2024
+ }
2025
+
2026
+ // Inject adjoint residuals into λ_Ey^{t+1} (adjoint of receiver recording)
2027
+ if (n_receivers_per_shot > 0) {
2028
+ add_sources_ey(
2029
+ lambda_ey, grad_r + t * n_shots * n_receivers_per_shot, receivers_i,
2030
+ n_shots, shot_numel, n_receivers_per_shot);
2031
+ }
2032
+
2033
+ // Record adjoint source gradient using λ_Ey^{t+1} (adjoint of source injection)
2034
+ if (n_sources_per_shot > 0) {
2035
+ record_receivers_ey(
2036
+ grad_f + t * n_shots * n_sources_per_shot,
2037
+ lambda_ey, sources_i,
2038
+ n_shots, shot_numel, n_sources_per_shot);
2039
+ }
2040
+
2041
+ // Backward λ_Ey update with gradient accumulation
2042
+ // This computes: λ_Ey^n = C_a * λ_Ey^{n+1} + C_q * curl_λH
2043
+ // And accumulates: grad_ca += λ_Ey^{n+1} * E_y^n, grad_cb += λ_Ey^{n+1} * curl_H^n
2044
+ if (storage_fp8 && (grad_ey || grad_curl)) {
2045
+ tide_fp8_e4m3 *const ey_store_1_t =
2046
+ (tide_fp8_e4m3 *)ey_store_1 + store_offset;
2047
+ tide_fp8_e4m3 *const curl_store_1_t =
2048
+ (tide_fp8_e4m3 *)curl_store_1 + store_offset;
2049
+ backward_kernel_lambda_e_with_grad_fp8(
2050
+ ca, cq, lambda_hx, lambda_hz, lambda_ey,
2051
+ m_lambda_hx_z, m_lambda_hz_x,
2052
+ grad_ey ? ey_store_1_t : NULL,
2053
+ grad_curl ? curl_store_1_t : NULL,
2054
+ grad_ca, grad_cb,
2055
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2056
+ ky, kyh, kx, kxh,
2057
+ rdy, rdx,
2058
+ n_shots, ny, nx, shot_numel,
2059
+ pml_y0, pml_y1, pml_x0, pml_x1,
2060
+ ca_batched, cq_batched,
2061
+ grad_ey, grad_curl,
2062
+ step_ratio);
2063
+ } else if (storage_bf16 && (grad_ey || grad_curl)) {
2064
+ tide_bfloat16 *const ey_store_1_t =
2065
+ (tide_bfloat16 *)ey_store_1 + store_offset;
2066
+ tide_bfloat16 *const curl_store_1_t =
2067
+ (tide_bfloat16 *)curl_store_1 + store_offset;
2068
+ backward_kernel_lambda_e_with_grad_bf16(
2069
+ ca, cq, lambda_hx, lambda_hz, lambda_ey,
2070
+ m_lambda_hx_z, m_lambda_hz_x,
2071
+ grad_ey ? ey_store_1_t : NULL,
2072
+ grad_curl ? curl_store_1_t : NULL,
2073
+ grad_ca, grad_cb,
2074
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2075
+ ky, kyh, kx, kxh,
2076
+ rdy, rdx,
2077
+ n_shots, ny, nx, shot_numel,
2078
+ pml_y0, pml_y1, pml_x0, pml_x1,
2079
+ ca_batched, cq_batched,
2080
+ grad_ey, grad_curl,
2081
+ step_ratio);
2082
+ } else {
2083
+ TIDE_DTYPE *const ey_store_1_t =
2084
+ (storage_fp8 || storage_bf16) ? NULL : (ey_store_1 + store_offset);
2085
+ TIDE_DTYPE *const curl_store_1_t =
2086
+ (storage_fp8 || storage_bf16) ? NULL : (curl_store_1 + store_offset);
2087
+ backward_kernel_lambda_e_with_grad(
2088
+ ca, cq, lambda_hx, lambda_hz, lambda_ey,
2089
+ m_lambda_hx_z, m_lambda_hz_x,
2090
+ grad_ey ? ey_store_1_t : NULL,
2091
+ grad_curl ? curl_store_1_t : NULL,
2092
+ grad_ca, grad_cb,
2093
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2094
+ ky, kyh, kx, kxh,
2095
+ rdy, rdx,
2096
+ n_shots, ny, nx, shot_numel,
2097
+ pml_y0, pml_y1, pml_x0, pml_x1,
2098
+ ca_batched, cq_batched,
2099
+ grad_ey, grad_curl,
2100
+ step_ratio);
2101
+ }
2102
+
2103
+ // Backward λ_H fields update
2104
+ backward_kernel_lambda_h(
2105
+ cb, lambda_ey, lambda_hx, lambda_hz,
2106
+ m_lambda_ey_x, m_lambda_ey_z,
2107
+ ay, ayh, ax, axh, by, byh, bx, bxh,
2108
+ ky, kyh, kx, kxh,
2109
+ rdy, rdx,
2110
+ n_shots, ny, nx, shot_numel,
2111
+ pml_y0, pml_y1, pml_x0, pml_x1,
2112
+ cb_batched);
2113
+ }
2114
+
2115
+ if (fp_ey != NULL) {
2116
+ for (int64_t shot = 0; shot < n_shots; ++shot) fclose(fp_ey[shot]);
2117
+ free(fp_ey);
2118
+ }
2119
+ if (fp_curl != NULL) {
2120
+ for (int64_t shot = 0; shot < n_shots; ++shot) fclose(fp_curl[shot]);
2121
+ free(fp_curl);
2122
+ }
2123
+
2124
+ convert_grad_ca_cb_to_eps_sigma(
2125
+ ca, cb, grad_ca, grad_cb, grad_eps, grad_sigma,
2126
+ dt, n_shots, ny, nx, ca_batched, cb_batched,
2127
+ ca_requires_grad, cb_requires_grad);
2128
+ #ifdef _OPENMP
2129
+ if (n_threads > 0) {
2130
+ omp_set_num_threads(prev_threads);
2131
+ }
2132
+ #endif
2133
+ }