tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tide/__init__.py +65 -0
- tide/autograd_utils.py +26 -0
- tide/backend_utils.py +536 -0
- tide/callbacks.py +348 -0
- tide/cfl.py +64 -0
- tide/csrc/CMakeLists.txt +263 -0
- tide/csrc/common_cpu.h +31 -0
- tide/csrc/common_gpu.h +56 -0
- tide/csrc/maxwell.c +2133 -0
- tide/csrc/maxwell.cu +2297 -0
- tide/csrc/maxwell_born.cu +0 -0
- tide/csrc/staggered_grid.h +175 -0
- tide/csrc/staggered_grid_3d.h +124 -0
- tide/csrc/storage_utils.c +78 -0
- tide/csrc/storage_utils.cu +135 -0
- tide/csrc/storage_utils.h +36 -0
- tide/grid_utils.py +31 -0
- tide/maxwell.py +2651 -0
- tide/padding.py +139 -0
- tide/resampling.py +246 -0
- tide/staggered.py +567 -0
- tide/storage.py +131 -0
- tide/tide/libtide_C.so +0 -0
- tide/utils.py +274 -0
- tide/validation.py +71 -0
- tide/wavelets.py +72 -0
- tide_gpr-0.0.9.dist-info/METADATA +256 -0
- tide_gpr-0.0.9.dist-info/RECORD +31 -0
- tide_gpr-0.0.9.dist-info/WHEEL +5 -0
- tide_gpr-0.0.9.dist-info/licenses/LICENSE +46 -0
- tide_gpr.libs/libgomp-24e2ab19.so.1.0.0 +0 -0
tide/__init__.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
"""TIDE: Torch-based Inversion & Intelligence Engine.
|
|
2
|
+
|
|
3
|
+
A PyTorch-based library for electromagnetic wave propagation and inversion.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
from . import callbacks
|
|
7
|
+
from . import cfl
|
|
8
|
+
from . import maxwell
|
|
9
|
+
from . import padding
|
|
10
|
+
from . import resampling
|
|
11
|
+
from . import staggered
|
|
12
|
+
from . import utils
|
|
13
|
+
from . import validation
|
|
14
|
+
from . import wavelets
|
|
15
|
+
|
|
16
|
+
from .callbacks import CallbackState, Callback, create_callback_state
|
|
17
|
+
from .cfl import cfl_condition
|
|
18
|
+
from .padding import create_or_pad, zero_interior, reverse_pad
|
|
19
|
+
from .resampling import upsample, downsample, downsample_and_movedim
|
|
20
|
+
from .validation import (
|
|
21
|
+
validate_model_gradient_sampling_interval,
|
|
22
|
+
validate_freq_taper_frac,
|
|
23
|
+
validate_time_pad_frac,
|
|
24
|
+
)
|
|
25
|
+
from .maxwell import MaxwellTM, maxwelltm
|
|
26
|
+
from .wavelets import ricker
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
# Modules
|
|
30
|
+
"callbacks",
|
|
31
|
+
"cfl",
|
|
32
|
+
"maxwell",
|
|
33
|
+
"padding",
|
|
34
|
+
"resampling",
|
|
35
|
+
"staggered",
|
|
36
|
+
"validation",
|
|
37
|
+
"utils",
|
|
38
|
+
"wavelets",
|
|
39
|
+
# Classes
|
|
40
|
+
"MaxwellTM",
|
|
41
|
+
"CallbackState",
|
|
42
|
+
# Type aliases
|
|
43
|
+
"Callback",
|
|
44
|
+
# Functions
|
|
45
|
+
"maxwelltm",
|
|
46
|
+
"create_callback_state",
|
|
47
|
+
# Signal processing
|
|
48
|
+
"upsample",
|
|
49
|
+
"downsample",
|
|
50
|
+
"downsample_and_movedim",
|
|
51
|
+
"cfl_condition",
|
|
52
|
+
# Validation
|
|
53
|
+
"validate_model_gradient_sampling_interval",
|
|
54
|
+
"validate_freq_taper_frac",
|
|
55
|
+
"validate_time_pad_frac",
|
|
56
|
+
# Model padding utilities
|
|
57
|
+
"create_or_pad",
|
|
58
|
+
"zero_interior",
|
|
59
|
+
"reverse_pad",
|
|
60
|
+
# Wavelets
|
|
61
|
+
"ricker",
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
__version__ = "0.0.9"
|
tide/autograd_utils.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import Any, Optional
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
|
|
6
|
+
_CTX_HANDLE_COUNTER = itertools.count()
|
|
7
|
+
_CTX_HANDLE_REGISTRY: dict[int, dict[str, Any]] = {}
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _register_ctx_handle(ctx_data: dict[str, Any]) -> torch.Tensor:
|
|
11
|
+
handle = next(_CTX_HANDLE_COUNTER)
|
|
12
|
+
_CTX_HANDLE_REGISTRY[handle] = ctx_data
|
|
13
|
+
return torch.tensor(handle, dtype=torch.int64)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_ctx_handle(handle: int) -> dict[str, Any]:
|
|
17
|
+
try:
|
|
18
|
+
return _CTX_HANDLE_REGISTRY[handle]
|
|
19
|
+
except KeyError as exc:
|
|
20
|
+
raise RuntimeError(f"Unknown context handle: {handle}") from exc
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _release_ctx_handle(handle: Optional[int]) -> None:
|
|
24
|
+
if handle is None:
|
|
25
|
+
return
|
|
26
|
+
_CTX_HANDLE_REGISTRY.pop(handle, None)
|
tide/backend_utils.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
import ctypes
|
|
2
|
+
import pathlib
|
|
3
|
+
import platform
|
|
4
|
+
import site
|
|
5
|
+
import sys
|
|
6
|
+
from importlib import resources
|
|
7
|
+
from ctypes import c_bool, c_double, c_float, c_int64, c_void_p
|
|
8
|
+
from typing import Any, Callable, Optional, TypeAlias
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
CFunctionPointer: TypeAlias = Any
|
|
13
|
+
|
|
14
|
+
# Platform-specific shared library extension
|
|
15
|
+
SO_EXT = {"Linux": "so", "Darwin": "dylib", "Windows": "dll"}.get(platform.system())
|
|
16
|
+
if SO_EXT is None:
|
|
17
|
+
raise RuntimeError("Unsupported OS or platform type")
|
|
18
|
+
|
|
19
|
+
def _candidate_lib_paths() -> list[pathlib.Path]:
|
|
20
|
+
lib_name = f"libtide_C.{SO_EXT}"
|
|
21
|
+
lib_dir = pathlib.Path(__file__).resolve().parent
|
|
22
|
+
candidates: list[pathlib.Path] = [
|
|
23
|
+
lib_dir / lib_name,
|
|
24
|
+
lib_dir / "tide" / lib_name,
|
|
25
|
+
lib_dir.parent / "tide.libs" / lib_name,
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
try:
|
|
29
|
+
pkg_root = resources.files(__package__ or "tide")
|
|
30
|
+
candidates.append(pathlib.Path(pkg_root / lib_name))
|
|
31
|
+
candidates.append(pathlib.Path(pkg_root / "tide" / lib_name))
|
|
32
|
+
for path in pkg_root.rglob(lib_name):
|
|
33
|
+
candidates.append(pathlib.Path(path))
|
|
34
|
+
except Exception:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
site_paths = list(site.getsitepackages())
|
|
39
|
+
except Exception:
|
|
40
|
+
site_paths = []
|
|
41
|
+
site_paths.append(site.getusersitepackages())
|
|
42
|
+
for base in site_paths:
|
|
43
|
+
if not base:
|
|
44
|
+
continue
|
|
45
|
+
base_path = pathlib.Path(base)
|
|
46
|
+
for path in base_path.glob(f"tide-*.data/**/{lib_name}"):
|
|
47
|
+
candidates.append(path)
|
|
48
|
+
|
|
49
|
+
return candidates
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
_dll: Optional[ctypes.CDLL] = None
|
|
53
|
+
_lib_path: pathlib.Path = pathlib.Path(__file__).resolve().parent / f"libtide_C.{SO_EXT}"
|
|
54
|
+
|
|
55
|
+
for candidate in _candidate_lib_paths():
|
|
56
|
+
if not candidate.exists():
|
|
57
|
+
continue
|
|
58
|
+
try:
|
|
59
|
+
_dll = ctypes.CDLL(str(candidate))
|
|
60
|
+
_lib_path = candidate
|
|
61
|
+
break
|
|
62
|
+
except OSError:
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def is_backend_available() -> bool:
|
|
67
|
+
"""Check if the C/CUDA backend is available."""
|
|
68
|
+
return _dll is not None
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_dll() -> ctypes.CDLL:
|
|
72
|
+
"""Get the loaded DLL, raising an error if not available."""
|
|
73
|
+
if _dll is None:
|
|
74
|
+
raise RuntimeError(
|
|
75
|
+
f"C/CUDA backend not available. Please compile the library first. "
|
|
76
|
+
f"Expected library at: {_lib_path}"
|
|
77
|
+
)
|
|
78
|
+
return _dll
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# Check if was compiled with OpenMP support
|
|
82
|
+
USE_OPENMP = _dll is not None and hasattr(_dll, "omp_get_num_threads")
|
|
83
|
+
|
|
84
|
+
# Define ctypes argument type templates to reduce repetition while preserving order.
|
|
85
|
+
# A placeholder will be replaced by the appropriate float type (c_float or c_double).
|
|
86
|
+
FLOAT_TYPE: type = c_float
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def get_maxwell_tm_forward_template() -> list[Any]:
|
|
90
|
+
"""Returns the argtype template for the Maxwell TM forward propagator."""
|
|
91
|
+
args: list[Any] = []
|
|
92
|
+
# Material parameters
|
|
93
|
+
args += [c_void_p] * 3 # ca, cb, cq
|
|
94
|
+
# Source
|
|
95
|
+
args += [c_void_p] # f (source amplitudes)
|
|
96
|
+
# Fields
|
|
97
|
+
args += [c_void_p] * 3 # ey, hx, hz
|
|
98
|
+
# PML memory variables
|
|
99
|
+
args += [c_void_p] * 4 # m_ey_x, m_ey_z, m_hx_z, m_hz_x
|
|
100
|
+
# Recorded data
|
|
101
|
+
args += [c_void_p] # r (receiver amplitudes)
|
|
102
|
+
# PML profiles
|
|
103
|
+
args += [c_void_p] * 8 # ay, by, ayh, byh, ax, bx, axh, bxh
|
|
104
|
+
# Kappa profiles
|
|
105
|
+
args += [c_void_p] * 4 # ky, kyh, kx, kxh
|
|
106
|
+
# Source and receiver indices
|
|
107
|
+
args += [c_void_p] * 2 # sources_i, receivers_i
|
|
108
|
+
# Grid spacing
|
|
109
|
+
args += [FLOAT_TYPE] * 2 # rdy, rdx
|
|
110
|
+
# Time step
|
|
111
|
+
args += [FLOAT_TYPE] # dt
|
|
112
|
+
# Sizes
|
|
113
|
+
args += [c_int64] # nt
|
|
114
|
+
args += [c_int64] # n_shots
|
|
115
|
+
args += [c_int64] * 2 # ny, nx
|
|
116
|
+
args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
|
|
117
|
+
args += [c_int64] # step_ratio
|
|
118
|
+
# Batched flags
|
|
119
|
+
args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
|
|
120
|
+
# Start time
|
|
121
|
+
args += [c_int64] # start_t
|
|
122
|
+
# PML boundaries
|
|
123
|
+
args += [c_int64] * 4 # pml_y0, pml_x0, pml_y1, pml_x1
|
|
124
|
+
# OpenMP threads (CPU only)
|
|
125
|
+
args += [c_int64] # n_threads
|
|
126
|
+
# Device (for CUDA)
|
|
127
|
+
args += [c_int64] # device
|
|
128
|
+
return args
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def get_maxwell_tm_backward_template() -> list[Any]:
|
|
132
|
+
"""Returns the argtype template for the Maxwell TM backward propagator (v2 with ASM)."""
|
|
133
|
+
args: list[Any] = []
|
|
134
|
+
# Material parameters
|
|
135
|
+
args += [c_void_p] * 3 # ca, cb, cq
|
|
136
|
+
# Gradient of receiver data
|
|
137
|
+
args += [c_void_p] # grad_r
|
|
138
|
+
# Adjoint fields (lambda)
|
|
139
|
+
args += [c_void_p] * 3 # lambda_ey, lambda_hx, lambda_hz
|
|
140
|
+
# Adjoint PML memory variables
|
|
141
|
+
args += [c_void_p] * 4 # m_lambda_ey_x, m_lambda_ey_z, m_lambda_hx_z, m_lambda_hz_x
|
|
142
|
+
# Stored forward values for gradient (Ey and curl_H)
|
|
143
|
+
# For each: store_1, store_3, filenames (char**)
|
|
144
|
+
args += [c_void_p] * 6
|
|
145
|
+
# Gradient outputs
|
|
146
|
+
args += [c_void_p] # grad_f
|
|
147
|
+
args += [c_void_p] * 4 # grad_ca, grad_cb, grad_eps, grad_sigma
|
|
148
|
+
args += [c_void_p] * 2 # grad_ca_shot, grad_cb_shot (per-shot workspace)
|
|
149
|
+
# PML profiles
|
|
150
|
+
args += [c_void_p] * 8 # ay, by, ayh, byh, ax, bx, axh, bxh
|
|
151
|
+
# Kappa profiles
|
|
152
|
+
args += [c_void_p] * 4 # ky, kyh, kx, kxh
|
|
153
|
+
# Source and receiver indices
|
|
154
|
+
args += [c_void_p] * 2 # sources_i, receivers_i
|
|
155
|
+
# Grid spacing
|
|
156
|
+
args += [FLOAT_TYPE] * 2 # rdy, rdx
|
|
157
|
+
# Time step
|
|
158
|
+
args += [FLOAT_TYPE] # dt
|
|
159
|
+
# Sizes
|
|
160
|
+
args += [c_int64] # nt
|
|
161
|
+
args += [c_int64] # n_shots
|
|
162
|
+
args += [c_int64] * 2 # ny, nx
|
|
163
|
+
args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
|
|
164
|
+
args += [c_int64] # step_ratio
|
|
165
|
+
# Storage mode
|
|
166
|
+
args += [c_int64] * 2 # storage_mode, shot_bytes_uncomp
|
|
167
|
+
# Requires grad flags
|
|
168
|
+
args += [c_bool] * 2 # ca_requires_grad, cb_requires_grad
|
|
169
|
+
# Batched flags
|
|
170
|
+
args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
|
|
171
|
+
# Start time
|
|
172
|
+
args += [c_int64] # start_t
|
|
173
|
+
# PML boundaries for adjoint propagation
|
|
174
|
+
args += [c_int64] * 4 # pml_y0, pml_x0, pml_y1, pml_x1
|
|
175
|
+
# OpenMP threads (CPU only)
|
|
176
|
+
args += [c_int64] # n_threads
|
|
177
|
+
# Device (for CUDA)
|
|
178
|
+
args += [c_int64] # device
|
|
179
|
+
return args
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
def get_maxwell_tm_forward_with_storage_template() -> list[Any]:
|
|
183
|
+
"""Returns the argtype template for Maxwell TM forward with storage (for ASM backward)."""
|
|
184
|
+
args: list[Any] = []
|
|
185
|
+
# Material parameters
|
|
186
|
+
args += [c_void_p] * 3 # ca, cb, cq
|
|
187
|
+
# Source
|
|
188
|
+
args += [c_void_p] # f (source amplitudes)
|
|
189
|
+
# Fields
|
|
190
|
+
args += [c_void_p] * 3 # ey, hx, hz
|
|
191
|
+
# PML memory variables
|
|
192
|
+
args += [c_void_p] * 4 # m_ey_x, m_ey_z, m_hx_z, m_hz_x
|
|
193
|
+
# Recorded data
|
|
194
|
+
args += [c_void_p] # r (receiver amplitudes)
|
|
195
|
+
# Storage for backward (Ey and curl_H)
|
|
196
|
+
# For each: store_1, store_3, filenames (char**)
|
|
197
|
+
args += [c_void_p] * 6
|
|
198
|
+
# PML profiles
|
|
199
|
+
args += [c_void_p] * 8 # ay, by, ayh, byh, ax, bx, axh, bxh
|
|
200
|
+
# Kappa profiles
|
|
201
|
+
args += [c_void_p] * 4 # ky, kyh, kx, kxh
|
|
202
|
+
# Source and receiver indices
|
|
203
|
+
args += [c_void_p] * 2 # sources_i, receivers_i
|
|
204
|
+
# Grid spacing
|
|
205
|
+
args += [FLOAT_TYPE] * 2 # rdy, rdx
|
|
206
|
+
# Time step
|
|
207
|
+
args += [FLOAT_TYPE] # dt
|
|
208
|
+
# Sizes
|
|
209
|
+
args += [c_int64] # nt
|
|
210
|
+
args += [c_int64] # n_shots
|
|
211
|
+
args += [c_int64] * 2 # ny, nx
|
|
212
|
+
args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
|
|
213
|
+
args += [c_int64] # step_ratio
|
|
214
|
+
# Storage mode
|
|
215
|
+
args += [c_int64] * 2 # storage_mode, shot_bytes_uncomp
|
|
216
|
+
# Requires grad flags
|
|
217
|
+
args += [c_bool] * 2 # ca_requires_grad, cb_requires_grad
|
|
218
|
+
# Batched flags
|
|
219
|
+
args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
|
|
220
|
+
# Start time
|
|
221
|
+
args += [c_int64] # start_t
|
|
222
|
+
# PML boundaries
|
|
223
|
+
args += [c_int64] * 4 # pml_y0, pml_x0, pml_y1, pml_x1
|
|
224
|
+
# OpenMP threads (CPU only)
|
|
225
|
+
args += [c_int64] # n_threads
|
|
226
|
+
# Device (for CUDA)
|
|
227
|
+
args += [c_int64] # device
|
|
228
|
+
return args
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def get_maxwell_3d_forward_template() -> list[Any]:
|
|
232
|
+
"""Returns the argtype template for the 3D Maxwell forward propagator."""
|
|
233
|
+
args: list[Any] = []
|
|
234
|
+
# Material parameters
|
|
235
|
+
args += [c_void_p] * 3 # ca, cb, cq
|
|
236
|
+
# Source
|
|
237
|
+
args += [c_void_p] # f (source amplitudes)
|
|
238
|
+
# Fields
|
|
239
|
+
args += [c_void_p] * 6 # ex, ey, ez, hx, hy, hz
|
|
240
|
+
# PML memory variables
|
|
241
|
+
args += (
|
|
242
|
+
[c_void_p] * 12
|
|
243
|
+
) # m_hz_y, m_hy_z, m_hx_z, m_hz_x, m_hy_x, m_hx_y, m_ey_z, m_ez_y, m_ez_x, m_ex_z, m_ex_y, m_ey_x
|
|
244
|
+
# Recorded data
|
|
245
|
+
args += [c_void_p] # r
|
|
246
|
+
# PML profiles
|
|
247
|
+
args += [c_void_p] * 12 # az, bz, azh, bzh, ay, by, ayh, byh, ax, bx, axh, bxh
|
|
248
|
+
# Kappa profiles
|
|
249
|
+
args += [c_void_p] * 6 # kz, kzh, ky, kyh, kx, kxh
|
|
250
|
+
# Source and receiver indices
|
|
251
|
+
args += [c_void_p] * 2 # sources_i, receivers_i
|
|
252
|
+
# Grid spacing
|
|
253
|
+
args += [FLOAT_TYPE] * 3 # rdz, rdy, rdx
|
|
254
|
+
# Time step
|
|
255
|
+
args += [FLOAT_TYPE] # dt
|
|
256
|
+
# Sizes
|
|
257
|
+
args += [c_int64] # nt
|
|
258
|
+
args += [c_int64] # n_shots
|
|
259
|
+
args += [c_int64] * 3 # nz, ny, nx
|
|
260
|
+
args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
|
|
261
|
+
args += [c_int64] # step_ratio
|
|
262
|
+
# Batched flags
|
|
263
|
+
args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
|
|
264
|
+
# Start time
|
|
265
|
+
args += [c_int64] # start_t
|
|
266
|
+
# PML boundaries
|
|
267
|
+
args += [c_int64] * 6 # pml_z0, pml_y0, pml_x0, pml_z1, pml_y1, pml_x1
|
|
268
|
+
# Source/receiver component
|
|
269
|
+
args += [c_int64] * 2 # source_component, receiver_component
|
|
270
|
+
# OpenMP threads (CPU only)
|
|
271
|
+
args += [c_int64] # n_threads
|
|
272
|
+
# Device (for CUDA)
|
|
273
|
+
args += [c_int64] # device
|
|
274
|
+
return args
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def get_maxwell_3d_forward_with_storage_template() -> list[Any]:
|
|
278
|
+
"""Returns the argtype template for 3D Maxwell forward with storage (ASM)."""
|
|
279
|
+
args: list[Any] = []
|
|
280
|
+
# Material parameters
|
|
281
|
+
args += [c_void_p] * 3 # ca, cb, cq
|
|
282
|
+
# Source
|
|
283
|
+
args += [c_void_p] # f (source amplitudes)
|
|
284
|
+
# Fields
|
|
285
|
+
args += [c_void_p] * 6 # ex, ey, ez, hx, hy, hz
|
|
286
|
+
# PML memory variables
|
|
287
|
+
args += (
|
|
288
|
+
[c_void_p] * 12
|
|
289
|
+
) # m_hz_y, m_hy_z, m_hx_z, m_hz_x, m_hy_x, m_hx_y, m_ey_z, m_ez_y, m_ez_x, m_ex_z, m_ex_y, m_ey_x
|
|
290
|
+
# Recorded data
|
|
291
|
+
args += [c_void_p] # r
|
|
292
|
+
# Storage for backward (Ex/Ey/Ez and curl(H) components)
|
|
293
|
+
# For each: store_1, store_3, filenames (char**)
|
|
294
|
+
args += [c_void_p] * 18
|
|
295
|
+
# PML profiles
|
|
296
|
+
args += [c_void_p] * 12 # az, bz, azh, bzh, ay, by, ayh, byh, ax, bx, axh, bxh
|
|
297
|
+
# Kappa profiles
|
|
298
|
+
args += [c_void_p] * 6 # kz, kzh, ky, kyh, kx, kxh
|
|
299
|
+
# Source and receiver indices
|
|
300
|
+
args += [c_void_p] * 2 # sources_i, receivers_i
|
|
301
|
+
# Grid spacing
|
|
302
|
+
args += [FLOAT_TYPE] * 3 # rdz, rdy, rdx
|
|
303
|
+
# Time step
|
|
304
|
+
args += [FLOAT_TYPE] # dt
|
|
305
|
+
# Sizes
|
|
306
|
+
args += [c_int64] # nt
|
|
307
|
+
args += [c_int64] # n_shots
|
|
308
|
+
args += [c_int64] * 3 # nz, ny, nx
|
|
309
|
+
args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
|
|
310
|
+
args += [c_int64] # step_ratio
|
|
311
|
+
# Storage mode
|
|
312
|
+
args += [c_int64] * 2 # storage_mode, shot_bytes_uncomp
|
|
313
|
+
# Requires grad flags
|
|
314
|
+
args += [c_bool] * 2 # ca_requires_grad, cb_requires_grad
|
|
315
|
+
# Batched flags
|
|
316
|
+
args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
|
|
317
|
+
# Start time
|
|
318
|
+
args += [c_int64] # start_t
|
|
319
|
+
# PML boundaries
|
|
320
|
+
args += [c_int64] * 6 # pml_z0, pml_y0, pml_x0, pml_z1, pml_y1, pml_x1
|
|
321
|
+
# Source/receiver component
|
|
322
|
+
args += [c_int64] * 2 # source_component, receiver_component
|
|
323
|
+
# OpenMP threads (CPU only)
|
|
324
|
+
args += [c_int64] # n_threads
|
|
325
|
+
# Device (for CUDA)
|
|
326
|
+
args += [c_int64] # device
|
|
327
|
+
return args
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
def get_maxwell_3d_backward_template() -> list[Any]:
|
|
331
|
+
"""Returns the argtype template for 3D Maxwell backward propagator (ASM)."""
|
|
332
|
+
args: list[Any] = []
|
|
333
|
+
# Material parameters
|
|
334
|
+
args += [c_void_p] * 3 # ca, cb, cq
|
|
335
|
+
# Gradient of receiver data
|
|
336
|
+
args += [c_void_p] # grad_r
|
|
337
|
+
# Adjoint fields (lambda)
|
|
338
|
+
args += [
|
|
339
|
+
c_void_p
|
|
340
|
+
] * 6 # lambda_ex, lambda_ey, lambda_ez, lambda_hx, lambda_hy, lambda_hz
|
|
341
|
+
# Adjoint PML memory variables
|
|
342
|
+
args += (
|
|
343
|
+
[c_void_p] * 12
|
|
344
|
+
) # m_lambda_ey_z, m_lambda_ez_y, m_lambda_ez_x, m_lambda_ex_z, m_lambda_ex_y, m_lambda_ey_x, m_lambda_hz_y, m_lambda_hy_z, m_lambda_hx_z, m_lambda_hz_x, m_lambda_hy_x, m_lambda_hx_y
|
|
345
|
+
# Stored forward values (Ex/Ey/Ez and curl(H) components)
|
|
346
|
+
# For each: store_1, store_3, filenames (char**)
|
|
347
|
+
args += [c_void_p] * 18
|
|
348
|
+
# Gradient outputs
|
|
349
|
+
args += [c_void_p] # grad_f
|
|
350
|
+
args += [c_void_p] * 4 # grad_ca, grad_cb, grad_eps, grad_sigma
|
|
351
|
+
args += [c_void_p] * 2 # grad_ca_shot, grad_cb_shot
|
|
352
|
+
# PML profiles
|
|
353
|
+
args += [c_void_p] * 12 # az, bz, azh, bzh, ay, by, ayh, byh, ax, bx, axh, bxh
|
|
354
|
+
# Kappa profiles
|
|
355
|
+
args += [c_void_p] * 6 # kz, kzh, ky, kyh, kx, kxh
|
|
356
|
+
# Source and receiver indices
|
|
357
|
+
args += [c_void_p] * 2 # sources_i, receivers_i
|
|
358
|
+
# Grid spacing
|
|
359
|
+
args += [FLOAT_TYPE] * 3 # rdz, rdy, rdx
|
|
360
|
+
# Time step
|
|
361
|
+
args += [FLOAT_TYPE] # dt
|
|
362
|
+
# Sizes
|
|
363
|
+
args += [c_int64] # nt
|
|
364
|
+
args += [c_int64] # n_shots
|
|
365
|
+
args += [c_int64] * 3 # nz, ny, nx
|
|
366
|
+
args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
|
|
367
|
+
args += [c_int64] # step_ratio
|
|
368
|
+
# Storage mode
|
|
369
|
+
args += [c_int64] * 2 # storage_mode, shot_bytes_uncomp
|
|
370
|
+
# Requires grad flags
|
|
371
|
+
args += [c_bool] * 2 # ca_requires_grad, cb_requires_grad
|
|
372
|
+
# Batched flags
|
|
373
|
+
args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
|
|
374
|
+
# Start time
|
|
375
|
+
args += [c_int64] # start_t
|
|
376
|
+
# PML boundaries
|
|
377
|
+
args += [c_int64] * 6 # pml_z0, pml_y0, pml_x0, pml_z1, pml_y1, pml_x1
|
|
378
|
+
# Source/receiver component
|
|
379
|
+
args += [c_int64] * 2 # source_component, receiver_component
|
|
380
|
+
# OpenMP threads (CPU only)
|
|
381
|
+
args += [c_int64] # n_threads
|
|
382
|
+
# Device (for CUDA)
|
|
383
|
+
args += [c_int64] # device
|
|
384
|
+
return args
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
# Template registry
|
|
388
|
+
templates: dict[str, Callable[[], list[Any]]] = {
|
|
389
|
+
"maxwell_tm_forward": get_maxwell_tm_forward_template,
|
|
390
|
+
"maxwell_tm_forward_with_storage": get_maxwell_tm_forward_with_storage_template,
|
|
391
|
+
"maxwell_tm_backward": get_maxwell_tm_backward_template,
|
|
392
|
+
"maxwell_3d_forward": get_maxwell_3d_forward_template,
|
|
393
|
+
"maxwell_3d_forward_with_storage": get_maxwell_3d_forward_with_storage_template,
|
|
394
|
+
"maxwell_3d_backward": get_maxwell_3d_backward_template,
|
|
395
|
+
}
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
def _get_argtypes(template_name: str, float_type: type) -> list[Any]:
|
|
399
|
+
"""Generates a concrete argtype list from a template and a float type.
|
|
400
|
+
|
|
401
|
+
Args:
|
|
402
|
+
template_name: The name of the argtype template to use.
|
|
403
|
+
float_type: The `ctypes` float type (`c_float` or `c_double`)
|
|
404
|
+
to substitute into the template.
|
|
405
|
+
|
|
406
|
+
Returns:
|
|
407
|
+
list[Any]: A list of `ctypes` types representing the argument
|
|
408
|
+
signature for a C function.
|
|
409
|
+
|
|
410
|
+
"""
|
|
411
|
+
template = templates[template_name]()
|
|
412
|
+
return [float_type if t is FLOAT_TYPE else t for t in template]
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def _assign_argtypes(
|
|
416
|
+
propagator: str,
|
|
417
|
+
accuracy: int,
|
|
418
|
+
dtype: str,
|
|
419
|
+
direction: str,
|
|
420
|
+
) -> None:
|
|
421
|
+
"""Dynamically assigns ctypes argtypes to a given C function.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
propagator: The name of the propagator (e.g., "maxwell_tm").
|
|
425
|
+
accuracy: The finite-difference accuracy order (e.g., 2, 4, 6, 8).
|
|
426
|
+
dtype: The data type as a string (e.g., "float", "double").
|
|
427
|
+
direction: The direction of propagation (e.g., "forward", "backward").
|
|
428
|
+
|
|
429
|
+
"""
|
|
430
|
+
if _dll is None:
|
|
431
|
+
return
|
|
432
|
+
|
|
433
|
+
template_name = f"{propagator}_{direction}"
|
|
434
|
+
float_type = c_float if dtype == "float" else c_double
|
|
435
|
+
argtypes = _get_argtypes(template_name, float_type)
|
|
436
|
+
|
|
437
|
+
for device in ["cpu", "cuda"]:
|
|
438
|
+
func_name = f"{propagator}_{accuracy}_{dtype}_{direction}_{device}"
|
|
439
|
+
try:
|
|
440
|
+
func = getattr(_dll, func_name)
|
|
441
|
+
func.argtypes = argtypes
|
|
442
|
+
func.restype = None # All C functions return void
|
|
443
|
+
except AttributeError:
|
|
444
|
+
continue
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
def get_backend_function(
|
|
448
|
+
propagator: str,
|
|
449
|
+
pass_name: str,
|
|
450
|
+
accuracy: int,
|
|
451
|
+
dtype: torch.dtype,
|
|
452
|
+
device: torch.device,
|
|
453
|
+
) -> CFunctionPointer:
|
|
454
|
+
"""Selects and returns the appropriate backend C/CUDA function.
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
propagator: The name of the propagator (e.g., "maxwell_tm").
|
|
458
|
+
pass_name: The name of the pass (e.g., "forward", "backward").
|
|
459
|
+
accuracy: The finite-difference accuracy order.
|
|
460
|
+
dtype: The torch.dtype of the tensors.
|
|
461
|
+
device: The torch.device the tensors are on.
|
|
462
|
+
|
|
463
|
+
Returns:
|
|
464
|
+
The backend function pointer.
|
|
465
|
+
|
|
466
|
+
Raises:
|
|
467
|
+
AttributeError: If the function is not found in the shared library.
|
|
468
|
+
TypeError: If the dtype is not torch.float32 or torch.float64.
|
|
469
|
+
RuntimeError: If the backend is not available.
|
|
470
|
+
|
|
471
|
+
"""
|
|
472
|
+
dll = get_dll()
|
|
473
|
+
|
|
474
|
+
if dtype == torch.float32:
|
|
475
|
+
dtype_str = "float"
|
|
476
|
+
elif dtype == torch.float64:
|
|
477
|
+
dtype_str = "double"
|
|
478
|
+
else:
|
|
479
|
+
raise TypeError(f"Unsupported dtype {dtype}")
|
|
480
|
+
|
|
481
|
+
device_str = device.type
|
|
482
|
+
|
|
483
|
+
func_name = f"{propagator}_{accuracy}_{dtype_str}_{pass_name}_{device_str}"
|
|
484
|
+
|
|
485
|
+
try:
|
|
486
|
+
return getattr(dll, func_name)
|
|
487
|
+
except AttributeError as e:
|
|
488
|
+
raise AttributeError(f"Backend function {func_name} not found.") from e
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def tensor_to_ptr(tensor: Optional[torch.Tensor]) -> int:
|
|
492
|
+
"""Convert a PyTorch tensor to a C pointer (int).
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
tensor: The tensor to convert, or None.
|
|
496
|
+
|
|
497
|
+
Returns:
|
|
498
|
+
The data pointer as an integer, or 0 if tensor is None.
|
|
499
|
+
|
|
500
|
+
"""
|
|
501
|
+
if tensor is None:
|
|
502
|
+
return 0
|
|
503
|
+
if torch._C._functorch.is_functorch_wrapped_tensor(tensor):
|
|
504
|
+
tensor = torch._C._functorch.get_unwrapped(tensor)
|
|
505
|
+
return tensor.data_ptr()
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
def ensure_contiguous(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
|
509
|
+
"""Ensure a tensor is contiguous in memory.
|
|
510
|
+
|
|
511
|
+
Args:
|
|
512
|
+
tensor: The tensor to check, or None.
|
|
513
|
+
|
|
514
|
+
Returns:
|
|
515
|
+
A contiguous version of the tensor, or None.
|
|
516
|
+
|
|
517
|
+
"""
|
|
518
|
+
if tensor is None:
|
|
519
|
+
return None
|
|
520
|
+
return tensor.contiguous()
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
# Initialize argtypes for all available functions when the module loads
|
|
524
|
+
if _dll is not None:
|
|
525
|
+
for current_accuracy in [2, 4, 6, 8]:
|
|
526
|
+
for current_dtype in ["float", "double"]:
|
|
527
|
+
_assign_argtypes("maxwell_tm", current_accuracy, current_dtype, "forward")
|
|
528
|
+
_assign_argtypes(
|
|
529
|
+
"maxwell_tm", current_accuracy, current_dtype, "forward_with_storage"
|
|
530
|
+
)
|
|
531
|
+
_assign_argtypes("maxwell_tm", current_accuracy, current_dtype, "backward")
|
|
532
|
+
_assign_argtypes("maxwell_3d", current_accuracy, current_dtype, "forward")
|
|
533
|
+
_assign_argtypes(
|
|
534
|
+
"maxwell_3d", current_accuracy, current_dtype, "forward_with_storage"
|
|
535
|
+
)
|
|
536
|
+
_assign_argtypes("maxwell_3d", current_accuracy, current_dtype, "backward")
|