adv-optm 2.4.dev20__tar.gz → 2.4.dev21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/PKG-INFO +12 -59
  2. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/README.md +10 -57
  3. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/__init__.py +1 -5
  4. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/AdaMuon_adv.py +47 -66
  5. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/AdamW_adv.py +22 -98
  6. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/Adopt_adv.py +28 -103
  7. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/Lion_adv.py +4 -26
  8. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/Muon_adv.py +39 -58
  9. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/Prodigy_adv.py +29 -96
  10. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/SignSGD_adv.py +55 -21
  11. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/SinkSGD_adv.py +64 -52
  12. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/optim/__init__.py +0 -4
  13. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/Muon_AuxAdam.py +15 -60
  14. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/Muon_util.py +22 -58
  15. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/centered_decay.py +0 -6
  16. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/factorization_util.py +9 -10
  17. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/param_update.py +47 -10
  18. adv_optm-2.4.dev21/adv_optm/util/signed_util.py +56 -0
  19. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/state_util.py +12 -14
  20. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/update_util.py +0 -32
  21. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm.egg-info/PKG-INFO +12 -59
  22. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm.egg-info/SOURCES.txt +0 -2
  23. adv_optm-2.4.dev21/adv_optm.egg-info/requires.txt +1 -0
  24. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/setup.py +2 -2
  25. adv_optm-2.4.dev20/adv_optm/optim/Lion_Prodigy_adv.py +0 -389
  26. adv_optm-2.4.dev20/adv_optm/optim/Simplified_AdEMAMix.py +0 -384
  27. adv_optm-2.4.dev20/adv_optm/util/signed_util.py +0 -56
  28. adv_optm-2.4.dev20/adv_optm.egg-info/requires.txt +0 -1
  29. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/LICENSE +0 -0
  30. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/Kourkoutas.py +0 -0
  31. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/OrthoGrad.py +0 -0
  32. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/__init__.py +0 -0
  33. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/lion_k.py +0 -0
  34. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/scaled_optm.py +0 -0
  35. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm/util/sinkhorn.py +0 -0
  36. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm.egg-info/dependency_links.txt +0 -0
  37. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/adv_optm.egg-info/top_level.txt +0 -0
  38. {adv_optm-2.4.dev20 → adv_optm-2.4.dev21}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: adv_optm
3
- Version: 2.4.dev20
3
+ Version: 2.4.dev21
4
4
  Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
5
  Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
6
  Author: Koratahiu
@@ -15,7 +15,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
15
  Requires-Python: >=3.8
16
16
  Description-Content-Type: text/markdown
17
17
  License-File: LICENSE
18
- Requires-Dist: torch>=2.0
18
+ Requires-Dist: torch>=2.1
19
19
  Dynamic: author
20
20
  Dynamic: author-email
21
21
  Dynamic: classifier
@@ -37,10 +37,6 @@ A comprehensive, all-in-one collection of optimization algorithms for deep learn
37
37
 
38
38
  ## 🔥 What's New
39
39
 
40
- ### in 2.2.2
41
-
42
- - `Simplified_AdEMAMix` now uses the same LR as AdamW for all `beta1` and `alpha_grad` values!
43
-
44
40
  ### in 2.1.x
45
41
 
46
42
  - Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
@@ -101,7 +97,6 @@ This library integrates multiple state-of-the-art optimization techniques valida
101
97
  |-----------|--------------|-------------|
102
98
  | `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
103
99
  | `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
104
- | `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
105
100
 
106
101
  ### Speed Comparison (SDXL, Batch Size 4)
107
102
  | Optimizer | Speed | Notes |
@@ -120,7 +115,6 @@ This library integrates multiple state-of-the-art optimization techniques valida
120
115
  | `Adam_Adv` | Advanced Adam implementation | General purpose |
121
116
  | `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
122
117
  | `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
123
- | `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
124
118
  | `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
125
119
  | `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
126
120
 
@@ -128,18 +122,14 @@ This library integrates multiple state-of-the-art optimization techniques valida
128
122
 
129
123
  ## ⚙️ Feature Matrix
130
124
 
