adv-optm 1.2.dev19__tar.gz → 2.dev2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/PKG-INFO +20 -20
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/README.md +19 -19
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/__init__.py +1 -1
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/AdaMuon_adv.py +11 -9
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/AdamW_adv.py +91 -61
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Adopt_adv.py +113 -68
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +79 -81
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Lion_adv.py +59 -43
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Muon_adv.py +13 -12
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Prodigy_adv.py +108 -86
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +93 -52
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/__init__.py +1 -1
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/Effective_Shape.py +1 -1
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/Kourkoutas.py +10 -12
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/NNMF.py +7 -2
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/One_Bit_Boolean.py +1 -1
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/OrthoGrad.py +4 -3
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/__init__.py +1 -1
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/PKG-INFO +20 -20
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/setup.py +1 -1
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/LICENSE +0 -0
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/Newton_Schulz.py +0 -0
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-1.2.dev19 → adv_optm-2.dev2}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adv_optm
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.dev2
|
|
4
4
|
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
5
|
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
6
|
Author: Koratahiu
|
|
@@ -52,7 +52,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
52
52
|
### **Memory-Efficient Optimization (SMMF-inspired)**
|
|
53
53
|
- **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
54
54
|
- **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
|
|
55
|
-
- **Innovation**:
|
|
55
|
+
- **Innovation**:
|
|
56
56
|
- First moment split into **1-bit sign + absolute value**
|
|
57
57
|
- Final storage: **four factored vectors + one 1-bit sign state**
|
|
58
58
|
- Preserves Adam-like update quality with drastically reduced memory
|
|
@@ -110,7 +110,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
110
110
|
|
|
111
111
|
## 🛠️ Comprehensive Feature Guide
|
|
112
112
|
|
|
113
|
-
### A. Universal Safe Features
|
|
113
|
+
### A. Universal Safe Features
|
|
114
114
|
*These features work with all optimizers and are generally safe to enable.*
|
|
115
115
|
|
|
116
116
|
| Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
|
|
@@ -165,7 +165,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
165
165
|
| `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
|
|
166
166
|
| `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
|
|
167
167
|
|
|
168
|
-
> ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
|
|
168
|
+
> ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
|
|
169
169
|
> For `Prodigy_Adv`, set `initial_d` to:
|
|
170
170
|
> - **LoRA**: `1e-8`
|
|
171
171
|
> - **Full FT**: `1e-10`
|
|
@@ -175,10 +175,10 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
175
175
|
|
|
176
176
|
#### Performance Validation
|
|
177
177
|
|
|
178
|
-
**Small Batch Training (SDXL, BS=2, 1.8K steps)**
|
|
178
|
+
**Small Batch Training (SDXL, BS=2, 1.8K steps)**
|
|
179
179
|

