tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tide/__init__.py ADDED
@@ -0,0 +1,65 @@
1
+ """TIDE: Torch-based Inversion & Intelligence Engine.
2
+
3
+ A PyTorch-based library for electromagnetic wave propagation and inversion.
4
+ """
5
+
6
+ from . import callbacks
7
+ from . import cfl
8
+ from . import maxwell
9
+ from . import padding
10
+ from . import resampling
11
+ from . import staggered
12
+ from . import utils
13
+ from . import validation
14
+ from . import wavelets
15
+
16
+ from .callbacks import CallbackState, Callback, create_callback_state
17
+ from .cfl import cfl_condition
18
+ from .padding import create_or_pad, zero_interior, reverse_pad
19
+ from .resampling import upsample, downsample, downsample_and_movedim
20
+ from .validation import (
21
+ validate_model_gradient_sampling_interval,
22
+ validate_freq_taper_frac,
23
+ validate_time_pad_frac,
24
+ )
25
+ from .maxwell import MaxwellTM, maxwelltm
26
+ from .wavelets import ricker
27
+
28
+ __all__ = [
29
+ # Modules
30
+ "callbacks",
31
+ "cfl",
32
+ "maxwell",
33
+ "padding",
34
+ "resampling",
35
+ "staggered",
36
+ "validation",
37
+ "utils",
38
+ "wavelets",
39
+ # Classes
40
+ "MaxwellTM",
41
+ "CallbackState",
42
+ # Type aliases
43
+ "Callback",
44
+ # Functions
45
+ "maxwelltm",
46
+ "create_callback_state",
47
+ # Signal processing
48
+ "upsample",
49
+ "downsample",
50
+ "downsample_and_movedim",
51
+ "cfl_condition",
52
+ # Validation
53
+ "validate_model_gradient_sampling_interval",
54
+ "validate_freq_taper_frac",
55
+ "validate_time_pad_frac",
56
+ # Model padding utilities
57
+ "create_or_pad",
58
+ "zero_interior",
59
+ "reverse_pad",
60
+ # Wavelets
61
+ "ricker",
62
+ ]
63
+
64
+
65
+ __version__ = "0.0.9"
tide/autograd_utils.py ADDED
@@ -0,0 +1,26 @@
1
+ import itertools
2
+ from typing import Any, Optional
3
+
4
+ import torch
5
+
6
+ _CTX_HANDLE_COUNTER = itertools.count()
7
+ _CTX_HANDLE_REGISTRY: dict[int, dict[str, Any]] = {}
8
+
9
+
10
+ def _register_ctx_handle(ctx_data: dict[str, Any]) -> torch.Tensor:
11
+ handle = next(_CTX_HANDLE_COUNTER)
12
+ _CTX_HANDLE_REGISTRY[handle] = ctx_data
13
+ return torch.tensor(handle, dtype=torch.int64)
14
+
15
+
16
+ def _get_ctx_handle(handle: int) -> dict[str, Any]:
17
+ try:
18
+ return _CTX_HANDLE_REGISTRY[handle]
19
+ except KeyError as exc:
20
+ raise RuntimeError(f"Unknown context handle: {handle}") from exc
21
+
22
+
23
+ def _release_ctx_handle(handle: Optional[int]) -> None:
24
+ if handle is None:
25
+ return
26
+ _CTX_HANDLE_REGISTRY.pop(handle, None)
tide/backend_utils.py ADDED
@@ -0,0 +1,536 @@
1
+ import ctypes
2
+ import pathlib
3
+ import platform
4
+ import site
5
+ import sys
6
+ from importlib import resources
7
+ from ctypes import c_bool, c_double, c_float, c_int64, c_void_p
8
+ from typing import Any, Callable, Optional, TypeAlias
9
+
10
+ import torch
11
+
12
+ CFunctionPointer: TypeAlias = Any
13
+
14
+ # Platform-specific shared library extension
15
+ SO_EXT = {"Linux": "so", "Darwin": "dylib", "Windows": "dll"}.get(platform.system())
16
+ if SO_EXT is None:
17
+ raise RuntimeError("Unsupported OS or platform type")
18
+
19
+ def _candidate_lib_paths() -> list[pathlib.Path]:
20
+ lib_name = f"libtide_C.{SO_EXT}"
21
+ lib_dir = pathlib.Path(__file__).resolve().parent
22
+ candidates: list[pathlib.Path] = [
23
+ lib_dir / lib_name,
24
+ lib_dir / "tide" / lib_name,
25
+ lib_dir.parent / "tide.libs" / lib_name,
26
+ ]
27
+
28
+ try:
29
+ pkg_root = resources.files(__package__ or "tide")
30
+ candidates.append(pathlib.Path(pkg_root / lib_name))
31
+ candidates.append(pathlib.Path(pkg_root / "tide" / lib_name))
32
+ for path in pkg_root.rglob(lib_name):
33
+ candidates.append(pathlib.Path(path))
34
+ except Exception:
35
+ pass
36
+
37
+ try:
38
+ site_paths = list(site.getsitepackages())
39
+ except Exception:
40
+ site_paths = []
41
+ site_paths.append(site.getusersitepackages())
42
+ for base in site_paths:
43
+ if not base:
44
+ continue
45
+ base_path = pathlib.Path(base)
46
+ for path in base_path.glob(f"tide-*.data/**/{lib_name}"):
47
+ candidates.append(path)
48
+
49
+ return candidates
50
+
51
+
52
+ _dll: Optional[ctypes.CDLL] = None
53
+ _lib_path: pathlib.Path = pathlib.Path(__file__).resolve().parent / f"libtide_C.{SO_EXT}"
54
+
55
+ for candidate in _candidate_lib_paths():
56
+ if not candidate.exists():
57
+ continue
58
+ try:
59
+ _dll = ctypes.CDLL(str(candidate))
60
+ _lib_path = candidate
61
+ break
62
+ except OSError:
63
+ continue
64
+
65
+
66
+ def is_backend_available() -> bool:
67
+ """Check if the C/CUDA backend is available."""
68
+ return _dll is not None
69
+
70
+
71
+ def get_dll() -> ctypes.CDLL:
72
+ """Get the loaded DLL, raising an error if not available."""
73
+ if _dll is None:
74
+ raise RuntimeError(
75
+ f"C/CUDA backend not available. Please compile the library first. "
76
+ f"Expected library at: {_lib_path}"
77
+ )
78
+ return _dll
79
+
80
+
81
+ # Check if was compiled with OpenMP support
82
+ USE_OPENMP = _dll is not None and hasattr(_dll, "omp_get_num_threads")
83
+
84
+ # Define ctypes argument type templates to reduce repetition while preserving order.
85
+ # A placeholder will be replaced by the appropriate float type (c_float or c_double).
86
+ FLOAT_TYPE: type = c_float
87
+
88
+
89
+ def get_maxwell_tm_forward_template() -> list[Any]:
90
+ """Returns the argtype template for the Maxwell TM forward propagator."""
91
+ args: list[Any] = []
92
+ # Material parameters
93
+ args += [c_void_p] * 3 # ca, cb, cq
94
+ # Source
95
+ args += [c_void_p] # f (source amplitudes)
96
+ # Fields
97
+ args += [c_void_p] * 3 # ey, hx, hz
98
+ # PML memory variables
99
+ args += [c_void_p] * 4 # m_ey_x, m_ey_z, m_hx_z, m_hz_x
100
+ # Recorded data
101
+ args += [c_void_p] # r (receiver amplitudes)
102
+ # PML profiles
103
+ args += [c_void_p] * 8 # ay, by, ayh, byh, ax, bx, axh, bxh
104
+ # Kappa profiles
105
+ args += [c_void_p] * 4 # ky, kyh, kx, kxh
106
+ # Source and receiver indices
107
+ args += [c_void_p] * 2 # sources_i, receivers_i
108
+ # Grid spacing
109
+ args += [FLOAT_TYPE] * 2 # rdy, rdx
110
+ # Time step
111
+ args += [FLOAT_TYPE] # dt
112
+ # Sizes
113
+ args += [c_int64] # nt
114
+ args += [c_int64] # n_shots
115
+ args += [c_int64] * 2 # ny, nx
116
+ args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
117
+ args += [c_int64] # step_ratio
118
+ # Batched flags
119
+ args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
120
+ # Start time
121
+ args += [c_int64] # start_t
122
+ # PML boundaries
123
+ args += [c_int64] * 4 # pml_y0, pml_x0, pml_y1, pml_x1
124
+ # OpenMP threads (CPU only)
125
+ args += [c_int64] # n_threads
126
+ # Device (for CUDA)
127
+ args += [c_int64] # device
128
+ return args
129
+
130
+
131
+ def get_maxwell_tm_backward_template() -> list[Any]:
132
+ """Returns the argtype template for the Maxwell TM backward propagator (v2 with ASM)."""
133
+ args: list[Any] = []
134
+ # Material parameters
135
+ args += [c_void_p] * 3 # ca, cb, cq
136
+ # Gradient of receiver data
137
+ args += [c_void_p] # grad_r
138
+ # Adjoint fields (lambda)
139
+ args += [c_void_p] * 3 # lambda_ey, lambda_hx, lambda_hz
140
+ # Adjoint PML memory variables
141
+ args += [c_void_p] * 4 # m_lambda_ey_x, m_lambda_ey_z, m_lambda_hx_z, m_lambda_hz_x
142
+ # Stored forward values for gradient (Ey and curl_H)
143
+ # For each: store_1, store_3, filenames (char**)
144
+ args += [c_void_p] * 6
145
+ # Gradient outputs
146
+ args += [c_void_p] # grad_f
147
+ args += [c_void_p] * 4 # grad_ca, grad_cb, grad_eps, grad_sigma
148
+ args += [c_void_p] * 2 # grad_ca_shot, grad_cb_shot (per-shot workspace)
149
+ # PML profiles
150
+ args += [c_void_p] * 8 # ay, by, ayh, byh, ax, bx, axh, bxh
151
+ # Kappa profiles
152
+ args += [c_void_p] * 4 # ky, kyh, kx, kxh
153
+ # Source and receiver indices
154
+ args += [c_void_p] * 2 # sources_i, receivers_i
155
+ # Grid spacing
156
+ args += [FLOAT_TYPE] * 2 # rdy, rdx
157
+ # Time step
158
+ args += [FLOAT_TYPE] # dt
159
+ # Sizes
160
+ args += [c_int64] # nt
161
+ args += [c_int64] # n_shots
162
+ args += [c_int64] * 2 # ny, nx
163
+ args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
164
+ args += [c_int64] # step_ratio
165
+ # Storage mode
166
+ args += [c_int64] * 2 # storage_mode, shot_bytes_uncomp
167
+ # Requires grad flags
168
+ args += [c_bool] * 2 # ca_requires_grad, cb_requires_grad
169
+ # Batched flags
170
+ args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
171
+ # Start time
172
+ args += [c_int64] # start_t
173
+ # PML boundaries for adjoint propagation
174
+ args += [c_int64] * 4 # pml_y0, pml_x0, pml_y1, pml_x1
175
+ # OpenMP threads (CPU only)
176
+ args += [c_int64] # n_threads
177
+ # Device (for CUDA)
178
+ args += [c_int64] # device
179
+ return args
180
+
181
+
182
+ def get_maxwell_tm_forward_with_storage_template() -> list[Any]:
183
+ """Returns the argtype template for Maxwell TM forward with storage (for ASM backward)."""
184
+ args: list[Any] = []
185
+ # Material parameters
186
+ args += [c_void_p] * 3 # ca, cb, cq
187
+ # Source
188
+ args += [c_void_p] # f (source amplitudes)
189
+ # Fields
190
+ args += [c_void_p] * 3 # ey, hx, hz
191
+ # PML memory variables
192
+ args += [c_void_p] * 4 # m_ey_x, m_ey_z, m_hx_z, m_hz_x
193
+ # Recorded data
194
+ args += [c_void_p] # r (receiver amplitudes)
195
+ # Storage for backward (Ey and curl_H)
196
+ # For each: store_1, store_3, filenames (char**)
197
+ args += [c_void_p] * 6
198
+ # PML profiles
199
+ args += [c_void_p] * 8 # ay, by, ayh, byh, ax, bx, axh, bxh
200
+ # Kappa profiles
201
+ args += [c_void_p] * 4 # ky, kyh, kx, kxh
202
+ # Source and receiver indices
203
+ args += [c_void_p] * 2 # sources_i, receivers_i
204
+ # Grid spacing
205
+ args += [FLOAT_TYPE] * 2 # rdy, rdx
206
+ # Time step
207
+ args += [FLOAT_TYPE] # dt
208
+ # Sizes
209
+ args += [c_int64] # nt
210
+ args += [c_int64] # n_shots
211
+ args += [c_int64] * 2 # ny, nx
212
+ args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
213
+ args += [c_int64] # step_ratio
214
+ # Storage mode
215
+ args += [c_int64] * 2 # storage_mode, shot_bytes_uncomp
216
+ # Requires grad flags
217
+ args += [c_bool] * 2 # ca_requires_grad, cb_requires_grad
218
+ # Batched flags
219
+ args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
220
+ # Start time
221
+ args += [c_int64] # start_t
222
+ # PML boundaries
223
+ args += [c_int64] * 4 # pml_y0, pml_x0, pml_y1, pml_x1
224
+ # OpenMP threads (CPU only)
225
+ args += [c_int64] # n_threads
226
+ # Device (for CUDA)
227
+ args += [c_int64] # device
228
+ return args
229
+
230
+
231
+ def get_maxwell_3d_forward_template() -> list[Any]:
232
+ """Returns the argtype template for the 3D Maxwell forward propagator."""
233
+ args: list[Any] = []
234
+ # Material parameters
235
+ args += [c_void_p] * 3 # ca, cb, cq
236
+ # Source
237
+ args += [c_void_p] # f (source amplitudes)
238
+ # Fields
239
+ args += [c_void_p] * 6 # ex, ey, ez, hx, hy, hz
240
+ # PML memory variables
241
+ args += (
242
+ [c_void_p] * 12
243
+ ) # m_hz_y, m_hy_z, m_hx_z, m_hz_x, m_hy_x, m_hx_y, m_ey_z, m_ez_y, m_ez_x, m_ex_z, m_ex_y, m_ey_x
244
+ # Recorded data
245
+ args += [c_void_p] # r
246
+ # PML profiles
247
+ args += [c_void_p] * 12 # az, bz, azh, bzh, ay, by, ayh, byh, ax, bx, axh, bxh
248
+ # Kappa profiles
249
+ args += [c_void_p] * 6 # kz, kzh, ky, kyh, kx, kxh
250
+ # Source and receiver indices
251
+ args += [c_void_p] * 2 # sources_i, receivers_i
252
+ # Grid spacing
253
+ args += [FLOAT_TYPE] * 3 # rdz, rdy, rdx
254
+ # Time step
255
+ args += [FLOAT_TYPE] # dt
256
+ # Sizes
257
+ args += [c_int64] # nt
258
+ args += [c_int64] # n_shots
259
+ args += [c_int64] * 3 # nz, ny, nx
260
+ args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
261
+ args += [c_int64] # step_ratio
262
+ # Batched flags
263
+ args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
264
+ # Start time
265
+ args += [c_int64] # start_t
266
+ # PML boundaries
267
+ args += [c_int64] * 6 # pml_z0, pml_y0, pml_x0, pml_z1, pml_y1, pml_x1
268
+ # Source/receiver component
269
+ args += [c_int64] * 2 # source_component, receiver_component
270
+ # OpenMP threads (CPU only)
271
+ args += [c_int64] # n_threads
272
+ # Device (for CUDA)
273
+ args += [c_int64] # device
274
+ return args
275
+
276
+
277
+ def get_maxwell_3d_forward_with_storage_template() -> list[Any]:
278
+ """Returns the argtype template for 3D Maxwell forward with storage (ASM)."""
279
+ args: list[Any] = []
280
+ # Material parameters
281
+ args += [c_void_p] * 3 # ca, cb, cq
282
+ # Source
283
+ args += [c_void_p] # f (source amplitudes)
284
+ # Fields
285
+ args += [c_void_p] * 6 # ex, ey, ez, hx, hy, hz
286
+ # PML memory variables
287
+ args += (
288
+ [c_void_p] * 12
289
+ ) # m_hz_y, m_hy_z, m_hx_z, m_hz_x, m_hy_x, m_hx_y, m_ey_z, m_ez_y, m_ez_x, m_ex_z, m_ex_y, m_ey_x
290
+ # Recorded data
291
+ args += [c_void_p] # r
292
+ # Storage for backward (Ex/Ey/Ez and curl(H) components)
293
+ # For each: store_1, store_3, filenames (char**)
294
+ args += [c_void_p] * 18
295
+ # PML profiles
296
+ args += [c_void_p] * 12 # az, bz, azh, bzh, ay, by, ayh, byh, ax, bx, axh, bxh
297
+ # Kappa profiles
298
+ args += [c_void_p] * 6 # kz, kzh, ky, kyh, kx, kxh
299
+ # Source and receiver indices
300
+ args += [c_void_p] * 2 # sources_i, receivers_i
301
+ # Grid spacing
302
+ args += [FLOAT_TYPE] * 3 # rdz, rdy, rdx
303
+ # Time step
304
+ args += [FLOAT_TYPE] # dt
305
+ # Sizes
306
+ args += [c_int64] # nt
307
+ args += [c_int64] # n_shots
308
+ args += [c_int64] * 3 # nz, ny, nx
309
+ args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
310
+ args += [c_int64] # step_ratio
311
+ # Storage mode
312
+ args += [c_int64] * 2 # storage_mode, shot_bytes_uncomp
313
+ # Requires grad flags
314
+ args += [c_bool] * 2 # ca_requires_grad, cb_requires_grad
315
+ # Batched flags
316
+ args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
317
+ # Start time
318
+ args += [c_int64] # start_t
319
+ # PML boundaries
320
+ args += [c_int64] * 6 # pml_z0, pml_y0, pml_x0, pml_z1, pml_y1, pml_x1
321
+ # Source/receiver component
322
+ args += [c_int64] * 2 # source_component, receiver_component
323
+ # OpenMP threads (CPU only)
324
+ args += [c_int64] # n_threads
325
+ # Device (for CUDA)
326
+ args += [c_int64] # device
327
+ return args
328
+
329
+
330
+ def get_maxwell_3d_backward_template() -> list[Any]:
331
+ """Returns the argtype template for 3D Maxwell backward propagator (ASM)."""
332
+ args: list[Any] = []
333
+ # Material parameters
334
+ args += [c_void_p] * 3 # ca, cb, cq
335
+ # Gradient of receiver data
336
+ args += [c_void_p] # grad_r
337
+ # Adjoint fields (lambda)
338
+ args += [
339
+ c_void_p
340
+ ] * 6 # lambda_ex, lambda_ey, lambda_ez, lambda_hx, lambda_hy, lambda_hz
341
+ # Adjoint PML memory variables
342
+ args += (
343
+ [c_void_p] * 12
344
+ ) # m_lambda_ey_z, m_lambda_ez_y, m_lambda_ez_x, m_lambda_ex_z, m_lambda_ex_y, m_lambda_ey_x, m_lambda_hz_y, m_lambda_hy_z, m_lambda_hx_z, m_lambda_hz_x, m_lambda_hy_x, m_lambda_hx_y
345
+ # Stored forward values (Ex/Ey/Ez and curl(H) components)
346
+ # For each: store_1, store_3, filenames (char**)
347
+ args += [c_void_p] * 18
348
+ # Gradient outputs
349
+ args += [c_void_p] # grad_f
350
+ args += [c_void_p] * 4 # grad_ca, grad_cb, grad_eps, grad_sigma
351
+ args += [c_void_p] * 2 # grad_ca_shot, grad_cb_shot
352
+ # PML profiles
353
+ args += [c_void_p] * 12 # az, bz, azh, bzh, ay, by, ayh, byh, ax, bx, axh, bxh
354
+ # Kappa profiles
355
+ args += [c_void_p] * 6 # kz, kzh, ky, kyh, kx, kxh
356
+ # Source and receiver indices
357
+ args += [c_void_p] * 2 # sources_i, receivers_i
358
+ # Grid spacing
359
+ args += [FLOAT_TYPE] * 3 # rdz, rdy, rdx
360
+ # Time step
361
+ args += [FLOAT_TYPE] # dt
362
+ # Sizes
363
+ args += [c_int64] # nt
364
+ args += [c_int64] # n_shots
365
+ args += [c_int64] * 3 # nz, ny, nx
366
+ args += [c_int64] * 2 # n_sources_per_shot, n_receivers_per_shot
367
+ args += [c_int64] # step_ratio
368
+ # Storage mode
369
+ args += [c_int64] * 2 # storage_mode, shot_bytes_uncomp
370
+ # Requires grad flags
371
+ args += [c_bool] * 2 # ca_requires_grad, cb_requires_grad
372
+ # Batched flags
373
+ args += [c_bool] * 3 # ca_batched, cb_batched, cq_batched
374
+ # Start time
375
+ args += [c_int64] # start_t
376
+ # PML boundaries
377
+ args += [c_int64] * 6 # pml_z0, pml_y0, pml_x0, pml_z1, pml_y1, pml_x1
378
+ # Source/receiver component
379
+ args += [c_int64] * 2 # source_component, receiver_component
380
+ # OpenMP threads (CPU only)
381
+ args += [c_int64] # n_threads
382
+ # Device (for CUDA)
383
+ args += [c_int64] # device
384
+ return args
385
+
386
+
387
+ # Template registry
388
+ templates: dict[str, Callable[[], list[Any]]] = {
389
+ "maxwell_tm_forward": get_maxwell_tm_forward_template,
390
+ "maxwell_tm_forward_with_storage": get_maxwell_tm_forward_with_storage_template,
391
+ "maxwell_tm_backward": get_maxwell_tm_backward_template,
392
+ "maxwell_3d_forward": get_maxwell_3d_forward_template,
393
+ "maxwell_3d_forward_with_storage": get_maxwell_3d_forward_with_storage_template,
394
+ "maxwell_3d_backward": get_maxwell_3d_backward_template,
395
+ }
396
+
397
+
398
+ def _get_argtypes(template_name: str, float_type: type) -> list[Any]:
399
+ """Generates a concrete argtype list from a template and a float type.
400
+
401
+ Args:
402
+ template_name: The name of the argtype template to use.
403
+ float_type: The `ctypes` float type (`c_float` or `c_double`)
404
+ to substitute into the template.
405
+
406
+ Returns:
407
+ list[Any]: A list of `ctypes` types representing the argument
408
+ signature for a C function.
409
+
410
+ """
411
+ template = templates[template_name]()
412
+ return [float_type if t is FLOAT_TYPE else t for t in template]
413
+
414
+
415
+ def _assign_argtypes(
416
+ propagator: str,
417
+ accuracy: int,
418
+ dtype: str,
419
+ direction: str,
420
+ ) -> None:
421
+ """Dynamically assigns ctypes argtypes to a given C function.
422
+
423
+ Args:
424
+ propagator: The name of the propagator (e.g., "maxwell_tm").
425
+ accuracy: The finite-difference accuracy order (e.g., 2, 4, 6, 8).
426
+ dtype: The data type as a string (e.g., "float", "double").
427
+ direction: The direction of propagation (e.g., "forward", "backward").
428
+
429
+ """
430
+ if _dll is None:
431
+ return
432
+
433
+ template_name = f"{propagator}_{direction}"
434
+ float_type = c_float if dtype == "float" else c_double
435
+ argtypes = _get_argtypes(template_name, float_type)
436
+
437
+ for device in ["cpu", "cuda"]:
438
+ func_name = f"{propagator}_{accuracy}_{dtype}_{direction}_{device}"
439
+ try:
440
+ func = getattr(_dll, func_name)
441
+ func.argtypes = argtypes
442
+ func.restype = None # All C functions return void
443
+ except AttributeError:
444
+ continue
445
+
446
+
447
+ def get_backend_function(
448
+ propagator: str,
449
+ pass_name: str,
450
+ accuracy: int,
451
+ dtype: torch.dtype,
452
+ device: torch.device,
453
+ ) -> CFunctionPointer:
454
+ """Selects and returns the appropriate backend C/CUDA function.
455
+
456
+ Args:
457
+ propagator: The name of the propagator (e.g., "maxwell_tm").
458
+ pass_name: The name of the pass (e.g., "forward", "backward").
459
+ accuracy: The finite-difference accuracy order.
460
+ dtype: The torch.dtype of the tensors.
461
+ device: The torch.device the tensors are on.
462
+
463
+ Returns:
464
+ The backend function pointer.
465
+
466
+ Raises:
467
+ AttributeError: If the function is not found in the shared library.
468
+ TypeError: If the dtype is not torch.float32 or torch.float64.
469
+ RuntimeError: If the backend is not available.
470
+
471
+ """
472
+ dll = get_dll()
473
+
474
+ if dtype == torch.float32:
475
+ dtype_str = "float"
476
+ elif dtype == torch.float64:
477
+ dtype_str = "double"
478
+ else:
479
+ raise TypeError(f"Unsupported dtype {dtype}")
480
+
481
+ device_str = device.type
482
+
483
+ func_name = f"{propagator}_{accuracy}_{dtype_str}_{pass_name}_{device_str}"
484
+
485
+ try:
486
+ return getattr(dll, func_name)
487
+ except AttributeError as e:
488
+ raise AttributeError(f"Backend function {func_name} not found.") from e
489
+
490
+
491
+ def tensor_to_ptr(tensor: Optional[torch.Tensor]) -> int:
492
+ """Convert a PyTorch tensor to a C pointer (int).
493
+
494
+ Args:
495
+ tensor: The tensor to convert, or None.
496
+
497
+ Returns:
498
+ The data pointer as an integer, or 0 if tensor is None.
499
+
500
+ """
501
+ if tensor is None:
502
+ return 0
503
+ if torch._C._functorch.is_functorch_wrapped_tensor(tensor):
504
+ tensor = torch._C._functorch.get_unwrapped(tensor)
505
+ return tensor.data_ptr()
506
+
507
+
508
+ def ensure_contiguous(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
509
+ """Ensure a tensor is contiguous in memory.
510
+
511
+ Args:
512
+ tensor: The tensor to check, or None.
513
+
514
+ Returns:
515
+ A contiguous version of the tensor, or None.
516
+
517
+ """
518
+ if tensor is None:
519
+ return None
520
+ return tensor.contiguous()
521
+
522
+
523
+ # Initialize argtypes for all available functions when the module loads
524
+ if _dll is not None:
525
+ for current_accuracy in [2, 4, 6, 8]:
526
+ for current_dtype in ["float", "double"]:
527
+ _assign_argtypes("maxwell_tm", current_accuracy, current_dtype, "forward")
528
+ _assign_argtypes(
529
+ "maxwell_tm", current_accuracy, current_dtype, "forward_with_storage"
530
+ )
531
+ _assign_argtypes("maxwell_tm", current_accuracy, current_dtype, "backward")
532
+ _assign_argtypes("maxwell_3d", current_accuracy, current_dtype, "forward")
533
+ _assign_argtypes(
534
+ "maxwell_3d", current_accuracy, current_dtype, "forward_with_storage"
535
+ )
536
+ _assign_argtypes("maxwell_3d", current_accuracy, current_dtype, "backward")