tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tide/staggered.py ADDED
@@ -0,0 +1,567 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+
5
+ from . import utils
6
+
7
+
8
+ def set_pml_profiles(
9
+ pml_width: list[int],
10
+ accuracy: int,
11
+ fd_pad: list[int],
12
+ dt: float,
13
+ grid_spacing: list[float],
14
+ max_vel: float,
15
+ dtype: torch.dtype,
16
+ device: torch.device,
17
+ pml_freq: float,
18
+ ny: int,
19
+ nx: int,
20
+ ) -> list[torch.Tensor]:
21
+ """Sets up PML profiles for a staggered grid.
22
+
23
+ Args:
24
+ pml_width: A list of integers specifying the width of the PML
25
+ on each side (top, bottom, left, right).
26
+ accuracy: The finite-difference accuracy order.
27
+ fd_pad: A list of integers specifying the padding for finite-difference.
28
+ dt: The time step.
29
+ grid_spacing: A list of floats specifying the grid spacing in
30
+ y and x directions.
31
+ max_vel: The maximum velocity in the model.
32
+ dtype: The data type of the tensors (e.g., torch.float32).
33
+ device: The device on which the tensors will be (e.g., 'cuda', 'cpu').
34
+ pml_freq: The PML frequency.
35
+ ny: The number of grid points in the y direction.
36
+ nx: The number of grid points in the x direction.
37
+
38
+ Returns:
39
+ A list containing:
40
+ - a, b profiles: [ay, ayh, ax, axh, by, byh, bx, bxh]
41
+ - k profiles: [ky, kyh, kx, kxh]
42
+ Total 12 tensors.
43
+
44
+ """
45
+ pml_start: list[float] = [
46
+ fd_pad[0] + pml_width[0],
47
+ ny - 1 - fd_pad[1] - pml_width[1],
48
+ fd_pad[2] + pml_width[2],
49
+ nx - 1 - fd_pad[3] - pml_width[3],
50
+ ]
51
+ max_pml = max(
52
+ [
53
+ pml_width[0] * grid_spacing[0],
54
+ pml_width[1] * grid_spacing[0],
55
+ pml_width[2] * grid_spacing[1],
56
+ pml_width[3] * grid_spacing[1],
57
+ ],
58
+ )
59
+
60
+ # Integer grid PML profiles
61
+ ay, by, ky = utils.setup_pml(
62
+ pml_width[:2],
63
+ pml_start[:2],
64
+ max_pml,
65
+ dt,
66
+ ny,
67
+ max_vel,
68
+ dtype,
69
+ device,
70
+ pml_freq,
71
+ start=0.0,
72
+ grid_spacing=grid_spacing[0],
73
+ )
74
+ ax, bx, kx = utils.setup_pml(
75
+ pml_width[2:],
76
+ pml_start[2:],
77
+ max_pml,
78
+ dt,
79
+ nx,
80
+ max_vel,
81
+ dtype,
82
+ device,
83
+ pml_freq,
84
+ start=0.0,
85
+ grid_spacing=grid_spacing[1],
86
+ )
87
+
88
+ # Half grid PML profiles
89
+ ayh, byh, kyh = utils.setup_pml_half(
90
+ pml_width[:2],
91
+ pml_start[:2],
92
+ max_pml,
93
+ dt,
94
+ ny,
95
+ max_vel,
96
+ dtype,
97
+ device,
98
+ pml_freq,
99
+ start=0.0,
100
+ grid_spacing=grid_spacing[0],
101
+ )
102
+ axh, bxh, kxh = utils.setup_pml_half(
103
+ pml_width[2:],
104
+ pml_start[2:],
105
+ max_pml,
106
+ dt,
107
+ nx,
108
+ max_vel,
109
+ dtype,
110
+ device,
111
+ pml_freq,
112
+ start=0.0,
113
+ grid_spacing=grid_spacing[1],
114
+ )
115
+
116
+ # Reshape for broadcasting: [batch, ny, nx]
117
+ ay = ay[None, :, None]
118
+ ayh = ayh[None, :, None]
119
+ ax = ax[None, None, :]
120
+ axh = axh[None, None, :]
121
+ by = by[None, :, None]
122
+ byh = byh[None, :, None]
123
+ bx = bx[None, None, :]
124
+ bxh = bxh[None, None, :]
125
+
126
+ ky = ky[None, :, None]
127
+ kyh = kyh[None, :, None]
128
+ kx = kx[None, None, :]
129
+ kxh = kxh[None, None, :]
130
+
131
+ return [ay, ayh, ax, axh, by, byh, bx, bxh, ky, kyh, kx, kxh]
132
+
133
+
134
+ def setup_pml_profiles_1d(
135
+ n: int,
136
+ pml_width0: int,
137
+ pml_width1: int,
138
+ sigma_max: float,
139
+ dt: float,
140
+ device: torch.device,
141
+ dtype: torch.dtype,
142
+ ) -> tuple[
143
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
144
+ ]:
145
+ """Create 1D CPML profiles (a, b, k) for integer and half-grid points."""
146
+ eps = 1e-9
147
+ n_power = 2
148
+
149
+ if pml_width0 == 0 and pml_width1 == 0:
150
+ zeros = torch.zeros(n, device=device, dtype=dtype)
151
+ ones = torch.ones(n, device=device, dtype=dtype)
152
+ return zeros, zeros, zeros, zeros, ones, ones
153
+
154
+ def _profiles(start: float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
155
+ x = torch.arange(n, device=device, dtype=dtype) + start
156
+ left_start = float(pml_width0)
157
+ right_start = float(n - 1 - pml_width1)
158
+
159
+ if pml_width0 == 0:
160
+ frac_left = torch.zeros_like(x)
161
+ else:
162
+ frac_left = (left_start - x) / float(pml_width0)
163
+ if pml_width1 == 0:
164
+ frac_right = torch.zeros_like(x)
165
+ else:
166
+ frac_right = (x - right_start) / float(pml_width1)
167
+
168
+ pml_frac = torch.clamp(torch.maximum(frac_left, frac_right), 0.0, 1.0)
169
+ sigma = sigma_max * pml_frac.pow(n_power)
170
+ kappa = torch.ones_like(sigma)
171
+
172
+ sigma_alpha = sigma
173
+ b = torch.exp(-sigma_alpha * abs(dt))
174
+ denom = sigma_alpha + eps
175
+ a = torch.where(
176
+ sigma_alpha > 0.0, sigma * (b - 1.0) / denom, torch.zeros_like(b)
177
+ )
178
+ return a, b, kappa
179
+
180
+ ay, by, ky = _profiles(0.0)
181
+ ayh, byh, kyh = _profiles(0.5)
182
+ return ay, ayh, by, byh, ky, kyh
183
+
184
+
185
+ def set_pml_profiles_3d(
186
+ pml_width: list[int],
187
+ accuracy: int,
188
+ fd_pad: list[int],
189
+ dt: float,
190
+ grid_spacing: list[float],
191
+ max_vel: float,
192
+ dtype: torch.dtype,
193
+ device: torch.device,
194
+ pml_freq: float,
195
+ nz: int,
196
+ ny: int,
197
+ nx: int,
198
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
199
+ """Sets up 3D PML profiles for a staggered grid.
200
+
201
+ Args:
202
+ pml_width: Widths [z0, z1, y0, y1, x0, x1].
203
+ accuracy: Finite-difference accuracy order (unused, kept for API parity).
204
+ fd_pad: FD padding [z0, z1, y0, y1, x0, x1].
205
+ dt: Time step.
206
+ grid_spacing: Grid spacing [dz, dy, dx].
207
+ max_vel: Maximum velocity (unused in EM formulation, kept for API parity).
208
+ dtype: Tensor dtype.
209
+ device: Tensor device.
210
+ pml_freq: PML frequency.
211
+ nz, ny, nx: Padded grid sizes.
212
+
213
+ Returns:
214
+ - PML a/b profiles: [az, azh, ay, ayh, ax, axh, bz, bzh, by, byh, bx, bxh]
215
+ - PML kappa profiles: [kz, kzh, ky, kyh, kx, kxh]
216
+ """
217
+ _ = accuracy
218
+ dz, dy, dx = grid_spacing
219
+
220
+ pml_start: list[float] = [
221
+ fd_pad[0] + pml_width[0],
222
+ nz - 1 - fd_pad[1] - pml_width[1],
223
+ fd_pad[2] + pml_width[2],
224
+ ny - 1 - fd_pad[3] - pml_width[3],
225
+ fd_pad[4] + pml_width[4],
226
+ nx - 1 - fd_pad[5] - pml_width[5],
227
+ ]
228
+
229
+ max_pml = max(
230
+ [
231
+ pml_width[0] * dz,
232
+ pml_width[1] * dz,
233
+ pml_width[2] * dy,
234
+ pml_width[3] * dy,
235
+ pml_width[4] * dx,
236
+ pml_width[5] * dx,
237
+ ]
238
+ )
239
+
240
+ az, bz, kz = utils.setup_pml(
241
+ pml_width[:2],
242
+ pml_start[:2],
243
+ max_pml,
244
+ dt,
245
+ nz,
246
+ max_vel,
247
+ dtype,
248
+ device,
249
+ pml_freq,
250
+ start=0.0,
251
+ grid_spacing=dz,
252
+ )
253
+ ay, by, ky = utils.setup_pml(
254
+ pml_width[2:4],
255
+ pml_start[2:4],
256
+ max_pml,
257
+ dt,
258
+ ny,
259
+ max_vel,
260
+ dtype,
261
+ device,
262
+ pml_freq,
263
+ start=0.0,
264
+ grid_spacing=dy,
265
+ )
266
+ ax, bx, kx = utils.setup_pml(
267
+ pml_width[4:],
268
+ pml_start[4:],
269
+ max_pml,
270
+ dt,
271
+ nx,
272
+ max_vel,
273
+ dtype,
274
+ device,
275
+ pml_freq,
276
+ start=0.0,
277
+ grid_spacing=dx,
278
+ )
279
+
280
+ azh, bzh, kzh = utils.setup_pml_half(
281
+ pml_width[:2],
282
+ pml_start[:2],
283
+ max_pml,
284
+ dt,
285
+ nz,
286
+ max_vel,
287
+ dtype,
288
+ device,
289
+ pml_freq,
290
+ start=0.0,
291
+ grid_spacing=dz,
292
+ )
293
+ ayh, byh, kyh = utils.setup_pml_half(
294
+ pml_width[2:4],
295
+ pml_start[2:4],
296
+ max_pml,
297
+ dt,
298
+ ny,
299
+ max_vel,
300
+ dtype,
301
+ device,
302
+ pml_freq,
303
+ start=0.0,
304
+ grid_spacing=dy,
305
+ )
306
+ axh, bxh, kxh = utils.setup_pml_half(
307
+ pml_width[4:],
308
+ pml_start[4:],
309
+ max_pml,
310
+ dt,
311
+ nx,
312
+ max_vel,
313
+ dtype,
314
+ device,
315
+ pml_freq,
316
+ start=0.0,
317
+ grid_spacing=dx,
318
+ )
319
+
320
+ az = az[None, :, None, None]
321
+ azh = azh[None, :, None, None]
322
+ bz = bz[None, :, None, None]
323
+ bzh = bzh[None, :, None, None]
324
+ kz = kz[None, :, None, None]
325
+ kzh = kzh[None, :, None, None]
326
+
327
+ ay = ay[None, None, :, None]
328
+ ayh = ayh[None, None, :, None]
329
+ by = by[None, None, :, None]
330
+ byh = byh[None, None, :, None]
331
+ ky = ky[None, None, :, None]
332
+ kyh = kyh[None, None, :, None]
333
+
334
+ ax = ax[None, None, None, :]
335
+ axh = axh[None, None, None, :]
336
+ bx = bx[None, None, None, :]
337
+ bxh = bxh[None, None, None, :]
338
+ kx = kx[None, None, None, :]
339
+ kxh = kxh[None, None, None, :]
340
+
341
+ return (
342
+ [az, azh, ay, ayh, ax, axh, bz, bzh, by, byh, bx, bxh],
343
+ [kz, kzh, ky, kyh, kx, kxh],
344
+ )
345
+
346
+
347
+ def diffy1(a: torch.Tensor, stencil: int, rdy: torch.Tensor) -> torch.Tensor:
348
+ """Calculates the first y derivative at integer grid points."""
349
+ if stencil == 2:
350
+ return torch.nn.functional.pad(
351
+ (a[..., 1:, :] - a[..., :-1, :]) * rdy, (0, 0, 1, 0)
352
+ )
353
+ if stencil == 4:
354
+ return torch.nn.functional.pad(
355
+ (
356
+ 9 / 8 * (a[..., 2:-1, :] - a[..., 1:-2, :])
357
+ + -1 / 24 * (a[..., 3:, :] - a[..., :-3, :])
358
+ )
359
+ * rdy,
360
+ (0, 0, 2, 1),
361
+ )
362
+ if stencil == 6:
363
+ return torch.nn.functional.pad(
364
+ (
365
+ 75 / 64 * (a[..., 3:-2, :] - a[..., 2:-3, :])
366
+ + -25 / 384 * (a[..., 4:-1, :] - a[..., 1:-4, :])
367
+ + 3 / 640 * (a[..., 5:, :] - a[..., :-5, :])
368
+ )
369
+ * rdy,
370
+ (0, 0, 3, 2),
371
+ )
372
+ return torch.nn.functional.pad(
373
+ (
374
+ 1225 / 1024 * (a[..., 4:-3, :] - a[..., 3:-4, :])
375
+ + -245 / 3072 * (a[..., 5:-2, :] - a[..., 2:-5, :])
376
+ + 49 / 5120 * (a[..., 6:-1, :] - a[..., 1:-6, :])
377
+ + -5 / 7168 * (a[..., 7:, :] - a[..., :-7, :])
378
+ )
379
+ * rdy,
380
+ (0, 0, 4, 3),
381
+ )
382
+
383
+
384
+ def diffx1(a: torch.Tensor, stencil: int, rdx: torch.Tensor) -> torch.Tensor:
385
+ """Calculates the first x derivative at integer grid points."""
386
+ if stencil == 2:
387
+ return torch.nn.functional.pad((a[..., 1:] - a[..., :-1]) * rdx, (1, 0))
388
+ if stencil == 4:
389
+ return torch.nn.functional.pad(
390
+ (
391
+ 9 / 8 * (a[..., 2:-1] - a[..., 1:-2])
392
+ + -1 / 24 * (a[..., 3:] - a[..., :-3])
393
+ )
394
+ * rdx,
395
+ (2, 1),
396
+ )
397
+ if stencil == 6:
398
+ return torch.nn.functional.pad(
399
+ (
400
+ 75 / 64 * (a[..., 3:-2] - a[..., 2:-3])
401
+ + -25 / 384 * (a[..., 4:-1] - a[..., 1:-4])
402
+ + 3 / 640 * (a[..., 5:] - a[..., :-5])
403
+ )
404
+ * rdx,
405
+ (3, 2),
406
+ )
407
+ return torch.nn.functional.pad(
408
+ (
409
+ 1225 / 1024 * (a[..., 4:-3] - a[..., 3:-4])
410
+ + -245 / 3072 * (a[..., 5:-2] - a[..., 2:-5])
411
+ + 49 / 5120 * (a[..., 6:-1] - a[..., 1:-6])
412
+ + -5 / 7168 * (a[..., 7:] - a[..., :-7])
413
+ )
414
+ * rdx,
415
+ (4, 3),
416
+ )
417
+
418
+
419
+ def diffz1(a: torch.Tensor, stencil: int, rdz: torch.Tensor) -> torch.Tensor:
420
+ """Calculates the first z derivative at integer grid points."""
421
+ if stencil == 2:
422
+ return torch.nn.functional.pad(
423
+ (a[..., 1:, :, :] - a[..., :-1, :, :]) * rdz,
424
+ (0, 0, 0, 0, 1, 0),
425
+ )
426
+ if stencil == 4:
427
+ return torch.nn.functional.pad(
428
+ (
429
+ 9 / 8 * (a[..., 2:-1, :, :] - a[..., 1:-2, :, :])
430
+ + -1 / 24 * (a[..., 3:, :, :] - a[..., :-3, :, :])
431
+ )
432
+ * rdz,
433
+ (0, 0, 0, 0, 2, 1),
434
+ )
435
+ if stencil == 6:
436
+ return torch.nn.functional.pad(
437
+ (
438
+ 75 / 64 * (a[..., 3:-2, :, :] - a[..., 2:-3, :, :])
439
+ + -25 / 384 * (a[..., 4:-1, :, :] - a[..., 1:-4, :, :])
440
+ + 3 / 640 * (a[..., 5:, :, :] - a[..., :-5, :, :])
441
+ )
442
+ * rdz,
443
+ (0, 0, 0, 0, 3, 2),
444
+ )
445
+ return torch.nn.functional.pad(
446
+ (
447
+ 1225 / 1024 * (a[..., 4:-3, :, :] - a[..., 3:-4, :, :])
448
+ + -245 / 3072 * (a[..., 5:-2, :, :] - a[..., 2:-5, :, :])
449
+ + 49 / 5120 * (a[..., 6:-1, :, :] - a[..., 1:-6, :, :])
450
+ + -5 / 7168 * (a[..., 7:, :, :] - a[..., :-7, :, :])
451
+ )
452
+ * rdz,
453
+ (0, 0, 0, 0, 4, 3),
454
+ )
455
+
456
+
457
+ def diffyh1(a: torch.Tensor, stencil: int, rdy: torch.Tensor) -> torch.Tensor:
458
+ """Calculates the first y derivative at half integer grid points."""
459
+ if stencil == 2:
460
+ return torch.nn.functional.pad(
461
+ (a[..., 2:, :] - a[..., 1:-1, :]) * rdy, (0, 0, 1, 1)
462
+ )
463
+ if stencil == 4:
464
+ return torch.nn.functional.pad(
465
+ (
466
+ 9 / 8 * (a[..., 3:-1, :] - a[..., 2:-2, :])
467
+ + -1 / 24 * (a[..., 4:, :] - a[..., 1:-3, :])
468
+ )
469
+ * rdy,
470
+ (0, 0, 2, 2),
471
+ )
472
+ if stencil == 6:
473
+ return torch.nn.functional.pad(
474
+ (
475
+ 75 / 64 * (a[..., 4:-2, :] - a[..., 3:-3, :])
476
+ + -25 / 384 * (a[..., 5:-1, :] - a[..., 2:-4, :])
477
+ + 3 / 640 * (a[..., 6:, :] - a[..., 1:-5, :])
478
+ )
479
+ * rdy,
480
+ (0, 0, 3, 3),
481
+ )
482
+ return torch.nn.functional.pad(
483
+ (
484
+ 1225 / 1024 * (a[..., 5:-3, :] - a[..., 4:-4, :])
485
+ + -245 / 3072 * (a[..., 6:-2, :] - a[..., 3:-5, :])
486
+ + 49 / 5120 * (a[..., 7:-1, :] - a[..., 2:-6, :])
487
+ + -5 / 7168 * (a[..., 8:, :] - a[..., 1:-7, :])
488
+ )
489
+ * rdy,
490
+ (0, 0, 4, 4),
491
+ )
492
+
493
+
494
+ def diffzh1(a: torch.Tensor, stencil: int, rdz: torch.Tensor) -> torch.Tensor:
495
+ """Calculates the first z derivative at half integer grid points.
496
+
497
+ For a tensor with shape [..., nz, ny, nx], the derivative is taken along
498
+ the z dimension at half-grid locations.
499
+ """
500
+ if stencil == 2:
501
+ return torch.nn.functional.pad(
502
+ (a[..., 2:, :, :] - a[..., 1:-1, :, :]) * rdz, (0, 0, 0, 0, 1, 1)
503
+ )
504
+ if stencil == 4:
505
+ return torch.nn.functional.pad(
506
+ (
507
+ 9 / 8 * (a[..., 3:-1, :, :] - a[..., 2:-2, :, :])
508
+ + -1 / 24 * (a[..., 4:, :, :] - a[..., 1:-3, :, :])
509
+ )
510
+ * rdz,
511
+ (0, 0, 0, 0, 2, 2),
512
+ )
513
+ if stencil == 6:
514
+ return torch.nn.functional.pad(
515
+ (
516
+ 75 / 64 * (a[..., 4:-2, :, :] - a[..., 3:-3, :, :])
517
+ + -25 / 384 * (a[..., 5:-1, :, :] - a[..., 2:-4, :, :])
518
+ + 3 / 640 * (a[..., 6:, :, :] - a[..., 1:-5, :, :])
519
+ )
520
+ * rdz,
521
+ (0, 0, 0, 0, 3, 3),
522
+ )
523
+ return torch.nn.functional.pad(
524
+ (
525
+ 1225 / 1024 * (a[..., 5:-3, :, :] - a[..., 4:-4, :, :])
526
+ + -245 / 3072 * (a[..., 6:-2, :, :] - a[..., 3:-5, :, :])
527
+ + 49 / 5120 * (a[..., 7:-1, :, :] - a[..., 2:-6, :, :])
528
+ + -5 / 7168 * (a[..., 8:, :, :] - a[..., 1:-7, :, :])
529
+ )
530
+ * rdz,
531
+ (0, 0, 0, 0, 4, 4),
532
+ )
533
+
534
+
535
+ def diffxh1(a: torch.Tensor, stencil: int, rdx: torch.Tensor) -> torch.Tensor:
536
+ """Calculates the first x derivative at half integer grid points."""
537
+ if stencil == 2:
538
+ return torch.nn.functional.pad((a[..., 2:] - a[..., 1:-1]) * rdx, (1, 1))
539
+ if stencil == 4:
540
+ return torch.nn.functional.pad(
541
+ (
542
+ 9 / 8 * (a[..., 3:-1] - a[..., 2:-2])
543
+ + -1 / 24 * (a[..., 4:] - a[..., 1:-3])
544
+ )
545
+ * rdx,
546
+ (2, 2),
547
+ )
548
+ if stencil == 6:
549
+ return torch.nn.functional.pad(
550
+ (
551
+ 75 / 64 * (a[..., 4:-2] - a[..., 3:-3])
552
+ + -25 / 384 * (a[..., 5:-1] - a[..., 2:-4])
553
+ + 3 / 640 * (a[..., 6:] - a[..., 1:-5])
554
+ )
555
+ * rdx,
556
+ (3, 3),
557
+ )
558
+ return torch.nn.functional.pad(
559
+ (
560
+ 1225 / 1024 * (a[..., 5:-3] - a[..., 4:-4])
561
+ + -245 / 3072 * (a[..., 6:-2] - a[..., 3:-5])
562
+ + 49 / 5120 * (a[..., 7:-1] - a[..., 2:-6])
563
+ + -5 / 7168 * (a[..., 8:] - a[..., 1:-7])
564
+ )
565
+ * rdx,
566
+ (4, 4),
567
+ )
tide/storage.py ADDED
@@ -0,0 +1,131 @@
1
+ """Storage helpers for wavefield snapshots.
2
+
3
+ This mirrors Deepwave's snapshot storage abstraction for use in the Maxwell
4
+ propagator. Stage 1 supports snapshot storage on device/CPU/disk.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import contextlib
10
+ import os
11
+ import shutil
12
+ from pathlib import Path
13
+ from typing import Union
14
+ from uuid import uuid4
15
+
16
+ import torch
17
+
18
+ # Snapshot storage modes: prefer DEVICE, fall back to CPU or DISK; NONE disables snapshotting
19
+ STORAGE_DEVICE = 0 # Keep snapshots on the accelerator (fastest, uses device memory)
20
+ STORAGE_CPU = 1 # Stage snapshots in host memory (slower, avoids GPU OOM)
21
+ STORAGE_DISK = 2 # Spill snapshots to disk (slowest, preserves host/GPU memory)
22
+ STORAGE_NONE = 3 # Do not store snapshots
23
+
24
+ # Number of ring buffers for CPU-stage ping-pong: allows overlapping reads/writes
25
+ # (write to one, read from another, keep one ready). MUST match csrc NUM_BUFFERS.
26
+ _CPU_STORAGE_BUFFERS = 3
27
+
28
+
29
+ def _normalize_storage_compression(storage_compression: Union[bool, str, None]) -> str:
30
+ """Normalize the storage compression setting to a standard string.
31
+
32
+ Args:
33
+ storage_compression: The input storage compression setting, which can be
34
+ a boolean, a string, or None.
35
+
36
+ Returns:
37
+ A normalized string representing the storage compression mode:
38
+ - "none" for no compression
39
+ - "bf16" for bfloat16 compression
40
+ - "fp8" for float8 compression
41
+
42
+ Raises:
43
+ ValueError: If the input value is not recognized.
44
+ """
45
+ if storage_compression is True:
46
+ return "bf16"
47
+ if storage_compression is False or storage_compression is None:
48
+ return "none"
49
+ if isinstance(storage_compression, str):
50
+ value = storage_compression.strip().lower()
51
+ if value in {"none", "false", "off", "0"}:
52
+ return "none"
53
+ if value in {"bf16", "bfloat16"}:
54
+ return "bf16"
55
+ if value in {"fp8", "float8", "e4m3", "e4m3fn", "fp8_e4m3"}:
56
+ return "fp8"
57
+ raise ValueError(
58
+ "storage_compression must be False/True or one of 'none', 'bf16', or 'fp8'."
59
+ )
60
+
61
+
62
+ def _resolve_storage_compression(
63
+ storage_compression: Union[bool, str, None],
64
+ dtype: torch.dtype,
65
+ device: torch.device,
66
+ *,
67
+ context: str,
68
+ allow_fp8: bool = True,
69
+ ) -> tuple[str, torch.dtype, int]:
70
+ storage_kind = _normalize_storage_compression(storage_compression)
71
+ if storage_kind == "none":
72
+ return storage_kind, dtype, dtype.itemsize
73
+ if storage_kind == "bf16":
74
+ if dtype != torch.float32:
75
+ raise NotImplementedError(
76
+ f"{context} (BF16 storage) is only supported for float32."
77
+ )
78
+ return storage_kind, torch.bfloat16, 2
79
+ if storage_kind == "fp8":
80
+ if not allow_fp8:
81
+ raise NotImplementedError(
82
+ f"{context} (FP8 storage) is not supported in this path."
83
+ )
84
+ # FP8 now supported on both CUDA and CPU
85
+ if dtype != torch.float32:
86
+ raise NotImplementedError(
87
+ f"{context} (FP8 storage) is only supported for float32."
88
+ )
89
+ return storage_kind, torch.uint8, 1
90
+ raise RuntimeError(f"Unsupported storage compression mode: {storage_kind}")
91
+
92
+
93
+ def storage_mode_to_int(storage_mode_str: str) -> int:
94
+ mode = storage_mode_str.lower()
95
+ if mode == "device":
96
+ return STORAGE_DEVICE
97
+ if mode == "cpu":
98
+ return STORAGE_CPU
99
+ if mode == "disk":
100
+ return STORAGE_DISK
101
+ if mode == "none":
102
+ return STORAGE_NONE
103
+ raise ValueError(
104
+ "storage_mode must be 'device', 'cpu', 'disk', 'none', or 'auto', "
105
+ f"but got {storage_mode_str!r}"
106
+ )
107
+
108
+
109
+ class TemporaryStorage:
110
+ """Manages temporary files for disk storage.
111
+
112
+ Creates a unique subdirectory for each instantiation to prevent collisions.
113
+ """
114
+
115
+ def __init__(self, base_path: str, num_files: int) -> None:
116
+ self.base_dir = Path(base_path) / f"tide_tmp_{os.getpid()}_{uuid4().hex}"
117
+ self.base_dir.mkdir(parents=True, exist_ok=True)
118
+ self.filenames: list[str] = [
119
+ str(self.base_dir / f"shot_{i}.bin") for i in range(num_files)
120
+ ]
121
+
122
+ def get_filenames(self) -> list[str]:
123
+ return self.filenames
124
+
125
+ def close(self) -> None:
126
+ if self.base_dir.exists():
127
+ with contextlib.suppress(OSError):
128
+ shutil.rmtree(self.base_dir)
129
+
130
+ def __del__(self) -> None:
131
+ self.close()
tide/tide/libtide_C.so ADDED
Binary file