tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tide/__init__.py +65 -0
- tide/autograd_utils.py +26 -0
- tide/backend_utils.py +536 -0
- tide/callbacks.py +348 -0
- tide/cfl.py +64 -0
- tide/csrc/CMakeLists.txt +263 -0
- tide/csrc/common_cpu.h +31 -0
- tide/csrc/common_gpu.h +56 -0
- tide/csrc/maxwell.c +2133 -0
- tide/csrc/maxwell.cu +2297 -0
- tide/csrc/maxwell_born.cu +0 -0
- tide/csrc/staggered_grid.h +175 -0
- tide/csrc/staggered_grid_3d.h +124 -0
- tide/csrc/storage_utils.c +78 -0
- tide/csrc/storage_utils.cu +135 -0
- tide/csrc/storage_utils.h +36 -0
- tide/grid_utils.py +31 -0
- tide/maxwell.py +2651 -0
- tide/padding.py +139 -0
- tide/resampling.py +246 -0
- tide/staggered.py +567 -0
- tide/storage.py +131 -0
- tide/tide/libtide_C.so +0 -0
- tide/utils.py +274 -0
- tide/validation.py +71 -0
- tide/wavelets.py +72 -0
- tide_gpr-0.0.9.dist-info/METADATA +256 -0
- tide_gpr-0.0.9.dist-info/RECORD +31 -0
- tide_gpr-0.0.9.dist-info/WHEEL +5 -0
- tide_gpr-0.0.9.dist-info/licenses/LICENSE +46 -0
- tide_gpr.libs/libgomp-24e2ab19.so.1.0.0 +0 -0
tide/maxwell.py
ADDED
|
@@ -0,0 +1,2651 @@
|
|
|
1
|
+
from typing import Any, Callable, Optional, Sequence, Union
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import staggered
|
|
6
|
+
from .autograd_utils import (
|
|
7
|
+
_get_ctx_handle,
|
|
8
|
+
_register_ctx_handle,
|
|
9
|
+
_release_ctx_handle,
|
|
10
|
+
)
|
|
11
|
+
from .callbacks import Callback, CallbackState
|
|
12
|
+
from .cfl import cfl_condition
|
|
13
|
+
from .grid_utils import (
|
|
14
|
+
_normalize_grid_spacing_2d,
|
|
15
|
+
_normalize_pml_width_2d,
|
|
16
|
+
)
|
|
17
|
+
from .resampling import downsample_and_movedim, upsample
|
|
18
|
+
from .storage import (
|
|
19
|
+
_CPU_STORAGE_BUFFERS,
|
|
20
|
+
STORAGE_CPU,
|
|
21
|
+
STORAGE_DEVICE,
|
|
22
|
+
STORAGE_DISK,
|
|
23
|
+
STORAGE_NONE,
|
|
24
|
+
TemporaryStorage,
|
|
25
|
+
_normalize_storage_compression,
|
|
26
|
+
_resolve_storage_compression,
|
|
27
|
+
storage_mode_to_int,
|
|
28
|
+
)
|
|
29
|
+
from .utils import C0, prepare_parameters
|
|
30
|
+
from .validation import (
|
|
31
|
+
validate_freq_taper_frac,
|
|
32
|
+
validate_model_gradient_sampling_interval,
|
|
33
|
+
validate_time_pad_frac,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class MaxwellTM(torch.nn.Module):
|
|
38
|
+
"""2D TM mode Maxwell equations solver using FDTD method.
|
|
39
|
+
|
|
40
|
+
This module solves the TM (Transverse Magnetic) mode Maxwell equations
|
|
41
|
+
in 2D with fields (Ey, Hx, Hz) using the FDTD method with CPML absorbing
|
|
42
|
+
boundary conditions.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
epsilon: Relative permittivity tensor [ny, nx].
|
|
46
|
+
For vacuum/air, use 1.0. For common materials:
|
|
47
|
+
- Water: ~80
|
|
48
|
+
- Glass: ~4-7
|
|
49
|
+
- Soil (dry): ~3-5
|
|
50
|
+
- Concrete: ~4-8
|
|
51
|
+
sigma: Electrical conductivity tensor [ny, nx] in S/m.
|
|
52
|
+
For lossless media, use 0.0.
|
|
53
|
+
mu: Relative permeability tensor [ny, nx].
|
|
54
|
+
For most non-magnetic materials, use 1.0.
|
|
55
|
+
grid_spacing: Grid spacing in meters. Can be a single value (same for
|
|
56
|
+
both directions) or a sequence [dy, dx].
|
|
57
|
+
epsilon_requires_grad: Whether to compute gradients for permittivity.
|
|
58
|
+
sigma_requires_grad: Whether to compute gradients for conductivity.
|
|
59
|
+
|
|
60
|
+
Note:
|
|
61
|
+
The input parameters are RELATIVE values (dimensionless). They will be
|
|
62
|
+
multiplied internally by the vacuum permittivity (ε₀ = 8.854e-12 F/m)
|
|
63
|
+
and vacuum permeability (μ₀ = 1.257e-6 H/m) respectively.
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
epsilon: torch.Tensor,
|
|
69
|
+
sigma: torch.Tensor,
|
|
70
|
+
mu: torch.Tensor,
|
|
71
|
+
grid_spacing: Union[float, Sequence[float]],
|
|
72
|
+
epsilon_requires_grad: Optional[bool] = None,
|
|
73
|
+
sigma_requires_grad: Optional[bool] = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
super().__init__()
|
|
76
|
+
if epsilon_requires_grad is not None and not isinstance(
|
|
77
|
+
epsilon_requires_grad, bool
|
|
78
|
+
):
|
|
79
|
+
raise TypeError(
|
|
80
|
+
f"epsilon_requires_grad must be bool or None, "
|
|
81
|
+
f"got {type(epsilon_requires_grad).__name__}",
|
|
82
|
+
)
|
|
83
|
+
if not isinstance(epsilon, torch.Tensor):
|
|
84
|
+
raise TypeError(
|
|
85
|
+
f"epsilon must be torch.Tensor, got {type(epsilon).__name__}",
|
|
86
|
+
)
|
|
87
|
+
if sigma_requires_grad is not None and not isinstance(
|
|
88
|
+
sigma_requires_grad, bool
|
|
89
|
+
):
|
|
90
|
+
raise TypeError(
|
|
91
|
+
f"sigma_requires_grad must be bool or None, "
|
|
92
|
+
f"got {type(sigma_requires_grad).__name__}",
|
|
93
|
+
)
|
|
94
|
+
if not isinstance(sigma, torch.Tensor):
|
|
95
|
+
raise TypeError(
|
|
96
|
+
f"sigma must be torch.Tensor, got {type(sigma).__name__}",
|
|
97
|
+
)
|
|
98
|
+
if not isinstance(mu, torch.Tensor):
|
|
99
|
+
raise TypeError(
|
|
100
|
+
f"mu must be torch.Tensor, got {type(mu).__name__}",
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# If requires_grad not specified, preserve the input tensor's setting
|
|
104
|
+
if epsilon_requires_grad is None:
|
|
105
|
+
epsilon_requires_grad = epsilon.requires_grad
|
|
106
|
+
if sigma_requires_grad is None:
|
|
107
|
+
sigma_requires_grad = sigma.requires_grad
|
|
108
|
+
|
|
109
|
+
self.epsilon = torch.nn.Parameter(epsilon, requires_grad=epsilon_requires_grad)
|
|
110
|
+
self.sigma = torch.nn.Parameter(sigma, requires_grad=sigma_requires_grad)
|
|
111
|
+
self.register_buffer("mu", mu) # In normal we don't optimize mu
|
|
112
|
+
self.grid_spacing = grid_spacing
|
|
113
|
+
|
|
114
|
+
def forward(
|
|
115
|
+
self,
|
|
116
|
+
dt: float,
|
|
117
|
+
source_amplitude: Optional[torch.Tensor], # [shot,source,time]
|
|
118
|
+
source_location: Optional[torch.Tensor], # [shot,source,2]
|
|
119
|
+
receiver_location: Optional[torch.Tensor], # [shot,receiver,2]
|
|
120
|
+
stencil: int = 2,
|
|
121
|
+
pml_width: Union[int, Sequence[int]] = 20,
|
|
122
|
+
max_vel: Optional[float] = None,
|
|
123
|
+
Ey_0: Optional[torch.Tensor] = None,
|
|
124
|
+
Hx_0: Optional[torch.Tensor] = None,
|
|
125
|
+
Hz_0: Optional[torch.Tensor] = None,
|
|
126
|
+
m_Ey_x: Optional[torch.Tensor] = None,
|
|
127
|
+
m_Ey_z: Optional[torch.Tensor] = None,
|
|
128
|
+
m_Hx_z: Optional[torch.Tensor] = None,
|
|
129
|
+
m_Hz_x: Optional[torch.Tensor] = None,
|
|
130
|
+
nt: Optional[int] = None,
|
|
131
|
+
model_gradient_sampling_interval: int = 1,
|
|
132
|
+
freq_taper_frac: float = 0.0,
|
|
133
|
+
time_pad_frac: float = 0.0,
|
|
134
|
+
time_taper: bool = False,
|
|
135
|
+
save_snapshots: Optional[bool] = None,
|
|
136
|
+
forward_callback: Optional[Callback] = None,
|
|
137
|
+
backward_callback: Optional[Callback] = None,
|
|
138
|
+
callback_frequency: int = 1,
|
|
139
|
+
python_backend: Union[bool, str] = False,
|
|
140
|
+
storage_mode: str = "device",
|
|
141
|
+
storage_path: str = ".",
|
|
142
|
+
storage_compression: Union[bool, str] = False,
|
|
143
|
+
storage_bytes_limit_device: Optional[int] = None,
|
|
144
|
+
storage_bytes_limit_host: Optional[int] = None,
|
|
145
|
+
storage_chunk_steps: int = 0,
|
|
146
|
+
):
|
|
147
|
+
# Type assertions for buffer and parameter tensors
|
|
148
|
+
assert isinstance(self.epsilon, torch.Tensor)
|
|
149
|
+
assert isinstance(self.sigma, torch.Tensor)
|
|
150
|
+
assert isinstance(self.mu, torch.Tensor)
|
|
151
|
+
return maxwelltm(
|
|
152
|
+
self.epsilon,
|
|
153
|
+
self.sigma,
|
|
154
|
+
self.mu,
|
|
155
|
+
self.grid_spacing,
|
|
156
|
+
dt,
|
|
157
|
+
source_amplitude,
|
|
158
|
+
source_location,
|
|
159
|
+
receiver_location,
|
|
160
|
+
stencil,
|
|
161
|
+
pml_width,
|
|
162
|
+
max_vel,
|
|
163
|
+
Ey_0,
|
|
164
|
+
Hx_0,
|
|
165
|
+
Hz_0,
|
|
166
|
+
m_Ey_x,
|
|
167
|
+
m_Ey_z,
|
|
168
|
+
m_Hx_z,
|
|
169
|
+
m_Hz_x,
|
|
170
|
+
nt,
|
|
171
|
+
model_gradient_sampling_interval,
|
|
172
|
+
freq_taper_frac,
|
|
173
|
+
time_pad_frac,
|
|
174
|
+
time_taper,
|
|
175
|
+
save_snapshots,
|
|
176
|
+
forward_callback,
|
|
177
|
+
backward_callback,
|
|
178
|
+
callback_frequency,
|
|
179
|
+
python_backend,
|
|
180
|
+
storage_mode,
|
|
181
|
+
storage_path,
|
|
182
|
+
storage_compression,
|
|
183
|
+
storage_bytes_limit_device,
|
|
184
|
+
storage_bytes_limit_host,
|
|
185
|
+
storage_chunk_steps,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def maxwelltm(
|
|
190
|
+
epsilon: torch.Tensor,
|
|
191
|
+
sigma: torch.Tensor,
|
|
192
|
+
mu: torch.Tensor,
|
|
193
|
+
grid_spacing: Union[float, Sequence[float]],
|
|
194
|
+
dt: float,
|
|
195
|
+
source_amplitude: Optional[torch.Tensor],
|
|
196
|
+
source_location: Optional[torch.Tensor],
|
|
197
|
+
receiver_location: Optional[torch.Tensor],
|
|
198
|
+
stencil: int = 2,
|
|
199
|
+
pml_width: Union[int, Sequence[int]] = 20,
|
|
200
|
+
max_vel: Optional[float] = None,
|
|
201
|
+
Ey_0: Optional[torch.Tensor] = None,
|
|
202
|
+
Hx_0: Optional[torch.Tensor] = None,
|
|
203
|
+
Hz_0: Optional[torch.Tensor] = None,
|
|
204
|
+
m_Ey_x: Optional[torch.Tensor] = None,
|
|
205
|
+
m_Ey_z: Optional[torch.Tensor] = None,
|
|
206
|
+
m_Hx_z: Optional[torch.Tensor] = None,
|
|
207
|
+
m_Hz_x: Optional[torch.Tensor] = None,
|
|
208
|
+
nt: Optional[int] = None,
|
|
209
|
+
model_gradient_sampling_interval: int = 1,
|
|
210
|
+
freq_taper_frac: float = 0.0,
|
|
211
|
+
time_pad_frac: float = 0.0,
|
|
212
|
+
time_taper: bool = False,
|
|
213
|
+
save_snapshots: Optional[bool] = None,
|
|
214
|
+
forward_callback: Optional[Callback] = None,
|
|
215
|
+
backward_callback: Optional[Callback] = None,
|
|
216
|
+
callback_frequency: int = 1,
|
|
217
|
+
python_backend: Union[bool, str] = False,
|
|
218
|
+
storage_mode: str = "device",
|
|
219
|
+
storage_path: str = ".",
|
|
220
|
+
storage_compression: Union[bool, str] = False,
|
|
221
|
+
storage_bytes_limit_device: Optional[int] = None,
|
|
222
|
+
storage_bytes_limit_host: Optional[int] = None,
|
|
223
|
+
storage_chunk_steps: int = 0,
|
|
224
|
+
n_threads: Optional[int] = None,
|
|
225
|
+
):
|
|
226
|
+
"""2D TM mode Maxwell equations solver.
|
|
227
|
+
|
|
228
|
+
This is the main entry point for Maxwell TM propagation. It automatically
|
|
229
|
+
handles CFL condition checking and time step resampling when needed.
|
|
230
|
+
|
|
231
|
+
If the user-provided time step (dt) is too large for numerical stability,
|
|
232
|
+
the source signal will be upsampled internally and receiver data will be
|
|
233
|
+
downsampled back to the original sampling rate.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
epsilon: Relative permittivity tensor [ny, nx].
|
|
237
|
+
sigma: Electrical conductivity tensor [ny, nx] in S/m.
|
|
238
|
+
mu: Relative permeability tensor [ny, nx].
|
|
239
|
+
grid_spacing: Grid spacing in meters. Single value or [dy, dx].
|
|
240
|
+
dt: Time step in seconds.
|
|
241
|
+
source_amplitude: Source waveform [n_shots, n_sources, nt].
|
|
242
|
+
source_location: Source locations [n_shots, n_sources, 2].
|
|
243
|
+
receiver_location: Receiver locations [n_shots, n_receivers, 2].
|
|
244
|
+
stencil: FD stencil order (2, 4, 6, or 8).
|
|
245
|
+
pml_width: PML width (single int or [top, bottom, left, right]).
|
|
246
|
+
max_vel: Maximum wave velocity. If None, computed from model.
|
|
247
|
+
Ey_0, Hx_0, Hz_0: Initial field values.
|
|
248
|
+
m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x: Initial CPML memory variables.
|
|
249
|
+
nt: Number of time steps (required if source_amplitude is None).
|
|
250
|
+
model_gradient_sampling_interval: Interval for storing gradient snapshots.
|
|
251
|
+
Values > 1 reduce memory usage during backpropagation.
|
|
252
|
+
freq_taper_frac: Fraction of frequency spectrum to taper (0.0-1.0).
|
|
253
|
+
Helps reduce ringing artifacts during resampling.
|
|
254
|
+
time_pad_frac: Fraction for zero padding before FFT (0.0-1.0).
|
|
255
|
+
Helps reduce wraparound artifacts during resampling.
|
|
256
|
+
time_taper: Whether to apply Hann window (mainly for testing).
|
|
257
|
+
save_snapshots: Whether to save wavefield snapshots for gradient computation.
|
|
258
|
+
If None (default), snapshots are saved only when model parameters
|
|
259
|
+
require gradients. Set to False to disable snapshot saving even
|
|
260
|
+
when gradients are needed. Set to True to force snapshot saving
|
|
261
|
+
even without gradients.
|
|
262
|
+
forward_callback: Callback function called during forward propagation.
|
|
263
|
+
backward_callback: Callback function called during backward (adjoint)
|
|
264
|
+
propagation. Receives the same CallbackState as forward_callback,
|
|
265
|
+
but with is_backward=True and gradients available.
|
|
266
|
+
callback_frequency: How often to call the callback.
|
|
267
|
+
python_backend: False for C/CUDA, True or 'eager'/'jit'/'compile' for Python.
|
|
268
|
+
storage_mode: Where to store intermediate snapshots for the ASM
|
|
269
|
+
backward pass. One of "device", "cpu", "disk", "none", or "auto".
|
|
270
|
+
storage_path: Base path for disk storage when storage_mode="disk".
|
|
271
|
+
storage_compression: Compression for stored snapshots. Use False/True
|
|
272
|
+
(True == BF16), or one of "bf16" / "fp8".
|
|
273
|
+
storage_bytes_limit_device: Soft limit in bytes for device snapshot
|
|
274
|
+
storage when storage_mode="auto".
|
|
275
|
+
storage_bytes_limit_host: Soft limit in bytes for host snapshot
|
|
276
|
+
storage when storage_mode="auto".
|
|
277
|
+
storage_chunk_steps: Optional chunk size (in stored steps) for
|
|
278
|
+
CPU/disk modes. Currently unused.
|
|
279
|
+
n_threads: OpenMP thread count for CPU backend. None uses the OpenMP default.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
Tuple of (Ey, Hx, Hz, m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x, receiver_amplitudes).
|
|
283
|
+
"""
|
|
284
|
+
# Validate resampling parameters
|
|
285
|
+
model_gradient_sampling_interval = validate_model_gradient_sampling_interval(
|
|
286
|
+
model_gradient_sampling_interval
|
|
287
|
+
)
|
|
288
|
+
freq_taper_frac = validate_freq_taper_frac(freq_taper_frac)
|
|
289
|
+
time_pad_frac = validate_time_pad_frac(time_pad_frac)
|
|
290
|
+
|
|
291
|
+
# Check inputs
|
|
292
|
+
if source_location is not None and source_location.numel() > 0:
|
|
293
|
+
if source_location[..., 0].max() >= epsilon.shape[-2]:
|
|
294
|
+
raise RuntimeError(
|
|
295
|
+
f"Source location dim 0 must be less than {epsilon.shape[-2]}"
|
|
296
|
+
)
|
|
297
|
+
if source_location[..., 1].max() >= epsilon.shape[-1]:
|
|
298
|
+
raise RuntimeError(
|
|
299
|
+
f"Source location dim 1 must be less than {epsilon.shape[-1]}"
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
if receiver_location is not None and receiver_location.numel() > 0:
|
|
303
|
+
if receiver_location[..., 0].max() >= epsilon.shape[-2]:
|
|
304
|
+
raise RuntimeError(
|
|
305
|
+
f"Receiver location dim 0 must be less than {epsilon.shape[-2]}"
|
|
306
|
+
)
|
|
307
|
+
if receiver_location[..., 1].max() >= epsilon.shape[-1]:
|
|
308
|
+
raise RuntimeError(
|
|
309
|
+
f"Receiver location dim 1 must be less than {epsilon.shape[-1]}"
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
if not isinstance(callback_frequency, int):
|
|
313
|
+
raise TypeError("callback_frequency must be an int.")
|
|
314
|
+
if callback_frequency <= 0:
|
|
315
|
+
raise ValueError("callback_frequency must be positive.")
|
|
316
|
+
|
|
317
|
+
# Normalize grid_spacing to list
|
|
318
|
+
grid_spacing_list = _normalize_grid_spacing_2d(grid_spacing)
|
|
319
|
+
|
|
320
|
+
# Compute maximum velocity if not provided
|
|
321
|
+
if max_vel is None:
|
|
322
|
+
# For EM waves: v = c0 / sqrt(epsilon_r * mu_r)
|
|
323
|
+
max_vel_computed = float((1.0 / torch.sqrt(epsilon * mu)).max().item()) * C0
|
|
324
|
+
else:
|
|
325
|
+
max_vel_computed = max_vel
|
|
326
|
+
|
|
327
|
+
# Check CFL condition and compute step_ratio
|
|
328
|
+
inner_dt, step_ratio = cfl_condition(grid_spacing_list, dt, max_vel_computed)
|
|
329
|
+
|
|
330
|
+
# Upsample source if needed for CFL
|
|
331
|
+
source_amplitude_internal = source_amplitude
|
|
332
|
+
if step_ratio > 1 and source_amplitude is not None and source_amplitude.numel() > 0:
|
|
333
|
+
source_amplitude_internal = upsample(
|
|
334
|
+
source_amplitude,
|
|
335
|
+
step_ratio,
|
|
336
|
+
freq_taper_frac=freq_taper_frac,
|
|
337
|
+
time_pad_frac=time_pad_frac,
|
|
338
|
+
time_taper=time_taper,
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Compute internal number of time steps
|
|
342
|
+
nt_internal = None
|
|
343
|
+
if nt is not None:
|
|
344
|
+
nt_internal = nt * step_ratio
|
|
345
|
+
elif source_amplitude_internal is not None:
|
|
346
|
+
nt_internal = source_amplitude_internal.shape[-1]
|
|
347
|
+
|
|
348
|
+
# Call the propagation function with internal dt and upsampled source
|
|
349
|
+
result = maxwell_func(
|
|
350
|
+
python_backend,
|
|
351
|
+
epsilon,
|
|
352
|
+
sigma,
|
|
353
|
+
mu,
|
|
354
|
+
grid_spacing,
|
|
355
|
+
inner_dt, # Use internal time step for CFL compliance
|
|
356
|
+
source_amplitude_internal,
|
|
357
|
+
source_location,
|
|
358
|
+
receiver_location,
|
|
359
|
+
stencil,
|
|
360
|
+
pml_width,
|
|
361
|
+
max_vel_computed, # Pass computed max_vel so it's not recomputed
|
|
362
|
+
Ey_0,
|
|
363
|
+
Hx_0,
|
|
364
|
+
Hz_0,
|
|
365
|
+
m_Ey_x,
|
|
366
|
+
m_Ey_z,
|
|
367
|
+
m_Hx_z,
|
|
368
|
+
m_Hz_x,
|
|
369
|
+
nt_internal,
|
|
370
|
+
model_gradient_sampling_interval,
|
|
371
|
+
freq_taper_frac,
|
|
372
|
+
time_pad_frac,
|
|
373
|
+
time_taper,
|
|
374
|
+
save_snapshots,
|
|
375
|
+
forward_callback,
|
|
376
|
+
backward_callback,
|
|
377
|
+
callback_frequency,
|
|
378
|
+
storage_mode,
|
|
379
|
+
storage_path,
|
|
380
|
+
storage_compression,
|
|
381
|
+
storage_bytes_limit_device,
|
|
382
|
+
storage_bytes_limit_host,
|
|
383
|
+
storage_chunk_steps,
|
|
384
|
+
n_threads,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Unpack result
|
|
388
|
+
(
|
|
389
|
+
Ey_out,
|
|
390
|
+
Hx_out,
|
|
391
|
+
Hz_out,
|
|
392
|
+
m_Ey_x_out,
|
|
393
|
+
m_Ey_z_out,
|
|
394
|
+
m_Hx_z_out,
|
|
395
|
+
m_Hz_x_out,
|
|
396
|
+
receiver_amplitudes,
|
|
397
|
+
) = result
|
|
398
|
+
|
|
399
|
+
# Downsample receiver data if we upsampled
|
|
400
|
+
if step_ratio > 1 and receiver_amplitudes.numel() > 0:
|
|
401
|
+
receiver_amplitudes = downsample_and_movedim(
|
|
402
|
+
receiver_amplitudes,
|
|
403
|
+
step_ratio,
|
|
404
|
+
freq_taper_frac=freq_taper_frac,
|
|
405
|
+
time_pad_frac=time_pad_frac,
|
|
406
|
+
time_taper=time_taper,
|
|
407
|
+
)
|
|
408
|
+
# Move time back to first dimension to match expected output format
|
|
409
|
+
receiver_amplitudes = torch.movedim(receiver_amplitudes, -1, 0)
|
|
410
|
+
|
|
411
|
+
return (
|
|
412
|
+
Ey_out,
|
|
413
|
+
Hx_out,
|
|
414
|
+
Hz_out,
|
|
415
|
+
m_Ey_x_out,
|
|
416
|
+
m_Ey_z_out,
|
|
417
|
+
m_Hx_z_out,
|
|
418
|
+
m_Hz_x_out,
|
|
419
|
+
receiver_amplitudes,
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
_update_E_jit: Optional[Callable] = None
|
|
424
|
+
_update_E_compile: Optional[Callable] = None
|
|
425
|
+
_update_H_jit: Optional[Callable] = None
|
|
426
|
+
_update_H_compile: Optional[Callable] = None
|
|
427
|
+
|
|
428
|
+
# These will be set after the functions are defined
|
|
429
|
+
_update_E_opt: Optional[Callable] = None
|
|
430
|
+
_update_H_opt: Optional[Callable] = None
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def maxwell_func(
|
|
434
|
+
python_backend: Union[bool, str],
|
|
435
|
+
*args,
|
|
436
|
+
) -> tuple[
|
|
437
|
+
torch.Tensor, # Ey
|
|
438
|
+
torch.Tensor, # Hx
|
|
439
|
+
torch.Tensor, # Hz
|
|
440
|
+
torch.Tensor, # m_Ey_x
|
|
441
|
+
torch.Tensor, # m_Ey_z
|
|
442
|
+
torch.Tensor, # m_Hx_z
|
|
443
|
+
torch.Tensor, # m_Hz_x
|
|
444
|
+
torch.Tensor, # receiver_amplitudes
|
|
445
|
+
]:
|
|
446
|
+
"""Dispatch to Python or C/CUDA backend for Maxwell propagation."""
|
|
447
|
+
global _update_E_jit, _update_E_compile, _update_E_opt
|
|
448
|
+
global _update_H_jit, _update_H_compile, _update_H_opt
|
|
449
|
+
|
|
450
|
+
# Check if we should use Python backend or C/CUDA backend
|
|
451
|
+
use_python = python_backend
|
|
452
|
+
if not use_python:
|
|
453
|
+
# Try to use C/CUDA backend
|
|
454
|
+
try:
|
|
455
|
+
from . import backend_utils
|
|
456
|
+
|
|
457
|
+
if not backend_utils.is_backend_available():
|
|
458
|
+
import warnings
|
|
459
|
+
|
|
460
|
+
warnings.warn(
|
|
461
|
+
"C/CUDA backend not available, falling back to Python backend. "
|
|
462
|
+
"To use the C/CUDA backend, compile the library first.",
|
|
463
|
+
RuntimeWarning,
|
|
464
|
+
)
|
|
465
|
+
use_python = True
|
|
466
|
+
except ImportError:
|
|
467
|
+
import warnings
|
|
468
|
+
|
|
469
|
+
warnings.warn(
|
|
470
|
+
"backend_utils not available, falling back to Python backend.",
|
|
471
|
+
RuntimeWarning,
|
|
472
|
+
)
|
|
473
|
+
use_python = True
|
|
474
|
+
|
|
475
|
+
if use_python:
|
|
476
|
+
if python_backend is True or python_backend is False:
|
|
477
|
+
mode = "eager" # Default to eager
|
|
478
|
+
elif isinstance(python_backend, str):
|
|
479
|
+
mode = python_backend.lower()
|
|
480
|
+
else:
|
|
481
|
+
raise TypeError(
|
|
482
|
+
f"python_backend must be bool or str, but got {type(python_backend)}"
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
if mode == "jit":
|
|
486
|
+
if _update_E_jit is None:
|
|
487
|
+
_update_E_jit = torch.jit.script(update_E)
|
|
488
|
+
_update_E_opt = _update_E_jit
|
|
489
|
+
if _update_H_jit is None:
|
|
490
|
+
_update_H_jit = torch.jit.script(update_H)
|
|
491
|
+
_update_H_opt = _update_H_jit
|
|
492
|
+
elif mode == "compile":
|
|
493
|
+
if _update_E_compile is None:
|
|
494
|
+
_update_E_compile = torch.compile(update_E, fullgraph=True)
|
|
495
|
+
_update_E_opt = _update_E_compile
|
|
496
|
+
if _update_H_compile is None:
|
|
497
|
+
_update_H_compile = torch.compile(update_H, fullgraph=True)
|
|
498
|
+
_update_H_opt = _update_H_compile
|
|
499
|
+
elif mode == "eager":
|
|
500
|
+
_update_E_opt = update_E
|
|
501
|
+
_update_H_opt = update_H
|
|
502
|
+
else:
|
|
503
|
+
raise ValueError(f"Unknown python_backend value {mode!r}.")
|
|
504
|
+
|
|
505
|
+
return maxwell_python(*args)
|
|
506
|
+
else:
|
|
507
|
+
# Use C/CUDA backend
|
|
508
|
+
return maxwell_c_cuda(*args)
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
def maxwell_python(
|
|
512
|
+
epsilon: torch.Tensor,
|
|
513
|
+
sigma: torch.Tensor,
|
|
514
|
+
mu: torch.Tensor,
|
|
515
|
+
grid_spacing: Union[float, Sequence[float]],
|
|
516
|
+
dt: float,
|
|
517
|
+
source_amplitude: Optional[torch.Tensor],
|
|
518
|
+
source_location: Optional[torch.Tensor],
|
|
519
|
+
receiver_location: Optional[torch.Tensor],
|
|
520
|
+
stencil: int,
|
|
521
|
+
pml_width: Union[int, Sequence[int]],
|
|
522
|
+
max_vel: Optional[float],
|
|
523
|
+
Ey_0: Optional[torch.Tensor],
|
|
524
|
+
Hx_0: Optional[torch.Tensor],
|
|
525
|
+
Hz_0: Optional[torch.Tensor],
|
|
526
|
+
m_Ey_x_0: Optional[torch.Tensor],
|
|
527
|
+
m_Ey_z_0: Optional[torch.Tensor],
|
|
528
|
+
m_Hx_z_0: Optional[torch.Tensor],
|
|
529
|
+
m_Hz_x_0: Optional[torch.Tensor],
|
|
530
|
+
nt: Optional[int],
|
|
531
|
+
model_gradient_sampling_interval: int,
|
|
532
|
+
freq_taper_frac: float,
|
|
533
|
+
time_pad_frac: float,
|
|
534
|
+
time_taper: bool,
|
|
535
|
+
save_snapshots: Optional[bool],
|
|
536
|
+
forward_callback: Optional[Callback],
|
|
537
|
+
backward_callback: Optional[Callback],
|
|
538
|
+
callback_frequency: int,
|
|
539
|
+
storage_mode: str = "device",
|
|
540
|
+
storage_path: str = ".",
|
|
541
|
+
storage_compression: Union[bool, str] = False,
|
|
542
|
+
storage_bytes_limit_device: Optional[int] = None,
|
|
543
|
+
storage_bytes_limit_host: Optional[int] = None,
|
|
544
|
+
storage_chunk_steps: int = 0,
|
|
545
|
+
n_threads: Optional[int] = None,
|
|
546
|
+
):
|
|
547
|
+
"""Performs the forward propagation of the 2D TM Maxwell equations.
|
|
548
|
+
|
|
549
|
+
This function implements the FDTD time-stepping loop for the TM mode
|
|
550
|
+
(Ey, Hx, Hz) with CPML absorbing boundary conditions.
|
|
551
|
+
|
|
552
|
+
- Models are padded by fd_pad + pml_width with replicate mode
|
|
553
|
+
- Wavefields are padded by fd_pad only with zero padding
|
|
554
|
+
- Output wavefields are cropped by fd_pad only (PML region is preserved)
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
epsilon: Permittivity model [ny, nx].
|
|
558
|
+
sigma: Conductivity model [ny, nx].
|
|
559
|
+
mu: Permeability model [ny, nx].
|
|
560
|
+
grid_spacing: Grid spacing (dy, dx) or single value for both.
|
|
561
|
+
dt: Time step.
|
|
562
|
+
source_amplitude: Source amplitudes [n_shots, n_sources, nt].
|
|
563
|
+
source_location: Source locations [n_shots, n_sources, 2].
|
|
564
|
+
receiver_location: Receiver locations [n_shots, n_receivers, 2].
|
|
565
|
+
stencil: Finite difference stencil order (2, 4, 6, or 8).
|
|
566
|
+
pml_width: PML width on each side [top, bottom, left, right] or single value.
|
|
567
|
+
max_vel: Maximum velocity for PML (if None, computed from model).
|
|
568
|
+
Ey_0, Hx_0, Hz_0: Initial field values.
|
|
569
|
+
m_Ey_x_0, m_Ey_z_0, m_Hx_z_0, m_Hz_x_0: Initial CPML memory variables.
|
|
570
|
+
nt: Number of time steps (required if source_amplitude is None).
|
|
571
|
+
model_gradient_sampling_interval: Interval for storing gradients.
|
|
572
|
+
freq_taper_frac: Frequency taper fraction.
|
|
573
|
+
time_pad_frac: Time padding fraction.
|
|
574
|
+
time_taper: Whether to apply time taper.
|
|
575
|
+
save_snapshots: Whether to save wavefield snapshots for backward pass.
|
|
576
|
+
If None, determined by requires_grad on model parameters.
|
|
577
|
+
forward_callback: Callback function called during propagation.
|
|
578
|
+
callback_frequency: Frequency of callback calls.
|
|
579
|
+
Returns:
|
|
580
|
+
Tuple containing:
|
|
581
|
+
- Ey: Final electric field [n_shots, ny + pml, nx + pml]
|
|
582
|
+
- Hx, Hz: Final magnetic fields
|
|
583
|
+
- m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x: Final CPML memory variables
|
|
584
|
+
- receiver_amplitudes: Recorded data at receivers [nt, n_shots, n_receivers]
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
from .padding import create_or_pad, zero_interior
|
|
588
|
+
|
|
589
|
+
# These should be set by maxwell_func before calling this function
|
|
590
|
+
assert _update_E_opt is not None, "_update_E_opt must be set by maxwell_func"
|
|
591
|
+
assert _update_H_opt is not None, "_update_H_opt must be set by maxwell_func"
|
|
592
|
+
|
|
593
|
+
# Validate inputs
|
|
594
|
+
if epsilon.ndim != 2:
|
|
595
|
+
raise RuntimeError("epsilon must be 2D")
|
|
596
|
+
if sigma.shape != epsilon.shape:
|
|
597
|
+
raise RuntimeError("sigma must have same shape as epsilon")
|
|
598
|
+
if mu.shape != epsilon.shape:
|
|
599
|
+
raise RuntimeError("mu must have same shape as epsilon")
|
|
600
|
+
|
|
601
|
+
device = epsilon.device
|
|
602
|
+
dtype = epsilon.dtype
|
|
603
|
+
model_ny, model_nx = epsilon.shape # Original model dimensions
|
|
604
|
+
|
|
605
|
+
storage_mode_str = storage_mode.lower()
|
|
606
|
+
if storage_mode_str in {"cpu", "disk"}:
|
|
607
|
+
raise ValueError(
|
|
608
|
+
"python_backend does not support storage_mode='cpu' or 'disk'. "
|
|
609
|
+
"Use the C/CUDA backend or storage_mode='device'/'none'."
|
|
610
|
+
)
|
|
611
|
+
storage_kind = _normalize_storage_compression(storage_compression)
|
|
612
|
+
if storage_kind != "none":
|
|
613
|
+
raise NotImplementedError(
|
|
614
|
+
"storage_compression is not implemented yet; set storage_compression=False."
|
|
615
|
+
)
|
|
616
|
+
|
|
617
|
+
# Normalize grid_spacing to list
|
|
618
|
+
grid_spacing = _normalize_grid_spacing_2d(grid_spacing)
|
|
619
|
+
dy, dx = grid_spacing
|
|
620
|
+
|
|
621
|
+
# Normalize pml_width to list [top, bottom, left, right]
|
|
622
|
+
pml_width_list = _normalize_pml_width_2d(pml_width)
|
|
623
|
+
|
|
624
|
+
# Determine number of time steps
|
|
625
|
+
if nt is None:
|
|
626
|
+
if source_amplitude is None:
|
|
627
|
+
raise ValueError("Either nt or source_amplitude must be provided")
|
|
628
|
+
nt = source_amplitude.shape[-1]
|
|
629
|
+
|
|
630
|
+
# Type cast to ensure nt is int for type checker
|
|
631
|
+
nt_steps: int = int(nt)
|
|
632
|
+
|
|
633
|
+
# Determine number of shots
|
|
634
|
+
if source_amplitude is not None and source_amplitude.numel() > 0:
|
|
635
|
+
n_shots = source_amplitude.shape[0]
|
|
636
|
+
elif source_location is not None and source_location.numel() > 0:
|
|
637
|
+
n_shots = source_location.shape[0]
|
|
638
|
+
elif receiver_location is not None and receiver_location.numel() > 0:
|
|
639
|
+
n_shots = receiver_location.shape[0]
|
|
640
|
+
else:
|
|
641
|
+
n_shots = 1
|
|
642
|
+
|
|
643
|
+
# Compute maximum velocity for PML if not provided
|
|
644
|
+
if max_vel is None:
|
|
645
|
+
# For EM waves: v = c0 / sqrt(epsilon_r * mu_r)
|
|
646
|
+
max_vel = float((1.0 / torch.sqrt(epsilon * mu)).max().item()) * C0
|
|
647
|
+
|
|
648
|
+
# Compute PML frequency (dominant frequency estimate)
|
|
649
|
+
pml_freq = 0.5 / dt # Nyquist as default
|
|
650
|
+
|
|
651
|
+
# =========================================================================
|
|
652
|
+
# Padding strategy:
|
|
653
|
+
# - fd_pad: padding for finite difference stencil accuracy
|
|
654
|
+
# - pml_width: padding for PML absorbing layers
|
|
655
|
+
# - Total model padding = fd_pad + pml_width
|
|
656
|
+
# - Wavefield padding = fd_pad only (wavefields include PML region)
|
|
657
|
+
# =========================================================================
|
|
658
|
+
|
|
659
|
+
# FD padding based on stencil: accuracy // 2
|
|
660
|
+
fd_pad = stencil // 2
|
|
661
|
+
# fd_pad_list: [y0, y1, x0, x1] - for 2D staggered grid, asymmetric because
|
|
662
|
+
# staggered diff a[1:] - a[:-1] reduces array size by 1, so we need fd_pad-1 at end
|
|
663
|
+
fd_pad_list = [fd_pad, fd_pad - 1, fd_pad, fd_pad - 1]
|
|
664
|
+
|
|
665
|
+
# Total padding for models = fd_pad + pml_width
|
|
666
|
+
total_pad = [fd + pml for fd, pml in zip(fd_pad_list, pml_width_list)]
|
|
667
|
+
|
|
668
|
+
# Calculate padded dimensions
|
|
669
|
+
# Model is padded by total_pad on each side
|
|
670
|
+
padded_ny = model_ny + total_pad[0] + total_pad[1]
|
|
671
|
+
padded_nx = model_nx + total_pad[2] + total_pad[3]
|
|
672
|
+
|
|
673
|
+
# Pad model tensors with replicate mode (extend boundary values)
|
|
674
|
+
padded_size = (padded_ny, padded_nx)
|
|
675
|
+
epsilon_padded = create_or_pad(
|
|
676
|
+
epsilon, total_pad, device, dtype, padded_size, mode="replicate"
|
|
677
|
+
)
|
|
678
|
+
sigma_padded = create_or_pad(
|
|
679
|
+
sigma, total_pad, device, dtype, padded_size, mode="replicate"
|
|
680
|
+
)
|
|
681
|
+
mu_padded = create_or_pad(
|
|
682
|
+
mu, total_pad, device, dtype, padded_size, mode="replicate"
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
# Prepare update coefficients using padded models
|
|
686
|
+
ca, cb, cq = prepare_parameters(epsilon_padded, sigma_padded, mu_padded, dt)
|
|
687
|
+
|
|
688
|
+
# Expand coefficients for batch dimension
|
|
689
|
+
ca = ca[None, :, :] # [1, padded_ny, padded_nx]
|
|
690
|
+
cb = cb[None, :, :]
|
|
691
|
+
cq = cq[None, :, :]
|
|
692
|
+
|
|
693
|
+
# =========================================================================
|
|
694
|
+
# Initialize wavefields
|
|
695
|
+
# Wavefields are padded by fd_pad only (they include the PML region)
|
|
696
|
+
# Size = [n_shots, model_ny + pml_width*2 + fd_pad*2, model_nx + ...]
|
|
697
|
+
# Which equals [n_shots, padded_ny, padded_nx]
|
|
698
|
+
# =========================================================================
|
|
699
|
+
size_with_batch = (n_shots, padded_ny, padded_nx)
|
|
700
|
+
|
|
701
|
+
# Helper function to initialize wavefields with fd_pad padding
|
|
702
|
+
def init_wavefield(field_0: Optional[torch.Tensor]) -> torch.Tensor:
|
|
703
|
+
"""Initialize wavefield with fd_pad zero padding.
|
|
704
|
+
|
|
705
|
+
Zero padding is used for wavefields because the fd_pad region should
|
|
706
|
+
always be zero after output cropping and re-padding. The staggered grid
|
|
707
|
+
operators only read from this region but don't need non-zero values there
|
|
708
|
+
for correct propagation.
|
|
709
|
+
"""
|
|
710
|
+
if field_0 is not None:
|
|
711
|
+
# User provides [n_shots, ny, nx] or [ny, nx]
|
|
712
|
+
if field_0.ndim == 2:
|
|
713
|
+
field_0 = field_0[None, :, :].expand(n_shots, -1, -1)
|
|
714
|
+
# Pad with asymmetric fd_pad_list for staggered grid (zero padding)
|
|
715
|
+
return create_or_pad(
|
|
716
|
+
field_0, fd_pad_list, device, dtype, size_with_batch, mode="constant"
|
|
717
|
+
)
|
|
718
|
+
return torch.zeros(size_with_batch, device=device, dtype=dtype)
|
|
719
|
+
|
|
720
|
+
Ey = init_wavefield(Ey_0)
|
|
721
|
+
Hx = init_wavefield(Hx_0)
|
|
722
|
+
Hz = init_wavefield(Hz_0)
|
|
723
|
+
m_Ey_x = init_wavefield(m_Ey_x_0)
|
|
724
|
+
m_Ey_z = init_wavefield(m_Ey_z_0)
|
|
725
|
+
m_Hx_z = init_wavefield(m_Hx_z_0)
|
|
726
|
+
m_Hz_x = init_wavefield(m_Hz_x_0)
|
|
727
|
+
|
|
728
|
+
# Zero out interior of PML auxiliary variables (optimization)
|
|
729
|
+
# PML memory variables should only be non-zero in PML regions.
|
|
730
|
+
# This works correctly even with user-provided initial states because:
|
|
731
|
+
# 1. The output preserves PML region (only fd_pad is cropped)
|
|
732
|
+
# 2. zero_interior only zeros the interior, preserving PML boundary values
|
|
733
|
+
# 3. Interior values are already zero in correctly propagated wavefields
|
|
734
|
+
# Dimension mapping for zero_interior:
|
|
735
|
+
# - m_Ey_x: x-direction auxiliary -> dim=1 (zero y-interior, keep x-boundaries)
|
|
736
|
+
# - m_Ey_z: y/z-direction auxiliary -> dim=0 (zero x-interior, keep y-boundaries)
|
|
737
|
+
# - m_Hx_z: y/z-direction auxiliary -> dim=0
|
|
738
|
+
# - m_Hz_x: x-direction auxiliary -> dim=1
|
|
739
|
+
pml_aux_dims = [1, 0, 0, 1] # [m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x]
|
|
740
|
+
for wf, dim in zip([m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x], pml_aux_dims):
|
|
741
|
+
zero_interior(wf, fd_pad_list, pml_width_list, dim)
|
|
742
|
+
|
|
743
|
+
# Set up PML profiles for the padded domain
|
|
744
|
+
pml_profiles_list = staggered.set_pml_profiles(
|
|
745
|
+
pml_width=pml_width_list,
|
|
746
|
+
accuracy=stencil,
|
|
747
|
+
fd_pad=fd_pad_list,
|
|
748
|
+
dt=dt,
|
|
749
|
+
grid_spacing=grid_spacing,
|
|
750
|
+
max_vel=max_vel,
|
|
751
|
+
dtype=dtype,
|
|
752
|
+
device=device,
|
|
753
|
+
pml_freq=pml_freq,
|
|
754
|
+
ny=padded_ny,
|
|
755
|
+
nx=padded_nx,
|
|
756
|
+
)
|
|
757
|
+
# pml_profiles_list = [ay, ayh, ax, axh, by, byh, bx, bxh, ky, kyh, kx, kxh]
|
|
758
|
+
(
|
|
759
|
+
ay,
|
|
760
|
+
ay_h,
|
|
761
|
+
ax,
|
|
762
|
+
ax_h,
|
|
763
|
+
by,
|
|
764
|
+
by_h,
|
|
765
|
+
bx,
|
|
766
|
+
bx_h,
|
|
767
|
+
kappa_y,
|
|
768
|
+
kappa_y_h,
|
|
769
|
+
kappa_x,
|
|
770
|
+
kappa_x_h,
|
|
771
|
+
) = pml_profiles_list
|
|
772
|
+
|
|
773
|
+
# Reciprocal grid spacing
|
|
774
|
+
rdy = torch.tensor(1.0 / dy, device=device, dtype=dtype)
|
|
775
|
+
rdx = torch.tensor(1.0 / dx, device=device, dtype=dtype)
|
|
776
|
+
dt_tensor = torch.tensor(dt, device=device, dtype=dtype)
|
|
777
|
+
|
|
778
|
+
# =========================================================================
|
|
779
|
+
# Prepare source and receiver indices
|
|
780
|
+
# Original positions are in the un-padded model coordinate system.
|
|
781
|
+
# We need to offset by total_pad (fd_pad + pml_width) to get padded coords.
|
|
782
|
+
# =========================================================================
|
|
783
|
+
flat_model_shape = padded_ny * padded_nx
|
|
784
|
+
|
|
785
|
+
if source_location is not None and source_location.numel() > 0:
|
|
786
|
+
# Adjust source positions by total padding offset
|
|
787
|
+
source_y = source_location[..., 0] + total_pad[0] # Add top offset
|
|
788
|
+
source_x = source_location[..., 1] + total_pad[2] # Add left offset
|
|
789
|
+
sources_i = (source_y * padded_nx + source_x).long() # [n_shots, n_sources]
|
|
790
|
+
n_sources = source_location.shape[1]
|
|
791
|
+
else:
|
|
792
|
+
sources_i = torch.empty(0, device=device, dtype=torch.long)
|
|
793
|
+
n_sources = 0
|
|
794
|
+
|
|
795
|
+
if receiver_location is not None and receiver_location.numel() > 0:
|
|
796
|
+
# Adjust receiver positions by total padding offset
|
|
797
|
+
receiver_y = receiver_location[..., 0] + total_pad[0] # Add top offset
|
|
798
|
+
receiver_x = receiver_location[..., 1] + total_pad[2] # Add left offset
|
|
799
|
+
receivers_i = (receiver_y * padded_nx + receiver_x).long()
|
|
800
|
+
n_receivers = receiver_location.shape[1]
|
|
801
|
+
else:
|
|
802
|
+
receivers_i = torch.empty(0, device=device, dtype=torch.long)
|
|
803
|
+
n_receivers = 0
|
|
804
|
+
|
|
805
|
+
# Initialize receiver amplitudes
|
|
806
|
+
if n_receivers > 0:
|
|
807
|
+
receiver_amplitudes = torch.zeros(
|
|
808
|
+
nt_steps, n_shots, n_receivers, device=device, dtype=dtype
|
|
809
|
+
)
|
|
810
|
+
else:
|
|
811
|
+
receiver_amplitudes = torch.empty(0, device=device, dtype=dtype)
|
|
812
|
+
|
|
813
|
+
# Prepare callback data - models dict uses the padded models
|
|
814
|
+
callback_models = {
|
|
815
|
+
"epsilon": epsilon_padded,
|
|
816
|
+
"sigma": sigma_padded,
|
|
817
|
+
"mu": mu_padded,
|
|
818
|
+
"ca": ca,
|
|
819
|
+
"cb": cb,
|
|
820
|
+
"cq": cq,
|
|
821
|
+
}
|
|
822
|
+
|
|
823
|
+
# Callback fd_pad is the actual fd_pad used for wavefields
|
|
824
|
+
callback_fd_pad = fd_pad_list
|
|
825
|
+
|
|
826
|
+
# Source injection coefficient: -cb * dt / (dx * dy)
|
|
827
|
+
# Since our cb already contains dt/epsilon, we need: -cb / (dx * dy)
|
|
828
|
+
# This normalizes the source by cell volume for correct amplitude
|
|
829
|
+
source_coeff = -1.0 / (dx * dy)
|
|
830
|
+
|
|
831
|
+
# Time stepping loop
|
|
832
|
+
for step in range(nt_steps):
|
|
833
|
+
# Callback at specified frequency
|
|
834
|
+
if forward_callback is not None and step % callback_frequency == 0:
|
|
835
|
+
callback_wavefields = {
|
|
836
|
+
"Ey": Ey,
|
|
837
|
+
"Hx": Hx,
|
|
838
|
+
"Hz": Hz,
|
|
839
|
+
"m_Ey_x": m_Ey_x,
|
|
840
|
+
"m_Ey_z": m_Ey_z,
|
|
841
|
+
"m_Hx_z": m_Hx_z,
|
|
842
|
+
"m_Hz_x": m_Hz_x,
|
|
843
|
+
}
|
|
844
|
+
# Create CallbackState for standardized interface
|
|
845
|
+
callback_state = CallbackState(
|
|
846
|
+
dt=dt,
|
|
847
|
+
step=step,
|
|
848
|
+
nt=nt_steps,
|
|
849
|
+
wavefields=callback_wavefields,
|
|
850
|
+
models=callback_models,
|
|
851
|
+
gradients=None, # No gradients during forward pass
|
|
852
|
+
fd_pad=callback_fd_pad,
|
|
853
|
+
pml_width=pml_width_list,
|
|
854
|
+
is_backward=False,
|
|
855
|
+
grid_spacing=[dy, dx],
|
|
856
|
+
)
|
|
857
|
+
forward_callback(callback_state)
|
|
858
|
+
|
|
859
|
+
# Update H fields: H^{n+1/2} = H^{n-1/2} + ...
|
|
860
|
+
Hx, Hz, m_Ey_x, m_Ey_z = _update_H_opt(
|
|
861
|
+
cq,
|
|
862
|
+
Hx,
|
|
863
|
+
Hz,
|
|
864
|
+
Ey,
|
|
865
|
+
m_Ey_x,
|
|
866
|
+
m_Ey_z,
|
|
867
|
+
kappa_y,
|
|
868
|
+
kappa_y_h,
|
|
869
|
+
kappa_x,
|
|
870
|
+
kappa_x_h,
|
|
871
|
+
ay,
|
|
872
|
+
ay_h,
|
|
873
|
+
ax,
|
|
874
|
+
ax_h,
|
|
875
|
+
by,
|
|
876
|
+
by_h,
|
|
877
|
+
bx,
|
|
878
|
+
bx_h,
|
|
879
|
+
rdy,
|
|
880
|
+
rdx,
|
|
881
|
+
dt_tensor,
|
|
882
|
+
stencil,
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
# Update E field: E^{n+1} = E^n + ...
|
|
886
|
+
Ey, m_Hx_z, m_Hz_x = _update_E_opt(
|
|
887
|
+
ca,
|
|
888
|
+
cb,
|
|
889
|
+
Hx,
|
|
890
|
+
Hz,
|
|
891
|
+
Ey,
|
|
892
|
+
m_Hx_z,
|
|
893
|
+
m_Hz_x,
|
|
894
|
+
kappa_y,
|
|
895
|
+
kappa_y_h,
|
|
896
|
+
kappa_x,
|
|
897
|
+
kappa_x_h,
|
|
898
|
+
ay,
|
|
899
|
+
ay_h,
|
|
900
|
+
ax,
|
|
901
|
+
ax_h,
|
|
902
|
+
by,
|
|
903
|
+
by_h,
|
|
904
|
+
bx,
|
|
905
|
+
bx_h,
|
|
906
|
+
rdy,
|
|
907
|
+
rdx,
|
|
908
|
+
dt_tensor,
|
|
909
|
+
stencil,
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
# Inject source into Ey field (after E update, following reference implementation)
|
|
913
|
+
# Source term: Ey += -cb * f * dt / (dx * dz) = -cb * f / (dx * dz) since cb contains dt
|
|
914
|
+
if (
|
|
915
|
+
source_amplitude is not None
|
|
916
|
+
and source_amplitude.numel() > 0
|
|
917
|
+
and n_sources > 0
|
|
918
|
+
):
|
|
919
|
+
# source_amplitude: [n_shots, n_sources, nt]
|
|
920
|
+
src_amp = source_amplitude[:, :, step] # [n_shots, n_sources]
|
|
921
|
+
# Get cb at source locations for proper scaling
|
|
922
|
+
cb_flat = cb.reshape(1, flat_model_shape).expand(n_shots, -1)
|
|
923
|
+
cb_at_src = cb_flat.gather(1, sources_i) # [n_shots, n_sources]
|
|
924
|
+
# Apply source with coefficient: -cb * f / (dx * dy)
|
|
925
|
+
scaled_src = cb_at_src * src_amp * source_coeff
|
|
926
|
+
Ey = (
|
|
927
|
+
Ey.reshape(n_shots, flat_model_shape)
|
|
928
|
+
.scatter_add(1, sources_i, scaled_src)
|
|
929
|
+
.reshape(size_with_batch)
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
# Record at receivers (after source injection)
|
|
933
|
+
if n_receivers > 0:
|
|
934
|
+
receiver_amplitudes[step] = Ey.reshape(n_shots, flat_model_shape).gather(
|
|
935
|
+
1, receivers_i
|
|
936
|
+
)
|
|
937
|
+
|
|
938
|
+
# =========================================================================
|
|
939
|
+
# Output cropping:
|
|
940
|
+
# Only remove fd_pad, keep the PML region in the output wavefields.
|
|
941
|
+
# Output shape: [n_shots, model_ny + pml_width_y, model_nx + pml_width_x]
|
|
942
|
+
# =========================================================================
|
|
943
|
+
s = (
|
|
944
|
+
slice(None), # batch dimension
|
|
945
|
+
slice(
|
|
946
|
+
fd_pad_list[0], padded_ny - fd_pad_list[1] if fd_pad_list[1] > 0 else None
|
|
947
|
+
),
|
|
948
|
+
slice(
|
|
949
|
+
fd_pad_list[2], padded_nx - fd_pad_list[3] if fd_pad_list[3] > 0 else None
|
|
950
|
+
),
|
|
951
|
+
)
|
|
952
|
+
|
|
953
|
+
return (
|
|
954
|
+
Ey[s],
|
|
955
|
+
Hx[s],
|
|
956
|
+
Hz[s],
|
|
957
|
+
m_Ey_x[s],
|
|
958
|
+
m_Ey_z[s],
|
|
959
|
+
m_Hx_z[s],
|
|
960
|
+
m_Hz_x[s],
|
|
961
|
+
receiver_amplitudes,
|
|
962
|
+
)
|
|
963
|
+
|
|
964
|
+
|
|
965
|
+
def update_E(
|
|
966
|
+
ca: torch.Tensor,
|
|
967
|
+
cb: torch.Tensor,
|
|
968
|
+
Hx: torch.Tensor,
|
|
969
|
+
Hz: torch.Tensor,
|
|
970
|
+
Ey: torch.Tensor,
|
|
971
|
+
m_Hx_z: torch.Tensor,
|
|
972
|
+
m_Hz_x: torch.Tensor,
|
|
973
|
+
kappa_y: torch.Tensor,
|
|
974
|
+
kappa_y_h: torch.Tensor,
|
|
975
|
+
kappa_x: torch.Tensor,
|
|
976
|
+
kappa_x_h: torch.Tensor,
|
|
977
|
+
ay: torch.Tensor,
|
|
978
|
+
ay_h: torch.Tensor,
|
|
979
|
+
ax: torch.Tensor,
|
|
980
|
+
ax_h: torch.Tensor,
|
|
981
|
+
by: torch.Tensor,
|
|
982
|
+
by_h: torch.Tensor,
|
|
983
|
+
bx: torch.Tensor,
|
|
984
|
+
bx_h: torch.Tensor,
|
|
985
|
+
rdy: torch.Tensor,
|
|
986
|
+
rdx: torch.Tensor,
|
|
987
|
+
dt: torch.Tensor,
|
|
988
|
+
stencil: int,
|
|
989
|
+
) -> tuple[
|
|
990
|
+
torch.Tensor,
|
|
991
|
+
torch.Tensor,
|
|
992
|
+
torch.Tensor,
|
|
993
|
+
]:
|
|
994
|
+
"""Update electric field Ey with CPML absorbing boundary conditions.
|
|
995
|
+
|
|
996
|
+
For TM mode, the update equation is:
|
|
997
|
+
Ey^{n+1} = Ca * Ey^n + Cb * (dHz/dx - dHx/dz)
|
|
998
|
+
|
|
999
|
+
With CPML, we split the derivatives and apply auxiliary variables:
|
|
1000
|
+
dHz/dx -> dHz/dx / kappa_x + m_Hz_x
|
|
1001
|
+
dHx/dz -> dHx/dz / kappa_y + m_Hx_z
|
|
1002
|
+
|
|
1003
|
+
Args:
|
|
1004
|
+
ca, cb: Update coefficients from material parameters
|
|
1005
|
+
Hx, Hz: Magnetic field components
|
|
1006
|
+
Ey: Electric field component to update
|
|
1007
|
+
m_Hx_z, m_Hz_x: CPML auxiliary memory variables
|
|
1008
|
+
kappa_y, kappa_y_h: CPML kappa profiles in y direction
|
|
1009
|
+
kappa_x, kappa_x_h: CPML kappa profiles in x direction
|
|
1010
|
+
ay, ay_h, ax, ax_h: CPML a coefficients
|
|
1011
|
+
by, by_h, bx, bx_h: CPML b coefficients
|
|
1012
|
+
rdy, rdx: Reciprocal of grid spacing (1/dy, 1/dx)
|
|
1013
|
+
dt: Time step
|
|
1014
|
+
stencil: Finite difference stencil order (2, 4, 6, or 8)
|
|
1015
|
+
|
|
1016
|
+
Returns:
|
|
1017
|
+
Updated Ey, m_Hx_z, m_Hz_x
|
|
1018
|
+
"""
|
|
1019
|
+
|
|
1020
|
+
# Compute spatial derivatives using staggered grid operators
|
|
1021
|
+
# dHz/dx at integer grid points (where Ey lives)
|
|
1022
|
+
dHz_dx = staggered.diffx1(Hz, stencil, rdx)
|
|
1023
|
+
# dHx/dz at integer grid points (where Ey lives)
|
|
1024
|
+
dHx_dz = staggered.diffy1(Hx, stencil, rdy)
|
|
1025
|
+
|
|
1026
|
+
# Update CPML auxiliary variables using standard CPML recursion:
|
|
1027
|
+
# psi_new = b * psi_old + a * derivative
|
|
1028
|
+
# m_Hz_x stores the x-direction PML memory for Hz derivative
|
|
1029
|
+
m_Hz_x = bx * m_Hz_x + ax * dHz_dx
|
|
1030
|
+
# m_Hx_z stores the z-direction PML memory for Hx derivative
|
|
1031
|
+
m_Hx_z = by * m_Hx_z + ay * dHx_dz
|
|
1032
|
+
|
|
1033
|
+
# Apply CPML correction to derivatives
|
|
1034
|
+
# In CPML: d/dx -> (1/kappa) * d/dx + m
|
|
1035
|
+
dHz_dx_pml = dHz_dx / kappa_x + m_Hz_x
|
|
1036
|
+
dHx_dz_pml = dHx_dz / kappa_y + m_Hx_z
|
|
1037
|
+
|
|
1038
|
+
# Update Ey using the FDTD update equation
|
|
1039
|
+
# Ey^{n+1} = Ca * Ey^n + Cb * (dHz/dx - dHx/dz)
|
|
1040
|
+
Ey = ca * Ey + cb * (dHz_dx_pml - dHx_dz_pml)
|
|
1041
|
+
|
|
1042
|
+
return Ey, m_Hx_z, m_Hz_x
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
def update_H(
|
|
1046
|
+
cq: torch.Tensor,
|
|
1047
|
+
Hx: torch.Tensor,
|
|
1048
|
+
Hz: torch.Tensor,
|
|
1049
|
+
Ey: torch.Tensor,
|
|
1050
|
+
m_Ey_x: torch.Tensor,
|
|
1051
|
+
m_Ey_z: torch.Tensor,
|
|
1052
|
+
kappa_y: torch.Tensor,
|
|
1053
|
+
kappa_y_h: torch.Tensor,
|
|
1054
|
+
kappa_x: torch.Tensor,
|
|
1055
|
+
kappa_x_h: torch.Tensor,
|
|
1056
|
+
ay: torch.Tensor,
|
|
1057
|
+
ay_h: torch.Tensor,
|
|
1058
|
+
ax: torch.Tensor,
|
|
1059
|
+
ax_h: torch.Tensor,
|
|
1060
|
+
by: torch.Tensor,
|
|
1061
|
+
by_h: torch.Tensor,
|
|
1062
|
+
bx: torch.Tensor,
|
|
1063
|
+
bx_h: torch.Tensor,
|
|
1064
|
+
rdy: torch.Tensor,
|
|
1065
|
+
rdx: torch.Tensor,
|
|
1066
|
+
dt: torch.Tensor,
|
|
1067
|
+
stencil: int,
|
|
1068
|
+
) -> tuple[
|
|
1069
|
+
torch.Tensor,
|
|
1070
|
+
torch.Tensor,
|
|
1071
|
+
torch.Tensor,
|
|
1072
|
+
torch.Tensor,
|
|
1073
|
+
]:
|
|
1074
|
+
"""Update magnetic fields Hx and Hz with CPML absorbing boundary conditions.
|
|
1075
|
+
|
|
1076
|
+
For TM mode, the update equations are:
|
|
1077
|
+
Hx^{n+1} = Hx^n - Cq * dEy/dz
|
|
1078
|
+
Hz^{n+1} = Hz^n + Cq * dEy/dx
|
|
1079
|
+
|
|
1080
|
+
With CPML, we use half-grid derivatives and auxiliary variables:
|
|
1081
|
+
dEy/dz -> dEy/dz / kappa_y_h + m_Ey_z
|
|
1082
|
+
dEy/dx -> dEy/dx / kappa_x_h + m_Ey_x
|
|
1083
|
+
|
|
1084
|
+
Args:
|
|
1085
|
+
cq: Update coefficient (dt/mu)
|
|
1086
|
+
Hx, Hz: Magnetic field components to update
|
|
1087
|
+
Ey: Electric field component
|
|
1088
|
+
m_Ey_x, m_Ey_z: CPML auxiliary memory variables
|
|
1089
|
+
kappa_y, kappa_y_h: CPML kappa profiles in y direction (integer and half grid)
|
|
1090
|
+
kappa_x, kappa_x_h: CPML kappa profiles in x direction (integer and half grid)
|
|
1091
|
+
ay, ay_h, ax, ax_h: CPML a coefficients
|
|
1092
|
+
by, by_h, bx, bx_h: CPML b coefficients
|
|
1093
|
+
rdy, rdx: Reciprocal of grid spacing (1/dy, 1/dx)
|
|
1094
|
+
dt: Time step
|
|
1095
|
+
stencil: Finite difference stencil order (2, 4, 6, or 8)
|
|
1096
|
+
|
|
1097
|
+
Returns:
|
|
1098
|
+
Updated Hx, Hz, m_Ey_x, m_Ey_z
|
|
1099
|
+
"""
|
|
1100
|
+
|
|
1101
|
+
# Compute spatial derivatives at half grid points (where H fields live)
|
|
1102
|
+
# dEy/dz at half grid points in z (for Hx update)
|
|
1103
|
+
dEy_dz = staggered.diffyh1(Ey, stencil, rdy)
|
|
1104
|
+
# dEy/dx at half grid points in x (for Hz update)
|
|
1105
|
+
dEy_dx = staggered.diffxh1(Ey, stencil, rdx)
|
|
1106
|
+
|
|
1107
|
+
# Update CPML auxiliary variables using standard CPML recursion:
|
|
1108
|
+
# psi_new = b * psi_old + a * derivative
|
|
1109
|
+
# m_Ey_z stores the z-direction PML memory for Ey derivative (used in Hx update)
|
|
1110
|
+
m_Ey_z = by_h * m_Ey_z + ay_h * dEy_dz
|
|
1111
|
+
# m_Ey_x stores the x-direction PML memory for Ey derivative (used in Hz update)
|
|
1112
|
+
m_Ey_x = bx_h * m_Ey_x + ax_h * dEy_dx
|
|
1113
|
+
|
|
1114
|
+
# Apply CPML correction to derivatives
|
|
1115
|
+
# In CPML: d/dz -> (1/kappa_h) * d/dz + m
|
|
1116
|
+
dEy_dz_pml = dEy_dz / kappa_y_h + m_Ey_z
|
|
1117
|
+
dEy_dx_pml = dEy_dx / kappa_x_h + m_Ey_x
|
|
1118
|
+
|
|
1119
|
+
# Update Hx using the FDTD update equation
|
|
1120
|
+
# Hx^{n+1} = Hx^n - Cq * dEy/dz
|
|
1121
|
+
Hx = Hx - cq * dEy_dz_pml
|
|
1122
|
+
|
|
1123
|
+
# Update Hz using the FDTD update equation
|
|
1124
|
+
# Hz^{n+1} = Hz^n + Cq * dEy/dx
|
|
1125
|
+
Hz = Hz + cq * dEy_dx_pml
|
|
1126
|
+
|
|
1127
|
+
return Hx, Hz, m_Ey_x, m_Ey_z
|
|
1128
|
+
|
|
1129
|
+
|
|
1130
|
+
# Initialize the optimized function pointers to the default implementations
|
|
1131
|
+
_update_E_opt = update_E
|
|
1132
|
+
_update_H_opt = update_H
|
|
1133
|
+
|
|
1134
|
+
|
|
1135
|
+
def maxwell_c_cuda(
|
|
1136
|
+
epsilon: torch.Tensor,
|
|
1137
|
+
sigma: torch.Tensor,
|
|
1138
|
+
mu: torch.Tensor,
|
|
1139
|
+
grid_spacing: Union[float, Sequence[float]],
|
|
1140
|
+
dt: float,
|
|
1141
|
+
source_amplitude: Optional[torch.Tensor],
|
|
1142
|
+
source_location: Optional[torch.Tensor],
|
|
1143
|
+
receiver_location: Optional[torch.Tensor],
|
|
1144
|
+
stencil: int,
|
|
1145
|
+
pml_width: Union[int, Sequence[int]],
|
|
1146
|
+
max_vel: Optional[float],
|
|
1147
|
+
Ey_0: Optional[torch.Tensor],
|
|
1148
|
+
Hx_0: Optional[torch.Tensor],
|
|
1149
|
+
Hz_0: Optional[torch.Tensor],
|
|
1150
|
+
m_Ey_x_0: Optional[torch.Tensor],
|
|
1151
|
+
m_Ey_z_0: Optional[torch.Tensor],
|
|
1152
|
+
m_Hx_z_0: Optional[torch.Tensor],
|
|
1153
|
+
m_Hz_x_0: Optional[torch.Tensor],
|
|
1154
|
+
nt: Optional[int],
|
|
1155
|
+
model_gradient_sampling_interval: int,
|
|
1156
|
+
freq_taper_frac: float,
|
|
1157
|
+
time_pad_frac: float,
|
|
1158
|
+
time_taper: bool,
|
|
1159
|
+
save_snapshots: Optional[bool],
|
|
1160
|
+
forward_callback: Optional[Callback],
|
|
1161
|
+
backward_callback: Optional[Callback],
|
|
1162
|
+
callback_frequency: int,
|
|
1163
|
+
storage_mode: str = "device",
|
|
1164
|
+
storage_path: str = ".",
|
|
1165
|
+
storage_compression: Union[bool, str] = False,
|
|
1166
|
+
storage_bytes_limit_device: Optional[int] = None,
|
|
1167
|
+
storage_bytes_limit_host: Optional[int] = None,
|
|
1168
|
+
storage_chunk_steps: int = 0,
|
|
1169
|
+
n_threads: Optional[int] = None,
|
|
1170
|
+
):
|
|
1171
|
+
"""Performs Maxwell propagation using C/CUDA backend.
|
|
1172
|
+
|
|
1173
|
+
This function provides the interface to the compiled C/CUDA implementations
|
|
1174
|
+
for high-performance wave propagation.
|
|
1175
|
+
|
|
1176
|
+
Padding strategy:
|
|
1177
|
+
- Models are padded by fd_pad + pml_width with replicate mode
|
|
1178
|
+
- Wavefields are padded by fd_pad only with zero padding
|
|
1179
|
+
- Output wavefields are cropped by fd_pad only (PML region is preserved)
|
|
1180
|
+
|
|
1181
|
+
Args:
|
|
1182
|
+
Same as maxwell_python.
|
|
1183
|
+
|
|
1184
|
+
Returns:
|
|
1185
|
+
Same as maxwell_python.
|
|
1186
|
+
"""
|
|
1187
|
+
from . import backend_utils, staggered
|
|
1188
|
+
from .padding import create_or_pad, zero_interior
|
|
1189
|
+
|
|
1190
|
+
# Validate inputs
|
|
1191
|
+
if epsilon.ndim != 2:
|
|
1192
|
+
raise RuntimeError("epsilon must be 2D")
|
|
1193
|
+
if sigma.shape != epsilon.shape:
|
|
1194
|
+
raise RuntimeError("sigma must have same shape as epsilon")
|
|
1195
|
+
if mu.shape != epsilon.shape:
|
|
1196
|
+
raise RuntimeError("mu must have same shape as epsilon")
|
|
1197
|
+
|
|
1198
|
+
device = epsilon.device
|
|
1199
|
+
dtype = epsilon.dtype
|
|
1200
|
+
model_ny, model_nx = epsilon.shape # Original model dimensions
|
|
1201
|
+
|
|
1202
|
+
# Normalize grid_spacing to list
|
|
1203
|
+
grid_spacing = _normalize_grid_spacing_2d(grid_spacing)
|
|
1204
|
+
dy, dx = grid_spacing
|
|
1205
|
+
|
|
1206
|
+
n_threads_val = 0
|
|
1207
|
+
if n_threads is not None:
|
|
1208
|
+
n_threads_val = int(n_threads)
|
|
1209
|
+
if n_threads_val < 0:
|
|
1210
|
+
raise ValueError("n_threads must be >= 0 when provided.")
|
|
1211
|
+
|
|
1212
|
+
# Normalize pml_width to list [top, bottom, left, right]
|
|
1213
|
+
pml_width_list = _normalize_pml_width_2d(pml_width)
|
|
1214
|
+
|
|
1215
|
+
# Determine number of time steps
|
|
1216
|
+
if nt is None:
|
|
1217
|
+
if source_amplitude is None:
|
|
1218
|
+
raise ValueError("Either nt or source_amplitude must be provided")
|
|
1219
|
+
nt = source_amplitude.shape[-1]
|
|
1220
|
+
|
|
1221
|
+
# Ensure nt is an integer for iteration
|
|
1222
|
+
nt_steps: int = int(nt)
|
|
1223
|
+
# Clamp gradient sampling interval to a sensible range for storage/backprop
|
|
1224
|
+
gradient_sampling_interval = int(model_gradient_sampling_interval)
|
|
1225
|
+
if gradient_sampling_interval < 1:
|
|
1226
|
+
gradient_sampling_interval = 1
|
|
1227
|
+
if nt_steps > 0:
|
|
1228
|
+
gradient_sampling_interval = min(gradient_sampling_interval, nt_steps)
|
|
1229
|
+
|
|
1230
|
+
# Determine number of shots
|
|
1231
|
+
if source_amplitude is not None and source_amplitude.numel() > 0:
|
|
1232
|
+
n_shots = source_amplitude.shape[0]
|
|
1233
|
+
elif source_location is not None and source_location.numel() > 0:
|
|
1234
|
+
n_shots = source_location.shape[0]
|
|
1235
|
+
elif receiver_location is not None and receiver_location.numel() > 0:
|
|
1236
|
+
n_shots = receiver_location.shape[0]
|
|
1237
|
+
else:
|
|
1238
|
+
n_shots = 1
|
|
1239
|
+
|
|
1240
|
+
# Compute maximum velocity for PML if not provided
|
|
1241
|
+
if max_vel is None:
|
|
1242
|
+
max_vel = float((1.0 / torch.sqrt(epsilon * mu)).max().item()) * C0
|
|
1243
|
+
|
|
1244
|
+
# Compute PML frequency
|
|
1245
|
+
pml_freq = 0.5 / dt
|
|
1246
|
+
|
|
1247
|
+
# =========================================================================
|
|
1248
|
+
# Padding strategy:
|
|
1249
|
+
# - fd_pad: padding for finite difference stencil accuracy
|
|
1250
|
+
# - pml_width: padding for PML absorbing layers
|
|
1251
|
+
# - Total model padding = fd_pad + pml_width
|
|
1252
|
+
# - Wavefield padding = fd_pad only (wavefields include PML region)
|
|
1253
|
+
# =========================================================================
|
|
1254
|
+
|
|
1255
|
+
# FD padding based on stencil: accuracy // 2
|
|
1256
|
+
fd_pad = stencil // 2
|
|
1257
|
+
# fd_pad_list: [y0, y1, x0, x1] - asymmetric for staggered grid
|
|
1258
|
+
fd_pad_list = [fd_pad, fd_pad - 1, fd_pad, fd_pad - 1]
|
|
1259
|
+
|
|
1260
|
+
# Total padding for models = fd_pad + pml_width
|
|
1261
|
+
total_pad = [fd + pml for fd, pml in zip(fd_pad_list, pml_width_list)]
|
|
1262
|
+
|
|
1263
|
+
# Calculate padded dimensions
|
|
1264
|
+
padded_ny = model_ny + total_pad[0] + total_pad[1]
|
|
1265
|
+
padded_nx = model_nx + total_pad[2] + total_pad[3]
|
|
1266
|
+
|
|
1267
|
+
# Pad model tensors with replicate mode (extend boundary values)
|
|
1268
|
+
padded_size = (padded_ny, padded_nx)
|
|
1269
|
+
epsilon_padded = create_or_pad(
|
|
1270
|
+
epsilon, total_pad, device, dtype, padded_size, mode="replicate"
|
|
1271
|
+
)
|
|
1272
|
+
sigma_padded = create_or_pad(
|
|
1273
|
+
sigma, total_pad, device, dtype, padded_size, mode="replicate"
|
|
1274
|
+
)
|
|
1275
|
+
mu_padded = create_or_pad(
|
|
1276
|
+
mu, total_pad, device, dtype, padded_size, mode="replicate"
|
|
1277
|
+
)
|
|
1278
|
+
|
|
1279
|
+
# Prepare update coefficients using padded models
|
|
1280
|
+
ca, cb, cq = prepare_parameters(epsilon_padded, sigma_padded, mu_padded, dt)
|
|
1281
|
+
|
|
1282
|
+
# Initialize fields with padded dimensions
|
|
1283
|
+
size_with_batch = (n_shots, padded_ny, padded_nx)
|
|
1284
|
+
|
|
1285
|
+
def init_wavefield(field_0: Optional[torch.Tensor]) -> torch.Tensor:
|
|
1286
|
+
"""Initialize wavefield with fd_pad zero padding."""
|
|
1287
|
+
if field_0 is not None:
|
|
1288
|
+
if field_0.ndim == 2:
|
|
1289
|
+
field_0 = field_0[None, :, :].expand(n_shots, -1, -1)
|
|
1290
|
+
# Pad with asymmetric fd_pad_list for staggered grid
|
|
1291
|
+
return create_or_pad(
|
|
1292
|
+
field_0, fd_pad_list, device, dtype, size_with_batch
|
|
1293
|
+
).contiguous()
|
|
1294
|
+
return torch.zeros(size_with_batch, device=device, dtype=dtype)
|
|
1295
|
+
|
|
1296
|
+
Ey = init_wavefield(Ey_0)
|
|
1297
|
+
Hx = init_wavefield(Hx_0)
|
|
1298
|
+
Hz = init_wavefield(Hz_0)
|
|
1299
|
+
m_Ey_x = init_wavefield(m_Ey_x_0)
|
|
1300
|
+
m_Ey_z = init_wavefield(m_Ey_z_0)
|
|
1301
|
+
m_Hx_z = init_wavefield(m_Hx_z_0)
|
|
1302
|
+
m_Hz_x = init_wavefield(m_Hz_x_0)
|
|
1303
|
+
|
|
1304
|
+
# Zero out interior of PML auxiliary variables (optimization)
|
|
1305
|
+
# This works correctly with user-provided states (see forward pass comments)
|
|
1306
|
+
pml_aux_dims = [1, 0, 0, 1] # [m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x]
|
|
1307
|
+
for wf, dim in zip([m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x], pml_aux_dims):
|
|
1308
|
+
zero_interior(wf, fd_pad_list, pml_width_list, dim)
|
|
1309
|
+
|
|
1310
|
+
# Set up PML profiles for the padded domain
|
|
1311
|
+
pml_profiles_list = staggered.set_pml_profiles(
|
|
1312
|
+
pml_width=pml_width_list,
|
|
1313
|
+
accuracy=stencil,
|
|
1314
|
+
fd_pad=fd_pad_list,
|
|
1315
|
+
dt=dt,
|
|
1316
|
+
grid_spacing=grid_spacing,
|
|
1317
|
+
max_vel=max_vel,
|
|
1318
|
+
dtype=dtype,
|
|
1319
|
+
device=device,
|
|
1320
|
+
pml_freq=pml_freq,
|
|
1321
|
+
ny=padded_ny,
|
|
1322
|
+
nx=padded_nx,
|
|
1323
|
+
)
|
|
1324
|
+
(
|
|
1325
|
+
ay,
|
|
1326
|
+
ay_h,
|
|
1327
|
+
ax,
|
|
1328
|
+
ax_h,
|
|
1329
|
+
by,
|
|
1330
|
+
by_h,
|
|
1331
|
+
bx,
|
|
1332
|
+
bx_h,
|
|
1333
|
+
ky,
|
|
1334
|
+
ky_h,
|
|
1335
|
+
kx,
|
|
1336
|
+
kx_h,
|
|
1337
|
+
) = pml_profiles_list
|
|
1338
|
+
|
|
1339
|
+
# Flatten PML profiles for C backend (remove batch dimensions)
|
|
1340
|
+
ay_flat = ay.squeeze().contiguous()
|
|
1341
|
+
ay_h_flat = ay_h.squeeze().contiguous()
|
|
1342
|
+
ax_flat = ax.squeeze().contiguous()
|
|
1343
|
+
ax_h_flat = ax_h.squeeze().contiguous()
|
|
1344
|
+
by_flat = by.squeeze().contiguous()
|
|
1345
|
+
by_h_flat = by_h.squeeze().contiguous()
|
|
1346
|
+
bx_flat = bx.squeeze().contiguous()
|
|
1347
|
+
bx_h_flat = bx_h.squeeze().contiguous()
|
|
1348
|
+
|
|
1349
|
+
# Flatten kappa profiles for C backend
|
|
1350
|
+
ky_flat = ky.squeeze().contiguous()
|
|
1351
|
+
ky_h_flat = ky_h.squeeze().contiguous()
|
|
1352
|
+
kx_flat = kx.squeeze().contiguous()
|
|
1353
|
+
kx_h_flat = kx_h.squeeze().contiguous()
|
|
1354
|
+
|
|
1355
|
+
# =========================================================================
|
|
1356
|
+
# Prepare source and receiver indices
|
|
1357
|
+
# Original positions are in the un-padded model coordinate system.
|
|
1358
|
+
# We need to offset by total_pad (fd_pad + pml_width) to get padded coords.
|
|
1359
|
+
# =========================================================================
|
|
1360
|
+
flat_model_shape = padded_ny * padded_nx
|
|
1361
|
+
|
|
1362
|
+
if source_location is not None and source_location.numel() > 0:
|
|
1363
|
+
# Adjust source positions by total padding offset
|
|
1364
|
+
source_y = source_location[..., 0] + total_pad[0]
|
|
1365
|
+
source_x = source_location[..., 1] + total_pad[2]
|
|
1366
|
+
sources_i = (source_y * padded_nx + source_x).long().contiguous()
|
|
1367
|
+
n_sources = source_location.shape[1]
|
|
1368
|
+
else:
|
|
1369
|
+
sources_i = torch.empty(0, device=device, dtype=torch.long)
|
|
1370
|
+
n_sources = 0
|
|
1371
|
+
|
|
1372
|
+
if receiver_location is not None and receiver_location.numel() > 0:
|
|
1373
|
+
# Adjust receiver positions by total padding offset
|
|
1374
|
+
receiver_y = receiver_location[..., 0] + total_pad[0]
|
|
1375
|
+
receiver_x = receiver_location[..., 1] + total_pad[2]
|
|
1376
|
+
receivers_i = (receiver_y * padded_nx + receiver_x).long().contiguous()
|
|
1377
|
+
n_receivers = receiver_location.shape[1]
|
|
1378
|
+
else:
|
|
1379
|
+
receivers_i = torch.empty(0, device=device, dtype=torch.long)
|
|
1380
|
+
n_receivers = 0
|
|
1381
|
+
|
|
1382
|
+
# Prepare source amplitudes with proper scaling
|
|
1383
|
+
if source_amplitude is not None and source_amplitude.numel() > 0:
|
|
1384
|
+
source_coeff = -1.0 / (dx * dy)
|
|
1385
|
+
# Expand cb to batch dimension for gather
|
|
1386
|
+
cb_expanded = cb[None, :, :].expand(n_shots, -1, -1)
|
|
1387
|
+
cb_flat = cb_expanded.reshape(n_shots, flat_model_shape)
|
|
1388
|
+
cb_at_src = cb_flat.gather(1, sources_i)
|
|
1389
|
+
# Reshape source amplitude: [shot, source, time] -> [time, shot, source]
|
|
1390
|
+
f = source_amplitude.permute(2, 0, 1).contiguous()
|
|
1391
|
+
# Scale by cb and source coefficient
|
|
1392
|
+
f = f * cb_at_src[None, :, :] * source_coeff
|
|
1393
|
+
f = f.reshape(nt_steps * n_shots * n_sources)
|
|
1394
|
+
else:
|
|
1395
|
+
f = torch.empty(0, device=device, dtype=dtype)
|
|
1396
|
+
|
|
1397
|
+
# Flatten fields for C backend
|
|
1398
|
+
Ey = Ey.contiguous()
|
|
1399
|
+
Hx = Hx.contiguous()
|
|
1400
|
+
Hz = Hz.contiguous()
|
|
1401
|
+
m_Ey_x = m_Ey_x.contiguous()
|
|
1402
|
+
m_Ey_z = m_Ey_z.contiguous()
|
|
1403
|
+
m_Hx_z = m_Hx_z.contiguous()
|
|
1404
|
+
m_Hz_x = m_Hz_x.contiguous()
|
|
1405
|
+
|
|
1406
|
+
# Flatten coefficients (add batch dimension for consistency)
|
|
1407
|
+
ca = ca[None, :, :].contiguous()
|
|
1408
|
+
cb = cb[None, :, :].contiguous()
|
|
1409
|
+
cq = cq[None, :, :].contiguous()
|
|
1410
|
+
|
|
1411
|
+
# PML boundaries (where PML starts in the padded domain)
|
|
1412
|
+
pml_y0 = fd_pad_list[0] + pml_width_list[0]
|
|
1413
|
+
pml_y1 = padded_ny - fd_pad_list[1] - pml_width_list[1]
|
|
1414
|
+
pml_x0 = fd_pad_list[2] + pml_width_list[2]
|
|
1415
|
+
pml_x1 = padded_nx - fd_pad_list[3] - pml_width_list[3]
|
|
1416
|
+
|
|
1417
|
+
# Determine if any input requires gradients
|
|
1418
|
+
requires_grad = epsilon.requires_grad or sigma.requires_grad
|
|
1419
|
+
|
|
1420
|
+
functorch_active = torch._C._are_functorch_transforms_active()
|
|
1421
|
+
if functorch_active:
|
|
1422
|
+
raise NotImplementedError(
|
|
1423
|
+
"torch.func transforms are not supported for the C/CUDA backend."
|
|
1424
|
+
)
|
|
1425
|
+
|
|
1426
|
+
storage_kind, _, storage_bytes_per_elem = _resolve_storage_compression(
|
|
1427
|
+
storage_compression,
|
|
1428
|
+
dtype,
|
|
1429
|
+
device,
|
|
1430
|
+
context="storage_compression",
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
# Determine if we should save snapshots for backward pass
|
|
1434
|
+
if save_snapshots is None:
|
|
1435
|
+
do_save_snapshots = requires_grad
|
|
1436
|
+
else:
|
|
1437
|
+
do_save_snapshots = save_snapshots
|
|
1438
|
+
|
|
1439
|
+
# If save_snapshots is False but requires_grad is True, warn user
|
|
1440
|
+
if requires_grad and save_snapshots is False:
|
|
1441
|
+
import warnings
|
|
1442
|
+
|
|
1443
|
+
warnings.warn(
|
|
1444
|
+
"save_snapshots=False but model parameters require gradients. "
|
|
1445
|
+
"Backward pass will fail.",
|
|
1446
|
+
UserWarning,
|
|
1447
|
+
)
|
|
1448
|
+
|
|
1449
|
+
storage_mode_str = storage_mode.lower()
|
|
1450
|
+
if storage_mode_str not in {"device", "cpu", "disk", "none", "auto"}:
|
|
1451
|
+
raise ValueError(
|
|
1452
|
+
"storage_mode must be 'device', 'cpu', 'disk', 'none', or 'auto', "
|
|
1453
|
+
f"but got {storage_mode!r}"
|
|
1454
|
+
)
|
|
1455
|
+
if device.type == "cpu" and storage_mode_str == "cpu":
|
|
1456
|
+
storage_mode_str = "device"
|
|
1457
|
+
|
|
1458
|
+
needs_storage = do_save_snapshots and requires_grad
|
|
1459
|
+
effective_storage_mode_str = storage_mode_str
|
|
1460
|
+
if not needs_storage:
|
|
1461
|
+
if effective_storage_mode_str == "auto":
|
|
1462
|
+
effective_storage_mode_str = "none"
|
|
1463
|
+
else:
|
|
1464
|
+
if effective_storage_mode_str == "none":
|
|
1465
|
+
raise ValueError(
|
|
1466
|
+
"storage_mode='none' is not compatible with gradient computation "
|
|
1467
|
+
"when model parameters require gradients."
|
|
1468
|
+
)
|
|
1469
|
+
if effective_storage_mode_str == "auto":
|
|
1470
|
+
dtype_size = storage_bytes_per_elem
|
|
1471
|
+
# Estimate required bytes for storing Ey and curl_H.
|
|
1472
|
+
num_elements_per_shot = padded_ny * padded_nx
|
|
1473
|
+
shot_bytes_uncomp = num_elements_per_shot * dtype_size
|
|
1474
|
+
n_stored = (
|
|
1475
|
+
nt_steps + gradient_sampling_interval - 1
|
|
1476
|
+
) // gradient_sampling_interval
|
|
1477
|
+
total_bytes = n_stored * n_shots * shot_bytes_uncomp * 2 # Ey + curl_H
|
|
1478
|
+
|
|
1479
|
+
limit_device = (
|
|
1480
|
+
storage_bytes_limit_device
|
|
1481
|
+
if storage_bytes_limit_device is not None
|
|
1482
|
+
else float("inf")
|
|
1483
|
+
)
|
|
1484
|
+
limit_host = (
|
|
1485
|
+
storage_bytes_limit_host
|
|
1486
|
+
if storage_bytes_limit_host is not None
|
|
1487
|
+
else float("inf")
|
|
1488
|
+
)
|
|
1489
|
+
import warnings
|
|
1490
|
+
|
|
1491
|
+
if device.type == "cuda" and total_bytes <= limit_device:
|
|
1492
|
+
effective_storage_mode_str = "device"
|
|
1493
|
+
elif total_bytes <= limit_host:
|
|
1494
|
+
effective_storage_mode_str = "cpu"
|
|
1495
|
+
else:
|
|
1496
|
+
effective_storage_mode_str = "disk"
|
|
1497
|
+
|
|
1498
|
+
warnings.warn(
|
|
1499
|
+
f"storage_mode='auto' selected storage_mode='{effective_storage_mode_str}' "
|
|
1500
|
+
f"for estimated storage size {total_bytes / 1e9:.2f} GB.",
|
|
1501
|
+
RuntimeWarning,
|
|
1502
|
+
)
|
|
1503
|
+
|
|
1504
|
+
# Callback fd_pad is the actual fd_pad used for wavefields
|
|
1505
|
+
callback_fd_pad = fd_pad_list
|
|
1506
|
+
|
|
1507
|
+
# Callback models dict
|
|
1508
|
+
callback_models = {
|
|
1509
|
+
"epsilon": epsilon_padded,
|
|
1510
|
+
"sigma": sigma_padded,
|
|
1511
|
+
"mu": mu_padded,
|
|
1512
|
+
"ca": ca,
|
|
1513
|
+
"cb": cb,
|
|
1514
|
+
"cq": cq,
|
|
1515
|
+
}
|
|
1516
|
+
|
|
1517
|
+
use_autograd_fn = (
|
|
1518
|
+
requires_grad
|
|
1519
|
+
and do_save_snapshots
|
|
1520
|
+
) or functorch_active
|
|
1521
|
+
if use_autograd_fn:
|
|
1522
|
+
# Use autograd Function for gradient computation
|
|
1523
|
+
result = MaxwellTMForwardFunc.apply(
|
|
1524
|
+
ca,
|
|
1525
|
+
cb,
|
|
1526
|
+
cq,
|
|
1527
|
+
f,
|
|
1528
|
+
ay_flat,
|
|
1529
|
+
by_flat,
|
|
1530
|
+
ay_h_flat,
|
|
1531
|
+
by_h_flat,
|
|
1532
|
+
ax_flat,
|
|
1533
|
+
bx_flat,
|
|
1534
|
+
ax_h_flat,
|
|
1535
|
+
bx_h_flat,
|
|
1536
|
+
ky_flat,
|
|
1537
|
+
ky_h_flat,
|
|
1538
|
+
kx_flat,
|
|
1539
|
+
kx_h_flat,
|
|
1540
|
+
sources_i,
|
|
1541
|
+
receivers_i,
|
|
1542
|
+
1.0 / dy, # rdy
|
|
1543
|
+
1.0 / dx, # rdx
|
|
1544
|
+
dt,
|
|
1545
|
+
nt_steps,
|
|
1546
|
+
n_shots,
|
|
1547
|
+
padded_ny,
|
|
1548
|
+
padded_nx,
|
|
1549
|
+
n_sources,
|
|
1550
|
+
n_receivers,
|
|
1551
|
+
gradient_sampling_interval, # step_ratio
|
|
1552
|
+
stencil, # accuracy
|
|
1553
|
+
False, # ca_batched
|
|
1554
|
+
False, # cb_batched
|
|
1555
|
+
False, # cq_batched
|
|
1556
|
+
pml_y0,
|
|
1557
|
+
pml_x0,
|
|
1558
|
+
pml_y1,
|
|
1559
|
+
pml_x1,
|
|
1560
|
+
tuple(fd_pad_list), # fd_pad for callback
|
|
1561
|
+
tuple(pml_width_list), # pml_width for callback
|
|
1562
|
+
callback_models, # models dict for callback
|
|
1563
|
+
forward_callback,
|
|
1564
|
+
backward_callback,
|
|
1565
|
+
callback_frequency,
|
|
1566
|
+
effective_storage_mode_str,
|
|
1567
|
+
storage_path,
|
|
1568
|
+
storage_compression,
|
|
1569
|
+
Ey,
|
|
1570
|
+
Hx,
|
|
1571
|
+
Hz,
|
|
1572
|
+
m_Ey_x,
|
|
1573
|
+
m_Ey_z,
|
|
1574
|
+
m_Hx_z,
|
|
1575
|
+
m_Hz_x,
|
|
1576
|
+
n_threads_val,
|
|
1577
|
+
)
|
|
1578
|
+
# Unpack result (drop context handle if present)
|
|
1579
|
+
if len(result) == 9:
|
|
1580
|
+
(
|
|
1581
|
+
Ey_out,
|
|
1582
|
+
Hx_out,
|
|
1583
|
+
Hz_out,
|
|
1584
|
+
m_Ey_x_out,
|
|
1585
|
+
m_Ey_z_out,
|
|
1586
|
+
m_Hx_z_out,
|
|
1587
|
+
m_Hz_x_out,
|
|
1588
|
+
receiver_amplitudes,
|
|
1589
|
+
_ctx_handle,
|
|
1590
|
+
) = result
|
|
1591
|
+
else:
|
|
1592
|
+
(
|
|
1593
|
+
Ey_out,
|
|
1594
|
+
Hx_out,
|
|
1595
|
+
Hz_out,
|
|
1596
|
+
m_Ey_x_out,
|
|
1597
|
+
m_Ey_z_out,
|
|
1598
|
+
m_Hx_z_out,
|
|
1599
|
+
m_Hz_x_out,
|
|
1600
|
+
receiver_amplitudes,
|
|
1601
|
+
) = result
|
|
1602
|
+
# Output cropping: only remove fd_pad, keep PML region
|
|
1603
|
+
s = (
|
|
1604
|
+
slice(None), # batch dimension
|
|
1605
|
+
slice(
|
|
1606
|
+
fd_pad_list[0],
|
|
1607
|
+
padded_ny - fd_pad_list[1] if fd_pad_list[1] > 0 else None,
|
|
1608
|
+
),
|
|
1609
|
+
slice(
|
|
1610
|
+
fd_pad_list[2],
|
|
1611
|
+
padded_nx - fd_pad_list[3] if fd_pad_list[3] > 0 else None,
|
|
1612
|
+
),
|
|
1613
|
+
)
|
|
1614
|
+
|
|
1615
|
+
return (
|
|
1616
|
+
Ey_out[s],
|
|
1617
|
+
Hx_out[s],
|
|
1618
|
+
Hz_out[s],
|
|
1619
|
+
m_Ey_x_out[s],
|
|
1620
|
+
m_Ey_z_out[s],
|
|
1621
|
+
m_Hx_z_out[s],
|
|
1622
|
+
m_Hz_x_out[s],
|
|
1623
|
+
receiver_amplitudes,
|
|
1624
|
+
)
|
|
1625
|
+
else:
|
|
1626
|
+
# Direct call without autograd for inference
|
|
1627
|
+
# Get the backend function
|
|
1628
|
+
try:
|
|
1629
|
+
forward_func = backend_utils.get_backend_function(
|
|
1630
|
+
"maxwell_tm", "forward", stencil, dtype, device
|
|
1631
|
+
)
|
|
1632
|
+
except AttributeError as e:
|
|
1633
|
+
raise RuntimeError(
|
|
1634
|
+
f"C/CUDA backend function not available for accuracy={stencil}, "
|
|
1635
|
+
f"dtype={dtype}, device={device}. Error: {e}"
|
|
1636
|
+
)
|
|
1637
|
+
|
|
1638
|
+
# Get device index for CUDA
|
|
1639
|
+
device_idx = (
|
|
1640
|
+
device.index if device.type == "cuda" and device.index is not None else 0
|
|
1641
|
+
)
|
|
1642
|
+
|
|
1643
|
+
# Initialize receiver amplitudes
|
|
1644
|
+
if n_receivers > 0:
|
|
1645
|
+
receiver_amplitudes = torch.zeros(
|
|
1646
|
+
nt_steps, n_shots, n_receivers, device=device, dtype=dtype
|
|
1647
|
+
)
|
|
1648
|
+
else:
|
|
1649
|
+
receiver_amplitudes = torch.empty(0, device=device, dtype=dtype)
|
|
1650
|
+
|
|
1651
|
+
# If no callback is provided, run entire propagation in single call
|
|
1652
|
+
# Otherwise, chunk into callback_frequency steps
|
|
1653
|
+
if forward_callback is None:
|
|
1654
|
+
effective_callback_freq = nt_steps
|
|
1655
|
+
else:
|
|
1656
|
+
effective_callback_freq = callback_frequency
|
|
1657
|
+
|
|
1658
|
+
# Main time-stepping loop with chunked calls for callback support
|
|
1659
|
+
for step in range(0, nt_steps, effective_callback_freq):
|
|
1660
|
+
# Call callback at the start of each chunk
|
|
1661
|
+
if forward_callback is not None:
|
|
1662
|
+
callback_wavefields = {
|
|
1663
|
+
"Ey": Ey,
|
|
1664
|
+
"Hx": Hx,
|
|
1665
|
+
"Hz": Hz,
|
|
1666
|
+
"m_Ey_x": m_Ey_x,
|
|
1667
|
+
"m_Ey_z": m_Ey_z,
|
|
1668
|
+
"m_Hx_z": m_Hx_z,
|
|
1669
|
+
"m_Hz_x": m_Hz_x,
|
|
1670
|
+
}
|
|
1671
|
+
callback_state = CallbackState(
|
|
1672
|
+
dt=dt,
|
|
1673
|
+
step=step,
|
|
1674
|
+
nt=nt_steps,
|
|
1675
|
+
wavefields=callback_wavefields,
|
|
1676
|
+
models=callback_models,
|
|
1677
|
+
gradients=None,
|
|
1678
|
+
fd_pad=callback_fd_pad,
|
|
1679
|
+
pml_width=pml_width_list,
|
|
1680
|
+
is_backward=False,
|
|
1681
|
+
grid_spacing=[dy, dx],
|
|
1682
|
+
)
|
|
1683
|
+
forward_callback(callback_state)
|
|
1684
|
+
|
|
1685
|
+
# Number of steps to propagate in this chunk
|
|
1686
|
+
step_nt = min(nt_steps - step, effective_callback_freq)
|
|
1687
|
+
|
|
1688
|
+
# Call the C/CUDA function for this chunk
|
|
1689
|
+
forward_func(
|
|
1690
|
+
backend_utils.tensor_to_ptr(ca),
|
|
1691
|
+
backend_utils.tensor_to_ptr(cb),
|
|
1692
|
+
backend_utils.tensor_to_ptr(cq),
|
|
1693
|
+
backend_utils.tensor_to_ptr(f),
|
|
1694
|
+
backend_utils.tensor_to_ptr(Ey),
|
|
1695
|
+
backend_utils.tensor_to_ptr(Hx),
|
|
1696
|
+
backend_utils.tensor_to_ptr(Hz),
|
|
1697
|
+
backend_utils.tensor_to_ptr(m_Ey_x),
|
|
1698
|
+
backend_utils.tensor_to_ptr(m_Ey_z),
|
|
1699
|
+
backend_utils.tensor_to_ptr(m_Hx_z),
|
|
1700
|
+
backend_utils.tensor_to_ptr(m_Hz_x),
|
|
1701
|
+
backend_utils.tensor_to_ptr(receiver_amplitudes),
|
|
1702
|
+
backend_utils.tensor_to_ptr(ay_flat),
|
|
1703
|
+
backend_utils.tensor_to_ptr(by_flat),
|
|
1704
|
+
backend_utils.tensor_to_ptr(ay_h_flat),
|
|
1705
|
+
backend_utils.tensor_to_ptr(by_h_flat),
|
|
1706
|
+
backend_utils.tensor_to_ptr(ax_flat),
|
|
1707
|
+
backend_utils.tensor_to_ptr(bx_flat),
|
|
1708
|
+
backend_utils.tensor_to_ptr(ax_h_flat),
|
|
1709
|
+
backend_utils.tensor_to_ptr(bx_h_flat),
|
|
1710
|
+
backend_utils.tensor_to_ptr(ky_flat),
|
|
1711
|
+
backend_utils.tensor_to_ptr(ky_h_flat),
|
|
1712
|
+
backend_utils.tensor_to_ptr(kx_flat),
|
|
1713
|
+
backend_utils.tensor_to_ptr(kx_h_flat),
|
|
1714
|
+
backend_utils.tensor_to_ptr(sources_i),
|
|
1715
|
+
backend_utils.tensor_to_ptr(receivers_i),
|
|
1716
|
+
1.0 / dy, # rdy
|
|
1717
|
+
1.0 / dx, # rdx
|
|
1718
|
+
dt,
|
|
1719
|
+
step_nt, # nt for this chunk
|
|
1720
|
+
n_shots,
|
|
1721
|
+
padded_ny,
|
|
1722
|
+
padded_nx,
|
|
1723
|
+
n_sources,
|
|
1724
|
+
n_receivers,
|
|
1725
|
+
gradient_sampling_interval, # step_ratio
|
|
1726
|
+
False, # ca_batched
|
|
1727
|
+
False, # cb_batched
|
|
1728
|
+
False, # cq_batched
|
|
1729
|
+
step, # start_t - crucial for correct source injection timing
|
|
1730
|
+
pml_y0,
|
|
1731
|
+
pml_x0,
|
|
1732
|
+
pml_y1,
|
|
1733
|
+
pml_x1,
|
|
1734
|
+
n_threads_val,
|
|
1735
|
+
device_idx,
|
|
1736
|
+
)
|
|
1737
|
+
|
|
1738
|
+
# Output cropping: only remove fd_pad, keep PML region
|
|
1739
|
+
s = (
|
|
1740
|
+
slice(None), # batch dimension
|
|
1741
|
+
slice(
|
|
1742
|
+
fd_pad_list[0],
|
|
1743
|
+
padded_ny - fd_pad_list[1] if fd_pad_list[1] > 0 else None,
|
|
1744
|
+
),
|
|
1745
|
+
slice(
|
|
1746
|
+
fd_pad_list[2],
|
|
1747
|
+
padded_nx - fd_pad_list[3] if fd_pad_list[3] > 0 else None,
|
|
1748
|
+
),
|
|
1749
|
+
)
|
|
1750
|
+
|
|
1751
|
+
return (
|
|
1752
|
+
Ey[s],
|
|
1753
|
+
Hx[s],
|
|
1754
|
+
Hz[s],
|
|
1755
|
+
m_Ey_x[s],
|
|
1756
|
+
m_Ey_z[s],
|
|
1757
|
+
m_Hx_z[s],
|
|
1758
|
+
m_Hz_x[s],
|
|
1759
|
+
receiver_amplitudes,
|
|
1760
|
+
)
|
|
1761
|
+
|
|
1762
|
+
|
|
1763
|
+
class MaxwellTMForwardFunc(torch.autograd.Function):
|
|
1764
|
+
"""Autograd function for the forward pass of Maxwell TM wave propagation.
|
|
1765
|
+
|
|
1766
|
+
This class defines the forward and backward passes for the 2D TM mode
|
|
1767
|
+
Maxwell equations, allowing PyTorch to compute gradients through the wave
|
|
1768
|
+
propagation operation. It interfaces directly with the C/CUDA backend.
|
|
1769
|
+
|
|
1770
|
+
The backward pass uses the Adjoint State Method (ASM) which requires
|
|
1771
|
+
storing forward wavefield values at each step_ratio interval for
|
|
1772
|
+
gradient computation.
|
|
1773
|
+
"""
|
|
1774
|
+
|
|
1775
|
+
@staticmethod
|
|
1776
|
+
def forward(
|
|
1777
|
+
ca: torch.Tensor,
|
|
1778
|
+
cb: torch.Tensor,
|
|
1779
|
+
cq: torch.Tensor,
|
|
1780
|
+
source_amplitudes_scaled: torch.Tensor,
|
|
1781
|
+
ay: torch.Tensor,
|
|
1782
|
+
by: torch.Tensor,
|
|
1783
|
+
ay_h: torch.Tensor,
|
|
1784
|
+
by_h: torch.Tensor,
|
|
1785
|
+
ax: torch.Tensor,
|
|
1786
|
+
bx: torch.Tensor,
|
|
1787
|
+
ax_h: torch.Tensor,
|
|
1788
|
+
bx_h: torch.Tensor,
|
|
1789
|
+
ky: torch.Tensor,
|
|
1790
|
+
ky_h: torch.Tensor,
|
|
1791
|
+
kx: torch.Tensor,
|
|
1792
|
+
kx_h: torch.Tensor,
|
|
1793
|
+
sources_i: torch.Tensor,
|
|
1794
|
+
receivers_i: torch.Tensor,
|
|
1795
|
+
rdy: float,
|
|
1796
|
+
rdx: float,
|
|
1797
|
+
dt: float,
|
|
1798
|
+
nt: int,
|
|
1799
|
+
n_shots: int,
|
|
1800
|
+
ny: int,
|
|
1801
|
+
nx: int,
|
|
1802
|
+
n_sources: int,
|
|
1803
|
+
n_receivers: int,
|
|
1804
|
+
step_ratio: int,
|
|
1805
|
+
accuracy: int,
|
|
1806
|
+
ca_batched: bool,
|
|
1807
|
+
cb_batched: bool,
|
|
1808
|
+
cq_batched: bool,
|
|
1809
|
+
pml_y0: int,
|
|
1810
|
+
pml_x0: int,
|
|
1811
|
+
pml_y1: int,
|
|
1812
|
+
pml_x1: int,
|
|
1813
|
+
fd_pad: tuple[int, int, int, int],
|
|
1814
|
+
pml_width: tuple[int, int, int, int],
|
|
1815
|
+
models: dict,
|
|
1816
|
+
forward_callback: Optional[Callback],
|
|
1817
|
+
backward_callback: Optional[Callback],
|
|
1818
|
+
callback_frequency: int,
|
|
1819
|
+
storage_mode_str: str,
|
|
1820
|
+
storage_path: str,
|
|
1821
|
+
storage_compression: Union[bool, str],
|
|
1822
|
+
Ey: torch.Tensor,
|
|
1823
|
+
Hx: torch.Tensor,
|
|
1824
|
+
Hz: torch.Tensor,
|
|
1825
|
+
m_Ey_x: torch.Tensor,
|
|
1826
|
+
m_Ey_z: torch.Tensor,
|
|
1827
|
+
m_Hx_z: torch.Tensor,
|
|
1828
|
+
m_Hz_x: torch.Tensor,
|
|
1829
|
+
n_threads: int,
|
|
1830
|
+
) -> tuple[Any, ...]:
|
|
1831
|
+
"""Performs the forward propagation of the Maxwell TM equations."""
|
|
1832
|
+
from . import backend_utils
|
|
1833
|
+
|
|
1834
|
+
device = Ey.device
|
|
1835
|
+
dtype = Ey.dtype
|
|
1836
|
+
|
|
1837
|
+
ca_requires_grad = ca.requires_grad
|
|
1838
|
+
cb_requires_grad = cb.requires_grad
|
|
1839
|
+
needs_grad = ca_requires_grad or cb_requires_grad
|
|
1840
|
+
|
|
1841
|
+
# Initialize receiver amplitudes
|
|
1842
|
+
if n_receivers > 0:
|
|
1843
|
+
receiver_amplitudes = torch.zeros(
|
|
1844
|
+
nt, n_shots, n_receivers, device=device, dtype=dtype
|
|
1845
|
+
)
|
|
1846
|
+
else:
|
|
1847
|
+
receiver_amplitudes = torch.empty(0, device=device, dtype=dtype)
|
|
1848
|
+
|
|
1849
|
+
# Get device index for CUDA
|
|
1850
|
+
device_idx = (
|
|
1851
|
+
device.index if device.type == "cuda" and device.index is not None else 0
|
|
1852
|
+
)
|
|
1853
|
+
|
|
1854
|
+
backward_storage_tensors: list[torch.Tensor] = []
|
|
1855
|
+
backward_storage_objects: list[Any] = []
|
|
1856
|
+
backward_storage_filename_arrays: list[Any] = []
|
|
1857
|
+
storage_mode = STORAGE_NONE
|
|
1858
|
+
shot_bytes_uncomp = 0
|
|
1859
|
+
|
|
1860
|
+
if needs_grad:
|
|
1861
|
+
import ctypes
|
|
1862
|
+
|
|
1863
|
+
# Resolve storage mode and sizes
|
|
1864
|
+
if str(device) == "cpu" and storage_mode_str == "cpu":
|
|
1865
|
+
storage_mode_str = "device"
|
|
1866
|
+
storage_mode = storage_mode_to_int(storage_mode_str)
|
|
1867
|
+
|
|
1868
|
+
num_elements_per_shot = ny * nx
|
|
1869
|
+
_, store_dtype, _ = _resolve_storage_compression(
|
|
1870
|
+
storage_compression,
|
|
1871
|
+
dtype,
|
|
1872
|
+
device,
|
|
1873
|
+
context="storage_compression",
|
|
1874
|
+
)
|
|
1875
|
+
|
|
1876
|
+
shot_bytes_uncomp = num_elements_per_shot * store_dtype.itemsize
|
|
1877
|
+
|
|
1878
|
+
num_steps_stored = (nt + step_ratio - 1) // step_ratio
|
|
1879
|
+
|
|
1880
|
+
# Storage buffers and filename arrays (mirrors Deepwave)
|
|
1881
|
+
char_ptr_type = ctypes.c_char_p
|
|
1882
|
+
is_cuda = device.type == "cuda"
|
|
1883
|
+
|
|
1884
|
+
def alloc_storage(requires_grad_cond: bool):
|
|
1885
|
+
store_1 = torch.empty(0)
|
|
1886
|
+
store_3 = torch.empty(0)
|
|
1887
|
+
filenames_arr = (char_ptr_type * 0)()
|
|
1888
|
+
|
|
1889
|
+
if requires_grad_cond and storage_mode != STORAGE_NONE:
|
|
1890
|
+
if storage_mode == STORAGE_DEVICE:
|
|
1891
|
+
store_1 = torch.empty(
|
|
1892
|
+
num_steps_stored,
|
|
1893
|
+
n_shots,
|
|
1894
|
+
ny,
|
|
1895
|
+
nx,
|
|
1896
|
+
device=device,
|
|
1897
|
+
dtype=store_dtype,
|
|
1898
|
+
)
|
|
1899
|
+
elif storage_mode == STORAGE_CPU:
|
|
1900
|
+
# Multi-buffer device staging to overlap D2H copies.
|
|
1901
|
+
store_1 = torch.empty(
|
|
1902
|
+
_CPU_STORAGE_BUFFERS,
|
|
1903
|
+
n_shots,
|
|
1904
|
+
ny,
|
|
1905
|
+
nx,
|
|
1906
|
+
device=device,
|
|
1907
|
+
dtype=store_dtype,
|
|
1908
|
+
)
|
|
1909
|
+
store_3 = torch.empty(
|
|
1910
|
+
num_steps_stored,
|
|
1911
|
+
n_shots,
|
|
1912
|
+
shot_bytes_uncomp // store_dtype.itemsize,
|
|
1913
|
+
device="cpu",
|
|
1914
|
+
pin_memory=True,
|
|
1915
|
+
dtype=store_dtype,
|
|
1916
|
+
)
|
|
1917
|
+
elif storage_mode == STORAGE_DISK:
|
|
1918
|
+
storage_obj = TemporaryStorage(
|
|
1919
|
+
storage_path, 1 if is_cuda else n_shots
|
|
1920
|
+
)
|
|
1921
|
+
backward_storage_objects.append(storage_obj)
|
|
1922
|
+
filenames_list = [
|
|
1923
|
+
f.encode("utf-8") for f in storage_obj.get_filenames()
|
|
1924
|
+
]
|
|
1925
|
+
filenames_arr = (char_ptr_type * len(filenames_list))()
|
|
1926
|
+
for i_file, f_name in enumerate(filenames_list):
|
|
1927
|
+
filenames_arr[i_file] = ctypes.cast(
|
|
1928
|
+
ctypes.create_string_buffer(f_name), char_ptr_type
|
|
1929
|
+
)
|
|
1930
|
+
|
|
1931
|
+
store_1 = torch.empty(
|
|
1932
|
+
n_shots, ny, nx, device=device, dtype=store_dtype
|
|
1933
|
+
)
|
|
1934
|
+
if is_cuda:
|
|
1935
|
+
store_3 = torch.empty(
|
|
1936
|
+
n_shots,
|
|
1937
|
+
shot_bytes_uncomp // store_dtype.itemsize,
|
|
1938
|
+
device="cpu",
|
|
1939
|
+
pin_memory=True,
|
|
1940
|
+
dtype=store_dtype,
|
|
1941
|
+
)
|
|
1942
|
+
|
|
1943
|
+
backward_storage_tensors.extend([store_1, store_3])
|
|
1944
|
+
backward_storage_filename_arrays.append(filenames_arr)
|
|
1945
|
+
|
|
1946
|
+
filenames_ptr = (
|
|
1947
|
+
ctypes.cast(filenames_arr, ctypes.c_void_p)
|
|
1948
|
+
if storage_mode == STORAGE_DISK
|
|
1949
|
+
else 0
|
|
1950
|
+
)
|
|
1951
|
+
|
|
1952
|
+
return store_1, store_3, filenames_ptr
|
|
1953
|
+
|
|
1954
|
+
ey_store_1, ey_store_3, ey_filenames_ptr = alloc_storage(ca_requires_grad)
|
|
1955
|
+
curl_store_1, curl_store_3, curl_filenames_ptr = alloc_storage(
|
|
1956
|
+
cb_requires_grad
|
|
1957
|
+
)
|
|
1958
|
+
|
|
1959
|
+
# Get the backend function with storage
|
|
1960
|
+
forward_func = backend_utils.get_backend_function(
|
|
1961
|
+
"maxwell_tm", "forward_with_storage", accuracy, dtype, device
|
|
1962
|
+
)
|
|
1963
|
+
|
|
1964
|
+
# Determine effective callback frequency
|
|
1965
|
+
if forward_callback is None:
|
|
1966
|
+
effective_callback_freq = nt // step_ratio
|
|
1967
|
+
else:
|
|
1968
|
+
effective_callback_freq = callback_frequency
|
|
1969
|
+
|
|
1970
|
+
# Chunked forward propagation with callback support
|
|
1971
|
+
for step in range(0, nt // step_ratio, effective_callback_freq):
|
|
1972
|
+
step_nt = (
|
|
1973
|
+
min(effective_callback_freq, nt // step_ratio - step) * step_ratio
|
|
1974
|
+
)
|
|
1975
|
+
start_t = step * step_ratio
|
|
1976
|
+
|
|
1977
|
+
# Call the C/CUDA function with storage for this chunk
|
|
1978
|
+
forward_func(
|
|
1979
|
+
backend_utils.tensor_to_ptr(ca),
|
|
1980
|
+
backend_utils.tensor_to_ptr(cb),
|
|
1981
|
+
backend_utils.tensor_to_ptr(cq),
|
|
1982
|
+
backend_utils.tensor_to_ptr(source_amplitudes_scaled),
|
|
1983
|
+
backend_utils.tensor_to_ptr(Ey),
|
|
1984
|
+
backend_utils.tensor_to_ptr(Hx),
|
|
1985
|
+
backend_utils.tensor_to_ptr(Hz),
|
|
1986
|
+
backend_utils.tensor_to_ptr(m_Ey_x),
|
|
1987
|
+
backend_utils.tensor_to_ptr(m_Ey_z),
|
|
1988
|
+
backend_utils.tensor_to_ptr(m_Hx_z),
|
|
1989
|
+
backend_utils.tensor_to_ptr(m_Hz_x),
|
|
1990
|
+
backend_utils.tensor_to_ptr(receiver_amplitudes),
|
|
1991
|
+
backend_utils.tensor_to_ptr(ey_store_1),
|
|
1992
|
+
backend_utils.tensor_to_ptr(ey_store_3),
|
|
1993
|
+
ey_filenames_ptr,
|
|
1994
|
+
backend_utils.tensor_to_ptr(curl_store_1),
|
|
1995
|
+
backend_utils.tensor_to_ptr(curl_store_3),
|
|
1996
|
+
curl_filenames_ptr,
|
|
1997
|
+
backend_utils.tensor_to_ptr(ay),
|
|
1998
|
+
backend_utils.tensor_to_ptr(by),
|
|
1999
|
+
backend_utils.tensor_to_ptr(ay_h),
|
|
2000
|
+
backend_utils.tensor_to_ptr(by_h),
|
|
2001
|
+
backend_utils.tensor_to_ptr(ax),
|
|
2002
|
+
backend_utils.tensor_to_ptr(bx),
|
|
2003
|
+
backend_utils.tensor_to_ptr(ax_h),
|
|
2004
|
+
backend_utils.tensor_to_ptr(bx_h),
|
|
2005
|
+
backend_utils.tensor_to_ptr(ky),
|
|
2006
|
+
backend_utils.tensor_to_ptr(ky_h),
|
|
2007
|
+
backend_utils.tensor_to_ptr(kx),
|
|
2008
|
+
backend_utils.tensor_to_ptr(kx_h),
|
|
2009
|
+
backend_utils.tensor_to_ptr(sources_i),
|
|
2010
|
+
backend_utils.tensor_to_ptr(receivers_i),
|
|
2011
|
+
rdy,
|
|
2012
|
+
rdx,
|
|
2013
|
+
dt,
|
|
2014
|
+
step_nt, # number of steps in this chunk
|
|
2015
|
+
n_shots,
|
|
2016
|
+
ny,
|
|
2017
|
+
nx,
|
|
2018
|
+
n_sources,
|
|
2019
|
+
n_receivers,
|
|
2020
|
+
step_ratio,
|
|
2021
|
+
storage_mode,
|
|
2022
|
+
shot_bytes_uncomp,
|
|
2023
|
+
ca_requires_grad,
|
|
2024
|
+
cb_requires_grad,
|
|
2025
|
+
ca_batched,
|
|
2026
|
+
cb_batched,
|
|
2027
|
+
cq_batched,
|
|
2028
|
+
start_t, # starting time step
|
|
2029
|
+
pml_y0,
|
|
2030
|
+
pml_x0,
|
|
2031
|
+
pml_y1,
|
|
2032
|
+
pml_x1,
|
|
2033
|
+
n_threads,
|
|
2034
|
+
device_idx,
|
|
2035
|
+
)
|
|
2036
|
+
|
|
2037
|
+
# Call forward callback after each chunk
|
|
2038
|
+
if forward_callback is not None:
|
|
2039
|
+
callback_wavefields = {
|
|
2040
|
+
"Ey": Ey,
|
|
2041
|
+
"Hx": Hx,
|
|
2042
|
+
"Hz": Hz,
|
|
2043
|
+
"m_Ey_x": m_Ey_x,
|
|
2044
|
+
"m_Ey_z": m_Ey_z,
|
|
2045
|
+
"m_Hx_z": m_Hx_z,
|
|
2046
|
+
"m_Hz_x": m_Hz_x,
|
|
2047
|
+
}
|
|
2048
|
+
forward_callback(
|
|
2049
|
+
CallbackState(
|
|
2050
|
+
dt=dt,
|
|
2051
|
+
step=step + step_nt // step_ratio,
|
|
2052
|
+
nt=nt // step_ratio,
|
|
2053
|
+
wavefields=callback_wavefields,
|
|
2054
|
+
models=models,
|
|
2055
|
+
gradients={},
|
|
2056
|
+
fd_pad=list(fd_pad),
|
|
2057
|
+
pml_width=list(pml_width),
|
|
2058
|
+
is_backward=False,
|
|
2059
|
+
)
|
|
2060
|
+
)
|
|
2061
|
+
else:
|
|
2062
|
+
# Use regular forward without storage
|
|
2063
|
+
forward_func = backend_utils.get_backend_function(
|
|
2064
|
+
"maxwell_tm", "forward", accuracy, dtype, device
|
|
2065
|
+
)
|
|
2066
|
+
|
|
2067
|
+
# Call the C/CUDA function
|
|
2068
|
+
forward_func(
|
|
2069
|
+
backend_utils.tensor_to_ptr(ca),
|
|
2070
|
+
backend_utils.tensor_to_ptr(cb),
|
|
2071
|
+
backend_utils.tensor_to_ptr(cq),
|
|
2072
|
+
backend_utils.tensor_to_ptr(source_amplitudes_scaled),
|
|
2073
|
+
backend_utils.tensor_to_ptr(Ey),
|
|
2074
|
+
backend_utils.tensor_to_ptr(Hx),
|
|
2075
|
+
backend_utils.tensor_to_ptr(Hz),
|
|
2076
|
+
backend_utils.tensor_to_ptr(m_Ey_x),
|
|
2077
|
+
backend_utils.tensor_to_ptr(m_Ey_z),
|
|
2078
|
+
backend_utils.tensor_to_ptr(m_Hx_z),
|
|
2079
|
+
backend_utils.tensor_to_ptr(m_Hz_x),
|
|
2080
|
+
backend_utils.tensor_to_ptr(receiver_amplitudes),
|
|
2081
|
+
backend_utils.tensor_to_ptr(ay),
|
|
2082
|
+
backend_utils.tensor_to_ptr(by),
|
|
2083
|
+
backend_utils.tensor_to_ptr(ay_h),
|
|
2084
|
+
backend_utils.tensor_to_ptr(by_h),
|
|
2085
|
+
backend_utils.tensor_to_ptr(ax),
|
|
2086
|
+
backend_utils.tensor_to_ptr(bx),
|
|
2087
|
+
backend_utils.tensor_to_ptr(ax_h),
|
|
2088
|
+
backend_utils.tensor_to_ptr(bx_h),
|
|
2089
|
+
backend_utils.tensor_to_ptr(ky),
|
|
2090
|
+
backend_utils.tensor_to_ptr(ky_h),
|
|
2091
|
+
backend_utils.tensor_to_ptr(kx),
|
|
2092
|
+
backend_utils.tensor_to_ptr(kx_h),
|
|
2093
|
+
backend_utils.tensor_to_ptr(sources_i),
|
|
2094
|
+
backend_utils.tensor_to_ptr(receivers_i),
|
|
2095
|
+
rdy,
|
|
2096
|
+
rdx,
|
|
2097
|
+
dt,
|
|
2098
|
+
nt,
|
|
2099
|
+
n_shots,
|
|
2100
|
+
ny,
|
|
2101
|
+
nx,
|
|
2102
|
+
n_sources,
|
|
2103
|
+
n_receivers,
|
|
2104
|
+
step_ratio,
|
|
2105
|
+
ca_batched,
|
|
2106
|
+
cb_batched,
|
|
2107
|
+
cq_batched,
|
|
2108
|
+
0, # start_t
|
|
2109
|
+
pml_y0,
|
|
2110
|
+
pml_x0,
|
|
2111
|
+
pml_y1,
|
|
2112
|
+
pml_x1,
|
|
2113
|
+
n_threads,
|
|
2114
|
+
device_idx,
|
|
2115
|
+
)
|
|
2116
|
+
|
|
2117
|
+
ctx_data = {
|
|
2118
|
+
"backward_storage_tensors": backward_storage_tensors,
|
|
2119
|
+
"backward_storage_objects": backward_storage_objects,
|
|
2120
|
+
"backward_storage_filename_arrays": backward_storage_filename_arrays,
|
|
2121
|
+
"storage_mode": storage_mode,
|
|
2122
|
+
"shot_bytes_uncomp": shot_bytes_uncomp,
|
|
2123
|
+
"source_amplitudes_scaled": source_amplitudes_scaled,
|
|
2124
|
+
"ca_requires_grad": ca_requires_grad,
|
|
2125
|
+
"cb_requires_grad": cb_requires_grad,
|
|
2126
|
+
}
|
|
2127
|
+
ctx_handle = _register_ctx_handle(ctx_data)
|
|
2128
|
+
|
|
2129
|
+
return (
|
|
2130
|
+
Ey,
|
|
2131
|
+
Hx,
|
|
2132
|
+
Hz,
|
|
2133
|
+
m_Ey_x,
|
|
2134
|
+
m_Ey_z,
|
|
2135
|
+
m_Hx_z,
|
|
2136
|
+
m_Hz_x,
|
|
2137
|
+
receiver_amplitudes,
|
|
2138
|
+
ctx_handle,
|
|
2139
|
+
)
|
|
2140
|
+
|
|
2141
|
+
@staticmethod
|
|
2142
|
+
def setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
|
|
2143
|
+
(
|
|
2144
|
+
ca,
|
|
2145
|
+
cb,
|
|
2146
|
+
cq,
|
|
2147
|
+
_source_amplitudes_scaled,
|
|
2148
|
+
ay,
|
|
2149
|
+
by,
|
|
2150
|
+
ay_h,
|
|
2151
|
+
by_h,
|
|
2152
|
+
ax,
|
|
2153
|
+
bx,
|
|
2154
|
+
ax_h,
|
|
2155
|
+
bx_h,
|
|
2156
|
+
ky,
|
|
2157
|
+
ky_h,
|
|
2158
|
+
kx,
|
|
2159
|
+
kx_h,
|
|
2160
|
+
sources_i,
|
|
2161
|
+
receivers_i,
|
|
2162
|
+
rdy,
|
|
2163
|
+
rdx,
|
|
2164
|
+
dt,
|
|
2165
|
+
nt,
|
|
2166
|
+
n_shots,
|
|
2167
|
+
ny,
|
|
2168
|
+
nx,
|
|
2169
|
+
n_sources,
|
|
2170
|
+
n_receivers,
|
|
2171
|
+
step_ratio,
|
|
2172
|
+
accuracy,
|
|
2173
|
+
ca_batched,
|
|
2174
|
+
cb_batched,
|
|
2175
|
+
cq_batched,
|
|
2176
|
+
pml_y0,
|
|
2177
|
+
pml_x0,
|
|
2178
|
+
pml_y1,
|
|
2179
|
+
pml_x1,
|
|
2180
|
+
fd_pad,
|
|
2181
|
+
pml_width,
|
|
2182
|
+
models,
|
|
2183
|
+
_forward_callback,
|
|
2184
|
+
backward_callback,
|
|
2185
|
+
callback_frequency,
|
|
2186
|
+
_storage_mode_str,
|
|
2187
|
+
_storage_path,
|
|
2188
|
+
_storage_compression,
|
|
2189
|
+
_Ey,
|
|
2190
|
+
_Hx,
|
|
2191
|
+
_Hz,
|
|
2192
|
+
_m_Ey_x,
|
|
2193
|
+
_m_Ey_z,
|
|
2194
|
+
_m_Hx_z,
|
|
2195
|
+
_m_Hz_x,
|
|
2196
|
+
n_threads,
|
|
2197
|
+
) = inputs
|
|
2198
|
+
|
|
2199
|
+
outputs = output if isinstance(output, tuple) else (output,)
|
|
2200
|
+
if len(outputs) != 9:
|
|
2201
|
+
raise RuntimeError(
|
|
2202
|
+
"MaxwellTMForwardFunc expected a context handle output for setup_context."
|
|
2203
|
+
)
|
|
2204
|
+
ctx_handle = outputs[-1]
|
|
2205
|
+
if not isinstance(ctx_handle, torch.Tensor):
|
|
2206
|
+
raise RuntimeError("MaxwellTMForwardFunc context handle must be a Tensor.")
|
|
2207
|
+
|
|
2208
|
+
ctx_handle_id = int(ctx_handle.item())
|
|
2209
|
+
ctx_data = _get_ctx_handle(ctx_handle_id)
|
|
2210
|
+
ctx._ctx_handle_id = ctx_handle_id
|
|
2211
|
+
backward_storage_tensors = ctx_data["backward_storage_tensors"]
|
|
2212
|
+
|
|
2213
|
+
ctx.save_for_backward(
|
|
2214
|
+
ca,
|
|
2215
|
+
cb,
|
|
2216
|
+
cq,
|
|
2217
|
+
ay,
|
|
2218
|
+
by,
|
|
2219
|
+
ay_h,
|
|
2220
|
+
by_h,
|
|
2221
|
+
ax,
|
|
2222
|
+
bx,
|
|
2223
|
+
ax_h,
|
|
2224
|
+
bx_h,
|
|
2225
|
+
ky,
|
|
2226
|
+
ky_h,
|
|
2227
|
+
kx,
|
|
2228
|
+
kx_h,
|
|
2229
|
+
sources_i,
|
|
2230
|
+
receivers_i,
|
|
2231
|
+
*backward_storage_tensors,
|
|
2232
|
+
)
|
|
2233
|
+
ctx.save_for_forward(
|
|
2234
|
+
ca,
|
|
2235
|
+
cb,
|
|
2236
|
+
cq,
|
|
2237
|
+
ay,
|
|
2238
|
+
by,
|
|
2239
|
+
ay_h,
|
|
2240
|
+
by_h,
|
|
2241
|
+
ax,
|
|
2242
|
+
bx,
|
|
2243
|
+
ax_h,
|
|
2244
|
+
bx_h,
|
|
2245
|
+
ky,
|
|
2246
|
+
ky_h,
|
|
2247
|
+
kx,
|
|
2248
|
+
kx_h,
|
|
2249
|
+
sources_i,
|
|
2250
|
+
receivers_i,
|
|
2251
|
+
)
|
|
2252
|
+
ctx.backward_storage_objects = ctx_data["backward_storage_objects"]
|
|
2253
|
+
ctx.backward_storage_filename_arrays = ctx_data[
|
|
2254
|
+
"backward_storage_filename_arrays"
|
|
2255
|
+
]
|
|
2256
|
+
ctx.rdy = rdy
|
|
2257
|
+
ctx.rdx = rdx
|
|
2258
|
+
ctx.dt = dt
|
|
2259
|
+
ctx.nt = nt
|
|
2260
|
+
ctx.n_shots = n_shots
|
|
2261
|
+
ctx.ny = ny
|
|
2262
|
+
ctx.nx = nx
|
|
2263
|
+
ctx.n_sources = n_sources
|
|
2264
|
+
ctx.n_receivers = n_receivers
|
|
2265
|
+
ctx.step_ratio = step_ratio
|
|
2266
|
+
ctx.accuracy = accuracy
|
|
2267
|
+
ctx.ca_batched = ca_batched
|
|
2268
|
+
ctx.cb_batched = cb_batched
|
|
2269
|
+
ctx.cq_batched = cq_batched
|
|
2270
|
+
ctx.pml_y0 = pml_y0
|
|
2271
|
+
ctx.pml_x0 = pml_x0
|
|
2272
|
+
ctx.pml_y1 = pml_y1
|
|
2273
|
+
ctx.pml_x1 = pml_x1
|
|
2274
|
+
ctx.ca_requires_grad = ctx_data["ca_requires_grad"]
|
|
2275
|
+
ctx.cb_requires_grad = ctx_data["cb_requires_grad"]
|
|
2276
|
+
ctx.storage_mode = ctx_data["storage_mode"]
|
|
2277
|
+
ctx.shot_bytes_uncomp = ctx_data["shot_bytes_uncomp"]
|
|
2278
|
+
ctx.fd_pad = fd_pad
|
|
2279
|
+
ctx.pml_width = pml_width
|
|
2280
|
+
ctx.models = models
|
|
2281
|
+
ctx.backward_callback = backward_callback
|
|
2282
|
+
ctx.callback_frequency = callback_frequency
|
|
2283
|
+
ctx.source_amplitudes_scaled = ctx_data["source_amplitudes_scaled"]
|
|
2284
|
+
ctx.n_threads = n_threads
|
|
2285
|
+
|
|
2286
|
+
@staticmethod
|
|
2287
|
+
def backward(
|
|
2288
|
+
ctx: Any, *grad_outputs: torch.Tensor
|
|
2289
|
+
) -> tuple[Optional[torch.Tensor], ...]:
|
|
2290
|
+
"""Computes the gradients during the backward pass using ASM.
|
|
2291
|
+
|
|
2292
|
+
Uses the Adjoint State Method (ASM) to compute gradients:
|
|
2293
|
+
- grad_ca = sum_t (E_y^n * lambda_Ey^{n+1})
|
|
2294
|
+
- grad_cb = sum_t (curl_H^n * lambda_Ey^{n+1})
|
|
2295
|
+
|
|
2296
|
+
Args:
|
|
2297
|
+
ctx: A context object containing information saved during forward.
|
|
2298
|
+
grad_outputs: Gradients of the loss with respect to the outputs.
|
|
2299
|
+
|
|
2300
|
+
Returns:
|
|
2301
|
+
Gradients with respect to the inputs of the forward pass.
|
|
2302
|
+
"""
|
|
2303
|
+
from . import backend_utils
|
|
2304
|
+
|
|
2305
|
+
grad_outputs_list = list(grad_outputs)
|
|
2306
|
+
if len(grad_outputs_list) == 9:
|
|
2307
|
+
grad_outputs_list.pop() # drop context handle grad
|
|
2308
|
+
|
|
2309
|
+
# Unpack grad_outputs
|
|
2310
|
+
(
|
|
2311
|
+
grad_Ey,
|
|
2312
|
+
grad_Hx,
|
|
2313
|
+
grad_Hz,
|
|
2314
|
+
grad_m_Ey_x,
|
|
2315
|
+
grad_m_Ey_z,
|
|
2316
|
+
grad_m_Hx_z,
|
|
2317
|
+
grad_m_Hz_x,
|
|
2318
|
+
grad_r,
|
|
2319
|
+
) = grad_outputs_list
|
|
2320
|
+
|
|
2321
|
+
# Retrieve saved tensors
|
|
2322
|
+
saved = ctx.saved_tensors
|
|
2323
|
+
ca, cb, cq = saved[0], saved[1], saved[2]
|
|
2324
|
+
ay, by, ay_h, by_h = saved[3], saved[4], saved[5], saved[6]
|
|
2325
|
+
ax, bx, ax_h, bx_h = saved[7], saved[8], saved[9], saved[10]
|
|
2326
|
+
ky, ky_h, kx, kx_h = saved[11], saved[12], saved[13], saved[14]
|
|
2327
|
+
sources_i, receivers_i = saved[15], saved[16]
|
|
2328
|
+
ey_store_1, ey_store_3 = saved[17], saved[18]
|
|
2329
|
+
curl_store_1, curl_store_3 = saved[19], saved[20]
|
|
2330
|
+
|
|
2331
|
+
device = ca.device
|
|
2332
|
+
dtype = ca.dtype
|
|
2333
|
+
|
|
2334
|
+
rdy = ctx.rdy
|
|
2335
|
+
rdx = ctx.rdx
|
|
2336
|
+
dt = ctx.dt
|
|
2337
|
+
nt = ctx.nt
|
|
2338
|
+
n_shots = ctx.n_shots
|
|
2339
|
+
ny = ctx.ny
|
|
2340
|
+
nx = ctx.nx
|
|
2341
|
+
n_sources = ctx.n_sources
|
|
2342
|
+
n_receivers = ctx.n_receivers
|
|
2343
|
+
step_ratio = ctx.step_ratio
|
|
2344
|
+
accuracy = ctx.accuracy
|
|
2345
|
+
ca_batched = ctx.ca_batched
|
|
2346
|
+
cb_batched = ctx.cb_batched
|
|
2347
|
+
cq_batched = ctx.cq_batched
|
|
2348
|
+
pml_y0 = ctx.pml_y0
|
|
2349
|
+
pml_x0 = ctx.pml_x0
|
|
2350
|
+
pml_y1 = ctx.pml_y1
|
|
2351
|
+
pml_x1 = ctx.pml_x1
|
|
2352
|
+
ca_requires_grad = ctx.ca_requires_grad
|
|
2353
|
+
cb_requires_grad = ctx.cb_requires_grad
|
|
2354
|
+
pml_width = ctx.pml_width
|
|
2355
|
+
storage_mode = ctx.storage_mode
|
|
2356
|
+
shot_bytes_uncomp = ctx.shot_bytes_uncomp
|
|
2357
|
+
|
|
2358
|
+
import ctypes
|
|
2359
|
+
|
|
2360
|
+
if storage_mode == STORAGE_DISK:
|
|
2361
|
+
ey_filenames_ptr = ctypes.cast(
|
|
2362
|
+
ctx.backward_storage_filename_arrays[0], ctypes.c_void_p
|
|
2363
|
+
)
|
|
2364
|
+
curl_filenames_ptr = ctypes.cast(
|
|
2365
|
+
ctx.backward_storage_filename_arrays[1], ctypes.c_void_p
|
|
2366
|
+
)
|
|
2367
|
+
else:
|
|
2368
|
+
ey_filenames_ptr = 0
|
|
2369
|
+
curl_filenames_ptr = 0
|
|
2370
|
+
|
|
2371
|
+
# Recalculate PML boundaries for gradient accumulation
|
|
2372
|
+
#
|
|
2373
|
+
# For staggered grid schemes, the backward pass uses an extended PML region
|
|
2374
|
+
# compared to forward. This is because backward calculations
|
|
2375
|
+
# involve spatial derivatives of terms that are themselves spatial derivatives.
|
|
2376
|
+
#
|
|
2377
|
+
# In tide, the padded domain includes both fd_pad and pml_width:
|
|
2378
|
+
# - pml_y0 = fd_pad + pml_width (start of interior, from forward)
|
|
2379
|
+
# - pml_y1 = ny - (fd_pad-1) - pml_width (end of interior, from forward)
|
|
2380
|
+
#
|
|
2381
|
+
# The gradient accumulation region is controlled by loop bounds in C/CUDA
|
|
2382
|
+
# with pml_bounds array and 3-region loop.
|
|
2383
|
+
|
|
2384
|
+
# Ensure grad_r is contiguous
|
|
2385
|
+
if grad_r is None or grad_r.numel() == 0:
|
|
2386
|
+
grad_r = torch.zeros(nt, n_shots, n_receivers, device=device, dtype=dtype)
|
|
2387
|
+
else:
|
|
2388
|
+
grad_r = grad_r.contiguous()
|
|
2389
|
+
|
|
2390
|
+
# Initialize adjoint fields (lambda fields)
|
|
2391
|
+
lambda_ey = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2392
|
+
lambda_hx = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2393
|
+
lambda_hz = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2394
|
+
|
|
2395
|
+
# Initialize adjoint PML memory variables
|
|
2396
|
+
m_lambda_ey_x = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2397
|
+
m_lambda_ey_z = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2398
|
+
m_lambda_hx_z = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2399
|
+
m_lambda_hz_x = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2400
|
+
|
|
2401
|
+
# Allocate gradient outputs
|
|
2402
|
+
if n_sources > 0:
|
|
2403
|
+
grad_f = torch.zeros(nt, n_shots, n_sources, device=device, dtype=dtype)
|
|
2404
|
+
else:
|
|
2405
|
+
grad_f = torch.empty(0, device=device, dtype=dtype)
|
|
2406
|
+
|
|
2407
|
+
if ca_requires_grad:
|
|
2408
|
+
if ca_batched:
|
|
2409
|
+
grad_ca = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2410
|
+
else:
|
|
2411
|
+
grad_ca = torch.zeros(ny, nx, device=device, dtype=dtype)
|
|
2412
|
+
# Per-shot workspace for gradient accumulation (needed for CUDA)
|
|
2413
|
+
grad_ca_shot = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2414
|
+
else:
|
|
2415
|
+
grad_ca = torch.empty(0, device=device, dtype=dtype)
|
|
2416
|
+
grad_ca_shot = torch.empty(0, device=device, dtype=dtype)
|
|
2417
|
+
|
|
2418
|
+
if cb_requires_grad:
|
|
2419
|
+
if cb_batched:
|
|
2420
|
+
grad_cb = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2421
|
+
else:
|
|
2422
|
+
grad_cb = torch.zeros(ny, nx, device=device, dtype=dtype)
|
|
2423
|
+
# Per-shot workspace for gradient accumulation (needed for CUDA)
|
|
2424
|
+
grad_cb_shot = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2425
|
+
else:
|
|
2426
|
+
grad_cb = torch.empty(0, device=device, dtype=dtype)
|
|
2427
|
+
grad_cb_shot = torch.empty(0, device=device, dtype=dtype)
|
|
2428
|
+
|
|
2429
|
+
if ca_requires_grad or cb_requires_grad:
|
|
2430
|
+
if ca_batched:
|
|
2431
|
+
grad_eps = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2432
|
+
grad_sigma = torch.zeros(n_shots, ny, nx, device=device, dtype=dtype)
|
|
2433
|
+
else:
|
|
2434
|
+
grad_eps = torch.zeros(ny, nx, device=device, dtype=dtype)
|
|
2435
|
+
grad_sigma = torch.zeros(ny, nx, device=device, dtype=dtype)
|
|
2436
|
+
else:
|
|
2437
|
+
grad_eps = torch.empty(0, device=device, dtype=dtype)
|
|
2438
|
+
grad_sigma = torch.empty(0, device=device, dtype=dtype)
|
|
2439
|
+
|
|
2440
|
+
# Get device index for CUDA
|
|
2441
|
+
device_idx = (
|
|
2442
|
+
device.index if device.type == "cuda" and device.index is not None else 0
|
|
2443
|
+
)
|
|
2444
|
+
|
|
2445
|
+
# Get callback-related context
|
|
2446
|
+
backward_callback = ctx.backward_callback
|
|
2447
|
+
callback_frequency = ctx.callback_frequency
|
|
2448
|
+
fd_pad_ctx = ctx.fd_pad
|
|
2449
|
+
models = ctx.models
|
|
2450
|
+
n_threads = ctx.n_threads
|
|
2451
|
+
|
|
2452
|
+
# Get the backend function
|
|
2453
|
+
backward_func = backend_utils.get_backend_function(
|
|
2454
|
+
"maxwell_tm", "backward", accuracy, dtype, device
|
|
2455
|
+
)
|
|
2456
|
+
|
|
2457
|
+
# Determine effective callback frequency
|
|
2458
|
+
if backward_callback is None:
|
|
2459
|
+
effective_callback_freq = nt // step_ratio
|
|
2460
|
+
else:
|
|
2461
|
+
effective_callback_freq = callback_frequency
|
|
2462
|
+
|
|
2463
|
+
# Chunked backward propagation with callback support
|
|
2464
|
+
# Backward propagation goes from nt to 0
|
|
2465
|
+
for step in range(nt // step_ratio, 0, -effective_callback_freq):
|
|
2466
|
+
step_nt = min(step, effective_callback_freq) * step_ratio
|
|
2467
|
+
start_t = step * step_ratio
|
|
2468
|
+
|
|
2469
|
+
# Call the C/CUDA backward function for this chunk
|
|
2470
|
+
backward_func(
|
|
2471
|
+
backend_utils.tensor_to_ptr(ca),
|
|
2472
|
+
backend_utils.tensor_to_ptr(cb),
|
|
2473
|
+
backend_utils.tensor_to_ptr(cq),
|
|
2474
|
+
backend_utils.tensor_to_ptr(grad_r),
|
|
2475
|
+
backend_utils.tensor_to_ptr(lambda_ey),
|
|
2476
|
+
backend_utils.tensor_to_ptr(lambda_hx),
|
|
2477
|
+
backend_utils.tensor_to_ptr(lambda_hz),
|
|
2478
|
+
backend_utils.tensor_to_ptr(m_lambda_ey_x),
|
|
2479
|
+
backend_utils.tensor_to_ptr(m_lambda_ey_z),
|
|
2480
|
+
backend_utils.tensor_to_ptr(m_lambda_hx_z),
|
|
2481
|
+
backend_utils.tensor_to_ptr(m_lambda_hz_x),
|
|
2482
|
+
backend_utils.tensor_to_ptr(ey_store_1),
|
|
2483
|
+
backend_utils.tensor_to_ptr(ey_store_3),
|
|
2484
|
+
ey_filenames_ptr,
|
|
2485
|
+
backend_utils.tensor_to_ptr(curl_store_1),
|
|
2486
|
+
backend_utils.tensor_to_ptr(curl_store_3),
|
|
2487
|
+
curl_filenames_ptr,
|
|
2488
|
+
backend_utils.tensor_to_ptr(grad_f),
|
|
2489
|
+
backend_utils.tensor_to_ptr(grad_ca),
|
|
2490
|
+
backend_utils.tensor_to_ptr(grad_cb),
|
|
2491
|
+
backend_utils.tensor_to_ptr(grad_eps),
|
|
2492
|
+
backend_utils.tensor_to_ptr(grad_sigma),
|
|
2493
|
+
backend_utils.tensor_to_ptr(grad_ca_shot),
|
|
2494
|
+
backend_utils.tensor_to_ptr(grad_cb_shot),
|
|
2495
|
+
backend_utils.tensor_to_ptr(ay),
|
|
2496
|
+
backend_utils.tensor_to_ptr(by),
|
|
2497
|
+
backend_utils.tensor_to_ptr(ay_h),
|
|
2498
|
+
backend_utils.tensor_to_ptr(by_h),
|
|
2499
|
+
backend_utils.tensor_to_ptr(ax),
|
|
2500
|
+
backend_utils.tensor_to_ptr(bx),
|
|
2501
|
+
backend_utils.tensor_to_ptr(ax_h),
|
|
2502
|
+
backend_utils.tensor_to_ptr(bx_h),
|
|
2503
|
+
backend_utils.tensor_to_ptr(ky),
|
|
2504
|
+
backend_utils.tensor_to_ptr(ky_h),
|
|
2505
|
+
backend_utils.tensor_to_ptr(kx),
|
|
2506
|
+
backend_utils.tensor_to_ptr(kx_h),
|
|
2507
|
+
backend_utils.tensor_to_ptr(sources_i),
|
|
2508
|
+
backend_utils.tensor_to_ptr(receivers_i),
|
|
2509
|
+
rdy,
|
|
2510
|
+
rdx,
|
|
2511
|
+
dt,
|
|
2512
|
+
step_nt, # number of steps to run in this chunk
|
|
2513
|
+
n_shots,
|
|
2514
|
+
ny,
|
|
2515
|
+
nx,
|
|
2516
|
+
n_sources,
|
|
2517
|
+
n_receivers,
|
|
2518
|
+
step_ratio,
|
|
2519
|
+
storage_mode,
|
|
2520
|
+
shot_bytes_uncomp,
|
|
2521
|
+
ca_requires_grad,
|
|
2522
|
+
cb_requires_grad,
|
|
2523
|
+
ca_batched,
|
|
2524
|
+
cb_batched,
|
|
2525
|
+
cq_batched,
|
|
2526
|
+
start_t, # starting time step for this chunk
|
|
2527
|
+
pml_y0, # Use original PML boundaries for adjoint propagation
|
|
2528
|
+
pml_x0,
|
|
2529
|
+
pml_y1,
|
|
2530
|
+
pml_x1,
|
|
2531
|
+
n_threads,
|
|
2532
|
+
device_idx,
|
|
2533
|
+
)
|
|
2534
|
+
|
|
2535
|
+
# Call backward callback after each chunk
|
|
2536
|
+
if backward_callback is not None:
|
|
2537
|
+
# The time step index is step - 1 because the callback is
|
|
2538
|
+
# executed after the calculations for the current backward
|
|
2539
|
+
# step are complete
|
|
2540
|
+
callback_wavefields = {
|
|
2541
|
+
"lambda_Ey": lambda_ey,
|
|
2542
|
+
"lambda_Hx": lambda_hx,
|
|
2543
|
+
"lambda_Hz": lambda_hz,
|
|
2544
|
+
"m_lambda_Ey_x": m_lambda_ey_x,
|
|
2545
|
+
"m_lambda_Ey_z": m_lambda_ey_z,
|
|
2546
|
+
"m_lambda_Hx_z": m_lambda_hx_z,
|
|
2547
|
+
"m_lambda_Hz_x": m_lambda_hz_x,
|
|
2548
|
+
}
|
|
2549
|
+
callback_gradients = {}
|
|
2550
|
+
if ca_requires_grad:
|
|
2551
|
+
callback_gradients["ca"] = grad_ca
|
|
2552
|
+
if cb_requires_grad:
|
|
2553
|
+
callback_gradients["cb"] = grad_cb
|
|
2554
|
+
if ca_requires_grad or cb_requires_grad:
|
|
2555
|
+
callback_gradients["epsilon"] = grad_eps
|
|
2556
|
+
callback_gradients["sigma"] = grad_sigma
|
|
2557
|
+
|
|
2558
|
+
backward_callback(
|
|
2559
|
+
CallbackState(
|
|
2560
|
+
dt=dt,
|
|
2561
|
+
step=step - 1,
|
|
2562
|
+
nt=nt // step_ratio,
|
|
2563
|
+
wavefields=callback_wavefields,
|
|
2564
|
+
models=models,
|
|
2565
|
+
gradients=callback_gradients,
|
|
2566
|
+
fd_pad=list(fd_pad_ctx),
|
|
2567
|
+
pml_width=list(pml_width),
|
|
2568
|
+
is_backward=True,
|
|
2569
|
+
)
|
|
2570
|
+
)
|
|
2571
|
+
|
|
2572
|
+
# Return gradients for all inputs
|
|
2573
|
+
# Order: ca, cb, cq, source_amplitudes_scaled,
|
|
2574
|
+
# ay, by, ay_h, by_h, ax, bx, ax_h, bx_h,
|
|
2575
|
+
# ky, ky_h, kx, kx_h,
|
|
2576
|
+
# sources_i, receivers_i,
|
|
2577
|
+
# rdy, rdx, dt, nt, n_shots, ny, nx, n_sources, n_receivers,
|
|
2578
|
+
# step_ratio, accuracy, ca_batched, cb_batched, cq_batched,
|
|
2579
|
+
# pml_y0, pml_x0, pml_y1, pml_x1,
|
|
2580
|
+
# fd_pad, pml_width, models, backward_callback, callback_frequency,
|
|
2581
|
+
# Ey, Hx, Hz, m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x
|
|
2582
|
+
|
|
2583
|
+
# Flatten grad_f to match input shape [nt * n_shots * n_sources]
|
|
2584
|
+
if n_sources > 0:
|
|
2585
|
+
grad_f_flat = grad_f.reshape(nt * n_shots * n_sources)
|
|
2586
|
+
else:
|
|
2587
|
+
grad_f_flat = None
|
|
2588
|
+
|
|
2589
|
+
# Match gradient shapes to input shapes
|
|
2590
|
+
# Input ca, cb are [1, ny, nx] but grad_ca, grad_cb are [ny, nx] when not batched
|
|
2591
|
+
if ca_requires_grad and not ca_batched:
|
|
2592
|
+
grad_ca = grad_ca.unsqueeze(0) # [ny, nx] -> [1, ny, nx]
|
|
2593
|
+
if cb_requires_grad and not cb_batched:
|
|
2594
|
+
grad_cb = grad_cb.unsqueeze(0) # [ny, nx] -> [1, ny, nx]
|
|
2595
|
+
|
|
2596
|
+
_release_ctx_handle(getattr(ctx, "_ctx_handle_id", None))
|
|
2597
|
+
return (
|
|
2598
|
+
grad_ca if ca_requires_grad else None, # ca
|
|
2599
|
+
grad_cb if cb_requires_grad else None, # cb
|
|
2600
|
+
None, # cq
|
|
2601
|
+
grad_f_flat, # source_amplitudes_scaled
|
|
2602
|
+
None,
|
|
2603
|
+
None,
|
|
2604
|
+
None,
|
|
2605
|
+
None, # ay, by, ay_h, by_h
|
|
2606
|
+
None,
|
|
2607
|
+
None,
|
|
2608
|
+
None,
|
|
2609
|
+
None, # ax, bx, ax_h, bx_h
|
|
2610
|
+
None,
|
|
2611
|
+
None,
|
|
2612
|
+
None,
|
|
2613
|
+
None, # ky, ky_h, kx, kx_h
|
|
2614
|
+
None,
|
|
2615
|
+
None, # sources_i, receivers_i
|
|
2616
|
+
None,
|
|
2617
|
+
None,
|
|
2618
|
+
None, # rdy, rdx, dt
|
|
2619
|
+
None,
|
|
2620
|
+
None,
|
|
2621
|
+
None,
|
|
2622
|
+
None, # nt, n_shots, ny, nx
|
|
2623
|
+
None,
|
|
2624
|
+
None, # n_sources, n_receivers
|
|
2625
|
+
None, # step_ratio
|
|
2626
|
+
None, # accuracy
|
|
2627
|
+
None,
|
|
2628
|
+
None,
|
|
2629
|
+
None, # ca_batched, cb_batched, cq_batched
|
|
2630
|
+
None,
|
|
2631
|
+
None,
|
|
2632
|
+
None,
|
|
2633
|
+
None, # pml_y0, pml_x0, pml_y1, pml_x1
|
|
2634
|
+
None,
|
|
2635
|
+
None,
|
|
2636
|
+
None, # fd_pad, pml_width, models
|
|
2637
|
+
None,
|
|
2638
|
+
None,
|
|
2639
|
+
None, # forward_callback, backward_callback, callback_frequency
|
|
2640
|
+
None,
|
|
2641
|
+
None,
|
|
2642
|
+
None, # storage_mode_str, storage_path, storage_compression
|
|
2643
|
+
None,
|
|
2644
|
+
None,
|
|
2645
|
+
None, # Ey, Hx, Hz
|
|
2646
|
+
None,
|
|
2647
|
+
None,
|
|
2648
|
+
None,
|
|
2649
|
+
None, # m_Ey_x, m_Ey_z, m_Hx_z, m_Hz_x
|
|
2650
|
+
None, # n_threads
|
|
2651
|
+
)
|