pilot-optimizer 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pilot_optimizer-0.1.0/LICENSE +21 -0
- pilot_optimizer-0.1.0/PKG-INFO +182 -0
- pilot_optimizer-0.1.0/README.md +154 -0
- pilot_optimizer-0.1.0/pilot/__init__.py +4 -0
- pilot_optimizer-0.1.0/pilot/diagnostics.py +54 -0
- pilot_optimizer-0.1.0/pilot/meta_grads.py +104 -0
- pilot_optimizer-0.1.0/pilot/optimizer.py +287 -0
- pilot_optimizer-0.1.0/pilot_optimizer.egg-info/PKG-INFO +182 -0
- pilot_optimizer-0.1.0/pilot_optimizer.egg-info/SOURCES.txt +14 -0
- pilot_optimizer-0.1.0/pilot_optimizer.egg-info/dependency_links.txt +1 -0
- pilot_optimizer-0.1.0/pilot_optimizer.egg-info/requires.txt +5 -0
- pilot_optimizer-0.1.0/pilot_optimizer.egg-info/top_level.txt +1 -0
- pilot_optimizer-0.1.0/pyproject.toml +46 -0
- pilot_optimizer-0.1.0/setup.cfg +4 -0
- pilot_optimizer-0.1.0/tests/test_meta_grads.py +301 -0
- pilot_optimizer-0.1.0/tests/test_pilot.py +290 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Sattam Altwaim
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pilot-optimizer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
|
|
5
|
+
Author-email: Sattam Altuuaim <sattam.tuuaim@kaust.edu.sa>, Lama Ayash <lama.ayash@kaust.edu.sa>, Muhammad Mubashar <muhammad.mubashar@strath.ac.uk>, Naeemullah Khan <naeemullah.khan@kaust.edu.sa>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://sattamaltwaim.github.io/PILOT/
|
|
8
|
+
Project-URL: Repository, https://github.com/SattamAltwaim/PILOT
|
|
9
|
+
Project-URL: Paper, https://arxiv.org/abs/submit/7629402
|
|
10
|
+
Keywords: optimizer,deep-learning,pytorch,meta-learning,adaptive-optimization
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: torch>=2.0.0
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest; extra == "dev"
|
|
26
|
+
Requires-Dist: numpy; extra == "dev"
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
|
|
29
|
+
# PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
|
|
30
|
+
|
|
31
|
+
> **PILOT** is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.
|
|
32
|
+
|
|
33
|
+
---
|
|
34
|
+
|
|
35
|
+
## Overview
|
|
36
|
+
|
|
37
|
+
Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.
|
|
38
|
+
|
|
39
|
+
**PILOT** introduces a learnable policy that continuously modulates three core update primitives:
|
|
40
|
+
- **Momentum reliance** — how much to rely on accumulated gradient history vs. the current gradient
|
|
41
|
+
- **Variance-normalization strength** — how aggressively to apply adaptive scaling
|
|
42
|
+
- **Sign-based behavior** — how much to compress gradient magnitudes toward ±1
|
|
43
|
+
|
|
44
|
+
The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.
|
|
45
|
+
|
|
46
|
+

