ladam 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ladam-0.2.0/.gitignore +16 -0
- ladam-0.2.0/LICENSE +21 -0
- ladam-0.2.0/PKG-INFO +246 -0
- ladam-0.2.0/README.md +216 -0
- ladam-0.2.0/pyproject.toml +49 -0
- ladam-0.2.0/src/ladam/__init__.py +12 -0
- ladam-0.2.0/src/ladam/optimizer.py +442 -0
- ladam-0.2.0/tests/test_optimizer.py +182 -0
ladam-0.2.0/.gitignore
ADDED
ladam-0.2.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Greg Partin
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
ladam-0.2.0/PKG-INFO
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: ladam
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: LAdam: Laplacian Adam — Adam with spatially-coupled variance estimates via discrete Laplacian
|
|
5
|
+
Project-URL: Homepage, https://github.com/gpartin/ladam
|
|
6
|
+
Project-URL: Documentation, https://github.com/gpartin/ladam#usage
|
|
7
|
+
Project-URL: Issues, https://github.com/gpartin/ladam/issues
|
|
8
|
+
Author: Greg Partin
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Keywords: adam,deep-learning,laplacian,optimizer,pinn,pytorch,scientific-ml,spatial-regularization
|
|
12
|
+
Classifier: Development Status :: 4 - Beta
|
|
13
|
+
Classifier: Intended Audience :: Developers
|
|
14
|
+
Classifier: Intended Audience :: Science/Research
|
|
15
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
16
|
+
Classifier: Programming Language :: Python :: 3
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.8
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
21
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
22
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
23
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
24
|
+
Requires-Python: >=3.8
|
|
25
|
+
Requires-Dist: torch>=1.10.0
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: pytest-cov>=4.0; extra == 'dev'
|
|
28
|
+
Requires-Dist: pytest>=7.0; extra == 'dev'
|
|
29
|
+
Description-Content-Type: text/markdown
|
|
30
|
+
|
|
31
|
+
# LAdam
|
|
32
|
+
|
|
33
|
+
**Laplacian Adam — spatially-aware adaptive optimizer for PyTorch**
|
|
34
|
+
|
|
35
|
+
[](https://pypi.org/project/ladam/)
|
|
36
|
+
[](LICENSE)
|
|
37
|
+
[](https://python.org)
|
|
38
|
+
|
|
39
|
+
LAdam is a drop-in Adam replacement that applies **discrete Laplacian regularization** to Adam's second-moment estimate (v_t). This couples neighboring weight learning rates, producing spatially-smoothed adaptive optimization.
|
|
40
|
+
|
|
41
|
+
## Why LAdam?
|
|
42
|
+
|
|
43
|
+
Adam computes independent per-parameter learning rates. But adjacent weights in trained networks are often functionally correlated — the per-parameter variance estimates should reflect this structure.
|
|
44
|
+
|
|
45
|
+
LAdam adds **one operation** to Adam: a Laplacian diffusion step on v_t, controlled by a single scalar `c2`. The Laplacian allows each weight's learning rate to be informed by its neighbors, smoothing the optimization landscape.
|
|
46
|
+
|
|
47
|
+
## Results
|
|
48
|
+
|
|
49
|
+
| Task | Architecture | Metric | Adam | LAdam | Improvement |
|
|
50
|
+
|------|-------------|--------|------|-------|-------------|
|
|
51
|
+
| **Wave Equation PINN** | 5×128 MLP | L2 Error | 0.0310 | **0.0172** | **-44.6%** |
|
|
52
|
+
| **FashionMNIST** | Transformer | Accuracy | 89.46% | **89.66%** | **+0.20%** (p=0.0005) |
|
|
53
|
+
| FashionMNIST | MLP | Accuracy | 89.10% | 89.12% | +0.02% (n.s.) |
|
|
54
|
+
| FashionMNIST | CNN | Accuracy | 91.15% | 91.14% | -0.01% (tie) |
|
|
55
|
+
|
|
56
|
+
> **LAdam excels on architectures with spatially-correlated weight structure** — particularly PINNs and transformers. For CNNs (whose conv filters are already spatial detectors), the Laplacian is redundant.
|
|
57
|
+
|
|
58
|
+
## Installation
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
pip install ladam
|
|
62
|
+
```
|
|
63
|
+
|
|
64
|
+
## Optimizers
|
|
65
|
+
|
|
66
|
+
LAdam ships three Laplacian-enhanced optimizers:
|
|
67
|
+
|
|
68
|
+
| Optimizer | Base | Laplacian target | Best for |
|
|
69
|
+
|-----------|------|------------------|----------|
|
|
70
|
+
| **LAdam** | Adam | Second moment v_t | PINNs, transformers, CNNs |
|
|
71
|
+
| **LAdaGrad** | AdaGrad | Cumulative sum G_t | Sparse features, NLP |
|
|
72
|
+
| **LRMSProp** | RMSProp | Running average v_t | RNNs, non-stationary losses |
|
|
73
|
+
|
|
74
|
+
All three share the same Laplacian kernel infrastructure and `c2` parameter.
|
|
75
|
+
|
|
76
|
+
## Usage
|
|
77
|
+
|
|
78
|
+
### Basic — Drop-in Adam replacement
|
|
79
|
+
|
|
80
|
+
```python
|
|
81
|
+
from ladam import LAdam
|
|
82
|
+
|
|
83
|
+
optimizer = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
|
|
84
|
+
|
|
85
|
+
# Training loop is identical to Adam
|
|
86
|
+
for batch in dataloader:
|
|
87
|
+
loss = criterion(model(batch))
|
|
88
|
+
loss.backward()
|
|
89
|
+
optimizer.step()
|
|
90
|
+
optimizer.zero_grad()
|
|
91
|
+
```
|
|
92
|
+
|
|
93
|
+
### LAdaGrad and LRMSProp
|
|
94
|
+
|
|
95
|
+
```python
|
|
96
|
+
from ladam import LAdaGrad, LRMSProp
|
|
97
|
+
|
|
98
|
+
# AdaGrad with Laplacian smoothing on cumulative squared gradients
|
|
99
|
+
optimizer = LAdaGrad(model.parameters(), lr=1e-2, c2=1e-4)
|
|
100
|
+
|
|
101
|
+
# RMSProp with Laplacian smoothing on running variance
|
|
102
|
+
optimizer = LRMSProp(model.parameters(), lr=1e-2, alpha=0.99, c2=1e-4)
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
### Per-layer c2 with parameter groups
|
|
106
|
+
|
|
107
|
+
```python
|
|
108
|
+
optimizer = LAdam([
|
|
109
|
+
{'params': model.attention.parameters(), 'c2': 1e-4}, # Transformer attention
|
|
110
|
+
{'params': model.ffn.parameters(), 'c2': 1e-5}, # Feed-forward
|
|
111
|
+
{'params': model.norm.parameters(), 'c2': 0.0}, # Skip for norms
|
|
112
|
+
], lr=3e-4)
|
|
113
|
+
```
|
|
114
|
+
|
|
115
|
+
### Architecture-aware defaults
|
|
116
|
+
|
|
117
|
+
```python
|
|
118
|
+
from ladam import LAdam, suggest_c2
|
|
119
|
+
|
|
120
|
+
c2 = suggest_c2('pinn') # Returns 1e-5
|
|
121
|
+
c2 = suggest_c2('transformer') # Returns 1e-4
|
|
122
|
+
|
|
123
|
+
optimizer = LAdam(model.parameters(), lr=1e-3, c2=c2)
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
## Parameters
|
|
127
|
+
|
|
128
|
+
| Parameter | Default | Description |
|
|
129
|
+
|-----------|---------|-------------|
|
|
130
|
+
| `lr` | 1e-3 | Learning rate |
|
|
131
|
+
| `betas` | (0.9, 0.999) | EMA coefficients (same as Adam) |
|
|
132
|
+
| `eps` | 1e-8 | Numerical stability (same as Adam) |
|
|
133
|
+
| `weight_decay` | 0 | L2 regularization (same as AdamW behavior) |
|
|
134
|
+
| `c2` | 1e-4 | **Laplacian coupling strength.** Controls how much neighboring variance estimates influence each other. |
|
|
135
|
+
| `mode` | 'variance_lap' | Which quantity to smooth. `'variance_lap'` is best. |
|
|
136
|
+
| `stencil` | '9point' | **Discrete Laplacian stencil.** `'9point'` (isotropic, 0.46% anisotropy) or `'5point'` (legacy, 12.3% anisotropy). |
|
|
137
|
+
| `min_spatial_size` | 16 | Skip Laplacian for params with fewer elements (biases, LayerNorm). |
|
|
138
|
+
|
|
139
|
+
### Stencil Selection
|
|
140
|
+
|
|
141
|
+
The `stencil` parameter controls the discrete Laplacian kernel used for spatial coupling:
|
|
142
|
+
|
|
143
|
+
- **`'9point'` (default)**: Isotropic stencil with face + edge neighbors. Treats diagonal neighbors with 1/6 weight vs 4/6 for face neighbors.
|
|
144
|
+
- **`'5point'`**: Standard cross-pattern stencil (faces only). Slightly faster but 25× more anisotropic.
|
|
145
|
+
|
|
146
|
+
At typical `c2` values (1e-5 to 1e-3), the effective learning rate difference between stencils is <0.3%. The 9-point default is recommended for correctness.
|
|
147
|
+
|
|
148
|
+
### Choosing c2
|
|
149
|
+
|
|
150
|
+
`c2` is the only new hyperparameter. It's robust across 3 orders of magnitude:
|
|
151
|
+
|
|
152
|
+
| c2 | Best For | Notes |
|
|
153
|
+
|----|----------|-------|
|
|
154
|
+
| `1e-5` | PINNs, scientific ML | Gentle coupling, biggest error reduction |
|
|
155
|
+
| `1e-4` | Transformers, general | **Safe default** |
|
|
156
|
+
| `1e-3` | Aggressive smoothing | Works but slightly less stable |
|
|
157
|
+
| `0` | Disable | Reduces to standard Adam |
|
|
158
|
+
|
|
159
|
+
All 7 values tested in [1e-6, 1e-3] outperformed Adam on transformers (B12 sweep).
|
|
160
|
+
|
|
161
|
+
## How It Works
|
|
162
|
+
|
|
163
|
+
Standard Adam computes per-parameter adaptive learning rates from the second moment:
|
|
164
|
+
|
|
165
|
+
```
|
|
166
|
+
v_t = β₂·v_{t-1} + (1-β₂)·g_t² # Variance estimate
|
|
167
|
+
lr_effective = lr / (√v_t + ε) # Per-parameter learning rate
|
|
168
|
+
```
|
|
169
|
+
|
|
170
|
+
LAdam adds a Laplacian coupling step:
|
|
171
|
+
|
|
172
|
+
```
|
|
173
|
+
v_smooth = v_t + c2 · ∇²v_t # Spatial smoothing
|
|
174
|
+
lr_effective = lr / (√v_smooth + ε) # Coupled learning rate
|
|
175
|
+
```
|
|
176
|
+
|
|
177
|
+
Where `\nabla^2` is the discrete Laplacian computed via a single `F.conv2d` kernel (9-point isotropic by default) -- efficient and GPU-friendly. The Laplacian treats weight matrices as 2D fields, coupling each weight's learning rate with its spatial neighbors.
|
|
178
|
+
|
|
179
|
+
**Overhead**: ~2-5% wall-clock time increase per step. The Laplacian is a single fused convolution kernel, not point-wise iteration.
|
|
180
|
+
|
|
181
|
+
## Benchmarks
|
|
182
|
+
|
|
183
|
+
### PINN: Wave Equation (u_tt = c^2 u_xx)
|
|
184
|
+
|
|
185
|
+
5-layer, 128-unit tanh MLP trained for 5000 steps on the 1D wave equation.
|
|
186
|
+
|
|
187
|
+
| Optimizer | L2 Error | vs Adam |
|
|
188
|
+
|-----------|----------|---------|
|
|
189
|
+
| Adam (lr=1e-3) | 0.0310 | — |
|
|
190
|
+
| LAdam c²=1e-4 | 0.0240 | -22.8% |
|
|
191
|
+
| **LAdam c²=1e-5** | **0.0172** | **-44.6%** |
|
|
192
|
+
| LAdam c²=1e-3 | 0.0185 | -40.3% |
|
|
193
|
+
|
|
194
|
+
### Transformer: FashionMNIST Classification
|
|
195
|
+
|
|
196
|
+
4-head, 128-dim, 2-layer transformer, 30 epochs, 5 independent seeds.
|
|
197
|
+
|
|
198
|
+
| Optimizer | Accuracy (mean ± std) | p-value (vs Adam) |
|
|
199
|
+
|-----------|----------------------|-------------------|
|
|
200
|
+
| Adam | 89.46 ± 0.10% | — |
|
|
201
|
+
| **LAdam c²=1e-4** | **89.66 ± 0.06%** | **0.0005** |
|
|
202
|
+
|
|
203
|
+
### c² Robustness Sweep
|
|
204
|
+
|
|
205
|
+
7 c² values on the same transformer task. **All 7 beat Adam:**
|
|
206
|
+
|
|
207
|
+
| c² | Accuracy | Δ vs Adam |
|
|
208
|
+
|----|----------|-----------|
|
|
209
|
+
| 1e-6 | 89.62% | +0.16% |
|
|
210
|
+
| 5e-6 | 89.73% | +0.27% |
|
|
211
|
+
| 1e-5 | 89.79% | +0.33% |
|
|
212
|
+
| 5e-5 | 89.75% | +0.29% |
|
|
213
|
+
| 1e-4 | 89.67% | +0.21% |
|
|
214
|
+
| 5e-4 | 89.64% | +0.18% |
|
|
215
|
+
| 1e-3 | 89.66% | +0.20% |
|
|
216
|
+
|
|
217
|
+
## FAQ
|
|
218
|
+
|
|
219
|
+
**Q: Does this work for LLMs / GPT-scale models?**
|
|
220
|
+
A: No. LAdam **hurts** LLM training (tested on GPT-2/WikiText-2). Attention weight matrices encode semantic structure, not spatial structure — the Laplacian destroys per-feature specialization. Use standard Adam/AdamW for LLMs.
|
|
221
|
+
|
|
222
|
+
**Q: Why not smooth the gradient instead of the variance?**
|
|
223
|
+
A: [Osher et al. (2018)](https://arxiv.org/abs/1806.06317) explored Laplacian smoothing of gradients. We found that smoothing the *variance estimate* is more effective because it smooths the *learning rate landscape* rather than the *descent direction*. These are mathematically distinct: ∇²(EMA(g²)) ≠ (∇²g)².
|
|
224
|
+
|
|
225
|
+
**Q: Why does this help PINNs so much?**
|
|
226
|
+
A: PDE-based loss landscapes have inherent spatial structure from the differential operators in the loss function. The Laplacian on v_t aligns the optimizer's internal representation with this structure.
|
|
227
|
+
|
|
228
|
+
**Q: Can I use this with learning rate schedulers?**
|
|
229
|
+
A: Yes. LAdam is fully compatible with any `torch.optim.lr_scheduler`.
|
|
230
|
+
|
|
231
|
+
## Citation
|
|
232
|
+
|
|
233
|
+
If you use LAdam in your research, please cite:
|
|
234
|
+
|
|
235
|
+
```bibtex
|
|
236
|
+
@software{partin2026ladam,
|
|
237
|
+
author = {Partin, Greg},
|
|
238
|
+
title = {LAdam: Spatially-Aware Adaptive Optimization via Laplacian-Regularized Variance Estimates},
|
|
239
|
+
year = {2026},
|
|
240
|
+
url = {https://github.com/gpartin/ladam}
|
|
241
|
+
}
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
## License
|
|
245
|
+
|
|
246
|
+
MIT. See [LICENSE](LICENSE) for details.
|
ladam-0.2.0/README.md
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# LAdam
|
|
2
|
+
|
|
3
|
+
**Laplacian Adam — spatially-aware adaptive optimizer for PyTorch**
|
|
4
|
+
|
|
5
|
+
[](https://pypi.org/project/ladam/)
|
|
6
|
+
[](LICENSE)
|
|
7
|
+
[](https://python.org)
|
|
8
|
+
|
|
9
|
+
LAdam is a drop-in Adam replacement that applies **discrete Laplacian regularization** to Adam's second-moment estimate (v_t). This couples neighboring weight learning rates, producing spatially-smoothed adaptive optimization.
|
|
10
|
+
|
|
11
|
+
## Why LAdam?
|
|
12
|
+
|
|
13
|
+
Adam computes independent per-parameter learning rates. But adjacent weights in trained networks are often functionally correlated — the per-parameter variance estimates should reflect this structure.
|
|
14
|
+
|
|
15
|
+
LAdam adds **one operation** to Adam: a Laplacian diffusion step on v_t, controlled by a single scalar `c2`. The Laplacian allows each weight's learning rate to be informed by its neighbors, smoothing the optimization landscape.
|
|
16
|
+
|
|
17
|
+
## Results
|
|
18
|
+
|
|
19
|
+
| Task | Architecture | Metric | Adam | LAdam | Improvement |
|
|
20
|
+
|------|-------------|--------|------|-------|-------------|
|
|
21
|
+
| **Wave Equation PINN** | 5×128 MLP | L2 Error | 0.0310 | **0.0172** | **-44.6%** |
|
|
22
|
+
| **FashionMNIST** | Transformer | Accuracy | 89.46% | **89.66%** | **+0.20%** (p=0.0005) |
|
|
23
|
+
| FashionMNIST | MLP | Accuracy | 89.10% | 89.12% | +0.02% (n.s.) |
|
|
24
|
+
| FashionMNIST | CNN | Accuracy | 91.15% | 91.14% | -0.01% (tie) |
|
|
25
|
+
|
|
26
|
+
> **LAdam excels on architectures with spatially-correlated weight structure** — particularly PINNs and transformers. For CNNs (whose conv filters are already spatial detectors), the Laplacian is redundant.
|
|
27
|
+
|
|
28
|
+
## Installation
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
pip install ladam
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
## Optimizers
|
|
35
|
+
|
|
36
|
+
LAdam ships three Laplacian-enhanced optimizers:
|
|
37
|
+
|
|
38
|
+
| Optimizer | Base | Laplacian target | Best for |
|
|
39
|
+
|-----------|------|------------------|----------|
|
|
40
|
+
| **LAdam** | Adam | Second moment v_t | PINNs, transformers, CNNs |
|
|
41
|
+
| **LAdaGrad** | AdaGrad | Cumulative sum G_t | Sparse features, NLP |
|
|
42
|
+
| **LRMSProp** | RMSProp | Running average v_t | RNNs, non-stationary losses |
|
|
43
|
+
|
|
44
|
+
All three share the same Laplacian kernel infrastructure and `c2` parameter.
|
|
45
|
+
|
|
46
|
+
## Usage
|
|
47
|
+
|
|
48
|
+
### Basic — Drop-in Adam replacement
|
|
49
|
+
|
|
50
|
+
```python
|
|
51
|
+
from ladam import LAdam
|
|
52
|
+
|
|
53
|
+
optimizer = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
|
|
54
|
+
|
|
55
|
+
# Training loop is identical to Adam
|
|
56
|
+
for batch in dataloader:
|
|
57
|
+
loss = criterion(model(batch))
|
|
58
|
+
loss.backward()
|
|
59
|
+
optimizer.step()
|
|
60
|
+
optimizer.zero_grad()
|
|
61
|
+
```
|
|
62
|
+
|
|
63
|
+
### LAdaGrad and LRMSProp
|
|
64
|
+
|
|
65
|
+
```python
|
|
66
|
+
from ladam import LAdaGrad, LRMSProp
|
|
67
|
+
|
|
68
|
+
# AdaGrad with Laplacian smoothing on cumulative squared gradients
|
|
69
|
+
optimizer = LAdaGrad(model.parameters(), lr=1e-2, c2=1e-4)
|
|
70
|
+
|
|
71
|
+
# RMSProp with Laplacian smoothing on running variance
|
|
72
|
+
optimizer = LRMSProp(model.parameters(), lr=1e-2, alpha=0.99, c2=1e-4)
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
### Per-layer c2 with parameter groups
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
optimizer = LAdam([
|
|
79
|
+
{'params': model.attention.parameters(), 'c2': 1e-4}, # Transformer attention
|
|
80
|
+
{'params': model.ffn.parameters(), 'c2': 1e-5}, # Feed-forward
|
|
81
|
+
{'params': model.norm.parameters(), 'c2': 0.0}, # Skip for norms
|
|
82
|
+
], lr=3e-4)
|
|
83
|
+
```
|
|
84
|
+
|
|
85
|
+
### Architecture-aware defaults
|
|
86
|
+
|
|
87
|
+
```python
|
|
88
|
+
from ladam import LAdam, suggest_c2
|
|
89
|
+
|
|
90
|
+
c2 = suggest_c2('pinn') # Returns 1e-5
|
|
91
|
+
c2 = suggest_c2('transformer') # Returns 1e-4
|
|
92
|
+
|
|
93
|
+
optimizer = LAdam(model.parameters(), lr=1e-3, c2=c2)
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
## Parameters
|
|
97
|
+
|
|
98
|
+
| Parameter | Default | Description |
|
|
99
|
+
|-----------|---------|-------------|
|
|
100
|
+
| `lr` | 1e-3 | Learning rate |
|
|
101
|
+
| `betas` | (0.9, 0.999) | EMA coefficients (same as Adam) |
|
|
102
|
+
| `eps` | 1e-8 | Numerical stability (same as Adam) |
|
|
103
|
+
| `weight_decay` | 0 | L2 regularization (same as AdamW behavior) |
|
|
104
|
+
| `c2` | 1e-4 | **Laplacian coupling strength.** Controls how much neighboring variance estimates influence each other. |
|
|
105
|
+
| `mode` | 'variance_lap' | Which quantity to smooth. `'variance_lap'` is best. |
|
|
106
|
+
| `stencil` | '9point' | **Discrete Laplacian stencil.** `'9point'` (isotropic, 0.46% anisotropy) or `'5point'` (legacy, 12.3% anisotropy). |
|
|
107
|
+
| `min_spatial_size` | 16 | Skip Laplacian for params with fewer elements (biases, LayerNorm). |
|
|
108
|
+
|
|
109
|
+
### Stencil Selection
|
|
110
|
+
|
|
111
|
+
The `stencil` parameter controls the discrete Laplacian kernel used for spatial coupling:
|
|
112
|
+
|
|
113
|
+
- **`'9point'` (default)**: Isotropic stencil with face + edge neighbors. Treats diagonal neighbors with 1/6 weight vs 4/6 for face neighbors.
|
|
114
|
+
- **`'5point'`**: Standard cross-pattern stencil (faces only). Slightly faster but 25× more anisotropic.
|
|
115
|
+
|
|
116
|
+
At typical `c2` values (1e-5 to 1e-3), the effective learning rate difference between stencils is <0.3%. The 9-point default is recommended for correctness.
|
|
117
|
+
|
|
118
|
+
### Choosing c2
|
|
119
|
+
|
|
120
|
+
`c2` is the only new hyperparameter. It's robust across 3 orders of magnitude:
|
|
121
|
+
|
|
122
|
+
| c2 | Best For | Notes |
|
|
123
|
+
|----|----------|-------|
|
|
124
|
+
| `1e-5` | PINNs, scientific ML | Gentle coupling, biggest error reduction |
|
|
125
|
+
| `1e-4` | Transformers, general | **Safe default** |
|
|
126
|
+
| `1e-3` | Aggressive smoothing | Works but slightly less stable |
|
|
127
|
+
| `0` | Disable | Reduces to standard Adam |
|
|
128
|
+
|
|
129
|
+
All 7 values tested in [1e-6, 1e-3] outperformed Adam on transformers (B12 sweep).
|
|
130
|
+
|
|
131
|
+
## How It Works
|
|
132
|
+
|
|
133
|
+
Standard Adam computes per-parameter adaptive learning rates from the second moment:
|
|
134
|
+
|
|
135
|
+
```
|
|
136
|
+
v_t = β₂·v_{t-1} + (1-β₂)·g_t² # Variance estimate
|
|
137
|
+
lr_effective = lr / (√v_t + ε) # Per-parameter learning rate
|
|
138
|
+
```
|
|
139
|
+
|
|
140
|
+
LAdam adds a Laplacian coupling step:
|
|
141
|
+
|
|
142
|
+
```
|
|
143
|
+
v_smooth = v_t + c2 · ∇²v_t # Spatial smoothing
|
|
144
|
+
lr_effective = lr / (√v_smooth + ε) # Coupled learning rate
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
Where `\nabla^2` is the discrete Laplacian computed via a single `F.conv2d` kernel (9-point isotropic by default) -- efficient and GPU-friendly. The Laplacian treats weight matrices as 2D fields, coupling each weight's learning rate with its spatial neighbors.
|
|
148
|
+
|
|
149
|
+
**Overhead**: ~2-5% wall-clock time increase per step. The Laplacian is a single fused convolution kernel, not point-wise iteration.
|
|
150
|
+
|
|
151
|
+
## Benchmarks
|
|
152
|
+
|
|
153
|
+
### PINN: Wave Equation (u_tt = c^2 u_xx)
|
|
154
|
+
|
|
155
|
+
5-layer, 128-unit tanh MLP trained for 5000 steps on the 1D wave equation.
|
|
156
|
+
|
|
157
|
+
| Optimizer | L2 Error | vs Adam |
|
|
158
|
+
|-----------|----------|---------|
|
|
159
|
+
| Adam (lr=1e-3) | 0.0310 | — |
|
|
160
|
+
| LAdam c²=1e-4 | 0.0240 | -22.8% |
|
|
161
|
+
| **LAdam c²=1e-5** | **0.0172** | **-44.6%** |
|
|
162
|
+
| LAdam c²=1e-3 | 0.0185 | -40.3% |
|
|
163
|
+
|
|
164
|
+
### Transformer: FashionMNIST Classification
|
|
165
|
+
|
|
166
|
+
4-head, 128-dim, 2-layer transformer, 30 epochs, 5 independent seeds.
|
|
167
|
+
|
|
168
|
+
| Optimizer | Accuracy (mean ± std) | p-value (vs Adam) |
|
|
169
|
+
|-----------|----------------------|-------------------|
|
|
170
|
+
| Adam | 89.46 ± 0.10% | — |
|
|
171
|
+
| **LAdam c²=1e-4** | **89.66 ± 0.06%** | **0.0005** |
|
|
172
|
+
|
|
173
|
+
### c² Robustness Sweep
|
|
174
|
+
|
|
175
|
+
7 c² values on the same transformer task. **All 7 beat Adam:**
|
|
176
|
+
|
|
177
|
+
| c² | Accuracy | Δ vs Adam |
|
|
178
|
+
|----|----------|-----------|
|
|
179
|
+
| 1e-6 | 89.62% | +0.16% |
|
|
180
|
+
| 5e-6 | 89.73% | +0.27% |
|
|
181
|
+
| 1e-5 | 89.79% | +0.33% |
|
|
182
|
+
| 5e-5 | 89.75% | +0.29% |
|
|
183
|
+
| 1e-4 | 89.67% | +0.21% |
|
|
184
|
+
| 5e-4 | 89.64% | +0.18% |
|
|
185
|
+
| 1e-3 | 89.66% | +0.20% |
|
|
186
|
+
|
|
187
|
+
## FAQ
|
|
188
|
+
|
|
189
|
+
**Q: Does this work for LLMs / GPT-scale models?**
|
|
190
|
+
A: No. LAdam **hurts** LLM training (tested on GPT-2/WikiText-2). Attention weight matrices encode semantic structure, not spatial structure — the Laplacian destroys per-feature specialization. Use standard Adam/AdamW for LLMs.
|
|
191
|
+
|
|
192
|
+
**Q: Why not smooth the gradient instead of the variance?**
|
|
193
|
+
A: [Osher et al. (2018)](https://arxiv.org/abs/1806.06317) explored Laplacian smoothing of gradients. We found that smoothing the *variance estimate* is more effective because it smooths the *learning rate landscape* rather than the *descent direction*. These are mathematically distinct: ∇²(EMA(g²)) ≠ (∇²g)².
|
|
194
|
+
|
|
195
|
+
**Q: Why does this help PINNs so much?**
|
|
196
|
+
A: PDE-based loss landscapes have inherent spatial structure from the differential operators in the loss function. The Laplacian on v_t aligns the optimizer's internal representation with this structure.
|
|
197
|
+
|
|
198
|
+
**Q: Can I use this with learning rate schedulers?**
|
|
199
|
+
A: Yes. LAdam is fully compatible with any `torch.optim.lr_scheduler`.
|
|
200
|
+
|
|
201
|
+
## Citation
|
|
202
|
+
|
|
203
|
+
If you use LAdam in your research, please cite:
|
|
204
|
+
|
|
205
|
+
```bibtex
|
|
206
|
+
@software{partin2026ladam,
|
|
207
|
+
author = {Partin, Greg},
|
|
208
|
+
title = {LAdam: Spatially-Aware Adaptive Optimization via Laplacian-Regularized Variance Estimates},
|
|
209
|
+
year = {2026},
|
|
210
|
+
url = {https://github.com/gpartin/ladam}
|
|
211
|
+
}
|
|
212
|
+
```
|
|
213
|
+
|
|
214
|
+
## License
|
|
215
|
+
|
|
216
|
+
MIT. See [LICENSE](LICENSE) for details.
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "ladam"
|
|
7
|
+
version = "0.2.0"
|
|
8
|
+
description = "LAdam: Laplacian Adam — Adam with spatially-coupled variance estimates via discrete Laplacian"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = {text = "MIT"}
|
|
11
|
+
requires-python = ">=3.8"
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Greg Partin"},
|
|
14
|
+
]
|
|
15
|
+
keywords = [
|
|
16
|
+
"optimizer", "adam", "deep-learning", "pytorch",
|
|
17
|
+
"laplacian", "pinn", "scientific-ml", "spatial-regularization",
|
|
18
|
+
]
|
|
19
|
+
classifiers = [
|
|
20
|
+
"Development Status :: 4 - Beta",
|
|
21
|
+
"Intended Audience :: Science/Research",
|
|
22
|
+
"Intended Audience :: Developers",
|
|
23
|
+
"License :: OSI Approved :: MIT License",
|
|
24
|
+
"Programming Language :: Python :: 3",
|
|
25
|
+
"Programming Language :: Python :: 3.8",
|
|
26
|
+
"Programming Language :: Python :: 3.9",
|
|
27
|
+
"Programming Language :: Python :: 3.10",
|
|
28
|
+
"Programming Language :: Python :: 3.11",
|
|
29
|
+
"Programming Language :: Python :: 3.12",
|
|
30
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
31
|
+
"Topic :: Scientific/Engineering :: Mathematics",
|
|
32
|
+
]
|
|
33
|
+
dependencies = [
|
|
34
|
+
"torch>=1.10.0",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
[project.optional-dependencies]
|
|
38
|
+
dev = [
|
|
39
|
+
"pytest>=7.0",
|
|
40
|
+
"pytest-cov>=4.0",
|
|
41
|
+
]
|
|
42
|
+
|
|
43
|
+
[project.urls]
|
|
44
|
+
Homepage = "https://github.com/gpartin/ladam"
|
|
45
|
+
Documentation = "https://github.com/gpartin/ladam#usage"
|
|
46
|
+
Issues = "https://github.com/gpartin/ladam/issues"
|
|
47
|
+
|
|
48
|
+
[tool.hatch.build.targets.wheel]
|
|
49
|
+
packages = ["src/ladam"]
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LAdam — Laplacian Adam Optimizer
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
__version__ = "0.2.0"
|
|
6
|
+
|
|
7
|
+
from .optimizer import LAdam, LAdaGrad, LRMSProp, suggest_c2
|
|
8
|
+
|
|
9
|
+
# Backward compatibility alias
|
|
10
|
+
WaveAdam = LAdam
|
|
11
|
+
|
|
12
|
+
__all__ = ["LAdam", "LAdaGrad", "LRMSProp", "WaveAdam", "suggest_c2"]
|
|
@@ -0,0 +1,442 @@
|
|
|
1
|
+
"""
|
|
2
|
+
LAdam — Laplacian Adam Optimizer
|
|
3
|
+
================================
|
|
4
|
+
|
|
5
|
+
A drop-in Adam replacement that applies discrete Laplacian regularization
|
|
6
|
+
to the second moment estimate (v_t), producing spatially-smoothed adaptive
|
|
7
|
+
learning rates.
|
|
8
|
+
|
|
9
|
+
Key insight: adjacent weights in neural networks are often functionally
|
|
10
|
+
correlated, but Adam computes independent per-parameter learning rates.
|
|
11
|
+
LAdam restores spatial coherence by applying a Laplacian diffusion
|
|
12
|
+
operator to Adam's variance estimate, allowing neighboring weights to
|
|
13
|
+
share curvature information.
|
|
14
|
+
|
|
15
|
+
This is particularly effective for:
|
|
16
|
+
- **PINNs** (physics-informed neural networks): -44.6% L2 error vs Adam
|
|
17
|
+
- **Transformers**: +0.20% accuracy, p=0.0005 across 5 seeds
|
|
18
|
+
- Any architecture with spatially-correlated weight structure
|
|
19
|
+
|
|
20
|
+
Usage:
|
|
21
|
+
from ladam import LAdam
|
|
22
|
+
|
|
23
|
+
optimizer = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
|
|
24
|
+
|
|
25
|
+
for batch in dataloader:
|
|
26
|
+
loss = model(batch)
|
|
27
|
+
loss.backward()
|
|
28
|
+
optimizer.step()
|
|
29
|
+
optimizer.zero_grad()
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
import torch
|
|
33
|
+
import torch.nn.functional as F
|
|
34
|
+
from torch.optim import Optimizer
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class LAdam(Optimizer):
|
|
38
|
+
"""
|
|
39
|
+
LAdam: Laplacian Adam — Adam with Laplacian-regularized variance estimates.
|
|
40
|
+
|
|
41
|
+
Applies a discrete Laplacian to Adam's second-moment estimate v_t,
|
|
42
|
+
coupling adjacent weight learning rates. This smooths the adaptive
|
|
43
|
+
learning rate landscape, improving convergence for architectures with
|
|
44
|
+
spatially-structured parameters.
|
|
45
|
+
|
|
46
|
+
Modes:
|
|
47
|
+
'variance_lap': Laplacian on v_t (default, best overall)
|
|
48
|
+
'update_lap': Laplacian on the update direction
|
|
49
|
+
'weight_lap': Standard Adam step + weight-space diffusion
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
params: Model parameters or param groups.
|
|
53
|
+
lr (float): Learning rate. Default: 1e-3.
|
|
54
|
+
betas (Tuple[float, float]): EMA coefficients for (m_t, v_t).
|
|
55
|
+
Default: (0.9, 0.999).
|
|
56
|
+
eps (float): Numerical stability. Default: 1e-8.
|
|
57
|
+
weight_decay (float): L2 weight decay. Default: 0.
|
|
58
|
+
c2 (float): Laplacian coupling strength. Controls how much
|
|
59
|
+
neighboring variance estimates influence each other.
|
|
60
|
+
Recommended: 1e-5 for PINNs, 1e-4 for transformers.
|
|
61
|
+
Default: 1e-4.
|
|
62
|
+
mode (str): Which internal quantity to apply the Laplacian to.
|
|
63
|
+
Default: 'variance_lap'.
|
|
64
|
+
stencil (str): Laplacian stencil type for 2D parameters.
|
|
65
|
+
'9point': Isotropic stencil [1,4,1;4,-20,4;1,4,1]/6 — 0.46%
|
|
66
|
+
anisotropy. Default. Based on LFM lattice geometry
|
|
67
|
+
(face weight 2/3, diagonal weight 1/6).
|
|
68
|
+
'5point': Standard stencil [0,1,0;1,-4,1;0,1,0] — 12.3%
|
|
69
|
+
anisotropy. Legacy option for reproducing prior results.
|
|
70
|
+
min_spatial_size (int): Skip Laplacian for parameters with fewer
|
|
71
|
+
elements than this threshold. Small params (biases, norms)
|
|
72
|
+
lack meaningful spatial structure. Default: 16.
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
# Pre-built Laplacian kernels (cached on first use per device/stencil)
|
|
76
|
+
_lap_kernels_2d = {}
|
|
77
|
+
_lap_kernel_1d = None
|
|
78
|
+
|
|
79
|
+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
|
|
80
|
+
weight_decay=0, c2=1e-4, mode='variance_lap',
|
|
81
|
+
stencil='9point', min_spatial_size=16):
|
|
82
|
+
if lr < 0.0:
|
|
83
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
84
|
+
if eps < 0.0:
|
|
85
|
+
raise ValueError(f"Invalid epsilon value: {eps}")
|
|
86
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
87
|
+
raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
|
|
88
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
89
|
+
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
|
90
|
+
if c2 < 0.0:
|
|
91
|
+
raise ValueError(f"Invalid coupling strength: {c2}")
|
|
92
|
+
if mode not in ('variance_lap', 'update_lap', 'weight_lap'):
|
|
93
|
+
raise ValueError(f"Invalid mode: {mode}. Choose from: "
|
|
94
|
+
"'variance_lap', 'update_lap', 'weight_lap'")
|
|
95
|
+
if stencil not in ('5point', '9point'):
|
|
96
|
+
raise ValueError(f"Invalid stencil: {stencil}. Choose from: "
|
|
97
|
+
"'5point', '9point'")
|
|
98
|
+
|
|
99
|
+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
|
100
|
+
c2=c2, mode=mode, stencil=stencil,
|
|
101
|
+
min_spatial_size=min_spatial_size)
|
|
102
|
+
super().__init__(params, defaults)
|
|
103
|
+
|
|
104
|
+
@classmethod
|
|
105
|
+
def _get_lap_kernel_2d(cls, device, dtype, stencil='9point'):
|
|
106
|
+
"""Lazily create and cache the 2D discrete Laplacian kernel.
|
|
107
|
+
|
|
108
|
+
stencil='9point' (default): Isotropic [1,4,1;4,-20,4;1,4,1]/6
|
|
109
|
+
Isotropic stencil — faces weighted 2/3, diagonals 1/6.
|
|
110
|
+
Only 0.46% anisotropy vs 12.3% for 5-point.
|
|
111
|
+
stencil='5point': Standard [0,1,0;1,-4,1;0,1,0].
|
|
112
|
+
Standard cross-pattern stencil.
|
|
113
|
+
"""
|
|
114
|
+
key = (stencil, device, dtype)
|
|
115
|
+
if key not in cls._lap_kernels_2d:
|
|
116
|
+
if stencil == '9point':
|
|
117
|
+
# Isotropic 2D Laplacian: face weight 4/6, diagonal weight 1/6
|
|
118
|
+
# Minimizes angular anisotropy (0.46% vs 12.3% for 5-point)
|
|
119
|
+
k = torch.tensor([[1/6., 4/6., 1/6.],
|
|
120
|
+
[4/6., -20/6., 4/6.],
|
|
121
|
+
[1/6., 4/6., 1/6.]], device=device, dtype=dtype)
|
|
122
|
+
else: # 5point
|
|
123
|
+
k = torch.tensor([[0., 1., 0.],
|
|
124
|
+
[1., -4., 1.],
|
|
125
|
+
[0., 1., 0.]], device=device, dtype=dtype)
|
|
126
|
+
cls._lap_kernels_2d[key] = k.reshape(1, 1, 3, 3)
|
|
127
|
+
return cls._lap_kernels_2d[key]
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def _get_lap_kernel_1d(cls, device, dtype):
|
|
131
|
+
"""Lazily create and cache the 1D discrete Laplacian kernel."""
|
|
132
|
+
if cls._lap_kernel_1d is None or cls._lap_kernel_1d.device != device:
|
|
133
|
+
k = torch.tensor([1., -2., 1.], device=device, dtype=dtype)
|
|
134
|
+
cls._lap_kernel_1d = k.reshape(1, 1, 3)
|
|
135
|
+
return cls._lap_kernel_1d.to(dtype=dtype)
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def _laplacian(W, kernel_2d, kernel_1d):
|
|
139
|
+
"""
|
|
140
|
+
Apply discrete Laplacian via fused conv2d — single GPU kernel launch.
|
|
141
|
+
|
|
142
|
+
Uses circular padding to treat weight matrices as periodic fields.
|
|
143
|
+
High-dimensional tensors are reshaped to 2D while preserving spatial
|
|
144
|
+
structure.
|
|
145
|
+
"""
|
|
146
|
+
if W.dim() == 1:
|
|
147
|
+
n = W.numel()
|
|
148
|
+
if n < 4:
|
|
149
|
+
return torch.zeros_like(W)
|
|
150
|
+
x = W.reshape(1, 1, n)
|
|
151
|
+
x = F.pad(x, (1, 1), mode='circular')
|
|
152
|
+
return F.conv1d(x, kernel_1d).reshape(W.shape)
|
|
153
|
+
else:
|
|
154
|
+
shape = W.shape
|
|
155
|
+
W2 = W.reshape(shape[0], -1) if W.dim() > 2 else W
|
|
156
|
+
h, w = W2.shape
|
|
157
|
+
if h < 3 or w < 3:
|
|
158
|
+
return torch.zeros_like(W)
|
|
159
|
+
x = W2.reshape(1, 1, h, w)
|
|
160
|
+
x = F.pad(x, (1, 1, 1, 1), mode='circular')
|
|
161
|
+
result = F.conv2d(x, kernel_2d).reshape(W2.shape)
|
|
162
|
+
if W.dim() > 2:
|
|
163
|
+
result = result.reshape(shape)
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
@torch.no_grad()
|
|
167
|
+
def step(self, closure=None):
|
|
168
|
+
"""Performs a single optimization step.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
closure (Callable, optional): A closure that reevaluates the model
|
|
172
|
+
and returns the loss.
|
|
173
|
+
"""
|
|
174
|
+
loss = None
|
|
175
|
+
if closure is not None:
|
|
176
|
+
with torch.enable_grad():
|
|
177
|
+
loss = closure()
|
|
178
|
+
|
|
179
|
+
for group in self.param_groups:
|
|
180
|
+
lr = group['lr']
|
|
181
|
+
beta1, beta2 = group['betas']
|
|
182
|
+
eps = group['eps']
|
|
183
|
+
c2 = group['c2']
|
|
184
|
+
mode = group['mode']
|
|
185
|
+
wd = group['weight_decay']
|
|
186
|
+
min_sp = group['min_spatial_size']
|
|
187
|
+
|
|
188
|
+
stencil = group['stencil']
|
|
189
|
+
|
|
190
|
+
for p in group['params']:
|
|
191
|
+
if p.grad is None:
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
grad = p.grad.data
|
|
195
|
+
if wd != 0:
|
|
196
|
+
grad = grad.add(p.data, alpha=wd)
|
|
197
|
+
|
|
198
|
+
state = self.state[p]
|
|
199
|
+
|
|
200
|
+
# State initialization
|
|
201
|
+
if len(state) == 0:
|
|
202
|
+
state['step'] = 0
|
|
203
|
+
state['m'] = torch.zeros_like(p.data)
|
|
204
|
+
state['v'] = torch.zeros_like(p.data)
|
|
205
|
+
state['use_lap'] = (c2 > 0 and p.data.numel() >= min_sp)
|
|
206
|
+
|
|
207
|
+
state['step'] += 1
|
|
208
|
+
m, v = state['m'], state['v']
|
|
209
|
+
step = state['step']
|
|
210
|
+
use_lap = state['use_lap']
|
|
211
|
+
|
|
212
|
+
# Standard Adam momentum updates
|
|
213
|
+
m.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
214
|
+
v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
215
|
+
|
|
216
|
+
# Bias correction
|
|
217
|
+
m_hat = m / (1 - beta1 ** step)
|
|
218
|
+
v_hat = v / (1 - beta2 ** step)
|
|
219
|
+
|
|
220
|
+
if mode == 'variance_lap':
|
|
221
|
+
# Core innovation: couple neighboring variance estimates
|
|
222
|
+
if use_lap:
|
|
223
|
+
k2d = self._get_lap_kernel_2d(p.device, v_hat.dtype, stencil)
|
|
224
|
+
k1d = self._get_lap_kernel_1d(p.device, v_hat.dtype)
|
|
225
|
+
v_smooth = v_hat + c2 * self._laplacian(v_hat, k2d, k1d)
|
|
226
|
+
v_smooth = v_smooth.clamp(min=eps)
|
|
227
|
+
else:
|
|
228
|
+
v_smooth = v_hat
|
|
229
|
+
p.data.addcdiv_(m_hat, v_smooth.sqrt().add_(eps), value=-lr)
|
|
230
|
+
|
|
231
|
+
elif mode == 'update_lap':
|
|
232
|
+
update = m_hat / (v_hat.sqrt() + eps)
|
|
233
|
+
if use_lap:
|
|
234
|
+
k2d = self._get_lap_kernel_2d(p.device, update.dtype, stencil)
|
|
235
|
+
k1d = self._get_lap_kernel_1d(p.device, update.dtype)
|
|
236
|
+
update = update + c2 * self._laplacian(update, k2d, k1d)
|
|
237
|
+
p.data.add_(update, alpha=-lr)
|
|
238
|
+
|
|
239
|
+
elif mode == 'weight_lap':
|
|
240
|
+
p.data.addcdiv_(m_hat, v_hat.sqrt().add_(eps), value=-lr)
|
|
241
|
+
if use_lap:
|
|
242
|
+
k2d = self._get_lap_kernel_2d(p.device, p.data.dtype, stencil)
|
|
243
|
+
k1d = self._get_lap_kernel_1d(p.device, p.data.dtype)
|
|
244
|
+
p.data.add_(self._laplacian(p.data, k2d, k1d), alpha=c2)
|
|
245
|
+
|
|
246
|
+
return loss
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def suggest_c2(architecture: str = 'auto') -> float:
|
|
250
|
+
"""
|
|
251
|
+
Suggest a c2 value based on architecture type.
|
|
252
|
+
|
|
253
|
+
Based on systematic sweeps across 7 c2 values and 4 architecture types.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
architecture: One of 'pinn', 'transformer', 'mlp', 'cnn', or 'auto'.
|
|
257
|
+
'auto' returns a safe default.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
Recommended c2 value.
|
|
261
|
+
"""
|
|
262
|
+
recommendations = {
|
|
263
|
+
'pinn': 1e-5, # -44.6% L2 error vs Adam
|
|
264
|
+
'transformer': 1e-4, # +0.20% acc, p=0.0005
|
|
265
|
+
'mlp': 1e-5, # Marginal improvement, safe value
|
|
266
|
+
'cnn': 0.0, # No benefit; skip Laplacian
|
|
267
|
+
'auto': 1e-4, # Safe general-purpose default
|
|
268
|
+
}
|
|
269
|
+
arch = architecture.lower()
|
|
270
|
+
if arch not in recommendations:
|
|
271
|
+
raise ValueError(f"Unknown architecture: {arch}. "
|
|
272
|
+
f"Choose from: {list(recommendations.keys())}")
|
|
273
|
+
return recommendations[arch]
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
class LAdaGrad(Optimizer):
|
|
277
|
+
"""
|
|
278
|
+
LAdaGrad: AdaGrad with Laplacian-regularized accumulator.
|
|
279
|
+
|
|
280
|
+
Standard AdaGrad accumulates squared gradients independently per parameter.
|
|
281
|
+
LAdaGrad applies discrete Laplacian diffusion to the accumulator, letting
|
|
282
|
+
neighboring weights share curvature information.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
params: Model parameters or param groups.
|
|
286
|
+
lr (float): Learning rate. Default: 1e-2.
|
|
287
|
+
lr_decay (float): Learning rate decay. Default: 0.
|
|
288
|
+
eps (float): Numerical stability. Default: 1e-10.
|
|
289
|
+
weight_decay (float): L2 weight decay. Default: 0.
|
|
290
|
+
c2 (float): Laplacian coupling strength. Default: 1e-4.
|
|
291
|
+
stencil (str): '9point' or '5point'. Default: '9point'.
|
|
292
|
+
min_spatial_size (int): Skip Laplacian below this size. Default: 16.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def __init__(self, params, lr=1e-2, lr_decay=0, eps=1e-10,
|
|
296
|
+
weight_decay=0, c2=1e-4, stencil='9point',
|
|
297
|
+
min_spatial_size=16):
|
|
298
|
+
if lr < 0.0:
|
|
299
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
300
|
+
if c2 < 0.0:
|
|
301
|
+
raise ValueError(f"Invalid coupling strength: {c2}")
|
|
302
|
+
defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps,
|
|
303
|
+
weight_decay=weight_decay, c2=c2,
|
|
304
|
+
stencil=stencil, min_spatial_size=min_spatial_size)
|
|
305
|
+
super().__init__(params, defaults)
|
|
306
|
+
|
|
307
|
+
@torch.no_grad()
|
|
308
|
+
def step(self, closure=None):
|
|
309
|
+
loss = None
|
|
310
|
+
if closure is not None:
|
|
311
|
+
with torch.enable_grad():
|
|
312
|
+
loss = closure()
|
|
313
|
+
|
|
314
|
+
for group in self.param_groups:
|
|
315
|
+
lr = group['lr']
|
|
316
|
+
lr_decay = group['lr_decay']
|
|
317
|
+
eps = group['eps']
|
|
318
|
+
c2 = group['c2']
|
|
319
|
+
wd = group['weight_decay']
|
|
320
|
+
min_sp = group['min_spatial_size']
|
|
321
|
+
stencil = group['stencil']
|
|
322
|
+
|
|
323
|
+
for p in group['params']:
|
|
324
|
+
if p.grad is None:
|
|
325
|
+
continue
|
|
326
|
+
|
|
327
|
+
grad = p.grad.data
|
|
328
|
+
if wd != 0:
|
|
329
|
+
grad = grad.add(p.data, alpha=wd)
|
|
330
|
+
|
|
331
|
+
state = self.state[p]
|
|
332
|
+
if len(state) == 0:
|
|
333
|
+
state['step'] = 0
|
|
334
|
+
state['sum'] = torch.zeros_like(p.data)
|
|
335
|
+
state['use_lap'] = (c2 > 0 and p.data.numel() >= min_sp)
|
|
336
|
+
|
|
337
|
+
state['step'] += 1
|
|
338
|
+
state['sum'].addcmul_(grad, grad, value=1)
|
|
339
|
+
|
|
340
|
+
clr = lr / (1 + (state['step'] - 1) * lr_decay)
|
|
341
|
+
|
|
342
|
+
acc = state['sum']
|
|
343
|
+
if state['use_lap']:
|
|
344
|
+
k2d = LAdam._get_lap_kernel_2d(p.device, acc.dtype, stencil)
|
|
345
|
+
k1d = LAdam._get_lap_kernel_1d(p.device, acc.dtype)
|
|
346
|
+
acc_smooth = acc + c2 * LAdam._laplacian(acc, k2d, k1d)
|
|
347
|
+
acc_smooth = acc_smooth.clamp(min=0)
|
|
348
|
+
else:
|
|
349
|
+
acc_smooth = acc
|
|
350
|
+
|
|
351
|
+
p.data.addcdiv_(grad, acc_smooth.sqrt().add_(eps), value=-clr)
|
|
352
|
+
|
|
353
|
+
return loss
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
class LRMSProp(Optimizer):
|
|
357
|
+
"""
|
|
358
|
+
LRMSProp: RMSProp with Laplacian-regularized running average.
|
|
359
|
+
|
|
360
|
+
Standard RMSProp maintains a decaying average of squared gradients per
|
|
361
|
+
parameter. LRMSProp applies discrete Laplacian diffusion to this average,
|
|
362
|
+
coupling neighbors.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
params: Model parameters or param groups.
|
|
366
|
+
lr (float): Learning rate. Default: 1e-2.
|
|
367
|
+
alpha (float): Smoothing constant (decay). Default: 0.99.
|
|
368
|
+
eps (float): Numerical stability. Default: 1e-8.
|
|
369
|
+
weight_decay (float): L2 weight decay. Default: 0.
|
|
370
|
+
momentum (float): Momentum factor. Default: 0.
|
|
371
|
+
c2 (float): Laplacian coupling strength. Default: 1e-4.
|
|
372
|
+
stencil (str): '9point' or '5point'. Default: '9point'.
|
|
373
|
+
min_spatial_size (int): Skip Laplacian below this size. Default: 16.
|
|
374
|
+
"""
|
|
375
|
+
|
|
376
|
+
def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8,
|
|
377
|
+
weight_decay=0, momentum=0, c2=1e-4, stencil='9point',
|
|
378
|
+
min_spatial_size=16):
|
|
379
|
+
if lr < 0.0:
|
|
380
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
381
|
+
if alpha < 0.0 or alpha >= 1.0:
|
|
382
|
+
raise ValueError(f"Invalid alpha: {alpha}")
|
|
383
|
+
if c2 < 0.0:
|
|
384
|
+
raise ValueError(f"Invalid coupling strength: {c2}")
|
|
385
|
+
defaults = dict(lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay,
|
|
386
|
+
momentum=momentum, c2=c2, stencil=stencil,
|
|
387
|
+
min_spatial_size=min_spatial_size)
|
|
388
|
+
super().__init__(params, defaults)
|
|
389
|
+
|
|
390
|
+
@torch.no_grad()
|
|
391
|
+
def step(self, closure=None):
|
|
392
|
+
loss = None
|
|
393
|
+
if closure is not None:
|
|
394
|
+
with torch.enable_grad():
|
|
395
|
+
loss = closure()
|
|
396
|
+
|
|
397
|
+
for group in self.param_groups:
|
|
398
|
+
lr = group['lr']
|
|
399
|
+
alpha = group['alpha']
|
|
400
|
+
eps = group['eps']
|
|
401
|
+
c2 = group['c2']
|
|
402
|
+
wd = group['weight_decay']
|
|
403
|
+
mom = group['momentum']
|
|
404
|
+
min_sp = group['min_spatial_size']
|
|
405
|
+
stencil = group['stencil']
|
|
406
|
+
|
|
407
|
+
for p in group['params']:
|
|
408
|
+
if p.grad is None:
|
|
409
|
+
continue
|
|
410
|
+
|
|
411
|
+
grad = p.grad.data
|
|
412
|
+
if wd != 0:
|
|
413
|
+
grad = grad.add(p.data, alpha=wd)
|
|
414
|
+
|
|
415
|
+
state = self.state[p]
|
|
416
|
+
if len(state) == 0:
|
|
417
|
+
state['step'] = 0
|
|
418
|
+
state['square_avg'] = torch.zeros_like(p.data)
|
|
419
|
+
if mom > 0:
|
|
420
|
+
state['momentum_buffer'] = torch.zeros_like(p.data)
|
|
421
|
+
state['use_lap'] = (c2 > 0 and p.data.numel() >= min_sp)
|
|
422
|
+
|
|
423
|
+
state['step'] += 1
|
|
424
|
+
sq_avg = state['square_avg']
|
|
425
|
+
sq_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
|
|
426
|
+
|
|
427
|
+
if state['use_lap']:
|
|
428
|
+
k2d = LAdam._get_lap_kernel_2d(p.device, sq_avg.dtype, stencil)
|
|
429
|
+
k1d = LAdam._get_lap_kernel_1d(p.device, sq_avg.dtype)
|
|
430
|
+
avg_smooth = sq_avg + c2 * LAdam._laplacian(sq_avg, k2d, k1d)
|
|
431
|
+
avg_smooth = avg_smooth.clamp(min=0)
|
|
432
|
+
else:
|
|
433
|
+
avg_smooth = sq_avg
|
|
434
|
+
|
|
435
|
+
if mom > 0:
|
|
436
|
+
buf = state['momentum_buffer']
|
|
437
|
+
buf.mul_(mom).addcdiv_(grad, avg_smooth.sqrt().add_(eps))
|
|
438
|
+
p.data.add_(buf, alpha=-lr)
|
|
439
|
+
else:
|
|
440
|
+
p.data.addcdiv_(grad, avg_smooth.sqrt().add_(eps), value=-lr)
|
|
441
|
+
|
|
442
|
+
return loss
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
"""Tests for LAdam optimizer."""
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn as nn
|
|
6
|
+
from ladam import LAdam, suggest_c2
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class SimpleMLP(nn.Module):
|
|
10
|
+
def __init__(self):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.fc1 = nn.Linear(20, 64)
|
|
13
|
+
self.fc2 = nn.Linear(64, 64)
|
|
14
|
+
self.fc3 = nn.Linear(64, 1)
|
|
15
|
+
|
|
16
|
+
def forward(self, x):
|
|
17
|
+
x = torch.relu(self.fc1(x))
|
|
18
|
+
x = torch.relu(self.fc2(x))
|
|
19
|
+
return self.fc3(x)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@pytest.fixture
|
|
23
|
+
def model():
|
|
24
|
+
return SimpleMLP()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@pytest.fixture
|
|
28
|
+
def data():
|
|
29
|
+
torch.manual_seed(42)
|
|
30
|
+
X = torch.randn(100, 20)
|
|
31
|
+
y = torch.randn(100, 1)
|
|
32
|
+
return X, y
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def test_basic_step(model, data):
|
|
36
|
+
"""LAdam can perform a basic optimization step."""
|
|
37
|
+
X, y = data
|
|
38
|
+
opt = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
|
|
39
|
+
loss_fn = nn.MSELoss()
|
|
40
|
+
|
|
41
|
+
loss_before = loss_fn(model(X), y).item()
|
|
42
|
+
for _ in range(10):
|
|
43
|
+
opt.zero_grad()
|
|
44
|
+
loss = loss_fn(model(X), y)
|
|
45
|
+
loss.backward()
|
|
46
|
+
opt.step()
|
|
47
|
+
loss_after = loss_fn(model(X), y).item()
|
|
48
|
+
|
|
49
|
+
assert loss_after < loss_before, "Loss should decrease after optimization"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def test_c2_zero_equals_adam(model, data):
|
|
53
|
+
"""With c2=0, LAdam should behave identically to Adam."""
|
|
54
|
+
X, y = data
|
|
55
|
+
loss_fn = nn.MSELoss()
|
|
56
|
+
|
|
57
|
+
# LAdam with c2=0
|
|
58
|
+
torch.manual_seed(0)
|
|
59
|
+
m1 = SimpleMLP()
|
|
60
|
+
opt1 = LAdam(m1.parameters(), lr=1e-3, c2=0.0)
|
|
61
|
+
for _ in range(5):
|
|
62
|
+
opt1.zero_grad()
|
|
63
|
+
loss_fn(m1(X), y).backward()
|
|
64
|
+
opt1.step()
|
|
65
|
+
|
|
66
|
+
# Standard Adam
|
|
67
|
+
torch.manual_seed(0)
|
|
68
|
+
m2 = SimpleMLP()
|
|
69
|
+
opt2 = torch.optim.Adam(m2.parameters(), lr=1e-3)
|
|
70
|
+
for _ in range(5):
|
|
71
|
+
opt2.zero_grad()
|
|
72
|
+
loss_fn(m2(X), y).backward()
|
|
73
|
+
opt2.step()
|
|
74
|
+
|
|
75
|
+
for p1, p2 in zip(m1.parameters(), m2.parameters()):
|
|
76
|
+
assert torch.allclose(p1, p2, atol=1e-6), \
|
|
77
|
+
"c2=0 LAdam should match standard Adam"
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def test_all_modes(model, data):
|
|
81
|
+
"""All three modes should run without error and reduce loss."""
|
|
82
|
+
X, y = data
|
|
83
|
+
loss_fn = nn.MSELoss()
|
|
84
|
+
|
|
85
|
+
for mode in ['variance_lap', 'update_lap', 'weight_lap']:
|
|
86
|
+
torch.manual_seed(42)
|
|
87
|
+
m = SimpleMLP()
|
|
88
|
+
opt = LAdam(m.parameters(), lr=1e-3, c2=1e-4, mode=mode)
|
|
89
|
+
|
|
90
|
+
loss_before = loss_fn(m(X), y).item()
|
|
91
|
+
for _ in range(20):
|
|
92
|
+
opt.zero_grad()
|
|
93
|
+
loss_fn(m(X), y).backward()
|
|
94
|
+
opt.step()
|
|
95
|
+
loss_after = loss_fn(m(X), y).item()
|
|
96
|
+
|
|
97
|
+
assert loss_after < loss_before, f"Mode '{mode}' should reduce loss"
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_invalid_params():
|
|
101
|
+
"""Invalid parameters should raise ValueError."""
|
|
102
|
+
m = SimpleMLP()
|
|
103
|
+
with pytest.raises(ValueError):
|
|
104
|
+
LAdam(m.parameters(), lr=-1)
|
|
105
|
+
with pytest.raises(ValueError):
|
|
106
|
+
LAdam(m.parameters(), c2=-1)
|
|
107
|
+
with pytest.raises(ValueError):
|
|
108
|
+
LAdam(m.parameters(), mode='invalid')
|
|
109
|
+
with pytest.raises(ValueError):
|
|
110
|
+
LAdam(m.parameters(), betas=(1.5, 0.999))
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def test_suggest_c2():
|
|
114
|
+
"""Architecture-specific c2 suggestions."""
|
|
115
|
+
assert suggest_c2('pinn') == 1e-5
|
|
116
|
+
assert suggest_c2('transformer') == 1e-4
|
|
117
|
+
assert suggest_c2('cnn') == 0.0
|
|
118
|
+
assert suggest_c2('auto') == 1e-4
|
|
119
|
+
with pytest.raises(ValueError):
|
|
120
|
+
suggest_c2('unknown_arch')
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def test_param_groups(model, data):
|
|
124
|
+
"""Per-group c2 values should work."""
|
|
125
|
+
X, y = data
|
|
126
|
+
loss_fn = nn.MSELoss()
|
|
127
|
+
|
|
128
|
+
opt = LAdam([
|
|
129
|
+
{'params': model.fc1.parameters(), 'c2': 1e-4},
|
|
130
|
+
{'params': model.fc2.parameters(), 'c2': 1e-5},
|
|
131
|
+
{'params': model.fc3.parameters(), 'c2': 0.0},
|
|
132
|
+
], lr=1e-3)
|
|
133
|
+
|
|
134
|
+
for _ in range(10):
|
|
135
|
+
opt.zero_grad()
|
|
136
|
+
loss_fn(model(X), y).backward()
|
|
137
|
+
opt.step()
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def test_gpu_if_available(data):
|
|
141
|
+
"""LAdam should work on GPU if available."""
|
|
142
|
+
if not torch.cuda.is_available():
|
|
143
|
+
pytest.skip("CUDA not available")
|
|
144
|
+
|
|
145
|
+
X, y = data
|
|
146
|
+
X, y = X.cuda(), y.cuda()
|
|
147
|
+
m = SimpleMLP().cuda()
|
|
148
|
+
opt = LAdam(m.parameters(), lr=1e-3, c2=1e-4)
|
|
149
|
+
loss_fn = nn.MSELoss()
|
|
150
|
+
|
|
151
|
+
for _ in range(5):
|
|
152
|
+
opt.zero_grad()
|
|
153
|
+
loss_fn(m(X), y).backward()
|
|
154
|
+
opt.step()
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def test_with_weight_decay(model, data):
|
|
158
|
+
"""Weight decay should work alongside Laplacian."""
|
|
159
|
+
X, y = data
|
|
160
|
+
opt = LAdam(model.parameters(), lr=1e-3, c2=1e-4, weight_decay=0.01)
|
|
161
|
+
loss_fn = nn.MSELoss()
|
|
162
|
+
|
|
163
|
+
for _ in range(10):
|
|
164
|
+
opt.zero_grad()
|
|
165
|
+
loss_fn(model(X), y).backward()
|
|
166
|
+
opt.step()
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def test_closure(model, data):
|
|
170
|
+
"""Closure-based step should work (for LBFGS-style usage)."""
|
|
171
|
+
X, y = data
|
|
172
|
+
opt = LAdam(model.parameters(), lr=1e-3, c2=1e-4)
|
|
173
|
+
loss_fn = nn.MSELoss()
|
|
174
|
+
|
|
175
|
+
def closure():
|
|
176
|
+
opt.zero_grad()
|
|
177
|
+
loss = loss_fn(model(X), y)
|
|
178
|
+
loss.backward()
|
|
179
|
+
return loss
|
|
180
|
+
|
|
181
|
+
loss = opt.step(closure)
|
|
182
|
+
assert loss is not None
|