tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tide/__init__.py +65 -0
- tide/autograd_utils.py +26 -0
- tide/backend_utils.py +536 -0
- tide/callbacks.py +348 -0
- tide/cfl.py +64 -0
- tide/csrc/CMakeLists.txt +263 -0
- tide/csrc/common_cpu.h +31 -0
- tide/csrc/common_gpu.h +56 -0
- tide/csrc/maxwell.c +2133 -0
- tide/csrc/maxwell.cu +2297 -0
- tide/csrc/maxwell_born.cu +0 -0
- tide/csrc/staggered_grid.h +175 -0
- tide/csrc/staggered_grid_3d.h +124 -0
- tide/csrc/storage_utils.c +78 -0
- tide/csrc/storage_utils.cu +135 -0
- tide/csrc/storage_utils.h +36 -0
- tide/grid_utils.py +31 -0
- tide/maxwell.py +2651 -0
- tide/padding.py +139 -0
- tide/resampling.py +246 -0
- tide/staggered.py +567 -0
- tide/storage.py +131 -0
- tide/tide/libtide_C.so +0 -0
- tide/utils.py +274 -0
- tide/validation.py +71 -0
- tide/wavelets.py +72 -0
- tide_gpr-0.0.9.dist-info/METADATA +256 -0
- tide_gpr-0.0.9.dist-info/RECORD +31 -0
- tide_gpr-0.0.9.dist-info/WHEEL +5 -0
- tide_gpr-0.0.9.dist-info/licenses/LICENSE +46 -0
- tide_gpr.libs/libgomp-24e2ab19.so.1.0.0 +0 -0
tide/staggered.py
ADDED
|
@@ -0,0 +1,567 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from . import utils
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def set_pml_profiles(
|
|
9
|
+
pml_width: list[int],
|
|
10
|
+
accuracy: int,
|
|
11
|
+
fd_pad: list[int],
|
|
12
|
+
dt: float,
|
|
13
|
+
grid_spacing: list[float],
|
|
14
|
+
max_vel: float,
|
|
15
|
+
dtype: torch.dtype,
|
|
16
|
+
device: torch.device,
|
|
17
|
+
pml_freq: float,
|
|
18
|
+
ny: int,
|
|
19
|
+
nx: int,
|
|
20
|
+
) -> list[torch.Tensor]:
|
|
21
|
+
"""Sets up PML profiles for a staggered grid.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
pml_width: A list of integers specifying the width of the PML
|
|
25
|
+
on each side (top, bottom, left, right).
|
|
26
|
+
accuracy: The finite-difference accuracy order.
|
|
27
|
+
fd_pad: A list of integers specifying the padding for finite-difference.
|
|
28
|
+
dt: The time step.
|
|
29
|
+
grid_spacing: A list of floats specifying the grid spacing in
|
|
30
|
+
y and x directions.
|
|
31
|
+
max_vel: The maximum velocity in the model.
|
|
32
|
+
dtype: The data type of the tensors (e.g., torch.float32).
|
|
33
|
+
device: The device on which the tensors will be (e.g., 'cuda', 'cpu').
|
|
34
|
+
pml_freq: The PML frequency.
|
|
35
|
+
ny: The number of grid points in the y direction.
|
|
36
|
+
nx: The number of grid points in the x direction.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
A list containing:
|
|
40
|
+
- a, b profiles: [ay, ayh, ax, axh, by, byh, bx, bxh]
|
|
41
|
+
- k profiles: [ky, kyh, kx, kxh]
|
|
42
|
+
Total 12 tensors.
|
|
43
|
+
|
|
44
|
+
"""
|
|
45
|
+
pml_start: list[float] = [
|
|
46
|
+
fd_pad[0] + pml_width[0],
|
|
47
|
+
ny - 1 - fd_pad[1] - pml_width[1],
|
|
48
|
+
fd_pad[2] + pml_width[2],
|
|
49
|
+
nx - 1 - fd_pad[3] - pml_width[3],
|
|
50
|
+
]
|
|
51
|
+
max_pml = max(
|
|
52
|
+
[
|
|
53
|
+
pml_width[0] * grid_spacing[0],
|
|
54
|
+
pml_width[1] * grid_spacing[0],
|
|
55
|
+
pml_width[2] * grid_spacing[1],
|
|
56
|
+
pml_width[3] * grid_spacing[1],
|
|
57
|
+
],
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
# Integer grid PML profiles
|
|
61
|
+
ay, by, ky = utils.setup_pml(
|
|
62
|
+
pml_width[:2],
|
|
63
|
+
pml_start[:2],
|
|
64
|
+
max_pml,
|
|
65
|
+
dt,
|
|
66
|
+
ny,
|
|
67
|
+
max_vel,
|
|
68
|
+
dtype,
|
|
69
|
+
device,
|
|
70
|
+
pml_freq,
|
|
71
|
+
start=0.0,
|
|
72
|
+
grid_spacing=grid_spacing[0],
|
|
73
|
+
)
|
|
74
|
+
ax, bx, kx = utils.setup_pml(
|
|
75
|
+
pml_width[2:],
|
|
76
|
+
pml_start[2:],
|
|
77
|
+
max_pml,
|
|
78
|
+
dt,
|
|
79
|
+
nx,
|
|
80
|
+
max_vel,
|
|
81
|
+
dtype,
|
|
82
|
+
device,
|
|
83
|
+
pml_freq,
|
|
84
|
+
start=0.0,
|
|
85
|
+
grid_spacing=grid_spacing[1],
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Half grid PML profiles
|
|
89
|
+
ayh, byh, kyh = utils.setup_pml_half(
|
|
90
|
+
pml_width[:2],
|
|
91
|
+
pml_start[:2],
|
|
92
|
+
max_pml,
|
|
93
|
+
dt,
|
|
94
|
+
ny,
|
|
95
|
+
max_vel,
|
|
96
|
+
dtype,
|
|
97
|
+
device,
|
|
98
|
+
pml_freq,
|
|
99
|
+
start=0.0,
|
|
100
|
+
grid_spacing=grid_spacing[0],
|
|
101
|
+
)
|
|
102
|
+
axh, bxh, kxh = utils.setup_pml_half(
|
|
103
|
+
pml_width[2:],
|
|
104
|
+
pml_start[2:],
|
|
105
|
+
max_pml,
|
|
106
|
+
dt,
|
|
107
|
+
nx,
|
|
108
|
+
max_vel,
|
|
109
|
+
dtype,
|
|
110
|
+
device,
|
|
111
|
+
pml_freq,
|
|
112
|
+
start=0.0,
|
|
113
|
+
grid_spacing=grid_spacing[1],
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Reshape for broadcasting: [batch, ny, nx]
|
|
117
|
+
ay = ay[None, :, None]
|
|
118
|
+
ayh = ayh[None, :, None]
|
|
119
|
+
ax = ax[None, None, :]
|
|
120
|
+
axh = axh[None, None, :]
|
|
121
|
+
by = by[None, :, None]
|
|
122
|
+
byh = byh[None, :, None]
|
|
123
|
+
bx = bx[None, None, :]
|
|
124
|
+
bxh = bxh[None, None, :]
|
|
125
|
+
|
|
126
|
+
ky = ky[None, :, None]
|
|
127
|
+
kyh = kyh[None, :, None]
|
|
128
|
+
kx = kx[None, None, :]
|
|
129
|
+
kxh = kxh[None, None, :]
|
|
130
|
+
|
|
131
|
+
return [ay, ayh, ax, axh, by, byh, bx, bxh, ky, kyh, kx, kxh]
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def setup_pml_profiles_1d(
|
|
135
|
+
n: int,
|
|
136
|
+
pml_width0: int,
|
|
137
|
+
pml_width1: int,
|
|
138
|
+
sigma_max: float,
|
|
139
|
+
dt: float,
|
|
140
|
+
device: torch.device,
|
|
141
|
+
dtype: torch.dtype,
|
|
142
|
+
) -> tuple[
|
|
143
|
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
|
144
|
+
]:
|
|
145
|
+
"""Create 1D CPML profiles (a, b, k) for integer and half-grid points."""
|
|
146
|
+
eps = 1e-9
|
|
147
|
+
n_power = 2
|
|
148
|
+
|
|
149
|
+
if pml_width0 == 0 and pml_width1 == 0:
|
|
150
|
+
zeros = torch.zeros(n, device=device, dtype=dtype)
|
|
151
|
+
ones = torch.ones(n, device=device, dtype=dtype)
|
|
152
|
+
return zeros, zeros, zeros, zeros, ones, ones
|
|
153
|
+
|
|
154
|
+
def _profiles(start: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
155
|
+
x = torch.arange(n, device=device, dtype=dtype) + start
|
|
156
|
+
left_start = float(pml_width0)
|
|
157
|
+
right_start = float(n - 1 - pml_width1)
|
|
158
|
+
|
|
159
|
+
if pml_width0 == 0:
|
|
160
|
+
frac_left = torch.zeros_like(x)
|
|
161
|
+
else:
|
|
162
|
+
frac_left = (left_start - x) / float(pml_width0)
|
|
163
|
+
if pml_width1 == 0:
|
|
164
|
+
frac_right = torch.zeros_like(x)
|
|
165
|
+
else:
|
|
166
|
+
frac_right = (x - right_start) / float(pml_width1)
|
|
167
|
+
|
|
168
|
+
pml_frac = torch.clamp(torch.maximum(frac_left, frac_right), 0.0, 1.0)
|
|
169
|
+
sigma = sigma_max * pml_frac.pow(n_power)
|
|
170
|
+
kappa = torch.ones_like(sigma)
|
|
171
|
+
|
|
172
|
+
sigma_alpha = sigma
|
|
173
|
+
b = torch.exp(-sigma_alpha * abs(dt))
|
|
174
|
+
denom = sigma_alpha + eps
|
|
175
|
+
a = torch.where(
|
|
176
|
+
sigma_alpha > 0.0, sigma * (b - 1.0) / denom, torch.zeros_like(b)
|
|
177
|
+
)
|
|
178
|
+
return a, b, kappa
|
|
179
|
+
|
|
180
|
+
ay, by, ky = _profiles(0.0)
|
|
181
|
+
ayh, byh, kyh = _profiles(0.5)
|
|
182
|
+
return ay, ayh, by, byh, ky, kyh
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def set_pml_profiles_3d(
|
|
186
|
+
pml_width: list[int],
|
|
187
|
+
accuracy: int,
|
|
188
|
+
fd_pad: list[int],
|
|
189
|
+
dt: float,
|
|
190
|
+
grid_spacing: list[float],
|
|
191
|
+
max_vel: float,
|
|
192
|
+
dtype: torch.dtype,
|
|
193
|
+
device: torch.device,
|
|
194
|
+
pml_freq: float,
|
|
195
|
+
nz: int,
|
|
196
|
+
ny: int,
|
|
197
|
+
nx: int,
|
|
198
|
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
|
199
|
+
"""Sets up 3D PML profiles for a staggered grid.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
pml_width: Widths [z0, z1, y0, y1, x0, x1].
|
|
203
|
+
accuracy: Finite-difference accuracy order (unused, kept for API parity).
|
|
204
|
+
fd_pad: FD padding [z0, z1, y0, y1, x0, x1].
|
|
205
|
+
dt: Time step.
|
|
206
|
+
grid_spacing: Grid spacing [dz, dy, dx].
|
|
207
|
+
max_vel: Maximum velocity (unused in EM formulation, kept for API parity).
|
|
208
|
+
dtype: Tensor dtype.
|
|
209
|
+
device: Tensor device.
|
|
210
|
+
pml_freq: PML frequency.
|
|
211
|
+
nz, ny, nx: Padded grid sizes.
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
- PML a/b profiles: [az, azh, ay, ayh, ax, axh, bz, bzh, by, byh, bx, bxh]
|
|
215
|
+
- PML kappa profiles: [kz, kzh, ky, kyh, kx, kxh]
|
|
216
|
+
"""
|
|
217
|
+
_ = accuracy
|
|
218
|
+
dz, dy, dx = grid_spacing
|
|
219
|
+
|
|
220
|
+
pml_start: list[float] = [
|
|
221
|
+
fd_pad[0] + pml_width[0],
|
|
222
|
+
nz - 1 - fd_pad[1] - pml_width[1],
|
|
223
|
+
fd_pad[2] + pml_width[2],
|
|
224
|
+
ny - 1 - fd_pad[3] - pml_width[3],
|
|
225
|
+
fd_pad[4] + pml_width[4],
|
|
226
|
+
nx - 1 - fd_pad[5] - pml_width[5],
|
|
227
|
+
]
|
|
228
|
+
|
|
229
|
+
max_pml = max(
|
|
230
|
+
[
|
|
231
|
+
pml_width[0] * dz,
|
|
232
|
+
pml_width[1] * dz,
|
|
233
|
+
pml_width[2] * dy,
|
|
234
|
+
pml_width[3] * dy,
|
|
235
|
+
pml_width[4] * dx,
|
|
236
|
+
pml_width[5] * dx,
|
|
237
|
+
]
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
az, bz, kz = utils.setup_pml(
|
|
241
|
+
pml_width[:2],
|
|
242
|
+
pml_start[:2],
|
|
243
|
+
max_pml,
|
|
244
|
+
dt,
|
|
245
|
+
nz,
|
|
246
|
+
max_vel,
|
|
247
|
+
dtype,
|
|
248
|
+
device,
|
|
249
|
+
pml_freq,
|
|
250
|
+
start=0.0,
|
|
251
|
+
grid_spacing=dz,
|
|
252
|
+
)
|
|
253
|
+
ay, by, ky = utils.setup_pml(
|
|
254
|
+
pml_width[2:4],
|
|
255
|
+
pml_start[2:4],
|
|
256
|
+
max_pml,
|
|
257
|
+
dt,
|
|
258
|
+
ny,
|
|
259
|
+
max_vel,
|
|
260
|
+
dtype,
|
|
261
|
+
device,
|
|
262
|
+
pml_freq,
|
|
263
|
+
start=0.0,
|
|
264
|
+
grid_spacing=dy,
|
|
265
|
+
)
|
|
266
|
+
ax, bx, kx = utils.setup_pml(
|
|
267
|
+
pml_width[4:],
|
|
268
|
+
pml_start[4:],
|
|
269
|
+
max_pml,
|
|
270
|
+
dt,
|
|
271
|
+
nx,
|
|
272
|
+
max_vel,
|
|
273
|
+
dtype,
|
|
274
|
+
device,
|
|
275
|
+
pml_freq,
|
|
276
|
+
start=0.0,
|
|
277
|
+
grid_spacing=dx,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
azh, bzh, kzh = utils.setup_pml_half(
|
|
281
|
+
pml_width[:2],
|
|
282
|
+
pml_start[:2],
|
|
283
|
+
max_pml,
|
|
284
|
+
dt,
|
|
285
|
+
nz,
|
|
286
|
+
max_vel,
|
|
287
|
+
dtype,
|
|
288
|
+
device,
|
|
289
|
+
pml_freq,
|
|
290
|
+
start=0.0,
|
|
291
|
+
grid_spacing=dz,
|
|
292
|
+
)
|
|
293
|
+
ayh, byh, kyh = utils.setup_pml_half(
|
|
294
|
+
pml_width[2:4],
|
|
295
|
+
pml_start[2:4],
|
|
296
|
+
max_pml,
|
|
297
|
+
dt,
|
|
298
|
+
ny,
|
|
299
|
+
max_vel,
|
|
300
|
+
dtype,
|
|
301
|
+
device,
|
|
302
|
+
pml_freq,
|
|
303
|
+
start=0.0,
|
|
304
|
+
grid_spacing=dy,
|
|
305
|
+
)
|
|
306
|
+
axh, bxh, kxh = utils.setup_pml_half(
|
|
307
|
+
pml_width[4:],
|
|
308
|
+
pml_start[4:],
|
|
309
|
+
max_pml,
|
|
310
|
+
dt,
|
|
311
|
+
nx,
|
|
312
|
+
max_vel,
|
|
313
|
+
dtype,
|
|
314
|
+
device,
|
|
315
|
+
pml_freq,
|
|
316
|
+
start=0.0,
|
|
317
|
+
grid_spacing=dx,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
az = az[None, :, None, None]
|
|
321
|
+
azh = azh[None, :, None, None]
|
|
322
|
+
bz = bz[None, :, None, None]
|
|
323
|
+
bzh = bzh[None, :, None, None]
|
|
324
|
+
kz = kz[None, :, None, None]
|
|
325
|
+
kzh = kzh[None, :, None, None]
|
|
326
|
+
|
|
327
|
+
ay = ay[None, None, :, None]
|
|
328
|
+
ayh = ayh[None, None, :, None]
|
|
329
|
+
by = by[None, None, :, None]
|
|
330
|
+
byh = byh[None, None, :, None]
|
|
331
|
+
ky = ky[None, None, :, None]
|
|
332
|
+
kyh = kyh[None, None, :, None]
|
|
333
|
+
|
|
334
|
+
ax = ax[None, None, None, :]
|
|
335
|
+
axh = axh[None, None, None, :]
|
|
336
|
+
bx = bx[None, None, None, :]
|
|
337
|
+
bxh = bxh[None, None, None, :]
|
|
338
|
+
kx = kx[None, None, None, :]
|
|
339
|
+
kxh = kxh[None, None, None, :]
|
|
340
|
+
|
|
341
|
+
return (
|
|
342
|
+
[az, azh, ay, ayh, ax, axh, bz, bzh, by, byh, bx, bxh],
|
|
343
|
+
[kz, kzh, ky, kyh, kx, kxh],
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
def diffy1(a: torch.Tensor, stencil: int, rdy: torch.Tensor) -> torch.Tensor:
|
|
348
|
+
"""Calculates the first y derivative at integer grid points."""
|
|
349
|
+
if stencil == 2:
|
|
350
|
+
return torch.nn.functional.pad(
|
|
351
|
+
(a[..., 1:, :] - a[..., :-1, :]) * rdy, (0, 0, 1, 0)
|
|
352
|
+
)
|
|
353
|
+
if stencil == 4:
|
|
354
|
+
return torch.nn.functional.pad(
|
|
355
|
+
(
|
|
356
|
+
9 / 8 * (a[..., 2:-1, :] - a[..., 1:-2, :])
|
|
357
|
+
+ -1 / 24 * (a[..., 3:, :] - a[..., :-3, :])
|
|
358
|
+
)
|
|
359
|
+
* rdy,
|
|
360
|
+
(0, 0, 2, 1),
|
|
361
|
+
)
|
|
362
|
+
if stencil == 6:
|
|
363
|
+
return torch.nn.functional.pad(
|
|
364
|
+
(
|
|
365
|
+
75 / 64 * (a[..., 3:-2, :] - a[..., 2:-3, :])
|
|
366
|
+
+ -25 / 384 * (a[..., 4:-1, :] - a[..., 1:-4, :])
|
|
367
|
+
+ 3 / 640 * (a[..., 5:, :] - a[..., :-5, :])
|
|
368
|
+
)
|
|
369
|
+
* rdy,
|
|
370
|
+
(0, 0, 3, 2),
|
|
371
|
+
)
|
|
372
|
+
return torch.nn.functional.pad(
|
|
373
|
+
(
|
|
374
|
+
1225 / 1024 * (a[..., 4:-3, :] - a[..., 3:-4, :])
|
|
375
|
+
+ -245 / 3072 * (a[..., 5:-2, :] - a[..., 2:-5, :])
|
|
376
|
+
+ 49 / 5120 * (a[..., 6:-1, :] - a[..., 1:-6, :])
|
|
377
|
+
+ -5 / 7168 * (a[..., 7:, :] - a[..., :-7, :])
|
|
378
|
+
)
|
|
379
|
+
* rdy,
|
|
380
|
+
(0, 0, 4, 3),
|
|
381
|
+
)
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
def diffx1(a: torch.Tensor, stencil: int, rdx: torch.Tensor) -> torch.Tensor:
|
|
385
|
+
"""Calculates the first x derivative at integer grid points."""
|
|
386
|
+
if stencil == 2:
|
|
387
|
+
return torch.nn.functional.pad((a[..., 1:] - a[..., :-1]) * rdx, (1, 0))
|
|
388
|
+
if stencil == 4:
|
|
389
|
+
return torch.nn.functional.pad(
|
|
390
|
+
(
|
|
391
|
+
9 / 8 * (a[..., 2:-1] - a[..., 1:-2])
|
|
392
|
+
+ -1 / 24 * (a[..., 3:] - a[..., :-3])
|
|
393
|
+
)
|
|
394
|
+
* rdx,
|
|
395
|
+
(2, 1),
|
|
396
|
+
)
|
|
397
|
+
if stencil == 6:
|
|
398
|
+
return torch.nn.functional.pad(
|
|
399
|
+
(
|
|
400
|
+
75 / 64 * (a[..., 3:-2] - a[..., 2:-3])
|
|
401
|
+
+ -25 / 384 * (a[..., 4:-1] - a[..., 1:-4])
|
|
402
|
+
+ 3 / 640 * (a[..., 5:] - a[..., :-5])
|
|
403
|
+
)
|
|
404
|
+
* rdx,
|
|
405
|
+
(3, 2),
|
|
406
|
+
)
|
|
407
|
+
return torch.nn.functional.pad(
|
|
408
|
+
(
|
|
409
|
+
1225 / 1024 * (a[..., 4:-3] - a[..., 3:-4])
|
|
410
|
+
+ -245 / 3072 * (a[..., 5:-2] - a[..., 2:-5])
|
|
411
|
+
+ 49 / 5120 * (a[..., 6:-1] - a[..., 1:-6])
|
|
412
|
+
+ -5 / 7168 * (a[..., 7:] - a[..., :-7])
|
|
413
|
+
)
|
|
414
|
+
* rdx,
|
|
415
|
+
(4, 3),
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def diffz1(a: torch.Tensor, stencil: int, rdz: torch.Tensor) -> torch.Tensor:
|
|
420
|
+
"""Calculates the first z derivative at integer grid points."""
|
|
421
|
+
if stencil == 2:
|
|
422
|
+
return torch.nn.functional.pad(
|
|
423
|
+
(a[..., 1:, :, :] - a[..., :-1, :, :]) * rdz,
|
|
424
|
+
(0, 0, 0, 0, 1, 0),
|
|
425
|
+
)
|
|
426
|
+
if stencil == 4:
|
|
427
|
+
return torch.nn.functional.pad(
|
|
428
|
+
(
|
|
429
|
+
9 / 8 * (a[..., 2:-1, :, :] - a[..., 1:-2, :, :])
|
|
430
|
+
+ -1 / 24 * (a[..., 3:, :, :] - a[..., :-3, :, :])
|
|
431
|
+
)
|
|
432
|
+
* rdz,
|
|
433
|
+
(0, 0, 0, 0, 2, 1),
|
|
434
|
+
)
|
|
435
|
+
if stencil == 6:
|
|
436
|
+
return torch.nn.functional.pad(
|
|
437
|
+
(
|
|
438
|
+
75 / 64 * (a[..., 3:-2, :, :] - a[..., 2:-3, :, :])
|
|
439
|
+
+ -25 / 384 * (a[..., 4:-1, :, :] - a[..., 1:-4, :, :])
|
|
440
|
+
+ 3 / 640 * (a[..., 5:, :, :] - a[..., :-5, :, :])
|
|
441
|
+
)
|
|
442
|
+
* rdz,
|
|
443
|
+
(0, 0, 0, 0, 3, 2),
|
|
444
|
+
)
|
|
445
|
+
return torch.nn.functional.pad(
|
|
446
|
+
(
|
|
447
|
+
1225 / 1024 * (a[..., 4:-3, :, :] - a[..., 3:-4, :, :])
|
|
448
|
+
+ -245 / 3072 * (a[..., 5:-2, :, :] - a[..., 2:-5, :, :])
|
|
449
|
+
+ 49 / 5120 * (a[..., 6:-1, :, :] - a[..., 1:-6, :, :])
|
|
450
|
+
+ -5 / 7168 * (a[..., 7:, :, :] - a[..., :-7, :, :])
|
|
451
|
+
)
|
|
452
|
+
* rdz,
|
|
453
|
+
(0, 0, 0, 0, 4, 3),
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def diffyh1(a: torch.Tensor, stencil: int, rdy: torch.Tensor) -> torch.Tensor:
|
|
458
|
+
"""Calculates the first y derivative at half integer grid points."""
|
|
459
|
+
if stencil == 2:
|
|
460
|
+
return torch.nn.functional.pad(
|
|
461
|
+
(a[..., 2:, :] - a[..., 1:-1, :]) * rdy, (0, 0, 1, 1)
|
|
462
|
+
)
|
|
463
|
+
if stencil == 4:
|
|
464
|
+
return torch.nn.functional.pad(
|
|
465
|
+
(
|
|
466
|
+
9 / 8 * (a[..., 3:-1, :] - a[..., 2:-2, :])
|
|
467
|
+
+ -1 / 24 * (a[..., 4:, :] - a[..., 1:-3, :])
|
|
468
|
+
)
|
|
469
|
+
* rdy,
|
|
470
|
+
(0, 0, 2, 2),
|
|
471
|
+
)
|
|
472
|
+
if stencil == 6:
|
|
473
|
+
return torch.nn.functional.pad(
|
|
474
|
+
(
|
|
475
|
+
75 / 64 * (a[..., 4:-2, :] - a[..., 3:-3, :])
|
|
476
|
+
+ -25 / 384 * (a[..., 5:-1, :] - a[..., 2:-4, :])
|
|
477
|
+
+ 3 / 640 * (a[..., 6:, :] - a[..., 1:-5, :])
|
|
478
|
+
)
|
|
479
|
+
* rdy,
|
|
480
|
+
(0, 0, 3, 3),
|
|
481
|
+
)
|
|
482
|
+
return torch.nn.functional.pad(
|
|
483
|
+
(
|
|
484
|
+
1225 / 1024 * (a[..., 5:-3, :] - a[..., 4:-4, :])
|
|
485
|
+
+ -245 / 3072 * (a[..., 6:-2, :] - a[..., 3:-5, :])
|
|
486
|
+
+ 49 / 5120 * (a[..., 7:-1, :] - a[..., 2:-6, :])
|
|
487
|
+
+ -5 / 7168 * (a[..., 8:, :] - a[..., 1:-7, :])
|
|
488
|
+
)
|
|
489
|
+
* rdy,
|
|
490
|
+
(0, 0, 4, 4),
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def diffzh1(a: torch.Tensor, stencil: int, rdz: torch.Tensor) -> torch.Tensor:
|
|
495
|
+
"""Calculates the first z derivative at half integer grid points.
|
|
496
|
+
|
|
497
|
+
For a tensor with shape [..., nz, ny, nx], the derivative is taken along
|
|
498
|
+
the z dimension at half-grid locations.
|
|
499
|
+
"""
|
|
500
|
+
if stencil == 2:
|
|
501
|
+
return torch.nn.functional.pad(
|
|
502
|
+
(a[..., 2:, :, :] - a[..., 1:-1, :, :]) * rdz, (0, 0, 0, 0, 1, 1)
|
|
503
|
+
)
|
|
504
|
+
if stencil == 4:
|
|
505
|
+
return torch.nn.functional.pad(
|
|
506
|
+
(
|
|
507
|
+
9 / 8 * (a[..., 3:-1, :, :] - a[..., 2:-2, :, :])
|
|
508
|
+
+ -1 / 24 * (a[..., 4:, :, :] - a[..., 1:-3, :, :])
|
|
509
|
+
)
|
|
510
|
+
* rdz,
|
|
511
|
+
(0, 0, 0, 0, 2, 2),
|
|
512
|
+
)
|
|
513
|
+
if stencil == 6:
|
|
514
|
+
return torch.nn.functional.pad(
|
|
515
|
+
(
|
|
516
|
+
75 / 64 * (a[..., 4:-2, :, :] - a[..., 3:-3, :, :])
|
|
517
|
+
+ -25 / 384 * (a[..., 5:-1, :, :] - a[..., 2:-4, :, :])
|
|
518
|
+
+ 3 / 640 * (a[..., 6:, :, :] - a[..., 1:-5, :, :])
|
|
519
|
+
)
|
|
520
|
+
* rdz,
|
|
521
|
+
(0, 0, 0, 0, 3, 3),
|
|
522
|
+
)
|
|
523
|
+
return torch.nn.functional.pad(
|
|
524
|
+
(
|
|
525
|
+
1225 / 1024 * (a[..., 5:-3, :, :] - a[..., 4:-4, :, :])
|
|
526
|
+
+ -245 / 3072 * (a[..., 6:-2, :, :] - a[..., 3:-5, :, :])
|
|
527
|
+
+ 49 / 5120 * (a[..., 7:-1, :, :] - a[..., 2:-6, :, :])
|
|
528
|
+
+ -5 / 7168 * (a[..., 8:, :, :] - a[..., 1:-7, :, :])
|
|
529
|
+
)
|
|
530
|
+
* rdz,
|
|
531
|
+
(0, 0, 0, 0, 4, 4),
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def diffxh1(a: torch.Tensor, stencil: int, rdx: torch.Tensor) -> torch.Tensor:
|
|
536
|
+
"""Calculates the first x derivative at half integer grid points."""
|
|
537
|
+
if stencil == 2:
|
|
538
|
+
return torch.nn.functional.pad((a[..., 2:] - a[..., 1:-1]) * rdx, (1, 1))
|
|
539
|
+
if stencil == 4:
|
|
540
|
+
return torch.nn.functional.pad(
|
|
541
|
+
(
|
|
542
|
+
9 / 8 * (a[..., 3:-1] - a[..., 2:-2])
|
|
543
|
+
+ -1 / 24 * (a[..., 4:] - a[..., 1:-3])
|
|
544
|
+
)
|
|
545
|
+
* rdx,
|
|
546
|
+
(2, 2),
|
|
547
|
+
)
|
|
548
|
+
if stencil == 6:
|
|
549
|
+
return torch.nn.functional.pad(
|
|
550
|
+
(
|
|
551
|
+
75 / 64 * (a[..., 4:-2] - a[..., 3:-3])
|
|
552
|
+
+ -25 / 384 * (a[..., 5:-1] - a[..., 2:-4])
|
|
553
|
+
+ 3 / 640 * (a[..., 6:] - a[..., 1:-5])
|
|
554
|
+
)
|
|
555
|
+
* rdx,
|
|
556
|
+
(3, 3),
|
|
557
|
+
)
|
|
558
|
+
return torch.nn.functional.pad(
|
|
559
|
+
(
|
|
560
|
+
1225 / 1024 * (a[..., 5:-3] - a[..., 4:-4])
|
|
561
|
+
+ -245 / 3072 * (a[..., 6:-2] - a[..., 3:-5])
|
|
562
|
+
+ 49 / 5120 * (a[..., 7:-1] - a[..., 2:-6])
|
|
563
|
+
+ -5 / 7168 * (a[..., 8:] - a[..., 1:-7])
|
|
564
|
+
)
|
|
565
|
+
* rdx,
|
|
566
|
+
(4, 4),
|
|
567
|
+
)
|
tide/storage.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
"""Storage helpers for wavefield snapshots.
|
|
2
|
+
|
|
3
|
+
This mirrors Deepwave's snapshot storage abstraction for use in the Maxwell
|
|
4
|
+
propagator. Stage 1 supports snapshot storage on device/CPU/disk.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import contextlib
|
|
10
|
+
import os
|
|
11
|
+
import shutil
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Union
|
|
14
|
+
from uuid import uuid4
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
# Snapshot storage modes: prefer DEVICE, fall back to CPU or DISK; NONE disables snapshotting
|
|
19
|
+
STORAGE_DEVICE = 0 # Keep snapshots on the accelerator (fastest, uses device memory)
|
|
20
|
+
STORAGE_CPU = 1 # Stage snapshots in host memory (slower, avoids GPU OOM)
|
|
21
|
+
STORAGE_DISK = 2 # Spill snapshots to disk (slowest, preserves host/GPU memory)
|
|
22
|
+
STORAGE_NONE = 3 # Do not store snapshots
|
|
23
|
+
|
|
24
|
+
# Number of ring buffers for CPU-stage ping-pong: allows overlapping reads/writes
|
|
25
|
+
# (write to one, read from another, keep one ready). MUST match csrc NUM_BUFFERS.
|
|
26
|
+
_CPU_STORAGE_BUFFERS = 3
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _normalize_storage_compression(storage_compression: Union[bool, str, None]) -> str:
|
|
30
|
+
"""Normalize the storage compression setting to a standard string.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
storage_compression: The input storage compression setting, which can be
|
|
34
|
+
a boolean, a string, or None.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
A normalized string representing the storage compression mode:
|
|
38
|
+
- "none" for no compression
|
|
39
|
+
- "bf16" for bfloat16 compression
|
|
40
|
+
- "fp8" for float8 compression
|
|
41
|
+
|
|
42
|
+
Raises:
|
|
43
|
+
ValueError: If the input value is not recognized.
|
|
44
|
+
"""
|
|
45
|
+
if storage_compression is True:
|
|
46
|
+
return "bf16"
|
|
47
|
+
if storage_compression is False or storage_compression is None:
|
|
48
|
+
return "none"
|
|
49
|
+
if isinstance(storage_compression, str):
|
|
50
|
+
value = storage_compression.strip().lower()
|
|
51
|
+
if value in {"none", "false", "off", "0"}:
|
|
52
|
+
return "none"
|
|
53
|
+
if value in {"bf16", "bfloat16"}:
|
|
54
|
+
return "bf16"
|
|
55
|
+
if value in {"fp8", "float8", "e4m3", "e4m3fn", "fp8_e4m3"}:
|
|
56
|
+
return "fp8"
|
|
57
|
+
raise ValueError(
|
|
58
|
+
"storage_compression must be False/True or one of 'none', 'bf16', or 'fp8'."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def _resolve_storage_compression(
|
|
63
|
+
storage_compression: Union[bool, str, None],
|
|
64
|
+
dtype: torch.dtype,
|
|
65
|
+
device: torch.device,
|
|
66
|
+
*,
|
|
67
|
+
context: str,
|
|
68
|
+
allow_fp8: bool = True,
|
|
69
|
+
) -> tuple[str, torch.dtype, int]:
|
|
70
|
+
storage_kind = _normalize_storage_compression(storage_compression)
|
|
71
|
+
if storage_kind == "none":
|
|
72
|
+
return storage_kind, dtype, dtype.itemsize
|
|
73
|
+
if storage_kind == "bf16":
|
|
74
|
+
if dtype != torch.float32:
|
|
75
|
+
raise NotImplementedError(
|
|
76
|
+
f"{context} (BF16 storage) is only supported for float32."
|
|
77
|
+
)
|
|
78
|
+
return storage_kind, torch.bfloat16, 2
|
|
79
|
+
if storage_kind == "fp8":
|
|
80
|
+
if not allow_fp8:
|
|
81
|
+
raise NotImplementedError(
|
|
82
|
+
f"{context} (FP8 storage) is not supported in this path."
|
|
83
|
+
)
|
|
84
|
+
# FP8 now supported on both CUDA and CPU
|
|
85
|
+
if dtype != torch.float32:
|
|
86
|
+
raise NotImplementedError(
|
|
87
|
+
f"{context} (FP8 storage) is only supported for float32."
|
|
88
|
+
)
|
|
89
|
+
return storage_kind, torch.uint8, 1
|
|
90
|
+
raise RuntimeError(f"Unsupported storage compression mode: {storage_kind}")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def storage_mode_to_int(storage_mode_str: str) -> int:
|
|
94
|
+
mode = storage_mode_str.lower()
|
|
95
|
+
if mode == "device":
|
|
96
|
+
return STORAGE_DEVICE
|
|
97
|
+
if mode == "cpu":
|
|
98
|
+
return STORAGE_CPU
|
|
99
|
+
if mode == "disk":
|
|
100
|
+
return STORAGE_DISK
|
|
101
|
+
if mode == "none":
|
|
102
|
+
return STORAGE_NONE
|
|
103
|
+
raise ValueError(
|
|
104
|
+
"storage_mode must be 'device', 'cpu', 'disk', 'none', or 'auto', "
|
|
105
|
+
f"but got {storage_mode_str!r}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class TemporaryStorage:
|
|
110
|
+
"""Manages temporary files for disk storage.
|
|
111
|
+
|
|
112
|
+
Creates a unique subdirectory for each instantiation to prevent collisions.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(self, base_path: str, num_files: int) -> None:
|
|
116
|
+
self.base_dir = Path(base_path) / f"tide_tmp_{os.getpid()}_{uuid4().hex}"
|
|
117
|
+
self.base_dir.mkdir(parents=True, exist_ok=True)
|
|
118
|
+
self.filenames: list[str] = [
|
|
119
|
+
str(self.base_dir / f"shot_{i}.bin") for i in range(num_files)
|
|
120
|
+
]
|
|
121
|
+
|
|
122
|
+
def get_filenames(self) -> list[str]:
|
|
123
|
+
return self.filenames
|
|
124
|
+
|
|
125
|
+
def close(self) -> None:
|
|
126
|
+
if self.base_dir.exists():
|
|
127
|
+
with contextlib.suppress(OSError):
|
|
128
|
+
shutil.rmtree(self.base_dir)
|
|
129
|
+
|
|
130
|
+
def __del__(self) -> None:
|
|
131
|
+
self.close()
|
tide/tide/libtide_C.so
ADDED
|
Binary file
|