ladam 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ladam-0.2.0/.gitignore ADDED
@@ -0,0 +1,16 @@
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.egg-info/
5
+ *.egg
6
+ dist/
7
+ build/
8
+ .eggs/
9
+ *.whl
10
+ .pytest_cache/
11
+ .tox/
12
+ .coverage
13
+ htmlcov/
14
+ *.so
15
+ .env
16
+ .venv
ladam-0.2.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Greg Partin
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ladam-0.2.0/PKG-INFO ADDED
@@ -0,0 +1,246 @@
1
+ Metadata-Version: 2.4
2
+ Name: ladam
3
+ Version: 0.2.0
4
+ Summary: LAdam: Laplacian Adam — Adam with spatially-coupled variance estimates via discrete Laplacian
5
+ Project-URL: Homepage, https://github.com/gpartin/ladam
6
+ Project-URL: Documentation, https://github.com/gpartin/ladam#usage
7
+ Project-URL: Issues, https://github.com/gpartin/ladam/issues
8
+ Author: Greg Partin
9
+ License: MIT
10
+ License-File: LICENSE
11
+ Keywords: adam,deep-learning,laplacian,optimizer,pinn,pytorch,scientific-ml,spatial-regularization
12
+ Classifier: Development Status :: 4 - Beta
13
+ Classifier: Intended Audience :: Developers
14
+ Classifier: Intended Audience :: Science/Research
15
+ Classifier: License :: OSI Approved :: MIT License
16
+ Classifier: Programming Language :: Python :: 3
17
+ Classifier: Programming Language :: Python :: 3.8
18
+ Classifier: Programming Language :: Python :: 3.9
19
+ Classifier: Programming Language :: Python :: 3.10
20
+ Classifier: Programming Language :: Python :: 3.11
21
+ Classifier: Programming Language :: Python :: 3.12
22
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
23
+ Classifier: Topic :: Scientific/Engineering :: Mathematics
24
+ Requires-Python: >=3.8
25
+ Requires-Dist: torch>=1.10.0
26
+ Provides-Extra: dev
27
+ Requires-Dist: pytest-cov>=4.0; extra == 'dev'
28
+ Requires-Dist: pytest>=7.0; extra == 'dev'
29
+ Description-Content-Type: text/markdown
30
+
31
+ # LAdam
32
+
33
+ **Laplacian Adam — spatially-aware adaptive optimizer for PyTorch**
34
+
35
+ [![PyPI](https://img.shields.io/pypi/v/ladam.svg)](https://pypi.org/project/ladam/)
36
+ [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
37
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://python.org)
38
+
39
+ LAdam is a drop-in Adam replacement that applies **discrete Laplacian regularization** to Adam's second-moment estimate (v_t). This couples neighboring weight learning rates, producing spatially-smoothed adaptive optimization.
40
+
41
+ ## Why LAdam?
42
+
43
+ Adam computes independent per-parameter learning rates. But adjacent weights in trained networks are often functionally correlated — the per-parameter variance estimates should reflect this structure.
44
+
45
+ LAdam adds **one operation** to Adam: a Laplacian diffusion step on v_t, controlled by a single scalar `c2`. The Laplacian allows each weight's learning rate to be informed by its neighbors, smoothing the optimization landscape.
46
+
47
+ ## Results
48
+
49
+ | Task | Architecture | Metric | Adam | LAdam | Improvement |
50
+ |------|-------------|--------|------|-------|-------------|
51
+ | **Wave Equation PINN** | 5×128 MLP | L2 Error | 0.0310 | **0.0172** | **-44.6%** |
52
+ | **FashionMNIST** | Transformer | Accuracy | 89.46% | **89.66%** | **+0.20%** (p=0.0005) |
53
+ | FashionMNIST | MLP | Accuracy | 89.10% | 89.12% | +0.02% (n.s.) |
54
+ | FashionMNIST | CNN | Accuracy | 91.15% | 91.14% | -0.01% (tie) |
55
+
56
+ > **LAdam excels on architectures with spatially-correlated weight structure** — particularly PINNs and transformers. For CNNs (whose conv filters are already spatial detectors), the Laplacian is redundant.
57
+
58
+ ## Installation
59
+
60
+ ```bash
61
+ pip install ladam
62
+ ```
63
+
64
+ ## Optimizers
65
+
66
+ LAdam ships three Laplacian-enhanced optimizers:
67
+
68
+ | Optimizer | Base | Laplacian target | Best for |
69
+ |-----------|------|------------------|----------|
70
+ | **LAdam** | Adam | Second moment v_t | PINNs, transformers, CNNs |
71
+ | **LAdaGrad** | AdaGrad | Cumulative sum G_t | Sparse features, NLP |
72
+ | **LRMSProp** | RMSProp | Running average v_t | RNNs, non-stationary losses |
73
+
74
+ All three share the same Laplacian kernel infrastructure and `c2` parameter.
75
+
76
+ ## Usage
77
+
78
+ ### Basic — Drop-in Adam replacement
79
+
80
+ ```python
81
+ from ladam import LAdam
82
+
83
+ optimizer = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
84
+
85
+ # Training loop is identical to Adam
86
+ for batch in dataloader:
87
+ loss = criterion(model(batch))
88
+ loss.backward()
89
+ optimizer.step()
90
+ optimizer.zero_grad()
91
+ ```
92
+
93
+ ### LAdaGrad and LRMSProp
94
+
95
+ ```python
96
+ from ladam import LAdaGrad, LRMSProp
97
+
98
+ # AdaGrad with Laplacian smoothing on cumulative squared gradients
99
+ optimizer = LAdaGrad(model.parameters(), lr=1e-2, c2=1e-4)
100
+
101
+ # RMSProp with Laplacian smoothing on running variance
102
+ optimizer = LRMSProp(model.parameters(), lr=1e-2, alpha=0.99, c2=1e-4)
103
+ ```
104
+
105
+ ### Per-layer c2 with parameter groups
106
+
107
+ ```python
108
+ optimizer = LAdam([
109
+ {'params': model.attention.parameters(), 'c2': 1e-4}, # Transformer attention
110
+ {'params': model.ffn.parameters(), 'c2': 1e-5}, # Feed-forward
111
+ {'params': model.norm.parameters(), 'c2': 0.0}, # Skip for norms
112
+ ], lr=3e-4)
113
+ ```
114
+
115
+ ### Architecture-aware defaults
116
+
117
+ ```python
118
+ from ladam import LAdam, suggest_c2
119
+
120
+ c2 = suggest_c2('pinn') # Returns 1e-5
121
+ c2 = suggest_c2('transformer') # Returns 1e-4
122
+
123
+ optimizer = LAdam(model.parameters(), lr=1e-3, c2=c2)
124
+ ```
125
+
126
+ ## Parameters
127
+
128
+ | Parameter | Default | Description |
129
+ |-----------|---------|-------------|
130
+ | `lr` | 1e-3 | Learning rate |
131
+ | `betas` | (0.9, 0.999) | EMA coefficients (same as Adam) |
132
+ | `eps` | 1e-8 | Numerical stability (same as Adam) |
133
+ | `weight_decay` | 0 | L2 regularization (same as AdamW behavior) |
134
+ | `c2` | 1e-4 | **Laplacian coupling strength.** Controls how much neighboring variance estimates influence each other. |
135
+ | `mode` | 'variance_lap' | Which quantity to smooth. `'variance_lap'` is best. |
136
+ | `stencil` | '9point' | **Discrete Laplacian stencil.** `'9point'` (isotropic, 0.46% anisotropy) or `'5point'` (legacy, 12.3% anisotropy). |
137
+ | `min_spatial_size` | 16 | Skip Laplacian for params with fewer elements (biases, LayerNorm). |
138
+
139
+ ### Stencil Selection
140
+
141
+ The `stencil` parameter controls the discrete Laplacian kernel used for spatial coupling:
142
+
143
+ - **`'9point'` (default)**: Isotropic stencil with face + edge neighbors. Treats diagonal neighbors with 1/6 weight vs 4/6 for face neighbors.
144
+ - **`'5point'`**: Standard cross-pattern stencil (faces only). Slightly faster but 25× more anisotropic.
145
+
146
+ At typical `c2` values (1e-5 to 1e-3), the effective learning rate difference between stencils is <0.3%. The 9-point default is recommended for correctness.
147
+
148
+ ### Choosing c2
149
+
150
+ `c2` is the only new hyperparameter. It's robust across 3 orders of magnitude:
151
+
152
+ | c2 | Best For | Notes |
153
+ |----|----------|-------|
154
+ | `1e-5` | PINNs, scientific ML | Gentle coupling, biggest error reduction |
155
+ | `1e-4` | Transformers, general | **Safe default** |
156
+ | `1e-3` | Aggressive smoothing | Works but slightly less stable |
157
+ | `0` | Disable | Reduces to standard Adam |
158
+
159
+ All 7 values tested in [1e-6, 1e-3] outperformed Adam on transformers (B12 sweep).
160
+
161
+ ## How It Works
162
+
163
+ Standard Adam computes per-parameter adaptive learning rates from the second moment:
164
+
165
+ ```
166
+ v_t = β₂·v_{t-1} + (1-β₂)·g_t² # Variance estimate
167
+ lr_effective = lr / (√v_t + ε) # Per-parameter learning rate
168
+ ```
169
+
170
+ LAdam adds a Laplacian coupling step:
171
+
172
+ ```
173
+ v_smooth = v_t + c2 · ∇²v_t # Spatial smoothing
174
+ lr_effective = lr / (√v_smooth + ε) # Coupled learning rate
175
+ ```
176
+
177
+ Where `\nabla^2` is the discrete Laplacian computed via a single `F.conv2d` kernel (9-point isotropic by default) -- efficient and GPU-friendly. The Laplacian treats weight matrices as 2D fields, coupling each weight's learning rate with its spatial neighbors.
178
+
179
+ **Overhead**: ~2-5% wall-clock time increase per step. The Laplacian is a single fused convolution kernel, not point-wise iteration.
180
+
181
+ ## Benchmarks
182
+
183
+ ### PINN: Wave Equation (u_tt = c^2 u_xx)
184
+
185
+ 5-layer, 128-unit tanh MLP trained for 5000 steps on the 1D wave equation.
186
+
187
+ | Optimizer | L2 Error | vs Adam |
188
+ |-----------|----------|---------|
189
+ | Adam (lr=1e-3) | 0.0310 | — |
190
+ | LAdam c²=1e-4 | 0.0240 | -22.8% |
191
+ | **LAdam c²=1e-5** | **0.0172** | **-44.6%** |
192
+ | LAdam c²=1e-3 | 0.0185 | -40.3% |
193
+
194
+ ### Transformer: FashionMNIST Classification
195
+
196
+ 4-head, 128-dim, 2-layer transformer, 30 epochs, 5 independent seeds.
197
+
198
+ | Optimizer | Accuracy (mean ± std) | p-value (vs Adam) |
199
+ |-----------|----------------------|-------------------|
200
+ | Adam | 89.46 ± 0.10% | — |
201
+ | **LAdam c²=1e-4** | **89.66 ± 0.06%** | **0.0005** |
202
+
203
+ ### c² Robustness Sweep
204
+
205
+ 7 c² values on the same transformer task. **All 7 beat Adam:**
206
+
207
+ | c² | Accuracy | Δ vs Adam |
208
+ |----|----------|-----------|
209
+ | 1e-6 | 89.62% | +0.16% |
210
+ | 5e-6 | 89.73% | +0.27% |
211
+ | 1e-5 | 89.79% | +0.33% |
212
+ | 5e-5 | 89.75% | +0.29% |
213
+ | 1e-4 | 89.67% | +0.21% |
214
+ | 5e-4 | 89.64% | +0.18% |
215
+ | 1e-3 | 89.66% | +0.20% |
216
+
217
+ ## FAQ
218
+
219
+ **Q: Does this work for LLMs / GPT-scale models?**
220
+ A: No. LAdam **hurts** LLM training (tested on GPT-2/WikiText-2). Attention weight matrices encode semantic structure, not spatial structure — the Laplacian destroys per-feature specialization. Use standard Adam/AdamW for LLMs.
221
+
222
+ **Q: Why not smooth the gradient instead of the variance?**
223
+ A: [Osher et al. (2018)](https://arxiv.org/abs/1806.06317) explored Laplacian smoothing of gradients. We found that smoothing the *variance estimate* is more effective because it smooths the *learning rate landscape* rather than the *descent direction*. These are mathematically distinct: ∇²(EMA(g²)) ≠ (∇²g)².
224
+
225
+ **Q: Why does this help PINNs so much?**
226
+ A: PDE-based loss landscapes have inherent spatial structure from the differential operators in the loss function. The Laplacian on v_t aligns the optimizer's internal representation with this structure.
227
+
228
+ **Q: Can I use this with learning rate schedulers?**
229
+ A: Yes. LAdam is fully compatible with any `torch.optim.lr_scheduler`.
230
+
231
+ ## Citation
232
+
233
+ If you use LAdam in your research, please cite:
234
+
235
+ ```bibtex
236
+ @software{partin2026ladam,
237
+ author = {Partin, Greg},
238
+ title = {LAdam: Spatially-Aware Adaptive Optimization via Laplacian-Regularized Variance Estimates},
239
+ year = {2026},
240
+ url = {https://github.com/gpartin/ladam}
241
+ }
242
+ ```
243
+
244
+ ## License
245
+
246
+ MIT. See [LICENSE](LICENSE) for details.
ladam-0.2.0/README.md ADDED
@@ -0,0 +1,216 @@
1
+ # LAdam
2
+
3
+ **Laplacian Adam — spatially-aware adaptive optimizer for PyTorch**
4
+
5
+ [![PyPI](https://img.shields.io/pypi/v/ladam.svg)](https://pypi.org/project/ladam/)
6
+ [![License](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
7
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://python.org)
8
+
9
+ LAdam is a drop-in Adam replacement that applies **discrete Laplacian regularization** to Adam's second-moment estimate (v_t). This couples neighboring weight learning rates, producing spatially-smoothed adaptive optimization.
10
+
11
+ ## Why LAdam?
12
+
13
+ Adam computes independent per-parameter learning rates. But adjacent weights in trained networks are often functionally correlated — the per-parameter variance estimates should reflect this structure.
14
+
15
+ LAdam adds **one operation** to Adam: a Laplacian diffusion step on v_t, controlled by a single scalar `c2`. The Laplacian allows each weight's learning rate to be informed by its neighbors, smoothing the optimization landscape.
16
+
17
+ ## Results
18
+
19
+ | Task | Architecture | Metric | Adam | LAdam | Improvement |
20
+ |------|-------------|--------|------|-------|-------------|
21
+ | **Wave Equation PINN** | 5×128 MLP | L2 Error | 0.0310 | **0.0172** | **-44.6%** |
22
+ | **FashionMNIST** | Transformer | Accuracy | 89.46% | **89.66%** | **+0.20%** (p=0.0005) |
23
+ | FashionMNIST | MLP | Accuracy | 89.10% | 89.12% | +0.02% (n.s.) |
24
+ | FashionMNIST | CNN | Accuracy | 91.15% | 91.14% | -0.01% (tie) |
25
+
26
+ > **LAdam excels on architectures with spatially-correlated weight structure** — particularly PINNs and transformers. For CNNs (whose conv filters are already spatial detectors), the Laplacian is redundant.
27
+
28
+ ## Installation
29
+
30
+ ```bash
31
+ pip install ladam
32
+ ```
33
+
34
+ ## Optimizers
35
+
36
+ LAdam ships three Laplacian-enhanced optimizers:
37
+
38
+ | Optimizer | Base | Laplacian target | Best for |
39
+ |-----------|------|------------------|----------|
40
+ | **LAdam** | Adam | Second moment v_t | PINNs, transformers, CNNs |
41
+ | **LAdaGrad** | AdaGrad | Cumulative sum G_t | Sparse features, NLP |
42
+ | **LRMSProp** | RMSProp | Running average v_t | RNNs, non-stationary losses |
43
+
44
+ All three share the same Laplacian kernel infrastructure and `c2` parameter.
45
+
46
+ ## Usage
47
+
48
+ ### Basic — Drop-in Adam replacement
49
+
50
+ ```python
51
+ from ladam import LAdam
52
+
53
+ optimizer = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
54
+
55
+ # Training loop is identical to Adam
56
+ for batch in dataloader:
57
+ loss = criterion(model(batch))
58
+ loss.backward()
59
+ optimizer.step()
60
+ optimizer.zero_grad()
61
+ ```
62
+
63
+ ### LAdaGrad and LRMSProp
64
+
65
+ ```python
66
+ from ladam import LAdaGrad, LRMSProp
67
+
68
+ # AdaGrad with Laplacian smoothing on cumulative squared gradients
69
+ optimizer = LAdaGrad(model.parameters(), lr=1e-2, c2=1e-4)
70
+
71
+ # RMSProp with Laplacian smoothing on running variance
72
+ optimizer = LRMSProp(model.parameters(), lr=1e-2, alpha=0.99, c2=1e-4)
73
+ ```
74
+
75
+ ### Per-layer c2 with parameter groups
76
+
77
+ ```python
78
+ optimizer = LAdam([
79
+ {'params': model.attention.parameters(), 'c2': 1e-4}, # Transformer attention
80
+ {'params': model.ffn.parameters(), 'c2': 1e-5}, # Feed-forward
81
+ {'params': model.norm.parameters(), 'c2': 0.0}, # Skip for norms
82
+ ], lr=3e-4)
83
+ ```
84
+
85
+ ### Architecture-aware defaults
86
+
87
+ ```python
88
+ from ladam import LAdam, suggest_c2
89
+
90
+ c2 = suggest_c2('pinn') # Returns 1e-5
91
+ c2 = suggest_c2('transformer') # Returns 1e-4
92
+
93
+ optimizer = LAdam(model.parameters(), lr=1e-3, c2=c2)
94
+ ```
95
+
96
+ ## Parameters
97
+
98
+ | Parameter | Default | Description |
99
+ |-----------|---------|-------------|
100
+ | `lr` | 1e-3 | Learning rate |
101
+ | `betas` | (0.9, 0.999) | EMA coefficients (same as Adam) |
102
+ | `eps` | 1e-8 | Numerical stability (same as Adam) |
103
+ | `weight_decay` | 0 | L2 regularization (same as AdamW behavior) |
104
+ | `c2` | 1e-4 | **Laplacian coupling strength.** Controls how much neighboring variance estimates influence each other. |
105
+ | `mode` | 'variance_lap' | Which quantity to smooth. `'variance_lap'` is best. |
106
+ | `stencil` | '9point' | **Discrete Laplacian stencil.** `'9point'` (isotropic, 0.46% anisotropy) or `'5point'` (legacy, 12.3% anisotropy). |
107
+ | `min_spatial_size` | 16 | Skip Laplacian for params with fewer elements (biases, LayerNorm). |
108
+
109
+ ### Stencil Selection
110
+
111
+ The `stencil` parameter controls the discrete Laplacian kernel used for spatial coupling:
112
+
113
+ - **`'9point'` (default)**: Isotropic stencil with face + edge neighbors. Treats diagonal neighbors with 1/6 weight vs 4/6 for face neighbors.
114
+ - **`'5point'`**: Standard cross-pattern stencil (faces only). Slightly faster but 25× more anisotropic.
115
+
116
+ At typical `c2` values (1e-5 to 1e-3), the effective learning rate difference between stencils is <0.3%. The 9-point default is recommended for correctness.
117
+
118
+ ### Choosing c2
119
+
120
+ `c2` is the only new hyperparameter. It's robust across 3 orders of magnitude:
121
+
122
+ | c2 | Best For | Notes |
123
+ |----|----------|-------|
124
+ | `1e-5` | PINNs, scientific ML | Gentle coupling, biggest error reduction |
125
+ | `1e-4` | Transformers, general | **Safe default** |
126
+ | `1e-3` | Aggressive smoothing | Works but slightly less stable |
127
+ | `0` | Disable | Reduces to standard Adam |
128
+
129
+ All 7 values tested in [1e-6, 1e-3] outperformed Adam on transformers (B12 sweep).
130
+
131
+ ## How It Works
132
+
133
+ Standard Adam computes per-parameter adaptive learning rates from the second moment:
134
+
135
+ ```
136
+ v_t = β₂·v_{t-1} + (1-β₂)·g_t² # Variance estimate
137
+ lr_effective = lr / (√v_t + ε) # Per-parameter learning rate
138
+ ```
139
+
140
+ LAdam adds a Laplacian coupling step:
141
+
142
+ ```
143
+ v_smooth = v_t + c2 · ∇²v_t # Spatial smoothing
144
+ lr_effective = lr / (√v_smooth + ε) # Coupled learning rate
145
+ ```
146
+
147
+ Where `\nabla^2` is the discrete Laplacian computed via a single `F.conv2d` kernel (9-point isotropic by default) -- efficient and GPU-friendly. The Laplacian treats weight matrices as 2D fields, coupling each weight's learning rate with its spatial neighbors.
148
+
149
+ **Overhead**: ~2-5% wall-clock time increase per step. The Laplacian is a single fused convolution kernel, not point-wise iteration.
150
+
151
+ ## Benchmarks
152
+
153
+ ### PINN: Wave Equation (u_tt = c^2 u_xx)
154
+
155
+ 5-layer, 128-unit tanh MLP trained for 5000 steps on the 1D wave equation.
156
+
157
+ | Optimizer | L2 Error | vs Adam |
158
+ |-----------|----------|---------|
159
+ | Adam (lr=1e-3) | 0.0310 | — |
160
+ | LAdam c²=1e-4 | 0.0240 | -22.8% |
161
+ | **LAdam c²=1e-5** | **0.0172** | **-44.6%** |
162
+ | LAdam c²=1e-3 | 0.0185 | -40.3% |
163
+
164
+ ### Transformer: FashionMNIST Classification
165
+
166
+ 4-head, 128-dim, 2-layer transformer, 30 epochs, 5 independent seeds.
167
+
168
+ | Optimizer | Accuracy (mean ± std) | p-value (vs Adam) |
169
+ |-----------|----------------------|-------------------|
170
+ | Adam | 89.46 ± 0.10% | — |
171
+ | **LAdam c²=1e-4** | **89.66 ± 0.06%** | **0.0005** |
172
+
173
+ ### c² Robustness Sweep
174
+
175
+ 7 c² values on the same transformer task. **All 7 beat Adam:**
176
+
177
+ | c² | Accuracy | Δ vs Adam |
178
+ |----|----------|-----------|
179
+ | 1e-6 | 89.62% | +0.16% |
180
+ | 5e-6 | 89.73% | +0.27% |
181
+ | 1e-5 | 89.79% | +0.33% |
182
+ | 5e-5 | 89.75% | +0.29% |
183
+ | 1e-4 | 89.67% | +0.21% |
184
+ | 5e-4 | 89.64% | +0.18% |
185
+ | 1e-3 | 89.66% | +0.20% |
186
+
187
+ ## FAQ
188
+
189
+ **Q: Does this work for LLMs / GPT-scale models?**
190
+ A: No. LAdam **hurts** LLM training (tested on GPT-2/WikiText-2). Attention weight matrices encode semantic structure, not spatial structure — the Laplacian destroys per-feature specialization. Use standard Adam/AdamW for LLMs.
191
+
192
+ **Q: Why not smooth the gradient instead of the variance?**
193
+ A: [Osher et al. (2018)](https://arxiv.org/abs/1806.06317) explored Laplacian smoothing of gradients. We found that smoothing the *variance estimate* is more effective because it smooths the *learning rate landscape* rather than the *descent direction*. These are mathematically distinct: ∇²(EMA(g²)) ≠ (∇²g)².
194
+
195
+ **Q: Why does this help PINNs so much?**
196
+ A: PDE-based loss landscapes have inherent spatial structure from the differential operators in the loss function. The Laplacian on v_t aligns the optimizer's internal representation with this structure.
197
+
198
+ **Q: Can I use this with learning rate schedulers?**
199
+ A: Yes. LAdam is fully compatible with any `torch.optim.lr_scheduler`.
200
+
201
+ ## Citation
202
+
203
+ If you use LAdam in your research, please cite:
204
+
205
+ ```bibtex
206
+ @software{partin2026ladam,
207
+ author = {Partin, Greg},
208
+ title = {LAdam: Spatially-Aware Adaptive Optimization via Laplacian-Regularized Variance Estimates},
209
+ year = {2026},
210
+ url = {https://github.com/gpartin/ladam}
211
+ }
212
+ ```
213
+
214
+ ## License
215
+
216
+ MIT. See [LICENSE](LICENSE) for details.
@@ -0,0 +1,49 @@
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "ladam"
7
+ version = "0.2.0"
8
+ description = "LAdam: Laplacian Adam — Adam with spatially-coupled variance estimates via discrete Laplacian"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.8"
12
+ authors = [
13
+ {name = "Greg Partin"},
14
+ ]
15
+ keywords = [
16
+ "optimizer", "adam", "deep-learning", "pytorch",
17
+ "laplacian", "pinn", "scientific-ml", "spatial-regularization",
18
+ ]
19
+ classifiers = [
20
+ "Development Status :: 4 - Beta",
21
+ "Intended Audience :: Science/Research",
22
+ "Intended Audience :: Developers",
23
+ "License :: OSI Approved :: MIT License",
24
+ "Programming Language :: Python :: 3",
25
+ "Programming Language :: Python :: 3.8",
26
+ "Programming Language :: Python :: 3.9",
27
+ "Programming Language :: Python :: 3.10",
28
+ "Programming Language :: Python :: 3.11",
29
+ "Programming Language :: Python :: 3.12",
30
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
31
+ "Topic :: Scientific/Engineering :: Mathematics",
32
+ ]
33
+ dependencies = [
34
+ "torch>=1.10.0",
35
+ ]
36
+
37
+ [project.optional-dependencies]
38
+ dev = [
39
+ "pytest>=7.0",
40
+ "pytest-cov>=4.0",
41
+ ]
42
+
43
+ [project.urls]
44
+ Homepage = "https://github.com/gpartin/ladam"
45
+ Documentation = "https://github.com/gpartin/ladam#usage"
46
+ Issues = "https://github.com/gpartin/ladam/issues"
47
+
48
+ [tool.hatch.build.targets.wheel]
49
+ packages = ["src/ladam"]
@@ -0,0 +1,12 @@
1
+ """
2
+ LAdam — Laplacian Adam Optimizer
3
+ """
4
+
5
+ __version__ = "0.2.0"
6
+
7
+ from .optimizer import LAdam, LAdaGrad, LRMSProp, suggest_c2
8
+
9
+ # Backward compatibility alias
10
+ WaveAdam = LAdam
11
+
12
+ __all__ = ["LAdam", "LAdaGrad", "LRMSProp", "WaveAdam", "suggest_c2"]
@@ -0,0 +1,442 @@
1
+ """
2
+ LAdam — Laplacian Adam Optimizer
3
+ ================================
4
+
5
+ A drop-in Adam replacement that applies discrete Laplacian regularization
6
+ to the second moment estimate (v_t), producing spatially-smoothed adaptive
7
+ learning rates.
8
+
9
+ Key insight: adjacent weights in neural networks are often functionally
10
+ correlated, but Adam computes independent per-parameter learning rates.
11
+ LAdam restores spatial coherence by applying a Laplacian diffusion
12
+ operator to Adam's variance estimate, allowing neighboring weights to
13
+ share curvature information.
14
+
15
+ This is particularly effective for:
16
+ - **PINNs** (physics-informed neural networks): -44.6% L2 error vs Adam
17
+ - **Transformers**: +0.20% accuracy, p=0.0005 across 5 seeds
18
+ - Any architecture with spatially-correlated weight structure
19
+
20
+ Usage:
21
+ from ladam import LAdam
22
+
23
+ optimizer = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
24
+
25
+ for batch in dataloader:
26
+ loss = model(batch)
27
+ loss.backward()
28
+ optimizer.step()
29
+ optimizer.zero_grad()
30
+ """
31
+
32
+ import torch
33
+ import torch.nn.functional as F
34
+ from torch.optim import Optimizer
35
+
36
+
37
+ class LAdam(Optimizer):
38
+ """
39
+ LAdam: Laplacian Adam — Adam with Laplacian-regularized variance estimates.
40
+
41
+ Applies a discrete Laplacian to Adam's second-moment estimate v_t,
42
+ coupling adjacent weight learning rates. This smooths the adaptive
43
+ learning rate landscape, improving convergence for architectures with
44
+ spatially-structured parameters.
45
+
46
+ Modes:
47
+ 'variance_lap': Laplacian on v_t (default, best overall)
48
+ 'update_lap': Laplacian on the update direction
49
+ 'weight_lap': Standard Adam step + weight-space diffusion
50
+
51
+ Args:
52
+ params: Model parameters or param groups.
53
+ lr (float): Learning rate. Default: 1e-3.
54
+ betas (Tuple[float, float]): EMA coefficients for (m_t, v_t).
55
+ Default: (0.9, 0.999).
56
+ eps (float): Numerical stability. Default: 1e-8.
57
+ weight_decay (float): L2 weight decay. Default: 0.
58
+ c2 (float): Laplacian coupling strength. Controls how much
59
+ neighboring variance estimates influence each other.
60
+ Recommended: 1e-5 for PINNs, 1e-4 for transformers.
61
+ Default: 1e-4.
62
+ mode (str): Which internal quantity to apply the Laplacian to.
63
+ Default: 'variance_lap'.
64
+ stencil (str): Laplacian stencil type for 2D parameters.
65
+ '9point': Isotropic stencil [1,4,1;4,-20,4;1,4,1]/6 — 0.46%
66
+ anisotropy. Default. Based on LFM lattice geometry
67
+ (face weight 2/3, diagonal weight 1/6).
68
+ '5point': Standard stencil [0,1,0;1,-4,1;0,1,0] — 12.3%
69
+ anisotropy. Legacy option for reproducing prior results.
70
+ min_spatial_size (int): Skip Laplacian for parameters with fewer
71
+ elements than this threshold. Small params (biases, norms)
72
+ lack meaningful spatial structure. Default: 16.
73
+ """
74
+
75
+ # Pre-built Laplacian kernels (cached on first use per device/stencil)
76
+ _lap_kernels_2d = {}
77
+ _lap_kernel_1d = None
78
+
79
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
80
+ weight_decay=0, c2=1e-4, mode='variance_lap',
81
+ stencil='9point', min_spatial_size=16):
82
+ if lr < 0.0:
83
+ raise ValueError(f"Invalid learning rate: {lr}")
84
+ if eps < 0.0:
85
+ raise ValueError(f"Invalid epsilon value: {eps}")
86
+ if not 0.0 <= betas[0] < 1.0:
87
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
88
+ if not 0.0 <= betas[1] < 1.0:
89
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
90
+ if c2 < 0.0:
91
+ raise ValueError(f"Invalid coupling strength: {c2}")
92
+ if mode not in ('variance_lap', 'update_lap', 'weight_lap'):
93
+ raise ValueError(f"Invalid mode: {mode}. Choose from: "
94
+ "'variance_lap', 'update_lap', 'weight_lap'")
95
+ if stencil not in ('5point', '9point'):
96
+ raise ValueError(f"Invalid stencil: {stencil}. Choose from: "
97
+ "'5point', '9point'")
98
+
99
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
100
+ c2=c2, mode=mode, stencil=stencil,
101
+ min_spatial_size=min_spatial_size)
102
+ super().__init__(params, defaults)
103
+
104
+ @classmethod
105
+ def _get_lap_kernel_2d(cls, device, dtype, stencil='9point'):
106
+ """Lazily create and cache the 2D discrete Laplacian kernel.
107
+
108
+ stencil='9point' (default): Isotropic [1,4,1;4,-20,4;1,4,1]/6
109
+ Isotropic stencil — faces weighted 2/3, diagonals 1/6.
110
+ Only 0.46% anisotropy vs 12.3% for 5-point.
111
+ stencil='5point': Standard [0,1,0;1,-4,1;0,1,0].
112
+ Standard cross-pattern stencil.
113
+ """
114
+ key = (stencil, device, dtype)
115
+ if key not in cls._lap_kernels_2d:
116
+ if stencil == '9point':
117
+ # Isotropic 2D Laplacian: face weight 4/6, diagonal weight 1/6
118
+ # Minimizes angular anisotropy (0.46% vs 12.3% for 5-point)
119
+ k = torch.tensor([[1/6., 4/6., 1/6.],
120
+ [4/6., -20/6., 4/6.],
121
+ [1/6., 4/6., 1/6.]], device=device, dtype=dtype)
122
+ else: # 5point
123
+ k = torch.tensor([[0., 1., 0.],
124
+ [1., -4., 1.],
125
+ [0., 1., 0.]], device=device, dtype=dtype)
126
+ cls._lap_kernels_2d[key] = k.reshape(1, 1, 3, 3)
127
+ return cls._lap_kernels_2d[key]
128
+
129
+ @classmethod
130
+ def _get_lap_kernel_1d(cls, device, dtype):
131
+ """Lazily create and cache the 1D discrete Laplacian kernel."""
132
+ if cls._lap_kernel_1d is None or cls._lap_kernel_1d.device != device:
133
+ k = torch.tensor([1., -2., 1.], device=device, dtype=dtype)
134
+ cls._lap_kernel_1d = k.reshape(1, 1, 3)
135
+ return cls._lap_kernel_1d.to(dtype=dtype)
136
+
137
+ @staticmethod
138
+ def _laplacian(W, kernel_2d, kernel_1d):
139
+ """
140
+ Apply discrete Laplacian via fused conv2d — single GPU kernel launch.
141
+
142
+ Uses circular padding to treat weight matrices as periodic fields.
143
+ High-dimensional tensors are reshaped to 2D while preserving spatial
144
+ structure.
145
+ """
146
+ if W.dim() == 1:
147
+ n = W.numel()
148
+ if n < 4:
149
+ return torch.zeros_like(W)
150
+ x = W.reshape(1, 1, n)
151
+ x = F.pad(x, (1, 1), mode='circular')
152
+ return F.conv1d(x, kernel_1d).reshape(W.shape)
153
+ else:
154
+ shape = W.shape
155
+ W2 = W.reshape(shape[0], -1) if W.dim() > 2 else W
156
+ h, w = W2.shape
157
+ if h < 3 or w < 3:
158
+ return torch.zeros_like(W)
159
+ x = W2.reshape(1, 1, h, w)
160
+ x = F.pad(x, (1, 1, 1, 1), mode='circular')
161
+ result = F.conv2d(x, kernel_2d).reshape(W2.shape)
162
+ if W.dim() > 2:
163
+ result = result.reshape(shape)
164
+ return result
165
+
166
+ @torch.no_grad()
167
+ def step(self, closure=None):
168
+ """Performs a single optimization step.
169
+
170
+ Args:
171
+ closure (Callable, optional): A closure that reevaluates the model
172
+ and returns the loss.
173
+ """
174
+ loss = None
175
+ if closure is not None:
176
+ with torch.enable_grad():
177
+ loss = closure()
178
+
179
+ for group in self.param_groups:
180
+ lr = group['lr']
181
+ beta1, beta2 = group['betas']
182
+ eps = group['eps']
183
+ c2 = group['c2']
184
+ mode = group['mode']
185
+ wd = group['weight_decay']
186
+ min_sp = group['min_spatial_size']
187
+
188
+ stencil = group['stencil']
189
+
190
+ for p in group['params']:
191
+ if p.grad is None:
192
+ continue
193
+
194
+ grad = p.grad.data
195
+ if wd != 0:
196
+ grad = grad.add(p.data, alpha=wd)
197
+
198
+ state = self.state[p]
199
+
200
+ # State initialization
201
+ if len(state) == 0:
202
+ state['step'] = 0
203
+ state['m'] = torch.zeros_like(p.data)
204
+ state['v'] = torch.zeros_like(p.data)
205
+ state['use_lap'] = (c2 > 0 and p.data.numel() >= min_sp)
206
+
207
+ state['step'] += 1
208
+ m, v = state['m'], state['v']
209
+ step = state['step']
210
+ use_lap = state['use_lap']
211
+
212
+ # Standard Adam momentum updates
213
+ m.mul_(beta1).add_(grad, alpha=1 - beta1)
214
+ v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
215
+
216
+ # Bias correction
217
+ m_hat = m / (1 - beta1 ** step)
218
+ v_hat = v / (1 - beta2 ** step)
219
+
220
+ if mode == 'variance_lap':
221
+ # Core innovation: couple neighboring variance estimates
222
+ if use_lap:
223
+ k2d = self._get_lap_kernel_2d(p.device, v_hat.dtype, stencil)
224
+ k1d = self._get_lap_kernel_1d(p.device, v_hat.dtype)
225
+ v_smooth = v_hat + c2 * self._laplacian(v_hat, k2d, k1d)
226
+ v_smooth = v_smooth.clamp(min=eps)
227
+ else:
228
+ v_smooth = v_hat
229
+ p.data.addcdiv_(m_hat, v_smooth.sqrt().add_(eps), value=-lr)
230
+
231
+ elif mode == 'update_lap':
232
+ update = m_hat / (v_hat.sqrt() + eps)
233
+ if use_lap:
234
+ k2d = self._get_lap_kernel_2d(p.device, update.dtype, stencil)
235
+ k1d = self._get_lap_kernel_1d(p.device, update.dtype)
236
+ update = update + c2 * self._laplacian(update, k2d, k1d)
237
+ p.data.add_(update, alpha=-lr)
238
+
239
+ elif mode == 'weight_lap':
240
+ p.data.addcdiv_(m_hat, v_hat.sqrt().add_(eps), value=-lr)
241
+ if use_lap:
242
+ k2d = self._get_lap_kernel_2d(p.device, p.data.dtype, stencil)
243
+ k1d = self._get_lap_kernel_1d(p.device, p.data.dtype)
244
+ p.data.add_(self._laplacian(p.data, k2d, k1d), alpha=c2)
245
+
246
+ return loss
247
+
248
+
249
+ def suggest_c2(architecture: str = 'auto') -> float:
250
+ """
251
+ Suggest a c2 value based on architecture type.
252
+
253
+ Based on systematic sweeps across 7 c2 values and 4 architecture types.
254
+
255
+ Args:
256
+ architecture: One of 'pinn', 'transformer', 'mlp', 'cnn', or 'auto'.
257
+ 'auto' returns a safe default.
258
+
259
+ Returns:
260
+ Recommended c2 value.
261
+ """
262
+ recommendations = {
263
+ 'pinn': 1e-5, # -44.6% L2 error vs Adam
264
+ 'transformer': 1e-4, # +0.20% acc, p=0.0005
265
+ 'mlp': 1e-5, # Marginal improvement, safe value
266
+ 'cnn': 0.0, # No benefit; skip Laplacian
267
+ 'auto': 1e-4, # Safe general-purpose default
268
+ }
269
+ arch = architecture.lower()
270
+ if arch not in recommendations:
271
+ raise ValueError(f"Unknown architecture: {arch}. "
272
+ f"Choose from: {list(recommendations.keys())}")
273
+ return recommendations[arch]
274
+
275
+
276
+ class LAdaGrad(Optimizer):
277
+ """
278
+ LAdaGrad: AdaGrad with Laplacian-regularized accumulator.
279
+
280
+ Standard AdaGrad accumulates squared gradients independently per parameter.
281
+ LAdaGrad applies discrete Laplacian diffusion to the accumulator, letting
282
+ neighboring weights share curvature information.
283
+
284
+ Args:
285
+ params: Model parameters or param groups.
286
+ lr (float): Learning rate. Default: 1e-2.
287
+ lr_decay (float): Learning rate decay. Default: 0.
288
+ eps (float): Numerical stability. Default: 1e-10.
289
+ weight_decay (float): L2 weight decay. Default: 0.
290
+ c2 (float): Laplacian coupling strength. Default: 1e-4.
291
+ stencil (str): '9point' or '5point'. Default: '9point'.
292
+ min_spatial_size (int): Skip Laplacian below this size. Default: 16.
293
+ """
294
+
295
+ def __init__(self, params, lr=1e-2, lr_decay=0, eps=1e-10,
296
+ weight_decay=0, c2=1e-4, stencil='9point',
297
+ min_spatial_size=16):
298
+ if lr < 0.0:
299
+ raise ValueError(f"Invalid learning rate: {lr}")
300
+ if c2 < 0.0:
301
+ raise ValueError(f"Invalid coupling strength: {c2}")
302
+ defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps,
303
+ weight_decay=weight_decay, c2=c2,
304
+ stencil=stencil, min_spatial_size=min_spatial_size)
305
+ super().__init__(params, defaults)
306
+
307
+ @torch.no_grad()
308
+ def step(self, closure=None):
309
+ loss = None
310
+ if closure is not None:
311
+ with torch.enable_grad():
312
+ loss = closure()
313
+
314
+ for group in self.param_groups:
315
+ lr = group['lr']
316
+ lr_decay = group['lr_decay']
317
+ eps = group['eps']
318
+ c2 = group['c2']
319
+ wd = group['weight_decay']
320
+ min_sp = group['min_spatial_size']
321
+ stencil = group['stencil']
322
+
323
+ for p in group['params']:
324
+ if p.grad is None:
325
+ continue
326
+
327
+ grad = p.grad.data
328
+ if wd != 0:
329
+ grad = grad.add(p.data, alpha=wd)
330
+
331
+ state = self.state[p]
332
+ if len(state) == 0:
333
+ state['step'] = 0
334
+ state['sum'] = torch.zeros_like(p.data)
335
+ state['use_lap'] = (c2 > 0 and p.data.numel() >= min_sp)
336
+
337
+ state['step'] += 1
338
+ state['sum'].addcmul_(grad, grad, value=1)
339
+
340
+ clr = lr / (1 + (state['step'] - 1) * lr_decay)
341
+
342
+ acc = state['sum']
343
+ if state['use_lap']:
344
+ k2d = LAdam._get_lap_kernel_2d(p.device, acc.dtype, stencil)
345
+ k1d = LAdam._get_lap_kernel_1d(p.device, acc.dtype)
346
+ acc_smooth = acc + c2 * LAdam._laplacian(acc, k2d, k1d)
347
+ acc_smooth = acc_smooth.clamp(min=0)
348
+ else:
349
+ acc_smooth = acc
350
+
351
+ p.data.addcdiv_(grad, acc_smooth.sqrt().add_(eps), value=-clr)
352
+
353
+ return loss
354
+
355
+
356
+ class LRMSProp(Optimizer):
357
+ """
358
+ LRMSProp: RMSProp with Laplacian-regularized running average.
359
+
360
+ Standard RMSProp maintains a decaying average of squared gradients per
361
+ parameter. LRMSProp applies discrete Laplacian diffusion to this average,
362
+ coupling neighbors.
363
+
364
+ Args:
365
+ params: Model parameters or param groups.
366
+ lr (float): Learning rate. Default: 1e-2.
367
+ alpha (float): Smoothing constant (decay). Default: 0.99.
368
+ eps (float): Numerical stability. Default: 1e-8.
369
+ weight_decay (float): L2 weight decay. Default: 0.
370
+ momentum (float): Momentum factor. Default: 0.
371
+ c2 (float): Laplacian coupling strength. Default: 1e-4.
372
+ stencil (str): '9point' or '5point'. Default: '9point'.
373
+ min_spatial_size (int): Skip Laplacian below this size. Default: 16.
374
+ """
375
+
376
+ def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8,
377
+ weight_decay=0, momentum=0, c2=1e-4, stencil='9point',
378
+ min_spatial_size=16):
379
+ if lr < 0.0:
380
+ raise ValueError(f"Invalid learning rate: {lr}")
381
+ if alpha < 0.0 or alpha >= 1.0:
382
+ raise ValueError(f"Invalid alpha: {alpha}")
383
+ if c2 < 0.0:
384
+ raise ValueError(f"Invalid coupling strength: {c2}")
385
+ defaults = dict(lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay,
386
+ momentum=momentum, c2=c2, stencil=stencil,
387
+ min_spatial_size=min_spatial_size)
388
+ super().__init__(params, defaults)
389
+
390
+ @torch.no_grad()
391
+ def step(self, closure=None):
392
+ loss = None
393
+ if closure is not None:
394
+ with torch.enable_grad():
395
+ loss = closure()
396
+
397
+ for group in self.param_groups:
398
+ lr = group['lr']
399
+ alpha = group['alpha']
400
+ eps = group['eps']
401
+ c2 = group['c2']
402
+ wd = group['weight_decay']
403
+ mom = group['momentum']
404
+ min_sp = group['min_spatial_size']
405
+ stencil = group['stencil']
406
+
407
+ for p in group['params']:
408
+ if p.grad is None:
409
+ continue
410
+
411
+ grad = p.grad.data
412
+ if wd != 0:
413
+ grad = grad.add(p.data, alpha=wd)
414
+
415
+ state = self.state[p]
416
+ if len(state) == 0:
417
+ state['step'] = 0
418
+ state['square_avg'] = torch.zeros_like(p.data)
419
+ if mom > 0:
420
+ state['momentum_buffer'] = torch.zeros_like(p.data)
421
+ state['use_lap'] = (c2 > 0 and p.data.numel() >= min_sp)
422
+
423
+ state['step'] += 1
424
+ sq_avg = state['square_avg']
425
+ sq_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
426
+
427
+ if state['use_lap']:
428
+ k2d = LAdam._get_lap_kernel_2d(p.device, sq_avg.dtype, stencil)
429
+ k1d = LAdam._get_lap_kernel_1d(p.device, sq_avg.dtype)
430
+ avg_smooth = sq_avg + c2 * LAdam._laplacian(sq_avg, k2d, k1d)
431
+ avg_smooth = avg_smooth.clamp(min=0)
432
+ else:
433
+ avg_smooth = sq_avg
434
+
435
+ if mom > 0:
436
+ buf = state['momentum_buffer']
437
+ buf.mul_(mom).addcdiv_(grad, avg_smooth.sqrt().add_(eps))
438
+ p.data.add_(buf, alpha=-lr)
439
+ else:
440
+ p.data.addcdiv_(grad, avg_smooth.sqrt().add_(eps), value=-lr)
441
+
442
+ return loss
@@ -0,0 +1,182 @@
1
+ """Tests for LAdam optimizer."""
2
+
3
+ import pytest
4
+ import torch
5
+ import torch.nn as nn
6
+ from ladam import LAdam, suggest_c2
7
+
8
+
9
+ class SimpleMLP(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.fc1 = nn.Linear(20, 64)
13
+ self.fc2 = nn.Linear(64, 64)
14
+ self.fc3 = nn.Linear(64, 1)
15
+
16
+ def forward(self, x):
17
+ x = torch.relu(self.fc1(x))
18
+ x = torch.relu(self.fc2(x))
19
+ return self.fc3(x)
20
+
21
+
22
+ @pytest.fixture
23
+ def model():
24
+ return SimpleMLP()
25
+
26
+
27
+ @pytest.fixture
28
+ def data():
29
+ torch.manual_seed(42)
30
+ X = torch.randn(100, 20)
31
+ y = torch.randn(100, 1)
32
+ return X, y
33
+
34
+
35
+ def test_basic_step(model, data):
36
+ """LAdam can perform a basic optimization step."""
37
+ X, y = data
38
+ opt = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
39
+ loss_fn = nn.MSELoss()
40
+
41
+ loss_before = loss_fn(model(X), y).item()
42
+ for _ in range(10):
43
+ opt.zero_grad()
44
+ loss = loss_fn(model(X), y)
45
+ loss.backward()
46
+ opt.step()
47
+ loss_after = loss_fn(model(X), y).item()
48
+
49
+ assert loss_after < loss_before, "Loss should decrease after optimization"
50
+
51
+
52
+ def test_c2_zero_equals_adam(model, data):
53
+ """With c2=0, LAdam should behave identically to Adam."""
54
+ X, y = data
55
+ loss_fn = nn.MSELoss()
56
+
57
+ # LAdam with c2=0
58
+ torch.manual_seed(0)
59
+ m1 = SimpleMLP()
60
+ opt1 = LAdam(m1.parameters(), lr=1e-3, c2=0.0)
61
+ for _ in range(5):
62
+ opt1.zero_grad()
63
+ loss_fn(m1(X), y).backward()
64
+ opt1.step()
65
+
66
+ # Standard Adam
67
+ torch.manual_seed(0)
68
+ m2 = SimpleMLP()
69
+ opt2 = torch.optim.Adam(m2.parameters(), lr=1e-3)
70
+ for _ in range(5):
71
+ opt2.zero_grad()
72
+ loss_fn(m2(X), y).backward()
73
+ opt2.step()
74
+
75
+ for p1, p2 in zip(m1.parameters(), m2.parameters()):
76
+ assert torch.allclose(p1, p2, atol=1e-6), \
77
+ "c2=0 LAdam should match standard Adam"
78
+
79
+
80
+ def test_all_modes(model, data):
81
+ """All three modes should run without error and reduce loss."""
82
+ X, y = data
83
+ loss_fn = nn.MSELoss()
84
+
85
+ for mode in ['variance_lap', 'update_lap', 'weight_lap']:
86
+ torch.manual_seed(42)
87
+ m = SimpleMLP()
88
+ opt = LAdam(m.parameters(), lr=1e-3, c2=1e-4, mode=mode)
89
+
90
+ loss_before = loss_fn(m(X), y).item()
91
+ for _ in range(20):
92
+ opt.zero_grad()
93
+ loss_fn(m(X), y).backward()
94
+ opt.step()
95
+ loss_after = loss_fn(m(X), y).item()
96
+
97
+ assert loss_after < loss_before, f"Mode '{mode}' should reduce loss"
98
+
99
+
100
+ def test_invalid_params():
101
+ """Invalid parameters should raise ValueError."""
102
+ m = SimpleMLP()
103
+ with pytest.raises(ValueError):
104
+ LAdam(m.parameters(), lr=-1)
105
+ with pytest.raises(ValueError):
106
+ LAdam(m.parameters(), c2=-1)
107
+ with pytest.raises(ValueError):
108
+ LAdam(m.parameters(), mode='invalid')
109
+ with pytest.raises(ValueError):
110
+ LAdam(m.parameters(), betas=(1.5, 0.999))
111
+
112
+
113
+ def test_suggest_c2():
114
+ """Architecture-specific c2 suggestions."""
115
+ assert suggest_c2('pinn') == 1e-5
116
+ assert suggest_c2('transformer') == 1e-4
117
+ assert suggest_c2('cnn') == 0.0
118
+ assert suggest_c2('auto') == 1e-4
119
+ with pytest.raises(ValueError):
120
+ suggest_c2('unknown_arch')
121
+
122
+
123
+ def test_param_groups(model, data):
124
+ """Per-group c2 values should work."""
125
+ X, y = data
126
+ loss_fn = nn.MSELoss()
127
+
128
+ opt = LAdam([
129
+ {'params': model.fc1.parameters(), 'c2': 1e-4},
130
+ {'params': model.fc2.parameters(), 'c2': 1e-5},
131
+ {'params': model.fc3.parameters(), 'c2': 0.0},
132
+ ], lr=1e-3)
133
+
134
+ for _ in range(10):
135
+ opt.zero_grad()
136
+ loss_fn(model(X), y).backward()
137
+ opt.step()
138
+
139
+
140
+ def test_gpu_if_available(data):
141
+ """LAdam should work on GPU if available."""
142
+ if not torch.cuda.is_available():
143
+ pytest.skip("CUDA not available")
144
+
145
+ X, y = data
146
+ X, y = X.cuda(), y.cuda()
147
+ m = SimpleMLP().cuda()
148
+ opt = LAdam(m.parameters(), lr=1e-3, c2=1e-4)
149
+ loss_fn = nn.MSELoss()
150
+
151
+ for _ in range(5):
152
+ opt.zero_grad()
153
+ loss_fn(m(X), y).backward()
154
+ opt.step()
155
+
156
+
157
+ def test_with_weight_decay(model, data):
158
+ """Weight decay should work alongside Laplacian."""
159
+ X, y = data
160
+ opt = LAdam(model.parameters(), lr=1e-3, c2=1e-4, weight_decay=0.01)
161
+ loss_fn = nn.MSELoss()
162
+
163
+ for _ in range(10):
164
+ opt.zero_grad()
165
+ loss_fn(model(X), y).backward()
166
+ opt.step()
167
+
168
+
169
+ def test_closure(model, data):
170
+ """Closure-based step should work (for LBFGS-style usage)."""
171
+ X, y = data
172
+ opt = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
173
+ loss_fn = nn.MSELoss()
174
+
175
+ def closure():
176
+ opt.zero_grad()
177
+ loss = loss_fn(model(X), y)
178
+ loss.backward()
179
+ return loss
180
+
181
+ loss = opt.step(closure)
182
+ assert loss is not None