tide-GPR 0.0.9__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tide/utils.py ADDED
@@ -0,0 +1,274 @@
1
+ import math
2
+ from typing import Sequence
3
+
4
+ import torch
5
+
6
+ # Physical constants
7
+ EP0 = 8.8541878128e-12 # vacuum permittivity
8
+ MU0 = 1.2566370614359173e-06 # vacuum permeability
9
+ C0 = 1.0 / math.sqrt(EP0 * MU0) # speed of light in vacuum
10
+
11
+
12
+ def prepare_parameters(
13
+ epsilon_r: torch.Tensor, sigma: torch.Tensor, mu_r: torch.Tensor, dt: float
14
+ ) -> tuple:
15
+ """Prepare update coefficients for Maxwell equations.
16
+
17
+ Args:
18
+ epsilon_r: Relative permittivity.
19
+ sigma: Conductivity (S/m).
20
+ mu_r: Relative permeability.
21
+ dt: Time step (s).
22
+
23
+ Returns:
24
+ Tuple of (ca, cb, cq) coefficients.
25
+ """
26
+ # Convert to absolute values
27
+ epsilon = epsilon_r * EP0
28
+ mu = mu_r * MU0
29
+
30
+ # Ca and Cb for E-field update: E^{n+1} = Ca*E^n + Cb*(curl H)
31
+ # Ca = (1 - sigma*dt/(2*epsilon)) / (1 + sigma*dt/(2*epsilon))
32
+ # Cb = dt/epsilon / (1 + sigma*dt/(2*epsilon))
33
+ denom = 1.0 + sigma * dt / (2.0 * epsilon)
34
+ ca = (1.0 - sigma * dt / (2.0 * epsilon)) / denom
35
+ cb = (dt / epsilon) / denom
36
+
37
+ # Cq for H-field update: H^{n+1/2} = H^{n-1/2} - Cq*(curl E)
38
+ cq = dt / mu
39
+
40
+ return ca, cb, cq
41
+
42
+
43
+ def setup_pml(
44
+ pml_width: Sequence[int],
45
+ pml_start: Sequence[float],
46
+ max_pml: float,
47
+ dt: float,
48
+ n: int,
49
+ max_vel: float,
50
+ dtype: torch.dtype,
51
+ device: torch.device,
52
+ pml_freq: float,
53
+ start: float = 0.0,
54
+ r_val: float = 1e-8,
55
+ n_power: int = 4,
56
+ eps: float = 1e-9,
57
+ grid_spacing: float = 1.0,
58
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
59
+ """Creates a, b, and k profiles for electromagnetic C-PML.
60
+
61
+ This implementation follows the standard CPML formulation for electromagnetic
62
+ wave simulation with proper physical parameters.
63
+
64
+ Only the first fd_pad[0]+pml_width[0] and last fd_pad[1]+pml_width[1]
65
+ elements of the profiles will be non-zero.
66
+
67
+ Args:
68
+ pml_width: list of two integers specifying the width of the PML
69
+ region [left/bottom, right/top].
70
+ pml_start: list of two floats specifying the coordinates (in grid
71
+ cells) of the start of the PML regions.
72
+ max_pml: Float specifying the length (in distance units) of the
73
+ longest PML over all sides and dimensions.
74
+ dt: Time step interval.
75
+ n: Integer specifying desired profile length, including fd_pad and
76
+ pml_width.
77
+ max_vel: Maximum wave speed (not used in EM formulation, kept for API).
78
+ dtype: PyTorch datatype to use.
79
+ device: PyTorch device to use.
80
+ pml_freq: The frequency value to use for the profile (not used in EM,
81
+ kept for API compatibility).
82
+ start: Float specifying the coordinate (in grid cells) of the first
83
+ element. Optional, default 0.
84
+ r_val: The reflection coefficient. Optional, default 1e-8.
85
+ n_power: The power for the profile. Optional, default 4.
86
+ eps: A small number to prevent division by zero. Optional,
87
+ default 1e-9.
88
+ grid_spacing: The grid spacing (dx or dy). Optional, default 1.0.
89
+
90
+ Returns:
91
+ A tuple containing the (a, b, k) profiles as Tensors.
92
+ - a: CPML 'a' coefficient for recursive convolution
93
+ - b: CPML 'b' coefficient for recursive convolution
94
+ - k: CPML stretching factor
95
+
96
+ """
97
+ # CPML parameters for electromagnetic waves
98
+ k_max_cpml = 5.0 # maximum stretching factor
99
+ alpha_max_cpml = 0.008 # maximum frequency shift
100
+
101
+ # Create output tensors
102
+ a = torch.zeros(n, device=device, dtype=dtype)
103
+ b = torch.zeros(n, device=device, dtype=dtype)
104
+ k = torch.ones(n, device=device, dtype=dtype)
105
+
106
+ if max_pml == 0 or (pml_width[0] == 0 and pml_width[1] == 0):
107
+ return a, b, k
108
+
109
+ # Standard CPML sigma_max: sig0 = (Npower+1) / (150 * pi * dx)
110
+ sigma0 = (n_power + 1) / (150.0 * math.pi * grid_spacing)
111
+
112
+ # Calculate profiles for each grid point
113
+ x = torch.arange(start, start + n, device=device, dtype=dtype)
114
+
115
+ # Left/bottom PML region
116
+ if pml_width[0] > 0:
117
+ origin_left = pml_start[0] # This is the inner edge of PML (in grid cells)
118
+ abscissa_left = origin_left - x # Distance from inner edge (in grid cells)
119
+ mask_left = abscissa_left >= 0
120
+
121
+ # Normalized distance into PML (0 at inner edge, 1 at outer edge)
122
+ # abscissa_left is in grid cells, pml_width[0] is also in grid cells
123
+ abscissa_norm_left = torch.clamp(abscissa_left / pml_width[0], 0, 1)
124
+
125
+ # Sigma, k, alpha profiles with polynomial grading
126
+ sigma_left = sigma0 * (abscissa_norm_left**n_power)
127
+ k_left = 1.0 + (k_max_cpml - 1.0) * (abscissa_norm_left**n_power)
128
+ alpha_left = alpha_max_cpml * (1.0 - abscissa_norm_left) + 0.1 * alpha_max_cpml
129
+
130
+ # Apply to left region
131
+ k = torch.where(mask_left, k_left, k)
132
+
133
+ # Calculate b = exp(-(sigma/k + alpha) * dt / epsilon0)
134
+ b_left = torch.exp(-(sigma_left / k_left + alpha_left) * dt / EP0)
135
+ b = torch.where(mask_left, b_left, b)
136
+
137
+ # Calculate a = sigma * (b - 1) / (k * (sigma + k * alpha))
138
+ denom_left = k_left * (sigma_left + k_left * alpha_left) + eps
139
+ a_left = sigma_left * (b_left - 1.0) / denom_left
140
+ # Only apply where sigma is significant
141
+ a_left = torch.where(sigma_left > 1e-6, a_left, torch.zeros_like(a_left))
142
+ a = torch.where(mask_left, a_left, a)
143
+
144
+ # Right/top PML region
145
+ if pml_width[1] > 0:
146
+ origin_right = pml_start[1] # This is the inner edge of PML (in grid cells)
147
+ abscissa_right = x - origin_right # Distance from inner edge (in grid cells)
148
+ mask_right = abscissa_right >= 0
149
+
150
+ # Normalized distance into PML (0 at inner edge, 1 at outer edge)
151
+ # abscissa_right is in grid cells, pml_width[1] is also in grid cells
152
+ abscissa_norm_right = torch.clamp(abscissa_right / pml_width[1], 0, 1)
153
+
154
+ # Sigma, k, alpha profiles
155
+ sigma_right = sigma0 * (abscissa_norm_right**n_power)
156
+ k_right = 1.0 + (k_max_cpml - 1.0) * (abscissa_norm_right**n_power)
157
+ alpha_right = (
158
+ alpha_max_cpml * (1.0 - abscissa_norm_right) + 0.1 * alpha_max_cpml
159
+ )
160
+
161
+ # Apply to right region
162
+ k = torch.where(mask_right, k_right, k)
163
+
164
+ # Calculate b = exp(-(sigma/k + alpha) * dt / epsilon0)
165
+ b_right = torch.exp(-(sigma_right / k_right + alpha_right) * dt / EP0)
166
+ b = torch.where(mask_right, b_right, b)
167
+
168
+ # Calculate a = sigma * (b - 1) / (k * (sigma + k * alpha))
169
+ denom_right = k_right * (sigma_right + k_right * alpha_right) + eps
170
+ a_right = sigma_right * (b_right - 1.0) / denom_right
171
+ a_right = torch.where(sigma_right > 1e-6, a_right, torch.zeros_like(a_right))
172
+ a = torch.where(mask_right, a_right, a)
173
+
174
+ return a, b, k
175
+
176
+
177
+ def setup_pml_half(
178
+ pml_width: Sequence[int],
179
+ pml_start: Sequence[float],
180
+ max_pml: float,
181
+ dt: float,
182
+ n: int,
183
+ max_vel: float,
184
+ dtype: torch.dtype,
185
+ device: torch.device,
186
+ pml_freq: float,
187
+ start: float = 0.0,
188
+ r_val: float = 1e-8,
189
+ n_power: int = 4,
190
+ eps: float = 1e-9,
191
+ grid_spacing: float = 1.0,
192
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
193
+ """Creates a, b, and k profiles for C-PML at half grid points.
194
+
195
+ This is used for staggered grid implementations where some field components
196
+ are located at half grid points (e.g., H fields in Yee grid).
197
+
198
+ Args:
199
+ Same as setup_pml.
200
+
201
+ Returns:
202
+ A tuple containing the (a_half, b_half, k_half) profiles as Tensors.
203
+
204
+ """
205
+ # CPML parameters for electromagnetic waves
206
+ k_max_cpml = 5.0
207
+ alpha_max_cpml = 0.008
208
+
209
+ a = torch.zeros(n, device=device, dtype=dtype)
210
+ b = torch.zeros(n, device=device, dtype=dtype)
211
+ k = torch.ones(n, device=device, dtype=dtype)
212
+
213
+ if max_pml == 0 or (pml_width[0] == 0 and pml_width[1] == 0):
214
+ return a, b, k
215
+
216
+ # Standard CPML sigma_max: sig0 = (Npower+1) / (150 * pi * dx)
217
+ sigma0 = (n_power + 1) / (150.0 * math.pi * grid_spacing)
218
+
219
+ # Half grid positions (shifted by dx/2 or dy/2)
220
+ x = torch.arange(start, start + n, device=device, dtype=dtype) + 0.5
221
+
222
+ # Left/bottom PML region (half grid)
223
+ if pml_width[0] > 0:
224
+ origin_left = pml_start[0] # Inner edge of PML (in grid cells)
225
+ abscissa_left = origin_left - x # Distance in grid cells
226
+ mask_left = abscissa_left >= 0
227
+
228
+ # Normalized distance (both in grid cells)
229
+ abscissa_norm_left = torch.clamp(abscissa_left / pml_width[0], 0, 1)
230
+
231
+ sigma_left = sigma0 * (abscissa_norm_left**n_power)
232
+ k_left = 1.0 + (k_max_cpml - 1.0) * (abscissa_norm_left**n_power)
233
+ alpha_left = alpha_max_cpml * (1.0 - abscissa_norm_left) + 0.1 * alpha_max_cpml
234
+
235
+ k = torch.where(mask_left, k_left, k)
236
+
237
+ # b = exp(-(sigma/k + alpha) * dt / epsilon0)
238
+ b_left = torch.exp(-(sigma_left / k_left + alpha_left) * dt / EP0)
239
+ b = torch.where(mask_left, b_left, b)
240
+
241
+ # a = sigma * (b - 1) / (k * (sigma + k * alpha))
242
+ denom_left = k_left * (sigma_left + k_left * alpha_left) + eps
243
+ a_left = sigma_left * (b_left - 1.0) / denom_left
244
+ a_left = torch.where(sigma_left > 1e-6, a_left, torch.zeros_like(a_left))
245
+ a = torch.where(mask_left, a_left, a)
246
+
247
+ # Right/top PML region (half grid)
248
+ if pml_width[1] > 0:
249
+ origin_right = pml_start[1] # Inner edge of PML (in grid cells)
250
+ abscissa_right = x - origin_right # Distance in grid cells
251
+ mask_right = abscissa_right >= 0
252
+
253
+ # Normalized distance (both in grid cells)
254
+ abscissa_norm_right = torch.clamp(abscissa_right / pml_width[1], 0, 1)
255
+
256
+ sigma_right = sigma0 * (abscissa_norm_right**n_power)
257
+ k_right = 1.0 + (k_max_cpml - 1.0) * (abscissa_norm_right**n_power)
258
+ alpha_right = (
259
+ alpha_max_cpml * (1.0 - abscissa_norm_right) + 0.1 * alpha_max_cpml
260
+ )
261
+
262
+ k = torch.where(mask_right, k_right, k)
263
+
264
+ # b = exp(-(sigma/k + alpha) * dt / epsilon0)
265
+ b_right = torch.exp(-(sigma_right / k_right + alpha_right) * dt / EP0)
266
+ b = torch.where(mask_right, b_right, b)
267
+
268
+ # a = sigma * (b - 1) / (k * (sigma + k * alpha))
269
+ denom_right = k_right * (sigma_right + k_right * alpha_right) + eps
270
+ a_right = sigma_right * (b_right - 1.0) / denom_right
271
+ a_right = torch.where(sigma_right > 1e-6, a_right, torch.zeros_like(a_right))
272
+ a = torch.where(mask_right, a_right, a)
273
+
274
+ return a, b, k
tide/validation.py ADDED
@@ -0,0 +1,71 @@
1
+ """Validation helpers for user-facing parameters."""
2
+
3
+
4
+ def validate_model_gradient_sampling_interval(
5
+ model_gradient_sampling_interval: int,
6
+ ) -> int:
7
+ """Validate the model gradient sampling interval parameter.
8
+
9
+ The gradient sampling interval controls memory usage during backpropagation.
10
+ Setting it > 1 reduces memory by storing fewer snapshots.
11
+
12
+ Args:
13
+ model_gradient_sampling_interval: Number of time steps between
14
+ gradient snapshots.
15
+
16
+ Returns:
17
+ Validated interval value.
18
+
19
+ Raises:
20
+ TypeError: If not an integer.
21
+ ValueError: If negative.
22
+ """
23
+ if not isinstance(model_gradient_sampling_interval, int):
24
+ raise TypeError("model_gradient_sampling_interval must be an int")
25
+ if model_gradient_sampling_interval < 0:
26
+ raise ValueError("model_gradient_sampling_interval must be >= 0")
27
+ return model_gradient_sampling_interval
28
+
29
+
30
+ def validate_freq_taper_frac(freq_taper_frac: float) -> float:
31
+ """Validate the frequency taper fraction parameter.
32
+
33
+ Args:
34
+ freq_taper_frac: Fraction of frequencies to taper (0.0-1.0).
35
+
36
+ Returns:
37
+ Validated fraction value.
38
+
39
+ Raises:
40
+ TypeError: If not convertible to float.
41
+ ValueError: If not in [0, 1].
42
+ """
43
+ try:
44
+ freq_taper_frac = float(freq_taper_frac)
45
+ except (TypeError, ValueError) as e:
46
+ raise TypeError("freq_taper_frac must be convertible to float") from e
47
+ if not 0.0 <= freq_taper_frac <= 1.0:
48
+ raise ValueError(f"freq_taper_frac must be in [0, 1], got {freq_taper_frac}")
49
+ return freq_taper_frac
50
+
51
+
52
+ def validate_time_pad_frac(time_pad_frac: float) -> float:
53
+ """Validate the time padding fraction parameter.
54
+
55
+ Args:
56
+ time_pad_frac: Fraction of time axis for zero padding (0.0-1.0).
57
+
58
+ Returns:
59
+ Validated fraction value.
60
+
61
+ Raises:
62
+ TypeError: If not convertible to float.
63
+ ValueError: If not in [0, 1].
64
+ """
65
+ try:
66
+ time_pad_frac = float(time_pad_frac)
67
+ except (TypeError, ValueError) as e:
68
+ raise TypeError("time_pad_frac must be convertible to float") from e
69
+ if not 0.0 <= time_pad_frac <= 1.0:
70
+ raise ValueError(f"time_pad_frac must be in [0, 1], got {time_pad_frac}")
71
+ return time_pad_frac
tide/wavelets.py ADDED
@@ -0,0 +1,72 @@
1
+ """Common electromagnetic wavelets for TIDE simulations.
2
+
3
+ This module provides various source wavelet functions commonly used in
4
+ electromagnetic wave simulations, particularly for Ground Penetrating Radar (GPR)
5
+ and other time-domain electromagnetic methods.
6
+
7
+ All wavelets return PyTorch tensors and support optional dtype specification.
8
+ """
9
+
10
+ import math
11
+ from typing import Optional
12
+
13
+ import torch
14
+
15
+
16
+ def ricker(
17
+ freq: float,
18
+ length: int,
19
+ dt: float,
20
+ peak_time: Optional[float] = None,
21
+ dtype: Optional[torch.dtype] = None,
22
+ device: Optional[torch.device] = None,
23
+ ) -> torch.Tensor:
24
+ """Return a Ricker wavelet (Mexican hat wavelet).
25
+
26
+ The Ricker wavelet is the negative normalized second derivative of a
27
+ Gaussian function. It is commonly used in seismic and GPR simulations.
28
+
29
+ The formula used is:
30
+ w(t) = -(2*pi^2*(f*t' - 1)^2 - 1) * exp(-pi^2*(f*t' - 1)^2)
31
+
32
+ where t' = t - peak_time.
33
+
34
+ Args:
35
+ freq: The central (dominant) frequency in Hz.
36
+ length: The number of time samples.
37
+ dt: The time sample spacing in seconds.
38
+ peak_time: The time (in seconds) of the peak amplitude. If None,
39
+ defaults to 1/freq (one period after start).
40
+ dtype: The PyTorch datatype to use. Optional, defaults to float32.
41
+ device: The PyTorch device to use. Optional, defaults to CPU.
42
+
43
+ Returns:
44
+ A PyTorch tensor representing the Ricker wavelet.
45
+
46
+ Example:
47
+ >>> # Create a 100 MHz Ricker wavelet for GPR
48
+ >>> freq = 100e6 # 100 MHz
49
+ >>> dt = 1e-10 # 0.1 ns time step
50
+ >>> length = 500 # 500 time samples
51
+ >>> wavelet = ricker(freq, length, dt)
52
+
53
+ """
54
+ if dt == 0:
55
+ raise ValueError("dt cannot be zero.")
56
+ if freq <= 0:
57
+ raise ValueError("freq must be positive.")
58
+ if length <= 0:
59
+ raise ValueError("length must be positive.")
60
+
61
+ if peak_time is None:
62
+ peak_time = 1.0 / freq # Default: one period
63
+
64
+ t = torch.arange(float(length), dtype=dtype, device=device) * dt
65
+ t_shifted = t - peak_time
66
+
67
+ # Ricker wavelet formula
68
+ pi2_f2_t2 = (math.pi * freq * t_shifted) ** 2
69
+ y = (1 - 2 * pi2_f2_t2) * torch.exp(-pi2_f2_t2)
70
+
71
+ return y
72
+
@@ -0,0 +1,256 @@
1
+ Metadata-Version: 2.2
2
+ Name: tide-GPR
3
+ Version: 0.0.9
4
+ Summary: Torch-based Inversion & Development Engine for electromagnetic wave propagation
5
+ Keywords: pytorch,electromagnetic,wave-propagation,maxwell-equations,fdtd,full-waveform-inversion,fwi,geophysics,inverse-problems,cuda
6
+ Author-Email: "V.cholerae" <v.cholerae1@gmail.com>
7
+ Maintainer-Email: "V.cholerae" <v.cholerae1@gmail.com>
8
+ License: MIT
9
+ Classifier: Development Status :: 4 - Beta
10
+ Classifier: Intended Audience :: Science/Research
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Classifier: Programming Language :: C
15
+ Classifier: Programming Language :: C++
16
+ Classifier: Topic :: Scientific/Engineering :: Physics
17
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
18
+ Classifier: Operating System :: POSIX :: Linux
19
+ Classifier: Operating System :: Microsoft :: Windows
20
+ Classifier: Operating System :: MacOS
21
+ Requires-Python: >=3.12
22
+ Requires-Dist: matplotlib>=3.10.7
23
+ Requires-Dist: numpy>=2.3.5
24
+ Requires-Dist: scipy>=1.16.3
25
+ Requires-Dist: torch>=2.9.1
26
+ Provides-Extra: dev
27
+ Requires-Dist: pytest>=7.0; extra == "dev"
28
+ Requires-Dist: pytest-cov>=4.0; extra == "dev"
29
+ Description-Content-Type: text/markdown
30
+
31
+ # TIDE
32
+
33
+ **T**orch-based **I**nversion & **D**evelopment **E**ngine
34
+
35
+ TIDE is a PyTorch-based library for high frequa electromagnetic wave propagation and inversion, built on Maxwell's equations. It provides efficient CPU and CUDA implementations for forward modeling, gradient computation, and full waveform inversion (FWI).
36
+
37
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
38
+
39
+ ## Features
40
+
41
+ - **Maxwell Equation Solvers**:
42
+ - 2D TM mode propagation (`MaxwellTM`)
43
+ - Other propagation is on the way
44
+ - **Automatic Differentiation**: Gradient support through PyTorch's autograd
45
+ - **High Performance**: Optimized C/CUDA kernels for critical operations
46
+ - **Flexible Storage**: Multiple storage modes for gradient computation (memory/disk/optional BF16 compressed)
47
+ - **Staggered Grid**: Industry-standard FDTD staggered grid implementation
48
+ - **PML Boundaries**: Perfectly Matched Layer absorbing boundaries
49
+
50
+ ## Installation
51
+
52
+ ### From PyPI
53
+
54
+
55
+ ```bash
56
+ uv pip install tide-GPR
57
+ ```
58
+
59
+ or
60
+
61
+ ```bash
62
+ pip install tide-GPR
63
+ ```
64
+
65
+ ### From Source
66
+
67
+ We recommend using [uv](https://github.com/astral-sh/uv) for building:
68
+
69
+ ```bash
70
+ git clone https://github.com/vcholerae1/tide.git
71
+ cd tide
72
+ uv build
73
+ ```
74
+
75
+ **Requirements:**
76
+ - Python >= 3.12
77
+ - PyTorch >= 2.9.1
78
+ - CUDA Toolkit (optional, for GPU support)
79
+ - CMake >= 3.28 (optional, for building from source)
80
+
81
+ ## Quick Start
82
+
83
+ ```python
84
+ import torch
85
+ import tide
86
+
87
+ # Create a simple model
88
+ nx, ny = 200, 100
89
+ epsilon = torch.ones(ny, nx) * 4.0 # Relative permittivity
90
+ epsilon[50:, :] = 9.0 # Add a layer
91
+
92
+ # Set up source
93
+ source_amplitudes = tide.ricker(
94
+ freq=1e9, # 1 GHz
95
+ nt=1000,
96
+ dt=1e-11,
97
+ peak_time=5e-10
98
+ ).reshape(1, 1, -1)
99
+
100
+ source_locations = torch.tensor([[[10, 100]]])
101
+ receiver_locations = torch.tensor([[[10, 150]]])
102
+
103
+ # Run forward simulation
104
+ receiver_data = tide.maxwelltm(
105
+ epsilon=epsilon,
106
+ dx=0.01,
107
+ dt=1e-11,
108
+ source_amplitudes=source_amplitudes,
109
+ source_locations=source_locations,
110
+ receiver_locations=receiver_locations,
111
+ pml_width=10
112
+ )
113
+
114
+ print(f"Recorded data shape: {receiver_data.shape}")
115
+ ```
116
+
117
+ ## Core Modules
118
+
119
+ - **`tide.maxwelltm`**: 2D TM mode Maxwell solver
120
+ - **`tide.wavelets`**: Source wavelet generation (Ricker, etc.)
121
+ - **`tide.staggered`**: Staggered grid finite difference operators
122
+ - **`tide.callbacks`**: Callback state and factories
123
+ - **`tide.resampling`**: Upsampling/downsampling utilities
124
+ - **`tide.cfl`**: CFL condition helpers
125
+ - **`tide.padding`**: Padding and interior masking helpers
126
+ - **`tide.validation`**: Input validation helpers
127
+ - **`tide.storage`**: Gradient checkpointing and storage management
128
+
129
+ ## Examples
130
+
131
+ See the [`examples/`](examples/) directory for complete workflows:
132
+
133
+ - [`example_multiscale_filtered.py`](examples/example_multiscale_filtered.py): Multi-scale FWI with frequency filtering
134
+ - [`example_multiscale_random_sources.py`](examples/example_multiscale_random_sources.py): FWI with random source encoding
135
+ - [`wavefield_animation.py`](examples/wavefield_animation.py): Visualize wave propagation
136
+
137
+ ## Documentation
138
+
139
+ For detailed API documentation and tutorials, visit: [Documentation]() *(coming soon)*
140
+
141
+ ## Testing
142
+
143
+ Run the test suite:
144
+
145
+ ```bash
146
+ pytest tests/
147
+ ```
148
+
149
+ ## Contributing
150
+
151
+ Contributions are welcome! Please feel free to submit a Pull Request.
152
+
153
+ ## Acknowledgments
154
+
155
+ This project includes code derived from [Deepwave](https://github.com/ar4/deepwave) by Alan Richardson. We gratefully acknowledge the foundational work that made TIDE possible.
156
+
157
+ ## Citation
158
+
159
+ If you use TIDE in your research, please cite:
160
+
161
+ ```bibtex
162
+ @software{tide2025,
163
+ author = {Vcholerae1},
164
+ title = {TIDE: Torch-based Inversion \& Development Engine},
165
+ year = {2025},
166
+ url = {https://github.com/vcholerae1/tide}
167
+ }
168
+ ```
169
+
170
+ ## License
171
+
172
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
173
+
174
+
175
+
176
+ ```mermaid
177
+ graph TD
178
+ %% --- Styling Definitions ---
179
+ classDef userLayer fill:#e1f5fe,stroke:#0288d1,stroke-width:2px,color:#01579b;
180
+ classDef apiLayer fill:#fff9c4,stroke:#fbc02d,stroke-width:2px,color:#f57f17;
181
+ classDef coreLayer fill:#ffe0b2,stroke:#f57c00,stroke-width:2px,color:#e65100;
182
+ classDef supportLayer fill:#e8f5e9,stroke:#388e3c,stroke-width:1px,color:#1b5e20;
183
+ classDef nativeLayer fill:#cfd8dc,stroke:#455a64,stroke-width:3px,color:#263238;
184
+ classDef torchFrame fill:none,stroke:#ee3333,stroke-width:2px,stroke-dasharray: 5 5,color:#ee3333;
185
+
186
+ %% --- Top Level ---
187
+ UserCode[用户代码 / 脚本<br>FWI, 建模任务]:::userLayer
188
+
189
+ subgraph PyTorch_Environment [PyTorch 生态系统 <br> Autograd Engine & Tensor Management]
190
+ style PyTorch_Environment fill:#fff3f3,stroke:#ffcdd2
191
+
192
+ %% --- Layer 1: User API ---
193
+ subgraph L1_UserAPI [第1层: 用户 API 层]
194
+ direction LR
195
+ MaxwellTM([MaxwellTM <br> nn.Module 子类]):::apiLayer
196
+ Utils[工具函数 <br> ricker, wavelets, validation]:::apiLayer
197
+ end
198
+
199
+ UserCode -->|初始化 & 调用| MaxwellTM
200
+ UserCode -.->|使用| Utils
201
+
202
+ %% --- Layer 3: Support (Side by Side) ---
203
+ subgraph L3_Support [第3层: 数值计算与辅助支持]
204
+ Staggered[staggered.py <br> 交错网格/PML定义]:::supportLayer
205
+ Callbacks[callbacks.py <br> 状态监控/观察者]:::supportLayer
206
+ Storage[storage.py <br> 波场快照存储策略<br>CPU/GPU/DISK/Compress]:::supportLayer
207
+ end
208
+
209
+ %% --- Layer 2: Core Solver ---
210
+ subgraph L2_Core [第2层: 核心求解器 maxwell.py]
211
+ MaxwellTMForwardFunc{{MaxwellTMForwardFunc <br> autograd.Function}}:::coreLayer
212
+ BackendDispatch[后端分发器 <br> maxwell_func]:::coreLayer
213
+ end
214
+
215
+ %% Connections within Python Levels
216
+ MaxwellTM -->|调用 forward| MaxwellTMForwardFunc
217
+ MaxwellTMForwardFunc -->|管理| BackendDispatch
218
+ MaxwellTMForwardFunc -.->|依赖| Storage
219
+ MaxwellTM -.->|配置| Staggered
220
+ MaxwellTM -->|触发| Callbacks
221
+
222
+ %% --- Data Flow Arrows ---
223
+ %% Forward Path
224
+ MaxwellTMForwardFunc == "正向传播 (Forward Pass)<br>FDTD 时间步进" ==> BackendDispatch
225
+ BackendDispatch -- 路径 A: Python (调试) --> PyImpl[纯 Python 实现<br>update_E/H]:::supportLayer
226
+
227
+ %% Backward Path
228
+ PyTorch_Environment -.- |"loss.backward() <br> 触发自动微分"| MaxwellTMForwardFunc
229
+ BackendDispatch <== "反向传播 (Backward Pass)<br>伴随状态法 / 梯度计算" == MaxwellTMForwardFunc
230
+ end
231
+
232
+ %% --- Layer 4: Native Backend ---
233
+ subgraph L4_Native [第4层: C/CUDA 内核层 libtide_C.so]
234
+ direction LR
235
+ CUDA_Kernels[maxwell.cu <br> GPU FDTD 内核]:::nativeLayer
236
+ Born_Kernels[maxwell_born.cu <br> Born 近似内核]:::nativeLayer
237
+ C_Staggered[staggered_grid.h <br> C数据结构]:::nativeLayer
238
+ end
239
+
240
+ %% Crossing the boundary
241
+ BackendDispatch -- "路径 B: C/CUDA (高性能)<br>ctypes 调用" --> CUDA_Kernels
242
+ BackendDispatch -- "反向路径调用" --> Born_Kernels
243
+ CUDA_Kernels -.-> C_Staggered
244
+ Born_Kernels -.-> C_Staggered
245
+ Storage -.->|C++实现| CUDA_Kernels
246
+
247
+ %% Legend
248
+ subgraph Legend [图例]
249
+ L_Py[Python 模块]:::apiLayer
250
+ L_Core[核心 Autograd]:::coreLayer
251
+ L_Sup[辅助支持]:::supportLayer
252
+ L_Nat[C/CUDA 原生库]:::nativeLayer
253
+ L_Flow_F[正向数据流] ==> L_Flow_F_End
254
+ L_Flow_B[反向数据流] <== L_Flow_B_End
255
+ end
256
+ ```