tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tide/padding.py ADDED
@@ -0,0 +1,139 @@
1
+ from typing import Union
2
+
3
+ import torch
4
+
5
+
6
+ def reverse_pad(pad: list[int]) -> list[int]:
7
+ """Reverse the padding order for use with torch.nn.functional.pad.
8
+
9
+ PyTorch's pad function expects padding in reverse order (last dim first).
10
+ This function converts [y0, y1, x0, x1] to [x0, x1, y0, y1].
11
+
12
+ Args:
13
+ pad: Padding values in [y0, y1, x0, x1] format.
14
+
15
+ Returns:
16
+ Padding values in PyTorch format [x0, x1, y0, y1].
17
+ """
18
+ # For 2D: [y0, y1, x0, x1] -> [x0, x1, y0, y1]
19
+ result = []
20
+ for i in range(len(pad) // 2 - 1, -1, -1):
21
+ result.extend([pad[i * 2], pad[i * 2 + 1]])
22
+ return result
23
+
24
+
25
+ def create_or_pad(
26
+ tensor: torch.Tensor,
27
+ pad: Union[int, list[int]],
28
+ device: torch.device,
29
+ dtype: torch.dtype,
30
+ size: tuple[int, ...],
31
+ mode: str = "constant",
32
+ ) -> torch.Tensor:
33
+ """Creates a zero tensor of specified size or pads an existing tensor.
34
+
35
+ If the input tensor is empty (numel == 0), a new zero tensor with the
36
+ given size is created. Otherwise, the tensor is padded according to
37
+ the specified mode.
38
+
39
+ This is a unified padding function that supports:
40
+ - Zero padding (mode='constant') for wavefields
41
+ - Replicate padding (mode='replicate') for models
42
+
43
+ Args:
44
+ tensor: The input tensor to be created or padded.
45
+ pad: The padding to apply. Can be an integer (for uniform padding)
46
+ or a list of integers [y0, y1, x0, x1] for per-side padding.
47
+ device: The PyTorch device for the tensor.
48
+ dtype: The PyTorch data type for the tensor.
49
+ size: The desired size of the tensor if it needs to be created.
50
+ mode: Padding mode ('constant', 'replicate', 'reflect', 'circular').
51
+ Default is 'constant' (zero padding)
52
+
53
+ Returns:
54
+ The created or padded tensor.
55
+
56
+ Example:
57
+ >>> # Create a zero tensor of size [2, 110, 110] (batch=2, with padding)
58
+ >>> wf = create_or_pad(torch.empty(0), 5, device, dtype, (2, 110, 110))
59
+ >>>
60
+ >>> # Pad a wavefield with zeros [2, 100, 100] -> [2, 110, 110]
61
+ >>> wf_padded = create_or_pad(wf, [5, 5, 5, 5], device, dtype, (2, 110, 110))
62
+ >>>
63
+ >>> # Pad a model with replicate mode [100, 100] -> [110, 110]
64
+ >>> eps_padded = create_or_pad(eps, [5, 5, 5, 5], device, dtype, (110, 110), mode='replicate')
65
+ """
66
+ if isinstance(pad, int):
67
+ # Convert single int to [pad, pad, pad, pad, ...] for each spatial dimension
68
+ # size includes batch dimension if len > 2, so spatial ndim = len(size) - 1 or len(size)
69
+ ndim = len(size) - 1 if len(size) > 2 else len(size)
70
+ pad = [pad] * ndim * 2
71
+
72
+ if tensor.numel() == 0:
73
+ return torch.zeros(size, device=device, dtype=dtype)
74
+
75
+ if max(pad) == 0:
76
+ return tensor.clone()
77
+
78
+ # Reverse padding for PyTorch's pad function
79
+ reversed_pad = reverse_pad(pad)
80
+
81
+ # For non-constant padding modes (replicate, reflect, circular),
82
+ # PyTorch requires:
83
+ # - 2D spatial padding: 3D or 4D input
84
+ # - 3D spatial padding: 4D or 5D input
85
+ original_ndim = tensor.ndim
86
+ needs_unsqueeze = original_ndim in {2, 3} and mode != "constant"
87
+
88
+ if needs_unsqueeze:
89
+ tensor = tensor.unsqueeze(0)
90
+
91
+ result = torch.nn.functional.pad(tensor, reversed_pad, mode=mode)
92
+
93
+ if needs_unsqueeze:
94
+ result = result.squeeze(0)
95
+
96
+ # PyTorch's autograd system automatically tracks gradients through operations.
97
+ # Explicitly calling requires_grad_() is incompatible with torch.func transforms.
98
+ # Simply return the result; gradient tracking is handled automatically.
99
+ return result
100
+
101
+
102
+ def zero_interior(
103
+ tensor: torch.Tensor,
104
+ fd_pad: Union[int, list[int]],
105
+ pml_width: list[int],
106
+ dim: int,
107
+ ) -> torch.Tensor:
108
+ """Zero out the interior region of a tensor (keeping only PML regions).
109
+
110
+ This is used for CPML auxiliary variables which should only be non-zero
111
+ in the PML regions. Setting the interior to zero allows the propagator
112
+ to skip unnecessary PML calculations in those regions.
113
+
114
+ Args:
115
+ tensor: The input tensor with shape [batch, ny, nx].
116
+ fd_pad: Finite difference padding. Can be an int or list [y0, y1, x0, x1].
117
+ pml_width: The width of PML regions [top, bottom, left, right].
118
+ dim: The spatial dimension to zero (0 for y, 1 for x).
119
+
120
+ Returns:
121
+ The tensor with interior region zeroed out.
122
+ """
123
+ shape = tensor.shape[1:] # Spatial dimensions (without batch)
124
+ ndim = len(shape)
125
+
126
+ if isinstance(fd_pad, int):
127
+ fd_pad = [fd_pad] * 2 * ndim
128
+
129
+ # Calculate interior slice for the specified dimension
130
+ interior_start = fd_pad[dim * 2] + pml_width[dim * 2]
131
+ interior_end = shape[dim] - pml_width[dim * 2 + 1] - fd_pad[dim * 2 + 1]
132
+
133
+ # Zero out the interior
134
+ if dim == 0: # y dimension
135
+ tensor[:, interior_start:interior_end, :].fill_(0)
136
+ else: # x dimension
137
+ tensor[:, :, interior_start:interior_end].fill_(0)
138
+
139
+ return tensor
tide/resampling.py ADDED
@@ -0,0 +1,246 @@
1
+ """Signal resampling utilities for CFL handling."""
2
+
3
+ import math
4
+
5
+ import torch
6
+
7
+
8
+ def cosine_taper_end(signal: torch.Tensor, taper_len: int) -> torch.Tensor:
9
+ """Apply a cosine taper to the end of the signal in the last dimension.
10
+
11
+ Args:
12
+ signal: Input tensor to taper.
13
+ taper_len: Number of samples to taper at the end.
14
+
15
+ Returns:
16
+ Tapered signal.
17
+ """
18
+ if taper_len <= 0 or signal.shape[-1] <= taper_len:
19
+ return signal
20
+
21
+ # Create taper: 1 -> 0 over taper_len samples
22
+ taper = 0.5 * (
23
+ 1 + torch.cos(torch.linspace(0, math.pi, taper_len, device=signal.device))
24
+ )
25
+ # Apply taper to the last taper_len elements
26
+ signal = signal.clone()
27
+ signal[..., -taper_len:] = signal[..., -taper_len:] * taper
28
+ return signal
29
+
30
+
31
+ def zero_last_element_of_final_dimension(signal: torch.Tensor) -> torch.Tensor:
32
+ """Zero the last element of the final dimension (Nyquist frequency).
33
+
34
+ This is used to avoid aliasing when resampling signals in the frequency domain.
35
+
36
+ Args:
37
+ signal: Input tensor.
38
+
39
+ Returns:
40
+ Signal with last element of final dimension set to zero.
41
+ """
42
+ signal = signal.clone()
43
+ signal[..., -1] = 0
44
+ return signal
45
+
46
+
47
+ def upsample(
48
+ signal: torch.Tensor,
49
+ step_ratio: int,
50
+ freq_taper_frac: float = 0.0,
51
+ time_pad_frac: float = 0.0,
52
+ time_taper: bool = False,
53
+ ) -> torch.Tensor:
54
+ """Upsample the final dimension of a tensor by a factor.
55
+
56
+ Low-pass upsampling is used to produce an upsampled signal without
57
+ introducing higher frequencies than were present in the input.
58
+
59
+ This is typically used when the CFL condition requires a smaller internal
60
+ time step than the user-provided time step.
61
+
62
+ Args:
63
+ signal: Tensor to upsample (time should be the last dimension).
64
+ step_ratio: Integer factor by which to upsample.
65
+ freq_taper_frac: Fraction of frequency spectrum end to taper (0.0-1.0).
66
+ Helps reduce ringing artifacts.
67
+ time_pad_frac: Fraction of signal length for zero padding (0.0-1.0).
68
+ Helps reduce wraparound artifacts.
69
+ time_taper: Whether to apply a Hann window in time.
70
+ Useful for correctness tests to ensure signals taper to zero.
71
+
72
+ Returns:
73
+ Upsampled signal.
74
+
75
+ Example:
76
+ >>> # Source with 100 time samples, need 3x internal steps for CFL
77
+ >>> source = torch.randn(1, 1, 100)
78
+ >>> source_upsampled = upsample(source, step_ratio=3)
79
+ >>> print(source_upsampled.shape) # [1, 1, 300]
80
+ """
81
+ if signal.numel() == 0 or step_ratio == 1:
82
+ return signal
83
+
84
+ # Optional zero padding to reduce wraparound artifacts
85
+ n_time_pad = int(time_pad_frac * signal.shape[-1]) if time_pad_frac > 0.0 else 0
86
+ if n_time_pad > 0:
87
+ signal = torch.nn.functional.pad(signal, (0, n_time_pad))
88
+
89
+ nt = signal.shape[-1]
90
+ up_nt = nt * step_ratio
91
+
92
+ # Transform to frequency domain
93
+ signal_f = torch.fft.rfft(signal, norm="ortho") * math.sqrt(step_ratio)
94
+
95
+ # Apply frequency taper or zero Nyquist
96
+ if freq_taper_frac > 0.0:
97
+ freq_taper_len = int(freq_taper_frac * signal_f.shape[-1])
98
+ signal_f = cosine_taper_end(signal_f, freq_taper_len)
99
+ elif signal_f.shape[-1] > 1:
100
+ signal_f = zero_last_element_of_final_dimension(signal_f)
101
+
102
+ # Zero-pad in frequency domain for upsampling
103
+ pad_len = up_nt // 2 + 1 - signal_f.shape[-1]
104
+ if pad_len > 0:
105
+ signal_f = torch.nn.functional.pad(signal_f, (0, pad_len))
106
+
107
+ # Back to time domain
108
+ signal = torch.fft.irfft(signal_f, n=up_nt, norm="ortho")
109
+
110
+ # Remove padding
111
+ if n_time_pad > 0:
112
+ signal = signal[..., : signal.shape[-1] - n_time_pad * step_ratio]
113
+
114
+ # Optional time taper
115
+ if time_taper:
116
+ signal = signal * torch.hann_window(
117
+ signal.shape[-1],
118
+ periodic=False,
119
+ device=signal.device,
120
+ )
121
+
122
+ return signal
123
+
124
+
125
+ def downsample(
126
+ signal: torch.Tensor,
127
+ step_ratio: int,
128
+ freq_taper_frac: float = 0.0,
129
+ time_pad_frac: float = 0.0,
130
+ time_taper: bool = False,
131
+ shift: float = 0.0,
132
+ ) -> torch.Tensor:
133
+ """Downsample the final dimension of a tensor by a factor.
134
+
135
+ Frequencies higher than or equal to the Nyquist frequency of the
136
+ downsampled signal will be zeroed before downsampling.
137
+
138
+ This is typically used when the internal time step is smaller than the
139
+ user-provided time step due to CFL requirements.
140
+
141
+ Args:
142
+ signal: Tensor to downsample (time should be the last dimension).
143
+ step_ratio: Integer factor by which to downsample.
144
+ freq_taper_frac: Fraction of frequency spectrum end to taper (0.0-1.0).
145
+ Helps reduce ringing artifacts.
146
+ time_pad_frac: Fraction of signal length for zero padding (0.0-1.0).
147
+ Helps reduce wraparound artifacts.
148
+ time_taper: Whether to apply a Hann window in time.
149
+ Useful for correctness tests.
150
+ shift: Amount to shift in time before downsampling (in time samples).
151
+
152
+ Returns:
153
+ Downsampled signal.
154
+
155
+ Example:
156
+ >>> # Receiver data at internal rate, downsample to user rate
157
+ >>> data = torch.randn(300, 1, 1) # [nt_internal, shot, receiver]
158
+ >>> data_ds = downsample(data.movedim(0, -1), step_ratio=3).movedim(-1, 0)
159
+ >>> print(data_ds.shape) # [100, 1, 1]
160
+ """
161
+ if signal.numel() == 0 or (step_ratio == 1 and shift == 0.0):
162
+ return signal
163
+
164
+ # Optional time taper
165
+ if time_taper:
166
+ signal = signal * torch.hann_window(
167
+ signal.shape[-1],
168
+ periodic=False,
169
+ device=signal.device,
170
+ )
171
+
172
+ # Optional zero padding
173
+ n_time_pad = (
174
+ int(time_pad_frac * (signal.shape[-1] // step_ratio))
175
+ if time_pad_frac > 0.0
176
+ else 0
177
+ )
178
+ if n_time_pad > 0:
179
+ signal = torch.nn.functional.pad(signal, (0, n_time_pad * step_ratio))
180
+
181
+ nt = signal.shape[-1]
182
+ down_nt = nt // step_ratio
183
+
184
+ # Transform to frequency domain, keeping only frequencies below new Nyquist
185
+ signal_f = torch.fft.rfft(signal, norm="ortho")[..., : down_nt // 2 + 1]
186
+
187
+ # Apply frequency taper or zero Nyquist
188
+ if freq_taper_frac > 0.0:
189
+ freq_taper_len = int(freq_taper_frac * signal_f.shape[-1])
190
+ signal_f = cosine_taper_end(signal_f, freq_taper_len)
191
+ elif signal_f.shape[-1] > 1:
192
+ signal_f = zero_last_element_of_final_dimension(signal_f)
193
+
194
+ # Apply time shift in frequency domain
195
+ if shift != 0.0:
196
+ freqs = torch.fft.rfftfreq(signal.shape[-1], device=signal.device)[
197
+ : down_nt // 2 + 1
198
+ ]
199
+ signal_f = signal_f * torch.exp(-1j * 2 * math.pi * freqs * shift)
200
+
201
+ # Back to time domain
202
+ signal = torch.fft.irfft(signal_f, n=down_nt, norm="ortho") / math.sqrt(step_ratio)
203
+
204
+ # Remove padding
205
+ if n_time_pad > 0:
206
+ signal = signal[..., : signal.shape[-1] - n_time_pad]
207
+
208
+ return signal
209
+
210
+
211
+ def downsample_and_movedim(
212
+ receiver_amplitudes: torch.Tensor,
213
+ step_ratio: int,
214
+ freq_taper_frac: float = 0.0,
215
+ time_pad_frac: float = 0.0,
216
+ time_taper: bool = False,
217
+ shift: float = 0.0,
218
+ ) -> torch.Tensor:
219
+ """Downsample receiver data and move time dimension to last axis.
220
+
221
+ Convenience function that combines downsampling with moving the time
222
+ dimension to the expected output format [shot, receiver, time].
223
+
224
+ Args:
225
+ receiver_amplitudes: Receiver data [time, shot, receiver].
226
+ step_ratio: Integer factor by which to downsample.
227
+ freq_taper_frac: Fraction of frequency spectrum to taper.
228
+ time_pad_frac: Fraction for zero padding.
229
+ time_taper: Whether to apply Hann window.
230
+ shift: Time shift before downsampling.
231
+
232
+ Returns:
233
+ Processed receiver data [shot, receiver, time].
234
+ """
235
+ if receiver_amplitudes.numel() > 0:
236
+ # Move time to last dimension for processing
237
+ receiver_amplitudes = torch.movedim(receiver_amplitudes, 0, -1)
238
+ receiver_amplitudes = downsample(
239
+ receiver_amplitudes,
240
+ step_ratio,
241
+ freq_taper_frac=freq_taper_frac,
242
+ time_pad_frac=time_pad_frac,
243
+ time_taper=time_taper,
244
+ shift=shift,
245
+ )
246
+ return receiver_amplitudes