adv-optm 2.4.dev22__tar.gz → 2.4.dev24__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. adv_optm-2.4.dev24/PKG-INFO +109 -0
  2. adv_optm-2.4.dev24/README.md +78 -0
  3. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/__init__.py +1 -1
  4. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/AdaMuon_adv.py +5 -8
  5. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/AdamW_adv.py +4 -5
  6. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/Adopt_adv.py +5 -6
  7. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/Lion_adv.py +0 -2
  8. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/Muon_adv.py +4 -8
  9. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/Prodigy_adv.py +3 -4
  10. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/SignSGD_adv.py +19 -10
  11. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/SinkSGD_adv.py +20 -11
  12. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/Muon_AuxAdam.py +1 -0
  13. adv_optm-2.4.dev24/adv_optm/util/OrthoGrad.py +80 -0
  14. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/centered_decay.py +22 -15
  15. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/param_update.py +0 -68
  16. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/scaled_optm.py +46 -27
  17. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/state_util.py +6 -30
  18. adv_optm-2.4.dev24/adv_optm.egg-info/PKG-INFO +109 -0
  19. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/setup.py +1 -1
  20. adv_optm-2.4.dev22/PKG-INFO +0 -202
  21. adv_optm-2.4.dev22/README.md +0 -171
  22. adv_optm-2.4.dev22/adv_optm/util/OrthoGrad.py +0 -19
  23. adv_optm-2.4.dev22/adv_optm.egg-info/PKG-INFO +0 -202
  24. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/LICENSE +0 -0
  25. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/optim/__init__.py +0 -0
  26. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/Kourkoutas.py +0 -0
  27. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/Muon_util.py +0 -0
  28. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/__init__.py +0 -0
  29. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/factorization_util.py +0 -0
  30. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/lion_k.py +0 -0
  31. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/signed_util.py +0 -0
  32. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/sinkhorn.py +0 -0
  33. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm/util/update_util.py +0 -0
  34. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm.egg-info/SOURCES.txt +0 -0
  35. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm.egg-info/dependency_links.txt +0 -0
  36. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm.egg-info/requires.txt +0 -0
  37. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/adv_optm.egg-info/top_level.txt +0 -0
  38. {adv_optm-2.4.dev22 → adv_optm-2.4.dev24}/setup.cfg +0 -0
