pilot-optimizer 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Sattam Altwaim
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,182 @@
1
+ Metadata-Version: 2.4
2
+ Name: pilot-optimizer
3
+ Version: 0.1.0
4
+ Summary: PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
5
+ Author-email: Sattam Altuuaim <sattam.tuuaim@kaust.edu.sa>, Lama Ayash <lama.ayash@kaust.edu.sa>, Muhammad Mubashar <muhammad.mubashar@strath.ac.uk>, Naeemullah Khan <naeemullah.khan@kaust.edu.sa>
6
+ License: MIT
7
+ Project-URL: Homepage, https://sattamaltwaim.github.io/PILOT/
8
+ Project-URL: Repository, https://github.com/SattamAltwaim/PILOT
9
+ Project-URL: Paper, https://arxiv.org/abs/submit/7629402
10
+ Keywords: optimizer,deep-learning,pytorch,meta-learning,adaptive-optimization
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Intended Audience :: Science/Research
13
+ Classifier: License :: OSI Approved :: MIT License
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
+ Requires-Python: >=3.9
21
+ Description-Content-Type: text/markdown
22
+ License-File: LICENSE
23
+ Requires-Dist: torch>=2.0.0
24
+ Provides-Extra: dev
25
+ Requires-Dist: pytest; extra == "dev"
26
+ Requires-Dist: numpy; extra == "dev"
27
+ Dynamic: license-file
28
+
29
+ # PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
30
+
31
+ > **PILOT** is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.
32
+
33
+ ---
34
+
35
+ ## Overview
36
+
37
+ Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.
38
+
39
+ **PILOT** introduces a learnable policy that continuously modulates three core update primitives:
40
+ - **Momentum reliance** — how much to rely on accumulated gradient history vs. the current gradient
41
+ - **Variance-normalization strength** — how aggressively to apply adaptive scaling
42
+ - **Sign-based behavior** — how much to compress gradient magnitudes toward ±1
43
+
44
+ The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.
45
+
46
+ ![Loss Landscape — CIFAR-10 / SmallCNN](fig2_landscape.png)
47
+ *PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.*
48
+
49
+ ---
50
+
51
+ ## Key Results
52
+
53
+ ### CNN Architecture
54
+
55
+ | Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
56
+ |---|---|---|---|---|
57
+ | FashionMNIST | Adam | 93.28 | 0.1957 | **0.0033** |
58
+ | FashionMNIST | AdamW | 93.22 | 0.1944 | 0.0034 |
59
+ | FashionMNIST | Lion | 92.91 | 0.2091 | 0.0041 |
60
+ | FashionMNIST | AdaBelief | 93.66 | 0.1822 | 0.0046 |
61
+ | FashionMNIST | **PILOT (Ours)** | **94.13** | **0.1719** | 0.0045 |
62
+ | CIFAR-10 | Adam | 79.91 | 0.5794 | 0.0103 |
63
+ | CIFAR-10 | Lion | 80.87 | 0.5487 | 0.0105 |
64
+ | CIFAR-10 | **PILOT (Ours)** | **81.94** | **0.5302** | **0.0073** |
65
+
66
+ ### ResNet-18 Architecture
67
+
68
+ | Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
69
+ |---|---|---|---|---|
70
+ | FashionMNIST | AdaBelief | 95.33 | 0.1711 | 0.0056 |
71
+ | FashionMNIST | **PILOT (Ours)** | **95.71** | 0.2690 | 0.0030 |
72
+ | CIFAR-10 | Adam | 93.18 | **0.2140** | 0.0073 |
73
+ | CIFAR-10 | AdamW | 92.90 | 0.2514 | 0.0066 |
74
+ | CIFAR-10 | **PILOT (Ours)** | **93.42** | 0.2496 | **0.0001** |
75
+
76
+ ---
77
+
78
+ ## Method
79
+
80
+ ### Gradient-Direction Agreement
81
+
82
+ At each step, PILOT computes the cosine similarity between successive gradients:
83
+
84
+ $$r_t = \frac{g_t^\top g_{t-1}}{\|g_t\|_2 \, \|g_{t-1}\|_2 + \epsilon}$$
85
+
86
+ This is smoothed via an exponential moving average:
87
+
88
+ $$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$
89
+
90
+ Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.
91
+
92
+ ### Learnable Policy
93
+
94
+ The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:
95
+
96
+ $$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$
97
+
98
+ The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.
99
+
100
+ ### Update Rule
101
+
102
+ $$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}_t^{\,p_{v,t}} + \epsilon}$$
103
+
104
+ where $n_t = p_{m,t} \hat{m}_t + (1 - p_{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.
105
+
106
+ This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.
107
+
108
+ ---
109
+
110
+ ## Installation
111
+
112
+ ```bash
113
+ pip install pilot-optimizer
114
+ ```
115
+
116
+ Or install from source:
117
+
118
+ ```bash
119
+ git clone https://github.com/SattamAltwaim/PILOT.git
120
+ cd PILOT
121
+ pip install -e .
122
+ ```
123
+
124
+ ---
125
+
126
+ ## Usage
127
+
128
+ ```python
129
+ from pilot import PILOT
130
+
131
+ optimizer = PILOT(
132
+ model.parameters(),
133
+ lr=1e-3,
134
+ betas=(0.9, 0.999),
135
+ weight_decay=1e-4,
136
+ gamma=0.95, # smoothing coefficient for agreement signal
137
+ lr_phi=0.01, # policy learning rate
138
+ degree=2 # polynomial degree
139
+ )
140
+
141
+ for batch in dataloader:
142
+ loss = criterion(model(x), y)
143
+ optimizer.zero_grad()
144
+ loss.backward()
145
+ optimizer.step()
146
+ ```
147
+
148
+ ---
149
+
150
+ ## Hyperparameters
151
+
152
+ | Parameter | Description | Typical Range |
153
+ |---|---|---|
154
+ | `lr` | Model learning rate | `1e-4` – `1e-3` |
155
+ | `betas` | Moment coefficients | `(0.9, 0.999)` |
156
+ | `gamma` | Agreement signal smoothing | `0.85` – `0.99` |
157
+ | `lr_phi` | Policy learning rate | `5e-4` – `5e-2` |
158
+ | `degree` | Polynomial degree | `1` – `4` |
159
+
160
+ ### Configuration-Specific Selections
161
+
162
+ | Dataset | Architecture | γ | η_φ | Degree |
163
+ |---|---|---|---|---|
164
+ | CIFAR-10 | CNN | 0.882 | 0.00312 | 1 |
165
+ | CIFAR-10 | ResNet-18 | 0.950 | 0.00500 | 2 |
166
+ | FashionMNIST | CNN | 0.950 | 0.01000 | 2 |
167
+ | FashionMNIST | ResNet-18 | 0.957 | 0.00273 | 3 |
168
+
169
+ ---
170
+
171
+ ## Experiments
172
+
173
+ Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.
174
+
175
+ ```bash
176
+ # CNN on CIFAR-10
177
+ python train.py --dataset cifar10 --arch cnn --optimizer pilot
178
+
179
+ # ResNet-18 on FashionMNIST
180
+ python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot
181
+ ```
182
+
@@ -0,0 +1,154 @@
1
+ # PILOT: Policy-Informed Learned Optimization for Adaptive Deep Network Training
2
+
3
+ > **PILOT** is an online adaptive optimizer that adjusts its update behavior during training using gradient-direction agreement as a signal of local optimization stability.
4
+
5
+ ---
6
+
7
+ ## Overview
8
+
9
+ Most optimizers use a fixed update structure throughout training — a static balance between momentum, normalization, and sign-based updates that cannot respond to how the loss landscape evolves.
10
+
11
+ **PILOT** introduces a learnable policy that continuously modulates three core update primitives:
12
+ - **Momentum reliance** — how much to rely on accumulated gradient history vs. the current gradient
13
+ - **Variance-normalization strength** — how aggressively to apply adaptive scaling
14
+ - **Sign-based behavior** — how much to compress gradient magnitudes toward ±1
15
+
16
+ The policy is conditioned on a smoothed gradient-direction agreement signal, which serves as a compact online descriptor of local update consistency. It is updated online during training using a one-step meta-gradient estimate — no offline search, no meta-training phase, no second-order estimation.
17
+
18
+ ![Loss Landscape — CIFAR-10 / SmallCNN](fig2_landscape.png)
19
+ *PILOT follows a distinct trajectory through the loss surface and converges to a lower-loss region compared to Adam, AdamW, Lion, and Sophia.*
20
+
21
+ ---
22
+
23
+ ## Key Results
24
+
25
+ ### CNN Architecture
26
+
27
+ | Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
28
+ |---|---|---|---|---|
29
+ | FashionMNIST | Adam | 93.28 | 0.1957 | **0.0033** |
30
+ | FashionMNIST | AdamW | 93.22 | 0.1944 | 0.0034 |
31
+ | FashionMNIST | Lion | 92.91 | 0.2091 | 0.0041 |
32
+ | FashionMNIST | AdaBelief | 93.66 | 0.1822 | 0.0046 |
33
+ | FashionMNIST | **PILOT (Ours)** | **94.13** | **0.1719** | 0.0045 |
34
+ | CIFAR-10 | Adam | 79.91 | 0.5794 | 0.0103 |
35
+ | CIFAR-10 | Lion | 80.87 | 0.5487 | 0.0105 |
36
+ | CIFAR-10 | **PILOT (Ours)** | **81.94** | **0.5302** | **0.0073** |
37
+
38
+ ### ResNet-18 Architecture
39
+
40
+ | Dataset | Optimizer | Accuracy (%) ↑ | Val Loss ↓ | Loss Var. ↓ |
41
+ |---|---|---|---|---|
42
+ | FashionMNIST | AdaBelief | 95.33 | 0.1711 | 0.0056 |
43
+ | FashionMNIST | **PILOT (Ours)** | **95.71** | 0.2690 | 0.0030 |
44
+ | CIFAR-10 | Adam | 93.18 | **0.2140** | 0.0073 |
45
+ | CIFAR-10 | AdamW | 92.90 | 0.2514 | 0.0066 |
46
+ | CIFAR-10 | **PILOT (Ours)** | **93.42** | 0.2496 | **0.0001** |
47
+
48
+ ---
49
+
50
+ ## Method
51
+
52
+ ### Gradient-Direction Agreement
53
+
54
+ At each step, PILOT computes the cosine similarity between successive gradients:
55
+
56
+ $$r_t = \frac{g_t^\top g_{t-1}}{\|g_t\|_2 \, \|g_{t-1}\|_2 + \epsilon}$$
57
+
58
+ This is smoothed via an exponential moving average:
59
+
60
+ $$\rho_t = \gamma \rho_{t-1} + (1 - \gamma) r_t$$
61
+
62
+ Positive values indicate stable, aligned gradients. Values near zero indicate noise. Negative values indicate directional disagreement.
63
+
64
+ ### Learnable Policy
65
+
66
+ The smoothed signal $\rho_t$ is fed through polynomial functions followed by sigmoid activations to produce three scalar control variables:
67
+
68
+ $$p_{m,t} = \sigma(f(\rho_t; \phi_m)), \quad p_{v,t} = \tfrac{1}{2}\sigma(f(\rho_t; \phi_v)), \quad p_{s,t} = \sigma(f(\rho_t; \phi_s))$$
69
+
70
+ The total number of learnable policy parameters is $3(d+1)$, where $d$ is the polynomial degree.
71
+
72
+ ### Update Rule
73
+
74
+ $$\theta_{t+1} = \theta_t - \eta \frac{(|n_t| + \epsilon_n)^{1 - p_{s,t}} \odot \text{sign}(n_t)}{\hat{v}_t^{\,p_{v,t}} + \epsilon}$$
75
+
76
+ where $n_t = p_{m,t} \hat{m}_t + (1 - p_{m,t}) g_t$ is the policy-controlled blend of momentum and current gradient.
77
+
78
+ This formulation recovers Adam ($p_m=1, p_v=0.5, p_s=0$) and sign-based updates ($p_s=1, p_v=0$) as special cases.
79
+
80
+ ---
81
+
82
+ ## Installation
83
+
84
+ ```bash
85
+ pip install pilot-optimizer
86
+ ```
87
+
88
+ Or install from source:
89
+
90
+ ```bash
91
+ git clone https://github.com/SattamAltwaim/PILOT.git
92
+ cd PILOT
93
+ pip install -e .
94
+ ```
95
+
96
+ ---
97
+
98
+ ## Usage
99
+
100
+ ```python
101
+ from pilot import PILOT
102
+
103
+ optimizer = PILOT(
104
+ model.parameters(),
105
+ lr=1e-3,
106
+ betas=(0.9, 0.999),
107
+ weight_decay=1e-4,
108
+ gamma=0.95, # smoothing coefficient for agreement signal
109
+ lr_phi=0.01, # policy learning rate
110
+ degree=2 # polynomial degree
111
+ )
112
+
113
+ for batch in dataloader:
114
+ loss = criterion(model(x), y)
115
+ optimizer.zero_grad()
116
+ loss.backward()
117
+ optimizer.step()
118
+ ```
119
+
120
+ ---
121
+
122
+ ## Hyperparameters
123
+
124
+ | Parameter | Description | Typical Range |
125
+ |---|---|---|
126
+ | `lr` | Model learning rate | `1e-4` – `1e-3` |
127
+ | `betas` | Moment coefficients | `(0.9, 0.999)` |
128
+ | `gamma` | Agreement signal smoothing | `0.85` – `0.99` |
129
+ | `lr_phi` | Policy learning rate | `5e-4` – `5e-2` |
130
+ | `degree` | Polynomial degree | `1` – `4` |
131
+
132
+ ### Configuration-Specific Selections
133
+
134
+ | Dataset | Architecture | γ | η_φ | Degree |
135
+ |---|---|---|---|---|
136
+ | CIFAR-10 | CNN | 0.882 | 0.00312 | 1 |
137
+ | CIFAR-10 | ResNet-18 | 0.950 | 0.00500 | 2 |
138
+ | FashionMNIST | CNN | 0.950 | 0.01000 | 2 |
139
+ | FashionMNIST | ResNet-18 | 0.957 | 0.00273 | 3 |
140
+
141
+ ---
142
+
143
+ ## Experiments
144
+
145
+ Experiments use 30 epochs, cross-entropy loss, cosine annealing LR schedule, batch size 128, and AMP. ResNet-18 configurations include a 3-epoch linear warmup.
146
+
147
+ ```bash
148
+ # CNN on CIFAR-10
149
+ python train.py --dataset cifar10 --arch cnn --optimizer pilot
150
+
151
+ # ResNet-18 on FashionMNIST
152
+ python train.py --dataset fashionmnist --arch resnet18 --optimizer pilot
153
+ ```
154
+
@@ -0,0 +1,4 @@
1
+ from .optimizer import PILOT
2
+
3
+ __all__ = ["PILOT"]
4
+ __version__ = "0.1.0"
@@ -0,0 +1,54 @@
1
+ """
2
+ Diagnostics tracker for the PILOT optimizer.
3
+
4
+ Only records data when explicitly enabled. Zero overhead when disabled.
5
+ """
6
+
7
+
8
+ class DiagnosticsTracker:
9
+ """Stores per-step optimizer internals for analysis."""
10
+
11
+ def __init__(self, *, degree: int = 2):
12
+ self._degree = degree
13
+ n_phi = 3 * (degree + 1)
14
+ self._phi_keys = [f"phi_{i}" for i in range(n_phi)]
15
+ self._history = {
16
+ "step": [],
17
+ "r": [],
18
+ "rho": [],
19
+ "p_m": [],
20
+ "p_v": [],
21
+ "p_s": [],
22
+ }
23
+ for key in self._phi_keys:
24
+ self._history[key] = []
25
+
26
+ def record(self, step, r, rho, pm, pv, ps, phi):
27
+ """Record diagnostics for one step."""
28
+ self._history["step"].append(step)
29
+ self._history["r"].append(float(r))
30
+ self._history["rho"].append(float(rho))
31
+ self._history["p_m"].append(float(pm))
32
+ self._history["p_v"].append(float(pv))
33
+ self._history["p_s"].append(float(ps))
34
+ phi_vals = phi.detach().cpu().tolist()
35
+ for i, key in enumerate(self._phi_keys):
36
+ self._history[key].append(phi_vals[i])
37
+
38
+ def get_history(self):
39
+ """Return full history as dict of lists."""
40
+ return dict(self._history)
41
+
42
+ def summary(self, last_n=5):
43
+ """Print a summary of the last N recorded steps."""
44
+ h = self._history
45
+ if not h["step"]:
46
+ return "No data recorded."
47
+ lines = [f"Last {min(last_n, len(h['step']))} steps:"]
48
+ for i in range(-last_n, 0):
49
+ idx = len(h["step"]) + i
50
+ lines.append(
51
+ f" step={h['step'][idx]} r={h['r'][idx]:.4f} rho={h['rho'][idx]:.4f} "
52
+ f"pm={h['p_m'][idx]:.4f} pv={h['p_v'][idx]:.4f} ps={h['p_s'][idx]:.4f}"
53
+ )
54
+ return "\n".join(lines)
@@ -0,0 +1,104 @@
1
+ """
2
+ Analytic meta-gradient computation for the PILOT optimizer.
3
+
4
+ Pure function that computes the 3*(degree+1) meta-gradients (dL/dphi_i)
5
+ from stored detached intermediates. No autograd graph, no state.
6
+ """
7
+
8
+ import torch
9
+
10
+
11
+ def _safe_pow(base, exp_val):
12
+ """|base|^exp_val with clamping for numerical stability."""
13
+ return torch.clamp(base.abs(), min=1e-12).pow(exp_val)
14
+
15
+
16
+ def _safe_log(x):
17
+ """log(|x|) with clamping for numerical stability."""
18
+ return torch.log(torch.clamp(x.abs(), min=1e-12))
19
+
20
+
21
+ def _horner(coeffs: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
22
+ """Evaluate a polynomial via Horner's method (highest power first)."""
23
+ out = coeffs[0]
24
+ for k in range(1, coeffs.shape[0]):
25
+ out = out * x + coeffs[k]
26
+ return out
27
+
28
+
29
+ def compute_meta_grads(g_next, intermediates, g_t, phi, eta, *, degree=2):
30
+ """Compute the 3*(degree+1) meta-gradients dL/dphi_i for one param tensor.
31
+
32
+ The chain: phi -> (pm, pv, ps) -> d_t -> theta_{t+1} -> g_{t+1} -> L
33
+
34
+ dL/dphi_i = g_{t+1}^T * dd_t/dphi_i (summed over all elements)
35
+
36
+ Args:
37
+ g_next: gradient at step t+1, shape matches param.
38
+ intermediates: dict with keys n_t, m_hat, v_hat, p_m, p_v, p_s, rho.
39
+ g_t: gradient at step t (stored separately as state["g_prev"]).
40
+ phi: current group-level phi vector, length 3*(degree+1).
41
+ eta: main optimizer learning rate.
42
+ degree: polynomial degree of the response policy.
43
+
44
+ Returns:
45
+ grads: tensor of shape (3*(degree+1),).
46
+ """
47
+ n_t = intermediates["n_t"]
48
+ m_hat = intermediates["m_hat"]
49
+ v_hat = intermediates["v_hat"]
50
+ pm = intermediates["p_m"]
51
+ pv = intermediates["p_v"]
52
+ ps = intermediates["p_s"]
53
+ rho = intermediates["rho"]
54
+
55
+ denom = _safe_pow(v_hat, pv) + 1e-8
56
+ n_abs_neg_ps = _safe_pow(n_t, -ps)
57
+ n_abs_1ps = _safe_pow(n_t, 1 - ps)
58
+ sign_n = torch.sign(n_t)
59
+
60
+ # dd_t / dp_m = -eta * (1 - ps) * |n|^(-ps) * (m_hat - g_t) / denom
61
+ dd_dpm = -eta * (1 - ps) * n_abs_neg_ps * (m_hat - g_t) / denom
62
+
63
+ # dd_t / dp_v = eta * |n|^(1-ps) * sign(n) * v^pv * log|v| / denom^2
64
+ dd_dpv = eta * n_abs_1ps * sign_n * _safe_pow(v_hat, pv) * _safe_log(v_hat) / (denom * denom)
65
+
66
+ # dd_t / dp_s = eta * |n|^(1-ps) * sign(n) * log|n| / denom
67
+ dd_dps = eta * n_abs_1ps * sign_n * _safe_log(n_t) / denom
68
+
69
+ # Scalar dL/dp for each policy variable
70
+ dL_dpm = torch.sum(g_next * dd_dpm)
71
+ dL_dpv = torch.sum(g_next * dd_dpv)
72
+ dL_dps = torch.sum(g_next * dd_dps)
73
+
74
+ # --- Chain rule through sigmoid and polynomial ---
75
+ n_phi = degree + 1
76
+ coeffs_m = phi[:n_phi]
77
+ coeffs_v = phi[n_phi : 2 * n_phi]
78
+ coeffs_s = phi[2 * n_phi :]
79
+
80
+ z_m = _horner(coeffs_m, rho)
81
+ z_v = _horner(coeffs_v, rho)
82
+ z_s = _horner(coeffs_s, rho)
83
+
84
+ s_m = torch.sigmoid(z_m)
85
+ s_v = torch.sigmoid(z_v)
86
+ s_s = torch.sigmoid(z_s)
87
+
88
+ ds_m = s_m * (1 - s_m)
89
+ ds_v = s_v * (1 - s_v)
90
+ ds_s = s_s * (1 - s_s)
91
+
92
+ # Build rho powers: [rho^d, rho^{d-1}, ..., rho^1, rho^0]
93
+ rho_powers = torch.empty(n_phi, device=phi.device)
94
+ rho_powers[n_phi - 1] = 1.0
95
+ for k in range(n_phi - 2, -1, -1):
96
+ rho_powers[k] = rho_powers[k + 1] * rho
97
+
98
+ # dL/d(coeffs_m[k]) = dL/dpm * ds_m * rho^(d-k)
99
+ grads_m = dL_dpm * ds_m * rho_powers
100
+ # pv = 0.5 * sigmoid(z_v), so dp_v/dz_v = 0.5 * sigmoid'(z_v)
101
+ grads_v = dL_dpv * 0.5 * ds_v * rho_powers
102
+ grads_s = dL_dps * ds_s * rho_powers
103
+
104
+ return torch.cat([grads_m, grads_v, grads_s])