adv-optm 0.1.7__tar.gz → 0.1.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of adv-optm might be problematic. Click here for more details.

Files changed (27) hide show
  1. adv_optm-0.1.9/PKG-INFO +174 -0
  2. adv_optm-0.1.9/README.md +143 -0
  3. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/__init__.py +1 -1
  4. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/AdamW_adv.py +13 -4
  5. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Adopt_adv.py +52 -13
  6. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Lion_Prodigy_adv.py +3 -37
  7. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Lion_adv.py +6 -39
  8. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Prodigy_adv.py +76 -39
  9. adv_optm-0.1.9/adv_optm.egg-info/PKG-INFO +174 -0
  10. {adv_optm-0.1.7 → adv_optm-0.1.9}/setup.py +1 -1
  11. adv_optm-0.1.7/PKG-INFO +0 -130
  12. adv_optm-0.1.7/README.md +0 -99
  13. adv_optm-0.1.7/adv_optm.egg-info/PKG-INFO +0 -130
  14. {adv_optm-0.1.7 → adv_optm-0.1.9}/LICENSE +0 -0
  15. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/Simplified_AdEMAMix.py +0 -0
  16. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/optim/__init__.py +0 -0
  17. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/BF16_Stochastic_Rounding.py +0 -0
  18. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/Effective_Shape.py +0 -0
  19. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/NNMF.py +0 -0
  20. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/One_Bit_Boolean.py +0 -0
  21. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/OrthoGrad.py +0 -0
  22. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm/util/__init__.py +0 -0
  23. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm.egg-info/SOURCES.txt +0 -0
  24. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm.egg-info/dependency_links.txt +0 -0
  25. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm.egg-info/requires.txt +0 -0
  26. {adv_optm-0.1.7 → adv_optm-0.1.9}/adv_optm.egg-info/top_level.txt +0 -0
  27. {adv_optm-0.1.7 → adv_optm-0.1.9}/setup.cfg +0 -0