@@ -0,0 +1,109 @@
1
+ Metadata-Version: 2.4
2
+ Name: adv_optm
3
+ Version: 2.4.dev24
4
+ Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
+ Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
+ Author: Koratahiu
7
+ Author-email: hiuhonor@gmail.com
8
+ License: Apache 2.0
9
+ Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch>=2.1
19
+ Dynamic: author
20
+ Dynamic: author-email
21
+ Dynamic: classifier
22
+ Dynamic: description
23
+ Dynamic: description-content-type
24
+ Dynamic: home-page
25
+ Dynamic: keywords
26
+ Dynamic: license
27
+ Dynamic: license-file
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
31
+
32
+ # Advanced Optimizers (AIO)
33
+
34
+ A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
35
+
36
+ [![PyPI](https://img.shields.io/pypi/v/adv_optm)](https://pypi.org/project/adv_optm/)
37
+
38
+ ## 🔥 What's New
39
+
40
+ ### In 2.4.x:
41
+
42
+ This update introduces a whole refactor of the library with many new features and changes:
43
+
44
+ - New optimizers state mode option (`state_precision`) with many precision settings for the optimizer states: rank-2 factored mode (`factored`), full FP32 (`fp32`), BF16 with Stochastic Rounding (`bf16_sr`), int8/uint8 with Stochastic Rounding (`int8_sr`), FP16 (`fp16`)
45
+ - Added new powerful optimizer: SinkSGD_adv.
46
+ - Added spectral scaling option to all optimizers, achieving width/rank invariant updates.
47
+ - Added Nesterov momentum (`nesterov`) and its coef (`nesterov_coef`) to all optimizers.
48
+ - Added centered weight decay (`centered_wd`), to pull the weights toward their pre-train state (anchor)
49
+ - anchor precision can be changed to save memory (`centered_wd_mode`): full, float8, int8, int4
50
+ - Added Fisher Weight Decay option for Adam variants (`fisher_wd`).
51
+ - Paper: [FAdam...](https://arxiv.org/abs/2405.12807)
52
+ - Added Factored Second Moment option for Adam variants (`factored_2nd`). This works alongside any `state_precision` setting.
53
+ - Added Geometric Weight Decay for SinkSGD_adv and SignSGD_adv.
54
+ - Added new powerful mode: variance normalized momentum (`normed_momentum`). Which applies the optimizer normalization before the momentum (also called as Normalization then momentum NtM)
55
+ - For: AdamW_adv, SignSGD_adv, SinkSGD_adv.
56
+ - Added Variance/Confidence Preconditioning (`snr_cond`) for SignSGD_adv, SinkSGD_adv.
57
+ - Only works with `normed_momentum`.
58
+ - Technical reports: [AASS](https://koratahiu.github.io/aass/), and [sink-v](https://koratahiu.github.io/sink-v/).
59
+ - Added Adaptive Stochastic Sign with L_inf preconditioning (`stochastic_sign`) for SignSGD_Adv and Lion_adv.
60
+ - Improved CANS (`accelerated_ns`) for Muon variants, by integrating dynamic lower bound.
61
+ - Removed Simplified_AdEMAMix optimizer and its settings in other optimizers, they are now replaced by Nesterov momentum and its coef. Which is better and less hard to tune.
62
+ - Removed cautious and grams modes, as they were heuristic and not working well.
63
+ - Removed optimizers: Lion_Prodigy_adv, and Simplified_AdEMAMix.
64
+
65
+ ### in 2.1.x
66
+
67
+ - Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
68
+ - More info coming soon.
69
+
70
+ ### in 2.0.x
71
+
72
+ * Implemented torch.compile for all advanced optimizers. Enabled via (compiled_optimizer=True) to fuse and optimize the optimizer step path.
73
+ * Better and improved 1-bit factored mode via (nnmf_factor=True).
74
+ * Various improvements across the optimizers.
75
+
76
+ ### in 1.2.x
77
+ * Added **advanced variants** of [Muon optimizer](https://kellerjordan.github.io/posts/muon/) with **features** and **settings** from recent papers.
78
+
79
+ | Optimizer | Description |
80
+ |---|---|
81
+ | `Muon_adv` | Advanced Muon implementation with CANS, NorMuon, Low-Rank ortho, etc. features. |
82
+ | `AdaMuon_adv` | Advanced AdaMuon implementation, which combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
83
+
84
+ > *Documentation coming soon.*
85
+
86
+ * Implemented [Cautious Weight Decay](https://arxiv.org/abs/2510.12402) for all advanced optimizers.
87
+
88
+ * Improved parameter update and weight decay for **BF16** with **stochastic rounding**. The updates are now accumulated in **float32** and rounded once at the end.
89
+
90
+ * Use fused and in-place operations whenever possible for all advanced optimizers.
91
+
92
+ * **Prodigy variants** are now **50% faster** by [avoiding CUDA syncs](https://github.com/Koratahiu/Advanced_Optimizers/pull/5). Thanks to **@dxqb**!
93
+
94
+ ---
95
+
96
+ ## 📦 Installation
97
+
98
+ ```bash
99
+ pip install adv_optm
100
+ ```
101
+
102
+ ---
103
+
104
+ ## 🧠 Core Innovations
105
+
106
+ This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training.
107
+
108
+ ---
109
+
@@ -0,0 +1,78 @@
1
+ # Advanced Optimizers (AIO)
2
+
3
+ A comprehensive, all-in-one collection of optimization algorithms for deep learning, designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
4
+
5
+ [![PyPI](https://img.shields.io/pypi/v/adv_optm)](https://pypi.org/project/adv_optm/)
6
+
7
+ ## 🔥 What's New
8
+
9
+ ### In 2.4.x:
10
+
11
+ This update introduces a whole refactor of the library with many new features and changes:
12
+
13
+ - New optimizers state mode option (`state_precision`) with many precision settings for the optimizer states: rank-2 factored mode (`factored`), full FP32 (`fp32`), BF16 with Stochastic Rounding (`bf16_sr`), int8/uint8 with Stochastic Rounding (`int8_sr`), FP16 (`fp16`)
14
+ - Added new powerful optimizer: SinkSGD_adv.
15
+ - Added spectral scaling option to all optimizers, achieving width/rank invariant updates.
16
+ - Added Nesterov momentum (`nesterov`) and its coef (`nesterov_coef`) to all optimizers.
17
+ - Added centered weight decay (`centered_wd`), to pull the weights toward their pre-train state (anchor)
18
+ - anchor precision can be changed to save memory (`centered_wd_mode`): full, float8, int8, int4
19
+ - Added Fisher Weight Decay option for Adam variants (`fisher_wd`).
20
+ - Paper: [FAdam...](https://arxiv.org/abs/2405.12807)
21
+ - Added Factored Second Moment option for Adam variants (`factored_2nd`). This works alongside any `state_precision` setting.
22
+ - Added Geometric Weight Decay for SinkSGD_adv and SignSGD_adv.
23
+ - Added new powerful mode: variance normalized momentum (`normed_momentum`). Which applies the optimizer normalization before the momentum (also called as Normalization then momentum NtM)
24
+ - For: AdamW_adv, SignSGD_adv, SinkSGD_adv.
25
+ - Added Variance/Confidence Preconditioning (`snr_cond`) for SignSGD_adv, SinkSGD_adv.
26
+ - Only works with `normed_momentum`.
27
+ - Technical reports: [AASS](https://koratahiu.github.io/aass/), and [sink-v](https://koratahiu.github.io/sink-v/).
28
+ - Added Adaptive Stochastic Sign with L_inf preconditioning (`stochastic_sign`) for SignSGD_Adv and Lion_adv.
29
+ - Improved CANS (`accelerated_ns`) for Muon variants, by integrating dynamic lower bound.
30
+ - Removed Simplified_AdEMAMix optimizer and its settings in other optimizers, they are now replaced by Nesterov momentum and its coef. Which is better and less hard to tune.
31
+ - Removed cautious and grams modes, as they were heuristic and not working well.
32
+ - Removed optimizers: Lion_Prodigy_adv, and Simplified_AdEMAMix.
33
+
34
+ ### in 2.1.x
35
+
36
+ - Added Signum (SignSGD with momentum): A new optimizer in the family (SignSGD_adv)
37
+ - More info coming soon.
38
+
39
+ ### in 2.0.x
40
+
41
+ * Implemented torch.compile for all advanced optimizers. Enabled via (compiled_optimizer=True) to fuse and optimize the optimizer step path.
42
+ * Better and improved 1-bit factored mode via (nnmf_factor=True).
43
+ * Various improvements across the optimizers.
44
+
45
+ ### in 1.2.x
46
+ * Added **advanced variants** of [Muon optimizer](https://kellerjordan.github.io/posts/muon/) with **features** and **settings** from recent papers.
47
+
48
+ | Optimizer | Description |
49
+ |---|---|
50
+ | `Muon_adv` | Advanced Muon implementation with CANS, NorMuon, Low-Rank ortho, etc. features. |
51
+ | `AdaMuon_adv` | Advanced AdaMuon implementation, which combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
52
+
53
+ > *Documentation coming soon.*
54
+
55
+ * Implemented [Cautious Weight Decay](https://arxiv.org/abs/2510.12402) for all advanced optimizers.
56
+
57
+ * Improved parameter update and weight decay for **BF16** with **stochastic rounding**. The updates are now accumulated in **float32** and rounded once at the end.
58
+
59
+ * Use fused and in-place operations whenever possible for all advanced optimizers.
60
+
61
+ * **Prodigy variants** are now **50% faster** by [avoiding CUDA syncs](https://github.com/Koratahiu/Advanced_Optimizers/pull/5). Thanks to **@dxqb**!
62
+
63
+ ---
64
+
65
+ ## 📦 Installation
66
+
67
+ ```bash
68
+ pip install adv_optm
69
+ ```
70
+
71
+ ---
72
+
73
+ ## 🧠 Core Innovations
74
+
75
+ This library integrates multiple state-of-the-art optimization techniques validated through extensive research and practical training.
76
+
77
+ ---
78
+
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.4.dev22"
23
+ __version__ = "2.4.dev24"
@@ -99,7 +99,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
99
99
  use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
100
100
  either here or via `optim_type` in parameter groups. (default: None)
101
101
  state_precision (str): Precision for Muon optimizer states. Options: 'auto' (parameter dtype), 'fp32',
102
- 'bf16_sr' (BF16 with stochastic rounding), 'fp8_sr', 'int8_sr'.
102
+ 'bf16_sr' (BF16 with stochastic rounding), 'int8_sr'.
103
103
  (default: 'auto')
104
104
  factored_2nd (bool): Factorize only the second moment (v_t) using SMMF
105
105
  low-rank compression while keeping the first moment (momentum_buffer)
@@ -123,7 +123,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
123
123
  adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
124
124
  adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
125
125
  adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
126
- adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
126
+ adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'int8_sr', 'factored'. (default: 'auto')
127
127
  adam_nnmf_factor (bool): 1-bit factored for AdamW.
128
128
  adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
129
129
  """
@@ -157,7 +157,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
157
157
  # Boolean to spilt param
158
158
  use_muon: bool | None = None,
159
159
  # States precision (Muon path)
160
- state_precision: str = "auto", # 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr'
160
+ state_precision: str = "auto", # 'fp32', 'bf16_sr', 'int8_sr'
161
161
  # Factorized second moment only
162
162
  factored_2nd: bool = False,
163
163
  # Update geometry parameters
@@ -220,7 +220,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
220
220
  state_precision = "factored"
221
221
 
222
222
  state_precision = state_precision.lower()
223
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
223
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
224
224
  if state_precision not in valid_precisions:
225
225
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
226
226
 
@@ -374,6 +374,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
374
374
  d1, d2 = state['effective_shape']
375
375
  state['mu_vbuf_nmf'] = torch.zeros(d1, device=p.device, dtype=torch.float32)
376
376
  state['mv_vbuf_nmf'] = torch.zeros(d2, device=p.device, dtype=torch.float32)
377
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
377
378
  elif not group['normuon_variant']:
378
379
  init_state_tensor(state, 'second_momentum_buffer', p.shape, actual_precision, p.device, default_dtype, non_neg=True)
379
380
 
@@ -454,8 +455,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
454
455
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
455
456
  elif actual_precision == 'int8_sr':
456
457
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
457
- elif actual_precision == 'fp8_sr':
458
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
459
458
  else:
460
459
  adam_step_param = Muon_AuxAdam._adam_step_parameter
461
460
 
@@ -475,8 +474,6 @@ class AdaMuon_adv(torch.optim.Optimizer):
475
474
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
476
475
  elif actual_precision == 'int8_sr':
477
476
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
478
- elif actual_precision == 'fp8_sr':
479
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
480
477
  if group['low_rank_ortho']:
481
478
  random_G_sketch = param_update._get_random_noise_for_low_rank_ortho(p, group['ortho_rank'])
482
479
  else:
@@ -84,7 +84,7 @@ class AdamW_adv(torch.optim.Optimizer):
84
84
  while only factorizing the second moment. (default: False)
85
85
  state_precision (str): Precision method for Adopt states. Options: 'auto'
86
86
  (parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
87
- stochastic rounding), 'fp16' , 'fp8_sr', 'int8_sr'. (default: 'auto')
87
+ stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
88
88
  """
89
89
 
90
90
  def __init__(
@@ -124,7 +124,7 @@ class AdamW_adv(torch.optim.Optimizer):
124
124
  centered_wd: float = 0.0,
125
125
  centered_wd_mode: str = 'float8',
126
126
  # States precision
127
- state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'fp8_sr', 'int8_sr'.
127
+ state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
128
128
  # Factorized second moment only
129
129
  factored_2nd: bool = False,
130
130
  # SMMF factorization (legacy)
@@ -145,7 +145,7 @@ class AdamW_adv(torch.optim.Optimizer):
145
145
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
146
146
 
147
147
  state_precision = state_precision.lower()
148
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
148
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
149
149
  if state_precision not in valid_precisions:
150
150
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
151
151
 
@@ -264,6 +264,7 @@ class AdamW_adv(torch.optim.Optimizer):
264
264
  d1, d2 = state['effective_shape']
265
265
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=torch.float32)
266
266
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=torch.float32)
267
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=device, dtype=torch.uint8)
267
268
  else:
268
269
  init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype, non_neg=True)
269
270
 
@@ -314,8 +315,6 @@ class AdamW_adv(torch.optim.Optimizer):
314
315
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
315
316
  elif group['actual_state_precision'] == 'int8_sr':
316
317
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
317
- elif group['actual_state_precision'] == 'fp8_sr':
318
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
319
318
  step_param_fn = self._compiled_step_parameter
320
319
  else:
321
320
  step_param_fn = self._step_parameter
@@ -88,7 +88,7 @@ class Adopt_adv(torch.optim.Optimizer):
88
88
  while only factorizing the second moment. (default: False)
89
89
  state_precision (str): Precision method for Adopt states. Options: 'auto'
90
90
  (parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
91
- stochastic rounding), 'fp16' , 'fp8_sr', 'int8_sr'. (default: 'auto')
91
+ stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
92
92
  """
93
93
 
94
94
  def __init__(
@@ -126,7 +126,7 @@ class Adopt_adv(torch.optim.Optimizer):
126
126
  centered_wd: float = 0.0,
127
127
  centered_wd_mode: str = 'float8',
128
128
  # States precision
129
- state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'fp8_sr', 'int8_sr'.
129
+ state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
130
130
  # Factorized second moment only
131
131
  factored_2nd: bool = False,
132
132
  # SMMF factorization (legacy)
@@ -148,7 +148,7 @@ class Adopt_adv(torch.optim.Optimizer):
148
148
 
149
149
 
150
150
  state_precision = state_precision.lower()
151
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
151
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
152
152
  if state_precision not in valid_precisions:
153
153
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
154
154
 
@@ -236,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
236
236
 
237
237
  dtype = torch.float32 if (state['factored'] or req_precision == 'factored') else p.dtype
238
238
 
239
- vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'fp8_sr', 'int8_sr']) else dtype
239
+ vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'int8_sr']) else dtype
240
240
  vt_init = grad.pow(2).to(vt_dtype) * (1 - group['betas'][1])
241
241
 
242
242
  if state['factored']:
@@ -262,6 +262,7 @@ class Adopt_adv(torch.optim.Optimizer):
262
262
  state['effective_shape'] = _get_effective_shape(p.numel())
263
263
  d1, d2 = state['effective_shape']
264
264
  state['mu_v_nmf'], state['mv_v_nmf'] = _nnmf(vt_init.view(d1, d2))
265
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
265
266
  else:
266
267
  init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype)
267
268
  set_state(state, 'exp_avg_sq', vt_init, actual_precision, None, non_neg=True)
@@ -316,8 +317,6 @@ class Adopt_adv(torch.optim.Optimizer):
316
317
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
317
318
  elif group['actual_state_precision'] == 'int8_sr':
318
319
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
319
- elif group['actual_state_precision'] == 'fp8_sr':
320
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
321
320
  step_param_fn = self._compiled_step_parameter
322
321
  else:
323
322
  lr = group['lr']
@@ -33,8 +33,6 @@ class Lion_adv(torch.optim.Optimizer):
33
33
  stochastic_rounding (bool, optional): whether to use stochastic
34
34
  rounding for BF16 parameter updates (default: True).
35
35
  orthogonal_gradient (bool): whether to orthogonalize the gradient (default: False).
36
- clip_threshold (float, optional): whether to clip the gradients norm
37
- per-parameter (default: 0.0).
38
36
  kappa_p (float, optional): The p-value for the Lp-norm in Lion-K (domain [1.0, 2.0]).
39
37
  - 1.0: Standard Lion (sign update).
40
38
  - 2.0: Spherical Lion (normalized L2 update).
@@ -47,7 +47,7 @@ class Muon_adv(torch.optim.Optimizer):
47
47
  use_muon (bool | None): whether to use Muon or AuxAdamW. MUST be provided
48
48
  either here or via `optim_type` in parameter groups. (default: None)
49
49
  state_precision (str): Precision for Muon optimizer states. Options: 'auto' (parameter dtype), 'fp32',
50
- 'bf16_sr' (BF16 with stochastic rounding), 'fp8_sr', 'int8_sr'.
50
+ 'bf16_sr' (BF16 with stochastic rounding), 'int8_sr'.
51
51
  (default: 'auto')
52
52
  low_rank_ortho (bool): If True, enables low-rank orthogonalization, which
53
53
  projects the update to a lower rank before orthogonalization.
@@ -98,7 +98,7 @@ class Muon_adv(torch.optim.Optimizer):
98
98
  adam_tiny_spike (float): Tiny spike for Kourkoutas-β. (default: 1e-9)
99
99
  adam_k_warmup_steps (int): Warmup steps for Kourkoutas-β. (default: 0)
100
100
  adam_spectral_normalization (bool): Enable explicit spectral normalization for AdamW. (default: False)
101
- adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'fp8_sr', 'int8_sr', 'factored'. (default: 'auto')
101
+ adam_state_precision (str): Precision for AuxAdam states. Options: 'auto', 'fp32', 'bf16_sr', 'fp16', 'int8_sr', 'factored'. (default: 'auto')
102
102
  adam_nnmf_factor (bool): 1-bit factored for AdamW.
103
103
  adam_factored_2nd (bool): Factorize only the second moment (v_t) for AuxAdam. (default: False)
104
104
  """
@@ -130,7 +130,7 @@ class Muon_adv(torch.optim.Optimizer):
130
130
  # Boolean to spilt param
131
131
  use_muon: bool | None = None,
132
132
  # States precision (Muon path)
133
- state_precision: str = "auto", # 'fp32', 'bf16_sr', 'fp8_sr', 'int8_sr'
133
+ state_precision: str = "auto", # 'fp32', 'bf16_sr', 'int8_sr'
134
134
  # Low-rank Muon
135
135
  low_rank_ortho: bool = False,
136
136
  ortho_rank: int = 128,
@@ -193,7 +193,7 @@ class Muon_adv(torch.optim.Optimizer):
193
193
  state_precision = "factored"
194
194
 
195
195
  state_precision = state_precision.lower()
196
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
196
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
197
197
  if state_precision not in valid_precisions:
198
198
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
199
199
 
@@ -406,8 +406,6 @@ class Muon_adv(torch.optim.Optimizer):
406
406
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
407
407
  elif actual_precision == 'int8_sr':
408
408
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
409
- elif actual_precision == 'fp8_sr':
410
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
411
409
  else:
412
410
  adam_step_param = Muon_AuxAdam._adam_step_parameter
413
411
 
@@ -427,8 +425,6 @@ class Muon_adv(torch.optim.Optimizer):
427
425
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
428
426
  elif actual_precision == 'int8_sr':
429
427
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
430
- elif actual_precision == 'fp8_sr':
431
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
432
428
  if group['low_rank_ortho']:
433
429
  random_G_sketch = param_update._get_random_noise_for_low_rank_ortho(p, group['ortho_rank'])
434
430
  else:
@@ -124,7 +124,7 @@ class Prodigy_adv(torch.optim.Optimizer):
124
124
  nesterov: bool = False,
125
125
  nesterov_coef: float | None = None,
126
126
  # States precision
127
- state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'fp8_sr', 'int8_sr'.
127
+ state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
128
128
  # Factorized second moment only
129
129
  factored_2nd: bool = False,
130
130
  # SMMF factorization (legacy)
@@ -168,7 +168,7 @@ class Prodigy_adv(torch.optim.Optimizer):
168
168
  raise ValueError(f"For Kourkoutas-β, betas[1] (as beta2_max) must be > beta2_min. Got {betas[1]} and {beta2_min}")
169
169
 
170
170
  state_precision = state_precision.lower()
171
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
171
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
172
172
  if state_precision not in valid_precisions:
173
173
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
174
174
 
@@ -311,6 +311,7 @@ class Prodigy_adv(torch.optim.Optimizer):
311
311
  d1, d2 = state['effective_shape']
312
312
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=torch.float32)
313
313
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=torch.float32)
314
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=p.device, dtype=torch.uint8)
314
315
  else:
315
316
  init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype, non_neg=True)
316
317
 
@@ -358,8 +359,6 @@ class Prodigy_adv(torch.optim.Optimizer):
358
359
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
359
360
  elif group['actual_state_precision'] == 'int8_sr':
360
361
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
361
- elif group['actual_state_precision'] == 'fp8_sr':
362
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
363
362
  step_param_fn = self._compiled_step_parameter
364
363
  else:
365
364
  d = group['d']
@@ -44,7 +44,7 @@ class SignSGD_adv(torch.optim.Optimizer):
44
44
  'int4': Uses 4-bit block-wise quantization (block size 32).
45
45
  state_precision (str): Precision method for Adopt states. Options: 'auto'
46
46
  (parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr' (with
47
- stochastic rounding), 'fp16' , 'fp8_sr', 'int8_sr'. (default: 'auto')
47
+ stochastic rounding), 'fp16' , 'int8_sr'. (default: 'auto')
48
48
  nnmf_factor (bool): whether to use the factorization or use the
49
49
  uncompressed optimizer. (default: True)
50
50
  """
@@ -70,13 +70,13 @@ class SignSGD_adv(torch.optim.Optimizer):
70
70
  nesterov_coef: float | None = None,
71
71
  # Normalization then Momentum
72
72
  normed_momentum: bool = False,
73
- # SNR Precondition
73
+ # SNR Precondition (requires normed_momentum)
74
74
  snr_cond: bool = False,
75
75
  # Centered WD
76
76
  centered_wd: float = 0.0,
77
77
  centered_wd_mode: str = 'float8',
78
78
  # States precision
79
- state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'fp8_sr', 'int8_sr'.
79
+ state_precision: str = "auto", # 'fp32', 'factored', 'bf16_sr', 'int8_sr'.
80
80
  # Spectral Normed Optimizer
81
81
  spectral_normalization: bool = False,
82
82
  # SMMF factorization
@@ -95,7 +95,7 @@ class SignSGD_adv(torch.optim.Optimizer):
95
95
  raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
96
96
 
97
97
  state_precision = state_precision.lower()
98
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
98
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
99
99
  if state_precision not in valid_precisions:
100
100
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
101
101
 
@@ -230,8 +230,6 @@ class SignSGD_adv(torch.optim.Optimizer):
230
230
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
231
231
  elif group['actual_state_precision'] == 'int8_sr':
232
232
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
233
- elif group['actual_state_precision'] == 'fp8_sr':
234
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
235
233
 
236
234
  if group.get('stochastic_sign', False) and not is_vector:
237
235
  random_noise_tensor = param_update._get_random_noise_for_sso(p)
@@ -254,7 +252,8 @@ class SignSGD_adv(torch.optim.Optimizer):
254
252
  nesterov = group.get('nesterov', False)
255
253
  nesterov_coef = group.get('nesterov_coef', None)
256
254
  sso = group.get('stochastic_sign', False)
257
- snr_cond = group.get('snr_cond', False) and group.get('normed_momentum', False) and momentum > 0
255
+ normed_mt = group.get('normed_momentum', False)
256
+ snr_cond = group.get('snr_cond', False) and normed_mt and momentum > 0
258
257
 
259
258
  denom = None
260
259
  wd_target = None
@@ -263,7 +262,7 @@ class SignSGD_adv(torch.optim.Optimizer):
263
262
  if group["orthogonal_gradient"]:
264
263
  grad = _orthogonalize_gradient(p, grad)
265
264
 
266
- if group.get('normed_momentum', False):
265
+ if normed_mt:
267
266
  if sso:
268
267
  grad = apply_stochastic_sign_(grad, noise=random_noise_tensor, is_vector=is_vector)
269
268
  else:
@@ -285,7 +284,12 @@ class SignSGD_adv(torch.optim.Optimizer):
285
284
 
286
285
  if nesterov:
287
286
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
288
- raw_update = grad_reshaped.lerp(exp_avg, nv_coef)
287
+ if normed_mt:
288
+ # Scale the normalized gradient down to match the buffer's variance
289
+ ema_std = math.sqrt((1 - momentum) / (1 + momentum))
290
+ raw_update = (grad_reshaped * ema_std).lerp_(exp_avg, nv_coef)
291
+ else:
292
+ raw_update = grad.lerp(exp_avg, nv_coef)
289
293
  else:
290
294
  raw_update = exp_avg.clone()
291
295
 
@@ -309,7 +313,12 @@ class SignSGD_adv(torch.optim.Optimizer):
309
313
 
310
314
  if nesterov:
311
315
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
312
- raw_update = grad.lerp(exp_avg, nv_coef)
316
+ if normed_mt:
317
+ # Scale the normalized gradient down to match the buffer's variance
318
+ ema_std = math.sqrt((1 - momentum) / (1 + momentum))
319
+ raw_update = (grad * ema_std).lerp_(exp_avg, nv_coef)
320
+ else:
321
+ raw_update = grad.lerp(exp_avg, nv_coef)
313
322
  else:
314
323
  raw_update = exp_avg.clone()
315
324
 
@@ -42,7 +42,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
42
42
  nnmf_factor (bool): whether to use factorization or disable it. (default: False)
43
43
  state_precision (str): Precision method for states. Options: 'auto'
44
44
  (parameter precision), 'fp32', 'factored' (SMMF low-rank FP32), 'bf16_sr',
45
- 'fp8_sr', 'int8_sr'. (default: 'auto')
45
+ 'int8_sr'. (default: 'auto')
46
46
  compiled_optimizer (bool): Compiles the core step function using torch.compile
47
47
  for faster execution. (default: False)
48
48
  """
@@ -58,7 +58,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
58
58
  orthogonal_sinkhorn: bool = False,
59
59
  # Normalization then Momentum
60
60
  normed_momentum: bool = False,
61
- # SNR Precondition
61
+ # SNR Precondition (requires normed_momentum)
62
62
  snr_cond: bool = False,
63
63
  # Nesterov Momentum
64
64
  nesterov: bool = False,
@@ -93,7 +93,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
93
93
  raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
94
94
 
95
95
  state_precision = state_precision.lower()
96
- valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "fp8_sr", "int8_sr"}
96
+ valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
97
97
  if state_precision not in valid_precisions:
98
98
  raise ValueError(f"state_precision must be one of {valid_precisions}. Got {state_precision}")
99
99
 
@@ -209,8 +209,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
209
209
  random_int_state_tensor = param_update._get_random_int_for_sr(p)
210
210
  elif group['actual_state_precision'] == 'int8_sr':
211
211
  random_int_state_tensor = param_update._get_random_int_for_8bit_sr(p)
212
- elif group['actual_state_precision'] == 'fp8_sr':
213
- random_int_state_tensor = param_update._get_random_int_for_fp8_sr(p)
214
212
  step_param_fn = self._compiled_step_parameter
215
213
  else:
216
214
  step_param_fn = self._step_parameter
@@ -226,6 +224,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
226
224
  orthogonal_sinkhorn = group['orthogonal_sinkhorn']
227
225
 
228
226
  momentum = group['momentum']
227
+ normed_mt = group.get('normed_momentum', False)
229
228
  nesterov = group['nesterov']
230
229
  nesterov_coef = group.get('nesterov_coef', None)
231
230
  snr_cond = group.get('snr_cond', False)
@@ -238,7 +237,10 @@ class SinkSGD_adv(torch.optim.Optimizer):
238
237
  wd_target = None
239
238
  cwd_target = None
240
239
 
241
- if group.get('normed_momentum', False):
240
+ if group["orthogonal_gradient"]:
241
+ grad = _orthogonalize_gradient(p, grad)
242
+
243
+ if normed_mt:
242
244
  if not is_vector:
243
245
  # Sinkhorn iterative normalization
244
246
  grad = apply_sr_sinkhorn(grad, iters=sinkhorn_iterations, p=p, ortho_project=orthogonal_sinkhorn)
@@ -246,9 +248,6 @@ class SinkSGD_adv(torch.optim.Optimizer):
246
248
  # For vectors, apply sign operation
247
249
  grad = grad.sign_()
248
250
 
249
- if group["orthogonal_gradient"]:
250
- grad = _orthogonalize_gradient(p, grad)
251
-
252
251
  if state['factored']:
253
252
  d1, d2 = state['effective_shape']
254
253
  grad_reshaped = grad.view(d1, d2)
@@ -272,7 +271,12 @@ class SinkSGD_adv(torch.optim.Optimizer):
272
271
 
273
272
  if nesterov:
274
273
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
275
- update = grad_reshaped.lerp(buf, nv_coef)
274
+ if normed_mt:
275
+ # Scale the normalized gradient down to match the buffer's variance
276
+ ema_std = math.sqrt((1 - momentum) / (1 + momentum))
277
+ update = (grad_reshaped * ema_std).lerp_(buf, nv_coef)
278
+ else:
279
+ update = grad_reshaped.lerp(buf, nv_coef)
276
280
  else:
277
281
  update = buf.clone()
278
282
  else:
@@ -301,7 +305,12 @@ class SinkSGD_adv(torch.optim.Optimizer):
301
305
 
302
306
  if nesterov:
303
307
  nv_coef = momentum if nesterov_coef is None else nesterov_coef
304
- update = grad.lerp(buf, nv_coef)
308
+ if normed_mt:
309
+ # Scale the normalized gradient down to match the buffer's variance
310
+ ema_std = math.sqrt((1 - momentum) / (1 + momentum))
311
+ update = (grad * ema_std).lerp_(buf, nv_coef)
312
+ else:
313
+ update = grad.lerp(buf, nv_coef)
305
314
  else:
306
315
  update = buf.clone()
307
316
  else:
@@ -56,6 +56,7 @@ def _init_auxadam_state(self, p, group):
56
56
  d1, d2 = state['effective_shape']
57
57
  state['mu_v_nmf'] = torch.zeros(d1, device=device, dtype=torch.float32)
58
58
  state['mv_v_nmf'] = torch.zeros(d2, device=device, dtype=torch.float32)
59
+ state['shifter'] = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128], device=device, dtype=torch.uint8)
59
60
  else:
60
61
  init_state_tensor(state, 'exp_avg_sq', p.shape, actual_precision, p.device, dtype, non_neg=True)
61
62