tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tide/maxwell.py ADDED
@@ -0,0 +1,2651 @@
1
+ from typing import Any, Callable, Optional, Sequence, Union
2
+
3
+ import torch
4
+
5
+ from . import staggered
6
+ from .autograd_utils import (
7
+ _get_ctx_handle,
8
+ _register_ctx_handle,
9
+ _release_ctx_handle,
10
+ )
11
+ from .callbacks import Callback, CallbackState
12
+ from .cfl import cfl_condition
13
+ from .grid_utils import (
14
+ _normalize_grid_spacing_2d,
15
+ _normalize_pml_width_2d,
16
+ )
17
+ from .resampling import downsample_and_movedim, upsample
18
+ from .storage import (
19
+ _CPU_STORAGE_BUFFERS,
20
+ STORAGE_CPU,
21
+ STORAGE_DEVICE,
22
+ STORAGE_DISK,
23
+ STORAGE_NONE,
24
+ TemporaryStorage,
25
+ _normalize_storage_compression,
26
+ _resolve_storage_compression,
27
+ storage_mode_to_int,
28
+ )
29
+ from .utils import C0, prepare_parameters
30
+ from .validation import (
31
+ validate_freq_taper_frac,
32
+ validate_model_gradient_sampling_interval,
33
+ validate_time_pad_frac,
34
+ )
35
+
36
+
37
+ class MaxwellTM(torch.nn.Module):
38
+ """2D TM mode Maxwell equations solver using FDTD method.
39
+
40
+ This module solves the TM (Transverse Magnetic) mode Maxwell equations
41
+ in 2D with fields (Ey, Hx, Hz) using the FDTD method with CPML absorbing
42
+ boundary conditions.
43
+
44
+ Args:
45
+ epsilon: Relative permittivity tensor [ny, nx].
46
+ For vacuum/air, use 1.0. For common materials:
47
+ - Water: ~80
48
+ - Glass: ~4-7
49
+ - Soil (dry): ~3-5
50
+ - Concrete: ~4-8
51
+ sigma: Electrical conductivity tensor [ny, nx] in S/m.
52
+ For lossless media, use 0.0.
53
+ mu: Relative permeability tensor [ny, nx].
54
+ For most non-magnetic materials, use 1.0.
55
+ grid_spacing: Grid spacing in meters. Can be a single value (same for
56
+ both directions) or a sequence [dy, dx].
57
+ epsilon_requires_grad: Whether to compute gradients for permittivity.
58
+ sigma_requires_grad: Whether to compute gradients for conductivity.
59
+
60
+ Note:
61
+ The input parameters are RELATIVE values (dimensionless). They will be
62
+ multiplied internally by the vacuum permittivity (ε₀ = 8.854e-12 F/m)
63
+ and vacuum permeability (μ₀ = 1.257e-6 H/m) respectively.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ epsilon: torch.Tensor,
69
+ sigma: torch.Tensor,
70
+ mu: torch.Tensor,
71
+ grid_spacing: Union[float, Sequence[float]],
72
+ epsilon_requires_grad: Optional[bool] = None,
73
+ sigma_requires_grad: Optional[bool] = None,
74
+ ) -> None:
75
+ super().__init__()
76
+ if epsilon_requires_grad is not None and not isinstance(
77
+ epsilon_requires_grad, bool
78
+ ):
79
+ raise TypeError(
80
+ f"epsilon_requires_grad must be bool or None, "
81
+ f"got {type(epsilon_requires_grad).__name__}",
82
+ )
83
+ if not isinstance(epsilon, torch.Tensor):
84
+ raise TypeError(
85
+ f"epsilon must be torch.Tensor, got {type(epsilon).__name__}",
86
+ )
87
+ if sigma_requires_grad is not None and not isinstance(
88
+ sigma_requires_grad, bool
89
+ ):
90
+ raise TypeError(
91
+ f"sigma_requires_grad must be bool or None, "
92
+ f"got {type(sigma_requires_grad).__name__}",
93
+ )
94
+ if not isinstance(sigma, torch.Tensor):
95
+ raise TypeError(
96
+ f"sigma must be torch.Tensor, got {type(sigma).__name__}",
97
+ )
98
+ if not isinstance(mu, torch.Tensor):
99
+ raise TypeError(
100
+ f"mu must be torch.Tensor, got {type(mu).__name__}",
101
+ )
102
+
103
+ # If requires_grad not specified, preserve the input tensor's setting
104
+ if epsilon_requires_grad is None:
105
+ epsilon_requires_grad = epsilon.requires_grad
106
+ if sigma_requires_grad is None:
107
+ sigma_requires_grad = sigma.requires_grad
108
+
109
+ self.epsilon = torch.nn.Parameter(epsilon, requires_grad=epsilon_requires_grad)
110
+ self.sigma = torch.nn.Parameter(sigma, requires_grad=sigma_requires_grad)
111
+ self.register_buffer("mu", mu) # In normal we don't optimize mu
112
+ self.grid_spacing = grid_spacing
113
+
114
+ def forward(
115
+ self,
116
+ dt: float,
117
+ source_amplitude: Optional[torch.Tensor], # [shot,source,time]
118
+ source_location: Optional[torch.Tensor], # [shot,source,2]
119
+ receiver_location: Optional[torch.Tensor], # [shot,receiver,2]
120
+ stencil: int = 2,
121
+ pml_width: Union[int, Sequence[int]] = 20,
122
+ max_vel: Optional[float] = None,
123
+ Ey_0: Optional[torch.Tensor] = None,
124
+ Hx_0: Optional[torch.Tensor] = None,
125
+ Hz_0: Optional[torch.Tensor] = None,
126
+ m_Ey_x: Optional[torch.Tensor] = None,
127
+ m_Ey_z: Optional[torch.Tensor] = None,
128
+ m_Hx_z: Optional[torch.Tensor] = None,
129
+ m_Hz_x: Optional[torch.Tensor] = None,
130
+ nt: Optional[int] = None,
131
+ model_gradient_sampling_interval: int = 1,
132
+ freq_taper_frac: float = 0.0,
133
+ time_pad_frac: float = 0.0,
134
+ time_taper: bool = False,
135
+ save_snapshots: Optional[bool] = None,
136
+ forward_callback: Optional[Callback] = None,
137
+ backward_callback: Optional[Callback] = None,
138
+ callback_frequency: int = 1,
139
+ python_backend: Union[bool, str] = False,
140
+ storage_mode: str = "device",
141
+ storage_path: str = ".",
142
+ storage_compression: Union[bool, str] = False,
143
+ storage_bytes_limit_device: Optional[int] = None,
144
+ storage_bytes_limit_host: Optional[int] = None,
145
+ storage_chunk_steps: int = 0,
146
+ ):
147
+ # Type assertions for buffer and parameter tensors
148
+ assert isinstance(self.epsilon, torch.Tensor)
149
+ assert isinstance(self.sigma, torch.Tensor)
150
+ assert isinstance(self.mu, torch.Tensor)
151
+ return maxwelltm(
152
+ self.epsilon,
153
+ self.sigma,
154
+ self.mu,
155
+ self.grid_spacing,
156
+ dt,
157
+ source_amplitude,
158
+ source_location,
159
+ receiver_location,
160
+ stencil,
161
+ pml_width,
162
+ max_vel,
163
+ Ey_0,
164
+ Hx_0,
165
+ Hz_0,
166
+ m_Ey_x,
167
+ m_Ey_z,
168
+ m_Hx_z,
169
+ m_Hz_x,
170
+ nt,
171
+ model_gradient_sampling_interval,
172
+ freq_taper_frac,
173
+ time_pad_frac,
174
+ time_taper,
175
+ save_snapshots,
176
+ forward_callback,
177
+ backward_callback,
178
+ callback_frequency,
179
+ python_backend,
180
+ storage_mode,
181
+ storage_path,
182
+ storage_compression,
183
+ storage_bytes_limit_device,
184
+ storage_bytes_limit_host,
185
+ storage_chunk_steps,
186
+ )
187
+
188
+
189
+ def maxwelltm(
190
+ epsilon: torch.Tensor,
191
+ sigma: torch.Tensor,
192
+ mu: torch.Tensor,
193
+ grid_spacing: Union[float, Sequence[float]],
194
+ dt: float,
195
+ source_amplitude: Optional[torch.Tensor],
196
+ source_location: Optional[torch.Tensor],
197
+ receiver_location: Optional[torch.Tensor],
198
+ stencil: int = 2,
199
+ pml_width: Union[int, Sequence[int]] = 20,
200
+ max_vel: Optional[float] = None,
201
+ Ey_0: Optional[torch.Tensor] = None,
202
+ Hx_0: Optional[torch.Tensor] = None,
203
+ Hz_0: Optional[torch.Tensor] = None,
204
+ m_Ey_x: Optional[torch.Tensor] = None,
205
+ m_Ey_z: Optional[torch.Tensor] = None,
206
+ m_Hx_z: Optional[torch.Tensor] = None,
207
+ m_Hz_x: Optional[torch.Tensor] = None,
208
+ nt: Optional[int] = None,
209
+ model_gradient_sampling_interval: int = 1,
210
+ freq_taper_frac: float = 0.0,
211
+ time_pad_frac: float = 0.0,
212
+ time_taper: bool = False,
213
+ save_snapshots: Optional[bool] = None,
214
+ forward_callback: Optional[Callback] = None,
215
+ backward_callback: Optional[Callback] = None,
216
+ callback_frequency: int = 1,
217
+ python_backend: Union[bool, str] = False,
218
+ storage_mode: str = "device",
219
+ storage_path: str = ".",
220
+ storage_compression: Union[bool, str] = False,
221
+ storage_bytes_limit_device: Optional[int] = None,
222
+ storage_bytes_limit_host: Optional[int] = None,
223
+ storage_chunk_steps: int = 0,
224
+ n_threads: Optional[int] = None,
225
+ ):
226
+ """2D TM mode Maxwell equations solver.
227
+
228
+ This is the main entry point for Maxwell TM propagation. It automatically
229
+ handles CFL condition checking and time step resampling when needed.
230
+
231
+ If the user-provided time step (dt) is too large for numerical stability,
232
+ the source signal will be upsampled internally and receiver data will be
233
+ downsampled back to the original sampling rate.
234
+
235
+ Args:
236
+ epsilon: Relative permittivity tensor [ny, nx].
237
+ sigma: Electrical conductivity tensor [ny, nx] in S/m.
238
+ mu: Relative permeability tensor [ny, nx].
239
+ grid_spacing: Grid spacing in meters. Single value or [dy, dx].
240
+ dt: Time step in seconds.
241
+ source_amplitude: Source waveform [n_shots, n_sources, nt].
242
+ source_location: Source locations [n_shots, n_sources, 2].
243
+ receiver_location: Receiver locations [n_shots, n_receivers, 2].
244
+ stencil: FD stencil order (2, 4, 6, or 8).
245
+ pml_width: PML width (single int or [top, bottom, left, right]).
246
+ max_vel: Maximum wave velocity. If None, computed from model.
247
+ Ey_0, Hx_0, Hz_0: Initial field values.
248
+ m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x: Initial CPML memory variables.
249
+ nt: Number of time steps (required if source_amplitude is None).
250
+ model_gradient_sampling_interval: Interval for storing gradient snapshots.
251
+ Values > 1 reduce memory usage during backpropagation.
252
+ freq_taper_frac: Fraction of frequency spectrum to taper (0.0-1.0).
253
+ Helps reduce ringing artifacts during resampling.
254
+ time_pad_frac: Fraction for zero padding before FFT (0.0-1.0).
255
+ Helps reduce wraparound artifacts during resampling.
256
+ time_taper: Whether to apply Hann window (mainly for testing).
257
+ save_snapshots: Whether to save wavefield snapshots for gradient computation.
258
+ If None (default), snapshots are saved only when model parameters
259
+ require gradients. Set to False to disable snapshot saving even
260
+ when gradients are needed. Set to True to force snapshot saving
261
+ even without gradients.
262
+ forward_callback: Callback function called during forward propagation.
263
+ backward_callback: Callback function called during backward (adjoint)
264
+ propagation. Receives the same CallbackState as forward_callback,
265
+ but with is_backward=True and gradients available.
266
+ callback_frequency: How often to call the callback.
267
+ python_backend: False for C/CUDA, True or 'eager'/'jit'/'compile' for Python.
268
+ storage_mode: Where to store intermediate snapshots for the ASM
269
+ backward pass. One of "device", "cpu", "disk", "none", or "auto".
270
+ storage_path: Base path for disk storage when storage_mode="disk".
271
+ storage_compression: Compression for stored snapshots. Use False/True
272
+ (True == BF16), or one of "bf16" / "fp8".
273
+ storage_bytes_limit_device: Soft limit in bytes for device snapshot
274
+ storage when storage_mode="auto".
275
+ storage_bytes_limit_host: Soft limit in bytes for host snapshot
276
+ storage when storage_mode="auto".
277
+ storage_chunk_steps: Optional chunk size (in stored steps) for
278
+ CPU/disk modes. Currently unused.
279
+ n_threads: OpenMP thread count for CPU backend. None uses the OpenMP default.
280
+
281
+ Returns:
282
+ Tuple of (Ey, Hx, Hz, m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x, receiver_amplitudes).
283
+ """
284
+ # Validate resampling parameters
285
+ model_gradient_sampling_interval = validate_model_gradient_sampling_interval(
286
+ model_gradient_sampling_interval
287
+ )
288
+ freq_taper_frac = validate_freq_taper_frac(freq_taper_frac)
289
+ time_pad_frac = validate_time_pad_frac(time_pad_frac)
290
+
291
+ # Check inputs
292
+ if source_location is not None and source_location.numel() > 0:
293
+ if source_location[..., 0].max() >= epsilon.shape[-2]:
294
+ raise RuntimeError(
295
+ f"Source location dim 0 must be less than {epsilon.shape[-2]}"
296
+ )
297
+ if source_location[..., 1].max() >= epsilon.shape[-1]:
298
+ raise RuntimeError(
299
+ f"Source location dim 1 must be less than {epsilon.shape[-1]}"
300
+ )
301
+
302
+ if receiver_location is not None and receiver_location.numel() > 0:
303
+ if receiver_location[..., 0].max() >= epsilon.shape[-2]:
304
+ raise RuntimeError(
305
+ f"Receiver location dim 0 must be less than {epsilon.shape[-2]}"
306
+ )
307
+ if receiver_location[..., 1].max() >= epsilon.shape[-1]:
308
+ raise RuntimeError(
309
+ f"Receiver location dim 1 must be less than {epsilon.shape[-1]}"
310
+ )
311
+
312
+ if not isinstance(callback_frequency, int):
313
+ raise TypeError("callback_frequency must be an int.")
314
+ if callback_frequency <= 0:
315
+ raise ValueError("callback_frequency must be positive.")
316
+
317
+ # Normalize grid_spacing to list
318
+ grid_spacing_list = _normalize_grid_spacing_2d(grid_spacing)
319
+
320
+ # Compute maximum velocity if not provided
321
+ if max_vel is None:
322
+ # For EM waves: v = c0 / sqrt(epsilon_r * mu_r)
323
+ max_vel_computed = float((1.0 / torch.sqrt(epsilon * mu)).max().item()) * C0
324
+ else:
325
+ max_vel_computed = max_vel
326
+
327
+ # Check CFL condition and compute step_ratio
328
+ inner_dt, step_ratio = cfl_condition(grid_spacing_list, dt, max_vel_computed)
329
+
330
+ # Upsample source if needed for CFL
331
+ source_amplitude_internal = source_amplitude
332
+ if step_ratio > 1 and source_amplitude is not None and source_amplitude.numel() > 0:
333
+ source_amplitude_internal = upsample(
334
+ source_amplitude,
335
+ step_ratio,
336
+ freq_taper_frac=freq_taper_frac,
337
+ time_pad_frac=time_pad_frac,
338
+ time_taper=time_taper,
339
+ )
340
+
341
+ # Compute internal number of time steps
342
+ nt_internal = None
343
+ if nt is not None:
344
+ nt_internal = nt * step_ratio
345
+ elif source_amplitude_internal is not None:
346
+ nt_internal = source_amplitude_internal.shape[-1]
347
+
348
+ # Call the propagation function with internal dt and upsampled source
349
+ result = maxwell_func(
350
+ python_backend,
351
+ epsilon,
352
+ sigma,
353
+ mu,
354
+ grid_spacing,
355
+ inner_dt, # Use internal time step for CFL compliance
356
+ source_amplitude_internal,
357
+ source_location,
358
+ receiver_location,
359
+ stencil,
360
+ pml_width,
361
+ max_vel_computed, # Pass computed max_vel so it's not recomputed
362
+ Ey_0,
363
+ Hx_0,
364
+ Hz_0,
365
+ m_Ey_x,
366
+ m_Ey_z,
367
+ m_Hx_z,
368
+ m_Hz_x,
369
+ nt_internal,
370
+ model_gradient_sampling_interval,
371
+ freq_taper_frac,
372
+ time_pad_frac,
373
+ time_taper,
374
+ save_snapshots,
375
+ forward_callback,
376
+ backward_callback,
377
+ callback_frequency,
378
+ storage_mode,
379
+ storage_path,
380
+ storage_compression,
381
+ storage_bytes_limit_device,
382
+ storage_bytes_limit_host,
383
+ storage_chunk_steps,
384
+ n_threads,
385
+ )
386
+
387
+ # Unpack result
388
+ (
389
+ Ey_out,
390
+ Hx_out,
391
+ Hz_out,
392
+ m_Ey_x_out,
393
+ m_Ey_z_out,
394
+ m_Hx_z_out,
395
+ m_Hz_x_out,
396
+ receiver_amplitudes,
397
+ ) = result
398
+
399
+ # Downsample receiver data if we upsampled
400
+ if step_ratio > 1 and receiver_amplitudes.numel() > 0:
401
+ receiver_amplitudes = downsample_and_movedim(
402
+ receiver_amplitudes,
403
+ step_ratio,
404
+ freq_taper_frac=freq_taper_frac,
405
+ time_pad_frac=time_pad_frac,
406
+ time_taper=time_taper,
407
+ )
408
+ # Move time back to first dimension to match expected output format
409
+ receiver_amplitudes = torch.movedim(receiver_amplitudes, -1, 0)
410
+
411
+ return (
412
+ Ey_out,
413
+ Hx_out,
414
+ Hz_out,
415
+ m_Ey_x_out,
416
+ m_Ey_z_out,
417
+ m_Hx_z_out,
418
+ m_Hz_x_out,
419
+ receiver_amplitudes,
420
+ )
421
+
422
+
423
+ _update_E_jit: Optional[Callable] = None
424
+ _update_E_compile: Optional[Callable] = None
425
+ _update_H_jit: Optional[Callable] = None
426
+ _update_H_compile: Optional[Callable] = None
427
+
428
+ # These will be set after the functions are defined
429
+ _update_E_opt: Optional[Callable] = None
430
+ _update_H_opt: Optional[Callable] = None
431
+
432
+
433
+ def maxwell_func(
434
+ python_backend: Union[bool, str],
435
+ *args,
436
+ ) -> tuple[
437
+ torch.Tensor, # Ey
438
+ torch.Tensor, # Hx
439
+ torch.Tensor, # Hz
440
+ torch.Tensor, # m_Ey_x
441
+ torch.Tensor, # m_Ey_z
442
+ torch.Tensor, # m_Hx_z
443
+ torch.Tensor, # m_Hz_x
444
+ torch.Tensor, # receiver_amplitudes
445
+ ]:
446
+ """Dispatch to Python or C/CUDA backend for Maxwell propagation."""
447
+ global _update_E_jit, _update_E_compile, _update_E_opt
448
+ global _update_H_jit, _update_H_compile, _update_H_opt
449
+
450
+ # Check if we should use Python backend or C/CUDA backend
451
+ use_python = python_backend
452
+ if not use_python:
453
+ # Try to use C/CUDA backend
454
+ try:
455
+ from . import backend_utils
456
+
457
+ if not backend_utils.is_backend_available():
458
+ import warnings
459
+
460
+ warnings.warn(
461
+ "C/CUDA backend not available, falling back to Python backend. "
462
+ "To use the C/CUDA backend, compile the library first.",
463
+ RuntimeWarning,
464
+ )
465
+ use_python = True
466
+ except ImportError:
467
+ import warnings
468
+
469
+ warnings.warn(
470
+ "backend_utils not available, falling back to Python backend.",
471
+ RuntimeWarning,
472
+ )
473
+ use_python = True
474
+
475
+ if use_python:
476
+ if python_backend is True or python_backend is False:
477
+ mode = "eager" # Default to eager
478
+ elif isinstance(python_backend, str):
479
+ mode = python_backend.lower()
480
+ else:
481
+ raise TypeError(
482
+ f"python_backend must be bool or str, but got {type(python_backend)}"
483
+ )
484
+
485
+ if mode == "jit":
486
+ if _update_E_jit is None:
487
+ _update_E_jit = torch.jit.script(update_E)
488
+ _update_E_opt = _update_E_jit
489
+ if _update_H_jit is None:
490
+ _update_H_jit = torch.jit.script(update_H)
491
+ _update_H_opt = _update_H_jit
492
+ elif mode == "compile":
493
+ if _update_E_compile is None:
494
+ _update_E_compile = torch.compile(update_E, fullgraph=True)
495
+ _update_E_opt = _update_E_compile
496
+ if _update_H_compile is None:
497
+ _update_H_compile = torch.compile(update_H, fullgraph=True)
498
+ _update_H_opt = _update_H_compile
499
+ elif mode == "eager":
500
+ _update_E_opt = update_E
501
+ _update_H_opt = update_H
502
+ else:
503
+ raise ValueError(f"Unknown python_backend value {mode!r}.")
504
+
505
+ return maxwell_python(*args)
506
+ else:
507
+ # Use C/CUDA backend
508
+ return maxwell_c_cuda(*args)
509
+
510
+
511
+ def maxwell_python(
512
+ epsilon: torch.Tensor,
513
+ sigma: torch.Tensor,
514
+ mu: torch.Tensor,
515
+ grid_spacing: Union[float, Sequence[float]],
516
+ dt: float,
517
+ source_amplitude: Optional[torch.Tensor],
518
+ source_location: Optional[torch.Tensor],
519
+ receiver_location: Optional[torch.Tensor],
520
+ stencil: int,
521
+ pml_width: Union[int, Sequence[int]],
522
+ max_vel: Optional[float],
523
+ Ey_0: Optional[torch.Tensor],
524
+ Hx_0: Optional[torch.Tensor],
525
+ Hz_0: Optional[torch.Tensor],
526
+ m_Ey_x_0: Optional[torch.Tensor],
527
+ m_Ey_z_0: Optional[torch.Tensor],
528
+ m_Hx_z_0: Optional[torch.Tensor],
529
+ m_Hz_x_0: Optional[torch.Tensor],
530
+ nt: Optional[int],
531
+ model_gradient_sampling_interval: int,
532
+ freq_taper_frac: float,
533
+ time_pad_frac: float,
534
+ time_taper: bool,
535
+ save_snapshots: Optional[bool],
536
+ forward_callback: Optional[Callback],
537
+ backward_callback: Optional[Callback],
538
+ callback_frequency: int,
539
+ storage_mode: str = "device",
540
+ storage_path: str = ".",
541
+ storage_compression: Union[bool, str] = False,
542
+ storage_bytes_limit_device: Optional[int] = None,
543
+ storage_bytes_limit_host: Optional[int] = None,
544
+ storage_chunk_steps: int = 0,
545
+ n_threads: Optional[int] = None,
546
+ ):
547
+ """Performs the forward propagation of the 2D TM Maxwell equations.
548
+
549
+ This function implements the FDTD time-stepping loop for the TM mode
550
+ (Ey, Hx, Hz) with CPML absorbing boundary conditions.
551
+
552
+ - Models are padded by fd_pad + pml_width with replicate mode
553
+ - Wavefields are padded by fd_pad only with zero padding
554
+ - Output wavefields are cropped by fd_pad only (PML region is preserved)
555
+
556
+ Args:
557
+ epsilon: Permittivity model [ny, nx].
558
+ sigma: Conductivity model [ny, nx].
559
+ mu: Permeability model [ny, nx].
560
+ grid_spacing: Grid spacing (dy, dx) or single value for both.
561
+ dt: Time step.
562
+ source_amplitude: Source amplitudes [n_shots, n_sources, nt].
563
+ source_location: Source locations [n_shots, n_sources, 2].
564
+ receiver_location: Receiver locations [n_shots, n_receivers, 2].
565
+ stencil: Finite difference stencil order (2, 4, 6, or 8).
566
+ pml_width: PML width on each side [top, bottom, left, right] or single value.
567
+ max_vel: Maximum velocity for PML (if None, computed from model).
568
+ Ey_0, Hx_0, Hz_0: Initial field values.
569
+ m_Ey_x_0, m_Ey_z_0, m_Hx_z_0, m_Hz_x_0: Initial CPML memory variables.
570
+ nt: Number of time steps (required if source_amplitude is None).
571
+ model_gradient_sampling_interval: Interval for storing gradients.
572
+ freq_taper_frac: Frequency taper fraction.
573
+ time_pad_frac: Time padding fraction.
574
+ time_taper: Whether to apply time taper.
575
+ save_snapshots: Whether to save wavefield snapshots for backward pass.
576
+ If None, determined by requires_grad on model parameters.
577
+ forward_callback: Callback function called during propagation.
578
+ callback_frequency: Frequency of callback calls.
579
+ Returns:
580
+ Tuple containing:
581
+ - Ey: Final electric field [n_shots, ny + pml, nx + pml]
582
+ - Hx, Hz: Final magnetic fields
583
+ - m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x: Final CPML memory variables
584
+ - receiver_amplitudes: Recorded data at receivers [nt, n_shots, n_receivers]
585
+ """
586
+
587
+ from .padding import create_or_pad, zero_interior
588
+
589
+ # These should be set by maxwell_func before calling this function
590
+ assert _update_E_opt is not None, "_update_E_opt must be set by maxwell_func"
591
+ assert _update_H_opt is not None, "_update_H_opt must be set by maxwell_func"
592
+
593
+ # Validate inputs
594
+ if epsilon.ndim != 2:
595
+ raise RuntimeError("epsilon must be 2D")
596
+ if sigma.shape != epsilon.shape:
597
+ raise RuntimeError("sigma must have same shape as epsilon")
598
+ if mu.shape != epsilon.shape:
599
+ raise RuntimeError("mu must have same shape as epsilon")
600
+
601
+ device = epsilon.device
602
+ dtype = epsilon.dtype
603
+ model_ny, model_nx = epsilon.shape # Original model dimensions
604
+
605
+ storage_mode_str = storage_mode.lower()
606
+ if storage_mode_str in {"cpu", "disk"}:
607
+ raise ValueError(
608
+ "python_backend does not support storage_mode='cpu' or 'disk'. "
609
+ "Use the C/CUDA backend or storage_mode='device'/'none'."
610
+ )
611
+ storage_kind = _normalize_storage_compression(storage_compression)
612
+ if storage_kind != "none":
613
+ raise NotImplementedError(
614
+ "storage_compression is not implemented yet; set storage_compression=False."
615
+ )
616
+
617
+ # Normalize grid_spacing to list
618
+ grid_spacing = _normalize_grid_spacing_2d(grid_spacing)
619
+ dy, dx = grid_spacing
620
+
621
+ # Normalize pml_width to list [top, bottom, left, right]
622
+ pml_width_list = _normalize_pml_width_2d(pml_width)
623
+
624
+ # Determine number of time steps
625
+ if nt is None:
626
+ if source_amplitude is None:
627
+ raise ValueError("Either nt or source_amplitude must be provided")
628
+ nt = source_amplitude.shape[-1]
629
+
630
+ # Type cast to ensure nt is int for type checker
631
+ nt_steps: int = int(nt)
632
+
633
+ # Determine number of shots
634
+ if source_amplitude is not None and source_amplitude.numel() > 0:
635
+ n_shots = source_amplitude.shape[0]
636
+ elif source_location is not None and source_location.numel() > 0:
637
+ n_shots = source_location.shape[0]
638
+ elif receiver_location is not None and receiver_location.numel() > 0:
639
+ n_shots = receiver_location.shape[0]
640
+ else:
641
+ n_shots = 1
642
+
643
+ # Compute maximum velocity for PML if not provided
644
+ if max_vel is None:
645
+ # For EM waves: v = c0 / sqrt(epsilon_r * mu_r)
646
+ max_vel = float((1.0 / torch.sqrt(epsilon * mu)).max().item()) * C0
647
+
648
+ # Compute PML frequency (dominant frequency estimate)
649
+ pml_freq = 0.5 / dt # Nyquist as default
650
+
651
+ # =========================================================================
652
+ # Padding strategy:
653
+ # - fd_pad: padding for finite difference stencil accuracy
654
+ # - pml_width: padding for PML absorbing layers
655
+ # - Total model padding = fd_pad + pml_width
656
+ # - Wavefield padding = fd_pad only (wavefields include PML region)
657
+ # =========================================================================
658
+
659
+ # FD padding based on stencil: accuracy // 2
660
+ fd_pad = stencil // 2
661
+ # fd_pad_list: [y0, y1, x0, x1] - for 2D staggered grid, asymmetric because
662
+ # staggered diff a[1:] - a[:-1] reduces array size by 1, so we need fd_pad-1 at end
663
+ fd_pad_list = [fd_pad, fd_pad - 1, fd_pad, fd_pad - 1]
664
+
665
+ # Total padding for models = fd_pad + pml_width
666
+ total_pad = [fd + pml for fd, pml in zip(fd_pad_list, pml_width_list)]
667
+
668
+ # Calculate padded dimensions
669
+ # Model is padded by total_pad on each side
670
+ padded_ny = model_ny + total_pad[0] + total_pad[1]
671
+ padded_nx = model_nx + total_pad[2] + total_pad[3]
672
+
673
+ # Pad model tensors with replicate mode (extend boundary values)
674
+ padded_size = (padded_ny, padded_nx)
675
+ epsilon_padded = create_or_pad(
676
+ epsilon, total_pad, device, dtype, padded_size, mode="replicate"
677
+ )
678
+ sigma_padded = create_or_pad(
679
+ sigma, total_pad, device, dtype, padded_size, mode="replicate"
680
+ )
681
+ mu_padded = create_or_pad(
682
+ mu, total_pad, device, dtype, padded_size, mode="replicate"
683
+ )
684
+
685
+ # Prepare update coefficients using padded models
686
+ ca, cb, cq = prepare_parameters(epsilon_padded, sigma_padded, mu_padded, dt)
687
+
688
+ # Expand coefficients for batch dimension
689
+ ca = ca[None, :, :] # [1, padded_ny, padded_nx]
690
+ cb = cb[None, :, :]
691
+ cq = cq[None, :, :]
692
+
693
+ # =========================================================================
694
+ # Initialize wavefields
695
+ # Wavefields are padded by fd_pad only (they include the PML region)
696
+ # Size = [n_shots, model_ny + pml_width*2 + fd_pad*2, model_nx + ...]
697
+ # Which equals [n_shots, padded_ny, padded_nx]
698
+ # =========================================================================
699
+ size_with_batch = (n_shots, padded_ny, padded_nx)
700
+
701
+ # Helper function to initialize wavefields with fd_pad padding
702
+ def init_wavefield(field_0: Optional[torch.Tensor]) -> torch.Tensor:
703
+ """Initialize wavefield with fd_pad zero padding.
704
+
705
+ Zero padding is used for wavefields because the fd_pad region should
706
+ always be zero after output cropping and re-padding. The staggered grid
707
+ operators only read from this region but don't need non-zero values there
708
+ for correct propagation.
709
+ """
710
+ if field_0 is not None:
711
+ # User provides [n_shots, ny, nx] or [ny, nx]
712
+ if field_0.ndim == 2:
713
+ field_0 = field_0[None, :, :].expand(n_shots, -1, -1)
714
+ # Pad with asymmetric fd_pad_list for staggered grid (zero padding)
715
+ return create_or_pad(
716
+ field_0, fd_pad_list, device, dtype, size_with_batch, mode="constant"
717
+ )
718
+ return torch.zeros(size_with_batch, device=device, dtype=dtype)
719
+
720
+ Ey = init_wavefield(Ey_0)
721
+ Hx = init_wavefield(Hx_0)
722
+ Hz = init_wavefield(Hz_0)
723
+ m_Ey_x = init_wavefield(m_Ey_x_0)
724
+ m_Ey_z = init_wavefield(m_Ey_z_0)
725
+ m_Hx_z = init_wavefield(m_Hx_z_0)
726
+ m_Hz_x = init_wavefield(m_Hz_x_0)
727
+
728
+ # Zero out interior of PML auxiliary variables (optimization)
729
+ # PML memory variables should only be non-zero in PML regions.
730
+ # This works correctly even with user-provided initial states because:
731
+ # 1. The output preserves PML region (only fd_pad is cropped)
732
+ # 2. zero_interior only zeros the interior, preserving PML boundary values
733
+ # 3. Interior values are already zero in correctly propagated wavefields
734
+ # Dimension mapping for zero_interior:
735
+ # - m_Ey_x: x-direction auxiliary -> dim=1 (zero y-interior, keep x-boundaries)
736
+ # - m_Ey_z: y/z-direction auxiliary -> dim=0 (zero x-interior, keep y-boundaries)
737
+ # - m_Hx_z: y/z-direction auxiliary -> dim=0
738
+ # - m_Hz_x: x-direction auxiliary -> dim=1
739
+ pml_aux_dims = [1, 0, 0, 1] # [m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x]
740
+ for wf, dim in zip([m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x], pml_aux_dims):
741
+ zero_interior(wf, fd_pad_list, pml_width_list, dim)
742
+
743
+ # Set up PML profiles for the padded domain
744
+ pml_profiles_list = staggered.set_pml_profiles(
745
+ pml_width=pml_width_list,
746
+ accuracy=stencil,
747
+ fd_pad=fd_pad_list,
748
+ dt=dt,
749
+ grid_spacing=grid_spacing,
750
+ max_vel=max_vel,
751
+ dtype=dtype,
752
+ device=device,
753
+ pml_freq=pml_freq,
754
+ ny=padded_ny,
755
+ nx=padded_nx,
756
+ )
757
+ # pml_profiles_list = [ay, ayh, ax, axh, by, byh, bx, bxh, ky, kyh, kx, kxh]
758
+ (
759
+ ay,
760
+ ay_h,
761
+ ax,
762
+ ax_h,
763
+ by,
764
+ by_h,
765
+ bx,
766
+ bx_h,
767
+ kappa_y,
768
+ kappa_y_h,
769
+ kappa_x,
770
+ kappa_x_h,
771
+ ) = pml_profiles_list
772
+
773
+ # Reciprocal grid spacing
774
+ rdy = torch.tensor(1.0 / dy, device=device, dtype=dtype)
775
+ rdx = torch.tensor(1.0 / dx, device=device, dtype=dtype)
776
+ dt_tensor = torch.tensor(dt, device=device, dtype=dtype)
777
+
778
+ # =========================================================================
779
+ # Prepare source and receiver indices
780
+ # Original positions are in the un-padded model coordinate system.
781
+ # We need to offset by total_pad (fd_pad + pml_width) to get padded coords.
782
+ # =========================================================================
783
+ flat_model_shape = padded_ny * padded_nx
784
+
785
+ if source_location is not None and source_location.numel() > 0:
786
+ # Adjust source positions by total padding offset
787
+ source_y = source_location[..., 0] + total_pad[0] # Add top offset
788
+ source_x = source_location[..., 1] + total_pad[2] # Add left offset
789
+ sources_i = (source_y * padded_nx + source_x).long() # [n_shots, n_sources]
790
+ n_sources = source_location.shape[1]
791
+ else:
792
+ sources_i = torch.empty(0, device=device, dtype=torch.long)
793
+ n_sources = 0
794
+
795
+ if receiver_location is not None and receiver_location.numel() > 0:
796
+ # Adjust receiver positions by total padding offset
797
+ receiver_y = receiver_location[..., 0] + total_pad[0] # Add top offset
798
+ receiver_x = receiver_location[..., 1] + total_pad[2] # Add left offset
799
+ receivers_i = (receiver_y * padded_nx + receiver_x).long()
800
+ n_receivers = receiver_location.shape[1]
801
+ else:
802
+ receivers_i = torch.empty(0, device=device, dtype=torch.long)
803
+ n_receivers = 0
804
+
805
+ # Initialize receiver amplitudes
806
+ if n_receivers > 0:
807
+ receiver_amplitudes = torch.zeros(
808
+ nt_steps, n_shots, n_receivers, device=device, dtype=dtype
809
+ )
810
+ else:
811
+ receiver_amplitudes = torch.empty(0, device=device, dtype=dtype)
812
+
813
+ # Prepare callback data - models dict uses the padded models
814
+ callback_models = {
815
+ "epsilon": epsilon_padded,
816
+ "sigma": sigma_padded,
817
+ "mu": mu_padded,
818
+ "ca": ca,
819
+ "cb": cb,
820
+ "cq": cq,
821
+ }
822
+
823
+ # Callback fd_pad is the actual fd_pad used for wavefields
824
+ callback_fd_pad = fd_pad_list
825
+
826
+ # Source injection coefficient: -cb * dt / (dx * dy)
827
+ # Since our cb already contains dt/epsilon, we need: -cb / (dx * dy)
828
+ # This normalizes the source by cell volume for correct amplitude
829
+ source_coeff = -1.0 / (dx * dy)
830
+
831
+ # Time stepping loop
832
+ for step in range(nt_steps):
833
+ # Callback at specified frequency
834
+ if forward_callback is not None and step % callback_frequency == 0:
835
+ callback_wavefields = {
836
+ "Ey": Ey,
837
+ "Hx": Hx,
838
+ "Hz": Hz,
839
+ "m_Ey_x": m_Ey_x,
840
+ "m_Ey_z": m_Ey_z,
841
+ "m_Hx_z": m_Hx_z,
842
+ "m_Hz_x": m_Hz_x,
843
+ }
844
+ # Create CallbackState for standardized interface
845
+ callback_state = CallbackState(
846
+ dt=dt,
847
+ step=step,
848
+ nt=nt_steps,
849
+ wavefields=callback_wavefields,
850
+ models=callback_models,
851
+ gradients=None, # No gradients during forward pass
852
+ fd_pad=callback_fd_pad,
853
+ pml_width=pml_width_list,
854
+ is_backward=False,
855
+ grid_spacing=[dy, dx],
856
+ )
857
+ forward_callback(callback_state)
858
+
859
+ # Update H fields: H^{n+1/2} = H^{n-1/2} + ...
860
+ Hx, Hz, m_Ey_x, m_Ey_z = _update_H_opt(
861
+ cq,
862
+ Hx,
863
+ Hz,
864
+ Ey,
865
+ m_Ey_x,
866
+ m_Ey_z,
867
+ kappa_y,
868
+ kappa_y_h,
869
+ kappa_x,
870
+ kappa_x_h,
871
+ ay,
872
+ ay_h,
873
+ ax,
874
+ ax_h,
875
+ by,
876
+ by_h,
877
+ bx,
878
+ bx_h,
879
+ rdy,
880
+ rdx,
881
+ dt_tensor,
882
+ stencil,
883
+ )
884
+
885
+ # Update E field: E^{n+1} = E^n + ...
886
+ Ey, m_Hx_z, m_Hz_x = _update_E_opt(
887
+ ca,
888
+ cb,
889
+ Hx,
890
+ Hz,
891
+ Ey,
892
+ m_Hx_z,
893
+ m_Hz_x,
894
+ kappa_y,
895
+ kappa_y_h,
896
+ kappa_x,
897
+ kappa_x_h,
898
+ ay,
899
+ ay_h,
900
+ ax,
901
+ ax_h,
902
+ by,
903
+ by_h,
904
+ bx,
905
+ bx_h,
906
+ rdy,
907
+ rdx,
908
+ dt_tensor,
909
+ stencil,
910
+ )
911
+
912
+ # Inject source into Ey field (after E update, following reference implementation)
913
+ # Source term: Ey += -cb * f * dt / (dx * dz) = -cb * f / (dx * dz) since cb contains dt
914
+ if (
915
+ source_amplitude is not None
916
+ and source_amplitude.numel() > 0
917
+ and n_sources > 0
918
+ ):
919
+ # source_amplitude: [n_shots, n_sources, nt]
920
+ src_amp = source_amplitude[:, :, step] # [n_shots, n_sources]
921
+ # Get cb at source locations for proper scaling
922
+ cb_flat = cb.reshape(1, flat_model_shape).expand(n_shots, -1)
923
+ cb_at_src = cb_flat.gather(1, sources_i) # [n_shots, n_sources]
924
+ # Apply source with coefficient: -cb * f / (dx * dy)
925
+ scaled_src = cb_at_src * src_amp * source_coeff
926
+ Ey = (
927
+ Ey.reshape(n_shots, flat_model_shape)
928
+ .scatter_add(1, sources_i, scaled_src)
929
+ .reshape(size_with_batch)
930
+ )
931
+
932
+ # Record at receivers (after source injection)
933
+ if n_receivers > 0:
934
+ receiver_amplitudes[step] = Ey.reshape(n_shots, flat_model_shape).gather(
935
+ 1, receivers_i
936
+ )
937
+
938
+ # =========================================================================
939
+ # Output cropping:
940
+ # Only remove fd_pad, keep the PML region in the output wavefields.
941
+ # Output shape: [n_shots, model_ny + pml_width_y, model_nx + pml_width_x]
942
+ # =========================================================================
943
+ s = (
944
+ slice(None), # batch dimension
945
+ slice(
946
+ fd_pad_list[0], padded_ny - fd_pad_list[1] if fd_pad_list[1] > 0 else None
947
+ ),
948
+ slice(
949
+ fd_pad_list[2], padded_nx - fd_pad_list[3] if fd_pad_list[3] > 0 else None
950
+ ),
951
+ )
952
+
953
+ return (
954
+ Ey[s],
955
+ Hx[s],
956
+ Hz[s],
957
+ m_Ey_x[s],
958
+ m_Ey_z[s],
959
+ m_Hx_z[s],
960
+ m_Hz_x[s],
961
+ receiver_amplitudes,
962
+ )
963
+
964
+
965
+ def update_E(
966
+ ca: torch.Tensor,
967
+ cb: torch.Tensor,
968
+ Hx: torch.Tensor,
969
+ Hz: torch.Tensor,
970
+ Ey: torch.Tensor,
971
+ m_Hx_z: torch.Tensor,
972
+ m_Hz_x: torch.Tensor,
973
+ kappa_y: torch.Tensor,
974
+ kappa_y_h: torch.Tensor,
975
+ kappa_x: torch.Tensor,
976
+ kappa_x_h: torch.Tensor,
977
+ ay: torch.Tensor,
978
+ ay_h: torch.Tensor,
979
+ ax: torch.Tensor,
980
+ ax_h: torch.Tensor,
981
+ by: torch.Tensor,
982
+ by_h: torch.Tensor,
983
+ bx: torch.Tensor,
984
+ bx_h: torch.Tensor,
985
+ rdy: torch.Tensor,
986
+ rdx: torch.Tensor,
987
+ dt: torch.Tensor,
988
+ stencil: int,
989
+ ) -> tuple[
990
+ torch.Tensor,
991
+ torch.Tensor,
992
+ torch.Tensor,
993
+ ]:
994
+ """Update electric field Ey with CPML absorbing boundary conditions.
995
+
996
+ For TM mode, the update equation is:
997
+ Ey^{n+1} = Ca * Ey^n + Cb * (dHz/dx - dHx/dz)
998
+
999
+ With CPML, we split the derivatives and apply auxiliary variables:
1000
+ dHz/dx -> dHz/dx / kappa_x + m_Hz_x
1001
+ dHx/dz -> dHx/dz / kappa_y + m_Hx_z
1002
+
1003
+ Args:
1004
+ ca, cb: Update coefficients from material parameters
1005
+ Hx, Hz: Magnetic field components
1006
+ Ey: Electric field component to update
1007
+ m_Hx_z, m_Hz_x: CPML auxiliary memory variables
1008
+ kappa_y, kappa_y_h: CPML kappa profiles in y direction
1009
+ kappa_x, kappa_x_h: CPML kappa profiles in x direction
1010
+ ay, ay_h, ax, ax_h: CPML a coefficients
1011
+ by, by_h, bx, bx_h: CPML b coefficients
1012
+ rdy, rdx: Reciprocal of grid spacing (1/dy, 1/dx)
1013
+ dt: Time step
1014
+ stencil: Finite difference stencil order (2, 4, 6, or 8)
1015
+
1016
+ Returns:
1017
+ Updated Ey, m_Hx_z, m_Hz_x
1018
+ """
1019
+
1020
+ # Compute spatial derivatives using staggered grid operators
1021
+ # dHz/dx at integer grid points (where Ey lives)
1022
+ dHz_dx = staggered.diffx1(Hz, stencil, rdx)
1023
+ # dHx/dz at integer grid points (where Ey lives)
1024
+ dHx_dz = staggered.diffy1(Hx, stencil, rdy)
1025
+
1026
+ # Update CPML auxiliary variables using standard CPML recursion:
1027
+ # psi_new = b * psi_old + a * derivative
1028
+ # m_Hz_x stores the x-direction PML memory for Hz derivative
1029
+ m_Hz_x = bx * m_Hz_x + ax * dHz_dx
1030
+ # m_Hx_z stores the z-direction PML memory for Hx derivative
1031
+ m_Hx_z = by * m_Hx_z + ay * dHx_dz
1032
+
1033
+ # Apply CPML correction to derivatives
1034
+ # In CPML: d/dx -> (1/kappa) * d/dx + m
1035
+ dHz_dx_pml = dHz_dx / kappa_x + m_Hz_x
1036
+ dHx_dz_pml = dHx_dz / kappa_y + m_Hx_z
1037
+
1038
+ # Update Ey using the FDTD update equation
1039
+ # Ey^{n+1} = Ca * Ey^n + Cb * (dHz/dx - dHx/dz)
1040
+ Ey = ca * Ey + cb * (dHz_dx_pml - dHx_dz_pml)
1041
+
1042
+ return Ey, m_Hx_z, m_Hz_x
1043
+
1044
+
1045
+ def update_H(
1046
+ cq: torch.Tensor,
1047
+ Hx: torch.Tensor,
1048
+ Hz: torch.Tensor,
1049
+ Ey: torch.Tensor,
1050
+ m_Ey_x: torch.Tensor,
1051
+ m_Ey_z: torch.Tensor,
1052
+ kappa_y: torch.Tensor,
1053
+ kappa_y_h: torch.Tensor,
1054
+ kappa_x: torch.Tensor,
1055
+ kappa_x_h: torch.Tensor,
1056
+ ay: torch.Tensor,
1057
+ ay_h: torch.Tensor,
1058
+ ax: torch.Tensor,
1059
+ ax_h: torch.Tensor,
1060
+ by: torch.Tensor,
1061
+ by_h: torch.Tensor,
1062
+ bx: torch.Tensor,
1063
+ bx_h: torch.Tensor,
1064
+ rdy: torch.Tensor,
1065
+ rdx: torch.Tensor,
1066
+ dt: torch.Tensor,
1067
+ stencil: int,
1068
+ ) -> tuple[
1069
+ torch.Tensor,
1070
+ torch.Tensor,
1071
+ torch.Tensor,
1072
+ torch.Tensor,
1073
+ ]:
1074
+ """Update magnetic fields Hx and Hz with CPML absorbing boundary conditions.
1075
+
1076
+ For TM mode, the update equations are:
1077
+ Hx^{n+1} = Hx^n - Cq * dEy/dz
1078
+ Hz^{n+1} = Hz^n + Cq * dEy/dx
1079
+
1080
+ With CPML, we use half-grid derivatives and auxiliary variables:
1081
+ dEy/dz -> dEy/dz / kappa_y_h + m_Ey_z
1082
+ dEy/dx -> dEy/dx / kappa_x_h + m_Ey_x
1083
+
1084
+ Args:
1085
+ cq: Update coefficient (dt/mu)
1086
+ Hx, Hz: Magnetic field components to update
1087
+ Ey: Electric field component
1088
+ m_Ey_x, m_Ey_z: CPML auxiliary memory variables
1089
+ kappa_y, kappa_y_h: CPML kappa profiles in y direction (integer and half grid)
1090
+ kappa_x, kappa_x_h: CPML kappa profiles in x direction (integer and half grid)
1091
+ ay, ay_h, ax, ax_h: CPML a coefficients
1092
+ by, by_h, bx, bx_h: CPML b coefficients
1093
+ rdy, rdx: Reciprocal of grid spacing (1/dy, 1/dx)
1094
+ dt: Time step
1095
+ stencil: Finite difference stencil order (2, 4, 6, or 8)
1096
+
1097
+ Returns:
1098
+ Updated Hx, Hz, m_Ey_x, m_Ey_z
1099
+ """
1100
+
1101
+ # Compute spatial derivatives at half grid points (where H fields live)
1102
+ # dEy/dz at half grid points in z (for Hx update)
1103
+ dEy_dz = staggered.diffyh1(Ey, stencil, rdy)
1104
+ # dEy/dx at half grid points in x (for Hz update)
1105
+ dEy_dx = staggered.diffxh1(Ey, stencil, rdx)
1106
+
1107
+ # Update CPML auxiliary variables using standard CPML recursion:
1108
+ # psi_new = b * psi_old + a * derivative
1109
+ # m_Ey_z stores the z-direction PML memory for Ey derivative (used in Hx update)
1110
+ m_Ey_z = by_h * m_Ey_z + ay_h * dEy_dz
1111
+ # m_Ey_x stores the x-direction PML memory for Ey derivative (used in Hz update)
1112
+ m_Ey_x = bx_h * m_Ey_x + ax_h * dEy_dx
1113
+
1114
+ # Apply CPML correction to derivatives
1115
+ # In CPML: d/dz -> (1/kappa_h) * d/dz + m
1116
+ dEy_dz_pml = dEy_dz / kappa_y_h + m_Ey_z
1117
+ dEy_dx_pml = dEy_dx / kappa_x_h + m_Ey_x
1118
+
1119
+ # Update Hx using the FDTD update equation
1120
+ # Hx^{n+1} = Hx^n - Cq * dEy/dz
1121
+ Hx = Hx - cq * dEy_dz_pml
1122
+
1123
+ # Update Hz using the FDTD update equation
1124
+ # Hz^{n+1} = Hz^n + Cq * dEy/dx
1125
+ Hz = Hz + cq * dEy_dx_pml
1126
+
1127
+ return Hx, Hz, m_Ey_x, m_Ey_z
1128
+
1129
+
1130
+ # Initialize the optimized function pointers to the default implementations
1131
+ _update_E_opt = update_E
1132
+ _update_H_opt = update_H
1133
+
1134
+
1135
+ def maxwell_c_cuda(
1136
+ epsilon: torch.Tensor,
1137
+ sigma: torch.Tensor,
1138
+ mu: torch.Tensor,
1139
+ grid_spacing: Union[float, Sequence[float]],
1140
+ dt: float,
1141
+ source_amplitude: Optional[torch.Tensor],
1142
+ source_location: Optional[torch.Tensor],
1143
+ receiver_location: Optional[torch.Tensor],
1144
+ stencil: int,
1145
+ pml_width: Union[int, Sequence[int]],
1146
+ max_vel: Optional[float],
1147
+ Ey_0: Optional[torch.Tensor],
1148
+ Hx_0: Optional[torch.Tensor],
1149
+ Hz_0: Optional[torch.Tensor],
1150
+ m_Ey_x_0: Optional[torch.Tensor],
1151
+ m_Ey_z_0: Optional[torch.Tensor],
1152
+ m_Hx_z_0: Optional[torch.Tensor],
1153
+ m_Hz_x_0: Optional[torch.Tensor],
1154
+ nt: Optional[int],
1155
+ model_gradient_sampling_interval: int,
1156
+ freq_taper_frac: float,
1157
+ time_pad_frac: float,
1158
+ time_taper: bool,
1159
+ save_snapshots: Optional[bool],
1160
+ forward_callback: Optional[Callback],
1161
+ backward_callback: Optional[Callback],
1162
+ callback_frequency: int,
1163
+ storage_mode: str = "device",
1164
+ storage_path: str = ".",
1165
+ storage_compression: Union[bool, str] = False,
1166
+ storage_bytes_limit_device: Optional[int] = None,
1167
+ storage_bytes_limit_host: Optional[int] = None,
1168
+ storage_chunk_steps: int = 0,
1169
+ n_threads: Optional[int] = None,
1170
+ ):
1171
+ """Performs Maxwell propagation using C/CUDA backend.
1172
+
1173
+ This function provides the interface to the compiled C/CUDA implementations
1174
+ for high-performance wave propagation.
1175
+
1176
+ Padding strategy:
1177
+ - Models are padded by fd_pad + pml_width with replicate mode
1178
+ - Wavefields are padded by fd_pad only with zero padding
1179
+ - Output wavefields are cropped by fd_pad only (PML region is preserved)
1180
+
1181
+ Args:
1182
+ Same as maxwell_python.
1183
+
1184
+ Returns:
1185
+ Same as maxwell_python.
1186
+ """
1187
+ from . import backend_utils, staggered
1188
+ from .padding import create_or_pad, zero_interior
1189
+
1190
+ # Validate inputs
1191
+ if epsilon.ndim != 2:
1192
+ raise RuntimeError("epsilon must be 2D")
1193
+ if sigma.shape != epsilon.shape:
1194
+ raise RuntimeError("sigma must have same shape as epsilon")
1195
+ if mu.shape != epsilon.shape:
1196
+ raise RuntimeError("mu must have same shape as epsilon")
1197
+
1198
+ device = epsilon.device
1199
+ dtype = epsilon.dtype
1200
+ model_ny, model_nx = epsilon.shape # Original model dimensions
1201
+
1202
+ # Normalize grid_spacing to list
1203
+ grid_spacing = _normalize_grid_spacing_2d(grid_spacing)
1204
+ dy, dx = grid_spacing
1205
+
1206
+ n_threads_val = 0
1207
+ if n_threads is not None:
1208
+ n_threads_val = int(n_threads)
1209
+ if n_threads_val < 0:
1210
+ raise ValueError("n_threads must be >= 0 when provided.")
1211
+
1212
+ # Normalize pml_width to list [top, bottom, left, right]
1213
+ pml_width_list = _normalize_pml_width_2d(pml_width)
1214
+
1215
+ # Determine number of time steps
1216
+ if nt is None:
1217
+ if source_amplitude is None:
1218
+ raise ValueError("Either nt or source_amplitude must be provided")
1219
+ nt = source_amplitude.shape[-1]
1220
+
1221
+ # Ensure nt is an integer for iteration
1222
+ nt_steps: int = int(nt)
1223
+ # Clamp gradient sampling interval to a sensible range for storage/backprop
1224
+ gradient_sampling_interval = int(model_gradient_sampling_interval)
1225
+ if gradient_sampling_interval < 1:
1226
+ gradient_sampling_interval = 1
1227
+ if nt_steps > 0:
1228
+ gradient_sampling_interval = min(gradient_sampling_interval, nt_steps)
1229
+
1230
+ # Determine number of shots
1231
+ if source_amplitude is not None and source_amplitude.numel() > 0:
1232
+ n_shots = source_amplitude.shape[0]
1233
+ elif source_location is not None and source_location.numel() > 0:
1234
+ n_shots = source_location.shape[0]
1235
+ elif receiver_location is not None and receiver_location.numel() > 0:
1236
+ n_shots = receiver_location.shape[0]
1237
+ else:
1238
+ n_shots = 1
1239
+
1240
+ # Compute maximum velocity for PML if not provided
1241
+ if max_vel is None:
1242
+ max_vel = float((1.0 / torch.sqrt(epsilon * mu)).max().item()) * C0
1243
+
1244
+ # Compute PML frequency
1245
+ pml_freq = 0.5 / dt
1246
+
1247
+ # =========================================================================
1248
+ # Padding strategy:
1249
+ # - fd_pad: padding for finite difference stencil accuracy
1250
+ # - pml_width: padding for PML absorbing layers
1251
+ # - Total model padding = fd_pad + pml_width
1252
+ # - Wavefield padding = fd_pad only (wavefields include PML region)
1253
+ # =========================================================================
1254
+
1255
+ # FD padding based on stencil: accuracy // 2
1256
+ fd_pad = stencil // 2
1257
+ # fd_pad_list: [y0, y1, x0, x1] - asymmetric for staggered grid
1258
+ fd_pad_list = [fd_pad, fd_pad - 1, fd_pad, fd_pad - 1]
1259
+
1260
+ # Total padding for models = fd_pad + pml_width
1261
+ total_pad = [fd + pml for fd, pml in zip(fd_pad_list, pml_width_list)]
1262
+
1263
+ # Calculate padded dimensions
1264
+ padded_ny = model_ny + total_pad[0] + total_pad[1]
1265
+ padded_nx = model_nx + total_pad[2] + total_pad[3]
1266
+
1267
+ # Pad model tensors with replicate mode (extend boundary values)
1268
+ padded_size = (padded_ny, padded_nx)
1269
+ epsilon_padded = create_or_pad(
1270
+ epsilon, total_pad, device, dtype, padded_size, mode="replicate"
1271
+ )
1272
+ sigma_padded = create_or_pad(
1273
+ sigma, total_pad, device, dtype, padded_size, mode="replicate"
1274
+ )
1275
+ mu_padded = create_or_pad(
1276
+ mu, total_pad, device, dtype, padded_size, mode="replicate"
1277
+ )
1278
+
1279
+ # Prepare update coefficients using padded models
1280
+ ca, cb, cq = prepare_parameters(epsilon_padded, sigma_padded, mu_padded, dt)
1281
+
1282
+ # Initialize fields with padded dimensions
1283
+ size_with_batch = (n_shots, padded_ny, padded_nx)
1284
+
1285
+ def init_wavefield(field_0: Optional[torch.Tensor]) -> torch.Tensor:
1286
+ """Initialize wavefield with fd_pad zero padding."""
1287
+ if field_0 is not None:
1288
+ if field_0.ndim == 2:
1289
+ field_0 = field_0[None, :, :].expand(n_shots, -1, -1)
1290
+ # Pad with asymmetric fd_pad_list for staggered grid
1291
+ return create_or_pad(
1292
+ field_0, fd_pad_list, device, dtype, size_with_batch
1293
+ ).contiguous()
1294
+ return torch.zeros(size_with_batch, device=device, dtype=dtype)
1295
+
1296
+ Ey = init_wavefield(Ey_0)
1297
+ Hx = init_wavefield(Hx_0)
1298
+ Hz = init_wavefield(Hz_0)
1299
+ m_Ey_x = init_wavefield(m_Ey_x_0)
1300
+ m_Ey_z = init_wavefield(m_Ey_z_0)
1301
+ m_Hx_z = init_wavefield(m_Hx_z_0)
1302
+ m_Hz_x = init_wavefield(m_Hz_x_0)
1303
+
1304
+ # Zero out interior of PML auxiliary variables (optimization)
1305
+ # This works correctly with user-provided states (see forward pass comments)
1306
+ pml_aux_dims = [1, 0, 0, 1] # [m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x]
1307
+ for wf, dim in zip([m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x], pml_aux_dims):
1308
+ zero_interior(wf, fd_pad_list, pml_width_list, dim)
1309
+
1310
+ # Set up PML profiles for the padded domain
1311
+ pml_profiles_list = staggered.set_pml_profiles(
1312
+ pml_width=pml_width_list,
1313
+ accuracy=stencil,
1314
+ fd_pad=fd_pad_list,
1315
+ dt=dt,
1316
+ grid_spacing=grid_spacing,
1317
+ max_vel=max_vel,
1318
+ dtype=dtype,
1319
+ device=device,
1320
+ pml_freq=pml_freq,
1321
+ ny=padded_ny,
1322
+ nx=padded_nx,
1323
+ )
1324
+ (
1325
+ ay,
1326
+ ay_h,
1327
+ ax,
1328
+ ax_h,
1329
+ by,
1330
+ by_h,
1331
+ bx,
1332
+ bx_h,
1333
+ ky,
1334
+ ky_h,
1335
+ kx,
1336
+ kx_h,
1337
+ ) = pml_profiles_list
1338
+
1339
+ # Flatten PML profiles for C backend (remove batch dimensions)
1340
+ ay_flat = ay.squeeze().contiguous()
1341
+ ay_h_flat = ay_h.squeeze().contiguous()
1342
+ ax_flat = ax.squeeze().contiguous()
1343
+ ax_h_flat = ax_h.squeeze().contiguous()
1344
+ by_flat = by.squeeze().contiguous()
1345
+ by_h_flat = by_h.squeeze().contiguous()
1346
+ bx_flat = bx.squeeze().contiguous()
1347
+ bx_h_flat = bx_h.squeeze().contiguous()
1348
+
1349
+ # Flatten kappa profiles for C backend
1350
+ ky_flat = ky.squeeze().contiguous()
1351
+ ky_h_flat = ky_h.squeeze().contiguous()
1352
+ kx_flat = kx.squeeze().contiguous()
1353
+ kx_h_flat = kx_h.squeeze().contiguous()
1354
+
1355
+ # =========================================================================
1356
+ # Prepare source and receiver indices
1357
+ # Original positions are in the un-padded model coordinate system.
1358
+ # We need to offset by total_pad (fd_pad + pml_width) to get padded coords.
1359
+ # =========================================================================
1360
+ flat_model_shape = padded_ny * padded_nx
1361
+
1362
+ if source_location is not None and source_location.numel() > 0:
1363
+ # Adjust source positions by total padding offset
1364
+ source_y = source_location[..., 0] + total_pad[0]
1365
+ source_x = source_location[..., 1] + total_pad[2]
1366
+ sources_i = (source_y * padded_nx + source_x).long().contiguous()
1367
+ n_sources = source_location.shape[1]
1368
+ else:
1369
+ sources_i = torch.empty(0, device=device, dtype=torch.long)
1370
+ n_sources = 0
1371
+
1372
+ if receiver_location is not None and receiver_location.numel() > 0:
1373
+ # Adjust receiver positions by total padding offset
1374
+ receiver_y = receiver_location[..., 0] + total_pad[0]
1375
+ receiver_x = receiver_location[..., 1] + total_pad[2]
1376
+ receivers_i = (receiver_y * padded_nx + receiver_x).long().contiguous()
1377
+ n_receivers = receiver_location.shape[1]
1378
+ else:
1379
+ receivers_i = torch.empty(0, device=device, dtype=torch.long)
1380
+ n_receivers = 0
1381
+
1382
+ # Prepare source amplitudes with proper scaling
1383
+ if source_amplitude is not None and source_amplitude.numel() > 0:
1384
+ source_coeff = -1.0 / (dx * dy)
1385
+ # Expand cb to batch dimension for gather
1386
+ cb_expanded = cb[None, :, :].expand(n_shots, -1, -1)
1387
+ cb_flat = cb_expanded.reshape(n_shots, flat_model_shape)
1388
+ cb_at_src = cb_flat.gather(1, sources_i)
1389
+ # Reshape source amplitude: [shot, source, time] -> [time, shot, source]
1390
+ f = source_amplitude.permute(2, 0, 1).contiguous()
1391
+ # Scale by cb and source coefficient
1392
+ f = f * cb_at_src[None, :, :] * source_coeff
1393
+ f = f.reshape(nt_steps * n_shots * n_sources)
1394
+ else:
1395
+ f = torch.empty(0, device=device, dtype=dtype)
1396
+
1397
+ # Flatten fields for C backend
1398
+ Ey = Ey.contiguous()
1399
+ Hx = Hx.contiguous()
1400
+ Hz = Hz.contiguous()
1401
+ m_Ey_x = m_Ey_x.contiguous()
1402
+ m_Ey_z = m_Ey_z.contiguous()
1403
+ m_Hx_z = m_Hx_z.contiguous()
1404
+ m_Hz_x = m_Hz_x.contiguous()
1405
+
1406
+ # Flatten coefficients (add batch dimension for consistency)
1407
+ ca = ca[None, :, :].contiguous()
1408
+ cb = cb[None, :, :].contiguous()
1409
+ cq = cq[None, :, :].contiguous()
1410
+
1411
+ # PML boundaries (where PML starts in the padded domain)
1412
+ pml_y0 = fd_pad_list[0] + pml_width_list[0]
1413
+ pml_y1 = padded_ny - fd_pad_list[1] - pml_width_list[1]
1414
+ pml_x0 = fd_pad_list[2] + pml_width_list[2]
1415
+ pml_x1 = padded_nx - fd_pad_list[3] - pml_width_list[3]
1416
+
1417
+ # Determine if any input requires gradients
1418
+ requires_grad = epsilon.requires_grad or sigma.requires_grad
1419
+
1420
+ functorch_active = torch._C._are_functorch_transforms_active()
1421
+ if functorch_active:
1422
+ raise NotImplementedError(
1423
+ "torch.func transforms are not supported for the C/CUDA backend."
1424
+ )
1425
+
1426
+ storage_kind, _, storage_bytes_per_elem = _resolve_storage_compression(
1427
+ storage_compression,
1428
+ dtype,
1429
+ device,
1430
+ context="storage_compression",
1431
+ )
1432
+
1433
+ # Determine if we should save snapshots for backward pass
1434
+ if save_snapshots is None:
1435
+ do_save_snapshots = requires_grad
1436
+ else:
1437
+ do_save_snapshots = save_snapshots
1438
+
1439
+ # If save_snapshots is False but requires_grad is True, warn user
1440
+ if requires_grad and save_snapshots is False:
1441
+ import warnings
1442
+
1443
+ warnings.warn(
1444
+ "save_snapshots=False but model parameters require gradients. "
1445
+ "Backward pass will fail.",
1446
+ UserWarning,
1447
+ )
1448
+
1449
+ storage_mode_str = storage_mode.lower()
1450
+ if storage_mode_str not in {"device", "cpu", "disk", "none", "auto"}:
1451
+ raise ValueError(
1452
+ "storage_mode must be 'device', 'cpu', 'disk', 'none', or 'auto', "
1453
+ f"but got {storage_mode!r}"
1454
+ )
1455
+ if device.type == "cpu" and storage_mode_str == "cpu":
1456
+ storage_mode_str = "device"
1457
+
1458
+ needs_storage = do_save_snapshots and requires_grad
1459
+ effective_storage_mode_str = storage_mode_str
1460
+ if not needs_storage:
1461
+ if effective_storage_mode_str == "auto":
1462
+ effective_storage_mode_str = "none"
1463
+ else:
1464
+ if effective_storage_mode_str == "none":
1465
+ raise ValueError(
1466
+ "storage_mode='none' is not compatible with gradient computation "
1467
+ "when model parameters require gradients."
1468
+ )
1469
+ if effective_storage_mode_str == "auto":
1470
+ dtype_size = storage_bytes_per_elem
1471
+ # Estimate required bytes for storing Ey and curl_H.
1472
+ num_elements_per_shot = padded_ny * padded_nx
1473
+ shot_bytes_uncomp = num_elements_per_shot * dtype_size
1474
+ n_stored = (
1475
+ nt_steps + gradient_sampling_interval - 1
1476
+ ) // gradient_sampling_interval
1477
+ total_bytes = n_stored * n_shots * shot_bytes_uncomp * 2 # Ey + curl_H
1478
+
1479
+ limit_device = (
1480
+ storage_bytes_limit_device
1481
+ if storage_bytes_limit_device is not None
1482
+ else float("inf")
1483
+ )
1484
+ limit_host = (
1485
+ storage_bytes_limit_host
1486
+ if storage_bytes_limit_host is not None
1487
+ else float("inf")
1488
+ )
1489
+ import warnings
1490
+
1491
+ if device.type == "cuda" and total_bytes <= limit_device:
1492
+ effective_storage_mode_str = "device"
1493
+ elif total_bytes <= limit_host:
1494
+ effective_storage_mode_str = "cpu"
1495
+ else:
1496
+ effective_storage_mode_str = "disk"
1497
+
1498
+ warnings.warn(
1499
+ f"storage_mode='auto' selected storage_mode='{effective_storage_mode_str}' "
1500
+ f"for estimated storage size {total_bytes / 1e9:.2f} GB.",
1501
+ RuntimeWarning,
1502
+ )
1503
+
1504
+ # Callback fd_pad is the actual fd_pad used for wavefields
1505
+ callback_fd_pad = fd_pad_list
1506
+
1507
+ # Callback models dict
1508
+ callback_models = {
1509
+ "epsilon": epsilon_padded,
1510
+ "sigma": sigma_padded,
1511
+ "mu": mu_padded,
1512
+ "ca": ca,
1513
+ "cb": cb,
1514
+ "cq": cq,
1515
+ }
1516
+
1517
+ use_autograd_fn = (
1518
+ requires_grad
1519
+ and do_save_snapshots
1520
+ ) or functorch_active
1521
+ if use_autograd_fn:
1522
+ # Use autograd Function for gradient computation
1523
+ result = MaxwellTMForwardFunc.apply(
1524
+ ca,
1525
+ cb,
1526
+ cq,
1527
+ f,
1528
+ ay_flat,
1529
+ by_flat,
1530
+ ay_h_flat,
1531
+ by_h_flat,
1532
+ ax_flat,
1533
+ bx_flat,
1534
+ ax_h_flat,
1535
+ bx_h_flat,
1536
+ ky_flat,
1537
+ ky_h_flat,
1538
+ kx_flat,
1539
+ kx_h_flat,
1540
+ sources_i,
1541
+ receivers_i,
1542
+ 1.0 / dy, # rdy
1543
+ 1.0 / dx, # rdx
1544
+ dt,
1545
+ nt_steps,
1546
+ n_shots,
1547
+ padded_ny,
1548
+ padded_nx,
1549
+ n_sources,
1550
+ n_receivers,
1551
+ gradient_sampling_interval, # step_ratio
1552
+ stencil, # accuracy
1553
+ False, # ca_batched
1554
+ False, # cb_batched
1555
+ False, # cq_batched
1556
+ pml_y0,
1557
+ pml_x0,
1558
+ pml_y1,
1559
+ pml_x1,
1560
+ tuple(fd_pad_list), # fd_pad for callback
1561
+ tuple(pml_width_list), # pml_width for callback
1562
+ callback_models, # models dict for callback
1563
+ forward_callback,
1564
+ backward_callback,
1565
+ callback_frequency,
1566
+ effective_storage_mode_str,
1567
+ storage_path,
1568
+ storage_compression,
1569
+ Ey,
1570
+ Hx,
1571
+ Hz,
1572
+ m_Ey_x,
1573
+ m_Ey_z,
1574
+ m_Hx_z,
1575
+ m_Hz_x,
1576
+ n_threads_val,
1577
+ )
1578
+ # Unpack result (drop context handle if present)
1579
+ if len(result) == 9:
1580
+ (
1581
+ Ey_out,
1582
+ Hx_out,
1583
+ Hz_out,
1584
+ m_Ey_x_out,
1585
+ m_Ey_z_out,
1586
+ m_Hx_z_out,
1587
+ m_Hz_x_out,
1588
+ receiver_amplitudes,
1589
+ _ctx_handle,
1590
+ ) = result
1591
+ else:
1592
+ (
1593
+ Ey_out,
1594
+ Hx_out,
1595
+ Hz_out,
1596
+ m_Ey_x_out,
1597
+ m_Ey_z_out,
1598
+ m_Hx_z_out,
1599
+ m_Hz_x_out,
1600
+ receiver_amplitudes,
1601
+ ) = result
1602
+ # Output cropping: only remove fd_pad, keep PML region
1603
+ s = (
1604
+ slice(None), # batch dimension
1605
+ slice(
1606
+ fd_pad_list[0],
1607
+ padded_ny - fd_pad_list[1] if fd_pad_list[1] > 0 else None,
1608
+ ),
1609
+ slice(
1610
+ fd_pad_list[2],
1611
+ padded_nx - fd_pad_list[3] if fd_pad_list[3] > 0 else None,
1612
+ ),
1613
+ )
1614
+
1615
+ return (
1616
+ Ey_out[s],
1617
+ Hx_out[s],
1618
+ Hz_out[s],
1619
+ m_Ey_x_out[s],
1620
+ m_Ey_z_out[s],
1621
+ m_Hx_z_out[s],
1622
+ m_Hz_x_out[s],
1623
+ receiver_amplitudes,
1624
+ )
1625
+ else:
1626
+ # Direct call without autograd for inference
1627
+ # Get the backend function
1628
+ try:
1629
+ forward_func = backend_utils.get_backend_function(
1630
+ "maxwell_tm", "forward", stencil, dtype, device
1631
+ )
1632
+ except AttributeError as e:
1633
+ raise RuntimeError(
1634
+ f"C/CUDA backend function not available for accuracy={stencil}, "
1635
+ f"dtype={dtype}, device={device}. Error: {e}"
1636
+ )
1637
+
1638
+ # Get device index for CUDA
1639
+ device_idx = (
1640
+ device.index if device.type == "cuda" and device.index is not None else 0
1641
+ )
1642
+
1643
+ # Initialize receiver amplitudes
1644
+ if n_receivers > 0:
1645
+ receiver_amplitudes = torch.zeros(
1646
+ nt_steps, n_shots, n_receivers, device=device, dtype=dtype
1647
+ )
1648
+ else:
1649
+ receiver_amplitudes = torch.empty(0, device=device, dtype=dtype)
1650
+
1651
+ # If no callback is provided, run entire propagation in single call
1652
+ # Otherwise, chunk into callback_frequency steps
1653
+ if forward_callback is None:
1654
+ effective_callback_freq = nt_steps
1655
+ else:
1656
+ effective_callback_freq = callback_frequency
1657
+
1658
+ # Main time-stepping loop with chunked calls for callback support
1659
+ for step in range(0, nt_steps, effective_callback_freq):
1660
+ # Call callback at the start of each chunk
1661
+ if forward_callback is not None:
1662
+ callback_wavefields = {
1663
+ "Ey": Ey,
1664
+ "Hx": Hx,
1665
+ "Hz": Hz,
1666
+ "m_Ey_x": m_Ey_x,
1667
+ "m_Ey_z": m_Ey_z,
1668
+ "m_Hx_z": m_Hx_z,
1669
+ "m_Hz_x": m_Hz_x,
1670
+ }
1671
+ callback_state = CallbackState(
1672
+ dt=dt,
1673
+ step=step,
1674
+ nt=nt_steps,
1675
+ wavefields=callback_wavefields,
1676
+ models=callback_models,
1677
+ gradients=None,
1678
+ fd_pad=callback_fd_pad,
1679
+ pml_width=pml_width_list,
1680
+ is_backward=False,
1681
+ grid_spacing=[dy, dx],
1682
+ )
1683
+ forward_callback(callback_state)
1684
+
1685
+ # Number of steps to propagate in this chunk
1686
+ step_nt = min(nt_steps - step, effective_callback_freq)
1687
+
1688
+ # Call the C/CUDA function for this chunk
1689
+ forward_func(
1690
+ backend_utils.tensor_to_ptr(ca),
1691
+ backend_utils.tensor_to_ptr(cb),
1692
+ backend_utils.tensor_to_ptr(cq),
1693
+ backend_utils.tensor_to_ptr(f),
1694
+ backend_utils.tensor_to_ptr(Ey),
1695
+ backend_utils.tensor_to_ptr(Hx),
1696
+ backend_utils.tensor_to_ptr(Hz),
1697
+ backend_utils.tensor_to_ptr(m_Ey_x),
1698
+ backend_utils.tensor_to_ptr(m_Ey_z),
1699
+ backend_utils.tensor_to_ptr(m_Hx_z),
1700
+ backend_utils.tensor_to_ptr(m_Hz_x),
1701
+ backend_utils.tensor_to_ptr(receiver_amplitudes),
1702
+ backend_utils.tensor_to_ptr(ay_flat),
1703
+ backend_utils.tensor_to_ptr(by_flat),
1704
+ backend_utils.tensor_to_ptr(ay_h_flat),
1705
+ backend_utils.tensor_to_ptr(by_h_flat),
1706
+ backend_utils.tensor_to_ptr(ax_flat),
1707
+ backend_utils.tensor_to_ptr(bx_flat),
1708
+ backend_utils.tensor_to_ptr(ax_h_flat),
1709
+ backend_utils.tensor_to_ptr(bx_h_flat),
1710
+ backend_utils.tensor_to_ptr(ky_flat),
1711
+ backend_utils.tensor_to_ptr(ky_h_flat),
1712
+ backend_utils.tensor_to_ptr(kx_flat),
1713
+ backend_utils.tensor_to_ptr(kx_h_flat),
1714
+ backend_utils.tensor_to_ptr(sources_i),
1715
+ backend_utils.tensor_to_ptr(receivers_i),
1716
+ 1.0 / dy, # rdy
1717
+ 1.0 / dx, # rdx
1718
+ dt,
1719
+ step_nt, # nt for this chunk
1720
+ n_shots,
1721
+ padded_ny,
1722
+ padded_nx,
1723
+ n_sources,
1724
+ n_receivers,
1725
+ gradient_sampling_interval, # step_ratio
1726
+ False, # ca_batched
1727
+ False, # cb_batched
1728
+ False, # cq_batched
1729
+ step, # start_t - crucial for correct source injection timing
1730
+ pml_y0,
1731
+ pml_x0,
1732
+ pml_y1,
1733
+ pml_x1,
1734
+ n_threads_val,
1735
+ device_idx,
1736
+ )
1737
+
1738
+ # Output cropping: only remove fd_pad, keep PML region
1739
+ s = (
1740
+ slice(None), # batch dimension
1741
+ slice(
1742
+ fd_pad_list[0],
1743
+ padded_ny - fd_pad_list[1] if fd_pad_list[1] > 0 else None,
1744
+ ),
1745
+ slice(
1746
+ fd_pad_list[2],
1747
+ padded_nx - fd_pad_list[3] if fd_pad_list[3] > 0 else None,
1748
+ ),
1749
+ )
1750
+
1751
+ return (
1752
+ Ey[s],
1753
+ Hx[s],
1754
+ Hz[s],
1755
+ m_Ey_x[s],
1756
+ m_Ey_z[s],
1757
+ m_Hx_z[s],
1758
+ m_Hz_x[s],
1759
+ receiver_amplitudes,
1760
+ )
1761
+
1762
+
1763
+ class MaxwellTMForwardFunc(torch.autograd.Function):
1764
+ """Autograd function for the forward pass of Maxwell TM wave propagation.
1765
+
1766
+ This class defines the forward and backward passes for the 2D TM mode
1767
+ Maxwell equations, allowing PyTorch to compute gradients through the wave
1768
+ propagation operation. It interfaces directly with the C/CUDA backend.
1769
+
1770
+ The backward pass uses the Adjoint State Method (ASM) which requires
1771
+ storing forward wavefield values at each step_ratio interval for
1772
+ gradient computation.
1773
+ """
1774
+
1775
+ @staticmethod
1776
+ def forward(
1777
+ ca: torch.Tensor,
1778
+ cb: torch.Tensor,
1779
+ cq: torch.Tensor,
1780
+ source_amplitudes_scaled: torch.Tensor,
1781
+ ay: torch.Tensor,
1782
+ by: torch.Tensor,
1783
+ ay_h: torch.Tensor,
1784
+ by_h: torch.Tensor,
1785
+ ax: torch.Tensor,
1786
+ bx: torch.Tensor,
1787
+ ax_h: torch.Tensor,
1788
+ bx_h: torch.Tensor,
1789
+ ky: torch.Tensor,
1790
+ ky_h: torch.Tensor,
1791
+ kx: torch.Tensor,
1792
+ kx_h: torch.Tensor,
1793
+ sources_i: torch.Tensor,
1794
+ receivers_i: torch.Tensor,
1795
+ rdy: float,
1796
+ rdx: float,
1797
+ dt: float,
1798
+ nt: int,
1799
+ n_shots: int,
1800
+ ny: int,
1801
+ nx: int,
1802
+ n_sources: int,
1803
+ n_receivers: int,
1804
+ step_ratio: int,
1805
+ accuracy: int,
1806
+ ca_batched: bool,
1807
+ cb_batched: bool,
1808
+ cq_batched: bool,
1809
+ pml_y0: int,
1810
+ pml_x0: int,
1811
+ pml_y1: int,
1812
+ pml_x1: int,
1813
+ fd_pad: tuple[int, int, int, int],
1814
+ pml_width: tuple[int, int, int, int],
1815
+ models: dict,
1816
+ forward_callback: Optional[Callback],
1817
+ backward_callback: Optional[Callback],
1818
+ callback_frequency: int,
1819
+ storage_mode_str: str,
1820
+ storage_path: str,
1821
+ storage_compression: Union[bool, str],
1822
+ Ey: torch.Tensor,
1823
+ Hx: torch.Tensor,
1824
+ Hz: torch.Tensor,
1825
+ m_Ey_x: torch.Tensor,
1826
+ m_Ey_z: torch.Tensor,
1827
+ m_Hx_z: torch.Tensor,
1828
+ m_Hz_x: torch.Tensor,
1829
+ n_threads: int,
1830
+ ) -> tuple[Any, ...]:
1831
+ """Performs the forward propagation of the Maxwell TM equations."""
1832
+ from . import backend_utils
1833
+
1834
+ device = Ey.device
1835
+ dtype = Ey.dtype
1836
+
1837
+ ca_requires_grad = ca.requires_grad
1838
+ cb_requires_grad = cb.requires_grad
1839
+ needs_grad = ca_requires_grad or cb_requires_grad
1840
+
1841
+ # Initialize receiver amplitudes
1842
+ if n_receivers > 0:
1843
+ receiver_amplitudes = torch.zeros(
1844
+ nt, n_shots, n_receivers, device=device, dtype=dtype
1845
+ )
1846
+ else:
1847
+ receiver_amplitudes = torch.empty(0, device=device, dtype=dtype)
1848
+
1849
+ # Get device index for CUDA
1850
+ device_idx = (
1851
+ device.index if device.type == "cuda" and device.index is not None else 0
1852
+ )
1853
+
1854
+ backward_storage_tensors: list[torch.Tensor] = []
1855
+ backward_storage_objects: list[Any] = []
1856
+ backward_storage_filename_arrays: list[Any] = []
1857
+ storage_mode = STORAGE_NONE
1858
+ shot_bytes_uncomp = 0
1859
+
1860
+ if needs_grad:
1861
+ import ctypes
1862
+
1863
+ # Resolve storage mode and sizes
1864
+ if str(device) == "cpu" and storage_mode_str == "cpu":
1865
+ storage_mode_str = "device"
1866
+ storage_mode = storage_mode_to_int(storage_mode_str)
1867
+
1868
+ num_elements_per_shot = ny * nx
1869
+ _, store_dtype, _ = _resolve_storage_compression(
1870
+ storage_compression,
1871
+ dtype,
1872
+ device,
1873
+ context="storage_compression",
1874
+ )
1875
+
1876
+ shot_bytes_uncomp = num_elements_per_shot * store_dtype.itemsize
1877
+
1878
+ num_steps_stored = (nt + step_ratio - 1) // step_ratio
1879
+
1880
+ # Storage buffers and filename arrays (mirrors Deepwave)
1881
+ char_ptr_type = ctypes.c_char_p
1882
+ is_cuda = device.type == "cuda"
1883
+
1884
+ def alloc_storage(requires_grad_cond: bool):
1885
+ store_1 = torch.empty(0)
1886
+ store_3 = torch.empty(0)
1887
+ filenames_arr = (char_ptr_type * 0)()
1888
+
1889
+ if requires_grad_cond and storage_mode != STORAGE_NONE:
1890
+ if storage_mode == STORAGE_DEVICE:
1891
+ store_1 = torch.empty(
1892
+ num_steps_stored,
1893
+ n_shots,
1894
+ ny,
1895
+ nx,
1896
+ device=device,
1897
+ dtype=store_dtype,
1898
+ )
1899
+ elif storage_mode == STORAGE_CPU:
1900
+ # Multi-buffer device staging to overlap D2H copies.
1901
+ store_1 = torch.empty(
1902
+ _CPU_STORAGE_BUFFERS,
1903
+ n_shots,
1904
+ ny,
1905
+ nx,
1906
+ device=device,
1907
+ dtype=store_dtype,
1908
+ )
1909
+ store_3 = torch.empty(
1910
+ num_steps_stored,
1911
+ n_shots,
1912
+ shot_bytes_uncomp // store_dtype.itemsize,
1913
+ device="cpu",
1914
+ pin_memory=True,
1915
+ dtype=store_dtype,
1916
+ )
1917
+ elif storage_mode == STORAGE_DISK:
1918
+ storage_obj = TemporaryStorage(
1919
+ storage_path, 1 if is_cuda else n_shots
1920
+ )
1921
+ backward_storage_objects.append(storage_obj)
1922
+ filenames_list = [
1923
+ f.encode("utf-8") for f in storage_obj.get_filenames()
1924
+ ]
1925
+ filenames_arr = (char_ptr_type * len(filenames_list))()
1926
+ for i_file, f_name in enumerate(filenames_list):
1927
+ filenames_arr[i_file] = ctypes.cast(
1928
+ ctypes.create_string_buffer(f_name), char_ptr_type
1929
+ )
1930
+
1931
+ store_1 = torch.empty(
1932
+ n_shots, ny, nx, device=device, dtype=store_dtype
1933
+ )
1934
+ if is_cuda:
1935
+ store_3 = torch.empty(
1936
+ n_shots,
1937
+ shot_bytes_uncomp // store_dtype.itemsize,
1938
+ device="cpu",
1939
+ pin_memory=True,
1940
+ dtype=store_dtype,
1941
+ )
1942
+
1943
+ backward_storage_tensors.extend([store_1, store_3])
1944
+ backward_storage_filename_arrays.append(filenames_arr)
1945
+
1946
+ filenames_ptr = (
1947
+ ctypes.cast(filenames_arr, ctypes.c_void_p)
1948
+ if storage_mode == STORAGE_DISK
1949
+ else 0
1950
+ )
1951
+
1952
+ return store_1, store_3, filenames_ptr
1953
+
1954
+ ey_store_1, ey_store_3, ey_filenames_ptr = alloc_storage(ca_requires_grad)
1955
+ curl_store_1, curl_store_3, curl_filenames_ptr = alloc_storage(
1956
+ cb_requires_grad
1957
+ )
1958
+
1959
+ # Get the backend function with storage
1960
+ forward_func = backend_utils.get_backend_function(
1961
+ "maxwell_tm", "forward_with_storage", accuracy, dtype, device
1962
+ )
1963
+
1964
+ # Determine effective callback frequency
1965
+ if forward_callback is None:
1966
+ effective_callback_freq = nt // step_ratio
1967
+ else:
1968
+ effective_callback_freq = callback_frequency
1969
+
1970
+ # Chunked forward propagation with callback support
1971
+ for step in range(0, nt // step_ratio, effective_callback_freq):
1972
+ step_nt = (
1973
+ min(effective_callback_freq, nt // step_ratio - step) * step_ratio
1974
+ )
1975
+ start_t = step * step_ratio
1976
+
1977
+ # Call the C/CUDA function with storage for this chunk
1978
+ forward_func(
1979
+ backend_utils.tensor_to_ptr(ca),
1980
+ backend_utils.tensor_to_ptr(cb),
1981
+ backend_utils.tensor_to_ptr(cq),
1982
+ backend_utils.tensor_to_ptr(source_amplitudes_scaled),
1983
+ backend_utils.tensor_to_ptr(Ey),
1984
+ backend_utils.tensor_to_ptr(Hx),
1985
+ backend_utils.tensor_to_ptr(Hz),
1986
+ backend_utils.tensor_to_ptr(m_Ey_x),
1987
+ backend_utils.tensor_to_ptr(m_Ey_z),
1988
+ backend_utils.tensor_to_ptr(m_Hx_z),
1989
+ backend_utils.tensor_to_ptr(m_Hz_x),
1990
+ backend_utils.tensor_to_ptr(receiver_amplitudes),
1991
+ backend_utils.tensor_to_ptr(ey_store_1),
1992
+ backend_utils.tensor_to_ptr(ey_store_3),
1993
+ ey_filenames_ptr,
1994
+ backend_utils.tensor_to_ptr(curl_store_1),
1995
+ backend_utils.tensor_to_ptr(curl_store_3),
1996
+ curl_filenames_ptr,
1997
+ backend_utils.tensor_to_ptr(ay),
1998
+ backend_utils.tensor_to_ptr(by),
1999
+ backend_utils.tensor_to_ptr(ay_h),
2000
+ backend_utils.tensor_to_ptr(by_h),
2001
+ backend_utils.tensor_to_ptr(ax),
2002
+ backend_utils.tensor_to_ptr(bx),
2003
+ backend_utils.tensor_to_ptr(ax_h),
2004
+ backend_utils.tensor_to_ptr(bx_h),
2005
+ backend_utils.tensor_to_ptr(ky),
2006
+ backend_utils.tensor_to_ptr(ky_h),
2007
+ backend_utils.tensor_to_ptr(kx),
2008
+ backend_utils.tensor_to_ptr(kx_h),
2009
+ backend_utils.tensor_to_ptr(sources_i),
2010
+ backend_utils.tensor_to_ptr(receivers_i),
2011
+ rdy,
2012
+ rdx,
2013
+ dt,
2014
+ step_nt, # number of steps in this chunk
2015
+ n_shots,
2016
+ ny,
2017
+ nx,
2018
+ n_sources,
2019
+ n_receivers,
2020
+ step_ratio,
2021
+ storage_mode,
2022
+ shot_bytes_uncomp,
2023
+ ca_requires_grad,
2024
+ cb_requires_grad,
2025
+ ca_batched,
2026
+ cb_batched,
2027
+ cq_batched,
2028
+ start_t, # starting time step
2029
+ pml_y0,
2030
+ pml_x0,
2031
+ pml_y1,
2032
+ pml_x1,
2033
+ n_threads,
2034
+ device_idx,
2035
+ )
2036
+
2037
+ # Call forward callback after each chunk
2038
+ if forward_callback is not None:
2039
+ callback_wavefields = {
2040
+ "Ey": Ey,
2041
+ "Hx": Hx,
2042
+ "Hz": Hz,
2043
+ "m_Ey_x": m_Ey_x,
2044
+ "m_Ey_z": m_Ey_z,
2045
+ "m_Hx_z": m_Hx_z,
2046
+ "m_Hz_x": m_Hz_x,
2047
+ }
2048
+ forward_callback(
2049
+ CallbackState(
2050
+ dt=dt,
2051
+ step=step + step_nt // step_ratio,
2052
+ nt=nt // step_ratio,
2053
+ wavefields=callback_wavefields,
2054
+ models=models,
2055
+ gradients={},
2056
+ fd_pad=list(fd_pad),
2057
+ pml_width=list(pml_width),
2058
+ is_backward=False,
2059
+ )
2060
+ )
2061
+ else:
2062
+ # Use regular forward without storage
2063
+ forward_func = backend_utils.get_backend_function(
2064
+ "maxwell_tm", "forward", accuracy, dtype, device
2065
+ )
2066
+
2067
+ # Call the C/CUDA function
2068
+ forward_func(
2069
+ backend_utils.tensor_to_ptr(ca),
2070
+ backend_utils.tensor_to_ptr(cb),
2071
+ backend_utils.tensor_to_ptr(cq),
2072
+ backend_utils.tensor_to_ptr(source_amplitudes_scaled),
2073
+ backend_utils.tensor_to_ptr(Ey),
2074
+ backend_utils.tensor_to_ptr(Hx),
2075
+ backend_utils.tensor_to_ptr(Hz),
2076
+ backend_utils.tensor_to_ptr(m_Ey_x),
2077
+ backend_utils.tensor_to_ptr(m_Ey_z),
2078
+ backend_utils.tensor_to_ptr(m_Hx_z),
2079
+ backend_utils.tensor_to_ptr(m_Hz_x),
2080
+ backend_utils.tensor_to_ptr(receiver_amplitudes),
2081
+ backend_utils.tensor_to_ptr(ay),
2082
+ backend_utils.tensor_to_ptr(by),
2083
+ backend_utils.tensor_to_ptr(ay_h),
2084
+ backend_utils.tensor_to_ptr(by_h),
2085
+ backend_utils.tensor_to_ptr(ax),
2086
+ backend_utils.tensor_to_ptr(bx),
2087
+ backend_utils.tensor_to_ptr(ax_h),
2088
+ backend_utils.tensor_to_ptr(bx_h),
2089
+ backend_utils.tensor_to_ptr(ky),
2090
+ backend_utils.tensor_to_ptr(ky_h),
2091
+ backend_utils.tensor_to_ptr(kx),
2092
+ backend_utils.tensor_to_ptr(kx_h),
2093
+ backend_utils.tensor_to_ptr(sources_i),
2094
+ backend_utils.tensor_to_ptr(receivers_i),
2095
+ rdy,
2096
+ rdx,
2097
+ dt,
2098
+ nt,
2099
+ n_shots,
2100
+ ny,
2101
+ nx,
2102
+ n_sources,
2103
+ n_receivers,
2104
+ step_ratio,
2105
+ ca_batched,
2106
+ cb_batched,
2107
+ cq_batched,
2108
+ 0, # start_t
2109
+ pml_y0,
2110
+ pml_x0,
2111
+ pml_y1,
2112
+ pml_x1,
2113
+ n_threads,
2114
+ device_idx,
2115
+ )
2116
+
2117
+ ctx_data = {
2118
+ "backward_storage_tensors": backward_storage_tensors,
2119
+ "backward_storage_objects": backward_storage_objects,
2120
+ "backward_storage_filename_arrays": backward_storage_filename_arrays,
2121
+ "storage_mode": storage_mode,
2122
+ "shot_bytes_uncomp": shot_bytes_uncomp,
2123
+ "source_amplitudes_scaled": source_amplitudes_scaled,
2124
+ "ca_requires_grad": ca_requires_grad,
2125
+ "cb_requires_grad": cb_requires_grad,
2126
+ }
2127
+ ctx_handle = _register_ctx_handle(ctx_data)
2128
+
2129
+ return (
2130
+ Ey,
2131
+ Hx,
2132
+ Hz,
2133
+ m_Ey_x,
2134
+ m_Ey_z,
2135
+ m_Hx_z,
2136
+ m_Hz_x,
2137
+ receiver_amplitudes,
2138
+ ctx_handle,
2139
+ )
2140
+
2141
+ @staticmethod
2142
+ def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
2143
+ (
2144
+ ca,
2145
+ cb,
2146
+ cq,
2147
+ _source_amplitudes_scaled,
2148
+ ay,
2149
+ by,
2150
+ ay_h,
2151
+ by_h,
2152
+ ax,
2153
+ bx,
2154
+ ax_h,
2155
+ bx_h,
2156
+ ky,
2157
+ ky_h,
2158
+ kx,
2159
+ kx_h,
2160
+ sources_i,
2161
+ receivers_i,
2162
+ rdy,
2163
+ rdx,
2164
+ dt,
2165
+ nt,
2166
+ n_shots,
2167
+ ny,
2168
+ nx,
2169
+ n_sources,
2170
+ n_receivers,
2171
+ step_ratio,
2172
+ accuracy,
2173
+ ca_batched,
2174
+ cb_batched,
2175
+ cq_batched,
2176
+ pml_y0,
2177
+ pml_x0,
2178
+ pml_y1,
2179
+ pml_x1,
2180
+ fd_pad,
2181
+ pml_width,
2182
+ models,
2183
+ _forward_callback,
2184
+ backward_callback,
2185
+ callback_frequency,
2186
+ _storage_mode_str,
2187
+ _storage_path,
2188
+ _storage_compression,
2189
+ _Ey,
2190
+ _Hx,
2191
+ _Hz,
2192
+ _m_Ey_x,
2193
+ _m_Ey_z,
2194
+ _m_Hx_z,
2195
+ _m_Hz_x,
2196
+ n_threads,
2197
+ ) = inputs
2198
+
2199
+ outputs = output if isinstance(output, tuple) else (output,)
2200
+ if len(outputs) != 9:
2201
+ raise RuntimeError(
2202
+ "MaxwellTMForwardFunc expected a context handle output for setup_context."
2203
+ )
2204
+ ctx_handle = outputs[-1]
2205
+ if not isinstance(ctx_handle, torch.Tensor):
2206
+ raise RuntimeError("MaxwellTMForwardFunc context handle must be a Tensor.")
2207
+
2208
+ ctx_handle_id = int(ctx_handle.item())
2209
+ ctx_data = _get_ctx_handle(ctx_handle_id)
2210
+ ctx._ctx_handle_id = ctx_handle_id
2211
+ backward_storage_tensors = ctx_data["backward_storage_tensors"]
2212
+
2213
+ ctx.save_for_backward(
2214
+ ca,
2215
+ cb,
2216
+ cq,
2217
+ ay,
2218
+ by,
2219
+ ay_h,
2220
+ by_h,
2221
+ ax,
2222
+ bx,
2223
+ ax_h,
2224
+ bx_h,
2225
+ ky,
2226
+ ky_h,
2227
+ kx,
2228
+ kx_h,
2229
+ sources_i,
2230
+ receivers_i,
2231
+ *backward_storage_tensors,
2232
+ )
2233
+ ctx.save_for_forward(
2234
+ ca,
2235
+ cb,
2236
+ cq,
2237
+ ay,
2238
+ by,
2239
+ ay_h,
2240
+ by_h,
2241
+ ax,
2242
+ bx,
2243
+ ax_h,
2244
+ bx_h,
2245
+ ky,
2246
+ ky_h,
2247
+ kx,
2248
+ kx_h,
2249
+ sources_i,
2250
+ receivers_i,
2251
+ )
2252
+ ctx.backward_storage_objects = ctx_data["backward_storage_objects"]
2253
+ ctx.backward_storage_filename_arrays = ctx_data[
2254
+ "backward_storage_filename_arrays"
2255
+ ]
2256
+ ctx.rdy = rdy
2257
+ ctx.rdx = rdx
2258
+ ctx.dt = dt
2259
+ ctx.nt = nt
2260
+ ctx.n_shots = n_shots
2261
+ ctx.ny = ny
2262
+ ctx.nx = nx
2263
+ ctx.n_sources = n_sources
2264
+ ctx.n_receivers = n_receivers
2265
+ ctx.step_ratio = step_ratio
2266
+ ctx.accuracy = accuracy
2267
+ ctx.ca_batched = ca_batched
2268
+ ctx.cb_batched = cb_batched
2269
+ ctx.cq_batched = cq_batched
2270
+ ctx.pml_y0 = pml_y0
2271
+ ctx.pml_x0 = pml_x0
2272
+ ctx.pml_y1 = pml_y1
2273
+ ctx.pml_x1 = pml_x1
2274
+ ctx.ca_requires_grad = ctx_data["ca_requires_grad"]
2275
+ ctx.cb_requires_grad = ctx_data["cb_requires_grad"]
2276
+ ctx.storage_mode = ctx_data["storage_mode"]
2277
+ ctx.shot_bytes_uncomp = ctx_data["shot_bytes_uncomp"]
2278
+ ctx.fd_pad = fd_pad
2279
+ ctx.pml_width = pml_width
2280
+ ctx.models = models
2281
+ ctx.backward_callback = backward_callback
2282
+ ctx.callback_frequency = callback_frequency
2283
+ ctx.source_amplitudes_scaled = ctx_data["source_amplitudes_scaled"]
2284
+ ctx.n_threads = n_threads
2285
+
2286
+ @staticmethod
2287
+ def backward(
2288
+ ctx: Any, *grad_outputs: torch.Tensor
2289
+ ) -> tuple[Optional[torch.Tensor], ...]:
2290
+ """Computes the gradients during the backward pass using ASM.
2291
+
2292
+ Uses the Adjoint State Method (ASM) to compute gradients:
2293
+ - grad_ca = sum_t (E_y^n * lambda_Ey^{n+1})
2294
+ - grad_cb = sum_t (curl_H^n * lambda_Ey^{n+1})
2295
+
2296
+ Args:
2297
+ ctx: A context object containing information saved during forward.
2298
+ grad_outputs: Gradients of the loss with respect to the outputs.
2299
+
2300
+ Returns:
2301
+ Gradients with respect to the inputs of the forward pass.
2302
+ """
2303
+ from . import backend_utils
2304
+
2305
+ grad_outputs_list = list(grad_outputs)
2306
+ if len(grad_outputs_list) == 9:
2307
+ grad_outputs_list.pop() # drop context handle grad
2308
+
2309
+ # Unpack grad_outputs
2310
+ (
2311
+ grad_Ey,
2312
+ grad_Hx,
2313
+ grad_Hz,
2314
+ grad_m_Ey_x,
2315
+ grad_m_Ey_z,
2316
+ grad_m_Hx_z,
2317
+ grad_m_Hz_x,
2318
+ grad_r,
2319
+ ) = grad_outputs_list
2320
+
2321
+ # Retrieve saved tensors
2322
+ saved = ctx.saved_tensors
2323
+ ca, cb, cq = saved[0], saved[1], saved[2]
2324
+ ay, by, ay_h, by_h = saved[3], saved[4], saved[5], saved[6]
2325
+ ax, bx, ax_h, bx_h = saved[7], saved[8], saved[9], saved[10]
2326
+ ky, ky_h, kx, kx_h = saved[11], saved[12], saved[13], saved[14]
2327
+ sources_i, receivers_i = saved[15], saved[16]
2328
+ ey_store_1, ey_store_3 = saved[17], saved[18]
2329
+ curl_store_1, curl_store_3 = saved[19], saved[20]
2330
+
2331
+ device = ca.device
2332
+ dtype = ca.dtype
2333
+
2334
+ rdy = ctx.rdy
2335
+ rdx = ctx.rdx
2336
+ dt = ctx.dt
2337
+ nt = ctx.nt
2338
+ n_shots = ctx.n_shots
2339
+ ny = ctx.ny
2340
+ nx = ctx.nx
2341
+ n_sources = ctx.n_sources
2342
+ n_receivers = ctx.n_receivers
2343
+ step_ratio = ctx.step_ratio
2344
+ accuracy = ctx.accuracy
2345
+ ca_batched = ctx.ca_batched
2346
+ cb_batched = ctx.cb_batched
2347
+ cq_batched = ctx.cq_batched
2348
+ pml_y0 = ctx.pml_y0
2349
+ pml_x0 = ctx.pml_x0
2350
+ pml_y1 = ctx.pml_y1
2351
+ pml_x1 = ctx.pml_x1
2352
+ ca_requires_grad = ctx.ca_requires_grad
2353
+ cb_requires_grad = ctx.cb_requires_grad
2354
+ pml_width = ctx.pml_width
2355
+ storage_mode = ctx.storage_mode
2356
+ shot_bytes_uncomp = ctx.shot_bytes_uncomp
2357
+
2358
+ import ctypes
2359
+
2360
+ if storage_mode == STORAGE_DISK:
2361
+ ey_filenames_ptr = ctypes.cast(
2362
+ ctx.backward_storage_filename_arrays[0], ctypes.c_void_p
2363
+ )
2364
+ curl_filenames_ptr = ctypes.cast(
2365
+ ctx.backward_storage_filename_arrays[1], ctypes.c_void_p
2366
+ )
2367
+ else:
2368
+ ey_filenames_ptr = 0
2369
+ curl_filenames_ptr = 0
2370
+
2371
+ # Recalculate PML boundaries for gradient accumulation
2372
+ #
2373
+ # For staggered grid schemes, the backward pass uses an extended PML region
2374
+ # compared to forward. This is because backward calculations
2375
+ # involve spatial derivatives of terms that are themselves spatial derivatives.
2376
+ #
2377
+ # In tide, the padded domain includes both fd_pad and pml_width:
2378
+ # - pml_y0 = fd_pad + pml_width (start of interior, from forward)
2379
+ # - pml_y1 = ny - (fd_pad-1) - pml_width (end of interior, from forward)
2380
+ #
2381
+ # The gradient accumulation region is controlled by loop bounds in C/CUDA
2382
+ # with pml_bounds array and 3-region loop.
2383
+
2384
+ # Ensure grad_r is contiguous
2385
+ if grad_r is None or grad_r.numel() == 0:
2386
+ grad_r = torch.zeros(nt, n_shots, n_receivers, device=device, dtype=dtype)
2387
+ else:
2388
+ grad_r = grad_r.contiguous()
2389
+
2390
+ # Initialize adjoint fields (lambda fields)
2391
+ lambda_ey = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2392
+ lambda_hx = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2393
+ lambda_hz = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2394
+
2395
+ # Initialize adjoint PML memory variables
2396
+ m_lambda_ey_x = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2397
+ m_lambda_ey_z = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2398
+ m_lambda_hx_z = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2399
+ m_lambda_hz_x = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2400
+
2401
+ # Allocate gradient outputs
2402
+ if n_sources > 0:
2403
+ grad_f = torch.zeros(nt, n_shots, n_sources, device=device, dtype=dtype)
2404
+ else:
2405
+ grad_f = torch.empty(0, device=device, dtype=dtype)
2406
+
2407
+ if ca_requires_grad:
2408
+ if ca_batched:
2409
+ grad_ca = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2410
+ else:
2411
+ grad_ca = torch.zeros(ny, nx, device=device, dtype=dtype)
2412
+ # Per-shot workspace for gradient accumulation (needed for CUDA)
2413
+ grad_ca_shot = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2414
+ else:
2415
+ grad_ca = torch.empty(0, device=device, dtype=dtype)
2416
+ grad_ca_shot = torch.empty(0, device=device, dtype=dtype)
2417
+
2418
+ if cb_requires_grad:
2419
+ if cb_batched:
2420
+ grad_cb = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2421
+ else:
2422
+ grad_cb = torch.zeros(ny, nx, device=device, dtype=dtype)
2423
+ # Per-shot workspace for gradient accumulation (needed for CUDA)
2424
+ grad_cb_shot = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2425
+ else:
2426
+ grad_cb = torch.empty(0, device=device, dtype=dtype)
2427
+ grad_cb_shot = torch.empty(0, device=device, dtype=dtype)
2428
+
2429
+ if ca_requires_grad or cb_requires_grad:
2430
+ if ca_batched:
2431
+ grad_eps = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2432
+ grad_sigma = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
2433
+ else:
2434
+ grad_eps = torch.zeros(ny, nx, device=device, dtype=dtype)
2435
+ grad_sigma = torch.zeros(ny, nx, device=device, dtype=dtype)
2436
+ else:
2437
+ grad_eps = torch.empty(0, device=device, dtype=dtype)
2438
+ grad_sigma = torch.empty(0, device=device, dtype=dtype)
2439
+
2440
+ # Get device index for CUDA
2441
+ device_idx = (
2442
+ device.index if device.type == "cuda" and device.index is not None else 0
2443
+ )
2444
+
2445
+ # Get callback-related context
2446
+ backward_callback = ctx.backward_callback
2447
+ callback_frequency = ctx.callback_frequency
2448
+ fd_pad_ctx = ctx.fd_pad
2449
+ models = ctx.models
2450
+ n_threads = ctx.n_threads
2451
+
2452
+ # Get the backend function
2453
+ backward_func = backend_utils.get_backend_function(
2454
+ "maxwell_tm", "backward", accuracy, dtype, device
2455
+ )
2456
+
2457
+ # Determine effective callback frequency
2458
+ if backward_callback is None:
2459
+ effective_callback_freq = nt // step_ratio
2460
+ else:
2461
+ effective_callback_freq = callback_frequency
2462
+
2463
+ # Chunked backward propagation with callback support
2464
+ # Backward propagation goes from nt to 0
2465
+ for step in range(nt // step_ratio, 0, -effective_callback_freq):
2466
+ step_nt = min(step, effective_callback_freq) * step_ratio
2467
+ start_t = step * step_ratio
2468
+
2469
+ # Call the C/CUDA backward function for this chunk
2470
+ backward_func(
2471
+ backend_utils.tensor_to_ptr(ca),
2472
+ backend_utils.tensor_to_ptr(cb),
2473
+ backend_utils.tensor_to_ptr(cq),
2474
+ backend_utils.tensor_to_ptr(grad_r),
2475
+ backend_utils.tensor_to_ptr(lambda_ey),
2476
+ backend_utils.tensor_to_ptr(lambda_hx),
2477
+ backend_utils.tensor_to_ptr(lambda_hz),
2478
+ backend_utils.tensor_to_ptr(m_lambda_ey_x),
2479
+ backend_utils.tensor_to_ptr(m_lambda_ey_z),
2480
+ backend_utils.tensor_to_ptr(m_lambda_hx_z),
2481
+ backend_utils.tensor_to_ptr(m_lambda_hz_x),
2482
+ backend_utils.tensor_to_ptr(ey_store_1),
2483
+ backend_utils.tensor_to_ptr(ey_store_3),
2484
+ ey_filenames_ptr,
2485
+ backend_utils.tensor_to_ptr(curl_store_1),
2486
+ backend_utils.tensor_to_ptr(curl_store_3),
2487
+ curl_filenames_ptr,
2488
+ backend_utils.tensor_to_ptr(grad_f),
2489
+ backend_utils.tensor_to_ptr(grad_ca),
2490
+ backend_utils.tensor_to_ptr(grad_cb),
2491
+ backend_utils.tensor_to_ptr(grad_eps),
2492
+ backend_utils.tensor_to_ptr(grad_sigma),
2493
+ backend_utils.tensor_to_ptr(grad_ca_shot),
2494
+ backend_utils.tensor_to_ptr(grad_cb_shot),
2495
+ backend_utils.tensor_to_ptr(ay),
2496
+ backend_utils.tensor_to_ptr(by),
2497
+ backend_utils.tensor_to_ptr(ay_h),
2498
+ backend_utils.tensor_to_ptr(by_h),
2499
+ backend_utils.tensor_to_ptr(ax),
2500
+ backend_utils.tensor_to_ptr(bx),
2501
+ backend_utils.tensor_to_ptr(ax_h),
2502
+ backend_utils.tensor_to_ptr(bx_h),
2503
+ backend_utils.tensor_to_ptr(ky),
2504
+ backend_utils.tensor_to_ptr(ky_h),
2505
+ backend_utils.tensor_to_ptr(kx),
2506
+ backend_utils.tensor_to_ptr(kx_h),
2507
+ backend_utils.tensor_to_ptr(sources_i),
2508
+ backend_utils.tensor_to_ptr(receivers_i),
2509
+ rdy,
2510
+ rdx,
2511
+ dt,
2512
+ step_nt, # number of steps to run in this chunk
2513
+ n_shots,
2514
+ ny,
2515
+ nx,
2516
+ n_sources,
2517
+ n_receivers,
2518
+ step_ratio,
2519
+ storage_mode,
2520
+ shot_bytes_uncomp,
2521
+ ca_requires_grad,
2522
+ cb_requires_grad,
2523
+ ca_batched,
2524
+ cb_batched,
2525
+ cq_batched,
2526
+ start_t, # starting time step for this chunk
2527
+ pml_y0, # Use original PML boundaries for adjoint propagation
2528
+ pml_x0,
2529
+ pml_y1,
2530
+ pml_x1,
2531
+ n_threads,
2532
+ device_idx,
2533
+ )
2534
+
2535
+ # Call backward callback after each chunk
2536
+ if backward_callback is not None:
2537
+ # The time step index is step - 1 because the callback is
2538
+ # executed after the calculations for the current backward
2539
+ # step are complete
2540
+ callback_wavefields = {
2541
+ "lambda_Ey": lambda_ey,
2542
+ "lambda_Hx": lambda_hx,
2543
+ "lambda_Hz": lambda_hz,
2544
+ "m_lambda_Ey_x": m_lambda_ey_x,
2545
+ "m_lambda_Ey_z": m_lambda_ey_z,
2546
+ "m_lambda_Hx_z": m_lambda_hx_z,
2547
+ "m_lambda_Hz_x": m_lambda_hz_x,
2548
+ }
2549
+ callback_gradients = {}
2550
+ if ca_requires_grad:
2551
+ callback_gradients["ca"] = grad_ca
2552
+ if cb_requires_grad:
2553
+ callback_gradients["cb"] = grad_cb
2554
+ if ca_requires_grad or cb_requires_grad:
2555
+ callback_gradients["epsilon"] = grad_eps
2556
+ callback_gradients["sigma"] = grad_sigma
2557
+
2558
+ backward_callback(
2559
+ CallbackState(
2560
+ dt=dt,
2561
+ step=step - 1,
2562
+ nt=nt // step_ratio,
2563
+ wavefields=callback_wavefields,
2564
+ models=models,
2565
+ gradients=callback_gradients,
2566
+ fd_pad=list(fd_pad_ctx),
2567
+ pml_width=list(pml_width),
2568
+ is_backward=True,
2569
+ )
2570
+ )
2571
+
2572
+ # Return gradients for all inputs
2573
+ # Order: ca, cb, cq, source_amplitudes_scaled,
2574
+ # ay, by, ay_h, by_h, ax, bx, ax_h, bx_h,
2575
+ # ky, ky_h, kx, kx_h,
2576
+ # sources_i, receivers_i,
2577
+ # rdy, rdx, dt, nt, n_shots, ny, nx, n_sources, n_receivers,
2578
+ # step_ratio, accuracy, ca_batched, cb_batched, cq_batched,
2579
+ # pml_y0, pml_x0, pml_y1, pml_x1,
2580
+ # fd_pad, pml_width, models, backward_callback, callback_frequency,
2581
+ # Ey, Hx, Hz, m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x
2582
+
2583
+ # Flatten grad_f to match input shape [nt * n_shots * n_sources]
2584
+ if n_sources > 0:
2585
+ grad_f_flat = grad_f.reshape(nt * n_shots * n_sources)
2586
+ else:
2587
+ grad_f_flat = None
2588
+
2589
+ # Match gradient shapes to input shapes
2590
+ # Input ca, cb are [1, ny, nx] but grad_ca, grad_cb are [ny, nx] when not batched
2591
+ if ca_requires_grad and not ca_batched:
2592
+ grad_ca = grad_ca.unsqueeze(0) # [ny, nx] -> [1, ny, nx]
2593
+ if cb_requires_grad and not cb_batched:
2594
+ grad_cb = grad_cb.unsqueeze(0) # [ny, nx] -> [1, ny, nx]
2595
+
2596
+ _release_ctx_handle(getattr(ctx, "_ctx_handle_id", None))
2597
+ return (
2598
+ grad_ca if ca_requires_grad else None, # ca
2599
+ grad_cb if cb_requires_grad else None, # cb
2600
+ None, # cq
2601
+ grad_f_flat, # source_amplitudes_scaled
2602
+ None,
2603
+ None,
2604
+ None,
2605
+ None, # ay, by, ay_h, by_h
2606
+ None,
2607
+ None,
2608
+ None,
2609
+ None, # ax, bx, ax_h, bx_h
2610
+ None,
2611
+ None,
2612
+ None,
2613
+ None, # ky, ky_h, kx, kx_h
2614
+ None,
2615
+ None, # sources_i, receivers_i
2616
+ None,
2617
+ None,
2618
+ None, # rdy, rdx, dt
2619
+ None,
2620
+ None,
2621
+ None,
2622
+ None, # nt, n_shots, ny, nx
2623
+ None,
2624
+ None, # n_sources, n_receivers
2625
+ None, # step_ratio
2626
+ None, # accuracy
2627
+ None,
2628
+ None,
2629
+ None, # ca_batched, cb_batched, cq_batched
2630
+ None,
2631
+ None,
2632
+ None,
2633
+ None, # pml_y0, pml_x0, pml_y1, pml_x1
2634
+ None,
2635
+ None,
2636
+ None, # fd_pad, pml_width, models
2637
+ None,
2638
+ None,
2639
+ None, # forward_callback, backward_callback, callback_frequency
2640
+ None,
2641
+ None,
2642
+ None, # storage_mode_str, storage_path, storage_compression
2643
+ None,
2644
+ None,
2645
+ None, # Ey, Hx, Hz
2646
+ None,
2647
+ None,
2648
+ None,
2649
+ None, # m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x
2650
+ None, # n_threads
2651
+ )