131
- | Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Simplified_AdEMAMix | Lion_Adv |
132
- |---------|----------|-----------|-------------|---------------------|----------|
133
- | Factored | ✓ | ✓ | ✓ | ✓ | ✓ |
134
- | AdEMAMix | ✓ | ✓ | ✓ | | ✗ |
135
- | Simplified_AdEMAMix | ✗ | ✓ | ✓ | ✓ | |
136
- | OrthoGrad | | ✓ | ✓ | ✓ | |
137
- | Grams | | ✓ | ✓ | | |
138
- | Cautious | ✓ | ✓ | ✓ | ✗ | ✓ |
139
- | atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
140
- | Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
141
- | Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
142
- | **Kourkoutas-β** | ✓ | ✓ | ✓ | ✓ | ✗ |
125
+ | Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Lion_Adv |
126
+ |---------|----------|-----------|-------------|----------|
127
+ | Factored | ✓ | ✓ | ✓ ✓ |
128
+ | OrthoGrad | ✓ | ✓ | ✓ | |
129
+ | atan2 | ✓ | ✓ | ✓ |✗ |
130
+ | Stochastic Rounding | ✓ | ✓ | ✓ |✓ |
131
+ | Fused Backward Pass | ✓ | ✓ | | |
132
+ | **Kourkoutas-β** | ✓ | ✓ | ✓ | ✗ |
143
133
 
144
134
  ---
145
135
 
@@ -159,48 +149,13 @@ This library integrates multiple state-of-the-art optimization techniques valida
159
149
 
160
150
  | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
161
151
  |--------|-------------|-------------------|--------------------|-------------------|--------------|
