adv-optm 2.4.dev23__tar.gz → 2.4.dev25__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adv_optm-2.4.dev25/PKG-INFO +109 -0
- adv_optm-2.4.dev25/README.md +78 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/__init__.py +1 -1
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/AdaMuon_adv.py +5 -8
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/AdamW_adv.py +4 -5
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/Adopt_adv.py +5 -6
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/Lion_adv.py +0 -2
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/Muon_adv.py +4 -8
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/Prodigy_adv.py +3 -4
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/SignSGD_adv.py +24 -11
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/SinkSGD_adv.py +21 -8
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/Muon_AuxAdam.py +1 -0
- adv_optm-2.4.dev25/adv_optm/util/OrthoGrad.py +80 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/centered_decay.py +22 -15
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/param_update.py +0 -68
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/scaled_optm.py +46 -27
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/state_util.py +6 -30
- adv_optm-2.4.dev25/adv_optm.egg-info/PKG-INFO +109 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/setup.py +1 -1
- adv_optm-2.4.dev23/PKG-INFO +0 -202
- adv_optm-2.4.dev23/README.md +0 -171
- adv_optm-2.4.dev23/adv_optm/util/OrthoGrad.py +0 -19
- adv_optm-2.4.dev23/adv_optm.egg-info/PKG-INFO +0 -202
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/LICENSE +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/optim/__init__.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/Kourkoutas.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/Muon_util.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/__init__.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/factorization_util.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/lion_k.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/signed_util.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/sinkhorn.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm/util/update_util.py +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm.egg-info/SOURCES.txt +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm.egg-info/dependency_links.txt +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm.egg-info/requires.txt +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/adv_optm.egg-info/top_level.txt +0 -0
- {adv_optm-2.4.dev23 → adv_optm-2.4.dev25}/setup.cfg +0 -0
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: adv_optm
|
|
3
|
+
Version: 2.4.dev25
|
|
4
|
+
Summary: A family of highly efficient, lightweight yet powerful optimizers.
|
|
5
|
+
Home-page: https://github.com/Koratahiu/Advanced_Optimizers
|
|
6
|
+
Author: Koratahiu
|
|
7
|
+
Author-email: hiuhonor@gmail.com
|
|
8
|
+
License: Apache 2.0
|
|
9
|
+
Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
12
|
+
Classifier: Operating System :: OS Independent
|
|
13
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
14
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
15
|
+
Requires-Python: >=3.8
|
|
16
|
+
Description-Content-Type: text/markdown
|
|
17
|
+
License-File: LICENSE
|
|
18
|
+
Requires-Dist: torch>=2.1
|
|
19
|
+
Dynamic: author
|
|
20
|
+
Dynamic: author-email
|
|
21
|
+
Dynamic: classifier
|
|
22
|
+
Dynamic: description
|
|
23
|
+
Dynamic: description-content-type
|
|
24
|
+
Dynamic: home-page
|
|
25
|
+
Dynamic: keywords
|
|
26
|
+
Dynamic: license
|
|
27
|
+
Dynamic: license-file
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
31
|
+
|
|
32
|
+
# Advanced Optimizers (AIO)
|
|
33
|
+
|
|
34
|
+
A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
|
|
35
|
+
|
|
36
|
+
[](https://pypi.org/project/adv_optm/)
|
|
37
|
+
|
|
38
|
+
## 🔥 What's New
|
|
39
|
+
|
|
40
|
+
### In 2.4.x:
|
|
41
|
+
|
|
42
|
+
This update introduces a whole refactor of the library with many new features and changes:
|
|
43
|
+
|
|
44
|
+
- New optimizers state mode option (`state_precision`) with many precision settings for the optimizer states: rank-2 factored mode (`factored`), full FP32 (`fp32`), BF16 with Stochastic Rounding (`bf16_sr`), int8/uint8 with Stochastic Rounding (`int8_sr`), FP16 (`fp16`)
|
|
45
|
+
- Added new powerful optimizer: SinkSGD_adv.
|
|
46
|
+
- Added spectral scaling option to all optimizers, achieving width/rank invariant updates.
|
|
47
|
+
- Added Nesterov momentum (`nesterov`) and its coef (`nesterov_coef`) to all optimizers.
|
|
48
|
+
- Added centered weight decay (`centered_wd`), to pull the weights toward their pre-train state (anchor)
|
|
49
|
+
- anchor precision can be changed to save memory (`centered_wd_mode`): full, float8, int8, int4
|
|
50
|
+
- Added Fisher Weight Decay option for Adam variants (`fisher_wd`).
|
|
51
|
+
- Paper: [FAdam...](https://arxiv.org/abs/2405.12807)
|
|
52
|
+
- Added Factored Second Moment option for Adam variants (`factored_2nd`). This works alongside any `state_precision` setting.
|
|
53
|
+
- Added Geometric Weight Decay for SinkSGD_adv and SignSGD_adv.
|
|
54
|
+
- Added new powerful mode: variance normalized momentum (`normed_momentum`). Which applies the optimizer normalization before the momentum (also called as Normalization then momentum NtM)
|
|
55
|
+
- For: AdamW_adv, SignSGD_adv, SinkSGD_adv.
|
|
56
|
+
- Added Variance/Confidence Preconditioning (`snr_cond`) for SignSGD_adv, SinkSGD_adv.
|
|
57
|
+
- Only works with `normed_momentum`.
|
|
58
|
+
- Technical reports: [AASS](https://koratahiu.github.io/aass/), and [sink-v](https://koratahiu.github.io/sink-v/).
|
|
59
|
+
- Added Adaptive Stochastic Sign with L_inf preconditioning (`stochastic_sign`) for SignSGD_Adv and Lion_adv.
|
|
60
|
+
- Improved CANS (`accelerated_ns`) for Muon variants, by integrating dynamic lower bound.
|
|
61
|
+
- Removed Simplified_AdEMAMix optimizer and its settings in other optimizers, they are now replaced by Nesterov momentum and its coef. Which is better and less hard to tune.
|
|
62
|
+
- Removed cautious and grams modes, as they were heuristic and not working well.
|
|
63
|
+
- Removed optimizers: Lion_Prodigy_adv, and Simplified_AdEMAMix.
|
|
64
|
+
|
|
65
|
+
### in 2.1.x
|
|
66
|
+
|
|
67
|
+
- Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
|
|
68
|
+
- More info coming soon.
|
|
69
|
+
|
|
70
|
+
### in 2.0.x
|
|
71
|
+
|
|
72
|
+
* Implemented torch.compile for all advanced optimizers. Enabled via (compiled_optimizer=True) to fuse and optimize the optimizer step path.
|
|
73
|
+
* Better and improved 1-bit factored mode via (nnmf_factor=True).
|
|
74
|
+
* Various improvements across the optimizers.
|
|
75
|
+
|
|
76
|
+
### in 1.2.x
|
|
77
|
+
* Added **advanced variants** of [Muon optimizer](https://kellerjordan.github.io/posts/muon/) with **features** and **settings** from recent papers.
|
|
78
|
+
|
|
79
|
+
| Optimizer | Description |
|
|
80
|
+
|---|---|
|
|
81
|
+
| `Muon_adv` | Advanced Muon implementation with CANS, NorMuon, Low-Rank ortho, etc. features. |
|
|
82
|
+
| `AdaMuon_adv` | Advanced AdaMuon implementation, which combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
|
|
83
|
+
|
|
84
|
+
> *Documentation coming soon.*
|
|
85
|
+
|
|
86
|
+
* Implemented [Cautious Weight Decay](https://arxiv.org/abs/2510.12402) for all advanced optimizers.
|
|
87
|
+
|
|
88
|
+
* Improved parameter update and weight decay for **BF16** with **stochastic rounding**. The updates are now accumulated in **float32** and rounded once at the end.
|
|
89
|
+
|
|
90
|
+
* Use fused and in-place operations whenever possible for all advanced optimizers.
|
|
91
|
+
|
|
92
|
+
* **Prodigy variants** are now **50% faster** by [avoiding CUDA syncs](https://github.com/Koratahiu/Advanced_Optimizers/pull/5). Thanks to **@dxqb**!
|
|
93
|
+
|
|
94
|
+
---
|
|
95
|
+
|
|
96
|
+
## 📦 Installation
|
|
97
|
+
|
|
98
|
+
```bash
|
|
99
|
+
pip install adv_optm
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
---
|
|
103
|
+
|
|
104
|
+
## 🧠 Core Innovations
|
|
105
|
+
|
|
106
|
+
This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training.
|
|
107
|
+
|
|
108
|
+
---
|
|
109
|
+
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
# Advanced Optimizers (AIO)
|
|
2
|
+
|
|
3
|
+
A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
|
|
4
|
+
|
|
5
|
+
[](https://pypi.org/project/adv_optm/)
|
|
6
|
+
|
|
7
|
+
## 🔥 What's New
|
|
8
|
+
|
|
9
|
+
### In 2.4.x:
|
|
10
|
+
|
|
11
|
+
This update introduces a whole refactor of the library with many new features and changes:
|
|
12
|
+
|
|
13
|
+
- New optimizers state mode option (`state_precision`) with many precision settings for the optimizer states: rank-2 factored mode (`factored`), full FP32 (`fp32`), BF16 with Stochastic Rounding (`bf16_sr`), int8/uint8 with Stochastic Rounding (`int8_sr`), FP16 (`fp16`)
|
|
14
|
+
- Added new powerful optimizer: SinkSGD_adv.
|
|
15
|
+
- Added spectral scaling option to all optimizers, achieving width/rank invariant updates.
|
|
16
|
+
- Added Nesterov momentum (`nesterov`) and its coef (`nesterov_coef`) to all optimizers.
|
|
17
|
+
- Added centered weight decay (`centered_wd`), to pull the weights toward their pre-train state (anchor)
|
|
18
|
+
- anchor precision can be changed to save memory (`centered_wd_mode`): full, float8, int8, int4
|
|
19
|
+
- Added Fisher Weight Decay option for Adam variants (`fisher_wd`).
|
|
20
|
+
- Paper: [FAdam...](https://arxiv.org/abs/2405.12807)
|
|
21
|
+
- Added Factored Second Moment option for Adam variants (`factored_2nd`). This works alongside any `state_precision` setting.
|
|
22
|
+
- Added Geometric Weight Decay for SinkSGD_adv and SignSGD_adv.
|
|
23
|
+
- Added new powerful mode: variance normalized momentum (`normed_momentum`). Which applies the optimizer normalization before the momentum (also called as Normalization then momentum NtM)
|
|
24
|
+
- For: AdamW_adv, SignSGD_adv, SinkSGD_adv.
|
|
25
|
+
- Added Variance/Confidence Preconditioning (`snr_cond`) for SignSGD_adv, SinkSGD_adv.
|
|
26
|
+
- Only works with `normed_momentum`.
|
|
27
|
+
- Technical reports: [AASS](https://koratahiu.github.io/aass/), and [sink-v](https://koratahiu.github.io/sink-v/).
|
|
28
|
+
- Added Adaptive Stochastic Sign with L_inf preconditioning (`stochastic_sign`) for SignSGD_Adv and Lion_adv.
|
|
29
|
+
- Improved CANS (`accelerated_ns`) for Muon variants, by integrating dynamic lower bound.
|
|
30
|
+
- Removed Simplified_AdEMAMix optimizer and its settings in other optimizers, they are now replaced by Nesterov momentum and its coef. Which is better and less hard to tune.
|
|
31
|
+
- Removed cautious and grams modes, as they were heuristic and not working well.
|
|
32
|
+
- Removed optimizers: Lion_Prodigy_adv, and Simplified_AdEMAMix.
|
|
33
|
+
|
|
34
|
+
### in 2.1.x
|
|
35
|
+
|
|
36
|
+
- Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
|
|
37
|
+
- More info coming soon.
|
|
38
|
+
|
|
39
|
+
### in 2.0.x
|
|
40
|
+
|
|
41
|
+
* Implemented torch.compile for all advanced optimizers. Enabled via (compiled_optimizer=True) to fuse and optimize the optimizer step path.
|
|
42
|
+
* Better and improved 1-bit factored mode via (nnmf_factor=True).
|
|
43
|
+
* Various improvements across the optimizers.
|
|
44
|
+
|
|
45
|
+
### in 1.2.x
|
|
46
|
+
* Added **advanced variants** of [Muon optimizer](https://kellerjordan.github.io/posts/muon/) with **features** and **settings** from recent papers.
|
|
47
|
+
|
|
48
|
+
| Optimizer | Description |
|
|
49
|
+
|---|---|
|
|
50
|
+
| `Muon_adv` | Advanced Muon implementation with CANS, NorMuon, Low-Rank ortho, etc. features. |
|
|
51
|
+
| `AdaMuon_adv` | Advanced AdaMuon implementation, which combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
|
|
52
|
+
|
|
53
|
+
> *Documentation coming soon.*
|
|
54
|
+
|
|
55
|
+
* Implemented [Cautious Weight Decay](https://arxiv.org/abs/2510.12402) for all advanced optimizers.
|
|
56
|
+
|
|
57
|
+
* Improved parameter update and weight decay for **BF16** with **stochastic rounding**. The updates are now accumulated in **float32** and rounded once at the end.
|
|
58
|
+
|
|
59
|
+
* Use fused and in-place operations whenever possible for all advanced optimizers.
|
|
60
|
+
|
|
61
|
+
* **Prodigy variants** are now **50% faster** by [avoiding CUDA syncs](https://github.com/Koratahiu/Advanced_Optimizers/pull/5). Thanks to **@dxqb**!
|
|
62
|
+
|
|
63
|
+
---
|
|
64
|
+
|
|
65
|
+
## 📦 Installation
|
|
66
|
+
|
|
67
|
+
```bash
|
|
68
|
+
pip install adv_optm
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
---
|
|
72
|
+
|
|
73
|
+
## 🧠 Core Innovations
|
|
74
|
+
|
|
75
|
+
This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training.
|
|
76
|
+
|
|
77
|
+
---
|
|
78
|
+
|
|
@@ -99,7 +99,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
99
99
|
use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
|
|
100
100
|
either here or via `optim_type` in parameter groups. (default: None)
|
|
101
101
|
state_precision (str): Precision for Muon optimizer states. Options: 'auto' (parameter dtype), 'fp32',
|
|
102
|
-
'bf16_sr' (BF16 with stochastic rounding), '
|
|
102
|
+
'bf16_sr' (BF16 with stochastic rounding), 'int8_sr'.
|
|
103
103
|
(default: 'auto')
|
|
104
104
|
factored_2nd (bool): Factorize only the second moment (v_t) using SMMF
|
|
105
105
|
low-rank compression while keeping the first moment (momentum_buffer)
|
|
@@ -123,7 +123,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
123
123
|
adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
|
|
124
124
|
adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
|
|
125
125
|
adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
|
|
126
|
-
adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', '
|
|
126
|
+
adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'int8_sr', 'factored'. (default: 'auto')
|
|
127
127
|
adam_nnmf_factor (bool): 1-bit factored for AdamW.
|
|
128
128
|
adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
|
|
129
129
|
"""
|
|
@@ -157,7 +157,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
157
157
|
# Boolean to spilt param
|
|
158
158
|
use_muon: bool | None = None,
|
|
159
159
|
# States precision (Muon path)
|
|
160
|
-
state_precision: str = "auto", # 'fp32', 'bf16_sr', '
|
|
160
|
+
state_precision: str = "auto", # 'fp32', 'bf16_sr', 'int8_sr'
|
|
161
161
|
# Factorized second moment only
|
|
162
162
|
factored_2nd: bool = False,
|
|
163
163
|
# Update geometry parameters
|
|
@@ -220,7 +220,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
220
220
|
state_precision = "factored"
|
|
221
221
|
|
|
222
222
|
state_precision = state_precision.lower()
|
|
223
|
-
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "
|
|
223
|
+
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
224
224
|
if state_precision not in valid_precisions:
|
|
225
225
|
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
226
226
|
|
|
@@ -374,6 +374,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
374
374
|
d1, d2 = state['effective_shape']
|
|
375
375
|
state['mu_vbuf_nmf'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
|
|
376
376
|
state['mv_vbuf_nmf'] = torch.zeros(d2, device=p.device, dtype=torch.float32)
|
|
377
|
+
state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
|
|
377
378
|
elif not group['normuon_variant']:
|
|
378
379
|
init_state_tensor(state, 'second_momentum_buffer', p.shape, actual_precision, p.device, default_dtype, non_neg=True)
|
|
379
380
|
|
|
@@ -454,8 +455,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
454
455
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
455
456
|
elif actual_precision == 'int8_sr':
|
|
456
457
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
457
|
-
elif actual_precision == 'fp8_sr':
|
|
458
|
-
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
459
458
|
else:
|
|
460
459
|
adam_step_param = Muon_AuxAdam._adam_step_parameter
|
|
461
460
|
|
|
@@ -475,8 +474,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
|
|
|
475
474
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
476
475
|
elif actual_precision == 'int8_sr':
|
|
477
476
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
478
|
-
elif actual_precision == 'fp8_sr':
|
|
479
|
-
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
480
477
|
if group['low_rank_ortho']:
|
|
481
478
|
random_G_sketch = param_update._get_random_noise_for_low_rank_ortho(p, group['ortho_rank'])
|
|
482
479
|
else:
|
|
@@ -84,7 +84,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
84
84
|
while only factorizing the second moment. (default: False)
|
|
85
85
|
state_precision (str): Precision method for Adopt states. Options: 'auto'
|
|
86
86
|
(parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
|
|
87
|
-
stochastic rounding), 'fp16' , '
|
|
87
|
+
stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
|
|
88
88
|
"""
|
|
89
89
|
|
|
90
90
|
def __init__(
|
|
@@ -124,7 +124,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
124
124
|
centered_wd: float = 0.0,
|
|
125
125
|
centered_wd_mode: str = 'float8',
|
|
126
126
|
# States precision
|
|
127
|
-
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', '
|
|
127
|
+
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
|
|
128
128
|
# Factorized second moment only
|
|
129
129
|
factored_2nd: bool = False,
|
|
130
130
|
# SMMF factorization (legacy)
|
|
@@ -145,7 +145,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
145
145
|
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
146
146
|
|
|
147
147
|
state_precision = state_precision.lower()
|
|
148
|
-
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "
|
|
148
|
+
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
149
149
|
if state_precision not in valid_precisions:
|
|
150
150
|
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
151
151
|
|
|
@@ -264,6 +264,7 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
264
264
|
d1, d2 = state['effective_shape']
|
|
265
265
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=torch.float32)
|
|
266
266
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=torch.float32)
|
|
267
|
+
state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=device, dtype=torch.uint8)
|
|
267
268
|
else:
|
|
268
269
|
init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype, non_neg=True)
|
|
269
270
|
|
|
@@ -314,8 +315,6 @@ class AdamW_adv(torch.optim.Optimizer):
|
|
|
314
315
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
315
316
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
316
317
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
317
|
-
elif group['actual_state_precision'] == 'fp8_sr':
|
|
318
|
-
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
319
318
|
step_param_fn = self._compiled_step_parameter
|
|
320
319
|
else:
|
|
321
320
|
step_param_fn = self._step_parameter
|
|
@@ -88,7 +88,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
88
88
|
while only factorizing the second moment. (default: False)
|
|
89
89
|
state_precision (str): Precision method for Adopt states. Options: 'auto'
|
|
90
90
|
(parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
|
|
91
|
-
stochastic rounding), 'fp16' , '
|
|
91
|
+
stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
|
|
92
92
|
"""
|
|
93
93
|
|
|
94
94
|
def __init__(
|
|
@@ -126,7 +126,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
126
126
|
centered_wd: float = 0.0,
|
|
127
127
|
centered_wd_mode: str = 'float8',
|
|
128
128
|
# States precision
|
|
129
|
-
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', '
|
|
129
|
+
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
|
|
130
130
|
# Factorized second moment only
|
|
131
131
|
factored_2nd: bool = False,
|
|
132
132
|
# SMMF factorization (legacy)
|
|
@@ -148,7 +148,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
148
148
|
|
|
149
149
|
|
|
150
150
|
state_precision = state_precision.lower()
|
|
151
|
-
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "
|
|
151
|
+
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
152
152
|
if state_precision not in valid_precisions:
|
|
153
153
|
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
154
154
|
|
|
@@ -236,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
236
236
|
|
|
237
237
|
dtype = torch.float32 if (state['factored'] or req_precision == 'factored') else p.dtype
|
|
238
238
|
|
|
239
|
-
vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', '
|
|
239
|
+
vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'int8_sr']) else dtype
|
|
240
240
|
vt_init = grad.pow(2).to(vt_dtype) * (1 - group['betas'][1])
|
|
241
241
|
|
|
242
242
|
if state['factored']:
|
|
@@ -262,6 +262,7 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
262
262
|
state['effective_shape'] = _get_effective_shape(p.numel())
|
|
263
263
|
d1, d2 = state['effective_shape']
|
|
264
264
|
state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init.view(d1, d2))
|
|
265
|
+
state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
|
|
265
266
|
else:
|
|
266
267
|
init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype)
|
|
267
268
|
set_state(state, 'exp_avg_sq', vt_init, actual_precision, None, non_neg=True)
|
|
@@ -316,8 +317,6 @@ class Adopt_adv(torch.optim.Optimizer):
|
|
|
316
317
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
317
318
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
318
319
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
319
|
-
elif group['actual_state_precision'] == 'fp8_sr':
|
|
320
|
-
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
321
320
|
step_param_fn = self._compiled_step_parameter
|
|
322
321
|
else:
|
|
323
322
|
lr = group['lr']
|
|
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
|
|
|
33
33
|
stochastic_rounding (bool, optional): whether to use stochastic
|
|
34
34
|
rounding for BF16 parameter updates (default: True).
|
|
35
35
|
orthogonal_gradient (bool): whether to orthogonalize the gradient (default: False).
|
|
36
|
-
clip_threshold (float, optional): whether to clip the gradients norm
|
|
37
|
-
per-parameter (default: 0.0).
|
|
38
36
|
kappa_p (float, optional): The p-value for the Lp-norm in Lion-K (domain [1.0, 2.0]).
|
|
39
37
|
- 1.0: Standard Lion (sign update).
|
|
40
38
|
- 2.0: Spherical Lion (normalized L2 update).
|
|
@@ -47,7 +47,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
47
47
|
use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
|
|
48
48
|
either here or via `optim_type` in parameter groups. (default: None)
|
|
49
49
|
state_precision (str): Precision for Muon optimizer states. Options: 'auto' (parameter dtype), 'fp32',
|
|
50
|
-
'bf16_sr' (BF16 with stochastic rounding), '
|
|
50
|
+
'bf16_sr' (BF16 with stochastic rounding), 'int8_sr'.
|
|
51
51
|
(default: 'auto')
|
|
52
52
|
low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
|
|
53
53
|
projects the update to a lower rank before orthogonalization.
|
|
@@ -98,7 +98,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
98
98
|
adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
|
|
99
99
|
adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
|
|
100
100
|
adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
|
|
101
|
-
adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', '
|
|
101
|
+
adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'int8_sr', 'factored'. (default: 'auto')
|
|
102
102
|
adam_nnmf_factor (bool): 1-bit factored for AdamW.
|
|
103
103
|
adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
|
|
104
104
|
"""
|
|
@@ -130,7 +130,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
130
130
|
# Boolean to spilt param
|
|
131
131
|
use_muon: bool | None = None,
|
|
132
132
|
# States precision (Muon path)
|
|
133
|
-
state_precision: str = "auto", # 'fp32', 'bf16_sr', '
|
|
133
|
+
state_precision: str = "auto", # 'fp32', 'bf16_sr', 'int8_sr'
|
|
134
134
|
# Low-rank Muon
|
|
135
135
|
low_rank_ortho: bool = False,
|
|
136
136
|
ortho_rank: int = 128,
|
|
@@ -193,7 +193,7 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
193
193
|
state_precision = "factored"
|
|
194
194
|
|
|
195
195
|
state_precision = state_precision.lower()
|
|
196
|
-
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "
|
|
196
|
+
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
197
197
|
if state_precision not in valid_precisions:
|
|
198
198
|
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
199
199
|
|
|
@@ -406,8 +406,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
406
406
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
407
407
|
elif actual_precision == 'int8_sr':
|
|
408
408
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
409
|
-
elif actual_precision == 'fp8_sr':
|
|
410
|
-
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
411
409
|
else:
|
|
412
410
|
adam_step_param = Muon_AuxAdam._adam_step_parameter
|
|
413
411
|
|
|
@@ -427,8 +425,6 @@ class Muon_adv(torch.optim.Optimizer):
|
|
|
427
425
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
428
426
|
elif actual_precision == 'int8_sr':
|
|
429
427
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
430
|
-
elif actual_precision == 'fp8_sr':
|
|
431
|
-
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
432
428
|
if group['low_rank_ortho']:
|
|
433
429
|
random_G_sketch = param_update._get_random_noise_for_low_rank_ortho(p, group['ortho_rank'])
|
|
434
430
|
else:
|
|
@@ -124,7 +124,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
124
124
|
nesterov: bool = False,
|
|
125
125
|
nesterov_coef: float | None = None,
|
|
126
126
|
# States precision
|
|
127
|
-
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', '
|
|
127
|
+
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
|
|
128
128
|
# Factorized second moment only
|
|
129
129
|
factored_2nd: bool = False,
|
|
130
130
|
# SMMF factorization (legacy)
|
|
@@ -168,7 +168,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
168
168
|
raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
|
|
169
169
|
|
|
170
170
|
state_precision = state_precision.lower()
|
|
171
|
-
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "
|
|
171
|
+
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
172
172
|
if state_precision not in valid_precisions:
|
|
173
173
|
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
174
174
|
|
|
@@ -311,6 +311,7 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
311
311
|
d1, d2 = state['effective_shape']
|
|
312
312
|
state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=torch.float32)
|
|
313
313
|
state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=torch.float32)
|
|
314
|
+
state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
|
|
314
315
|
else:
|
|
315
316
|
init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype, non_neg=True)
|
|
316
317
|
|
|
@@ -358,8 +359,6 @@ class Prodigy_adv(torch.optim.Optimizer):
|
|
|
358
359
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
359
360
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
360
361
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
361
|
-
elif group['actual_state_precision'] == 'fp8_sr':
|
|
362
|
-
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
363
362
|
step_param_fn = self._compiled_step_parameter
|
|
364
363
|
else:
|
|
365
364
|
d = group['d']
|
|
@@ -44,7 +44,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
44
44
|
'int4': Uses 4-bit block-wise quantization (block size 32).
|
|
45
45
|
state_precision (str): Precision method for Adopt states. Options: 'auto'
|
|
46
46
|
(parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
|
|
47
|
-
stochastic rounding), 'fp16' , '
|
|
47
|
+
stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
|
|
48
48
|
nnmf_factor (bool): whether to use the factorization or use the
|
|
49
49
|
uncompressed optimizer. (default: True)
|
|
50
50
|
"""
|
|
@@ -70,13 +70,13 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
70
70
|
nesterov_coef: float | None = None,
|
|
71
71
|
# Normalization then Momentum
|
|
72
72
|
normed_momentum: bool = False,
|
|
73
|
-
# SNR Precondition
|
|
73
|
+
# SNR Precondition (requires normed_momentum)
|
|
74
74
|
snr_cond: bool = False,
|
|
75
75
|
# Centered WD
|
|
76
76
|
centered_wd: float = 0.0,
|
|
77
77
|
centered_wd_mode: str = 'float8',
|
|
78
78
|
# States precision
|
|
79
|
-
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', '
|
|
79
|
+
state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
|
|
80
80
|
# Spectral Normed Optimizer
|
|
81
81
|
spectral_normalization: bool = False,
|
|
82
82
|
# SMMF factorization
|
|
@@ -95,7 +95,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
95
95
|
raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
|
|
96
96
|
|
|
97
97
|
state_precision = state_precision.lower()
|
|
98
|
-
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "
|
|
98
|
+
valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
|
|
99
99
|
if state_precision not in valid_precisions:
|
|
100
100
|
raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
|
|
101
101
|
|
|
@@ -230,8 +230,6 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
230
230
|
random_int_state_tensor = param_update._get_random_int_for_sr(p)
|
|
231
231
|
elif group['actual_state_precision'] == 'int8_sr':
|
|
232
232
|
random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
|
|
233
|
-
elif group['actual_state_precision'] == 'fp8_sr':
|
|
234
|
-
random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
|
|
235
233
|
|
|
236
234
|
if group.get('stochastic_sign', False) and not is_vector:
|
|
237
235
|
random_noise_tensor = param_update._get_random_noise_for_sso(p)
|
|
@@ -254,7 +252,8 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
254
252
|
nesterov = group.get('nesterov', False)
|
|
255
253
|
nesterov_coef = group.get('nesterov_coef', None)
|
|
256
254
|
sso = group.get('stochastic_sign', False)
|
|
257
|
-
|
|
255
|
+
normed_mt = group.get('normed_momentum', False)
|
|
256
|
+
snr_cond = group.get('snr_cond', False) and normed_mt and momentum > 0
|
|
258
257
|
|
|
259
258
|
denom = None
|
|
260
259
|
wd_target = None
|
|
@@ -263,7 +262,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
263
262
|
if group["orthogonal_gradient"]:
|
|
264
263
|
grad = _orthogonalize_gradient(p, grad)
|
|
265
264
|
|
|
266
|
-
if
|
|
265
|
+
if normed_mt:
|
|
267
266
|
if sso:
|
|
268
267
|
grad = apply_stochastic_sign_(grad, noise=random_noise_tensor, is_vector=is_vector)
|
|
269
268
|
else:
|
|
@@ -281,11 +280,18 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
281
280
|
if snr_cond:
|
|
282
281
|
denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_().view_as(p)
|
|
283
282
|
|
|
283
|
+
if nesterov and normed_mt:
|
|
284
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
285
|
+
normed_grad = grad_reshaped * exp_avg.abs()
|
|
286
|
+
|
|
284
287
|
exp_avg.lerp_(grad_reshaped, 1 - momentum)
|
|
285
288
|
|
|
286
289
|
if nesterov:
|
|
287
290
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
288
|
-
|
|
291
|
+
if normed_mt:
|
|
292
|
+
raw_update = normed_grad.lerp_(exp_avg, nv_coef)
|
|
293
|
+
else:
|
|
294
|
+
raw_update = grad_reshaped.lerp(exp_avg, nv_coef)
|
|
289
295
|
else:
|
|
290
296
|
raw_update = exp_avg.clone()
|
|
291
297
|
|
|
@@ -305,11 +311,18 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
305
311
|
if snr_cond:
|
|
306
312
|
denom = (1.0 - exp_avg.square()).clamp_min_(1e-30).sqrt_()
|
|
307
313
|
|
|
314
|
+
if nesterov and normed_mt:
|
|
315
|
+
# Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
|
|
316
|
+
normed_grad = grad * exp_avg.abs()
|
|
317
|
+
|
|
308
318
|
exp_avg.lerp_(grad, 1 - momentum)
|
|
309
319
|
|
|
310
320
|
if nesterov:
|
|
311
321
|
nv_coef = momentum if nesterov_coef is None else nesterov_coef
|
|
312
|
-
|
|
322
|
+
if normed_mt:
|
|
323
|
+
raw_update = normed_grad.lerp_(exp_avg, nv_coef)
|
|
324
|
+
else:
|
|
325
|
+
raw_update = grad.lerp(exp_avg, nv_coef)
|
|
313
326
|
else:
|
|
314
327
|
raw_update = exp_avg.clone()
|
|
315
328
|
|
|
@@ -342,7 +355,7 @@ class SignSGD_adv(torch.optim.Optimizer):
|
|
|
342
355
|
update_scaling = lr * A if snr_cond else lr
|
|
343
356
|
update.mul_(update_scaling)
|
|
344
357
|
|
|
345
|
-
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
|
|
358
|
+
param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target, decoupled=True)
|
|
346
359
|
|
|
347
360
|
def compile(self, *args, **kwargs):
|
|
348
361
|
self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
|