adv-optm 1.2.dev19__tar.gz → 2.dev2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (28) hide show
  1. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/PKG-INFO +20 -20
  2. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/README.md +19 -19
  3. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/__init__.py +1 -1
  4. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/AdaMuon_adv.py +11 -9
  5. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/AdamW_adv.py +91 -61
  6. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Adopt_adv.py +113 -68
  7. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Lion_Prodigy_adv.py +79 -81
  8. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Lion_adv.py +59 -43
  9. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Muon_adv.py +13 -12
  10. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Prodigy_adv.py +108 -86
  11. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/Simplified_AdEMAMix.py +93 -52
  12. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/optim/__init__.py +1 -1
  13. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/BF16_Stochastic_Rounding.py +1 -1
  14. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/Effective_Shape.py +1 -1
  15. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/Kourkoutas.py +10 -12
  16. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/NNMF.py +7 -2
  17. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/One_Bit_Boolean.py +1 -1
  18. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/OrthoGrad.py +4 -3
  19. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/__init__.py +1 -1
  20. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/PKG-INFO +20 -20
  21. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/setup.py +1 -1
  22. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/LICENSE +0 -0
  23. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm/util/Newton_Schulz.py +0 -0
  24. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/SOURCES.txt +0 -0
  25. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/dependency_links.txt +0 -0
  26. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/requires.txt +0 -0
  27. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/adv_optm.egg-info/top_level.txt +0 -0
  28. {adv_optm-1.2.dev19 → adv_optm-2.dev2}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 1.2.dev19
3
+ Version: 2.dev2
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -52,7 +52,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
52
52
  ### **Memory-Efficient Optimization (SMMF-inspired)**
53
53
  - **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
54
54
  - **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
55
- - **Innovation**:
55
+ - **Innovation**:
56
56
  - First moment split into **1-bit sign + absolute value**
57
57
  - Final storage: **four factored vectors + one 1-bit sign state**
58
58
  - Preserves Adam-like update quality with drastically reduced memory
@@ -110,7 +110,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
110
110
 
111
111
  ## 🛠️ Comprehensive Feature Guide
112
112
 
113
- ### A. Universal Safe Features
113
+ ### A. Universal Safe Features
114
114
  *These features work with all optimizers and are generally safe to enable.*
115
115
 
116
116
  | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
@@ -165,7 +165,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
165
165
  | `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
166
166
  | `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
167
167
 
168
- > ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
168
+ > ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
169
169
  > For `Prodigy_Adv`, set `initial_d` to:
170
170
  > - **LoRA**: `1e-8`
171
171
  > - **Full FT**: `1e-10`
@@ -175,10 +175,10 @@ This library integrates multiple state-of-the-art optimization techniques valida
175
175
 
176
176
  #### Performance Validation
177
177
 