|
|
180
180
|
|
|
181
|
-
- **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
|
|
181
|
+
- **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
|
|
182
182
|
- **🔵 Prodigy_Adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR = 5.8e-6
|
|
183
183
|
|
|
184
184
|
**Results**:
|
|
@@ -202,8 +202,8 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
202
202
|
|
|
203
203
|
Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
|
|
204
204
|
|
|
205
|
-
- **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
|
|
206
|
-
- **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
|
|
205
|
+
- **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
|
|
206
|
+
- **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
|
|
207
207
|
|
|
208
208
|
This is especially effective for **noisy training, small batch sizes, and high learning rates**, where gradient norms shift abruptly due to noise or aggressive LR schedules.
|
|
209
209
|
|
|
@@ -220,17 +220,17 @@ This is especially effective for **noisy training, small batch sizes, and high l
|
|
|
220
220
|
|
|
221
221
|
#### 📊 Performance Validation
|
|
222
222
|
|
|
223
|
-
**ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
|
|
223
|
+
**ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
|
|
224
224
|
<img width="1460" height="382" alt="image" src="https://github.com/user-attachments/assets/007f278a-fbac-4f3d-9cc7-274c3b959cdd" />
|
|
225
225
|
|
|
226
|
-
- 🟣 Fixed `beta2=0.999`
|
|
227
|
-
- 🟠 Auto K-beta
|
|
226
|
+
- 🟣 Fixed `beta2=0.999`
|
|
227
|
+
- 🟠 Auto K-beta
|
|
228
228
|
|
|
229
|
-
**Observations:**
|
|
229
|
+
**Observations:**
|
|
230
230
|
- K-beta is clearly better and more robust/stable for high LRs.
|
|
231
231
|
|
|
232
|
-
> 📚 **Reference**:
|
|
233
|
-
> - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
232
|
+
> 📚 **Reference**:
|
|
233
|
+
> - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
234
234
|
> - Code: [kbeta](https://github.com/sck-at-ucy/kbeta)
|
|
235
235
|
|
|
236
236
|
---
|
|
@@ -258,7 +258,7 @@ settings:
|
|
|
258
258
|
- factored: False # Can be true or false, quality should not degrade due to Simplified_AdEMAMix’s high tolerance to 1-bit factorization.
|
|
259
259
|
```
|
|
260
260
|
|
|
261
|
-
> ✅ **Why it works**:
|
|
261
|
+
> ✅ **Why it works**:
|
|
262
262
|
> - `Kourkoutas-β` handles beta2 values
|
|
263
263
|
> - `Simplified_AdEMAMix` ensures responsiveness in small-batch noise
|
|
264
264
|
> - `OrthoGrad` prevents overfitting without weight decay
|
|
@@ -267,9 +267,9 @@ settings:
|
|
|
267
267
|
|
|
268
268
|
## 📚 References
|
|
269
269
|
|
|
270
|
-
1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
|
|
271
|
-
2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
272
|
-
3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
|
|
273
|
-
4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
|
|
274
|
-
5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
|
|
270
|
+
1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
|
|
271
|
+
2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
272
|
+
3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
|
|
273
|
+
4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
|
|
274
|
+
5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
|
|
275
275
|
6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
@@ -21,7 +21,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
21
21
|
### **Memory-Efficient Optimization (SMMF-inspired)**
|
|
22
22
|
- **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
23
23
|
- **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
|
|
24
|
-
- **Innovation**:
|
|
24
|
+
- **Innovation**:
|
|
25
25
|
- First moment split into **1-bit sign + absolute value**
|
|
26
26
|
- Final storage: **four factored vectors + one 1-bit sign state**
|
|
27
27
|
- Preserves Adam-like update quality with drastically reduced memory
|
|
@@ -79,7 +79,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
79
79
|
|
|
80
80
|
## 🛠️ Comprehensive Feature Guide
|
|
81
81
|
|
|
82
|
-
### A. Universal Safe Features
|
|
82
|
+
### A. Universal Safe Features
|
|
83
83
|
*These features work with all optimizers and are generally safe to enable.*
|
|
84
84
|
|
|
85
85
|
| Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
|
|
@@ -134,7 +134,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
134
134
|
| `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
|
|
135
135
|
| `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
|
|
136
136
|
|
|
137
|
-
> ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
|
|
137
|
+
> ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
|
|
138
138
|
> For `Prodigy_Adv`, set `initial_d` to:
|
|
139
139
|
> - **LoRA**: `1e-8`
|
|
140
140
|
> - **Full FT**: `1e-10`
|
|
@@ -144,10 +144,10 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
144
144
|
|
|
145
145
|
#### Performance Validation
|
|
146
146
|
|
|
147
|
-
**Small Batch Training (SDXL, BS=2, 1.8K steps)**
|
|
147
|
+
**Small Batch Training (SDXL, BS=2, 1.8K steps)**
|
|
148
148
|

|
|
149
149
|
|
|
150
|
-
- **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
|
|
150
|
+
- **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
|
|
151
151
|
- **🔵 Prodigy_Adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR = 5.8e-6
|
|
152
152
|
|
|
153
153
|
**Results**:
|
|
@@ -171,8 +171,8 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
171
171
|
|
|
172
172
|
Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
|
|
173
173
|
|
|
174
|
-
- **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
|
|
175
|
-
- **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
|
|
174
|
+
- **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
|
|
175
|
+
- **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
|
|
176
176
|
|
|
177
177
|
This is especially effective for **noisy training, small batch sizes, and high learning rates**, where gradient norms shift abruptly due to noise or aggressive LR schedules.
|
|
178
178
|
|
|
@@ -189,17 +189,17 @@ This is especially effective for **noisy training, small batch sizes, and high l
|
|
|
189
189
|
|
|
190
190
|
#### 📊 Performance Validation
|
|
191
191
|
|
|
192
|
-
**ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
|
|
192
|
+
**ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
|
|
193
193
|
<img width="1460" height="382" alt="image" src="https://github.com/user-attachments/assets/007f278a-fbac-4f3d-9cc7-274c3b959cdd" />
|
|
194
194
|
|
|
195
|
-
- 🟣 Fixed `beta2=0.999`
|
|
196
|
-
- 🟠 Auto K-beta
|
|
195
|
+
- 🟣 Fixed `beta2=0.999`
|
|
196
|
+
- 🟠 Auto K-beta
|
|
197
197
|
|
|
198
|
-
**Observations:**
|
|
198
|
+
**Observations:**
|
|
199
199
|
- K-beta is clearly better and more robust/stable for high LRs.
|
|
200
200
|
|
|
201
|
-
> 📚 **Reference**:
|
|
202
|
-
> - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
201
|
+
> 📚 **Reference**:
|
|
202
|
+
> - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
203
203
|
> - Code: [kbeta](https://github.com/sck-at-ucy/kbeta)
|
|
204
204
|
|
|
205
205
|
---
|
|
@@ -227,7 +227,7 @@ settings:
|
|
|
227
227
|
- factored: False # Can be true or false, quality should not degrade due to Simplified_AdEMAMix’s high tolerance to 1-bit factorization.
|
|
228
228
|
```
|
|
229
229
|
|
|
230
|
-
> ✅ **Why it works**:
|
|
230
|
+
> ✅ **Why it works**:
|
|
231
231
|
> - `Kourkoutas-β` handles beta2 values
|
|
232
232
|
> - `Simplified_AdEMAMix` ensures responsiveness in small-batch noise
|
|
233
233
|
> - `OrthoGrad` prevents overfitting without weight decay
|
|
@@ -236,9 +236,9 @@ settings:
|
|
|
236
236
|
|
|
237
237
|
## 📚 References
|
|
238
238
|
|
|
239
|
-
1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
|
|
240
|
-
2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
241
|
-
3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
|
|
242
|
-
4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
|
|
243
|
-
5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
|
|
239
|
+
1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
|
|
240
|
+
2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
241
|
+
3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
|
|
242
|
+
4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
|
|
243
|
+
5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
|
|
244
244
|
6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
@@ -46,7 +46,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
46
46
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
47
47
|
stochastic_rounding (bool): whether to use stochastic rounding for
|
|
48
48
|
BF16 parameter updates (default: True).
|
|
49
|
-
orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
|
|
50
49
|
nesterov (bool): enables Nesterov momentum (default: False).
|
|
51
50
|
use_atan2 (bool): whether to use the atan2 update rule. (default: False)
|
|
52
51
|
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
@@ -96,7 +95,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
96
95
|
ns_eps: float = 1e-7,
|
|
97
96
|
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
98
97
|
stochastic_rounding: bool = False,
|
|
99
|
-
orthogonal_gradient: bool = False,
|
|
100
98
|
use_atan2: bool = False,
|
|
101
99
|
nesterov: bool = False,
|
|
102
100
|
Simplified_AdEMAMix: bool = False,
|
|
@@ -149,7 +147,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
149
147
|
"vector_reshape": vector_reshape,
|
|
150
148
|
"nesterov":nesterov, "use_atan2":use_atan2,
|
|
151
149
|
"Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
|
|
152
|
-
"normuon_variant": normuon_variant,
|
|
150
|
+
"normuon_variant": normuon_variant,
|
|
153
151
|
# Low-rank Ortho
|
|
154
152
|
"low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
|
|
155
153
|
"compiled_optimizer":compiled_optimizer,
|
|
@@ -284,10 +282,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
284
282
|
nesterov = group['nesterov']
|
|
285
283
|
Simplified_AdEMAMix = group['Simplified_AdEMAMix']
|
|
286
284
|
alpha_grad = group['alpha_grad']
|
|
287
|
-
if grad.dtype != torch.float32 and state.get('factored', False):
|
|
288
|
-
grad = grad.float()
|
|
289
|
-
if group.get("orthogonal_gradient"):
|
|
290
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
291
285
|
|
|
292
286
|
if state['factored']: # Factored AdaMuon
|
|
293
287
|
|
|
@@ -351,7 +345,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
351
345
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
352
346
|
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
353
347
|
# Normalize update
|
|
354
|
-
|
|
348
|
+
if group['use_atan2']:
|
|
349
|
+
a = 1.2732395
|
|
350
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
351
|
+
else:
|
|
352
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
355
353
|
# Scale learning rate
|
|
356
354
|
update_norm = torch.linalg.vector_norm(update)
|
|
357
355
|
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
|
@@ -456,7 +454,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
456
454
|
mean_squared_update = torch.mean(update.square(), dim=1)
|
|
457
455
|
v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
|
|
458
456
|
# Normalize update
|
|
459
|
-
|
|
457
|
+
if group['use_atan2']:
|
|
458
|
+
a = 1.2732395
|
|
459
|
+
update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
|
|
460
|
+
else:
|
|
461
|
+
update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
|
|
460
462
|
# Scale learning rate
|
|
461
463
|
update_norm = torch.linalg.vector_norm(update)
|
|
462
464
|
scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
|
|
@@ -49,14 +49,12 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
49
49
|
before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
|
|
50
50
|
A higher value increases the stabilizing influence of the slow
|
|
51
51
|
momentum. (default: 5.0)
|
|
52
|
-
t_alpha (Optional[int]): The number of steps for a linear warmup of the
|
|
53
|
-
`alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
|
|
54
|
-
highly recommended to prevent instability at the beginning of training,
|
|
55
|
-
as it gradually introduces the stabilizing slow momentum term. During
|
|
56
|
-
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
57
|
-
the scheduler is disabled. (default: None)
|
|
58
52
|
kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
|
|
59
53
|
If `False`, the optimizer behaves as standard AdamW. (default: False)
|
|
54
|
+
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
55
|
+
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
56
|
+
If `None`, parameters are bucketed by their shape.
|
|
57
|
+
(default: None)
|
|
60
58
|
beta2_min (float): The minimum value for dynamic β₂, used during periods of
|
|
61
59
|
high gradient variance ("sunspikes"). Must be less than `betas[1]`.
|
|
62
60
|
(default: 0.88)
|
|
@@ -72,11 +70,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
72
70
|
k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
|
|
73
71
|
logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
|
|
74
72
|
every logging steps. Useful for debugging and tuning. Set to 0 to disable
|
|
75
|
-
logging (default: 0).
|
|
76
|
-
layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
|
|
77
|
-
and returns a unique, hashable key representing its "layer" or "bucket".
|
|
78
|
-
If `None`, parameters are bucketed by their memory ID (tensor-wise).
|
|
79
|
-
(default: None)
|
|
73
|
+
logging (default: 0).
|
|
80
74
|
nnmf_factor (bool): whether to use the factorization or disable it to use
|
|
81
75
|
the uncompressed optimizer. (default: False)
|
|
82
76
|
"""
|
|
@@ -89,7 +83,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
89
83
|
eps: float = 1e-8,
|
|
90
84
|
weight_decay: float = 0.0,
|
|
91
85
|
use_bias_correction: bool = True,
|
|
92
|
-
vector_reshape: bool =
|
|
86
|
+
vector_reshape: bool = False,
|
|
93
87
|
stochastic_rounding: bool = True,
|
|
94
88
|
use_atan2: bool = False,
|
|
95
89
|
cautious_mask: bool = False,
|
|
@@ -98,15 +92,16 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
98
92
|
use_AdEMAMix: bool = False,
|
|
99
93
|
beta3_ema: float = 0.9999,
|
|
100
94
|
alpha: float = 5.0,
|
|
101
|
-
t_alpha: int | None = None,
|
|
102
95
|
kourkoutas_beta: bool = False,
|
|
96
|
+
layer_key_fn: Optional[Callable] = None,
|
|
103
97
|
beta2_min: float = 0.9,
|
|
104
98
|
ema_alpha: float = 0.95,
|
|
105
99
|
tiny_spike: float = 1e-9,
|
|
106
100
|
k_warmup_steps: int = 0,
|
|
107
101
|
k_logging: int = 0,
|
|
108
|
-
layer_key_fn: Optional[Callable] = None,
|
|
109
102
|
nnmf_factor: bool = False,
|
|
103
|
+
# Compiled
|
|
104
|
+
compiled_optimizer: bool = False,
|
|
110
105
|
):
|
|
111
106
|
if not (lr >= 0.0):
|
|
112
107
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -116,7 +111,8 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
116
111
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
117
112
|
if not (weight_decay >= 0.0):
|
|
118
113
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
119
|
-
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
114
|
+
if kourkoutas_beta and not (betas[1] > beta2_min):
|
|
115
|
+
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
120
116
|
|
|
121
117
|
if cautious_mask and grams_moment:
|
|
122
118
|
print("Warning: cautious is incompatible with grams, Disabling cautious.")
|
|
@@ -126,22 +122,30 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
126
122
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
127
123
|
"vector_reshape": vector_reshape, "use_atan2": use_atan2,
|
|
128
124
|
"orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
|
|
129
|
-
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
125
|
+
"beta3_ema": beta3_ema, "alpha": alpha,
|
|
130
126
|
"kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
|
|
131
127
|
"tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
|
|
128
|
+
"compiled_optimizer": compiled_optimizer,
|
|
132
129
|
}
|
|
133
130
|
self.stochastic_rounding = stochastic_rounding
|
|
134
131
|
self.cautious_mask = cautious_mask
|
|
135
132
|
self.grams_moment = grams_moment
|
|
136
133
|
self.use_AdEMAMix = use_AdEMAMix
|
|
137
134
|
self.factored = nnmf_factor
|
|
138
|
-
self.kourkoutas_beta = kourkoutas_beta
|
|
139
135
|
self.layer_key_fn = layer_key_fn
|
|
136
|
+
self.kourkoutas_beta = kourkoutas_beta
|
|
137
|
+
|
|
140
138
|
super().__init__(params, defaults)
|
|
141
139
|
|
|
140
|
+
self.init_step()
|
|
141
|
+
|
|
142
142
|
if self.kourkoutas_beta:
|
|
143
143
|
self.kourkoutas_helper = KourkoutasHelper(self)
|
|
144
144
|
|
|
145
|
+
if compiled_optimizer:
|
|
146
|
+
torch._dynamo.config.cache_size_limit = 8192
|
|
147
|
+
self.compile(fullgraph=True)
|
|
148
|
+
|
|
145
149
|
@property
|
|
146
150
|
def supports_fused_back_pass(self):
|
|
147
151
|
return True
|
|
@@ -154,29 +158,24 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
154
158
|
def supports_flat_params(self):
|
|
155
159
|
return False
|
|
156
160
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
+
def init_step(self):
|
|
162
|
+
for group in self.param_groups:
|
|
163
|
+
for p in group['params']:
|
|
164
|
+
self.__init_state(p, group)
|
|
161
165
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
grad = grad.float()
|
|
165
|
-
if group["orthogonal_gradient"]:
|
|
166
|
-
grad = _orthogonalize_gradient(p, grad)
|
|
166
|
+
@torch.no_grad()
|
|
167
|
+
def __init_state(self, p, group):
|
|
167
168
|
state = self.state[p]
|
|
168
169
|
|
|
169
|
-
|
|
170
|
-
|
|
170
|
+
if len(state) == 0:
|
|
171
|
+
|
|
171
172
|
state['step'] = 0
|
|
172
173
|
|
|
173
|
-
|
|
174
|
+
state['factored'] = (
|
|
174
175
|
self.factored and
|
|
175
176
|
not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
176
177
|
)
|
|
177
178
|
|
|
178
|
-
state['factored'] = should_factor
|
|
179
|
-
|
|
180
179
|
dtype = torch.float32 if self.factored else p.dtype
|
|
181
180
|
device = p.device
|
|
182
181
|
|
|
@@ -186,18 +185,18 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
186
185
|
|
|
187
186
|
# First moment (m)
|
|
188
187
|
if group['betas'][0] > 0:
|
|
189
|
-
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
188
|
+
state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
190
189
|
state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
191
190
|
if not self.grams_moment:
|
|
192
191
|
packed_d2 = (d2 + 7) // 8
|
|
193
192
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
194
193
|
if self.use_AdEMAMix:
|
|
195
|
-
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
194
|
+
state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
|
|
196
195
|
state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
197
196
|
packed_d2 = (d2 + 7) // 8
|
|
198
197
|
state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
199
198
|
# Second moment (v)
|
|
200
|
-
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
199
|
+
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
201
200
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
202
201
|
else: # Fallback to standard AdamW for non-factored tensors
|
|
203
202
|
if group['betas'][0] > 0:
|
|
@@ -206,37 +205,32 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
206
205
|
state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
207
206
|
state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
|
|
208
207
|
|
|
208
|
+
@torch.no_grad()
|
|
209
|
+
def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float, bias_correction1: torch.Tensor | float, bias_correction2: torch.Tensor | float):
|
|
210
|
+
if p.grad is None:
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
grad = p.grad
|
|
214
|
+
if grad.dtype != torch.float32 and self.factored:
|
|
215
|
+
grad = grad.float()
|
|
216
|
+
if group["orthogonal_gradient"]:
|
|
217
|
+
grad = _orthogonalize_gradient(p, grad)
|
|
218
|
+
state = self.state[p]
|
|
219
|
+
|
|
220
|
+
|
|
209
221
|
beta1, beta2 = group['betas']
|
|
210
222
|
|
|
211
|
-
current_step = state['step']
|
|
212
223
|
if group.get('kourkoutas_beta', False):
|
|
213
|
-
# Call prepare_step() once at the beginning of the step for all params
|
|
214
|
-
self.kourkoutas_helper.maybe_prepare_step(current_step)
|
|
215
224
|
# Accumulate current grad's norm for the *next* step
|
|
216
225
|
self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
|
|
217
226
|
# Get the dynamic beta2 calculated in prepare_step()
|
|
218
|
-
beta2 = self.kourkoutas_helper.get_beta2(p, group
|
|
227
|
+
beta2 = self.kourkoutas_helper.get_beta2(p, group)
|
|
219
228
|
|
|
220
|
-
|
|
221
|
-
if group['use_bias_correction']:
|
|
222
|
-
bias_correction1 = 1.0 - beta1 ** step
|
|
223
|
-
if group.get('kourkoutas_beta', False):
|
|
224
|
-
bias_correction2 = 1.0 - group['betas'][1] ** step
|
|
225
|
-
# Use beta2_max for bias correction
|
|
226
|
-
else:
|
|
227
|
-
bias_correction2 = 1.0 - beta2 ** step
|
|
228
|
-
else:
|
|
229
|
-
bias_correction1 = 1
|
|
230
|
-
bias_correction2 = 1
|
|
231
|
-
step_size = group['lr'] / bias_correction1
|
|
229
|
+
step_size = lr / bias_correction1
|
|
232
230
|
|
|
233
231
|
if self.use_AdEMAMix:
|
|
234
232
|
beta3_ema = group['beta3_ema']
|
|
235
233
|
alpha = group['alpha']
|
|
236
|
-
t_alpha = group['t_alpha']
|
|
237
|
-
alpha_t = alpha
|
|
238
|
-
if t_alpha is not None and t_alpha > 0 and step < t_alpha:
|
|
239
|
-
alpha_t = min(step * alpha / t_alpha, alpha)
|
|
240
234
|
|
|
241
235
|
if state['factored']:
|
|
242
236
|
d1, d2 = state['effective_shape']
|
|
@@ -272,9 +266,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
272
266
|
|
|
273
267
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
274
268
|
if beta1 > 0:
|
|
275
|
-
update = torch.add(mt, mt_slow, alpha=
|
|
269
|
+
update = torch.add(mt, mt_slow, alpha=alpha)
|
|
276
270
|
else:
|
|
277
|
-
update = torch.add(grad_reshaped, mt_slow, alpha=
|
|
271
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
|
|
278
272
|
else:
|
|
279
273
|
update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
|
|
280
274
|
del grad_reshaped
|
|
@@ -321,9 +315,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
321
315
|
exp_avg_slow = state['exp_avg_slow']
|
|
322
316
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
323
317
|
if beta1 > 0:
|
|
324
|
-
update = torch.add(exp_avg, exp_avg_slow, alpha=
|
|
318
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
|
|
325
319
|
else:
|
|
326
|
-
update = torch.add(grad, exp_avg_slow, alpha=
|
|
320
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha)
|
|
327
321
|
else:
|
|
328
322
|
update = exp_avg.clone() if beta1 > 0 else grad.clone()
|
|
329
323
|
|
|
@@ -343,9 +337,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
343
337
|
# Decoupled weight decay
|
|
344
338
|
if group["weight_decay"] != 0:
|
|
345
339
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
346
|
-
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] *
|
|
340
|
+
add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
|
|
347
341
|
else:
|
|
348
|
-
p.data.add_(p.data, alpha=-group["weight_decay"] *
|
|
342
|
+
p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
|
|
349
343
|
|
|
350
344
|
if p.dtype == torch.bfloat16 and self.stochastic_rounding:
|
|
351
345
|
add_stochastic_(p.data, -update)
|
|
@@ -353,7 +347,40 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
353
347
|
p.data.add_(-update)
|
|
354
348
|
del update
|
|
355
349
|
|
|
356
|
-
|
|
350
|
+
@torch.no_grad()
|
|
351
|
+
def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
|
|
352
|
+
|
|
353
|
+
state = self.state[p]
|
|
354
|
+
|
|
355
|
+
step = state['step']
|
|
356
|
+
|
|
357
|
+
if group['use_bias_correction']:
|
|
358
|
+
current_step = step + 1
|
|
359
|
+
beta1, beta2 = group['betas']
|
|
360
|
+
bias_correction1 = 1.0 - beta1 ** current_step
|
|
361
|
+
bias_correction2 = 1.0 - beta2 ** current_step
|
|
362
|
+
else:
|
|
363
|
+
bias_correction1 = 1.0
|
|
364
|
+
bias_correction2 = 1.0
|
|
365
|
+
|
|
366
|
+
if group.get('kourkoutas_beta', False):
|
|
367
|
+
# Prepare Kourkoutas-β once per step using the global step counter.
|
|
368
|
+
self.kourkoutas_helper.maybe_prepare_step(step)
|
|
369
|
+
|
|
370
|
+
self.state[p]['step'] += 1
|
|
371
|
+
|
|
372
|
+
if not group.get('compiled_optimizer', False):
|
|
373
|
+
self.__step_parameter(p, group, group['lr'], bias_correction1, bias_correction2)
|
|
374
|
+
else:
|
|
375
|
+
if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
|
|
376
|
+
# convert to tensors for compiled path once a step
|
|
377
|
+
self.lr_tensor = torch.tensor(group['lr'], device=p.device)
|
|
378
|
+
self.bc1_tensor = torch.tensor(bias_correction1, device=p.device)
|
|
379
|
+
self.bc2_tensor = torch.tensor(bias_correction2, device=p.device)
|
|
380
|
+
self._compiled_step_parameter(p, group, self.lr_tensor, self.bc1_tensor, self.bc2_tensor)
|
|
381
|
+
|
|
382
|
+
def compile(self, *args, **kwargs):
|
|
383
|
+
self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
|
|
357
384
|
|
|
358
385
|
@torch.no_grad()
|
|
359
386
|
def step(self, closure=None):
|
|
@@ -367,4 +394,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
367
394
|
for i, p in enumerate(group['params']):
|
|
368
395
|
self.step_parameter(p, group, i)
|
|
369
396
|
|
|
397
|
+
if self.param_groups[0].get('compiled_optimizer', False):
|
|
398
|
+
# Reset compile tensors once a step
|
|
399
|
+
self.lr_tensor = None
|
|
370
400
|
return loss
|