pilot-optimizer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pilot/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .optimizer import PILOT
2
+
3
+ __all__ = ["PILOT"]
4
+ __version__ = "0.1.0"
pilot/diagnostics.py ADDED
@@ -0,0 +1,54 @@
1
+ """
2
+ Diagnostics tracker for the PILOT optimizer.
3
+
4
+ Only records data when explicitly enabled. Zero overhead when disabled.
5
+ """
6
+
7
+
8
+ class DiagnosticsTracker:
9
+ """Stores per-step optimizer internals for analysis."""
10
+
11
+ def __init__(self, *, degree: int = 2):
12
+ self._degree = degree
13
+ n_phi = 3 * (degree + 1)
14
+ self._phi_keys = [f"phi_{i}" for i in range(n_phi)]
15
+ self._history = {
16
+ "step": [],
17
+ "r": [],
18
+ "rho": [],
19
+ "p_m": [],
20
+ "p_v": [],
21
+ "p_s": [],
22
+ }
23
+ for key in self._phi_keys:
24
+ self._history[key] = []
25
+
26
+ def record(self, step, r, rho, pm, pv, ps, phi):
27
+ """Record diagnostics for one step."""
28
+ self._history["step"].append(step)
29
+ self._history["r"].append(float(r))
30
+ self._history["rho"].append(float(rho))
31
+ self._history["p_m"].append(float(pm))
32
+ self._history["p_v"].append(float(pv))
33
+ self._history["p_s"].append(float(ps))
34
+ phi_vals = phi.detach().cpu().tolist()
35
+ for i, key in enumerate(self._phi_keys):
36
+ self._history[key].append(phi_vals[i])
37
+
38
+ def get_history(self):
39
+ """Return full history as dict of lists."""
40
+ return dict(self._history)
41
+
42
+ def summary(self, last_n=5):
43
+ """Print a summary of the last N recorded steps."""
44
+ h = self._history
45
+ if not h["step"]:
46
+ return "No data recorded."
47
+ lines = [f"Last {min(last_n, len(h['step']))} steps:"]
48
+ for i in range(-last_n, 0):
49
+ idx = len(h["step"]) + i
50
+ lines.append(
51
+ f" step={h['step'][idx]} r={h['r'][idx]:.4f} rho={h['rho'][idx]:.4f} "
52
+ f"pm={h['p_m'][idx]:.4f} pv={h['p_v'][idx]:.4f} ps={h['p_s'][idx]:.4f}"
53
+ )
54
+ return "\n".join(lines)
pilot/meta_grads.py ADDED
@@ -0,0 +1,104 @@
1
+ """
2
+ Analytic meta-gradient computation for the PILOT optimizer.
3
+
4
+ Pure function that computes the 3*(degree+1) meta-gradients (dL/dphi_i)
5
+ from stored detached intermediates. No autograd graph, no state.
6
+ """
7
+
8
+ import torch
9
+
10
+
11
+ def _safe_pow(base, exp_val):
12
+ """|base|^exp_val with clamping for numerical stability."""
13
+ return torch.clamp(base.abs(), min=1e-12).pow(exp_val)
14
+
15
+
16
+ def _safe_log(x):
17
+ """log(|x|) with clamping for numerical stability."""
18
+ return torch.log(torch.clamp(x.abs(), min=1e-12))
19
+
20
+
21
+ def _horner(coeffs: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
22
+ """Evaluate a polynomial via Horner's method (highest power first)."""
23
+ out = coeffs[0]
24
+ for k in range(1, coeffs.shape[0]):
25
+ out = out * x + coeffs[k]
26
+ return out
27
+
28
+
29
+ def compute_meta_grads(g_next, intermediates, g_t, phi, eta, *, degree=2):
30
+ """Compute the 3*(degree+1) meta-gradients dL/dphi_i for one param tensor.
31
+
32
+ The chain: phi -> (pm, pv, ps) -> d_t -> theta_{t+1} -> g_{t+1} -> L
33
+
34
+ dL/dphi_i = g_{t+1}^T * dd_t/dphi_i (summed over all elements)
35
+
36
+ Args:
37
+ g_next: gradient at step t+1, shape matches param.
38
+ intermediates: dict with keys n_t, m_hat, v_hat, p_m, p_v, p_s, rho.
39
+ g_t: gradient at step t (stored separately as state["g_prev"]).
40
+ phi: current group-level phi vector, length 3*(degree+1).
41
+ eta: main optimizer learning rate.
42
+ degree: polynomial degree of the response policy.
43
+
44
+ Returns:
45
+ grads: tensor of shape (3*(degree+1),).
46
+ """
47
+ n_t = intermediates["n_t"]
48
+ m_hat = intermediates["m_hat"]
49
+ v_hat = intermediates["v_hat"]
50
+ pm = intermediates["p_m"]
51
+ pv = intermediates["p_v"]
52
+ ps = intermediates["p_s"]
53
+ rho = intermediates["rho"]
54
+
55
+ denom = _safe_pow(v_hat, pv) + 1e-8
56
+ n_abs_neg_ps = _safe_pow(n_t, -ps)
57
+ n_abs_1ps = _safe_pow(n_t, 1 - ps)
58
+ sign_n = torch.sign(n_t)
59
+
60
+ # dd_t / dp_m = -eta * (1 - ps) * |n|^(-ps) * (m_hat - g_t) / denom
61
+ dd_dpm = -eta * (1 - ps) * n_abs_neg_ps * (m_hat - g_t) / denom
62
+
63
+ # dd_t / dp_v = eta * |n|^(1-ps) * sign(n) * v^pv * log|v| / denom^2
64
+ dd_dpv = eta * n_abs_1ps * sign_n * _safe_pow(v_hat, pv) * _safe_log(v_hat) / (denom * denom)
65
+
66
+ # dd_t / dp_s = eta * |n|^(1-ps) * sign(n) * log|n| / denom
67
+ dd_dps = eta * n_abs_1ps * sign_n * _safe_log(n_t) / denom
68
+
69
+ # Scalar dL/dp for each policy variable
70
+ dL_dpm = torch.sum(g_next * dd_dpm)
71
+ dL_dpv = torch.sum(g_next * dd_dpv)
72
+ dL_dps = torch.sum(g_next * dd_dps)
73
+
74
+ # --- Chain rule through sigmoid and polynomial ---
75
+ n_phi = degree + 1
76
+ coeffs_m = phi[:n_phi]
77
+ coeffs_v = phi[n_phi : 2 * n_phi]
78
+ coeffs_s = phi[2 * n_phi :]
79
+
80
+ z_m = _horner(coeffs_m, rho)
81
+ z_v = _horner(coeffs_v, rho)
82
+ z_s = _horner(coeffs_s, rho)
83
+
84
+ s_m = torch.sigmoid(z_m)
85
+ s_v = torch.sigmoid(z_v)
86
+ s_s = torch.sigmoid(z_s)
87
+
88
+ ds_m = s_m * (1 - s_m)
89
+ ds_v = s_v * (1 - s_v)
90
+ ds_s = s_s * (1 - s_s)
91
+
92
+ # Build rho powers: [rho^d, rho^{d-1}, ..., rho^1, rho^0]
93
+ rho_powers = torch.empty(n_phi, device=phi.device)
94
+ rho_powers[n_phi - 1] = 1.0
95
+ for k in range(n_phi - 2, -1, -1):
96
+ rho_powers[k] = rho_powers[k + 1] * rho
97
+
98
+ # dL/d(coeffs_m[k]) = dL/dpm * ds_m * rho^(d-k)
99
+ grads_m = dL_dpm * ds_m * rho_powers
100
+ # pv = 0.5 * sigmoid(z_v), so dp_v/dz_v = 0.5 * sigmoid'(z_v)
101
+ grads_v = dL_dpv * 0.5 * ds_v * rho_powers
102
+ grads_s = dL_dps * ds_s * rho_powers
103
+
104
+ return torch.cat([grads_m, grads_v, grads_s])
pilot/optimizer.py ADDED
@@ -0,0 +1,287 @@
1
+ """
2
+ PILOT -- Policy-Informed Learned Optimization for Training.
3
+
4
+ Adam with a learnable brain that reads the loss landscape and reshapes
5
+ the update rule every step. The brain is trained by gradient descent using
6
+ information the optimizer already computes.
7
+ """
8
+
9
+ import torch
10
+ from torch.optim import Optimizer
11
+
12
+ from .diagnostics import DiagnosticsTracker
13
+ from .meta_grads import compute_meta_grads
14
+
15
+
16
+ # Degree-1 biases for the three policy sigmoids (slope=0, so flat at init).
17
+ _PHI_BIAS = (1.4, 3.0, -2.0)
18
+
19
+ META_GRAD_CLIP = 1.0 # safety cap on ||meta_grad_acc|| before phi update
20
+
21
+
22
+ def _build_phi_init(degree: int) -> torch.Tensor:
23
+ """Build the initial phi vector for a given polynomial degree.
24
+
25
+ Layout: three consecutive blocks of (degree+1) coefficients, one per
26
+ policy variable (pm, pv, ps). Within each block the coefficients are
27
+ ordered highest-power-first: [coeff_d, ..., coeff_1, coeff_0].
28
+
29
+ Higher-order coefficients start at zero; the degree-1 (slope) coefficient
30
+ is zero; the degree-0 (bias) coefficient gets the Adam-like default.
31
+ This makes the initial behavior identical regardless of degree.
32
+ """
33
+ n = degree + 1
34
+ phi = torch.zeros(3 * n)
35
+ for i, bias in enumerate(_PHI_BIAS):
36
+ phi[i * n + n - 1] = bias # constant term (last in block)
37
+ return phi
38
+
39
+
40
+ def _horner(coeffs: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
41
+ """Evaluate a polynomial using Horner's method.
42
+
43
+ coeffs: 1-D tensor [c_d, c_{d-1}, ..., c_1, c_0] (highest power first).
44
+ x: scalar tensor.
45
+ Returns: scalar tensor = c_d*x^d + ... + c_1*x + c_0.
46
+ """
47
+ out = coeffs[0]
48
+ for k in range(1, coeffs.shape[0]):
49
+ out = out * x + coeffs[k]
50
+ return out
51
+
52
+
53
+ class PILOT(Optimizer):
54
+ r"""PILOT optimizer.
55
+
56
+ Args:
57
+ params: iterable of parameters to optimize.
58
+ lr: learning rate (default: 1e-3).
59
+ betas: coefficients for running averages of gradient and its square (default: (0.9, 0.999)).
60
+ eps: term for numerical stability (default: 1e-8).
61
+ weight_decay: weight decay coefficient (default: 0.01).
62
+ gamma: smoothing factor for the landscape signal (default: 0.95).
63
+ eta_phi: meta learning rate for the response policy (default: 0.01).
64
+ degree: polynomial degree for the response policy (default: 2).
65
+ diagnostics: if True, record internal state every step (default: False).
66
+ policy_overrides: optional dict to fix policy variables, e.g. {'pm': 1.0}.
67
+ Overridden variables bypass phi and their meta-gradient blocks are zeroed.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ params,
73
+ lr=1e-3,
74
+ betas=(0.9, 0.999),
75
+ eps=1e-8,
76
+ weight_decay=0.01,
77
+ gamma=0.95,
78
+ eta_phi=0.01,
79
+ degree=2,
80
+ diagnostics=False,
81
+ policy_overrides=None,
82
+ ):
83
+ if not 0.0 <= lr:
84
+ raise ValueError(f"Invalid learning rate: {lr}")
85
+ if not 0.0 <= betas[0] < 1.0:
86
+ raise ValueError(f"Invalid beta1: {betas[0]}")
87
+ if not 0.0 <= betas[1] < 1.0:
88
+ raise ValueError(f"Invalid beta2: {betas[1]}")
89
+ if not 0.0 <= eps:
90
+ raise ValueError(f"Invalid eps: {eps}")
91
+ if not 0.0 <= gamma < 1.0:
92
+ raise ValueError(f"Invalid gamma: {gamma}")
93
+ if not (isinstance(degree, int) and degree >= 1):
94
+ raise ValueError(f"Invalid degree: {degree} (must be int >= 1)")
95
+
96
+ defaults = dict(
97
+ lr=lr,
98
+ betas=betas,
99
+ eps=eps,
100
+ weight_decay=weight_decay,
101
+ gamma=gamma,
102
+ eta_phi=eta_phi,
103
+ degree=degree,
104
+ )
105
+ super().__init__(params, defaults)
106
+
107
+ self.diagnostics = (
108
+ DiagnosticsTracker(degree=degree) if diagnostics else None
109
+ )
110
+ self.policy_overrides = policy_overrides or {}
111
+
112
+ phi_init = _build_phi_init(degree)
113
+
114
+ # Initialize phi and rho per group. rho is a 0-dim tensor on the
115
+ # first param's device to avoid GPU->CPU syncs in the hot path.
116
+ for group in self.param_groups:
117
+ device = next(iter(group["params"])).device
118
+ group["rho"] = torch.zeros((), device=device)
119
+ group["phi"] = phi_init.clone().to(device=device)
120
+ group["step"] = 0
121
+
122
+ @torch.no_grad()
123
+ def step(self, closure=None):
124
+ """Performs a single optimization step."""
125
+ loss = None
126
+ if closure is not None:
127
+ with torch.enable_grad():
128
+ loss = closure()
129
+
130
+ for group in self.param_groups:
131
+ lr = group["lr"]
132
+ beta1, beta2 = group["betas"]
133
+ eps = group["eps"]
134
+ weight_decay = group["weight_decay"]
135
+ gamma = group["gamma"]
136
+ eta_phi = group["eta_phi"]
137
+
138
+ phi = group["phi"]
139
+ rho = group["rho"]
140
+ group["step"] += 1
141
+ step = group["step"]
142
+
143
+ bias_corr1 = 1 - beta1 ** step
144
+ bias_corr2 = 1 - beta2 ** step
145
+
146
+ # ---- Pass 1: single landscape signal from the full gradient ----
147
+ # Accumulate dot(g, g_prev), ||g||^2, ||g_prev||^2 across all params.
148
+ # Also handles lazy state init so pass 2 can rely on it.
149
+ dot = torch.zeros((), device=phi.device)
150
+ sq_g = torch.zeros((), device=phi.device)
151
+ sq_prev = torch.zeros((), device=phi.device)
152
+
153
+ for p in group["params"]:
154
+ if p.grad is None:
155
+ continue
156
+ state = self.state[p]
157
+ if len(state) == 0:
158
+ state["m"] = torch.zeros_like(p)
159
+ state["v"] = torch.zeros_like(p)
160
+ state["g_prev"] = torch.zeros_like(p)
161
+ state["intermediates"] = None
162
+ g = p.grad
163
+ g_prev = state["g_prev"]
164
+ dot = dot + (g * g_prev).sum().to(phi.device)
165
+ sq_g = sq_g + g.pow(2).sum().to(phi.device)
166
+ sq_prev = sq_prev + g_prev.pow(2).sum().to(phi.device)
167
+
168
+ denom_cos = (sq_g * sq_prev).sqrt()
169
+ r = torch.where(
170
+ denom_cos > 1e-12,
171
+ dot / denom_cos.clamp(min=1e-12),
172
+ torch.zeros((), device=phi.device),
173
+ )
174
+ rho = gamma * rho + (1 - gamma) * r
175
+
176
+ # ---- Single response policy for the whole group ----
177
+ degree = group["degree"]
178
+ n_phi = degree + 1
179
+ coeffs_m = phi[:n_phi]
180
+ coeffs_v = phi[n_phi : 2 * n_phi]
181
+ coeffs_s = phi[2 * n_phi :]
182
+
183
+ z_m = _horner(coeffs_m, rho)
184
+ z_v = _horner(coeffs_v, rho)
185
+ z_s = _horner(coeffs_s, rho)
186
+
187
+ pm = torch.sigmoid(z_m)
188
+ pv = 0.5 * torch.sigmoid(z_v)
189
+ ps = torch.sigmoid(z_s)
190
+
191
+ if self.policy_overrides:
192
+ if 'pm' in self.policy_overrides:
193
+ pm = torch.tensor(self.policy_overrides['pm'], device=phi.device)
194
+ if 'pv' in self.policy_overrides:
195
+ pv = torch.tensor(self.policy_overrides['pv'], device=phi.device)
196
+ if 'ps' in self.policy_overrides:
197
+ ps = torch.tensor(self.policy_overrides['ps'], device=phi.device)
198
+
199
+ # ---- Pass 2: meta-gradient + Adam update for each param ----
200
+ meta_grad_acc = torch.zeros(3 * n_phi, device=phi.device)
201
+
202
+ for p in group["params"]:
203
+ if p.grad is None:
204
+ continue
205
+
206
+ g = p.grad
207
+ state = self.state[p]
208
+ m = state["m"]
209
+ v = state["v"]
210
+ g_prev = state["g_prev"] # still holds g_{t-1} until we copy below
211
+
212
+ # Meta-gradient from step t (intermediates) against g_{t+1} (current g).
213
+ intermediates = state.get("intermediates")
214
+ if intermediates is not None:
215
+ mg = compute_meta_grads(
216
+ g, intermediates, g_prev, phi, lr, degree=degree
217
+ )
218
+ meta_grad_acc += mg.to(phi.device)
219
+
220
+ # Adam bookkeeping
221
+ m.mul_(beta1).add_(g, alpha=1 - beta1)
222
+ v.mul_(beta2).addcmul_(g, g, value=1 - beta2)
223
+ m_hat = m / bias_corr1
224
+ v_hat = v / bias_corr2
225
+
226
+ # Move policy scalars to param device if needed (no-op on single device).
227
+ pm_p = pm.to(g.device) if pm.device != g.device else pm
228
+ pv_p = pv.to(g.device) if pv.device != g.device else pv
229
+ ps_p = ps.to(g.device) if ps.device != g.device else ps
230
+
231
+ # Adaptive step
232
+ n = pm_p * m_hat + (1 - pm_p) * g
233
+ n_abs = torch.clamp(n.abs(), min=1e-12)
234
+ denom = v_hat.pow(pv_p) + eps
235
+ d = -lr * n_abs.pow(1 - ps_p) * torch.sign(n) / denom
236
+
237
+ # Decoupled weight decay (AdamW-style)
238
+ if weight_decay != 0:
239
+ d = d - lr * weight_decay * p.data
240
+
241
+ p.data.add_(d)
242
+
243
+ # Store intermediates for next step's meta-gradient.
244
+ rho_p = rho.to(g.device) if rho.device != g.device else rho
245
+ state["intermediates"] = {
246
+ "n_t": n.detach(),
247
+ "m_hat": m_hat.detach(),
248
+ "v_hat": v_hat.detach(),
249
+ "p_m": pm_p.detach(),
250
+ "p_v": pv_p.detach(),
251
+ "p_s": ps_p.detach(),
252
+ "rho": rho_p.detach().clone(),
253
+ }
254
+
255
+ # Update g_prev to g_t for the next step.
256
+ g_prev.copy_(g)
257
+
258
+ # ---- Update phi with accumulated meta-gradients ----
259
+ if self.policy_overrides:
260
+ if 'pm' in self.policy_overrides:
261
+ meta_grad_acc[:n_phi] = 0
262
+ if 'pv' in self.policy_overrides:
263
+ meta_grad_acc[n_phi:2 * n_phi] = 0
264
+ if 'ps' in self.policy_overrides:
265
+ meta_grad_acc[2 * n_phi:] = 0
266
+
267
+ # Clip by L2 norm to guard against log|n| blow-up in meta-gradients.
268
+ mg_norm = meta_grad_acc.norm()
269
+ scale = torch.clamp(META_GRAD_CLIP / (mg_norm + 1e-12), max=1.0)
270
+ phi.add_(meta_grad_acc * scale, alpha=-eta_phi)
271
+
272
+ # Persist updated rho (still a tensor, no sync).
273
+ group["rho"] = rho
274
+
275
+ # ---- Diagnostics ----
276
+ if self.diagnostics is not None:
277
+ self.diagnostics.record(
278
+ step,
279
+ float(r),
280
+ float(rho),
281
+ float(pm),
282
+ float(pv),
283
+ float(ps),
284
+ phi,
285
+ )
286
+
287
+ return loss
@@ -0,0 +1,182 @@
1
+ Metadata-Version: 2.4
2
+ Name: pilot-optimizer
3
+ Version: 0.1.0
4
+ Summary: PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
5
+ Author-email: Sattam Altuuaim <sattam.tuuaim@kaust.edu.sa>, Lama Ayash <lama.ayash@kaust.edu.sa>, Muhammad Mubashar <muhammad.mubashar@strath.ac.uk>, Naeemullah Khan <naeemullah.khan@kaust.edu.sa>
6
+ License: MIT
7
+ Project-URL: Homepage, https://sattamaltwaim.github.io/PILOT/
8
+ Project-URL: Repository, https://github.com/SattamAltwaim/PILOT
9
+ Project-URL: Paper, https://arxiv.org/abs/submit/7629402
10
+ Keywords: optimizer,deep-learning,pytorch,meta-learning,adaptive-optimization
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.9
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: torch>=2.0.0
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest; extra == "dev"
26
+ Requires-Dist: numpy; extra == "dev"
27
+ Dynamic: license-file
28
+
29
+ # PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
30
+
31
+ > **PILOT** is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.
32
+
33
+ ---
34
+
35
+ ## Overview
36
+
37
+ Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.
38
+
39
+ **PILOT** introduces a learnable policy that continuously modulates three core update primitives:
40
+ - **Momentum reliance** — how much to rely on accumulated gradient history vs. the current gradient
41
+ - **Variance-normalization strength** — how aggressively to apply adaptive scaling
42
+ - **Sign-based behavior** — how much to compress gradient magnitudes toward ±1
43
+
44
+ The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.
45
+
46
+ ![Loss Landscape — CIFAR-10 / SmallCNN](fig2_landscape.png)
47
+ *PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.*
48
+
49
+ ---
50
+
51
+ ## Key Results
52
+
53
+ ### CNN Architecture
54
+
55
+ | Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
56
+ |---|---|---|---|---|
57
+ | FashionMNIST | Adam | 93.28 | 0.1957 | **0.0033** |
58
+ | FashionMNIST | AdamW | 93.22 | 0.1944 | 0.0034 |
59
+ | FashionMNIST | Lion | 92.91 | 0.2091 | 0.0041 |
60
+ | FashionMNIST | AdaBelief | 93.66 | 0.1822 | 0.0046 |
61
+ | FashionMNIST | **PILOT (Ours)** | **94.13** | **0.1719** | 0.0045 |
62
+ | CIFAR-10 | Adam | 79.91 | 0.5794 | 0.0103 |
63
+ | CIFAR-10 | Lion | 80.87 | 0.5487 | 0.0105 |
64
+ | CIFAR-10 | **PILOT (Ours)** | **81.94** | **0.5302** | **0.0073** |
65
+
66
+ ### ResNet-18 Architecture
67
+
68
+ | Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
69
+ |---|---|---|---|---|
70
+ | FashionMNIST | AdaBelief | 95.33 | 0.1711 | 0.0056 |
71
+ | FashionMNIST | **PILOT (Ours)** | **95.71** | 0.2690 | 0.0030 |
72
+ | CIFAR-10 | Adam | 93.18 | **0.2140** | 0.0073 |
73
+ | CIFAR-10 | AdamW | 92.90 | 0.2514 | 0.0066 |
74
+ | CIFAR-10 | **PILOT (Ours)** | **93.42** | 0.2496 | **0.0001** |
75
+
76
+ ---
77
+
78
+ ## Method
79
+
80
+ ### Gradient-Direction Agreement
81
+
82
+ At each step, PILOT computes the cosine similarity between successive gradients:
83
+
84
+ $$r_t = \frac{g_t^\top g_{t-1}}{\|g_t\|_2 \, \|g_{t-1}\|_2 + \epsilon}$$
85
+
86
+ This is smoothed via an exponential moving average:
87
+
88
+ $$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$
89
+
90
+ Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.
91
+
92
+ ### Learnable Policy
93
+
94
+ The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:
95
+
96
+ $$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$
97
+
98
+ The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.
99
+
100
+ ### Update Rule
101
+
102
+ $$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}_t^{\,p_{v,t}} + \epsilon}$$
103
+
104
+ where $n_t = p_{m,t} \hat{m}_t + (1 - p_{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.
105
+
106
+ This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.
107
+
108
+ ---
109
+
110
+ ## Installation
111
+
112
+ ```bash
113
+ pip install pilot-optimizer
114
+ ```
115
+
116
+ Or install from source:
117
+
118
+ ```bash
119
+ git clone https://github.com/SattamAltwaim/PILOT.git
120
+ cd PILOT
121
+ pip install -e .
122
+ ```
123
+
124
+ ---
125
+
126
+ ## Usage
127
+
128
+ ```python
129
+ from pilot import PILOT
130
+
131
+ optimizer = PILOT(
132
+ model.parameters(),
133
+ lr=1e-3,
134
+ betas=(0.9, 0.999),
135
+ weight_decay=1e-4,
136
+ gamma=0.95, # smoothing coefficient for agreement signal
137
+ lr_phi=0.01, # policy learning rate
138
+ degree=2 # polynomial degree
139
+ )
140
+
141
+ for batch in dataloader:
142
+ loss = criterion(model(x), y)
143
+ optimizer.zero_grad()
144
+ loss.backward()
145
+ optimizer.step()
146
+ ```
147
+
148
+ ---
149
+
150
+ ## Hyperparameters
151
+
152
+ | Parameter | Description | Typical Range |
153
+ |---|---|---|
154
+ | `lr` | Model learning rate | `1e-4` – `1e-3` |
155
+ | `betas` | Moment coefficients | `(0.9, 0.999)` |
156
+ | `gamma` | Agreement signal smoothing | `0.85` – `0.99` |
157
+ | `lr_phi` | Policy learning rate | `5e-4` – `5e-2` |
158
+ | `degree` | Polynomial degree | `1` – `4` |
159
+
160
+ ### Configuration-Specific Selections
161
+
162
+ | Dataset | Architecture | γ | η_φ | Degree |
163
+ |---|---|---|---|---|
164
+ | CIFAR-10 | CNN | 0.882 | 0.00312 | 1 |
165
+ | CIFAR-10 | ResNet-18 | 0.950 | 0.00500 | 2 |
166
+ | FashionMNIST | CNN | 0.950 | 0.01000 | 2 |
167
+ | FashionMNIST | ResNet-18 | 0.957 | 0.00273 | 3 |
168
+
169
+ ---
170
+
171
+ ## Experiments
172
+
173
+ Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.
174
+
175
+ ```bash
176
+ # CNN on CIFAR-10
177
+ python train.py --dataset cifar10 --arch cnn --optimizer pilot
178
+
179
+ # ResNet-18 on FashionMNIST
180
+ python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot
181
+ ```
182
+
@@ -0,0 +1,9 @@
1
+ pilot/__init__.py,sha256=lkgpmUggaQ0t25jh5OWADfnhl5KpbX8CWXL-QWdUW_E,72
2
+ pilot/diagnostics.py,sha256=A1_gRprfSpXEPGOkeVDHxsMLxl2qtHyun0HO2tgaDk8,1834
3
+ pilot/meta_grads.py,sha256=lhIZjaggoDkeyukRNrhIT6WsY031M1zfGXplVry6kHk,3479
4
+ pilot/optimizer.py,sha256=hcLeXtM3HRVQK3B1gMXJdp-MudSefTDICF_E72d3-b4,10820
5
+ pilot_optimizer-0.1.0.dist-info/licenses/LICENSE,sha256=d50RzYxuMWIMIhnSEGoYkVNnbUCeae2WrpPoDQgFIYE,1071
6
+ pilot_optimizer-0.1.0.dist-info/METADATA,sha256=J8XOcu6WF1_BeUs9PjyRekbnk8ixyKkmaNiRYJdJBAo,6535
7
+ pilot_optimizer-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
8
+ pilot_optimizer-0.1.0.dist-info/top_level.txt,sha256=BijnVJdXnIPxxx3s60M848seL4Z12gNUPod6KPJxK9c,6
9
+ pilot_optimizer-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Sattam Altwaim
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1 @@
1
+ pilot