178
- **Small Batch Training (SDXL, BS=2, 1.8K steps)**
178
+ **Small Batch Training (SDXL, BS=2, 1.8K steps)**
179
179
  ![Training Comparison](https://github.com/user-attachments/assets/7eff0671-cc59-47fc-8b63-d5205456d649)
180
180
 
181
- - **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
181
+ - **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
182
182
  - **🔵 Prodigy_Adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR = 5.8e-6
183
183
 
184
184
  **Results**:
@@ -202,8 +202,8 @@ This library integrates multiple state-of-the-art optimization techniques valida
202
202
 
203
203
  Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
204
204
 
205
- - **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
206
- - **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
205
+ - **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
206
+ - **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
207
207
 
208
208
  This is especially effective for **noisy training, small batch sizes, and high learning rates**, where gradient norms shift abruptly due to noise or aggressive LR schedules.
209
209
 
@@ -220,17 +220,17 @@ This is especially effective for **noisy training, small batch sizes, and high l
220
220
 
221
221
  #### 📊 Performance Validation
222
222
 
223
- **ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
223
+ **ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
224
224
  <img width="1460" height="382" alt="image" src="https://github.com/user-attachments/assets/007f278a-fbac-4f3d-9cc7-274c3b959cdd" />
225
225
 
226
- - 🟣 Fixed `beta2=0.999`
227
- - 🟠 Auto K-beta
226
+ - 🟣 Fixed `beta2=0.999`
227
+ - 🟠 Auto K-beta
228
228
 
229
- **Observations:**
229
+ **Observations:**
230
230
  - K-beta is clearly better and more robust/stable for high LRs.
231
231
 
232
- > 📚 **Reference**:
233
- > - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
232
+ > 📚 **Reference**:
233
+ > - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
234
234
  > - Code: [kbeta](https://github.com/sck-at-ucy/kbeta)
235
235
 
236
236
  ---
@@ -258,7 +258,7 @@ settings:
258
258
  - factored: False # Can be true or false, quality should not degrade due to Simplified_AdEMAMix’s high tolerance to 1-bit factorization.
259
259
  ```
260
260
 
261
- > ✅ **Why it works**:
261
+ > ✅ **Why it works**:
262
262
  > - `Kourkoutas-β` handles beta2 values
263
263
  > - `Simplified_AdEMAMix` ensures responsiveness in small-batch noise
264
264
  > - `OrthoGrad` prevents overfitting without weight decay
@@ -267,9 +267,9 @@ settings:
267
267
 
268
268
  ## 📚 References
269
269
 
270
- 1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
271
- 2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
272
- 3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
273
- 4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
274
- 5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
270
+ 1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
271
+ 2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
272
+ 3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
273
+ 4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
274
+ 5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
275
275
  6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
@@ -21,7 +21,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
21
21
  ### **Memory-Efficient Optimization (SMMF-inspired)**
22
22
  - **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
23
23
  - **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
24
- - **Innovation**:
24
+ - **Innovation**:
25
25
  - First moment split into **1-bit sign + absolute value**
26
26
  - Final storage: **four factored vectors + one 1-bit sign state**
27
27
  - Preserves Adam-like update quality with drastically reduced memory
@@ -79,7 +79,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
79
79
 
80
80
  ## 🛠️ Comprehensive Feature Guide
81
81
 
82
- ### A. Universal Safe Features
82
+ ### A. Universal Safe Features
83
83
  *These features work with all optimizers and are generally safe to enable.*
84
84
 
85
85
  | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
@@ -134,7 +134,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
134
134
  | `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
135
135
  | `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
136
136
 
137
- > ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
137
+ > ⚠️ **Critical**: Requires **~100x smaller learning rate** than AdamW (e.g., 1e-6 vs 1e-4).
138
138
  > For `Prodigy_Adv`, set `initial_d` to:
139
139
  > - **LoRA**: `1e-8`
140
140
  > - **Full FT**: `1e-10`
@@ -144,10 +144,10 @@ This library integrates multiple state-of-the-art optimization techniques valida
144
144
 
145
145
  #### Performance Validation
146
146
 
147
- **Small Batch Training (SDXL, BS=2, 1.8K steps)**
147
+ **Small Batch Training (SDXL, BS=2, 1.8K steps)**
148
148
  ![Training Comparison](https://github.com/user-attachments/assets/7eff0671-cc59-47fc-8b63-d5205456d649)
149
149
 
150
- - **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
150
+ - **🟢 Prodigy_Adv** (beta1=0.9, d0=1e-5): Final LR = 2.9e-4
151
151
  - **🔵 Prodigy_Adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR = 5.8e-6
152
152
 
153
153
  **Results**:
@@ -171,8 +171,8 @@ This library integrates multiple state-of-the-art optimization techniques valida
171
171
 
172
172
  Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
173
173
 
174
- - **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
175
- - **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
174
+ - **During gradient bursts** → β₂ ↓ toward `Lower β₂` → faster reaction
175
+ - **During calm phases** → β₂ ↑ toward `The Selected β₂` → stronger smoothing
176
176
 
177
177
  This is especially effective for **noisy training, small batch sizes, and high learning rates**, where gradient norms shift abruptly due to noise or aggressive LR schedules.
178
178
 
@@ -189,17 +189,17 @@ This is especially effective for **noisy training, small batch sizes, and high l
189
189
 
190
190
  #### 📊 Performance Validation
191
191
 
192
- **ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
192
+ **ADAMW_ADV - full SDXL finetuning (aggressive LR: 3e-5) (BS=4, 2.5K steps)**
193
193
  <img width="1460" height="382" alt="image" src="https://github.com/user-attachments/assets/007f278a-fbac-4f3d-9cc7-274c3b959cdd" />
194
194
 
195
- - 🟣 Fixed `beta2=0.999`
196
- - 🟠 Auto K-beta
195
+ - 🟣 Fixed `beta2=0.999`
196
+ - 🟠 Auto K-beta
197
197
 
198
- **Observations:**
198
+ **Observations:**
199
199
  - K-beta is clearly better and more robust/stable for high LRs.
200
200
 
201
- > 📚 **Reference**:
202
- > - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
201
+ > 📚 **Reference**:
202
+ > - Paper: [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
203
203
  > - Code: [kbeta](https://github.com/sck-at-ucy/kbeta)
204
204
 
205
205
  ---
@@ -227,7 +227,7 @@ settings:
227
227
  - factored: False # Can be true or false, quality should not degrade due to Simplified_AdEMAMix’s high tolerance to 1-bit factorization.
228
228
  ```
229
229
 
230
- > ✅ **Why it works**:
230
+ > ✅ **Why it works**:
231
231
  > - `Kourkoutas-β` handles beta2 values
232
232
  > - `Simplified_AdEMAMix` ensures responsiveness in small-batch noise
233
233
  > - `OrthoGrad` prevents overfitting without weight decay
@@ -236,9 +236,9 @@ settings:
236
236
 
237
237
  ## 📚 References
238
238
 
239
- 1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
240
- 2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
241
- 3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
242
- 4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
243
- 5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
239
+ 1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
240
+ 2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
241
+ 3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
242
+ 4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
243
+ 5. [AdaMeM: Memory Efficient Momentum for Adafactor](https://openreview.net/forum?id=fZqMVTz7K5)
244
244
  6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "AdaMuon_adv",
21
21
  ]
22
22
 
23
- __version__ = "1.2.dev19"
23
+ __version__ = "2.dev2"
@@ -46,7 +46,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
46
46
  (default: (3.4445, -4.7750, 2.0315)).
47
47
  stochastic_rounding (bool): whether to use stochastic rounding for
48
48
  BF16 parameter updates (default: True).
49
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
50
49
  nesterov (bool): enables Nesterov momentum (default: False).
51
50
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
52
51
  Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
@@ -96,7 +95,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
96
95
  ns_eps: float = 1e-7,
97
96
  ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
98
97
  stochastic_rounding: bool = False,
99
- orthogonal_gradient: bool = False,
100
98
  use_atan2: bool = False,
101
99
  nesterov: bool = False,
102
100
  Simplified_AdEMAMix: bool = False,
@@ -149,7 +147,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
149
147
  "vector_reshape": vector_reshape,
150
148
  "nesterov":nesterov, "use_atan2":use_atan2,
151
149
  "Simplified_AdEMAMix": Simplified_AdEMAMix, "alpha_grad": alpha_grad,
152
- "normuon_variant": normuon_variant, "orthogonal_gradient": orthogonal_gradient,
150
+ "normuon_variant": normuon_variant,
153
151
  # Low-rank Ortho
154
152
  "low_rank_ortho": low_rank_ortho, "ortho_rank": ortho_rank,
155
153
  "compiled_optimizer":compiled_optimizer,
@@ -284,10 +282,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
284
282
  nesterov = group['nesterov']
285
283
  Simplified_AdEMAMix = group['Simplified_AdEMAMix']
286
284
  alpha_grad = group['alpha_grad']
287
- if grad.dtype != torch.float32 and state.get('factored', False):
288
- grad = grad.float()
289
- if group.get("orthogonal_gradient"):
290
- grad = _orthogonalize_gradient(p, grad)
291
285
 
292
286
  if state['factored']: # Factored AdaMuon
293
287
 
@@ -351,7 +345,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
351
345
  mean_squared_update = torch.mean(update.square(), dim=1)
352
346
  v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
353
347
  # Normalize update
354
- update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
348
+ if group['use_atan2']:
349
+ a = 1.2732395
350
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
351
+ else:
352
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
355
353
  # Scale learning rate
356
354
  update_norm = torch.linalg.vector_norm(update)
357
355
  scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
@@ -456,7 +454,11 @@ class AdaMuon_adv(torch.optim.Optimizer):
456
454
  mean_squared_update = torch.mean(update.square(), dim=1)
457
455
  v_t.mul_(beta2).add_(mean_squared_update, alpha=1 - beta2)
458
456
  # Normalize update
459
- update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
457
+ if group['use_atan2']:
458
+ a = 1.2732395
459
+ update.atan2_(v_t.sqrt().unsqueeze(1)).mul_(a)
460
+ else:
461
+ update.div_(v_t.sqrt().unsqueeze(1).add_(group['eps']))
460
462
  # Scale learning rate
461
463
  update_norm = torch.linalg.vector_norm(update)
462
464
  scaled_lr = group['rms_target'] * lr * (p.numel()**0.5) / update_norm.add_(group['eps'])
@@ -49,14 +49,12 @@ class AdamW_adv(torch.optim.Optimizer):
49
49
  before it is added to the fast momentum term (`update = mt + alpha * mt_slow`).
50
50
  A higher value increases the stabilizing influence of the slow
51
51
  momentum. (default: 5.0)
52
- t_alpha (Optional[int]): The number of steps for a linear warmup of the
53
- `alpha` parameter (only used when `use_AdEMAMix` is `True`). This is
54
- highly recommended to prevent instability at the beginning of training,
55
- as it gradually introduces the stabilizing slow momentum term. During
56
- the warmup, `alpha` ramps from 0 to its target value. If `None`,
57
- the scheduler is disabled. (default: None)
58
52
  kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
59
53
  If `False`, the optimizer behaves as standard AdamW. (default: False)
54
+ layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
55
+ and returns a unique, hashable key representing its "layer" or "bucket".
56
+ If `None`, parameters are bucketed by their shape.
57
+ (default: None)
60
58
  beta2_min (float): The minimum value for dynamic β₂, used during periods of
61
59
  high gradient variance ("sunspikes"). Must be less than `betas[1]`.
62
60
  (default: 0.88)
@@ -72,11 +70,7 @@ class AdamW_adv(torch.optim.Optimizer):
72
70
  k_logging (int): if > 0 and kourkoutas_beta=True, enables periodic console
73
71
  logging of Kourkoutas-β statistics (min, max, mean of `β₂` across layers)
74
72
  every logging steps. Useful for debugging and tuning. Set to 0 to disable
75
- logging (default: 0).
76
- layer_key_fn (Optional[Callable]): A function that takes a parameter `p`
77
- and returns a unique, hashable key representing its "layer" or "bucket".
78
- If `None`, parameters are bucketed by their memory ID (tensor-wise).
79
- (default: None)
73
+ logging (default: 0).
80
74
  nnmf_factor (bool): whether to use the factorization or disable it to use
81
75
  the uncompressed optimizer. (default: False)
82
76
  """
@@ -89,7 +83,7 @@ class AdamW_adv(torch.optim.Optimizer):
89
83
  eps: float = 1e-8,
90
84
  weight_decay: float = 0.0,
91
85
  use_bias_correction: bool = True,
92
- vector_reshape: bool = True,
86
+ vector_reshape: bool = False,
93
87
  stochastic_rounding: bool = True,
94
88
  use_atan2: bool = False,
95
89
  cautious_mask: bool = False,
@@ -98,15 +92,16 @@ class AdamW_adv(torch.optim.Optimizer):
98
92
  use_AdEMAMix: bool = False,
99
93
  beta3_ema: float = 0.9999,
100
94
  alpha: float = 5.0,
101
- t_alpha: int | None = None,
102
95
  kourkoutas_beta: bool = False,
96
+ layer_key_fn: Optional[Callable] = None,
103
97
  beta2_min: float = 0.9,
104
98
  ema_alpha: float = 0.95,
105
99
  tiny_spike: float = 1e-9,
106
100
  k_warmup_steps: int = 0,
107
101
  k_logging: int = 0,
108
- layer_key_fn: Optional[Callable] = None,
109
102
  nnmf_factor: bool = False,
103
+ # Compiled
104
+ compiled_optimizer: bool = False,
110
105
  ):
111
106
  if not (lr >= 0.0):
112
107
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -116,7 +111,8 @@ class AdamW_adv(torch.optim.Optimizer):
116
111
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
117
112
  if not (weight_decay >= 0.0):
118
113
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
119
- if kourkoutas_beta and not (betas[1] > beta2_min): raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
114
+ if kourkoutas_beta and not (betas[1] > beta2_min):
115
+ raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
120
116
 
121
117
  if cautious_mask and grams_moment:
122
118
  print("Warning: cautious is incompatible with grams, Disabling cautious.")
@@ -126,22 +122,30 @@ class AdamW_adv(torch.optim.Optimizer):
126
122
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
127
123
  "vector_reshape": vector_reshape, "use_atan2": use_atan2,
128
124
  "orthogonal_gradient": orthogonal_gradient, "use_bias_correction": use_bias_correction,
129
- "beta3_ema": beta3_ema, "alpha": alpha, "t_alpha": t_alpha,
125
+ "beta3_ema": beta3_ema, "alpha": alpha,
130
126
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
131
127
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
128
+ "compiled_optimizer": compiled_optimizer,
132
129
  }
133
130
  self.stochastic_rounding = stochastic_rounding
134
131
  self.cautious_mask = cautious_mask
135
132
  self.grams_moment = grams_moment
136
133
  self.use_AdEMAMix = use_AdEMAMix
137
134
  self.factored = nnmf_factor
138
- self.kourkoutas_beta = kourkoutas_beta
139
135
  self.layer_key_fn = layer_key_fn
136
+ self.kourkoutas_beta = kourkoutas_beta
137
+
140
138
  super().__init__(params, defaults)
141
139
 
140
+ self.init_step()
141
+
142
142
  if self.kourkoutas_beta:
143
143
  self.kourkoutas_helper = KourkoutasHelper(self)
144
144
 
145
+ if compiled_optimizer:
146
+ torch._dynamo.config.cache_size_limit = 8192
147
+ self.compile(fullgraph=True)
148
+
145
149
  @property
146
150
  def supports_fused_back_pass(self):
147
151
  return True
@@ -154,29 +158,24 @@ class AdamW_adv(torch.optim.Optimizer):
154
158
  def supports_flat_params(self):
155
159
  return False
156
160
 
157
- @torch.no_grad()
158
- def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
159
- if p.grad is None:
160
- return
161
+ def init_step(self):
162
+ for group in self.param_groups:
163
+ for p in group['params']:
164
+ self.__init_state(p, group)
161
165
 
162
- grad = p.grad
163
- if grad.dtype != torch.float32 and self.factored:
164
- grad = grad.float()
165
- if group["orthogonal_gradient"]:
166
- grad = _orthogonalize_gradient(p, grad)
166
+ @torch.no_grad()
167
+ def __init_state(self, p, group):
167
168
  state = self.state[p]
168
169
 
169
- # State Initialization
170
- if 'step' not in state:
170
+ if len(state) == 0:
171
+
171
172
  state['step'] = 0
172
173
 
173
- should_factor = (
174
+ state['factored'] = (
174
175
  self.factored and
175
176
  not (len(p.shape) == 1 and not group['vector_reshape'])
176
177
  )
177
178
 
178
- state['factored'] = should_factor
179
-
180
179
  dtype = torch.float32 if self.factored else p.dtype
181
180
  device = p.device
182
181
 
@@ -186,18 +185,18 @@ class AdamW_adv(torch.optim.Optimizer):
186
185
 
187
186
  # First moment (m)
188
187
  if group['betas'][0] > 0:
189
- state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
188
+ state['mu_m_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
190
189
  state['mv_m_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
191
190
  if not self.grams_moment:
192
191
  packed_d2 = (d2 + 7) // 8
193
192
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
194
193
  if self.use_AdEMAMix:
195
- state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
194
+ state['mu_m_slow_nmf'] = torch.zeros(d1, device=p.device, dtype=dtype)
196
195
  state['mv_m_slow_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
197
196
  packed_d2 = (d2 + 7) // 8
198
197
  state['sign_slow'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
199
198
  # Second moment (v)
200
- state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
199
+ state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
201
200
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
202
201
  else: # Fallback to standard AdamW for non-factored tensors
203
202
  if group['betas'][0] > 0:
@@ -206,37 +205,32 @@ class AdamW_adv(torch.optim.Optimizer):
206
205
  state['exp_avg_slow'] = torch.zeros_like(p, device=device, dtype=dtype)
207
206
  state['exp_avg_sq'] = torch.zeros_like(p, device=device, dtype=dtype)
208
207
 
208
+ @torch.no_grad()
209
+ def __step_parameter(self, p: torch.Tensor, group: dict, lr: torch.Tensor | float, bias_correction1: torch.Tensor | float, bias_correction2: torch.Tensor | float):
210
+ if p.grad is None:
211
+ return
212
+
213
+ grad = p.grad
214
+ if grad.dtype != torch.float32 and self.factored:
215
+ grad = grad.float()
216
+ if group["orthogonal_gradient"]:
217
+ grad = _orthogonalize_gradient(p, grad)
218
+ state = self.state[p]
219
+
220
+
209
221
  beta1, beta2 = group['betas']
210
222
 
211
- current_step = state['step']
212
223
  if group.get('kourkoutas_beta', False):
213
- # Call prepare_step() once at the beginning of the step for all params
214
- self.kourkoutas_helper.maybe_prepare_step(current_step)
215
224
  # Accumulate current grad's norm for the *next* step
216
225
  self.kourkoutas_helper.accumulate_gradient_sq_norm(p, grad)
217
226
  # Get the dynamic beta2 calculated in prepare_step()
218
- beta2 = self.kourkoutas_helper.get_beta2(p, group, current_step)
227
+ beta2 = self.kourkoutas_helper.get_beta2(p, group)
219
228
 
220
- step = state['step'] + 1
221
- if group['use_bias_correction']:
222
- bias_correction1 = 1.0 - beta1 ** step
223
- if group.get('kourkoutas_beta', False):
224
- bias_correction2 = 1.0 - group['betas'][1] ** step
225
- # Use beta2_max for bias correction
226
- else:
227
- bias_correction2 = 1.0 - beta2 ** step
228
- else:
229
- bias_correction1 = 1
230
- bias_correction2 = 1
231
- step_size = group['lr'] / bias_correction1
229
+ step_size = lr / bias_correction1
232
230
 
233
231
  if self.use_AdEMAMix:
234
232
  beta3_ema = group['beta3_ema']
235
233
  alpha = group['alpha']
236
- t_alpha = group['t_alpha']
237
- alpha_t = alpha
238
- if t_alpha is not None and t_alpha > 0 and step < t_alpha:
239
- alpha_t = min(step * alpha / t_alpha, alpha)
240
234
 
241
235
  if state['factored']:
242
236
  d1, d2 = state['effective_shape']
@@ -272,9 +266,9 @@ class AdamW_adv(torch.optim.Optimizer):
272
266
 
273
267
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
274
268
  if beta1 > 0:
275
- update = torch.add(mt, mt_slow, alpha=alpha_t)
269
+ update = torch.add(mt, mt_slow, alpha=alpha)
276
270
  else:
277
- update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
271
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha)
278
272
  else:
279
273
  update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
280
274
  del grad_reshaped
@@ -321,9 +315,9 @@ class AdamW_adv(torch.optim.Optimizer):
321
315
  exp_avg_slow = state['exp_avg_slow']
322
316
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
323
317
  if beta1 > 0:
324
- update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
318
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha)
325
319
  else:
326
- update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
320
+ update = torch.add(grad, exp_avg_slow, alpha=alpha)
327
321
  else:
328
322
  update = exp_avg.clone() if beta1 > 0 else grad.clone()
329
323
 
@@ -343,9 +337,9 @@ class AdamW_adv(torch.optim.Optimizer):
343
337
  # Decoupled weight decay
344
338
  if group["weight_decay"] != 0:
345
339
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
346
- add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * group["lr"])
340
+ add_stochastic_(p.data, p.data, alpha=-group["weight_decay"] * lr)
347
341
  else:
348
- p.data.add_(p.data, alpha=-group["weight_decay"] * group["lr"])
342
+ p.data.add_(p.data, alpha=-group["weight_decay"] * lr)
349
343
 
350
344
  if p.dtype == torch.bfloat16 and self.stochastic_rounding:
351
345
  add_stochastic_(p.data, -update)
@@ -353,7 +347,40 @@ class AdamW_adv(torch.optim.Optimizer):
353
347
  p.data.add_(-update)
354
348
  del update
355
349
 
356
- state['step'] += 1
350
+ @torch.no_grad()
351
+ def step_parameter(self, p: torch.Tensor, group: dict, i: int | None = None):
352
+
353
+ state = self.state[p]
354
+
355
+ step = state['step']
356
+
357
+ if group['use_bias_correction']:
358
+ current_step = step + 1
359
+ beta1, beta2 = group['betas']
360
+ bias_correction1 = 1.0 - beta1 ** current_step
361
+ bias_correction2 = 1.0 - beta2 ** current_step
362
+ else:
363
+ bias_correction1 = 1.0
364
+ bias_correction2 = 1.0
365
+
366
+ if group.get('kourkoutas_beta', False):
367
+ # Prepare Kourkoutas-β once per step using the global step counter.
368
+ self.kourkoutas_helper.maybe_prepare_step(step)
369
+
370
+ self.state[p]['step'] += 1
371
+
372
+ if not group.get('compiled_optimizer', False):
373
+ self.__step_parameter(p, group, group['lr'], bias_correction1, bias_correction2)
374
+ else:
375
+ if not hasattr(self, 'lr_tensor') or self.lr_tensor is None:
376
+ # convert to tensors for compiled path once a step
377
+ self.lr_tensor = torch.tensor(group['lr'], device=p.device)
378
+ self.bc1_tensor = torch.tensor(bias_correction1, device=p.device)
379
+ self.bc2_tensor = torch.tensor(bias_correction2, device=p.device)
380
+ self._compiled_step_parameter(p, group, self.lr_tensor, self.bc1_tensor, self.bc2_tensor)
381
+
382
+ def compile(self, *args, **kwargs):
383
+ self._compiled_step_parameter = torch.compile(self.__step_parameter, *args, **kwargs)
357
384
 
358
385
  @torch.no_grad()
359
386
  def step(self, closure=None):
@@ -367,4 +394,7 @@ class AdamW_adv(torch.optim.Optimizer):
367
394
  for i, p in enumerate(group['params']):
368
395
  self.step_parameter(p, group, i)
369
396
 
397
+ if self.param_groups[0].get('compiled_optimizer', False):
398
+ # Reset compile tensors once a step
399
+ self.lr_tensor = None
370
400
  return loss