pilot-optimizer 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pilot/__init__.py +4 -0
- pilot/diagnostics.py +54 -0
- pilot/meta_grads.py +104 -0
- pilot/optimizer.py +287 -0
- pilot_optimizer-0.1.0.dist-info/METADATA +182 -0
- pilot_optimizer-0.1.0.dist-info/RECORD +9 -0
- pilot_optimizer-0.1.0.dist-info/WHEEL +5 -0
- pilot_optimizer-0.1.0.dist-info/licenses/LICENSE +21 -0
- pilot_optimizer-0.1.0.dist-info/top_level.txt +1 -0
pilot/__init__.py
ADDED
pilot/diagnostics.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Diagnostics tracker for the PILOT optimizer.
|
|
3
|
+
|
|
4
|
+
Only records data when explicitly enabled. Zero overhead when disabled.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DiagnosticsTracker:
|
|
9
|
+
"""Stores per-step optimizer internals for analysis."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, *, degree: int = 2):
|
|
12
|
+
self._degree = degree
|
|
13
|
+
n_phi = 3 * (degree + 1)
|
|
14
|
+
self._phi_keys = [f"phi_{i}" for i in range(n_phi)]
|
|
15
|
+
self._history = {
|
|
16
|
+
"step": [],
|
|
17
|
+
"r": [],
|
|
18
|
+
"rho": [],
|
|
19
|
+
"p_m": [],
|
|
20
|
+
"p_v": [],
|
|
21
|
+
"p_s": [],
|
|
22
|
+
}
|
|
23
|
+
for key in self._phi_keys:
|
|
24
|
+
self._history[key] = []
|
|
25
|
+
|
|
26
|
+
def record(self, step, r, rho, pm, pv, ps, phi):
|
|
27
|
+
"""Record diagnostics for one step."""
|
|
28
|
+
self._history["step"].append(step)
|
|
29
|
+
self._history["r"].append(float(r))
|
|
30
|
+
self._history["rho"].append(float(rho))
|
|
31
|
+
self._history["p_m"].append(float(pm))
|
|
32
|
+
self._history["p_v"].append(float(pv))
|
|
33
|
+
self._history["p_s"].append(float(ps))
|
|
34
|
+
phi_vals = phi.detach().cpu().tolist()
|
|
35
|
+
for i, key in enumerate(self._phi_keys):
|
|
36
|
+
self._history[key].append(phi_vals[i])
|
|
37
|
+
|
|
38
|
+
def get_history(self):
|
|
39
|
+
"""Return full history as dict of lists."""
|
|
40
|
+
return dict(self._history)
|
|
41
|
+
|
|
42
|
+
def summary(self, last_n=5):
|
|
43
|
+
"""Print a summary of the last N recorded steps."""
|
|
44
|
+
h = self._history
|
|
45
|
+
if not h["step"]:
|
|
46
|
+
return "No data recorded."
|
|
47
|
+
lines = [f"Last {min(last_n, len(h['step']))} steps:"]
|
|
48
|
+
for i in range(-last_n, 0):
|
|
49
|
+
idx = len(h["step"]) + i
|
|
50
|
+
lines.append(
|
|
51
|
+
f" step={h['step'][idx]} r={h['r'][idx]:.4f} rho={h['rho'][idx]:.4f} "
|
|
52
|
+
f"pm={h['p_m'][idx]:.4f} pv={h['p_v'][idx]:.4f} ps={h['p_s'][idx]:.4f}"
|
|
53
|
+
)
|
|
54
|
+
return "\n".join(lines)
|
pilot/meta_grads.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Analytic meta-gradient computation for the PILOT optimizer.
|
|
3
|
+
|
|
4
|
+
Pure function that computes the 3*(degree+1) meta-gradients (dL/dphi_i)
|
|
5
|
+
from stored detached intermediates. No autograd graph, no state.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _safe_pow(base, exp_val):
|
|
12
|
+
"""|base|^exp_val with clamping for numerical stability."""
|
|
13
|
+
return torch.clamp(base.abs(), min=1e-12).pow(exp_val)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _safe_log(x):
|
|
17
|
+
"""log(|x|) with clamping for numerical stability."""
|
|
18
|
+
return torch.log(torch.clamp(x.abs(), min=1e-12))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _horner(coeffs: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
22
|
+
"""Evaluate a polynomial via Horner's method (highest power first)."""
|
|
23
|
+
out = coeffs[0]
|
|
24
|
+
for k in range(1, coeffs.shape[0]):
|
|
25
|
+
out = out * x + coeffs[k]
|
|
26
|
+
return out
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def compute_meta_grads(g_next, intermediates, g_t, phi, eta, *, degree=2):
|
|
30
|
+
"""Compute the 3*(degree+1) meta-gradients dL/dphi_i for one param tensor.
|
|
31
|
+
|
|
32
|
+
The chain: phi -> (pm, pv, ps) -> d_t -> theta_{t+1} -> g_{t+1} -> L
|
|
33
|
+
|
|
34
|
+
dL/dphi_i = g_{t+1}^T * dd_t/dphi_i (summed over all elements)
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
g_next: gradient at step t+1, shape matches param.
|
|
38
|
+
intermediates: dict with keys n_t, m_hat, v_hat, p_m, p_v, p_s, rho.
|
|
39
|
+
g_t: gradient at step t (stored separately as state["g_prev"]).
|
|
40
|
+
phi: current group-level phi vector, length 3*(degree+1).
|
|
41
|
+
eta: main optimizer learning rate.
|
|
42
|
+
degree: polynomial degree of the response policy.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
grads: tensor of shape (3*(degree+1),).
|
|
46
|
+
"""
|
|
47
|
+
n_t = intermediates["n_t"]
|
|
48
|
+
m_hat = intermediates["m_hat"]
|
|
49
|
+
v_hat = intermediates["v_hat"]
|
|
50
|
+
pm = intermediates["p_m"]
|
|
51
|
+
pv = intermediates["p_v"]
|
|
52
|
+
ps = intermediates["p_s"]
|
|
53
|
+
rho = intermediates["rho"]
|
|
54
|
+
|
|
55
|
+
denom = _safe_pow(v_hat, pv) + 1e-8
|
|
56
|
+
n_abs_neg_ps = _safe_pow(n_t, -ps)
|
|
57
|
+
n_abs_1ps = _safe_pow(n_t, 1 - ps)
|
|
58
|
+
sign_n = torch.sign(n_t)
|
|
59
|
+
|
|
60
|
+
# dd_t / dp_m = -eta * (1 - ps) * |n|^(-ps) * (m_hat - g_t) / denom
|
|
61
|
+
dd_dpm = -eta * (1 - ps) * n_abs_neg_ps * (m_hat - g_t) / denom
|
|
62
|
+
|
|
63
|
+
# dd_t / dp_v = eta * |n|^(1-ps) * sign(n) * v^pv * log|v| / denom^2
|
|
64
|
+
dd_dpv = eta * n_abs_1ps * sign_n * _safe_pow(v_hat, pv) * _safe_log(v_hat) / (denom * denom)
|
|
65
|
+
|
|
66
|
+
# dd_t / dp_s = eta * |n|^(1-ps) * sign(n) * log|n| / denom
|
|
67
|
+
dd_dps = eta * n_abs_1ps * sign_n * _safe_log(n_t) / denom
|
|
68
|
+
|
|
69
|
+
# Scalar dL/dp for each policy variable
|
|
70
|
+
dL_dpm = torch.sum(g_next * dd_dpm)
|
|
71
|
+
dL_dpv = torch.sum(g_next * dd_dpv)
|
|
72
|
+
dL_dps = torch.sum(g_next * dd_dps)
|
|
73
|
+
|
|
74
|
+
# --- Chain rule through sigmoid and polynomial ---
|
|
75
|
+
n_phi = degree + 1
|
|
76
|
+
coeffs_m = phi[:n_phi]
|
|
77
|
+
coeffs_v = phi[n_phi : 2 * n_phi]
|
|
78
|
+
coeffs_s = phi[2 * n_phi :]
|
|
79
|
+
|
|
80
|
+
z_m = _horner(coeffs_m, rho)
|
|
81
|
+
z_v = _horner(coeffs_v, rho)
|
|
82
|
+
z_s = _horner(coeffs_s, rho)
|
|
83
|
+
|
|
84
|
+
s_m = torch.sigmoid(z_m)
|
|
85
|
+
s_v = torch.sigmoid(z_v)
|
|
86
|
+
s_s = torch.sigmoid(z_s)
|
|
87
|
+
|
|
88
|
+
ds_m = s_m * (1 - s_m)
|
|
89
|
+
ds_v = s_v * (1 - s_v)
|
|
90
|
+
ds_s = s_s * (1 - s_s)
|
|
91
|
+
|
|
92
|
+
# Build rho powers: [rho^d, rho^{d-1}, ..., rho^1, rho^0]
|
|
93
|
+
rho_powers = torch.empty(n_phi, device=phi.device)
|
|
94
|
+
rho_powers[n_phi - 1] = 1.0
|
|
95
|
+
for k in range(n_phi - 2, -1, -1):
|
|
96
|
+
rho_powers[k] = rho_powers[k + 1] * rho
|
|
97
|
+
|
|
98
|
+
# dL/d(coeffs_m[k]) = dL/dpm * ds_m * rho^(d-k)
|
|
99
|
+
grads_m = dL_dpm * ds_m * rho_powers
|
|
100
|
+
# pv = 0.5 * sigmoid(z_v), so dp_v/dz_v = 0.5 * sigmoid'(z_v)
|
|
101
|
+
grads_v = dL_dpv * 0.5 * ds_v * rho_powers
|
|
102
|
+
grads_s = dL_dps * ds_s * rho_powers
|
|
103
|
+
|
|
104
|
+
return torch.cat([grads_m, grads_v, grads_s])
|
pilot/optimizer.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PILOT -- Policy-Informed Learned Optimization for Training.
|
|
3
|
+
|
|
4
|
+
Adam with a learnable brain that reads the loss landscape and reshapes
|
|
5
|
+
the update rule every step. The brain is trained by gradient descent using
|
|
6
|
+
information the optimizer already computes.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch.optim import Optimizer
|
|
11
|
+
|
|
12
|
+
from .diagnostics import DiagnosticsTracker
|
|
13
|
+
from .meta_grads import compute_meta_grads
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# Degree-1 biases for the three policy sigmoids (slope=0, so flat at init).
|
|
17
|
+
_PHI_BIAS = (1.4, 3.0, -2.0)
|
|
18
|
+
|
|
19
|
+
META_GRAD_CLIP = 1.0 # safety cap on ||meta_grad_acc|| before phi update
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _build_phi_init(degree: int) -> torch.Tensor:
|
|
23
|
+
"""Build the initial phi vector for a given polynomial degree.
|
|
24
|
+
|
|
25
|
+
Layout: three consecutive blocks of (degree+1) coefficients, one per
|
|
26
|
+
policy variable (pm, pv, ps). Within each block the coefficients are
|
|
27
|
+
ordered highest-power-first: [coeff_d, ..., coeff_1, coeff_0].
|
|
28
|
+
|
|
29
|
+
Higher-order coefficients start at zero; the degree-1 (slope) coefficient
|
|
30
|
+
is zero; the degree-0 (bias) coefficient gets the Adam-like default.
|
|
31
|
+
This makes the initial behavior identical regardless of degree.
|
|
32
|
+
"""
|
|
33
|
+
n = degree + 1
|
|
34
|
+
phi = torch.zeros(3 * n)
|
|
35
|
+
for i, bias in enumerate(_PHI_BIAS):
|
|
36
|
+
phi[i * n + n - 1] = bias # constant term (last in block)
|
|
37
|
+
return phi
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _horner(coeffs: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
|
41
|
+
"""Evaluate a polynomial using Horner's method.
|
|
42
|
+
|
|
43
|
+
coeffs: 1-D tensor [c_d, c_{d-1}, ..., c_1, c_0] (highest power first).
|
|
44
|
+
x: scalar tensor.
|
|
45
|
+
Returns: scalar tensor = c_d*x^d + ... + c_1*x + c_0.
|
|
46
|
+
"""
|
|
47
|
+
out = coeffs[0]
|
|
48
|
+
for k in range(1, coeffs.shape[0]):
|
|
49
|
+
out = out * x + coeffs[k]
|
|
50
|
+
return out
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class PILOT(Optimizer):
|
|
54
|
+
r"""PILOT optimizer.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
params: iterable of parameters to optimize.
|
|
58
|
+
lr: learning rate (default: 1e-3).
|
|
59
|
+
betas: coefficients for running averages of gradient and its square (default: (0.9, 0.999)).
|
|
60
|
+
eps: term for numerical stability (default: 1e-8).
|
|
61
|
+
weight_decay: weight decay coefficient (default: 0.01).
|
|
62
|
+
gamma: smoothing factor for the landscape signal (default: 0.95).
|
|
63
|
+
eta_phi: meta learning rate for the response policy (default: 0.01).
|
|
64
|
+
degree: polynomial degree for the response policy (default: 2).
|
|
65
|
+
diagnostics: if True, record internal state every step (default: False).
|
|
66
|
+
policy_overrides: optional dict to fix policy variables, e.g. {'pm': 1.0}.
|
|
67
|
+
Overridden variables bypass phi and their meta-gradient blocks are zeroed.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
def __init__(
|
|
71
|
+
self,
|
|
72
|
+
params,
|
|
73
|
+
lr=1e-3,
|
|
74
|
+
betas=(0.9, 0.999),
|
|
75
|
+
eps=1e-8,
|
|
76
|
+
weight_decay=0.01,
|
|
77
|
+
gamma=0.95,
|
|
78
|
+
eta_phi=0.01,
|
|
79
|
+
degree=2,
|
|
80
|
+
diagnostics=False,
|
|
81
|
+
policy_overrides=None,
|
|
82
|
+
):
|
|
83
|
+
if not 0.0 <= lr:
|
|
84
|
+
raise ValueError(f"Invalid learning rate: {lr}")
|
|
85
|
+
if not 0.0 <= betas[0] < 1.0:
|
|
86
|
+
raise ValueError(f"Invalid beta1: {betas[0]}")
|
|
87
|
+
if not 0.0 <= betas[1] < 1.0:
|
|
88
|
+
raise ValueError(f"Invalid beta2: {betas[1]}")
|
|
89
|
+
if not 0.0 <= eps:
|
|
90
|
+
raise ValueError(f"Invalid eps: {eps}")
|
|
91
|
+
if not 0.0 <= gamma < 1.0:
|
|
92
|
+
raise ValueError(f"Invalid gamma: {gamma}")
|
|
93
|
+
if not (isinstance(degree, int) and degree >= 1):
|
|
94
|
+
raise ValueError(f"Invalid degree: {degree} (must be int >= 1)")
|
|
95
|
+
|
|
96
|
+
defaults = dict(
|
|
97
|
+
lr=lr,
|
|
98
|
+
betas=betas,
|
|
99
|
+
eps=eps,
|
|
100
|
+
weight_decay=weight_decay,
|
|
101
|
+
gamma=gamma,
|
|
102
|
+
eta_phi=eta_phi,
|
|
103
|
+
degree=degree,
|
|
104
|
+
)
|
|
105
|
+
super().__init__(params, defaults)
|
|
106
|
+
|
|
107
|
+
self.diagnostics = (
|
|
108
|
+
DiagnosticsTracker(degree=degree) if diagnostics else None
|
|
109
|
+
)
|
|
110
|
+
self.policy_overrides = policy_overrides or {}
|
|
111
|
+
|
|
112
|
+
phi_init = _build_phi_init(degree)
|
|
113
|
+
|
|
114
|
+
# Initialize phi and rho per group. rho is a 0-dim tensor on the
|
|
115
|
+
# first param's device to avoid GPU->CPU syncs in the hot path.
|
|
116
|
+
for group in self.param_groups:
|
|
117
|
+
device = next(iter(group["params"])).device
|
|
118
|
+
group["rho"] = torch.zeros((), device=device)
|
|
119
|
+
group["phi"] = phi_init.clone().to(device=device)
|
|
120
|
+
group["step"] = 0
|
|
121
|
+
|
|
122
|
+
@torch.no_grad()
|
|
123
|
+
def step(self, closure=None):
|
|
124
|
+
"""Performs a single optimization step."""
|
|
125
|
+
loss = None
|
|
126
|
+
if closure is not None:
|
|
127
|
+
with torch.enable_grad():
|
|
128
|
+
loss = closure()
|
|
129
|
+
|
|
130
|
+
for group in self.param_groups:
|
|
131
|
+
lr = group["lr"]
|
|
132
|
+
beta1, beta2 = group["betas"]
|
|
133
|
+
eps = group["eps"]
|
|
134
|
+
weight_decay = group["weight_decay"]
|
|
135
|
+
gamma = group["gamma"]
|
|
136
|
+
eta_phi = group["eta_phi"]
|
|
137
|
+
|
|
138
|
+
phi = group["phi"]
|
|
139
|
+
rho = group["rho"]
|
|
140
|
+
group["step"] += 1
|
|
141
|
+
step = group["step"]
|
|
142
|
+
|
|
143
|
+
bias_corr1 = 1 - beta1 ** step
|
|
144
|
+
bias_corr2 = 1 - beta2 ** step
|
|
145
|
+
|
|
146
|
+
# ---- Pass 1: single landscape signal from the full gradient ----
|
|
147
|
+
# Accumulate dot(g, g_prev), ||g||^2, ||g_prev||^2 across all params.
|
|
148
|
+
# Also handles lazy state init so pass 2 can rely on it.
|
|
149
|
+
dot = torch.zeros((), device=phi.device)
|
|
150
|
+
sq_g = torch.zeros((), device=phi.device)
|
|
151
|
+
sq_prev = torch.zeros((), device=phi.device)
|
|
152
|
+
|
|
153
|
+
for p in group["params"]:
|
|
154
|
+
if p.grad is None:
|
|
155
|
+
continue
|
|
156
|
+
state = self.state[p]
|
|
157
|
+
if len(state) == 0:
|
|
158
|
+
state["m"] = torch.zeros_like(p)
|
|
159
|
+
state["v"] = torch.zeros_like(p)
|
|
160
|
+
state["g_prev"] = torch.zeros_like(p)
|
|
161
|
+
state["intermediates"] = None
|
|
162
|
+
g = p.grad
|
|
163
|
+
g_prev = state["g_prev"]
|
|
164
|
+
dot = dot + (g * g_prev).sum().to(phi.device)
|
|
165
|
+
sq_g = sq_g + g.pow(2).sum().to(phi.device)
|
|
166
|
+
sq_prev = sq_prev + g_prev.pow(2).sum().to(phi.device)
|
|
167
|
+
|
|
168
|
+
denom_cos = (sq_g * sq_prev).sqrt()
|
|
169
|
+
r = torch.where(
|
|
170
|
+
denom_cos > 1e-12,
|
|
171
|
+
dot / denom_cos.clamp(min=1e-12),
|
|
172
|
+
torch.zeros((), device=phi.device),
|
|
173
|
+
)
|
|
174
|
+
rho = gamma * rho + (1 - gamma) * r
|
|
175
|
+
|
|
176
|
+
# ---- Single response policy for the whole group ----
|
|
177
|
+
degree = group["degree"]
|
|
178
|
+
n_phi = degree + 1
|
|
179
|
+
coeffs_m = phi[:n_phi]
|
|
180
|
+
coeffs_v = phi[n_phi : 2 * n_phi]
|
|
181
|
+
coeffs_s = phi[2 * n_phi :]
|
|
182
|
+
|
|
183
|
+
z_m = _horner(coeffs_m, rho)
|
|
184
|
+
z_v = _horner(coeffs_v, rho)
|
|
185
|
+
z_s = _horner(coeffs_s, rho)
|
|
186
|
+
|
|
187
|
+
pm = torch.sigmoid(z_m)
|
|
188
|
+
pv = 0.5 * torch.sigmoid(z_v)
|
|
189
|
+
ps = torch.sigmoid(z_s)
|
|
190
|
+
|
|
191
|
+
if self.policy_overrides:
|
|
192
|
+
if 'pm' in self.policy_overrides:
|
|
193
|
+
pm = torch.tensor(self.policy_overrides['pm'], device=phi.device)
|
|
194
|
+
if 'pv' in self.policy_overrides:
|
|
195
|
+
pv = torch.tensor(self.policy_overrides['pv'], device=phi.device)
|
|
196
|
+
if 'ps' in self.policy_overrides:
|
|
197
|
+
ps = torch.tensor(self.policy_overrides['ps'], device=phi.device)
|
|
198
|
+
|
|
199
|
+
# ---- Pass 2: meta-gradient + Adam update for each param ----
|
|
200
|
+
meta_grad_acc = torch.zeros(3 * n_phi, device=phi.device)
|
|
201
|
+
|
|
202
|
+
for p in group["params"]:
|
|
203
|
+
if p.grad is None:
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
g = p.grad
|
|
207
|
+
state = self.state[p]
|
|
208
|
+
m = state["m"]
|
|
209
|
+
v = state["v"]
|
|
210
|
+
g_prev = state["g_prev"] # still holds g_{t-1} until we copy below
|
|
211
|
+
|
|
212
|
+
# Meta-gradient from step t (intermediates) against g_{t+1} (current g).
|
|
213
|
+
intermediates = state.get("intermediates")
|
|
214
|
+
if intermediates is not None:
|
|
215
|
+
mg = compute_meta_grads(
|
|
216
|
+
g, intermediates, g_prev, phi, lr, degree=degree
|
|
217
|
+
)
|
|
218
|
+
meta_grad_acc += mg.to(phi.device)
|
|
219
|
+
|
|
220
|
+
# Adam bookkeeping
|
|
221
|
+
m.mul_(beta1).add_(g, alpha=1 - beta1)
|
|
222
|
+
v.mul_(beta2).addcmul_(g, g, value=1 - beta2)
|
|
223
|
+
m_hat = m / bias_corr1
|
|
224
|
+
v_hat = v / bias_corr2
|
|
225
|
+
|
|
226
|
+
# Move policy scalars to param device if needed (no-op on single device).
|
|
227
|
+
pm_p = pm.to(g.device) if pm.device != g.device else pm
|
|
228
|
+
pv_p = pv.to(g.device) if pv.device != g.device else pv
|
|
229
|
+
ps_p = ps.to(g.device) if ps.device != g.device else ps
|
|
230
|
+
|
|
231
|
+
# Adaptive step
|
|
232
|
+
n = pm_p * m_hat + (1 - pm_p) * g
|
|
233
|
+
n_abs = torch.clamp(n.abs(), min=1e-12)
|
|
234
|
+
denom = v_hat.pow(pv_p) + eps
|
|
235
|
+
d = -lr * n_abs.pow(1 - ps_p) * torch.sign(n) / denom
|
|
236
|
+
|
|
237
|
+
# Decoupled weight decay (AdamW-style)
|
|
238
|
+
if weight_decay != 0:
|
|
239
|
+
d = d - lr * weight_decay * p.data
|
|
240
|
+
|
|
241
|
+
p.data.add_(d)
|
|
242
|
+
|
|
243
|
+
# Store intermediates for next step's meta-gradient.
|
|
244
|
+
rho_p = rho.to(g.device) if rho.device != g.device else rho
|
|
245
|
+
state["intermediates"] = {
|
|
246
|
+
"n_t": n.detach(),
|
|
247
|
+
"m_hat": m_hat.detach(),
|
|
248
|
+
"v_hat": v_hat.detach(),
|
|
249
|
+
"p_m": pm_p.detach(),
|
|
250
|
+
"p_v": pv_p.detach(),
|
|
251
|
+
"p_s": ps_p.detach(),
|
|
252
|
+
"rho": rho_p.detach().clone(),
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
# Update g_prev to g_t for the next step.
|
|
256
|
+
g_prev.copy_(g)
|
|
257
|
+
|
|
258
|
+
# ---- Update phi with accumulated meta-gradients ----
|
|
259
|
+
if self.policy_overrides:
|
|
260
|
+
if 'pm' in self.policy_overrides:
|
|
261
|
+
meta_grad_acc[:n_phi] = 0
|
|
262
|
+
if 'pv' in self.policy_overrides:
|
|
263
|
+
meta_grad_acc[n_phi:2 * n_phi] = 0
|
|
264
|
+
if 'ps' in self.policy_overrides:
|
|
265
|
+
meta_grad_acc[2 * n_phi:] = 0
|
|
266
|
+
|
|
267
|
+
# Clip by L2 norm to guard against log|n| blow-up in meta-gradients.
|
|
268
|
+
mg_norm = meta_grad_acc.norm()
|
|
269
|
+
scale = torch.clamp(META_GRAD_CLIP / (mg_norm + 1e-12), max=1.0)
|
|
270
|
+
phi.add_(meta_grad_acc * scale, alpha=-eta_phi)
|
|
271
|
+
|
|
272
|
+
# Persist updated rho (still a tensor, no sync).
|
|
273
|
+
group["rho"] = rho
|
|
274
|
+
|
|
275
|
+
# ---- Diagnostics ----
|
|
276
|
+
if self.diagnostics is not None:
|
|
277
|
+
self.diagnostics.record(
|
|
278
|
+
step,
|
|
279
|
+
float(r),
|
|
280
|
+
float(rho),
|
|
281
|
+
float(pm),
|
|
282
|
+
float(pv),
|
|
283
|
+
float(ps),
|
|
284
|
+
phi,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
return loss
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pilot-optimizer
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
|
|
5
|
+
Author-email: Sattam Altuuaim <sattam.tuuaim@kaust.edu.sa>, Lama Ayash <lama.ayash@kaust.edu.sa>, Muhammad Mubashar <muhammad.mubashar@strath.ac.uk>, Naeemullah Khan <naeemullah.khan@kaust.edu.sa>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://sattamaltwaim.github.io/PILOT/
|
|
8
|
+
Project-URL: Repository, https://github.com/SattamAltwaim/PILOT
|
|
9
|
+
Project-URL: Paper, https://arxiv.org/abs/submit/7629402
|
|
10
|
+
Keywords: optimizer,deep-learning,pytorch,meta-learning,adaptive-optimization
|
|
11
|
+
Classifier: Development Status :: 4 - Beta
|
|
12
|
+
Classifier: Intended Audience :: Science/Research
|
|
13
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Description-Content-Type: text/markdown
|
|
22
|
+
License-File: LICENSE
|
|
23
|
+
Requires-Dist: torch>=2.0.0
|
|
24
|
+
Provides-Extra: dev
|
|
25
|
+
Requires-Dist: pytest; extra == "dev"
|
|
26
|
+
Requires-Dist: numpy; extra == "dev"
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
|
|
29
|
+
# PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
|
|
30
|
+
|
|
31
|
+
> **PILOT** is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.
|
|
32
|
+
|
|
33
|
+
---
|
|
34
|
+
|
|
35
|
+
## Overview
|
|
36
|
+
|
|
37
|
+
Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.
|
|
38
|
+
|
|
39
|
+
**PILOT** introduces a learnable policy that continuously modulates three core update primitives:
|
|
40
|
+
- **Momentum reliance** — how much to rely on accumulated gradient history vs. the current gradient
|
|
41
|
+
- **Variance-normalization strength** — how aggressively to apply adaptive scaling
|
|
42
|
+
- **Sign-based behavior** — how much to compress gradient magnitudes toward ±1
|
|
43
|
+
|
|
44
|
+
The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.
|
|
45
|
+
|
|
46
|
+

|
|
47
|
+
*PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.*
|
|
48
|
+
|
|
49
|
+
---
|
|
50
|
+
|
|
51
|
+
## Key Results
|
|
52
|
+
|
|
53
|
+
### CNN Architecture
|
|
54
|
+
|
|
55
|
+
| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|
|
56
|
+
|---|---|---|---|---|
|
|
57
|
+
| FashionMNIST | Adam | 93.28 | 0.1957 | **0.0033** |
|
|
58
|
+
| FashionMNIST | AdamW | 93.22 | 0.1944 | 0.0034 |
|
|
59
|
+
| FashionMNIST | Lion | 92.91 | 0.2091 | 0.0041 |
|
|
60
|
+
| FashionMNIST | AdaBelief | 93.66 | 0.1822 | 0.0046 |
|
|
61
|
+
| FashionMNIST | **PILOT (Ours)** | **94.13** | **0.1719** | 0.0045 |
|
|
62
|
+
| CIFAR-10 | Adam | 79.91 | 0.5794 | 0.0103 |
|
|
63
|
+
| CIFAR-10 | Lion | 80.87 | 0.5487 | 0.0105 |
|
|
64
|
+
| CIFAR-10 | **PILOT (Ours)** | **81.94** | **0.5302** | **0.0073** |
|
|
65
|
+
|
|
66
|
+
### ResNet-18 Architecture
|
|
67
|
+
|
|
68
|
+
| Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
|
|
69
|
+
|---|---|---|---|---|
|
|
70
|
+
| FashionMNIST | AdaBelief | 95.33 | 0.1711 | 0.0056 |
|
|
71
|
+
| FashionMNIST | **PILOT (Ours)** | **95.71** | 0.2690 | 0.0030 |
|
|
72
|
+
| CIFAR-10 | Adam | 93.18 | **0.2140** | 0.0073 |
|
|
73
|
+
| CIFAR-10 | AdamW | 92.90 | 0.2514 | 0.0066 |
|
|
74
|
+
| CIFAR-10 | **PILOT (Ours)** | **93.42** | 0.2496 | **0.0001** |
|
|
75
|
+
|
|
76
|
+
---
|
|
77
|
+
|
|
78
|
+
## Method
|
|
79
|
+
|
|
80
|
+
### Gradient-Direction Agreement
|
|
81
|
+
|
|
82
|
+
At each step, PILOT computes the cosine similarity between successive gradients:
|
|
83
|
+
|
|
84
|
+
$$r_t = \frac{g_t^\top g_{t-1}}{\|g_t\|_2 \, \|g_{t-1}\|_2 + \epsilon}$$
|
|
85
|
+
|
|
86
|
+
This is smoothed via an exponential moving average:
|
|
87
|
+
|
|
88
|
+
$$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$
|
|
89
|
+
|
|
90
|
+
Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.
|
|
91
|
+
|
|
92
|
+
### Learnable Policy
|
|
93
|
+
|
|
94
|
+
The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:
|
|
95
|
+
|
|
96
|
+
$$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$
|
|
97
|
+
|
|
98
|
+
The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.
|
|
99
|
+
|
|
100
|
+
### Update Rule
|
|
101
|
+
|
|
102
|
+
$$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}_t^{\,p_{v,t}} + \epsilon}$$
|
|
103
|
+
|
|
104
|
+
where $n_t = p_{m,t} \hat{m}_t + (1 - p_{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.
|
|
105
|
+
|
|
106
|
+
This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.
|
|
107
|
+
|
|
108
|
+
---
|
|
109
|
+
|
|
110
|
+
## Installation
|
|
111
|
+
|
|
112
|
+
```bash
|
|
113
|
+
pip install pilot-optimizer
|
|
114
|
+
```
|
|
115
|
+
|
|
116
|
+
Or install from source:
|
|
117
|
+
|
|
118
|
+
```bash
|
|
119
|
+
git clone https://github.com/SattamAltwaim/PILOT.git
|
|
120
|
+
cd PILOT
|
|
121
|
+
pip install -e .
|
|
122
|
+
```
|
|
123
|
+
|
|
124
|
+
---
|
|
125
|
+
|
|
126
|
+
## Usage
|
|
127
|
+
|
|
128
|
+
```python
|
|
129
|
+
from pilot import PILOT
|
|
130
|
+
|
|
131
|
+
optimizer = PILOT(
|
|
132
|
+
model.parameters(),
|
|
133
|
+
lr=1e-3,
|
|
134
|
+
betas=(0.9, 0.999),
|
|
135
|
+
weight_decay=1e-4,
|
|
136
|
+
gamma=0.95, # smoothing coefficient for agreement signal
|
|
137
|
+
lr_phi=0.01, # policy learning rate
|
|
138
|
+
degree=2 # polynomial degree
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
for batch in dataloader:
|
|
142
|
+
loss = criterion(model(x), y)
|
|
143
|
+
optimizer.zero_grad()
|
|
144
|
+
loss.backward()
|
|
145
|
+
optimizer.step()
|
|
146
|
+
```
|
|
147
|
+
|
|
148
|
+
---
|
|
149
|
+
|
|
150
|
+
## Hyperparameters
|
|
151
|
+
|
|
152
|
+
| Parameter | Description | Typical Range |
|
|
153
|
+
|---|---|---|
|
|
154
|
+
| `lr` | Model learning rate | `1e-4` – `1e-3` |
|
|
155
|
+
| `betas` | Moment coefficients | `(0.9, 0.999)` |
|
|
156
|
+
| `gamma` | Agreement signal smoothing | `0.85` – `0.99` |
|
|
157
|
+
| `lr_phi` | Policy learning rate | `5e-4` – `5e-2` |
|
|
158
|
+
| `degree` | Polynomial degree | `1` – `4` |
|
|
159
|
+
|
|
160
|
+
### Configuration-Specific Selections
|
|
161
|
+
|
|
162
|
+
| Dataset | Architecture | γ | η_φ | Degree |
|
|
163
|
+
|---|---|---|---|---|
|
|
164
|
+
| CIFAR-10 | CNN | 0.882 | 0.00312 | 1 |
|
|
165
|
+
| CIFAR-10 | ResNet-18 | 0.950 | 0.00500 | 2 |
|
|
166
|
+
| FashionMNIST | CNN | 0.950 | 0.01000 | 2 |
|
|
167
|
+
| FashionMNIST | ResNet-18 | 0.957 | 0.00273 | 3 |
|
|
168
|
+
|
|
169
|
+
---
|
|
170
|
+
|
|
171
|
+
## Experiments
|
|
172
|
+
|
|
173
|
+
Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.
|
|
174
|
+
|
|
175
|
+
```bash
|
|
176
|
+
# CNN on CIFAR-10
|
|
177
|
+
python train.py --dataset cifar10 --arch cnn --optimizer pilot
|
|
178
|
+
|
|
179
|
+
# ResNet-18 on FashionMNIST
|
|
180
|
+
python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot
|
|
181
|
+
```
|
|
182
|
+
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
pilot/__init__.py,sha256=lkgpmUggaQ0t25jh5OWADfnhl5KpbX8CWXL-QWdUW_E,72
|
|
2
|
+
pilot/diagnostics.py,sha256=A1_gRprfSpXEPGOkeVDHxsMLxl2qtHyun0HO2tgaDk8,1834
|
|
3
|
+
pilot/meta_grads.py,sha256=lhIZjaggoDkeyukRNrhIT6WsY031M1zfGXplVry6kHk,3479
|
|
4
|
+
pilot/optimizer.py,sha256=hcLeXtM3HRVQK3B1gMXJdp-MudSefTDICF_E72d3-b4,10820
|
|
5
|
+
pilot_optimizer-0.1.0.dist-info/licenses/LICENSE,sha256=d50RzYxuMWIMIhnSEGoYkVNnbUCeae2WrpPoDQgFIYE,1071
|
|
6
|
+
pilot_optimizer-0.1.0.dist-info/METADATA,sha256=J8XOcu6WF1_BeUs9PjyRekbnk8ixyKkmaNiRYJdJBAo,6535
|
|
7
|
+
pilot_optimizer-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
8
|
+
pilot_optimizer-0.1.0.dist-info/top_level.txt,sha256=BijnVJdXnIPxxx3s60M848seL4Z12gNUPod6KPJxK9c,6
|
|
9
|
+
pilot_optimizer-0.1.0.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Sattam Altwaim
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
pilot
|