adv-optm 0.1.7__tar.gz → 0.1.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of adv-optm might be problematic. Click here for more details.
- adv_optm-0.1.9/PKG-INFO +174 -0
- adv_optm-0.1.9/README.md +143 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/__init__.py +1 -1
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/AdamW_adv.py +13 -4
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Adopt_adv.py +52 -13
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Lion_Prodigy_adv.py +3 -37
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Lion_adv.py +6 -39
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Prodigy_adv.py +76 -39
- adv_optm-0.1.9/adv_optm.egg-info/PKG-INFO +174 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/setup.py +1 -1
- adv_optm-0.1.7/PKG-INFO +0 -130
- adv_optm-0.1.7/README.md +0 -99
- adv_optm-0.1.7/adv_optm.egg-info/PKG-INFO +0 -130
- {adv_optm-0.1.7 → adv_optm-0.1.9}/LICENSE +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/Effective_Shape.py +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/NNMF.py +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/One_Bit_Boolean.py +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/OrthoGrad.py +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/__init__.py +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-0.1.7 → adv_optm-0.1.9}/setup.cfg +0 -0
adv_optm-0.1.9/PKG-INFO
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: adv_optm
|
|
3
|
+
Version: 0.1.9
|
|
4
|
+
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
|
+
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
|
+
Author: Koratahiu
|
|
7
|
+
Author-email: hiuhonor@gmail.com
|
|
8
|
+
License: Apache 2.0
|
|
9
|
+
Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
15
|
+
Requires-Python: >=3.8
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: torch>=2.0
|
|
19
|
+
Dynamic: author
|
|
20
|
+
Dynamic: author-email
|
|
21
|
+
Dynamic: classifier
|
|
22
|
+
Dynamic: description
|
|
23
|
+
Dynamic: description-content-type
|
|
24
|
+
Dynamic: home-page
|
|
25
|
+
Dynamic: keywords
|
|
26
|
+
Dynamic: license
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
31
|
+
|
|
32
|
+
# Advanced Optimizers (AIO)
|
|
33
|
+
|
|
34
|
+
A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for maximum efficiency, minimal memory footprint, and superior performance across diverse model architectures and training scenarios.
|
|
35
|
+
|
|
36
|
+
[](https://pypi.org/project/adv_optm/)
|
|
37
|
+
|
|
38
|
+
---
|
|
39
|
+
|
|
40
|
+
## 📦 Installation
|
|
41
|
+
|
|
42
|
+
```bash
|
|
43
|
+
pip install adv_optm
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
---
|
|
47
|
+
|
|
48
|
+
## 🧠 Core Innovations
|
|
49
|
+
|
|
50
|
+
This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training, with 1-bit compression for optimizer states:
|
|
51
|
+
|
|
52
|
+
### **Memory-Efficient Optimization (SMMF-inspired)**
|
|
53
|
+
- **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
54
|
+
- **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
|
|
55
|
+
- **Innovation**:
|
|
56
|
+
- First moment split into **1-bit sign + absolute value**
|
|
57
|
+
- Final storage: **four factored vectors + one 1-bit sign state**
|
|
58
|
+
- Preserves Adam-like update quality with drastically reduced memory
|
|
59
|
+
|
|
60
|
+
---
|
|
61
|
+
|
|
62
|
+
## ⚡ Performance Characteristics
|
|
63
|
+
|
|
64
|
+
### Memory Efficiency (SDXL Model - 6.5GB)
|
|
65
|
+
| Optimizer | Memory Usage | Description |
|
|
66
|
+
|-----------|--------------|-------------|
|
|
67
|
+
| `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
|
|
68
|
+
| `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
|
|
69
|
+
| `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
|
|
70
|
+
|
|
71
|
+
### Speed Comparison (SDXL, Batch Size 4)
|
|
72
|
+
| Optimizer | Speed | Notes |
|
|
73
|
+
|-----------|-------|-------|
|
|
74
|
+
| `Adafactor` | ~8.5s/it | Baseline |
|
|
75
|
+
| `Adopt_Factored` | ~10s/it | +18% overhead from compression |
|
|
76
|
+
| `Adopt_Factored + AdEMAMix` | ~12s/it | +41% overhead (3 factored states) |
|
|
77
|
+
|
|
78
|
+
---
|
|
79
|
+
|
|
80
|
+
## 🧪 Available Optimizers
|
|
81
|
+
|
|
82
|
+
### Standard Optimizers (All support `factored=True/False`)
|
|
83
|
+
| Optimizer | Description | Best For |
|
|
84
|
+
|-----------|-------------|----------|
|
|
85
|
+
| `Adam_Adv` | Advanced Adam implementation | General purpose |
|
|
86
|
+
| `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
|
|
87
|
+
| `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
|
|
88
|
+
| `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
|
|
89
|
+
| `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
|
|
90
|
+
| `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
|
|
91
|
+
|
|
92
|
+
### Feature Matrix
|
|
93
|
+
| Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Simplified_AdEMAMix | Lion_Adv |
|
|
94
|
+
|---------|----------|-----------|-------------|---------------------|----------|
|
|
95
|
+
| Factored | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
96
|
+
| AdEMAMix | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
97
|
+
| Simplified_AdEMAMix | ✗ | ✗ | ✓ | ✓ | ✗ |
|
|
98
|
+
| OrthoGrad | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
99
|
+
| Grams | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
100
|
+
| Cautious | ✓ | ✓ | ✓ | ✗ | ✓ |
|
|
101
|
+
| atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
102
|
+
| Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
103
|
+
| Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
104
|
+
|
|
105
|
+
---
|
|
106
|
+
|
|
107
|
+
## ⚙️ Key Features & Parameters
|
|
108
|
+
|
|
109
|
+
### Comprehensive Feature Guide
|
|
110
|
+
|
|
111
|
+
| Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
|
|
112
|
+
|---------|-------------|-------------------|--------------------|-------------------|--------------|
|
|
113
|
+
| **Factored** | Memory-efficient optimization using rank-1 factorization | Enable for large models (>1B params) or limited VRAM | +12-41% time overhead, 1-bit memory usage | [SMMF](https://arxiv.org/abs/2412.08894) | All optimizers |
|
|
114
|
+
| **AdEMAMix** | Dual EMA system for momentum | Use for long training runs (10k+ steps) | +1 state memory. | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
|
|
115
|
+
| **Simplified_AdEMAMix** | Accumulator-based momentum | Small batch training (≤32) | Same memory as standard, no extra overhead | [Schedule-Free Connections](https://arxiv.org/abs/2502.02431) | Adam/Prodigy |
|
|
116
|
+
| **OrthoGrad** | Removes gradient component parallel to weights | Full finetuning without weight decay | +33% time overhead, no memory impact | [Grokking at Edge](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | All optimizers |
|
|
117
|
+
| **Stochastic Rounding** | Improves precision for BF16 training | BF16 training | Minimal overhead (<5%) | [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192) | All optimizers |
|
|
118
|
+
| **atan2** | Robust eps replacement + built-in clipping | Use with Adopt or unstable training | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/prodigy |
|
|
119
|
+
| **Cautious** | Update only when the direction align with the gradients | should faster the convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/prodigy |
|
|
120
|
+
| **Grams** | Update direction from the gradients | should have a stronger effect than cautious | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/prodigy |
|
|
121
|
+
|
|
122
|
+
---
|
|
123
|
+
|
|
124
|
+
## Simplified_AdEMAMix Parameters
|
|
125
|
+
Simplified_AdEMAMix replaces standard momentum with an accumulator for better small-large batch performance.
|
|
126
|
+
|
|
127
|
+
| Parameter | Recommended Values | Description |
|
|
128
|
+
|-----------|---------------------|-------------|
|
|
129
|
+
| `beta1` | 0.9 (large BS), 0.99-0.9999 (small BS) | Determines memory length of accumulator |
|
|
130
|
+
| `alpha` | 100-10 (small BS), 1-0 (large BS) | Gradient smoothing factor |
|
|
131
|
+
|
|
132
|
+
**Alpha Tuning Guide**:
|
|
133
|
+
| Batch Size | Recommended α | Rationale |
|
|
134
|
+
|------------|---------------|-----------|
|
|
135
|
+
| Small (≤32) | 100, 50, 20, 10 | Emphasizes recent gradients for quick adaptation |
|
|
136
|
+
| Medium (32-512) | 10, 5, 2, 1 | Balanced approach |
|
|
137
|
+
| Large (≥512) | 1, 0.5, 0 | Emphasizes historical gradients for stability |
|
|
138
|
+
|
|
139
|
+
⚠️ **Important**: Use **~100x smaller learning rate** with Simplified_AdEMAMix compared to AdamW (e.g., 1e-6 instead of 1e-4)
|
|
140
|
+
|
|
141
|
+
### 📊 Performance Validation
|
|
142
|
+
Small Batch Training (SDXL, BS=2, 1.8K steps)
|
|
143
|
+

|
|
144
|
+
|
|
145
|
+
- **🟢 Prodigy_adv** (beta1=0.9, d0=1e-5): Final LR=2.9e-4
|
|
146
|
+
- **🔵 Prodigy_adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR=5.8e-6
|
|
147
|
+
|
|
148
|
+
**Results**:
|
|
149
|
+
- Simplified_AdEMAMix shows faster convergence and better final performance
|
|
150
|
+
- D-Adaptation automatically handles aggressive updates (50x smaller LR)
|
|
151
|
+
- Generated samples show significantly better quality with Simplified_AdEMAMix
|
|
152
|
+
|
|
153
|
+
---
|
|
154
|
+
|
|
155
|
+
## ⚠️ Known Limitations
|
|
156
|
+
|
|
157
|
+
### 1. Prodigy_Adv Sensitivity
|
|
158
|
+
- Highly sensitive to gradient modifications (Adopt normalization, low-rank factorization)
|
|
159
|
+
- May fail to increase learning rate in some LoRA scenarios
|
|
160
|
+
- **Fix**: Disable factorization or set beta1=0
|
|
161
|
+
|
|
162
|
+
### 2. Aggressive Learning Rates
|
|
163
|
+
- Can destabilize factored first moment
|
|
164
|
+
- **Recommendation**: Check Prodigy learning rate as reference for safe LR threshold
|
|
165
|
+
|
|
166
|
+
---
|
|
167
|
+
|
|
168
|
+
## 📚 References
|
|
169
|
+
|
|
170
|
+
1. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
171
|
+
2. [The AdEMAMix Optimizer: Better, Faster, Older](https://arxiv.org/abs/2409.03137)
|
|
172
|
+
3. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431)
|
|
173
|
+
|
|
174
|
+
---
|
adv_optm-0.1.9/README.md
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
# Advanced Optimizers (AIO)
|
|
2
|
+
|
|
3
|
+
A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for maximum efficiency, minimal memory footprint, and superior performance across diverse model architectures and training scenarios.
|
|
4
|
+
|
|
5
|
+
[](https://pypi.org/project/adv_optm/)
|
|
6
|
+
|
|
7
|
+
---
|
|
8
|
+
|
|
9
|
+
## 📦 Installation
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
pip install adv_optm
|
|
13
|
+
```
|
|
14
|
+
|
|
15
|
+
---
|
|
16
|
+
|
|
17
|
+
## 🧠 Core Innovations
|
|
18
|
+
|
|
19
|
+
This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training, with 1-bit compression for optimizer states:
|
|
20
|
+
|
|
21
|
+
### **Memory-Efficient Optimization (SMMF-inspired)**
|
|
22
|
+
- **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
23
|
+
- **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
|
|
24
|
+
- **Innovation**:
|
|
25
|
+
- First moment split into **1-bit sign + absolute value**
|
|
26
|
+
- Final storage: **four factored vectors + one 1-bit sign state**
|
|
27
|
+
- Preserves Adam-like update quality with drastically reduced memory
|
|
28
|
+
|
|
29
|
+
---
|
|
30
|
+
|
|
31
|
+
## ⚡ Performance Characteristics
|
|
32
|
+
|
|
33
|
+
### Memory Efficiency (SDXL Model - 6.5GB)
|
|
34
|
+
| Optimizer | Memory Usage | Description |
|
|
35
|
+
|-----------|--------------|-------------|
|
|
36
|
+
| `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
|
|
37
|
+
| `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
|
|
38
|
+
| `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
|
|
39
|
+
|
|
40
|
+
### Speed Comparison (SDXL, Batch Size 4)
|
|
41
|
+
| Optimizer | Speed | Notes |
|
|
42
|
+
|-----------|-------|-------|
|
|
43
|
+
| `Adafactor` | ~8.5s/it | Baseline |
|
|
44
|
+
| `Adopt_Factored` | ~10s/it | +18% overhead from compression |
|
|
45
|
+
| `Adopt_Factored + AdEMAMix` | ~12s/it | +41% overhead (3 factored states) |
|
|
46
|
+
|
|
47
|
+
---
|
|
48
|
+
|
|
49
|
+
## 🧪 Available Optimizers
|
|
50
|
+
|
|
51
|
+
### Standard Optimizers (All support `factored=True/False`)
|
|
52
|
+
| Optimizer | Description | Best For |
|
|
53
|
+
|-----------|-------------|----------|
|
|
54
|
+
| `Adam_Adv` | Advanced Adam implementation | General purpose |
|
|
55
|
+
| `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
|
|
56
|
+
| `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
|
|
57
|
+
| `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
|
|
58
|
+
| `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
|
|
59
|
+
| `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
|
|
60
|
+
|
|
61
|
+
### Feature Matrix
|
|
62
|
+
| Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Simplified_AdEMAMix | Lion_Adv |
|
|
63
|
+
|---------|----------|-----------|-------------|---------------------|----------|
|
|
64
|
+
| Factored | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
65
|
+
| AdEMAMix | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
66
|
+
| Simplified_AdEMAMix | ✗ | ✗ | ✓ | ✓ | ✗ |
|
|
67
|
+
| OrthoGrad | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
68
|
+
| Grams | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
69
|
+
| Cautious | ✓ | ✓ | ✓ | ✗ | ✓ |
|
|
70
|
+
| atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
|
|
71
|
+
| Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
72
|
+
| Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
|
|
73
|
+
|
|
74
|
+
---
|
|
75
|
+
|
|
76
|
+
## ⚙️ Key Features & Parameters
|
|
77
|
+
|
|
78
|
+
### Comprehensive Feature Guide
|
|
79
|
+
|
|
80
|
+
| Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
|
|
81
|
+
|---------|-------------|-------------------|--------------------|-------------------|--------------|
|
|
82
|
+
| **Factored** | Memory-efficient optimization using rank-1 factorization | Enable for large models (>1B params) or limited VRAM | +12-41% time overhead, 1-bit memory usage | [SMMF](https://arxiv.org/abs/2412.08894) | All optimizers |
|
|
83
|
+
| **AdEMAMix** | Dual EMA system for momentum | Use for long training runs (10k+ steps) | +1 state memory. | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
|
|
84
|
+
| **Simplified_AdEMAMix** | Accumulator-based momentum | Small batch training (≤32) | Same memory as standard, no extra overhead | [Schedule-Free Connections](https://arxiv.org/abs/2502.02431) | Adam/Prodigy |
|
|
85
|
+
| **OrthoGrad** | Removes gradient component parallel to weights | Full finetuning without weight decay | +33% time overhead, no memory impact | [Grokking at Edge](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | All optimizers |
|
|
86
|
+
| **Stochastic Rounding** | Improves precision for BF16 training | BF16 training | Minimal overhead (<5%) | [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192) | All optimizers |
|
|
87
|
+
| **atan2** | Robust eps replacement + built-in clipping | Use with Adopt or unstable training | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/prodigy |
|
|
88
|
+
| **Cautious** | Update only when the direction align with the gradients | should faster the convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/prodigy |
|
|
89
|
+
| **Grams** | Update direction from the gradients | should have a stronger effect than cautious | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/prodigy |
|
|
90
|
+
|
|
91
|
+
---
|
|
92
|
+
|
|
93
|
+
## Simplified_AdEMAMix Parameters
|
|
94
|
+
Simplified_AdEMAMix replaces standard momentum with an accumulator for better small-large batch performance.
|
|
95
|
+
|
|
96
|
+
| Parameter | Recommended Values | Description |
|
|
97
|
+
|-----------|---------------------|-------------|
|
|
98
|
+
| `beta1` | 0.9 (large BS), 0.99-0.9999 (small BS) | Determines memory length of accumulator |
|
|
99
|
+
| `alpha` | 100-10 (small BS), 1-0 (large BS) | Gradient smoothing factor |
|
|
100
|
+
|
|
101
|
+
**Alpha Tuning Guide**:
|
|
102
|
+
| Batch Size | Recommended α | Rationale |
|
|
103
|
+
|------------|---------------|-----------|
|
|
104
|
+
| Small (≤32) | 100, 50, 20, 10 | Emphasizes recent gradients for quick adaptation |
|
|
105
|
+
| Medium (32-512) | 10, 5, 2, 1 | Balanced approach |
|
|
106
|
+
| Large (≥512) | 1, 0.5, 0 | Emphasizes historical gradients for stability |
|
|
107
|
+
|
|
108
|
+
⚠️ **Important**: Use **~100x smaller learning rate** with Simplified_AdEMAMix compared to AdamW (e.g., 1e-6 instead of 1e-4)
|
|
109
|
+
|
|
110
|
+
### 📊 Performance Validation
|
|
111
|
+
Small Batch Training (SDXL, BS=2, 1.8K steps)
|
|
112
|
+

|
|
113
|
+
|
|
114
|
+
- **🟢 Prodigy_adv** (beta1=0.9, d0=1e-5): Final LR=2.9e-4
|
|
115
|
+
- **🔵 Prodigy_adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR=5.8e-6
|
|
116
|
+
|
|
117
|
+
**Results**:
|
|
118
|
+
- Simplified_AdEMAMix shows faster convergence and better final performance
|
|
119
|
+
- D-Adaptation automatically handles aggressive updates (50x smaller LR)
|
|
120
|
+
- Generated samples show significantly better quality with Simplified_AdEMAMix
|
|
121
|
+
|
|
122
|
+
---
|
|
123
|
+
|
|
124
|
+
## ⚠️ Known Limitations
|
|
125
|
+
|
|
126
|
+
### 1. Prodigy_Adv Sensitivity
|
|
127
|
+
- Highly sensitive to gradient modifications (Adopt normalization, low-rank factorization)
|
|
128
|
+
- May fail to increase learning rate in some LoRA scenarios
|
|
129
|
+
- **Fix**: Disable factorization or set beta1=0
|
|
130
|
+
|
|
131
|
+
### 2. Aggressive Learning Rates
|
|
132
|
+
- Can destabilize factored first moment
|
|
133
|
+
- **Recommendation**: Check Prodigy learning rate as reference for safe LR threshold
|
|
134
|
+
|
|
135
|
+
---
|
|
136
|
+
|
|
137
|
+
## 📚 References
|
|
138
|
+
|
|
139
|
+
1. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
|
|
140
|
+
2. [The AdEMAMix Optimizer: Better, Faster, Older](https://arxiv.org/abs/2409.03137)
|
|
141
|
+
3. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431)
|
|
142
|
+
|
|
143
|
+
---
|
|
@@ -55,7 +55,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
55
55
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
56
56
|
the scheduler is disabled. (default: None)
|
|
57
57
|
factored (bool): whether to use the factorization or disable it to use
|
|
58
|
-
the uncompressed optimizer. (default:
|
|
58
|
+
the uncompressed optimizer. (default: False)
|
|
59
59
|
"""
|
|
60
60
|
|
|
61
61
|
def __init__(
|
|
@@ -76,7 +76,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
76
76
|
beta3_ema: float = 0.9999,
|
|
77
77
|
alpha: float = 5.0,
|
|
78
78
|
t_alpha: int | None = None,
|
|
79
|
-
factored: bool =
|
|
79
|
+
factored: bool = False,
|
|
80
80
|
):
|
|
81
81
|
if not (lr >= 0.0):
|
|
82
82
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -86,6 +86,9 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
86
86
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
87
87
|
if not (weight_decay >= 0.0):
|
|
88
88
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
89
|
+
if use_cautious and use_grams:
|
|
90
|
+
print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
|
|
91
|
+
use_cautious = False
|
|
89
92
|
|
|
90
93
|
defaults = {
|
|
91
94
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
@@ -216,7 +219,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
216
219
|
del unpacked_sign_slow
|
|
217
220
|
|
|
218
221
|
mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
|
|
219
|
-
|
|
222
|
+
if beta1 > 0:
|
|
223
|
+
update = torch.add(mt, mt_slow, alpha=alpha_t)
|
|
224
|
+
else:
|
|
225
|
+
update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
|
|
220
226
|
else:
|
|
221
227
|
update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
|
|
222
228
|
del grad_reshaped
|
|
@@ -262,7 +268,10 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
262
268
|
if self.use_AdEMAMix:
|
|
263
269
|
exp_avg_slow = state['exp_avg_slow']
|
|
264
270
|
exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
|
|
265
|
-
|
|
271
|
+
if beta1 > 0:
|
|
272
|
+
update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
|
|
273
|
+
else:
|
|
274
|
+
update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
|
|
266
275
|
else:
|
|
267
276
|
update = exp_avg.clone() if beta1 > 0 else grad.clone()
|
|
268
277
|
|
|
@@ -62,8 +62,18 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
62
62
|
the warmup, `alpha` ramps from 0 to its target value. If `None`,
|
|
63
63
|
the scheduler is disabled and the full `alpha` value is used from
|
|
64
64
|
the start. (default: None)
|
|
65
|
+
Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
|
|
66
|
+
This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
|
|
67
|
+
more responsive, especially for small batch sizes. Enabling this will
|
|
68
|
+
automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
|
|
69
|
+
and `use_atan2`. (default: False)
|
|
70
|
+
alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
|
|
71
|
+
(only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
|
|
72
|
+
current gradient. For small batch sizes, use high values (e.g., 10-100) to be
|
|
73
|
+
more responsive. For large batch sizes, use low values (e.g., 0-1) for
|
|
74
|
+
stability. (default: 100.0)
|
|
65
75
|
factored (bool): whether to use the factorization or disable it to use
|
|
66
|
-
the uncompressed optimizer. (default:
|
|
76
|
+
the uncompressed optimizer. (default: False)
|
|
67
77
|
"""
|
|
68
78
|
|
|
69
79
|
def __init__(
|
|
@@ -77,14 +87,16 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
77
87
|
vector_reshape: bool = True,
|
|
78
88
|
stochastic_rounding: bool = True,
|
|
79
89
|
use_atan2: bool = False,
|
|
80
|
-
use_cautious: bool =
|
|
90
|
+
use_cautious: bool = False,
|
|
81
91
|
use_grams: bool = False,
|
|
82
92
|
use_orthograd: bool = False,
|
|
83
93
|
use_AdEMAMix: bool = False,
|
|
84
94
|
beta3_ema: float = 0.9999,
|
|
85
95
|
alpha: float = 5.0,
|
|
86
96
|
t_alpha: int | None = None,
|
|
87
|
-
|
|
97
|
+
Simplified_AdEMAMix: bool = False,
|
|
98
|
+
alpha_grad: float = 100.0,
|
|
99
|
+
factored: bool = False,
|
|
88
100
|
):
|
|
89
101
|
if not (lr >= 0.0):
|
|
90
102
|
raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
|
|
@@ -94,19 +106,34 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
94
106
|
raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
|
|
95
107
|
if not (weight_decay >= 0.0):
|
|
96
108
|
raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
|
|
109
|
+
if use_cautious and use_grams:
|
|
110
|
+
print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
|
|
111
|
+
use_cautious = False
|
|
112
|
+
if betas[0] == 0.0 and Simplified_AdEMAMix:
|
|
113
|
+
raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
|
|
114
|
+
if use_AdEMAMix and Simplified_AdEMAMix:
|
|
115
|
+
print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
|
|
116
|
+
if use_grams and Simplified_AdEMAMix:
|
|
117
|
+
print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
|
|
118
|
+
if use_cautious and Simplified_AdEMAMix:
|
|
119
|
+
print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
|
|
120
|
+
if use_atan2 and Simplified_AdEMAMix:
|
|
121
|
+
print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
|
|
122
|
+
use_atan2 = False
|
|
97
123
|
|
|
98
124
|
defaults = {
|
|
99
125
|
"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
|
|
100
126
|
"vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
|
|
101
|
-
"t_alpha": t_alpha,
|
|
127
|
+
"t_alpha": t_alpha, "alpha_grad": alpha_grad,
|
|
102
128
|
}
|
|
103
129
|
self.clip_lambda = clip_lambda
|
|
104
130
|
self.stochastic_rounding = stochastic_rounding
|
|
105
|
-
self.use_atan2 = use_atan2
|
|
106
|
-
self.use_cautious = use_cautious
|
|
107
|
-
self.use_grams = use_grams
|
|
131
|
+
self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
|
|
132
|
+
self.use_cautious = use_cautious and not Simplified_AdEMAMix
|
|
133
|
+
self.use_grams = use_grams and not Simplified_AdEMAMix
|
|
108
134
|
self.use_orthograd = use_orthograd
|
|
109
|
-
self.use_AdEMAMix = use_AdEMAMix
|
|
135
|
+
self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
|
|
136
|
+
self.Simplified_AdEMAMix = Simplified_AdEMAMix
|
|
110
137
|
self.factored = factored
|
|
111
138
|
super().__init__(params, defaults)
|
|
112
139
|
|
|
@@ -185,6 +212,8 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
185
212
|
alpha_t = alpha
|
|
186
213
|
if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
|
|
187
214
|
alpha_t = min(current_step * alpha / t_alpha, alpha)
|
|
215
|
+
if self.Simplified_AdEMAMix:
|
|
216
|
+
alpha_grad = group["alpha_grad"]
|
|
188
217
|
|
|
189
218
|
if state['factored']:
|
|
190
219
|
d1, d2 = state['effective_shape']
|
|
@@ -224,7 +253,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
224
253
|
del denom
|
|
225
254
|
|
|
226
255
|
# ADOPT Step B: Update momentum m_t using normalized gradient
|
|
227
|
-
|
|
256
|
+
if self.Simplified_AdEMAMix:
|
|
257
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
258
|
+
else:
|
|
259
|
+
mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
228
260
|
if self.use_grams:
|
|
229
261
|
mt = grad_reshaped.sign() * mt.abs()
|
|
230
262
|
elif self.use_cautious:
|
|
@@ -235,8 +267,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
235
267
|
|
|
236
268
|
if self.use_AdEMAMix:
|
|
237
269
|
mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
238
|
-
update = mt
|
|
270
|
+
update = torch.add(mt, m_slow, alpha=alpha_t)
|
|
239
271
|
update = update.view(p.shape)
|
|
272
|
+
elif self.Simplified_AdEMAMix:
|
|
273
|
+
update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
|
|
240
274
|
else:
|
|
241
275
|
update = mt.view(p.shape)
|
|
242
276
|
|
|
@@ -283,7 +317,10 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
283
317
|
del denom
|
|
284
318
|
|
|
285
319
|
# ADOPT Step B: Update momentum m_t
|
|
286
|
-
|
|
320
|
+
if self.Simplified_AdEMAMix:
|
|
321
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0)
|
|
322
|
+
else:
|
|
323
|
+
m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
|
|
287
324
|
|
|
288
325
|
if self.use_grams:
|
|
289
326
|
m = grad.sign() * m.abs()
|
|
@@ -295,9 +332,11 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
295
332
|
|
|
296
333
|
if self.use_AdEMAMix:
|
|
297
334
|
m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
|
|
298
|
-
update = m
|
|
335
|
+
update = torch.add(m, m_slow, alpha=alpha_t)
|
|
336
|
+
elif self.Simplified_AdEMAMix:
|
|
337
|
+
update = torch.add(m, grad, alpha=alpha_grad)
|
|
299
338
|
else:
|
|
300
|
-
update = m
|
|
339
|
+
update = m.clone()
|
|
301
340
|
|
|
302
341
|
if self.use_atan2:
|
|
303
342
|
update.mul_(group['lr'] * 1.2732395447351628)
|
|
@@ -33,8 +33,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
33
33
|
(default: 0.0).
|
|
34
34
|
factored (bool): whether to use the factorization or use the
|
|
35
35
|
uncompressed optimizer. (default: True)
|
|
36
|
-
variance_reduction (bool): whether to use the variance reduction technique
|
|
37
|
-
from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
|
|
38
36
|
d0 (float):
|
|
39
37
|
Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
|
|
40
38
|
d_coef (float):
|
|
@@ -66,7 +64,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
66
64
|
use_cautious: bool = False,
|
|
67
65
|
clip_threshold: float = 0.0,
|
|
68
66
|
factored: bool = True,
|
|
69
|
-
variance_reduction: bool = False,
|
|
70
67
|
# prodigy parameters
|
|
71
68
|
beta3: float = None,
|
|
72
69
|
d0: float = 1e-6,
|
|
@@ -97,7 +94,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
97
94
|
self.stochastic_rounding = stochastic_rounding
|
|
98
95
|
self.use_cautious = use_cautious
|
|
99
96
|
self.factored = factored
|
|
100
|
-
self.variance_reduction = variance_reduction
|
|
101
97
|
self.fsdp_in_use = fsdp_in_use
|
|
102
98
|
super().__init__(params, defaults)
|
|
103
99
|
# Global state for accumulating metrics across parameter updates within a single step.
|
|
@@ -183,12 +179,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
183
179
|
state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
|
|
184
180
|
packed_d2 = (d2 + 7) // 8
|
|
185
181
|
state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
|
|
186
|
-
if self.variance_reduction:
|
|
187
|
-
state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
|
|
188
182
|
else: # Fallback to standard Lion
|
|
189
183
|
state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
190
|
-
if self.variance_reduction:
|
|
191
|
-
state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
|
|
192
184
|
|
|
193
185
|
if state['factored']:
|
|
194
186
|
# Factored Path
|
|
@@ -215,20 +207,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
215
207
|
update_for_param = signed_update.view(p.shape).mul(self.dlr)
|
|
216
208
|
|
|
217
209
|
# Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
|
|
218
|
-
|
|
219
|
-
if state['step'] == 1:
|
|
220
|
-
exp_avg.copy_(grad_reshaped)
|
|
221
|
-
else:
|
|
222
|
-
# Heuristic Prodigy-STORM update
|
|
223
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
224
|
-
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
225
|
-
exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
226
|
-
del correction, grad_alpha
|
|
227
|
-
state['prev_grad'].copy_(grad_reshaped)
|
|
228
|
-
else:
|
|
229
|
-
# Standard Prodigy-Lion
|
|
230
|
-
alpha = self.d * (1 - self.beta2)
|
|
231
|
-
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
|
|
210
|
+
exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
|
|
232
211
|
del grad_reshaped
|
|
233
212
|
|
|
234
213
|
# Compress new momentum m_t and store factors
|
|
@@ -254,20 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
254
233
|
update_for_param = signed_update.mul(self.dlr)
|
|
255
234
|
|
|
256
235
|
# Update momentum
|
|
257
|
-
|
|
258
|
-
if state['step'] == 1:
|
|
259
|
-
exp_avg.copy_(grad)
|
|
260
|
-
else:
|
|
261
|
-
# Heuristic Prodigy-STORM update
|
|
262
|
-
correction = exp_avg.sub(state['prev_grad'])
|
|
263
|
-
grad_alpha = self.d * (1 - self.beta2) + self.beta2
|
|
264
|
-
exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
|
|
265
|
-
del grad_alpha, correction
|
|
266
|
-
state['prev_grad'].copy_(grad)
|
|
267
|
-
else:
|
|
268
|
-
# Standard Prodigy-Lion
|
|
269
|
-
alpha = self.d * (1 - self.beta2)
|
|
270
|
-
exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
|
|
236
|
+
exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
|
|
271
237
|
|
|
272
238
|
# --- Accumulate Prodigy stats ---
|
|
273
239
|
d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
|
|
@@ -298,7 +264,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
|
|
|
298
264
|
else:
|
|
299
265
|
p.data.add_(-update_for_param)
|
|
300
266
|
|
|
301
|
-
|
|
267
|
+
del update_for_param
|
|
302
268
|
|
|
303
269
|
@torch.no_grad()
|
|
304
270
|
def step(self, closure: Optional[callable] = None):
|