|
|
47
|
+
*PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.*
|
|
48
|
+
|
|
49
|
+
---
|
|
50
|
+
|
|
51
|
+
## Key Results
|
|
52
|
+
|
|
53
|
+
### CNN Architecture
|
|
54
|
+
|
|
55
|
+
| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|
|
56
|
+
|---|---|---|---|---|
|
|
57
|
+
| FashionMNIST | Adam | 93.28 | 0.1957 | **0.0033** |
|
|
58
|
+
| FashionMNIST | AdamW | 93.22 | 0.1944 | 0.0034 |
|
|
59
|
+
| FashionMNIST | Lion | 92.91 | 0.2091 | 0.0041 |
|
|
60
|
+
| FashionMNIST | AdaBelief | 93.66 | 0.1822 | 0.0046 |
|
|
61
|
+
| FashionMNIST | **PILOT (Ours)** | **94.13** | **0.1719** | 0.0045 |
|
|
62
|
+
| CIFAR-10 | Adam | 79.91 | 0.5794 | 0.0103 |
|
|
63
|
+
| CIFAR-10 | Lion | 80.87 | 0.5487 | 0.0105 |
|
|
64
|
+
| CIFAR-10 | **PILOT (Ours)** | **81.94** | **0.5302** | **0.0073** |
|
|
65
|
+
|
|
66
|
+
### ResNet-18 Architecture
|
|
67
|
+
|
|
68
|
+
| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|
|
69
|
+
|---|---|---|---|---|
|
|
70
|
+
| FashionMNIST | AdaBelief | 95.33 | 0.1711 | 0.0056 |
|
|
71
|
+
| FashionMNIST | **PILOT (Ours)** | **95.71** | 0.2690 | 0.0030 |
|
|
72
|
+
| CIFAR-10 | Adam | 93.18 | **0.2140** | 0.0073 |
|
|
73
|
+
| CIFAR-10 | AdamW | 92.90 | 0.2514 | 0.0066 |
|
|
74
|
+
| CIFAR-10 | **PILOT (Ours)** | **93.42** | 0.2496 | **0.0001** |
|
|
75
|
+
|
|
76
|
+
---
|
|
77
|
+
|
|
78
|
+
## Method
|
|
79
|
+
|
|
80
|
+
### Gradient-Direction Agreement
|
|
81
|
+
|
|
82
|
+
At each step, PILOT computes the cosine similarity between successive gradients:
|
|
83
|
+
|
|
84
|
+
$$r_t = \frac{g_t^\top g_{t-1}}{\|g_t\|_2 \, \|g_{t-1}\|_2 + \epsilon}$$
|
|
85
|
+
|
|
86
|
+
This is smoothed via an exponential moving average:
|
|
87
|
+
|
|
88
|
+
$$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$
|
|
89
|
+
|
|
90
|
+
Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.
|
|
91
|
+
|
|
92
|
+
### Learnable Policy
|
|
93
|
+
|
|
94
|
+
The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:
|
|
95
|
+
|
|
96
|
+
$$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$
|
|
97
|
+
|
|
98
|
+
The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.
|
|
99
|
+
|
|
100
|
+
### Update Rule
|
|
101
|
+
|
|
102
|
+
$$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}_t^{\,p_{v,t}} + \epsilon}$$
|
|
103
|
+
|
|
104
|
+
where $n_t = p_{m,t} \hat{m}_t + (1 - p_{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.
|
|
105
|
+
|
|
106
|
+
This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.
|
|
107
|
+
|
|
108
|
+
---
|
|
109
|
+
|
|
110
|
+
## Installation
|
|
111
|
+
|
|
112
|
+
```bash
|
|
113
|
+
pip install pilot-optimizer
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
Or install from source:
|
|
117
|
+
|
|
118
|
+
```bash
|
|
119
|
+
git clone https://github.com/SattamAltwaim/PILOT.git
|
|
120
|
+
cd PILOT
|
|
121
|
+
pip install -e .
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
---
|
|
125
|
+
|
|
126
|
+
## Usage
|
|
127
|
+
|
|
128
|
+
```python
|
|
129
|
+
from pilot import PILOT
|
|
130
|
+
|
|
131
|
+
optimizer = PILOT(
|
|
132
|
+
model.parameters(),
|
|
133
|
+
lr=1e-3,
|
|
134
|
+
betas=(0.9, 0.999),
|
|
135
|
+
weight_decay=1e-4,
|
|
136
|
+
gamma=0.95, # smoothing coefficient for agreement signal
|
|
137
|
+
lr_phi=0.01, # policy learning rate
|
|
138
|
+
degree=2 # polynomial degree
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
for batch in dataloader:
|
|
142
|
+
loss = criterion(model(x), y)
|
|
143
|
+
optimizer.zero_grad()
|
|
144
|
+
loss.backward()
|
|
145
|
+
optimizer.step()
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
---
|
|
149
|
+
|
|
150
|
+
## Hyperparameters
|
|
151
|
+
|
|
152
|
+
| Parameter | Description | Typical Range |
|
|
153
|
+
|---|---|---|
|
|
154
|
+
| `lr` | Model learning rate | `1e-4` – `1e-3` |
|
|
155
|
+
| `betas` | Moment coefficients | `(0.9, 0.999)` |
|
|
156
|
+
| `gamma` | Agreement signal smoothing | `0.85` – `0.99` |
|
|
157
|
+
| `lr_phi` | Policy learning rate | `5e-4` – `5e-2` |
|
|
158
|
+
| `degree` | Polynomial degree | `1` – `4` |
|
|
159
|
+
|
|
160
|
+
### Configuration-Specific Selections
|
|
161
|
+
|
|
162
|
+
| Dataset | Architecture | γ | η_φ | Degree |
|
|
163
|
+
|---|---|---|---|---|
|
|
164
|
+
| CIFAR-10 | CNN | 0.882 | 0.00312 | 1 |
|
|
165
|
+
| CIFAR-10 | ResNet-18 | 0.950 | 0.00500 | 2 |
|
|
166
|
+
| FashionMNIST | CNN | 0.950 | 0.01000 | 2 |
|
|
167
|
+
| FashionMNIST | ResNet-18 | 0.957 | 0.00273 | 3 |
|
|
168
|
+
|
|
169
|
+
---
|
|
170
|
+
|
|
171
|
+
## Experiments
|
|
172
|
+
|
|
173
|
+
Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.
|
|
174
|
+
|
|
175
|
+
```bash
|
|
176
|
+
# CNN on CIFAR-10
|
|
177
|
+
python train.py --dataset cifar10 --arch cnn --optimizer pilot
|
|
178
|
+
|
|
179
|
+
# ResNet-18 on FashionMNIST
|
|
180
|
+
python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot
|
|
181
|
+
```
|
|
182
|
+
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
|
|
2
|
+
|
|
3
|
+
> **PILOT** is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.
|
|
4
|
+
|
|
5
|
+
---
|
|
6
|
+
|
|
7
|
+
## Overview
|
|
8
|
+
|
|
9
|
+
Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.
|
|
10
|
+
|
|
11
|
+
**PILOT** introduces a learnable policy that continuously modulates three core update primitives:
|
|
12
|
+
- **Momentum reliance** — how much to rely on accumulated gradient history vs. the current gradient
|
|
13
|
+
- **Variance-normalization strength** — how aggressively to apply adaptive scaling
|
|
14
|
+
- **Sign-based behavior** — how much to compress gradient magnitudes toward ±1
|
|
15
|
+
|
|
16
|
+
The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.
|
|
17
|
+
|
|
18
|
+

|
|
19
|
+
*PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.*
|
|
20
|
+
|
|
21
|
+
---
|
|
22
|
+
|
|
23
|
+
## Key Results
|
|
24
|
+
|
|
25
|
+
### CNN Architecture
|
|
26
|
+
|
|
27
|
+
| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|
|
28
|
+
|---|---|---|---|---|
|
|
29
|
+
| FashionMNIST | Adam | 93.28 | 0.1957 | **0.0033** |
|
|
30
|
+
| FashionMNIST | AdamW | 93.22 | 0.1944 | 0.0034 |
|
|
31
|
+
| FashionMNIST | Lion | 92.91 | 0.2091 | 0.0041 |
|
|
32
|
+
| FashionMNIST | AdaBelief | 93.66 | 0.1822 | 0.0046 |
|
|
33
|
+
| FashionMNIST | **PILOT (Ours)** | **94.13** | **0.1719** | 0.0045 |
|
|
34
|
+
| CIFAR-10 | Adam | 79.91 | 0.5794 | 0.0103 |
|
|
35
|
+
| CIFAR-10 | Lion | 80.87 | 0.5487 | 0.0105 |
|
|
36
|
+
| CIFAR-10 | **PILOT (Ours)** | **81.94** | **0.5302** | **0.0073** |
|
|
37
|
+
|
|
38
|
+
### ResNet-18 Architecture
|
|
39
|
+
|
|
40
|
+
| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|
|
41
|
+
|---|---|---|---|---|
|
|
42
|
+
| FashionMNIST | AdaBelief | 95.33 | 0.1711 | 0.0056 |
|
|
43
|
+
| FashionMNIST | **PILOT (Ours)** | **95.71** | 0.2690 | 0.0030 |
|
|
44
|
+
| CIFAR-10 | Adam | 93.18 | **0.2140** | 0.0073 |
|
|
45
|
+
| CIFAR-10 | AdamW | 92.90 | 0.2514 | 0.0066 |
|
|
46
|
+
| CIFAR-10 | **PILOT (Ours)** | **93.42** | 0.2496 | **0.0001** |
|
|
47
|
+
|
|
48
|
+
---
|
|
49
|
+
|
|
50
|
+
## Method
|
|
51
|
+
|
|
52
|
+
### Gradient-Direction Agreement
|
|
53
|
+
|
|
54
|
+
At each step, PILOT computes the cosine similarity between successive gradients:
|
|
55
|
+
|
|
56
|
+
$$r_t = \frac{g_t^\top g_{t-1}}{\|g_t\|_2 \, \|g_{t-1}\|_2 + \epsilon}$$
|
|
57
|
+
|
|
58
|
+
This is smoothed via an exponential moving average:
|
|
59
|
+
|
|
60
|
+
$$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$
|
|
61
|
+
|
|
62
|
+
Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.
|
|
63
|
+
|
|
64
|
+
### Learnable Policy
|
|
65
|
+
|
|
66
|
+
The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:
|
|
67
|
+
|
|
68
|
+
$$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$
|
|
69
|
+
|
|
70
|
+
The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.
|
|
71
|
+
|
|
72
|
+
### Update Rule
|
|
73
|
+
|
|
74
|
+
$$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}_t^{\,p_{v,t}} + \epsilon}$$
|
|
75
|
+
|
|
76
|
+
where $n_t = p_{m,t} \hat{m}_t + (1 - p_{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.
|
|
77
|
+
|
|
78
|
+
This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.
|
|
79
|
+
|
|
80
|
+
---
|
|
81
|
+
|
|
82
|
+
## Installation
|
|
83
|
+
|
|
84
|
+
```bash
|
|
85
|
+
pip install pilot-optimizer
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
Or install from source:
|
|
89
|
+
|
|
90
|
+
```bash
|
|
91
|
+
git clone https://github.com/SattamAltwaim/PILOT.git
|
|
92
|
+
cd PILOT
|
|
93
|
+
pip install -e .
|
|
94
|
+
```
|
|
95
|
+
|
|
96
|
+
---
|
|
97
|
+
|
|
98
|
+
## Usage
|
|
99
|
+
|
|
100
|
+
```python
|
|
101
|
+
from pilot import PILOT
|
|
102
|
+
|
|
103
|
+
optimizer = PILOT(
|
|
104
|
+
model.parameters(),
|
|
105
|
+
lr=1e-3,
|
|
106
|
+
betas=(0.9, 0.999),
|
|
107
|
+
weight_decay=1e-4,
|
|
108
|
+
gamma=0.95, # smoothing coefficient for agreement signal
|
|
109
|
+
lr_phi=0.01, # policy learning rate
|
|
110
|
+
degree=2 # polynomial degree
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
for batch in dataloader:
|
|
114
|
+
loss = criterion(model(x), y)
|
|
115
|
+
optimizer.zero_grad()
|
|
116
|
+
loss.backward()
|
|
117
|
+
optimizer.step()
|
|
118
|
+
```
|
|
119
|
+
|
|
120
|
+
---
|
|
121
|
+
|
|
122
|
+
## Hyperparameters
|
|
123
|
+
|
|
124
|
+
| Parameter | Description | Typical Range |
|
|
125
|
+
|---|---|---|
|
|
126
|
+
| `lr` | Model learning rate | `1e-4` – `1e-3` |
|
|
127
|
+
| `betas` | Moment coefficients | `(0.9, 0.999)` |
|
|
128
|
+
| `gamma` | Agreement signal smoothing | `0.85` – `0.99` |
|
|
129
|
+
| `lr_phi` | Policy learning rate | `5e-4` – `5e-2` |
|
|
130
|
+
| `degree` | Polynomial degree | `1` – `4` |
|
|
131
|
+
|
|
132
|
+
### Configuration-Specific Selections
|
|
133
|
+
|
|
134
|
+
| Dataset | Architecture | γ | η_φ | Degree |
|
|
135
|
+
|---|---|---|---|---|
|
|
136
|
+
| CIFAR-10 | CNN | 0.882 | 0.00312 | 1 |
|
|
137
|
+
| CIFAR-10 | ResNet-18 | 0.950 | 0.00500 | 2 |
|
|
138
|
+
| FashionMNIST | CNN | 0.950 | 0.01000 | 2 |
|
|
139
|
+
| FashionMNIST | ResNet-18 | 0.957 | 0.00273 | 3 |
|
|
140
|
+
|
|
141
|
+
---
|
|
142
|
+
|
|
143
|
+
## Experiments
|
|
144
|
+
|
|
145
|
+
Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.
|
|
146
|
+
|
|
147
|
+
```bash
|
|
148
|
+
# CNN on CIFAR-10
|
|
149
|
+
python train.py --dataset cifar10 --arch cnn --optimizer pilot
|
|
150
|
+
|
|
151
|
+
# ResNet-18 on FashionMNIST
|
|
152
|
+
python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot
|
|
153
|
+
```
|
|
154
|
+
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Diagnostics tracker for the PILOT optimizer.
|
|
3
|
+
|
|
4
|
+
Only records data when explicitly enabled. Zero overhead when disabled.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DiagnosticsTracker:
|
|
9
|
+
"""Stores per-step optimizer internals for analysis."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, *, degree: int = 2):
|
|
12
|
+
self._degree = degree
|
|
13
|
+
n_phi = 3 * (degree + 1)
|
|
14
|
+
self._phi_keys = [f"phi_{i}" for i in range(n_phi)]
|
|
15
|
+
self._history = {
|
|
16
|
+
"step": [],
|
|
17
|
+
"r": [],
|
|
18
|
+
"rho": [],
|
|
19
|
+
"p_m": [],
|
|
20
|
+
"p_v": [],
|
|
21
|
+
"p_s": [],
|
|
22
|
+
}
|
|
23
|
+
for key in self._phi_keys:
|
|
24
|
+
self._history[key] = []
|
|
25
|
+
|
|
26
|
+
def record(self, step, r, rho, pm, pv, ps, phi):
|
|
27
|
+
"""Record diagnostics for one step."""
|
|
28
|
+
self._history["step"].append(step)
|
|
29
|
+
self._history["r"].append(float(r))
|
|
30
|
+
self._history["rho"].append(float(rho))
|
|
31
|
+
self._history["p_m"].append(float(pm))
|
|
32
|
+
self._history["p_v"].append(float(pv))
|
|
33
|
+
self._history["p_s"].append(float(ps))
|
|
34
|
+
phi_vals = phi.detach().cpu().tolist()
|
|
35
|
+
for i, key in enumerate(self._phi_keys):
|
|
36
|
+
self._history[key].append(phi_vals[i])
|
|
37
|
+
|
|
38
|
+
def get_history(self):
|
|
39
|
+
"""Return full history as dict of lists."""
|
|
40
|
+
return dict(self._history)
|
|
41
|
+
|
|
42
|
+
def summary(self, last_n=5):
|
|
43
|
+
"""Print a summary of the last N recorded steps."""
|
|
44
|
+
h = self._history
|
|
45
|
+
if not h["step"]:
|
|
46
|
+
return "No data recorded."
|
|
47
|
+
lines = [f"Last {min(last_n, len(h['step']))} steps:"]
|
|
48
|
+
for i in range(-last_n, 0):
|
|
49
|
+
idx = len(h["step"]) + i
|
|
50
|
+
lines.append(
|
|
51
|
+
f" step={h['step'][idx]} r={h['r'][idx]:.4f} rho={h['rho'][idx]:.4f} "
|
|
52
|
+
f"pm={h['p_m'][idx]:.4f} pv={h['p_v'][idx]:.4f} ps={h['p_s'][idx]:.4f}"
|
|
53
|
+
)
|
|
54
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Analytic meta-gradient computation for the PILOT optimizer.
|
|
3
|
+
|
|
4
|
+
Pure function that computes the 3*(degree+1) meta-gradients (dL/dphi_i)
|
|
5
|
+
from stored detached intermediates. No autograd graph, no state.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _safe_pow(base, exp_val):
|
|
12
|
+
"""|base|^exp_val with clamping for numerical stability."""
|
|
13
|
+
return torch.clamp(base.abs(), min=1e-12).pow(exp_val)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _safe_log(x):
|
|
17
|
+
"""log(|x|) with clamping for numerical stability."""
|
|
18
|
+
return torch.log(torch.clamp(x.abs(), min=1e-12))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _horner(coeffs: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
"""Evaluate a polynomial via Horner's method (highest power first)."""
|
|
23
|
+
out = coeffs[0]
|
|
24
|
+
for k in range(1, coeffs.shape[0]):
|
|
25
|
+
out = out * x + coeffs[k]
|
|
26
|
+
return out
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compute_meta_grads(g_next, intermediates, g_t, phi, eta, *, degree=2):
|
|
30
|
+
"""Compute the 3*(degree+1) meta-gradients dL/dphi_i for one param tensor.
|
|
31
|
+
|
|
32
|
+
The chain: phi -> (pm, pv, ps) -> d_t -> theta_{t+1} -> g_{t+1} -> L
|
|
33
|
+
|
|
34
|
+
dL/dphi_i = g_{t+1}^T * dd_t/dphi_i (summed over all elements)
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
g_next: gradient at step t+1, shape matches param.
|
|
38
|
+
intermediates: dict with keys n_t, m_hat, v_hat, p_m, p_v, p_s, rho.
|
|
39
|
+
g_t: gradient at step t (stored separately as state["g_prev"]).
|
|
40
|
+
phi: current group-level phi vector, length 3*(degree+1).
|
|
41
|
+
eta: main optimizer learning rate.
|
|
42
|
+
degree: polynomial degree of the response policy.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
grads: tensor of shape (3*(degree+1),).
|
|
46
|
+
"""
|
|
47
|
+
n_t = intermediates["n_t"]
|
|
48
|
+
m_hat = intermediates["m_hat"]
|
|
49
|
+
v_hat = intermediates["v_hat"]
|
|
50
|
+
pm = intermediates["p_m"]
|
|
51
|
+
pv = intermediates["p_v"]
|
|
52
|
+
ps = intermediates["p_s"]
|
|
53
|
+
rho = intermediates["rho"]
|
|
54
|
+
|
|
55
|
+
denom = _safe_pow(v_hat, pv) + 1e-8
|
|
56
|
+
n_abs_neg_ps = _safe_pow(n_t, -ps)
|
|
57
|
+
n_abs_1ps = _safe_pow(n_t, 1 - ps)
|
|
58
|
+
sign_n = torch.sign(n_t)
|
|
59
|
+
|
|
60
|
+
# dd_t / dp_m = -eta * (1 - ps) * |n|^(-ps) * (m_hat - g_t) / denom
|
|
61
|
+
dd_dpm = -eta * (1 - ps) * n_abs_neg_ps * (m_hat - g_t) / denom
|
|
62
|
+
|
|
63
|
+
# dd_t / dp_v = eta * |n|^(1-ps) * sign(n) * v^pv * log|v| / denom^2
|
|
64
|
+
dd_dpv = eta * n_abs_1ps * sign_n * _safe_pow(v_hat, pv) * _safe_log(v_hat) / (denom * denom)
|
|
65
|
+
|
|
66
|
+
# dd_t / dp_s = eta * |n|^(1-ps) * sign(n) * log|n| / denom
|
|
67
|
+
dd_dps = eta * n_abs_1ps * sign_n * _safe_log(n_t) / denom
|
|
68
|
+
|
|
69
|
+
# Scalar dL/dp for each policy variable
|
|
70
|
+
dL_dpm = torch.sum(g_next * dd_dpm)
|
|
71
|
+
dL_dpv = torch.sum(g_next * dd_dpv)
|
|
72
|
+
dL_dps = torch.sum(g_next * dd_dps)
|
|
73
|
+
|
|
74
|
+
# --- Chain rule through sigmoid and polynomial ---
|
|
75
|
+
n_phi = degree + 1
|
|
76
|
+
coeffs_m = phi[:n_phi]
|
|
77
|
+
coeffs_v = phi[n_phi : 2 * n_phi]
|
|
78
|
+
coeffs_s = phi[2 * n_phi :]
|
|
79
|
+
|
|
80
|
+
z_m = _horner(coeffs_m, rho)
|
|
81
|
+
z_v = _horner(coeffs_v, rho)
|
|
82
|
+
z_s = _horner(coeffs_s, rho)
|
|
83
|
+
|
|
84
|
+
s_m = torch.sigmoid(z_m)
|
|
85
|
+
s_v = torch.sigmoid(z_v)
|
|
86
|
+
s_s = torch.sigmoid(z_s)
|
|
87
|
+
|
|
88
|
+
ds_m = s_m * (1 - s_m)
|
|
89
|
+
ds_v = s_v * (1 - s_v)
|
|
90
|
+
ds_s = s_s * (1 - s_s)
|
|
91
|
+
|
|
92
|
+
# Build rho powers: [rho^d, rho^{d-1}, ..., rho^1, rho^0]
|
|
93
|
+
rho_powers = torch.empty(n_phi, device=phi.device)
|
|
94
|
+
rho_powers[n_phi - 1] = 1.0
|
|
95
|
+
for k in range(n_phi - 2, -1, -1):
|
|
96
|
+
rho_powers[k] = rho_powers[k + 1] * rho
|
|
97
|
+
|
|
98
|
+
# dL/d(coeffs_m[k]) = dL/dpm * ds_m * rho^(d-k)
|
|
99
|
+
grads_m = dL_dpm * ds_m * rho_powers
|
|
100
|
+
# pv = 0.5 * sigmoid(z_v), so dp_v/dz_v = 0.5 * sigmoid'(z_v)
|
|
101
|
+
grads_v = dL_dpv * 0.5 * ds_v * rho_powers
|
|
102
|
+
grads_s = dL_dps * ds_s * rho_powers
|
|
103
|
+
|
|
104
|
+
return torch.cat([grads_m, grads_v, grads_s])
|