162
- | **Cautious** | Only applies update if gradient direction aligns with momentum direction | Accelerating convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/Prodigy/Lion |
163
- | **Grams** | Update direction derived purely from current gradient | When Cautious is insufficient | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/Prodigy |
164
- | **AdEMAMix** | Dual EMA system that retains relevance of gradients over tens of thousands of steps | Long training runs, especially where model forgetting is a concern | +1 state memory | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
165
- | **Simplified_AdEMAMix** | Accumulator-based momentum, single EMA variant of AdEMAMix | All scenarios when tuned correctly | No overhead | [Connections](https://arxiv.org/abs/2502.02431) | Adam/Adopt/Prodigy |
166
152
  | **atan2** | Robust epsilon replacement with built-in gradient clipping | Use for stable bounded updates (or for Adopt as it needs that) | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/Prodigy |
167
- | **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy/Simplified_AdEMAMix |
168
-
169
- > **Note**: If both **Cautious** and **Grams** are enabled, **Grams takes precedence** and Cautious is disabled.
153
+ | **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy |
170
154
 
171
155
  ---
172
156
 
173
157
  ## 🔍 Feature Deep Dives
174
158
 
175
- ### AdEMAMix
176
-
177
- - Adds a **slow-decaying second EMA** (`beta3`) that retains gradient memory over tens of thousands of steps.
178
- - Particularly effective for **small batch sizes**, where Adam’s standard first moment is nearly useless.
179
-
180
- #### Tunable Hyperparameters
181
- | Parameter | Default | Tuning Guide |
182
- |-----------|---------|--------------|
183
- | `beta3` | 0.9999 | • Runs >120k steps: **0.9999**<br>• Runs ≤120k steps: **0.999** |
184
- | `alpha` | 5 | • Reduce to **2–3** if diverging<br>• Increase to strengthen long-term memory |
185
-
186
- > ✅ **Pro Tip**: Set `beta1=0` in Adam/Adopt/Prodigy to skip standard EMA entirely and rely solely on AdEMAMix’s slow EMA, ideal for small-batch regimes.
187
-
188
- ---
189
-
190
- ### Simplified_AdEMAMix
191
-
192
- - Introduced in [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants (arXiv:2502.02431)](https://arxiv.org/abs/2502.02431).
193
- - Replaces Adam’s first moment with a **theory-based momentum** with emphasize on raw gradient, combining the stability of long memory with responsiveness to recent gradients.
194
- - **Key insight**: Classical momentum **does not accelerate** in noisy (small-batch) regimes; this accumulator do.
195
-
196
- #### Tunable Hyperparameters
197
- | Parameter | Default | Tuning Guide |
198
- |----------|---------|--------------|
199
- | `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
200
- | `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
201
-
202
- ---
203
-
204
159
  ### atan2
205
160
 
206
161
  - Replaces `eps` in Adam-family optimizers with a **scale-invariant**, bounded update rule.
@@ -215,7 +170,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
215
170
 
216
171
  ### **Kourkoutas-β**
217
172
 
218
- **Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv`, and `Simplified_AdEMAMix`.
173
+ **Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv`.
219
174
 
220
175
  Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
221
176
 
@@ -243,7 +198,5 @@ This is especially effective for **noisy training, small batch sizes, and high l
243
198
 
244
199
  1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
245
200
  2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
246
- 3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
247
- 4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
248
201
  6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
249
202
  7. [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/abs/2407.05872)
@@ -6,10 +6,6 @@ A comprehensive, all-in-one collection of optimization algorithms for deep learn
6
6
 
7
7
  ## 🔥 What's New
8
8
 
9
- ### in 2.2.2
10
-
11
- - `Simplified_AdEMAMix` now uses the same LR as AdamW for all `beta1` and `alpha_grad` values!
12
-
13
9
  ### in 2.1.x
14
10
 
15
11
  - Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
@@ -70,7 +66,6 @@ This library integrates multiple state-of-the-art optimization techniques valida
70
66
  |-----------|--------------|-------------|
71
67
  | `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
72
68
  | `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
73
- | `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
74
69
 
75
70
  ### Speed Comparison (SDXL, Batch Size 4)
76
71
  | Optimizer | Speed | Notes |
@@ -89,7 +84,6 @@ This library integrates multiple state-of-the-art optimization techniques valida
89
84
  | `Adam_Adv` | Advanced Adam implementation | General purpose |
90
85
  | `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
91
86
  | `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
92
- | `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
93
87
  | `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
94
88
  | `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
95
89
 
@@ -97,18 +91,14 @@ This library integrates multiple state-of-the-art optimization techniques valida
97
91
 
98
92
  ## ⚙️ Feature Matrix
99
93
 
100
- | Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Simplified_AdEMAMix | Lion_Adv |
101
- |---------|----------|-----------|-------------|---------------------|----------|
102
- | Factored | ✓ | ✓ | ✓ | ✓ | ✓ |
103
- | AdEMAMix | ✓ | ✓ | ✓ | | ✗ |
104
- | Simplified_AdEMAMix | ✗ | ✓ | ✓ | ✓ | |
105
- | OrthoGrad | | ✓ | ✓ | ✓ | |
106
- | Grams | | ✓ | ✓ | | |
107
- | Cautious | ✓ | ✓ | ✓ | ✗ | ✓ |
108
- | atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
109
- | Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
110
- | Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
111
- | **Kourkoutas-β** | ✓ | ✓ | ✓ | ✓ | ✗ |
94
+ | Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Lion_Adv |
95
+ |---------|----------|-----------|-------------|----------|
96
+ | Factored | ✓ | ✓ | ✓ ✓ |
97
+ | OrthoGrad | ✓ | ✓ | ✓ | |
98
+ | atan2 | ✓ | ✓ | ✓ |✗ |
99
+ | Stochastic Rounding | ✓ | ✓ | ✓ |✓ |
100
+ | Fused Backward Pass | ✓ | ✓ | | |
101
+ | **Kourkoutas-β** | ✓ | ✓ | ✓ | ✗ |
112
102
 
113
103
  ---
114
104
 
@@ -128,48 +118,13 @@ This library integrates multiple state-of-the-art optimization techniques valida
128
118
 
129
119
  | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
130
120
  |--------|-------------|-------------------|--------------------|-------------------|--------------|
131
- | **Cautious** | Only applies update if gradient direction aligns with momentum direction | Accelerating convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/Prodigy/Lion |
132
- | **Grams** | Update direction derived purely from current gradient | When Cautious is insufficient | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/Prodigy |
133
- | **AdEMAMix** | Dual EMA system that retains relevance of gradients over tens of thousands of steps | Long training runs, especially where model forgetting is a concern | +1 state memory | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
134
- | **Simplified_AdEMAMix** | Accumulator-based momentum, single EMA variant of AdEMAMix | All scenarios when tuned correctly | No overhead | [Connections](https://arxiv.org/abs/2502.02431) | Adam/Adopt/Prodigy |
135
121
  | **atan2** | Robust epsilon replacement with built-in gradient clipping | Use for stable bounded updates (or for Adopt as it needs that) | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/Prodigy |
136
- | **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy/Simplified_AdEMAMix |
137
-
138
- > **Note**: If both **Cautious** and **Grams** are enabled, **Grams takes precedence** and Cautious is disabled.
122
+ | **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy |
139
123
 
140
124
  ---
141
125
 
142
126
  ## 🔍 Feature Deep Dives
143
127
 
144
- ### AdEMAMix
145
-
146
- - Adds a **slow-decaying second EMA** (`beta3`) that retains gradient memory over tens of thousands of steps.
147
- - Particularly effective for **small batch sizes**, where Adam’s standard first moment is nearly useless.
148
-
149
- #### Tunable Hyperparameters
150
- | Parameter | Default | Tuning Guide |
151
- |-----------|---------|--------------|
152
- | `beta3` | 0.9999 | • Runs >120k steps: **0.9999**<br>• Runs ≤120k steps: **0.999** |
153
- | `alpha` | 5 | • Reduce to **2–3** if diverging<br>• Increase to strengthen long-term memory |
154
-
155
- > ✅ **Pro Tip**: Set `beta1=0` in Adam/Adopt/Prodigy to skip standard EMA entirely and rely solely on AdEMAMix’s slow EMA, ideal for small-batch regimes.
156
-
157
- ---
158
-
159
- ### Simplified_AdEMAMix
160
-
161
- - Introduced in [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants (arXiv:2502.02431)](https://arxiv.org/abs/2502.02431).
162
- - Replaces Adam’s first moment with a **theory-based momentum** with emphasize on raw gradient, combining the stability of long memory with responsiveness to recent gradients.
163
- - **Key insight**: Classical momentum **does not accelerate** in noisy (small-batch) regimes; this accumulator do.
164
-
165
- #### Tunable Hyperparameters
166
- | Parameter | Default | Tuning Guide |
167
- |----------|---------|--------------|
168
- | `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
169
- | `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
170
-
171
- ---
172
-
173
128
  ### atan2
174
129
 
175
130
  - Replaces `eps` in Adam-family optimizers with a **scale-invariant**, bounded update rule.
@@ -184,7 +139,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
184
139
 
185
140
  ### **Kourkoutas-β**
186
141
 
187
- **Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv`, and `Simplified_AdEMAMix`.
142
+ **Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv`.
188
143
 
189
144
  Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
190
145
 
@@ -212,7 +167,5 @@ This is especially effective for **noisy training, small batch sizes, and high l
212
167
 
213
168
  1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
214
169
  2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
215
- 3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
216
- 4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
217
170
  6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
218
171
  7. [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/abs/2407.05872)
@@ -2,9 +2,7 @@ from .optim import (
2
2
  AdamW_adv,
3
3
  Prodigy_adv,
4
4
  Adopt_adv,
5
- Simplified_AdEMAMix,
6
5
  Lion_adv,
7
- Lion_Prodigy_adv,
8
6
  Muon_adv,
9
7
  AdaMuon_adv,
10
8
  SignSGD_adv,
@@ -15,13 +13,11 @@ __all__ = [
15
13
  "AdamW_adv",
16
14
  "Prodigy_adv",
17
15
  "Adopt_adv",
18
- "Simplified_AdEMAMix",
19
16
  "Lion_adv",
20
- "Lion_Prodigy_adv",
21
17
  "Muon_adv",
22
18
  "AdaMuon_adv",
23
19
  "SignSGD_adv",
24
20
  "SinkSGD_adv",
25
21
  ]
26
22
 
27
- __version__ = "2.4.dev20"
23
+ __version__ = "2.4.dev21"
@@ -3,8 +3,8 @@ import torch
3
3
  import math
4
4
 
5
5
  from ..util import param_update
6
- from ..util.Muon_util import newton_schulz, _is_suitable_for_muon, rms_adjustment, normuon_update, approx_mars, _auto_projection_for_adamuon, get_spectral_scaling
7
- from ..util.scaled_optm import spectral_normalization, init_spectral_norm
6
+ from ..util.Muon_util import newton_schulz, _is_suitable_for_muon, rms_adjustment, normuon_update, approx_mars, _auto_projection_for_adamuon
7
+ from ..util.scaled_optm import spectral_normalization, init_spectral_norm, scale_eps
8
8
  from ..util.factorization_util import _get_effective_shape, _factorize_state, _reconstruct_state
9
9
  from ..util.OrthoGrad import _orthogonalize_gradient
10
10
  from ..util.Kourkoutas import KourkoutasHelper
@@ -50,7 +50,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
50
50
  vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
51
51
  learning rate schedules. (default: True).
52
52
  ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
53
- ns_eps (float): epsilon for Newton-Schulz normalization stability (default: 1e-7).
53
+ ns_eps (float): epsilon for Newton-Schulz normalization stability. When None
54
+ it's derived from scale invariant rule (default: 1e-7).
54
55
  ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
55
56
  quintic polynomial in the Newton-Schulz iteration.
56
57
  (default: (3.4445, -4.7750, 2.0315)).
@@ -77,7 +78,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
77
78
  (default: 128)
78
79
  accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
79
80
  dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
80
- cns_a_bound (float): Initial lower bound for singular values for CANS. (default: 1e-4)
81
+ cns_a_bound (float): Initial lower bound for singular values for CANS. When None
82
+ it's derived from scale invariant rule (default: None).
81
83
  approx_mars (bool): If True, enables Approximated MARS-M variance reduction.
82
84
  fom the paper "MARS-M: When Variance Reduction Meets Matrices"
83
85
  (default: False)
@@ -112,12 +114,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
112
114
  adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
113
115
  adam_use_bias_correction (bool): Bias correction for AdamW.
114
116
  adam_use_atan2 (bool): Atan2 update rule for AdamW.
115
- adam_cautious_mask (bool): Cautious masking for AdamW.
116
- adam_grams_moment (bool): Grams-style updates for AdamW.
117
117
  adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
118
- adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
119
- adam_beta3_ema (float): Beta3 for AdEMAMix.
120
- adam_alpha (float): Alpha for AdEMAMix.
121
118
  adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
122
119
  adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
123
120
  adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
@@ -126,7 +123,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
126
123
  adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
127
124
  adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
128
125
  adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
129
- adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
126
+ adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
130
127
  adam_nnmf_factor (bool): 1-bit factored for AdamW.
131
128
  adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
132
129
  """
@@ -147,7 +144,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
147
144
  rms_rescaling: bool = True,
148
145
  # Newton Schulz
149
146
  ns_steps: int = 5,
150
- ns_eps: float = 1e-7,
147
+ ns_eps: float | None = 1e-7,
151
148
  ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
152
149
  # Stochastic Rounding for BF16
153
150
  stochastic_rounding: bool = True,
@@ -174,7 +171,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
174
171
  nnmf_factor: bool = False,
175
172
  # CANS
176
173
  accelerated_ns: bool = False,
177
- cns_a_bound: float = 1e-4,
174
+ cns_a_bound: float | None = None,
178
175
  # MARS-M
179
176
  approx_mars: bool = False,
180
177
  mars_gamma: float = 0.025,
@@ -193,12 +190,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
193
190
  adam_fisher_wd: bool = False,
194
191
  adam_use_bias_correction: bool = True,
195
192
  adam_use_atan2: bool = False,
196
- adam_cautious_mask: bool = False,
197
- adam_grams_moment: bool = False,
198
193
  adam_orthogonal_gradient: bool = False,
199
- adam_use_AdEMAMix: bool = False,
200
- adam_beta3_ema: float = 0.9999,
201
- adam_alpha: float = 5.0,
202
194
  adam_nesterov: bool = False,
203
195
  adam_nesterov_coef: float | None = None,
204
196
  adam_kourkoutas_beta: bool = False,
@@ -223,8 +215,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
223
215
  if spectral_normalization and accelerated_ns:
224
216
  ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
225
217
 
218
+ # Legacy backwards compatibility support for `nnmf_factor=True`
219
+ if nnmf_factor:
220
+ state_precision = "factored"
221
+
226
222
  state_precision = state_precision.lower()
227
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
223
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
228
224
  if state_precision not in valid_precisions:
229
225
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
230
226
 
@@ -262,9 +258,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
262
258
  "adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
263
259
  "adam_fisher_wd": adam_fisher_wd,
264
260
  "adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
265
- "adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
266
261
  "adam_orthogonal_gradient": adam_orthogonal_gradient,
267
- "adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
268
262
  "adam_nesterov": adam_nesterov, "adam_nesterov_coef": adam_nesterov_coef,
269
263
  "adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
270
264
  "adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
@@ -274,25 +268,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
274
268
  "adam_nnmf_factor": adam_nnmf_factor, "adam_factored_2nd": adam_factored_2nd,
275
269
  }
276
270
  self.stochastic_rounding = stochastic_rounding
277
- self._init_lr = lr
271
+ self._init_lr = lr if lr > 0 else 1
278
272
 
279
273
  super().__init__(params, defaults)
280
274
 
281
- # Validate that every group has a determined optimizer type
282
- for i, group in enumerate(self.param_groups):
283
- if group.get('use_muon') is None and group.get('optim_type') is None:
284
- # Automatic shape-based detection if not explicit
285
- has_muon_shape = False
286
- for p in group['params']:
287
- has_muon_shape = _is_suitable_for_muon(p)
288
- if has_muon_shape:
289
- group['use_muon'] = True
290
- else:
291
- group['use_muon'] = False
292
-
293
- if group.get('use_muon') is None: # Fallback
294
- group['use_muon'] = group.get('optim_type') == 'muon'
295
-
296
275
  self.init_step()
297
276
 
298
277
  self.kourkoutas_helper = None
@@ -346,12 +325,19 @@ class AdaMuon_adv(torch.optim.Optimizer):
346
325
  if 'is_muon' in state:
347
326
  return
348
327
 
349
- if group['use_muon']:
328
+ if group.get('use_muon') is not None:
329
+ state['is_muon'] = group['use_muon']
330
+ elif group.get('optim_type') is not None:
331
+ state['is_muon'] = group['optim_type'] == 'muon'
332
+ else: # Auto-detect per parameter
333
+ state['is_muon'] = _is_suitable_for_muon(p)
350
334
 
351
- state['factored'] = (
352
- group['nnmf_factor'] and
353
- not (len(p.shape) == 1 and not group['vector_reshape'])
354
- )
335
+ if state['is_muon']:
336
+
337
+ req_precision = group['state_precision']
338
+ is_vector = len(p.shape) == 1 and not group['vector_reshape']
339
+
340
+ state['factored'] = req_precision == 'factored' and not is_vector
355
341
  dtype = torch.float32 if state['factored'] else p.dtype
356
342
  device = p.device
357
343
 
@@ -362,23 +348,21 @@ class AdaMuon_adv(torch.optim.Optimizer):
362
348
  state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
363
349
  packed_d2 = (d2 + 7) // 8
364
350
  state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
351
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=device, dtype=torch.uint8)
365
352
  if not group['normuon_variant']:
366
353
  state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
367
354
  state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
368
355
  else:
369
356
  # Determine effective state precision (small tensors always use fp32)
370
357
  req_precision = group.get('state_precision', 'auto')
371
- actual_precision = req_precision
372
- if actual_precision != 'auto' and (p.numel() < 10000 or p.ndim == 1):
373
- actual_precision = 'fp32'
358
+ actual_precision = 'auto' if req_precision == 'factored' else req_precision
374
359
  group['actual_state_precision'] = actual_precision
375
360
 
376
361
  # factored_2nd: factorize v_t only; ignored for NorMuon (no v_t) and tiny params
377
362
  use_factored_2nd = (
378
363
  group.get('factored_2nd', False)
379
364
  and not group['normuon_variant']
380
- and p.numel() >= 10000
381
- and p.ndim > 1
365
+ and not (len(p.shape) == 1 and not group['vector_reshape'])
382
366
  )
383
367
  state['factored_2nd'] = use_factored_2nd
384
368
 
@@ -493,19 +477,22 @@ class AdaMuon_adv(torch.optim.Optimizer):
493
477
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
494
478
  elif actual_precision == 'fp8_sr':
495
479
  random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
480
+ if group['low_rank_ortho']:
481
+ random_G_sketch = param_update._get_random_noise_for_low_rank_ortho(p, group['ortho_rank'])
496
482
  else:
497
483
  lr = group['lr']
498
- muon_step_param = self._muon_step_parameter
499
484
  random_int_state_tensor = None
485
+ random_G_sketch = None
486
+ muon_step_param = self._muon_step_parameter
500
487
 
501
- muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor)
488
+ muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch)
502
489
 
503
490
  def compile(self, *args, **kwargs):
504
491
  self._compiled_muon_step_parameter = torch.compile(self._muon_step_parameter, *args, **kwargs)
505
492
  self._compiled_adam_step_parameter = torch.compile(Muon_AuxAdam._adam_step_parameter, *args, **kwargs)
506
493
 
507
494
  @torch.no_grad()
508
- def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor=None):
495
+ def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch):
509
496
  # Upcast grad for low-precision state modes (non-factored path)
510
497
  grad = upcast_grad_for_precision(grad, state, group.get('state_precision', 'auto'))
511
498
  beta1, beta2 = group['betas']
@@ -523,14 +510,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
523
510
  else:
524
511
  kappa_p = 1.0
525
512
 
526
- if group.get('spectral_normalization', False):
527
-
528
- ns_eps, adaptive_eps, _, _ = get_spectral_scaling(p, p.shape, group.get('n_layers', 1))
529
- decoupled_wd = True
530
- else:
531
- decoupled_wd = False
532
- ns_eps = group['ns_eps']
533
- adaptive_eps = group['eps']
513
+ ns_eps = group['ns_eps']
514
+ adaptive_eps = scale_eps(group['eps'], p)
534
515
 
535
516
  # MARS-M Approximated (Variance Reduction)
536
517
  if group.get('approx_mars', False):
@@ -545,7 +526,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
545
526
  grad_reshaped = grad.view(d1, d2)
546
527
 
547
528
  # Reconstruct momentum from previous step's factors & sign
548
- mt_buf = _reconstruct_state((state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'], d2), signed=True)
529
+ mt_buf = _reconstruct_state((state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'], d2), signed=True, shifter=state['shifter'])
549
530
 
550
531
  # Update momentum in full-size
551
532
  mt_buf.lerp_(grad_reshaped, 1 - beta1)
@@ -557,7 +538,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
557
538
  update = mt_buf.clone()
558
539
 
559
540
  # Factorize
560
- state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'] = _factorize_state(mt_buf, signed=True)
541
+ state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'] = _factorize_state(mt_buf, signed=True, shifter=state['shifter'])
561
542
  del mt_buf
562
543
 
563
544
  # Apply update projection
@@ -573,7 +554,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
573
554
  cns_a_bound=group['cns_a_bound'],
574
555
  low_rank_ortho=group['low_rank_ortho'],
575
556
  ortho_rank=group['ortho_rank'],
576
- spectral_normalization=group.get('spectral_normalization', False),
557
+ G_sketch=random_G_sketch,
577
558
  compiled=group.get('compiled_optimizer', False)
578
559
  )
579
560
 
@@ -581,10 +562,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
581
562
  normuon_update(update, state['normuon_v'], beta2, group['eps'])
582
563
  else:
583
564
  # Reconstruct second momentum from previous step's factors
584
- vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False)
565
+ vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False, shifter=state['shifter'])
585
566
  # Update second momentum in full-size
586
567
  vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
587
- state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False)
568
+ state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False, shifter=state['shifter'])
588
569
  # Apply second momentum update (adaptive scaling)
589
570
  if group['use_atan2']:
590
571
  denom = vt_buf.sqrt_()
@@ -629,7 +610,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
629
610
  cns_a_bound=group['cns_a_bound'],
630
611
  low_rank_ortho=group['low_rank_ortho'],
631
612
  ortho_rank=group['ortho_rank'],
632
- spectral_normalization=group.get('spectral_normalization', False),
613
+ G_sketch=random_G_sketch,
633
614
  compiled=group.get('compiled_optimizer', False)
634
615
  )
635
616
 
@@ -641,9 +622,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
641
622
  d1, d2 = state['effective_shape']
642
623
  update = update.view(original_shape)
643
624
  update_f32 = update.float()
644
- vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False)
625
+ vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False, shifter=state['shifter'])
645
626
  vt_buf.mul_(beta2).addcmul_(update_f32.view(d1, d2), update_f32.view(d1, d2), value=1 - beta2)
646
- state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False)
627
+ state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False, shifter=state['shifter'])
647
628
  # Apply second moment scaling
648
629
  if group['use_atan2']:
649
630
  denom = vt_buf.sqrt_().view(original_shape)
@@ -678,7 +659,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
678
659
 
679
660
  update = update.reshape(original_shape)
680
661
 
681
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, decoupled=decoupled_wd)
662
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
682
663
 
683
664
  @torch.no_grad()
684
665
  def step(self, closure=None):