@@ -0,0 +1,174 @@
1
+ Metadata-Version: 2.4
2
+ Name: adv_optm
3
+ Version: 0.1.9
4
+ Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
+ Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
+ Author: Koratahiu
7
+ Author-email: hiuhonor@gmail.com
8
+ License: Apache 2.0
9
+ Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch>=2.0
19
+ Dynamic: author
20
+ Dynamic: author-email
21
+ Dynamic: classifier
22
+ Dynamic: description
23
+ Dynamic: description-content-type
24
+ Dynamic: home-page
25
+ Dynamic: keywords
26
+ Dynamic: license
27
+ Dynamic: license-file
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
31
+
32
+ # Advanced Optimizers (AIO)
33
+
34
+ A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for maximum efficiency, minimal memory footprint, and superior performance across diverse model architectures and training scenarios.
35
+
36
+ [![PyPI](https://img.shields.io/pypi/v/adv_optm)](https://pypi.org/project/adv_optm/)
37
+
38
+ ---
39
+
40
+ ## 📦 Installation
41
+
42
+ ```bash
43
+ pip install adv_optm
44
+ ```
45
+
46
+ ---
47
+
48
+ ## 🧠 Core Innovations
49
+
50
+ This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training, with 1-bit compression for optimizer states:
51
+
52
+ ### **Memory-Efficient Optimization (SMMF-inspired)**
53
+ - **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
54
+ - **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
55
+ - **Innovation**:
56
+ - First moment split into **1-bit sign + absolute value**
57
+ - Final storage: **four factored vectors + one 1-bit sign state**
58
+ - Preserves Adam-like update quality with drastically reduced memory
59
+
60
+ ---
61
+
62
+ ## ⚡ Performance Characteristics
63
+
64
+ ### Memory Efficiency (SDXL Model - 6.5GB)
65
+ | Optimizer | Memory Usage | Description |
66
+ |-----------|--------------|-------------|
67
+ | `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
68
+ | `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
69
+ | `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
70
+
71
+ ### Speed Comparison (SDXL, Batch Size 4)
72
+ | Optimizer | Speed | Notes |
73
+ |-----------|-------|-------|
74
+ | `Adafactor` | ~8.5s/it | Baseline |
75
+ | `Adopt_Factored` | ~10s/it | +18% overhead from compression |
76
+ | `Adopt_Factored + AdEMAMix` | ~12s/it | +41% overhead (3 factored states) |
77
+
78
+ ---
79
+
80
+ ## 🧪 Available Optimizers
81
+
82
+ ### Standard Optimizers (All support `factored=True/False`)
83
+ | Optimizer | Description | Best For |
84
+ |-----------|-------------|----------|
85
+ | `Adam_Adv` | Advanced Adam implementation | General purpose |
86
+ | `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
87
+ | `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
88
+ | `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
89
+ | `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
90
+ | `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
91
+
92
+ ### Feature Matrix
93
+ | Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Simplified_AdEMAMix | Lion_Adv |
94
+ |---------|----------|-----------|-------------|---------------------|----------|
95
+ | Factored | ✓ | ✓ | ✓ | ✓ | ✓ |
96
+ | AdEMAMix | ✓ | ✓ | ✓ | ✗ | ✗ |
97
+ | Simplified_AdEMAMix | ✗ | ✗ | ✓ | ✓ | ✗ |
98
+ | OrthoGrad | ✓ | ✓ | ✓ | ✓ | ✓ |
99
+ | Grams | ✓ | ✓ | ✓ | ✗ | ✗ |
100
+ | Cautious | ✓ | ✓ | ✓ | ✗ | ✓ |
101
+ | atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
102
+ | Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
103
+ | Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
104
+
105
+ ---
106
+
107
+ ## ⚙️ Key Features & Parameters
108
+
109
+ ### Comprehensive Feature Guide
110
+
111
+ | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
112
+ |---------|-------------|-------------------|--------------------|-------------------|--------------|
113
+ | **Factored** | Memory-efficient optimization using rank-1 factorization | Enable for large models (>1B params) or limited VRAM | +12-41% time overhead, 1-bit memory usage | [SMMF](https://arxiv.org/abs/2412.08894) | All optimizers |
114
+ | **AdEMAMix** | Dual EMA system for momentum | Use for long training runs (10k+ steps) | +1 state memory. | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
115
+ | **Simplified_AdEMAMix** | Accumulator-based momentum | Small batch training (≤32) | Same memory as standard, no extra overhead | [Schedule-Free Connections](https://arxiv.org/abs/2502.02431) | Adam/Prodigy |
116
+ | **OrthoGrad** | Removes gradient component parallel to weights | Full finetuning without weight decay | +33% time overhead, no memory impact | [Grokking at Edge](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | All optimizers |
117
+ | **Stochastic Rounding** | Improves precision for BF16 training | BF16 training | Minimal overhead (<5%) | [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192) | All optimizers |
118
+ | **atan2** | Robust eps replacement + built-in clipping | Use with Adopt or unstable training | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/prodigy |
119
+ | **Cautious** | Update only when the direction align with the gradients | should faster the convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/prodigy |
120
+ | **Grams** | Update direction from the gradients | should have a stronger effect than cautious | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/prodigy |
121
+
122
+ ---
123
+
124
+ ## Simplified_AdEMAMix Parameters
125
+ Simplified_AdEMAMix replaces standard momentum with an accumulator for better small-large batch performance.
126
+
127
+ | Parameter | Recommended Values | Description |
128
+ |-----------|---------------------|-------------|
129
+ | `beta1` | 0.9 (large BS), 0.99-0.9999 (small BS) | Determines memory length of accumulator |
130
+ | `alpha` | 100-10 (small BS), 1-0 (large BS) | Gradient smoothing factor |
131
+
132
+ **Alpha Tuning Guide**:
133
+ | Batch Size | Recommended α | Rationale |
134
+ |------------|---------------|-----------|
135
+ | Small (≤32) | 100, 50, 20, 10 | Emphasizes recent gradients for quick adaptation |
136
+ | Medium (32-512) | 10, 5, 2, 1 | Balanced approach |
137
+ | Large (≥512) | 1, 0.5, 0 | Emphasizes historical gradients for stability |
138
+
139
+ ⚠️ **Important**: Use **~100x smaller learning rate** with Simplified_AdEMAMix compared to AdamW (e.g., 1e-6 instead of 1e-4)
140
+
141
+ ### 📊 Performance Validation
142
+ Small Batch Training (SDXL, BS=2, 1.8K steps)
143
+ ![Training Comparison](https://github.com/user-attachments/assets/7eff0671-cc59-47fc-8b63-d5205456d649)
144
+
145
+ - **🟢 Prodigy_adv** (beta1=0.9, d0=1e-5): Final LR=2.9e-4
146
+ - **🔵 Prodigy_adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR=5.8e-6
147
+
148
+ **Results**:
149
+ - Simplified_AdEMAMix shows faster convergence and better final performance
150
+ - D-Adaptation automatically handles aggressive updates (50x smaller LR)
151
+ - Generated samples show significantly better quality with Simplified_AdEMAMix
152
+
153
+ ---
154
+
155
+ ## ⚠️ Known Limitations
156
+
157
+ ### 1. Prodigy_Adv Sensitivity
158
+ - Highly sensitive to gradient modifications (Adopt normalization, low-rank factorization)
159
+ - May fail to increase learning rate in some LoRA scenarios
160
+ - **Fix**: Disable factorization or set beta1=0
161
+
162
+ ### 2. Aggressive Learning Rates
163
+ - Can destabilize factored first moment
164
+ - **Recommendation**: Check Prodigy learning rate as reference for safe LR threshold
165
+
166
+ ---
167
+
168
+ ## 📚 References
169
+
170
+ 1. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
171
+ 2. [The AdEMAMix Optimizer: Better, Faster, Older](https://arxiv.org/abs/2409.03137)
172
+ 3. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431)
173
+
174
+ ---
@@ -0,0 +1,143 @@
1
+ # Advanced Optimizers (AIO)
2
+
3
+ A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for maximum efficiency, minimal memory footprint, and superior performance across diverse model architectures and training scenarios.
4
+
5
+ [![PyPI](https://img.shields.io/pypi/v/adv_optm)](https://pypi.org/project/adv_optm/)
6
+
7
+ ---
8
+
9
+ ## 📦 Installation
10
+
11
+ ```bash
12
+ pip install adv_optm
13
+ ```
14
+
15
+ ---
16
+
17
+ ## 🧠 Core Innovations
18
+
19
+ This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training, with 1-bit compression for optimizer states:
20
+
21
+ ### **Memory-Efficient Optimization (SMMF-inspired)**
22
+ - **Paper**: [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
23
+ - **Approach**: Uses rank-1 non-negative matrix factorization with reconstruction cycle (factor → reconstruct → update → factor)
24
+ - **Innovation**:
25
+ - First moment split into **1-bit sign + absolute value**
26
+ - Final storage: **four factored vectors + one 1-bit sign state**
27
+ - Preserves Adam-like update quality with drastically reduced memory
28
+
29
+ ---
30
+
31
+ ## ⚡ Performance Characteristics
32
+
33
+ ### Memory Efficiency (SDXL Model - 6.5GB)
34
+ | Optimizer | Memory Usage | Description |
35
+ |-----------|--------------|-------------|
36
+ | `Adopt_Factored` | 328 MB | 4 small vectors + 1-bit state |
37
+ | `Adopt_Factored + AdEMAMix` | 625 MB | 6 small vectors + two 1-bit states |
38
+ | `Simplified_AdEMAMix` | 328 MB | Same as standard factored (no extra state) |
39
+
40
+ ### Speed Comparison (SDXL, Batch Size 4)
41
+ | Optimizer | Speed | Notes |
42
+ |-----------|-------|-------|
43
+ | `Adafactor` | ~8.5s/it | Baseline |
44
+ | `Adopt_Factored` | ~10s/it | +18% overhead from compression |
45
+ | `Adopt_Factored + AdEMAMix` | ~12s/it | +41% overhead (3 factored states) |
46
+
47
+ ---
48
+
49
+ ## 🧪 Available Optimizers
50
+
51
+ ### Standard Optimizers (All support `factored=True/False`)
52
+ | Optimizer | Description | Best For |
53
+ |-----------|-------------|----------|
54
+ | `Adam_Adv` | Advanced Adam implementation | General purpose |
55
+ | `Adopt_Adv` | Adam-variant with independent beta2 | Stable training for small batch size regimes |
56
+ | `Prodigy_Adv` | Prodigy with D-Adaptation | Adam with automatic LR tuning |
57
+ | `Simplified_AdEMAMix` | Adam variant with accumulator momentum | Small/large batch training when tuned correctly |
58
+ | `Lion_Adv` | Advanced Lion implementation | Memory-constrained environments |
59
+ | `Prodigy_Lion_Adv` | Prodigy + Lion combination | Lion with automatic LR tuning |
60
+
61
+ ### Feature Matrix
62
+ | Feature | Adam_Adv | Adopt_Adv | Prodigy_Adv | Simplified_AdEMAMix | Lion_Adv |
63
+ |---------|----------|-----------|-------------|---------------------|----------|
64
+ | Factored | ✓ | ✓ | ✓ | ✓ | ✓ |
65
+ | AdEMAMix | ✓ | ✓ | ✓ | ✗ | ✗ |
66
+ | Simplified_AdEMAMix | ✗ | ✗ | ✓ | ✓ | ✗ |
67
+ | OrthoGrad | ✓ | ✓ | ✓ | ✓ | ✓ |
68
+ | Grams | ✓ | ✓ | ✓ | ✗ | ✗ |
69
+ | Cautious | ✓ | ✓ | ✓ | ✗ | ✓ |
70
+ | atan2 | ✓ | ✓ | ✓ | ✗ | ✗ |
71
+ | Stochastic Rounding | ✓ | ✓ | ✓ | ✓ | ✓ |
72
+ | Fused Backward Pass | ✓ | ✓ | ✓ | ✓ | ✓ |
73
+
74
+ ---
75
+
76
+ ## ⚙️ Key Features & Parameters
77
+
78
+ ### Comprehensive Feature Guide
79
+
80
+ | Feature | Description | Recommended Usage | Performance Impact | Theoretical Basis | Compatibility |
81
+ |---------|-------------|-------------------|--------------------|-------------------|--------------|
82
+ | **Factored** | Memory-efficient optimization using rank-1 factorization | Enable for large models (>1B params) or limited VRAM | +12-41% time overhead, 1-bit memory usage | [SMMF](https://arxiv.org/abs/2412.08894) | All optimizers |
83
+ | **AdEMAMix** | Dual EMA system for momentum | Use for long training runs (10k+ steps) | +1 state memory. | [AdEMAMix](https://arxiv.org/abs/2409.03137) | Adam/Adopt/Prodigy |
84
+ | **Simplified_AdEMAMix** | Accumulator-based momentum | Small batch training (≤32) | Same memory as standard, no extra overhead | [Schedule-Free Connections](https://arxiv.org/abs/2502.02431) | Adam/Prodigy |
85
+ | **OrthoGrad** | Removes gradient component parallel to weights | Full finetuning without weight decay | +33% time overhead, no memory impact | [Grokking at Edge](https://github.com/LucasPrietoAl/grokking-at-the-edge-of-numerical-stability) | All optimizers |
86
+ | **Stochastic Rounding** | Improves precision for BF16 training | BF16 training | Minimal overhead (<5%) | [Revisiting BFloat16 Training](https://arxiv.org/abs/2010.06192) | All optimizers |
87
+ | **atan2** | Robust eps replacement + built-in clipping | Use with Adopt or unstable training | No overhead | [Adam-atan2](https://github.com/lucidrains/adam-atan2-pytorch) | Adam/Adopt/prodigy |
88
+ | **Cautious** | Update only when the direction align with the gradients | should faster the convergence | No overhead | [C-Optim](https://github.com/kyleliang919/C-Optim) | Adam/Adopt/prodigy |
89
+ | **Grams** | Update direction from the gradients | should have a stronger effect than cautious | No overhead | [Grams](https://github.com/Gunale0926/Grams) | Adam/Adopt/prodigy |
90
+
91
+ ---
92
+
93
+ ## Simplified_AdEMAMix Parameters
94
+ Simplified_AdEMAMix replaces standard momentum with an accumulator for better small-large batch performance.
95
+
96
+ | Parameter | Recommended Values | Description |
97
+ |-----------|---------------------|-------------|
98
+ | `beta1` | 0.9 (large BS), 0.99-0.9999 (small BS) | Determines memory length of accumulator |
99
+ | `alpha` | 100-10 (small BS), 1-0 (large BS) | Gradient smoothing factor |
100
+
101
+ **Alpha Tuning Guide**:
102
+ | Batch Size | Recommended α | Rationale |
103
+ |------------|---------------|-----------|
104
+ | Small (≤32) | 100, 50, 20, 10 | Emphasizes recent gradients for quick adaptation |
105
+ | Medium (32-512) | 10, 5, 2, 1 | Balanced approach |
106
+ | Large (≥512) | 1, 0.5, 0 | Emphasizes historical gradients for stability |
107
+
108
+ ⚠️ **Important**: Use **~100x smaller learning rate** with Simplified_AdEMAMix compared to AdamW (e.g., 1e-6 instead of 1e-4)
109
+
110
+ ### 📊 Performance Validation
111
+ Small Batch Training (SDXL, BS=2, 1.8K steps)
112
+ ![Training Comparison](https://github.com/user-attachments/assets/7eff0671-cc59-47fc-8b63-d5205456d649)
113
+
114
+ - **🟢 Prodigy_adv** (beta1=0.9, d0=1e-5): Final LR=2.9e-4
115
+ - **🔵 Prodigy_adv + Simplified_AdEMAMix** (beta1=0.99, α=100, d0=1e-7): Final LR=5.8e-6
116
+
117
+ **Results**:
118
+ - Simplified_AdEMAMix shows faster convergence and better final performance
119
+ - D-Adaptation automatically handles aggressive updates (50x smaller LR)
120
+ - Generated samples show significantly better quality with Simplified_AdEMAMix
121
+
122
+ ---
123
+
124
+ ## ⚠️ Known Limitations
125
+
126
+ ### 1. Prodigy_Adv Sensitivity
127
+ - Highly sensitive to gradient modifications (Adopt normalization, low-rank factorization)
128
+ - May fail to increase learning rate in some LoRA scenarios
129
+ - **Fix**: Disable factorization or set beta1=0
130
+
131
+ ### 2. Aggressive Learning Rates
132
+ - Can destabilize factored first moment
133
+ - **Recommendation**: Check Prodigy learning rate as reference for safe LR threshold
134
+
135
+ ---
136
+
137
+ ## 📚 References
138
+
139
+ 1. [SMMF: Square-Matricized Momentum Factorization](https://arxiv.org/abs/2412.08894)
140
+ 2. [The AdEMAMix Optimizer: Better, Faster, Older](https://arxiv.org/abs/2409.03137)
141
+ 3. [Connections between Schedule-Free Optimizers, AdEMAMix, and Accelerated SGD Variants](https://arxiv.org/abs/2502.02431)
142
+
143
+ ---
@@ -16,4 +16,4 @@ __all__ = [
16
16
  "Lion_Prodigy_adv",
17
17
  ]
18
18
 
19
- __version__ = "0.1.7"
19
+ __version__ = "0.1.9"
@@ -55,7 +55,7 @@ class AdamW_adv(torch.optim.Optimizer):
55
55
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
56
56
  the scheduler is disabled. (default: None)
57
57
  factored (bool): whether to use the factorization or disable it to use
58
- the uncompressed optimizer. (default: True)
58
+ the uncompressed optimizer. (default: False)
59
59
  """
60
60
 
61
61
  def __init__(
@@ -76,7 +76,7 @@ class AdamW_adv(torch.optim.Optimizer):
76
76
  beta3_ema: float = 0.9999,
77
77
  alpha: float = 5.0,
78
78
  t_alpha: int | None = None,
79
- factored: bool = True,
79
+ factored: bool = False,
80
80
  ):
81
81
  if not (lr >= 0.0):
82
82
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -86,6 +86,9 @@ class AdamW_adv(torch.optim.Optimizer):
86
86
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
87
87
  if not (weight_decay >= 0.0):
88
88
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
89
+ if use_cautious and use_grams:
90
+ print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
91
+ use_cautious = False
89
92
 
90
93
  defaults = {
91
94
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
@@ -216,7 +219,10 @@ class AdamW_adv(torch.optim.Optimizer):
216
219
  del unpacked_sign_slow
217
220
 
218
221
  mt_slow.mul_(beta3_ema).add_(grad_reshaped, alpha=1.0 - beta3_ema)
219
- update = mt + (alpha_t * mt_slow) if beta1 > 0 else grad_reshaped + (alpha_t * mt_slow)
222
+ if beta1 > 0:
223
+ update = torch.add(mt, mt_slow, alpha=alpha_t)
224
+ else:
225
+ update = torch.add(grad_reshaped, mt_slow, alpha=alpha_t)
220
226
  else:
221
227
  update = mt.clone() if beta1 > 0 else grad_reshaped.clone()
222
228
  del grad_reshaped
@@ -262,7 +268,10 @@ class AdamW_adv(torch.optim.Optimizer):
262
268
  if self.use_AdEMAMix:
263
269
  exp_avg_slow = state['exp_avg_slow']
264
270
  exp_avg_slow.mul_(beta3_ema).add_(grad, alpha=1 - beta3_ema)
265
- update = exp_avg + (alpha_t * exp_avg_slow) if beta1 > 0 else grad + (alpha_t * exp_avg_slow)
271
+ if beta1 > 0:
272
+ update = torch.add(exp_avg, exp_avg_slow, alpha=alpha_t)
273
+ else:
274
+ update = torch.add(grad, exp_avg_slow, alpha=alpha_t)
266
275
  else:
267
276
  update = exp_avg.clone() if beta1 > 0 else grad.clone()
268
277
 
@@ -62,8 +62,18 @@ class Adopt_adv(torch.optim.Optimizer):
62
62
  the warmup, `alpha` ramps from 0 to its target value. If `None`,
63
63
  the scheduler is disabled and the full `alpha` value is used from
64
64
  the start. (default: None)
65
+ Simplified_AdEMAMix (bool): whether to use the Simplified AdEMAMix update rule.
66
+ This changes the EMA to accumulator and the update numerator to `alpha_grad * grad + mt`, which can be
67
+ more responsive, especially for small batch sizes. Enabling this will
68
+ automatically disable `use_AdEMAMix`, `use_cautious`, `use_grams`,
69
+ and `use_atan2`. (default: False)
70
+ alpha_grad (float): Mixing coefficient for the Simplified AdEMAMix update rule
71
+ (only used when `Simplified_AdEMAMix` is `True`). Controls the weight of the
72
+ current gradient. For small batch sizes, use high values (e.g., 10-100) to be
73
+ more responsive. For large batch sizes, use low values (e.g., 0-1) for
74
+ stability. (default: 100.0)
65
75
  factored (bool): whether to use the factorization or disable it to use
66
- the uncompressed optimizer. (default: True)
76
+ the uncompressed optimizer. (default: False)
67
77
  """
68
78
 
69
79
  def __init__(
@@ -77,14 +87,16 @@ class Adopt_adv(torch.optim.Optimizer):
77
87
  vector_reshape: bool = True,
78
88
  stochastic_rounding: bool = True,
79
89
  use_atan2: bool = False,
80
- use_cautious: bool = True,
90
+ use_cautious: bool = False,
81
91
  use_grams: bool = False,
82
92
  use_orthograd: bool = False,
83
93
  use_AdEMAMix: bool = False,
84
94
  beta3_ema: float = 0.9999,
85
95
  alpha: float = 5.0,
86
96
  t_alpha: int | None = None,
87
- factored: bool = True,
97
+ Simplified_AdEMAMix: bool = False,
98
+ alpha_grad: float = 100.0,
99
+ factored: bool = False,
88
100
  ):
89
101
  if not (lr >= 0.0):
90
102
  raise ValueError(f"Learning-rate should be >= 0.0. Got {lr}")
@@ -94,19 +106,34 @@ class Adopt_adv(torch.optim.Optimizer):
94
106
  raise ValueError(f"Epsilon should be >= 0.0. Got {eps}")
95
107
  if not (weight_decay >= 0.0):
96
108
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
109
+ if use_cautious and use_grams:
110
+ print("Warning: use_cautious is incompatible with use_grams, Disabling use_cautious.")
111
+ use_cautious = False
112
+ if betas[0] == 0.0 and Simplified_AdEMAMix:
113
+ raise ValueError(f"Beta1 cannot be 0.0 when using Simplified_AdEMAMix. Got {betas[0]}")
114
+ if use_AdEMAMix and Simplified_AdEMAMix:
115
+ print("Warning: use_AdEMAMix is incompatible with Simplified_AdEMAMix, Disabling use_AdEMAMix.")
116
+ if use_grams and Simplified_AdEMAMix:
117
+ print("Warning: use_grams is incompatible with Simplified_AdEMAMix, Disabling use_grams.")
118
+ if use_cautious and Simplified_AdEMAMix:
119
+ print("Warning: use_cautious is incompatible with Simplified_AdEMAMix, Disabling use_cautious.")
120
+ if use_atan2 and Simplified_AdEMAMix:
121
+ print("Warning: use_atan2 is incompatible with Simplified_AdEMAMix. Disabling use_atan2.")
122
+ use_atan2 = False
97
123
 
98
124
  defaults = {
99
125
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
100
126
  "vector_reshape": vector_reshape, "beta3_ema": beta3_ema, "alpha": alpha,
101
- "t_alpha": t_alpha,
127
+ "t_alpha": t_alpha, "alpha_grad": alpha_grad,
102
128
  }
103
129
  self.clip_lambda = clip_lambda
104
130
  self.stochastic_rounding = stochastic_rounding
105
- self.use_atan2 = use_atan2
106
- self.use_cautious = use_cautious
107
- self.use_grams = use_grams
131
+ self.use_atan2 = use_atan2 and not Simplified_AdEMAMix
132
+ self.use_cautious = use_cautious and not Simplified_AdEMAMix
133
+ self.use_grams = use_grams and not Simplified_AdEMAMix
108
134
  self.use_orthograd = use_orthograd
109
- self.use_AdEMAMix = use_AdEMAMix
135
+ self.use_AdEMAMix = use_AdEMAMix and not Simplified_AdEMAMix
136
+ self.Simplified_AdEMAMix = Simplified_AdEMAMix
110
137
  self.factored = factored
111
138
  super().__init__(params, defaults)
112
139
 
@@ -185,6 +212,8 @@ class Adopt_adv(torch.optim.Optimizer):
185
212
  alpha_t = alpha
186
213
  if t_alpha is not None and t_alpha > 0 and current_step < t_alpha:
187
214
  alpha_t = min(current_step * alpha / t_alpha, alpha)
215
+ if self.Simplified_AdEMAMix:
216
+ alpha_grad = group["alpha_grad"]
188
217
 
189
218
  if state['factored']:
190
219
  d1, d2 = state['effective_shape']
@@ -224,7 +253,10 @@ class Adopt_adv(torch.optim.Optimizer):
224
253
  del denom
225
254
 
226
255
  # ADOPT Step B: Update momentum m_t using normalized gradient
227
- mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
256
+ if self.Simplified_AdEMAMix:
257
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0)
258
+ else:
259
+ mt.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
228
260
  if self.use_grams:
229
261
  mt = grad_reshaped.sign() * mt.abs()
230
262
  elif self.use_cautious:
@@ -235,8 +267,10 @@ class Adopt_adv(torch.optim.Optimizer):
235
267
 
236
268
  if self.use_AdEMAMix:
237
269
  mt_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
238
- update = mt + (alpha_t * mt_slow)
270
+ update = torch.add(mt, m_slow, alpha=alpha_t)
239
271
  update = update.view(p.shape)
272
+ elif self.Simplified_AdEMAMix:
273
+ update = torch.add(mt, grad_reshaped, alpha=alpha_grad)
240
274
  else:
241
275
  update = mt.view(p.shape)
242
276
 
@@ -283,7 +317,10 @@ class Adopt_adv(torch.optim.Optimizer):
283
317
  del denom
284
318
 
285
319
  # ADOPT Step B: Update momentum m_t
286
- m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
320
+ if self.Simplified_AdEMAMix:
321
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0)
322
+ else:
323
+ m.mul_(beta1).add_(normalized_grad, alpha=1.0 - beta1)
287
324
 
288
325
  if self.use_grams:
289
326
  m = grad.sign() * m.abs()
@@ -295,9 +332,11 @@ class Adopt_adv(torch.optim.Optimizer):
295
332
 
296
333
  if self.use_AdEMAMix:
297
334
  m_slow.mul_(beta3_ema).add_(normalized_grad, alpha=1.0 - beta3_ema)
298
- update = m + (alpha_t * m_slow)
335
+ update = torch.add(m, m_slow, alpha=alpha_t)
336
+ elif self.Simplified_AdEMAMix:
337
+ update = torch.add(m, grad, alpha=alpha_grad)
299
338
  else:
300
- update = m
339
+ update = m.clone()
301
340
 
302
341
  if self.use_atan2:
303
342
  update.mul_(group['lr'] * 1.2732395447351628)
@@ -33,8 +33,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
33
33
  (default: 0.0).
34
34
  factored (bool): whether to use the factorization or use the
35
35
  uncompressed optimizer. (default: True)
36
- variance_reduction (bool): whether to use the variance reduction technique
37
- from "Convergence Analysis of the Lion Optimizer" (arXiv:2508.12327v1). (default: False).
38
36
  d0 (float):
39
37
  Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing.
40
38
  d_coef (float):
@@ -66,7 +64,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
66
64
  use_cautious: bool = False,
67
65
  clip_threshold: float = 0.0,
68
66
  factored: bool = True,
69
- variance_reduction: bool = False,
70
67
  # prodigy parameters
71
68
  beta3: float = None,
72
69
  d0: float = 1e-6,
@@ -97,7 +94,6 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
97
94
  self.stochastic_rounding = stochastic_rounding
98
95
  self.use_cautious = use_cautious
99
96
  self.factored = factored
100
- self.variance_reduction = variance_reduction
101
97
  self.fsdp_in_use = fsdp_in_use
102
98
  super().__init__(params, defaults)
103
99
  # Global state for accumulating metrics across parameter updates within a single step.
@@ -183,12 +179,8 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
183
179
  state['mv_m_nmf'] = torch.zeros(d2, device=p.device, dtype=dtype)
184
180
  packed_d2 = (d2 + 7) // 8
185
181
  state['sign'] = torch.zeros((d1, packed_d2), dtype=torch.uint8, device=p.device)
186
- if self.variance_reduction:
187
- state['prev_grad'] = torch.zeros((d1, d2), device=p.device, dtype=dtype)
188
182
  else: # Fallback to standard Lion
189
183
  state['exp_avg'] = torch.zeros_like(p, device=p.device, dtype=dtype)
190
- if self.variance_reduction:
191
- state['prev_grad'] = torch.zeros_like(p, device=p.device, dtype=dtype)
192
184
 
193
185
  if state['factored']:
194
186
  # Factored Path
@@ -215,20 +207,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
215
207
  update_for_param = signed_update.view(p.shape).mul(self.dlr)
216
208
 
217
209
  # Update momentum m_t = β2*m_{t-1} + (1-β2)*lr*g_t
218
- if self.variance_reduction:
219
- if state['step'] == 1:
220
- exp_avg.copy_(grad_reshaped)
221
- else:
222
- # Heuristic Prodigy-STORM update
223
- correction = exp_avg.sub(state['prev_grad'])
224
- grad_alpha = self.d * (1 - self.beta2) + self.beta2
225
- exp_avg.copy_(grad_reshaped).mul_(grad_alpha).add_(correction, alpha=self.beta2)
226
- del correction, grad_alpha
227
- state['prev_grad'].copy_(grad_reshaped)
228
- else:
229
- # Standard Prodigy-Lion
230
- alpha = self.d * (1 - self.beta2)
231
- exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=alpha)
210
+ exp_avg.mul_(self.beta2).add_(grad_reshaped, alpha=self.d * (1 - self.beta2))
232
211
  del grad_reshaped
233
212
 
234
213
  # Compress new momentum m_t and store factors
@@ -254,20 +233,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
254
233
  update_for_param = signed_update.mul(self.dlr)
255
234
 
256
235
  # Update momentum
257
- if self.variance_reduction:
258
- if state['step'] == 1:
259
- exp_avg.copy_(grad)
260
- else:
261
- # Heuristic Prodigy-STORM update
262
- correction = exp_avg.sub(state['prev_grad'])
263
- grad_alpha = self.d * (1 - self.beta2) + self.beta2
264
- exp_avg.copy_(grad).mul_(grad_alpha).add_(correction, alpha=self.beta2)
265
- del grad_alpha, correction
266
- state['prev_grad'].copy_(grad)
267
- else:
268
- # Standard Prodigy-Lion
269
- alpha = self.d * (1 - self.beta2)
270
- exp_avg.mul_(self.beta2).add_(grad, alpha=alpha)
236
+ exp_avg.mul_(self.beta2).add_(grad, alpha=self.d * (1 - self.beta2))
271
237
 
272
238
  # --- Accumulate Prodigy stats ---
273
239
  d0, safeguard_warmup, slice_p = group['d0'], group['safeguard_warmup'], group['slice_p']
@@ -298,7 +264,7 @@ class Lion_Prodigy_adv(torch.optim.Optimizer):
298
264
  else:
299
265
  p.data.add_(-update_for_param)
300
266
 
301
- del update_for_param
267
+ del update_for_param
302
268
 
303
269
  @torch.no_grad()
304
270
  def step(self, closure: Optional[callable] = None):