adv-optm 2.4.dev20__tar.gz → 2.4.dev22__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/PKG-INFO +12 -59
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/README.md +10 -57
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/__init__.py +1 -5
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/AdaMuon_adv.py +47 -66
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/AdamW_adv.py +22 -98
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/Adopt_adv.py +28 -103
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/Lion_adv.py +4 -26
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/Muon_adv.py +39 -58
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/Prodigy_adv.py +29 -96
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/SignSGD_adv.py +55 -21
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/SinkSGD_adv.py +69 -57
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/optim/__init__.py +0 -4
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/Muon_AuxAdam.py +15 -60
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/Muon_util.py +22 -58
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/centered_decay.py +0 -6
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/factorization_util.py +9 -10
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/param_update.py +47 -10
- adv_optm-2.4.dev22/adv_optm/util/signed_util.py +56 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/state_util.py +12 -14
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/update_util.py +0 -32
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm.egg-info/PKG-INFO +12 -59
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm.egg-info/SOURCES.txt +0 -2
- adv_optm-2.4.dev22/adv_optm.egg-info/requires.txt +1 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/setup.py +2 -2
- adv_optm-2.4.dev20/adv_optm/optim/Lion_Prodigy_adv.py +0 -389
- adv_optm-2.4.dev20/adv_optm/optim/Simplified_AdEMAMix.py +0 -384
- adv_optm-2.4.dev20/adv_optm/util/signed_util.py +0 -56
- adv_optm-2.4.dev20/adv_optm.egg-info/requires.txt +0 -1
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/LICENSE +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/scaled_optm.py +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev20 → adv_optm-2.4.dev22}/setup.cfg +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: adv_optm
|
|
3
|
-
Version: 2.4.
|
|
3
|
+
Version: 2.4.dev22
|
|
4
4
|
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
5
|
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
6
|
Author: Koratahiu
|
|
@@ -15,7 +15,7 @@ Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
|
15
15
|
Requires-Python: >=3.8
|
|
16
16
|
Description-Content-Type: text/markdown
|
|
17
17
|
License-File: LICENSE
|
|
18
|
-
Requires-Dist: torch>=2.
|
|
18
|
+
Requires-Dist: torch>=2.1
|
|
19
19
|
Dynamic: author
|
|
20
20
|
Dynamic: author-email
|
|
21
21
|
Dynamic: classifier
|
|
@@ -37,10 +37,6 @@ A comprehensive, all-in-one collection of optimization algorithms for deep learn
|
|
|
37
37
|
|
|
38
38
|
## 🔥 What's New
|
|
39
39
|
|
|
40
|
-
### in 2.2.2
|
|
41
|
-
|
|
42
|
-
- `Simplified_AdEMAMix` now uses the same LR as AdamW for all `beta1` and `alpha_grad` values!
|
|
43
|
-
|
|
44
40
|
### in 2.1.x
|
|
45
41
|
|
|
46
42
|
- Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
|
|
@@ -101,7 +97,6 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
101
97
|
|-----------|--------------|-------------|
|
|
102
98
|
| `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
|
|
103
99
|
| `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
|
|
104
|
-
| `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
|
|
105
100
|
|
|
106
101
|
### Speed Comparison (SDXL, Batch Size 4)
|
|
107
102
|
| Optimizer | Speed | Notes |
|
|
@@ -120,7 +115,6 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
120
115
|
| `Adam_Adv` | Advanced Adam implementation | General purpose |
|
|
121
116
|
| `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
|
|
122
117
|
| `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
|
|
123
|
-
| `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
|
|
124
118
|
| `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
|
|
125
119
|
| `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
|
|
126
120
|
|
|
@@ -128,18 +122,14 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
128
122
|
|
|
129
123
|
## ⚙️ Feature Matrix
|
|
130
124
|
|
|
131
|
-
| Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv |
|
|
132
|
-
|
|
133
|
-
| Factored | ✓ | ✓ | ✓
|
|
134
|
-
|
|
|
135
|
-
|
|
|
136
|
-
|
|
|
137
|
-
|
|
|
138
|
-
|
|
|
139
|
-
| atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
140
|
-
| Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
141
|
-
| Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
142
|
-
| **Kourkoutas-β** | ✓ | ✓ | ✓ | ✓ | ✗ |
|
|
125
|
+
| Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Lion_Adv |
|
|
126
|
+
|---------|----------|-----------|-------------|----------|
|
|
127
|
+
| Factored | ✓ | ✓ | ✓ ✓ |
|
|
128
|
+
| OrthoGrad | ✓ | ✓ | ✓ | ✓ |
|
|
129
|
+
| atan2 | ✓ | ✓ | ✓ |✗ |
|
|
130
|
+
| Stochastic Rounding | ✓ | ✓ | ✓ |✓ |
|
|
131
|
+
| Fused Backward Pass | ✓ | ✓ | ✓ | ✓ |
|
|
132
|
+
| **Kourkoutas-β** | ✓ | ✓ | ✓ | ✗ |
|
|
143
133
|
|
|
144
134
|
---
|
|
145
135
|
|
|
@@ -159,48 +149,13 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
159
149
|
|
|
160
150
|
| Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
|
|
161
151
|
|--------|-------------|-------------------|--------------------|-------------------|--------------|
|
|
162
|
-
| **Cautious** | Only applies update if gradient direction aligns with momentum direction | Accelerating convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/Prodigy/Lion |
|
|
163
|
-
| **Grams** | Update direction derived purely from current gradient | When Cautious is insufficient | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/Prodigy |
|
|
164
|
-
| **AdEMAMix** | Dual EMA system that retains relevance of gradients over tens of thousands of steps | Long training runs, especially where model forgetting is a concern | +1 state memory | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
|
|
165
|
-
| **Simplified_AdEMAMix** | Accumulator-based momentum, single EMA variant of AdEMAMix | All scenarios when tuned correctly | No overhead | [Connections](https://arxiv.org/abs/2502.02431) | Adam/Adopt/Prodigy |
|
|
166
152
|
| **atan2** | Robust epsilon replacement with built-in gradient clipping | Use for stable bounded updates (or for Adopt as it needs that) | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/Prodigy |
|
|
167
|
-
| **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy
|
|
168
|
-
|
|
169
|
-
> **Note**: If both **Cautious** and **Grams** are enabled, **Grams takes precedence** and Cautious is disabled.
|
|
153
|
+
| **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy |
|
|
170
154
|
|
|
171
155
|
---
|
|
172
156
|
|
|
173
157
|
## 🔍 Feature Deep Dives
|
|
174
158
|
|
|
175
|
-
### AdEMAMix
|
|
176
|
-
|
|
177
|
-
- Adds a **slow-decaying second EMA** (`beta3`) that retains gradient memory over tens of thousands of steps.
|
|
178
|
-
- Particularly effective for **small batch sizes**, where Adam’s standard first moment is nearly useless.
|
|
179
|
-
|
|
180
|
-
#### Tunable Hyperparameters
|
|
181
|
-
| Parameter | Default | Tuning Guide |
|
|
182
|
-
|-----------|---------|--------------|
|
|
183
|
-
| `beta3` | 0.9999 | • Runs >120k steps: **0.9999**<br>• Runs ≤120k steps: **0.999** |
|
|
184
|
-
| `alpha` | 5 | • Reduce to **2–3** if diverging<br>• Increase to strengthen long-term memory |
|
|
185
|
-
|
|
186
|
-
> ✅ **Pro Tip**: Set `beta1=0` in Adam/Adopt/Prodigy to skip standard EMA entirely and rely solely on AdEMAMix’s slow EMA, ideal for small-batch regimes.
|
|
187
|
-
|
|
188
|
-
---
|
|
189
|
-
|
|
190
|
-
### Simplified_AdEMAMix
|
|
191
|
-
|
|
192
|
-
- Introduced in [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants (arXiv:2502.02431)](https://arxiv.org/abs/2502.02431).
|
|
193
|
-
- Replaces Adam’s first moment with a **theory-based momentum** with emphasize on raw gradient, combining the stability of long memory with responsiveness to recent gradients.
|
|
194
|
-
- **Key insight**: Classical momentum **does not accelerate** in noisy (small-batch) regimes; this accumulator do.
|
|
195
|
-
|
|
196
|
-
#### Tunable Hyperparameters
|
|
197
|
-
| Parameter | Default | Tuning Guide |
|
|
198
|
-
|----------|---------|--------------|
|
|
199
|
-
| `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
|
|
200
|
-
| `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
|
|
201
|
-
|
|
202
|
-
---
|
|
203
|
-
|
|
204
159
|
### atan2
|
|
205
160
|
|
|
206
161
|
- Replaces `eps` in Adam-family optimizers with a **scale-invariant**, bounded update rule.
|
|
@@ -215,7 +170,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
215
170
|
|
|
216
171
|
### **Kourkoutas-β**
|
|
217
172
|
|
|
218
|
-
**Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv
|
|
173
|
+
**Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv`.
|
|
219
174
|
|
|
220
175
|
Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
|
|
221
176
|
|
|
@@ -243,7 +198,5 @@ This is especially effective for **noisy training, small batch sizes, and high l
|
|
|
243
198
|
|
|
244
199
|
1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
|
|
245
200
|
2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
246
|
-
3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
|
|
247
|
-
4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
|
|
248
201
|
6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
249
202
|
7. [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/abs/2407.05872)
|
|
@@ -6,10 +6,6 @@ A comprehensive, all-in-one collection of optimization algorithms for deep learn
|
|
|
6
6
|
|
|
7
7
|
## 🔥 What's New
|
|
8
8
|
|
|
9
|
-
### in 2.2.2
|
|
10
|
-
|
|
11
|
-
- `Simplified_AdEMAMix` now uses the same LR as AdamW for all `beta1` and `alpha_grad` values!
|
|
12
|
-
|
|
13
9
|
### in 2.1.x
|
|
14
10
|
|
|
15
11
|
- Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
|
|
@@ -70,7 +66,6 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
70
66
|
|-----------|--------------|-------------|
|
|
71
67
|
| `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
|
|
72
68
|
| `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
|
|
73
|
-
| `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
|
|
74
69
|
|
|
75
70
|
### Speed Comparison (SDXL, Batch Size 4)
|
|
76
71
|
| Optimizer | Speed | Notes |
|
|
@@ -89,7 +84,6 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
89
84
|
| `Adam_Adv` | Advanced Adam implementation | General purpose |
|
|
90
85
|
| `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
|
|
91
86
|
| `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
|
|
92
|
-
| `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
|
|
93
87
|
| `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
|
|
94
88
|
| `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
|
|
95
89
|
|
|
@@ -97,18 +91,14 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
97
91
|
|
|
98
92
|
## ⚙️ Feature Matrix
|
|
99
93
|
|
|
100
|
-
| Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv |
|
|
101
|
-
|
|
102
|
-
| Factored | ✓ | ✓ | ✓
|
|
103
|
-
|
|
|
104
|
-
|
|
|
105
|
-
|
|
|
106
|
-
|
|
|
107
|
-
|
|
|
108
|
-
| atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
109
|
-
| Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
110
|
-
| Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
111
|
-
| **Kourkoutas-β** | ✓ | ✓ | ✓ | ✓ | ✗ |
|
|
94
|
+
| Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Lion_Adv |
|
|
95
|
+
|---------|----------|-----------|-------------|----------|
|
|
96
|
+
| Factored | ✓ | ✓ | ✓ ✓ |
|
|
97
|
+
| OrthoGrad | ✓ | ✓ | ✓ | ✓ |
|
|
98
|
+
| atan2 | ✓ | ✓ | ✓ |✗ |
|
|
99
|
+
| Stochastic Rounding | ✓ | ✓ | ✓ |✓ |
|
|
100
|
+
| Fused Backward Pass | ✓ | ✓ | ✓ | ✓ |
|
|
101
|
+
| **Kourkoutas-β** | ✓ | ✓ | ✓ | ✗ |
|
|
112
102
|
|
|
113
103
|
---
|
|
114
104
|
|
|
@@ -128,48 +118,13 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
128
118
|
|
|
129
119
|
| Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
|
|
130
120
|
|--------|-------------|-------------------|--------------------|-------------------|--------------|
|
|
131
|
-
| **Cautious** | Only applies update if gradient direction aligns with momentum direction | Accelerating convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/Prodigy/Lion |
|
|
132
|
-
| **Grams** | Update direction derived purely from current gradient | When Cautious is insufficient | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/Prodigy |
|
|
133
|
-
| **AdEMAMix** | Dual EMA system that retains relevance of gradients over tens of thousands of steps | Long training runs, especially where model forgetting is a concern | +1 state memory | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
|
|
134
|
-
| **Simplified_AdEMAMix** | Accumulator-based momentum, single EMA variant of AdEMAMix | All scenarios when tuned correctly | No overhead | [Connections](https://arxiv.org/abs/2502.02431) | Adam/Adopt/Prodigy |
|
|
135
121
|
| **atan2** | Robust epsilon replacement with built-in gradient clipping | Use for stable bounded updates (or for Adopt as it needs that) | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/Prodigy |
|
|
136
|
-
| **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy
|
|
137
|
-
|
|
138
|
-
> **Note**: If both **Cautious** and **Grams** are enabled, **Grams takes precedence** and Cautious is disabled.
|
|
122
|
+
| **Kourkoutas-β** | Layer-wise adaptive β₂ based on gradient “sunspike” ratio | Noisy/small/large-batch/high-LR training | No overhead | [Kourkoutas-β]() | Adam/Adopt/Prodigy |
|
|
139
123
|
|
|
140
124
|
---
|
|
141
125
|
|
|
142
126
|
## 🔍 Feature Deep Dives
|
|
143
127
|
|
|
144
|
-
### AdEMAMix
|
|
145
|
-
|
|
146
|
-
- Adds a **slow-decaying second EMA** (`beta3`) that retains gradient memory over tens of thousands of steps.
|
|
147
|
-
- Particularly effective for **small batch sizes**, where Adam’s standard first moment is nearly useless.
|
|
148
|
-
|
|
149
|
-
#### Tunable Hyperparameters
|
|
150
|
-
| Parameter | Default | Tuning Guide |
|
|
151
|
-
|-----------|---------|--------------|
|
|
152
|
-
| `beta3` | 0.9999 | • Runs >120k steps: **0.9999**<br>• Runs ≤120k steps: **0.999** |
|
|
153
|
-
| `alpha` | 5 | • Reduce to **2–3** if diverging<br>• Increase to strengthen long-term memory |
|
|
154
|
-
|
|
155
|
-
> ✅ **Pro Tip**: Set `beta1=0` in Adam/Adopt/Prodigy to skip standard EMA entirely and rely solely on AdEMAMix’s slow EMA, ideal for small-batch regimes.
|
|
156
|
-
|
|
157
|
-
---
|
|
158
|
-
|
|
159
|
-
### Simplified_AdEMAMix
|
|
160
|
-
|
|
161
|
-
- Introduced in [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants (arXiv:2502.02431)](https://arxiv.org/abs/2502.02431).
|
|
162
|
-
- Replaces Adam’s first moment with a **theory-based momentum** with emphasize on raw gradient, combining the stability of long memory with responsiveness to recent gradients.
|
|
163
|
-
- **Key insight**: Classical momentum **does not accelerate** in noisy (small-batch) regimes; this accumulator do.
|
|
164
|
-
|
|
165
|
-
#### Tunable Hyperparameters
|
|
166
|
-
| Parameter | Default | Tuning Guide |
|
|
167
|
-
|----------|---------|--------------|
|
|
168
|
-
| `beta1` | 0.99 | Controls accumulator memory length:<br>• Small BS: **0.99–0.9999**<br>• Large BS: **0.9** |
|
|
169
|
-
| `Grad α` | 100 | Most critical parameter:<br>• Inversely scales with batch size<br>• **100–10** for small BS (≤32)<br>• **1–0.1** for large BS (≥512) |
|
|
170
|
-
|
|
171
|
-
---
|
|
172
|
-
|
|
173
128
|
### atan2
|
|
174
129
|
|
|
175
130
|
- Replaces `eps` in Adam-family optimizers with a **scale-invariant**, bounded update rule.
|
|
@@ -184,7 +139,7 @@ This library integrates multiple state-of-the-art optimization techniques valida
|
|
|
184
139
|
|
|
185
140
|
### **Kourkoutas-β**
|
|
186
141
|
|
|
187
|
-
**Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv
|
|
142
|
+
**Kourkoutas-β** introduces a **sunspike-driven, layer-wise adaptive second-moment decay (β₂)** as an optional enhancement for `Adam_Adv`, `Adopt_Adv`, `Prodigy_Adv`.
|
|
188
143
|
|
|
189
144
|
Instead of using a fixed β₂ (e.g., 0.999 or 0.95), it **dynamically modulates β₂ per layer** based on a bounded *sunspike ratio*:
|
|
190
145
|
|
|
@@ -212,7 +167,5 @@ This is especially effective for **noisy training, small batch sizes, and high l
|
|
|
212
167
|
|
|
213
168
|
1. [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192)
|
|
214
169
|
2. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
215
|
-
3. [The AdEMAMix Optimizer](https://arxiv.org/abs/2409.03137)
|
|
216
|
-
4. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD](https://arxiv.org/abs/2502.02431)
|
|
217
170
|
6. [Kourkoutas-β: A Sunspike-Driven Adam Optimizer with Desert Flair](https://arxiv.org/abs/2508.12996)
|
|
218
171
|
7. [Scaling Exponents Across Parameterizations and Optimizers](https://arxiv.org/abs/2407.05872)
|
|
@@ -2,9 +2,7 @@ from .optim import (
|
|
|
2
2
|
AdamW_adv,
|
|
3
3
|
Prodigy_adv,
|
|
4
4
|
Adopt_adv,
|
|
5
|
-
Simplified_AdEMAMix,
|
|
6
5
|
Lion_adv,
|
|
7
|
-
Lion_Prodigy_adv,
|
|
8
6
|
Muon_adv,
|
|
9
7
|
AdaMuon_adv,
|
|
10
8
|
SignSGD_adv,
|
|
@@ -15,13 +13,11 @@ __all__ = [
|
|
|
15
13
|
"AdamW_adv",
|
|
16
14
|
"Prodigy_adv",
|
|
17
15
|
"Adopt_adv",
|
|
18
|
-
"Simplified_AdEMAMix",
|
|
19
16
|
"Lion_adv",
|
|
20
|
-
"Lion_Prodigy_adv",
|
|
21
17
|
"Muon_adv",
|
|
22
18
|
"AdaMuon_adv",
|
|
23
19
|
"SignSGD_adv",
|
|
24
20
|
"SinkSGD_adv",
|
|
25
21
|
]
|
|
26
22
|
|
|
27
|
-
__version__ = "2.4.
|
|
23
|
+
__version__ = "2.4.dev22"
|
|
@@ -3,8 +3,8 @@ import torch
|
|
|
3
3
|
import math
|
|
4
4
|
|
|
5
5
|
from ..util import param_update
|
|
6
|
-
from ..util.Muon_util import newton_schulz, _is_suitable_for_muon, rms_adjustment, normuon_update, approx_mars, _auto_projection_for_adamuon
|
|
7
|
-
from ..util.scaled_optm import spectral_normalization, init_spectral_norm
|
|
6
|
+
from ..util.Muon_util import newton_schulz, _is_suitable_for_muon, rms_adjustment, normuon_update, approx_mars, _auto_projection_for_adamuon
|
|
7
|
+
from ..util.scaled_optm import spectral_normalization, init_spectral_norm, scale_eps
|
|
8
8
|
from ..util.factorization_util import _get_effective_shape, _factorize_state, _reconstruct_state
|
|
9
9
|
from ..util.OrthoGrad import _orthogonalize_gradient
|
|
10
10
|
from ..util.Kourkoutas import KourkoutasHelper
|
|
@@ -50,7 +50,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
50
50
|
vector, used for RMS-aligned rescaling. Allows for the reuse of existing Adam
|
|
51
51
|
learning rate schedules. (default: True).
|
|
52
52
|
ns_steps (int): number of Newton-Schulz iterations to perform (default: 5).
|
|
53
|
-
ns_eps (float): epsilon for Newton-Schulz normalization stability
|
|
53
|
+
ns_eps (float): epsilon for Newton-Schulz normalization stability. When None
|
|
54
|
+
it's derived from scale invariant rule (default: 1e-7).
|
|
54
55
|
ns_coeffs (tuple[float, float, float]): The (a, b, c) coefficients for the
|
|
55
56
|
quintic polynomial in the Newton-Schulz iteration.
|
|
56
57
|
(default: (3.4445, -4.7750, 2.0315)).
|
|
@@ -77,7 +78,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
77
78
|
(default: 128)
|
|
78
79
|
accelerated_ns (bool): If True, enables Chebyshev-accelerated Newton-Schulz, which
|
|
79
80
|
dynamically calculates optimal 3rd-order polynomial coefficients. (default: False)
|
|
80
|
-
cns_a_bound (float): Initial lower bound for singular values for CANS.
|
|
81
|
+
cns_a_bound (float): Initial lower bound for singular values for CANS. When None
|
|
82
|
+
it's derived from scale invariant rule (default: None).
|
|
81
83
|
approx_mars (bool): If True, enables Approximated MARS-M variance reduction.
|
|
82
84
|
fom the paper "MARS-M: When Variance Reduction Meets Matrices"
|
|
83
85
|
(default: False)
|
|
@@ -112,12 +114,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
112
114
|
adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
|
|
113
115
|
adam_use_bias_correction (bool): Bias correction for AdamW.
|
|
114
116
|
adam_use_atan2 (bool): Atan2 update rule for AdamW.
|
|
115
|
-
adam_cautious_mask (bool): Cautious masking for AdamW.
|
|
116
|
-
adam_grams_moment (bool): Grams-style updates for AdamW.
|
|
117
117
|
adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
|
|
118
|
-
adam_use_AdEMAMix (bool): AdEMAMix for AdamW.
|
|
119
|
-
adam_beta3_ema (float): Beta3 for AdEMAMix.
|
|
120
|
-
adam_alpha (float): Alpha for AdEMAMix.
|
|
121
118
|
adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
|
|
122
119
|
adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
|
|
123
120
|
adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
|
|
@@ -126,7 +123,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
126
123
|
adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
|
|
127
124
|
adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
|
|
128
125
|
adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
|
|
129
|
-
adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
|
|
126
|
+
adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
|
|
130
127
|
adam_nnmf_factor (bool): 1-bit factored for AdamW.
|
|
131
128
|
adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
|
|
132
129
|
"""
|
|
@@ -147,7 +144,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
147
144
|
rms_rescaling: bool = True,
|
|
148
145
|
# Newton Schulz
|
|
149
146
|
ns_steps: int = 5,
|
|
150
|
-
ns_eps: float = 1e-7,
|
|
147
|
+
ns_eps: float | None = 1e-7,
|
|
151
148
|
ns_coeffs: tuple[float, float, float] = (3.4445, -4.7750, 2.0315),
|
|
152
149
|
# Stochastic Rounding for BF16
|
|
153
150
|
stochastic_rounding: bool = True,
|
|
@@ -174,7 +171,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
174
171
|
nnmf_factor: bool = False,
|
|
175
172
|
# CANS
|
|
176
173
|
accelerated_ns: bool = False,
|
|
177
|
-
cns_a_bound: float =
|
|
174
|
+
cns_a_bound: float | None = None,
|
|
178
175
|
# MARS-M
|
|
179
176
|
approx_mars: bool = False,
|
|
180
177
|
mars_gamma: float = 0.025,
|
|
@@ -193,12 +190,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
193
190
|
adam_fisher_wd: bool = False,
|
|
194
191
|
adam_use_bias_correction: bool = True,
|
|
195
192
|
adam_use_atan2: bool = False,
|
|
196
|
-
adam_cautious_mask: bool = False,
|
|
197
|
-
adam_grams_moment: bool = False,
|
|
198
193
|
adam_orthogonal_gradient: bool = False,
|
|
199
|
-
adam_use_AdEMAMix: bool = False,
|
|
200
|
-
adam_beta3_ema: float = 0.9999,
|
|
201
|
-
adam_alpha: float = 5.0,
|
|
202
194
|
adam_nesterov: bool = False,
|
|
203
195
|
adam_nesterov_coef: float | None = None,
|
|
204
196
|
adam_kourkoutas_beta: bool = False,
|
|
@@ -223,8 +215,12 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
223
215
|
if spectral_normalization and accelerated_ns:
|
|
224
216
|
ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
|
|
225
217
|
|
|
218
|
+
# Legacy backwards compatibility support for `nnmf_factor=True`
|
|
219
|
+
if nnmf_factor:
|
|
220
|
+
state_precision = "factored"
|
|
221
|
+
|
|
226
222
|
state_precision = state_precision.lower()
|
|
227
|
-
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp8_sr", "int8_sr"}
|
|
223
|
+
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
|
|
228
224
|
if state_precision not in valid_precisions:
|
|
229
225
|
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
230
226
|
|
|
@@ -262,9 +258,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
262
258
|
"adam_betas": adam_betas, "adam_eps": adam_eps, "adam_weight_decay": adam_weight_decay,
|
|
263
259
|
"adam_fisher_wd": adam_fisher_wd,
|
|
264
260
|
"adam_use_bias_correction": adam_use_bias_correction, "adam_use_atan2": adam_use_atan2,
|
|
265
|
-
"adam_cautious_mask": adam_cautious_mask, "adam_grams_moment": adam_grams_moment,
|
|
266
261
|
"adam_orthogonal_gradient": adam_orthogonal_gradient,
|
|
267
|
-
"adam_use_AdEMAMix": adam_use_AdEMAMix, "adam_beta3_ema": adam_beta3_ema, "adam_alpha": adam_alpha,
|
|
268
262
|
"adam_nesterov": adam_nesterov, "adam_nesterov_coef": adam_nesterov_coef,
|
|
269
263
|
"adam_kourkoutas_beta": adam_kourkoutas_beta, "adam_beta2_min": adam_beta2_min,
|
|
270
264
|
"adam_ema_alpha": adam_ema_alpha, "adam_tiny_spike": adam_tiny_spike,
|
|
@@ -274,25 +268,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
274
268
|
"adam_nnmf_factor": adam_nnmf_factor, "adam_factored_2nd": adam_factored_2nd,
|
|
275
269
|
}
|
|
276
270
|
self.stochastic_rounding = stochastic_rounding
|
|
277
|
-
self._init_lr = lr
|
|
271
|
+
self._init_lr = lr if lr > 0 else 1
|
|
278
272
|
|
|
279
273
|
super().__init__(params, defaults)
|
|
280
274
|
|
|
281
|
-
# Validate that every group has a determined optimizer type
|
|
282
|
-
for i, group in enumerate(self.param_groups):
|
|
283
|
-
if group.get('use_muon') is None and group.get('optim_type') is None:
|
|
284
|
-
# Automatic shape-based detection if not explicit
|
|
285
|
-
has_muon_shape = False
|
|
286
|
-
for p in group['params']:
|
|
287
|
-
has_muon_shape = _is_suitable_for_muon(p)
|
|
288
|
-
if has_muon_shape:
|
|
289
|
-
group['use_muon'] = True
|
|
290
|
-
else:
|
|
291
|
-
group['use_muon'] = False
|
|
292
|
-
|
|
293
|
-
if group.get('use_muon') is None: # Fallback
|
|
294
|
-
group['use_muon'] = group.get('optim_type') == 'muon'
|
|
295
|
-
|
|
296
275
|
self.init_step()
|
|
297
276
|
|
|
298
277
|
self.kourkoutas_helper = None
|
|
@@ -346,12 +325,19 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
346
325
|
if 'is_muon' in state:
|
|
347
326
|
return
|
|
348
327
|
|
|
349
|
-
if group
|
|
328
|
+
if group.get('use_muon') is not None:
|
|
329
|
+
state['is_muon'] = group['use_muon']
|
|
330
|
+
elif group.get('optim_type') is not None:
|
|
331
|
+
state['is_muon'] = group['optim_type'] == 'muon'
|
|
332
|
+
else: # Auto-detect per parameter
|
|
333
|
+
state['is_muon'] = _is_suitable_for_muon(p)
|
|
350
334
|
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
)
|
|
335
|
+
if state['is_muon']:
|
|
336
|
+
|
|
337
|
+
req_precision = group['state_precision']
|
|
338
|
+
is_vector = len(p.shape) == 1 and not group['vector_reshape']
|
|
339
|
+
|
|
340
|
+
state['factored'] = req_precision == 'factored' and not is_vector
|
|
355
341
|
dtype = torch.float32 if state['factored'] else p.dtype
|
|
356
342
|
device = p.device
|
|
357
343
|
|
|
@@ -362,23 +348,21 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
362
348
|
state['mv_mbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
363
349
|
packed_d2 = (d2 + 7) // 8
|
|
364
350
|
state['sign_buf'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=device)
|
|
351
|
+
state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=device, dtype=torch.uint8)
|
|
365
352
|
if not group['normuon_variant']:
|
|
366
353
|
state['mu_vbuf_nmf'] = torch.zeros(d1, device=device, dtype=dtype)
|
|
367
354
|
state['mv_vbuf_nmf'] = torch.zeros(d2, device=device, dtype=dtype)
|
|
368
355
|
else:
|
|
369
356
|
# Determine effective state precision (small tensors always use fp32)
|
|
370
357
|
req_precision = group.get('state_precision', 'auto')
|
|
371
|
-
actual_precision = req_precision
|
|
372
|
-
if actual_precision != 'auto' and (p.numel() < 10000 or p.ndim == 1):
|
|
373
|
-
actual_precision = 'fp32'
|
|
358
|
+
actual_precision = 'auto' if req_precision == 'factored' else req_precision
|
|
374
359
|
group['actual_state_precision'] = actual_precision
|
|
375
360
|
|
|
376
361
|
# factored_2nd: factorize v_t only; ignored for NorMuon (no v_t) and tiny params
|
|
377
362
|
use_factored_2nd = (
|
|
378
363
|
group.get('factored_2nd', False)
|
|
379
364
|
and not group['normuon_variant']
|
|
380
|
-
and p.
|
|
381
|
-
and p.ndim > 1
|
|
365
|
+
and not (len(p.shape) == 1 and not group['vector_reshape'])
|
|
382
366
|
)
|
|
383
367
|
state['factored_2nd'] = use_factored_2nd
|
|
384
368
|
|
|
@@ -493,19 +477,22 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
493
477
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
494
478
|
elif actual_precision == 'fp8_sr':
|
|
495
479
|
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
480
|
+
if group['low_rank_ortho']:
|
|
481
|
+
random_G_sketch = param_update._get_random_noise_for_low_rank_ortho(p, group['ortho_rank'])
|
|
496
482
|
else:
|
|
497
483
|
lr = group['lr']
|
|
498
|
-
muon_step_param = self._muon_step_parameter
|
|
499
484
|
random_int_state_tensor = None
|
|
485
|
+
random_G_sketch = None
|
|
486
|
+
muon_step_param = self._muon_step_parameter
|
|
500
487
|
|
|
501
|
-
muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor)
|
|
488
|
+
muon_step_param(p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch)
|
|
502
489
|
|
|
503
490
|
def compile(self, *args, **kwargs):
|
|
504
491
|
self._compiled_muon_step_parameter = torch.compile(self._muon_step_parameter, *args, **kwargs)
|
|
505
492
|
self._compiled_adam_step_parameter = torch.compile(Muon_AuxAdam._adam_step_parameter, *args, **kwargs)
|
|
506
493
|
|
|
507
494
|
@torch.no_grad()
|
|
508
|
-
def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor
|
|
495
|
+
def _muon_step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_int_state_tensor, random_G_sketch):
|
|
509
496
|
# Upcast grad for low-precision state modes (non-factored path)
|
|
510
497
|
grad = upcast_grad_for_precision(grad, state, group.get('state_precision', 'auto'))
|
|
511
498
|
beta1, beta2 = group['betas']
|
|
@@ -523,14 +510,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
523
510
|
else:
|
|
524
511
|
kappa_p = 1.0
|
|
525
512
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
ns_eps, adaptive_eps, _, _ = get_spectral_scaling(p, p.shape, group.get('n_layers', 1))
|
|
529
|
-
decoupled_wd = True
|
|
530
|
-
else:
|
|
531
|
-
decoupled_wd = False
|
|
532
|
-
ns_eps = group['ns_eps']
|
|
533
|
-
adaptive_eps = group['eps']
|
|
513
|
+
ns_eps = group['ns_eps']
|
|
514
|
+
adaptive_eps = scale_eps(group['eps'], p)
|
|
534
515
|
|
|
535
516
|
# MARS-M Approximated (Variance Reduction)
|
|
536
517
|
if group.get('approx_mars', False):
|
|
@@ -545,7 +526,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
545
526
|
grad_reshaped = grad.view(d1, d2)
|
|
546
527
|
|
|
547
528
|
# Reconstruct momentum from previous step's factors & sign
|
|
548
|
-
mt_buf = _reconstruct_state((state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'], d2), signed=True)
|
|
529
|
+
mt_buf = _reconstruct_state((state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'], d2), signed=True, shifter=state['shifter'])
|
|
549
530
|
|
|
550
531
|
# Update momentum in full-size
|
|
551
532
|
mt_buf.lerp_(grad_reshaped, 1 - beta1)
|
|
@@ -557,7 +538,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
557
538
|
update = mt_buf.clone()
|
|
558
539
|
|
|
559
540
|
# Factorize
|
|
560
|
-
state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'] = _factorize_state(mt_buf, signed=True)
|
|
541
|
+
state['mu_mbuf_nmf'], state['mv_mbuf_nmf'], state['sign_buf'] = _factorize_state(mt_buf, signed=True, shifter=state['shifter'])
|
|
561
542
|
del mt_buf
|
|
562
543
|
|
|
563
544
|
# Apply update projection
|
|
@@ -573,7 +554,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
573
554
|
cns_a_bound=group['cns_a_bound'],
|
|
574
555
|
low_rank_ortho=group['low_rank_ortho'],
|
|
575
556
|
ortho_rank=group['ortho_rank'],
|
|
576
|
-
|
|
557
|
+
G_sketch=random_G_sketch,
|
|
577
558
|
compiled=group.get('compiled_optimizer', False)
|
|
578
559
|
)
|
|
579
560
|
|
|
@@ -581,10 +562,10 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
581
562
|
normuon_update(update, state['normuon_v'], beta2, group['eps'])
|
|
582
563
|
else:
|
|
583
564
|
# Reconstruct second momentum from previous step's factors
|
|
584
|
-
vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False)
|
|
565
|
+
vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False, shifter=state['shifter'])
|
|
585
566
|
# Update second momentum in full-size
|
|
586
567
|
vt_buf.mul_(beta2).addcmul_(update, update, value=1 - beta2)
|
|
587
|
-
state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False)
|
|
568
|
+
state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False, shifter=state['shifter'])
|
|
588
569
|
# Apply second momentum update (adaptive scaling)
|
|
589
570
|
if group['use_atan2']:
|
|
590
571
|
denom = vt_buf.sqrt_()
|
|
@@ -629,7 +610,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
629
610
|
cns_a_bound=group['cns_a_bound'],
|
|
630
611
|
low_rank_ortho=group['low_rank_ortho'],
|
|
631
612
|
ortho_rank=group['ortho_rank'],
|
|
632
|
-
|
|
613
|
+
G_sketch=random_G_sketch,
|
|
633
614
|
compiled=group.get('compiled_optimizer', False)
|
|
634
615
|
)
|
|
635
616
|
|
|
@@ -641,9 +622,9 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
641
622
|
d1, d2 = state['effective_shape']
|
|
642
623
|
update = update.view(original_shape)
|
|
643
624
|
update_f32 = update.float()
|
|
644
|
-
vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False)
|
|
625
|
+
vt_buf = _reconstruct_state((state['mu_vbuf_nmf'], state['mv_vbuf_nmf']), signed=False, shifter=state['shifter'])
|
|
645
626
|
vt_buf.mul_(beta2).addcmul_(update_f32.view(d1, d2), update_f32.view(d1, d2), value=1 - beta2)
|
|
646
|
-
state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False)
|
|
627
|
+
state['mu_vbuf_nmf'], state['mv_vbuf_nmf'] = _factorize_state(vt_buf, signed=False, shifter=state['shifter'])
|
|
647
628
|
# Apply second moment scaling
|
|
648
629
|
if group['use_atan2']:
|
|
649
630
|
denom = vt_buf.sqrt_().view(original_shape)
|
|
@@ -678,7 +659,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
678
659
|
|
|
679
660
|
update = update.reshape(original_shape)
|
|
680
661
|
|
|
681
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor
|
|
662
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor)
|
|
682
663
|
|
|
683
664
|
@torch.no_grad()
|
|
684
665
|
def step(self, closure=None):
|