adv-optm 2.4.dev25__tar.gz → 2.5.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. adv_optm-2.5.1/PKG-INFO +113 -0
  2. adv_optm-2.5.1/README.md +82 -0
  3. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/__init__.py +1 -1
  4. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/AdaMuon_adv.py +7 -7
  5. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/AdamW_adv.py +5 -5
  6. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/Adopt_adv.py +4 -6
  7. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/Lion_adv.py +6 -5
  8. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/Muon_adv.py +7 -7
  9. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/Prodigy_adv.py +4 -4
  10. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/SignSGD_adv.py +7 -8
  11. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/SinkSGD_adv.py +7 -8
  12. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/Muon_AuxAdam.py +2 -3
  13. adv_optm-2.5.1/adv_optm/util/OrthoGrad.py +99 -0
  14. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/param_update.py +3 -3
  15. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/scaled_optm.py +2 -2
  16. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/state_util.py +1 -1
  17. adv_optm-2.5.1/adv_optm.egg-info/PKG-INFO +113 -0
  18. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/setup.py +1 -1
  19. adv_optm-2.4.dev25/PKG-INFO +0 -109
  20. adv_optm-2.4.dev25/README.md +0 -78
  21. adv_optm-2.4.dev25/adv_optm/util/OrthoGrad.py +0 -80
  22. adv_optm-2.4.dev25/adv_optm.egg-info/PKG-INFO +0 -109
  23. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/LICENSE +0 -0
  24. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/optim/__init__.py +0 -0
  25. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/Kourkoutas.py +0 -0
  26. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/Muon_util.py +0 -0
  27. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/__init__.py +0 -0
  28. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/centered_decay.py +0 -0
  29. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/factorization_util.py +0 -0
  30. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/lion_k.py +0 -0
  31. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/signed_util.py +0 -0
  32. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/sinkhorn.py +0 -0
  33. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm/util/update_util.py +0 -0
  34. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm.egg-info/SOURCES.txt +0 -0
  35. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm.egg-info/dependency_links.txt +0 -0
  36. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm.egg-info/requires.txt +0 -0
  37. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/adv_optm.egg-info/top_level.txt +0 -0
  38. {adv_optm-2.4.dev25 → adv_optm-2.5.1}/setup.cfg +0 -0
@@ -0,0 +1,113 @@
1
+ Metadata-Version: 2.4
2
+ Name: adv_optm
3
+ Version: 2.5.1
4
+ Summary: A family of highly efficient, lightweight yet powerful optimizers.
5
+ Home-page: https://github.com/Koratahiu/Advanced_Optimizers
6
+ Author: Koratahiu
7
+ Author-email: hiuhonor@gmail.com
8
+ License: Apache 2.0
9
+ Keywords: llm,fine-tuning,memory-efficient,low-rank,compression,pytorch,optimizer,adam
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: License :: OSI Approved :: Apache Software License
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
14
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
15
+ Requires-Python: >=3.8
16
+ Description-Content-Type: text/markdown
17
+ License-File: LICENSE
18
+ Requires-Dist: torch>=2.1
19
+ Dynamic: author
20
+ Dynamic: author-email
21
+ Dynamic: classifier
22
+ Dynamic: description
23
+ Dynamic: description-content-type
24
+ Dynamic: home-page
25
+ Dynamic: keywords
26
+ Dynamic: license
27
+ Dynamic: license-file
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
31
+
32
+ # Advanced Optimizers (AIO)
33
+
34
+ A comprehensive, all-in-one collection of state-of-the-art optimization algorithms for deep learning. Designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
35
+
36
+ [![PyPI version](https://img.shields.io/pypi/v/adv_optm.svg?color=blue&style=flat-square)](https://pypi.org/project/adv_optm/)
37
+ [![Python versions](https://img.shields.io/pypi/pyversions/adv_optm.svg?style=flat-square)](https://pypi.org/project/adv_optm/)
38
+ [![License](https://img.shields.io/badge/license-Apache-green?style=flat-square)](LICENSE)
39
+
40
+ ---
41
+
42
+ ## 📦 Installation
43
+
44
+ ```bash
45
+ pip install adv_optm
46
+ ```
47
+ *Requires PyTorch 2.3+ for `torch.compile` support.*
48
+
49
+ ---
50
+
51
+ ## What's New
52
+
53
+ ### 🌟 Version 2.5.x: The Massive Refactor
54
+ This major update introduces a complete architectural refactor of the library:
55
+
56
+ **🆕 New Optimizers & Scaling**
57
+ * **`SinkSGD_adv`:** Added a powerful new optimizer to the lineup.
58
+ * **Spectral Scaling:** Now available across *all* optimizers, achieving width/rank invariant updates for highly stable training.
59
+
60
+ **💾 Memory & State Precision Control**
61
+ * **Granular State Precision (`state_precision`):** Drastically reduce memory overhead with new optimizer state modes:
62
+ * `factored` (Rank-2 factored mode)
63
+ * `fp32` (Full precision)
64
+ * `bf16_sr` & `int8_sr` (BF16/Int8 with Stochastic Rounding)
65
+ * **Factored Second Moment (`factored_2nd`):** Available for all Adam variants. Works seamlessly alongside any `state_precision` setting to further slash memory usage.
66
+
67
+ **⚙️ Advanced Dynamics & Momentum**
68
+ * **Variance Normalized Momentum (`normed_momentum`):** Applies optimizer normalization *before* momentum (Normalization then Momentum/NtM). Available for `AdamW_adv`, `SignSGD_adv`, and `SinkSGD_adv`.
69
+ * **Universal Nesterov Momentum:** Replaced the hard-to-tune Simplified_AdEMAMix with Nesterov momentum (`nesterov`) and a dedicated coefficient (`nesterov_coef`) across all optimizers.
70
+ * **Preconditioning & Signs:**
71
+ * Added **Variance/Confidence Preconditioning (`snr_cond`)** for `SignSGD_adv` and `SinkSGD_adv` (requires `normed_momentum`). Read the technical reports: [AASS](https://koratahiu.github.io/aass/) & [sink-v](https://koratahiu.github.io/sink-v/).
72
+ * Added **Adaptive Stochastic Sign** with $L_\infty$ preconditioning (`stochastic_sign`) for `SignSGD_Adv` and `Lion_adv`.
73
+ * **Improved CANS (`accelerated_ns`):** Enhanced for Muon variants by integrating a dynamic lower bound.
74
+ * **New OrthoGrad modes (`orthogonal_gradient`):** Standard OrthoGrad `flattened` and a new matrix-wise mode `iterative`.
75
+
76
+ **⚓ Weight Decay Innovations**
77
+ * **Centered Weight Decay (`centered_wd`):** Pulls weights toward their pre-train state (anchor). To save memory, anchor precision (`centered_wd_mode`) can be set to full, float8, int8, or int4.
78
+ * **Fisher Weight Decay (`fisher_wd`):** Now available for Adam variants based on the [FAdam paper](https://arxiv.org/abs/2405.12807).
79
+ * **Geometric Weight Decay:** Added specifically for `SinkSGD_adv` and `SignSGD_adv`.
80
+
81
+ *(Note: `Lion_Prodigy_adv`, `Simplified_AdEMAMix`, and heuristic cautious/grams modes have been deprecated in favor of these superior, theoretically-grounded features).*
82
+
83
+ <details>
84
+ <summary><b>Click to see older release notes (v1.2.x - v2.1.x)</b></summary>
85
+
86
+ ### Version 2.1.x
87
+ * **New Optimizer:** Added **Signum** (SignSGD with momentum) to the `SignSGD_adv` family.
88
+
89
+ ### Version 2.0.x
90
+ * ⚡ **`torch.compile` Support:** Fully implemented for all advanced optimizers. Enable via `compiled_optimizer=True` to heavily fuse and optimize the optimizer step path.
91
+ * 📉 **1-Bit Factored Mode:** Vastly improved implementation via `nnmf_factor=True`.
92
+ * 🛠️ Broad performance and stability improvements across all optimizers.
93
+
94
+ ### Version 1.2.x
95
+ * **Advanced Muon Variants:** Brought the groundbreaking [Muon optimizer](https://kellerjordan.github.io/posts/muon/) into the fold, enriched with features from recent literature.
96
+
97
+ | Optimizer | Description |
98
+ |---|---|
99
+ | `Muon_adv` | Advanced Muon implementation featuring CANS, NorMuon, Low-Rank Orthogonalization, and more. |
100
+ | `AdaMuon_adv` | Combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
101
+
102
+ * **Prodigy Speedup:** Prodigy variants are now **50% faster** by eliminating unnecessary CUDA syncs (Shoutout to **@dxqb**!).
103
+ * **Stochastic Rounding for BF16:** Parameter updates and weight decay now accumulate in float32 and round once at the end.
104
+ * **Cautious Weight Decay:** Implemented for all advanced optimizers ([Paper](https://arxiv.org/abs/2510.12402)).
105
+ * **Fused Operations:** Transitioned to fused and in-place operations wherever possible.
106
+
107
+ </details>
108
+
109
+ ---
110
+
111
+ ## 💡 Core Innovations
112
+
113
+ *(Documentation expanding on the theory and usage of these features is coming soon!)*
@@ -0,0 +1,82 @@
1
+ # Advanced Optimizers (AIO)
2
+
3
+ A comprehensive, all-in-one collection of state-of-the-art optimization algorithms for deep learning. Designed for **maximum efficiency**, **minimal memory footprint**, and **superior performance** across diverse model architectures and training scenarios.
4
+
5
+ [![PyPI version](https://img.shields.io/pypi/v/adv_optm.svg?color=blue&style=flat-square)](https://pypi.org/project/adv_optm/)
6
+ [![Python versions](https://img.shields.io/pypi/pyversions/adv_optm.svg?style=flat-square)](https://pypi.org/project/adv_optm/)
7
+ [![License](https://img.shields.io/badge/license-Apache-green?style=flat-square)](LICENSE)
8
+
9
+ ---
10
+
11
+ ## 📦 Installation
12
+
13
+ ```bash
14
+ pip install adv_optm
15
+ ```
16
+ *Requires PyTorch 2.3+ for `torch.compile` support.*
17
+
18
+ ---
19
+
20
+ ## What's New
21
+
22
+ ### 🌟 Version 2.5.x: The Massive Refactor
23
+ This major update introduces a complete architectural refactor of the library:
24
+
25
+ **🆕 New Optimizers & Scaling**
26
+ * **`SinkSGD_adv`:** Added a powerful new optimizer to the lineup.
27
+ * **Spectral Scaling:** Now available across *all* optimizers, achieving width/rank invariant updates for highly stable training.
28
+
29
+ **💾 Memory & State Precision Control**
30
+ * **Granular State Precision (`state_precision`):** Drastically reduce memory overhead with new optimizer state modes:
31
+ * `factored` (Rank-2 factored mode)
32
+ * `fp32` (Full precision)
33
+ * `bf16_sr` & `int8_sr` (BF16/Int8 with Stochastic Rounding)
34
+ * **Factored Second Moment (`factored_2nd`):** Available for all Adam variants. Works seamlessly alongside any `state_precision` setting to further slash memory usage.
35
+
36
+ **⚙️ Advanced Dynamics & Momentum**
37
+ * **Variance Normalized Momentum (`normed_momentum`):** Applies optimizer normalization *before* momentum (Normalization then Momentum/NtM). Available for `AdamW_adv`, `SignSGD_adv`, and `SinkSGD_adv`.
38
+ * **Universal Nesterov Momentum:** Replaced the hard-to-tune Simplified_AdEMAMix with Nesterov momentum (`nesterov`) and a dedicated coefficient (`nesterov_coef`) across all optimizers.
39
+ * **Preconditioning & Signs:**
40
+ * Added **Variance/Confidence Preconditioning (`snr_cond`)** for `SignSGD_adv` and `SinkSGD_adv` (requires `normed_momentum`). Read the technical reports: [AASS](https://koratahiu.github.io/aass/) & [sink-v](https://koratahiu.github.io/sink-v/).
41
+ * Added **Adaptive Stochastic Sign** with $L_\infty$ preconditioning (`stochastic_sign`) for `SignSGD_Adv` and `Lion_adv`.
42
+ * **Improved CANS (`accelerated_ns`):** Enhanced for Muon variants by integrating a dynamic lower bound.
43
+ * **New OrthoGrad modes (`orthogonal_gradient`):** Standard OrthoGrad `flattened` and a new matrix-wise mode `iterative`.
44
+
45
+ **⚓ Weight Decay Innovations**
46
+ * **Centered Weight Decay (`centered_wd`):** Pulls weights toward their pre-train state (anchor). To save memory, anchor precision (`centered_wd_mode`) can be set to full, float8, int8, or int4.
47
+ * **Fisher Weight Decay (`fisher_wd`):** Now available for Adam variants based on the [FAdam paper](https://arxiv.org/abs/2405.12807).
48
+ * **Geometric Weight Decay:** Added specifically for `SinkSGD_adv` and `SignSGD_adv`.
49
+
50
+ *(Note: `Lion_Prodigy_adv`, `Simplified_AdEMAMix`, and heuristic cautious/grams modes have been deprecated in favor of these superior, theoretically-grounded features).*
51
+
52
+ <details>
53
+ <summary><b>Click to see older release notes (v1.2.x - v2.1.x)</b></summary>
54
+
55
+ ### Version 2.1.x
56
+ * **New Optimizer:** Added **Signum** (SignSGD with momentum) to the `SignSGD_adv` family.
57
+
58
+ ### Version 2.0.x
59
+ * ⚡ **`torch.compile` Support:** Fully implemented for all advanced optimizers. Enable via `compiled_optimizer=True` to heavily fuse and optimize the optimizer step path.
60
+ * 📉 **1-Bit Factored Mode:** Vastly improved implementation via `nnmf_factor=True`.
61
+ * 🛠️ Broad performance and stability improvements across all optimizers.
62
+
63
+ ### Version 1.2.x
64
+ * **Advanced Muon Variants:** Brought the groundbreaking [Muon optimizer](https://kellerjordan.github.io/posts/muon/) into the fold, enriched with features from recent literature.
65
+
66
+ | Optimizer | Description |
67
+ |---|---|
68
+ | `Muon_adv` | Advanced Muon implementation featuring CANS, NorMuon, Low-Rank Orthogonalization, and more. |
69
+ | `AdaMuon_adv` | Combines Muon's geometry with Adam-like adaptive scaling and sign-based orthogonalization. |
70
+
71
+ * **Prodigy Speedup:** Prodigy variants are now **50% faster** by eliminating unnecessary CUDA syncs (Shoutout to **@dxqb**!).
72
+ * **Stochastic Rounding for BF16:** Parameter updates and weight decay now accumulate in float32 and round once at the end.
73
+ * **Cautious Weight Decay:** Implemented for all advanced optimizers ([Paper](https://arxiv.org/abs/2510.12402)).
74
+ * **Fused Operations:** Transitioned to fused and in-place operations wherever possible.
75
+
76
+ </details>
77
+
78
+ ---
79
+
80
+ ## 💡 Core Innovations
81
+
82
+ *(Documentation expanding on the theory and usage of these features is coming soon!)*
@@ -20,4 +20,4 @@ __all__ = [
20
20
  "SinkSGD_adv",
21
21
  ]
22
22
 
23
- __version__ = "2.4.dev25"
23
+ __version__ = "2.5.1"
@@ -57,7 +57,8 @@ class AdaMuon_adv(torch.optim.Optimizer):
57
57
  (default: (3.4445, -4.7750, 2.0315)).
58
58
  stochastic_rounding (bool): whether to use stochastic rounding for
59
59
  BF16 parameter updates (default: True).
60
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
60
+ orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
61
+ 'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
61
62
  nesterov (bool): enables Nesterov momentum (default: False).
62
63
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
63
64
  vector_reshape (bool): whether to reshape 1D vectors into 2D
@@ -114,7 +115,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
114
115
  adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
115
116
  adam_use_bias_correction (bool): Bias correction for AdamW.
116
117
  adam_use_atan2 (bool): Atan2 update rule for AdamW.
117
- adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
118
+ adam_orthogonal_gradient (str): OrthoGrad for AdamW.
118
119
  adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
119
120
  adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
120
121
  adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
@@ -149,7 +150,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
149
150
  # Stochastic Rounding for BF16
150
151
  stochastic_rounding: bool = True,
151
152
  # OrthoGrad
152
- orthogonal_gradient: bool = False,
153
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
153
154
  # Adam_atan2 (scale invariant)
154
155
  use_atan2: bool = False,
155
156
  # NorMuon
@@ -190,7 +191,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
190
191
  adam_fisher_wd: bool = False,
191
192
  adam_use_bias_correction: bool = True,
192
193
  adam_use_atan2: bool = False,
193
- adam_orthogonal_gradient: bool = False,
194
+ adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
194
195
  adam_nesterov: bool = False,
195
196
  adam_nesterov_coef: float | None = None,
196
197
  adam_kourkoutas_beta: bool = False,
@@ -213,7 +214,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
213
214
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
214
215
  rms_rescaling = False
215
216
  if spectral_normalization and accelerated_ns:
216
- ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
217
+ raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
217
218
 
218
219
  # Legacy backwards compatibility support for `nnmf_factor=True`
219
220
  if nnmf_factor:
@@ -515,8 +516,7 @@ class AdaMuon_adv(torch.optim.Optimizer):
515
516
  grad = approx_mars(grad, state['last_grad'], group['mars_gamma'], beta1)
516
517
 
517
518
 
518
- if group.get("orthogonal_gradient"):
519
- grad = _orthogonalize_gradient(p, grad)
519
+ grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
520
520
 
521
521
  if state['factored']: # Factored Muon
522
522
  d1, d2 = state['effective_shape']
@@ -45,7 +45,8 @@ class AdamW_adv(torch.optim.Optimizer):
45
45
  stochastic_rounding (bool): whether to use stochastic
46
46
  rounding for BF16 parameter updates (default: True).
47
47
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
48
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
48
+ orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
49
+ 'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
49
50
  normed_momentum (bool): whether to compute the first moment on the normalized gradient. (default: False)
50
51
  kourkoutas_beta (bool): whether to enable the layer-wise dynamic β₂ logic.
51
52
  If `False`, the optimizer behaves as standard AdamW. (default: False)
@@ -104,7 +105,7 @@ class AdamW_adv(torch.optim.Optimizer):
104
105
  # Adam_atan2 (scale invariant)
105
106
  use_atan2: bool = False,
106
107
  # OrthoGrad
107
- orthogonal_gradient: bool = False,
108
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
108
109
  # Nesterov momentum
109
110
  nesterov: bool = False,
110
111
  nesterov_coef: float | None = None,
@@ -326,8 +327,7 @@ class AdamW_adv(torch.optim.Optimizer):
326
327
  def _step_parameter(self, p, grad, state, group, step_size, beta1, beta2, sqrt_bias_correction2, random_int_tensor, random_int_state_tensor):
327
328
  grad = upcast_grad_for_precision(grad, state, group['state_precision'])
328
329
 
329
- if group["orthogonal_gradient"]:
330
- grad = _orthogonalize_gradient(p, grad)
330
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
331
331
 
332
332
  nesterov = group.get('nesterov', False)
333
333
  nesterov_coef = group.get('nesterov_coef', None)
@@ -462,7 +462,7 @@ class AdamW_adv(torch.optim.Optimizer):
462
462
  else:
463
463
  update.mul_(update_scaling)
464
464
 
465
- param_update.apply_parameter_update(self, p, group, update, step_size, random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
465
+ param_update.apply_parameter_update(self, p, group, update, group['lr'], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
466
466
 
467
467
  def compile(self, *args, **kwargs):
468
468
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -108,7 +108,7 @@ class Adopt_adv(torch.optim.Optimizer):
108
108
  # Stochastic Rounding for BF16
109
109
  stochastic_rounding: bool = True,
110
110
  # OrthoGrad
111
- orthogonal_gradient: bool = False,
111
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
112
112
  # Nesterov momentum
113
113
  nesterov: bool = False,
114
114
  nesterov_coef: float | None = None,
@@ -158,7 +158,7 @@ class Adopt_adv(torch.optim.Optimizer):
158
158
 
159
159
  defaults = {
160
160
  "lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay,
161
- "fisher_wd": fisher_wd, "cautious_wd": cautious_wd,
161
+ "fisher_wd": fisher_wd, "cautious_wd": cautious_wd, "orthogonal_gradient": orthogonal_gradient,
162
162
  "nesterov": nesterov, "nesterov_coef": nesterov_coef,
163
163
  "kourkoutas_beta": kourkoutas_beta, "beta2_min": beta2_min, "ema_alpha": ema_alpha,
164
164
  "tiny_spike": tiny_spike, "k_warmup_steps": k_warmup_steps, "k_logging": k_logging,
@@ -172,7 +172,6 @@ class Adopt_adv(torch.optim.Optimizer):
172
172
  self.clip_lambda = clip_lambda
173
173
  self.stochastic_rounding = stochastic_rounding
174
174
  self.use_atan2 = use_atan2
175
- self.orthogonal_gradient = orthogonal_gradient
176
175
  self.kourkoutas_beta = kourkoutas_beta
177
176
  self.layer_key_fn = layer_key_fn
178
177
  self._init_lr = lr if lr > 0 else 1
@@ -237,7 +236,7 @@ class Adopt_adv(torch.optim.Optimizer):
237
236
  dtype = torch.float32 if (state['factored'] or req_precision == 'factored') else p.dtype
238
237
 
239
238
  vt_dtype = torch.float32 if (state['factored'] or state['factored_2nd'] or req_precision in ['factored', 'bf16_sr', 'int8_sr']) else dtype
240
- vt_init = grad.pow(2).to(vt_dtype) * (1 - group['betas'][1])
239
+ vt_init = grad.pow(2).to(vt_dtype)
241
240
 
242
241
  if state['factored']:
243
242
  state['effective_shape'] = _get_effective_shape(p.numel())
@@ -329,8 +328,7 @@ class Adopt_adv(torch.optim.Optimizer):
329
328
  def _step_parameter(self, p, grad, state, group, lr, beta1, beta2, random_int_tensor, random_int_state_tensor):
330
329
  grad = upcast_grad_for_precision(grad, state, group['state_precision'])
331
330
 
332
- if self.orthogonal_gradient:
333
- grad = _orthogonalize_gradient(p, grad)
331
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
334
332
 
335
333
  nesterov = group.get('nesterov', False)
336
334
  nesterov_coef = group.get('nesterov_coef', None)
@@ -67,7 +67,7 @@ class Lion_adv(torch.optim.Optimizer):
67
67
  # Stochastic Rounding for BF16
68
68
  stochastic_rounding: bool = True,
69
69
  # OrthoGrad
70
- orthogonal_gradient: bool = False,
70
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
71
71
  # Lion-k
72
72
  kappa_p: float = 1.0,
73
73
  auto_kappa_p: bool = False,
@@ -213,8 +213,9 @@ class Lion_adv(torch.optim.Optimizer):
213
213
  def _step_parameter(self, p, grad, state, group, lr, random_int_tensor, random_noise_tensor):
214
214
  if grad.dtype != torch.float32 and state['factored']:
215
215
  grad = grad.float()
216
- if group["orthogonal_gradient"]:
217
- grad = _orthogonalize_gradient(p, grad)
216
+ is_vector = p.ndim < 2 or getattr(p, '_is_dora_scale', False) or getattr(p, 'is_vector', False)
217
+
218
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
218
219
 
219
220
  # Lion-K Logic
220
221
  kappa_p = group.get("kappa_p", 1.0)
@@ -250,7 +251,7 @@ class Lion_adv(torch.optim.Optimizer):
250
251
  update = update.view(p.shape)
251
252
 
252
253
  if group.get('stochastic_sign', False):
253
- update = apply_stochastic_sign_(update, noise=random_noise_tensor)
254
+ update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
254
255
  else:
255
256
  update = _get_lion_k_update(update, kappa_p)
256
257
 
@@ -265,7 +266,7 @@ class Lion_adv(torch.optim.Optimizer):
265
266
  exp_avg.lerp_(grad, 1 - beta2)
266
267
 
267
268
  if group.get('stochastic_sign', False):
268
- update = apply_stochastic_sign_(update, noise=random_noise_tensor)
269
+ update = apply_stochastic_sign_(update, noise=random_noise_tensor, is_vector=is_vector)
269
270
  else:
270
271
  update = _get_lion_k_update(update, kappa_p)
271
272
 
@@ -39,7 +39,8 @@ class Muon_adv(torch.optim.Optimizer):
39
39
  (default: (3.4445, -4.7750, 2.0315)).
40
40
  stochastic_rounding (bool): whether to use stochastic rounding for
41
41
  BF16 parameter updates (default: True).
42
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
42
+ orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
43
+ 'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
43
44
  vector_reshape (bool): whether to reshape 1D vectors into 2D
44
45
  matrices to apply low-rank compression (default: True).
45
46
  nnmf_factor (bool): whether to use the factorization or disable it to use
@@ -89,7 +90,7 @@ class Muon_adv(torch.optim.Optimizer):
89
90
  adam_fisher_wd (bool): Fisher Adam (FAdam) weight decay for the AdamW part. (default: False)
90
91
  adam_use_bias_correction (bool): Bias correction for AdamW.
91
92
  adam_use_atan2 (bool): Atan2 update rule for AdamW.
92
- adam_orthogonal_gradient (bool): OrthoGrad for AdamW.
93
+ adam_orthogonal_gradient (str): OrthoGrad for AdamW.
93
94
  adam_nesterov (bool): Nesterov momentum for AdamW. (default: False)
94
95
  adam_nesterov_coef (float, optional): Nesterov coefficient for AdamW. (default: None)
95
96
  adam_kourkoutas_beta (bool): Kourkoutas-β for AdamW.
@@ -121,7 +122,7 @@ class Muon_adv(torch.optim.Optimizer):
121
122
  # Stochastic Rounding for BF16
122
123
  stochastic_rounding: bool = True,
123
124
  # OrthoGrad
124
- orthogonal_gradient: bool = False,
125
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
125
126
  # RMS Rescaling
126
127
  rms_rescaling: bool = True,
127
128
  # SMMF factorization
@@ -159,7 +160,7 @@ class Muon_adv(torch.optim.Optimizer):
159
160
  adam_fisher_wd: bool = False,
160
161
  adam_use_bias_correction: bool = True,
161
162
  adam_use_atan2: bool = False,
162
- adam_orthogonal_gradient: bool = False,
163
+ adam_orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
163
164
  adam_nesterov: bool = False,
164
165
  adam_nesterov_coef: float | None = None,
165
166
  adam_kourkoutas_beta: bool = False,
@@ -186,7 +187,7 @@ class Muon_adv(torch.optim.Optimizer):
186
187
  print("Warning: spectral_normalization is incompatible with rms_rescaling, Disabling rms_rescaling.")
187
188
  rms_rescaling = False
188
189
  if spectral_normalization and accelerated_ns:
189
- ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
190
+ raise ValueError("spectral_normalization violates accelerated Newton-Schulz assumptions. Pick one of them.")
190
191
 
191
192
  # Legacy backwards compatibility support for `nnmf_factor=True`
192
193
  if nnmf_factor:
@@ -457,8 +458,7 @@ class Muon_adv(torch.optim.Optimizer):
457
458
  if grad.dtype != torch.float32 and state.get('factored', False):
458
459
  grad = grad.float()
459
460
 
460
- if group.get("orthogonal_gradient"):
461
- grad = _orthogonalize_gradient(p, grad)
461
+ grad = _orthogonalize_gradient(p, grad, group.get("orthogonal_gradient"))
462
462
 
463
463
  if state['factored']: # Factored Muon
464
464
  d1, d2 = state['effective_shape']
@@ -43,7 +43,8 @@ class Prodigy_adv(torch.optim.Optimizer):
43
43
  stochastic_rounding (bool): whether to use stochastic
44
44
  rounding for BF16 parameter updates (default: True).
45
45
  use_atan2 (bool): whether to use the atan2 update rule. (default: False)
46
- orthogonal_gradient (bool): whether to use OrthoGrad. (default: False)
46
+ orthogonal_gradient (str): whether to use OrthoGrad variants. 'disabled': off.
47
+ 'flattened': Standard vectorized OrthoGrad. 'iterative': Matrix-wise rank-2 OrthoGrad. (default: disabled)
47
48
  nnmf_factor (bool): whether to use the factorization or disable it to use
48
49
  the uncompressed optimizer. (default: False)
49
50
  factored_2nd (bool): whether to keep the first moment uncompressed (dense)
@@ -119,7 +120,7 @@ class Prodigy_adv(torch.optim.Optimizer):
119
120
  # Adam_atan2 (scale invariant)
120
121
  use_atan2: bool = False,
121
122
  # OrthoGrad
122
- orthogonal_gradient: bool = False,
123
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
123
124
  # Nesterov momentum
124
125
  nesterov: bool = False,
125
126
  nesterov_coef: float | None = None,
@@ -371,8 +372,7 @@ class Prodigy_adv(torch.optim.Optimizer):
371
372
  def _step_parameter(self, p, grad, state, group, beta2, d, dlr, random_int_tensor, random_int_state_tensor):
372
373
  grad = upcast_grad_for_precision(grad, state, group['state_precision'])
373
374
 
374
- if group["orthogonal_gradient"]:
375
- grad = _orthogonalize_gradient(p, grad)
375
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
376
376
 
377
377
  nesterov = group.get('nesterov', False)
378
378
  nesterov_coef = group.get('nesterov_coef', None)
@@ -62,7 +62,7 @@ class SignSGD_adv(torch.optim.Optimizer):
62
62
  # Stochastic Rounding for BF16
63
63
  stochastic_rounding: bool = True,
64
64
  # OrthoGrad
65
- orthogonal_gradient: bool = False,
65
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
66
66
  # Stochastic Sign Operator
67
67
  stochastic_sign: bool = False,
68
68
  # Nesterov momentum
@@ -171,7 +171,7 @@ class SignSGD_adv(torch.optim.Optimizer):
171
171
  def __init_state(self, p, group):
172
172
  state = self.state[p]
173
173
  # State Initialization
174
- if group["momentum"] > 0 and len(state) == 0:
174
+ if 'step' not in state:
175
175
  req_precision = group['state_precision']
176
176
  is_vector = len(p.shape) == 1 and not group['vector_reshape']
177
177
 
@@ -259,8 +259,7 @@ class SignSGD_adv(torch.optim.Optimizer):
259
259
  wd_target = None
260
260
  cwd_target = None
261
261
 
262
- if group["orthogonal_gradient"]:
263
- grad = _orthogonalize_gradient(p, grad)
262
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
264
263
 
265
264
  if normed_mt:
266
265
  if sso:
@@ -282,7 +281,7 @@ class SignSGD_adv(torch.optim.Optimizer):
282
281
 
283
282
  if nesterov and normed_mt:
284
283
  # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
285
- normed_grad = grad_reshaped * exp_avg.abs()
284
+ normed_grad = exp_avg.abs().mul_(grad_reshaped)
286
285
 
287
286
  exp_avg.lerp_(grad_reshaped, 1 - momentum)
288
287
 
@@ -313,7 +312,7 @@ class SignSGD_adv(torch.optim.Optimizer):
313
312
 
314
313
  if nesterov and normed_mt:
315
314
  # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
316
- normed_grad = grad * exp_avg.abs()
315
+ normed_grad = exp_avg.abs().mul_(grad)
317
316
 
318
317
  exp_avg.lerp_(grad, 1 - momentum)
319
318
 
@@ -344,7 +343,7 @@ class SignSGD_adv(torch.optim.Optimizer):
344
343
  if group.get('geometric_wd', False) and group["weight_decay"] > 0 :
345
344
  wd_target = get_signsgd_wd_target(p, denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
346
345
 
347
- if group.get('centered_wd', 0.0) > 0 and 'anchor_type' in state:
346
+ if group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
348
347
  anchor = dequantize_anchor(p, state, group, p.dtype)
349
348
  cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom, stochastic_sign=sso, noise=random_noise_tensor, is_vector=is_vector)
350
349
  del anchor
@@ -355,7 +354,7 @@ class SignSGD_adv(torch.optim.Optimizer):
355
354
  update_scaling = lr * A if snr_cond else lr
356
355
  update.mul_(update_scaling)
357
356
 
358
- param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target, decoupled=True)
357
+ param_update.apply_parameter_update(self, p, group, update, lr, random_int_tensor=random_int_tensor, wd_target=wd_target, cwd_target=cwd_target)
359
358
 
360
359
  def compile(self, *args, **kwargs):
361
360
  self._compiled_step_parameter = torch.compile(self._step_parameter, *args, **kwargs)
@@ -69,7 +69,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
69
69
  # Stochastic Rounding for BF16
70
70
  stochastic_rounding: bool = True,
71
71
  # OrthoGrad
72
- orthogonal_gradient: bool = False,
72
+ orthogonal_gradient: str = 'disabled', # 'flattened', 'iterative'
73
73
  # Spectral Normed Optimizer
74
74
  spectral_normalization: bool = False,
75
75
  # Centered WD
@@ -89,8 +89,8 @@ class SinkSGD_adv(torch.optim.Optimizer):
89
89
  raise ValueError(f"Momentum should be >= 0.0. Got {momentum}")
90
90
  if not (weight_decay >= 0.0):
91
91
  raise ValueError(f"Weight-decay should be >= 0.0. Got {weight_decay}")
92
- if snr_cond and not normed_momentum:
93
- raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum")
92
+ if snr_cond and not normed_momentum and not momentum > 0:
93
+ raise NotImplementedError(f"snr_cond is intended to be used with normed_momentum.")
94
94
 
95
95
  state_precision = state_precision.lower()
96
96
  valid_precisions = {"auto", "fp32", "factored", "bf16_sr", "fp16", "int8_sr"}
@@ -237,8 +237,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
237
237
  wd_target = None
238
238
  cwd_target = None
239
239
 
240
- if group["orthogonal_gradient"]:
241
- grad = _orthogonalize_gradient(p, grad)
240
+ grad = _orthogonalize_gradient(p, grad, group["orthogonal_gradient"])
242
241
 
243
242
  if normed_mt:
244
243
  if not is_vector:
@@ -266,7 +265,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
266
265
 
267
266
  if nesterov and normed_mt:
268
267
  # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
269
- normed_grad = grad_reshaped * buf.abs()
268
+ normed_grad = buf.abs().mul_(grad_reshaped)
270
269
 
271
270
  buf.lerp_(grad_reshaped, 1 - momentum)
272
271
 
@@ -303,7 +302,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
303
302
 
304
303
  if nesterov and normed_mt:
305
304
  # Scale the normalized gradient using empirical buffer magnitude (SNR recovery)
306
- normed_grad = grad * buf.abs()
305
+ normed_grad = buf.abs().mul_(grad)
307
306
 
308
307
  buf.lerp_(grad, 1 - momentum)
309
308
 
@@ -346,7 +345,7 @@ class SinkSGD_adv(torch.optim.Optimizer):
346
345
  wd_scaler = get_sinkhorn_wd_scaler(p, row_denom=vt_row, col_denom=vt_col)
347
346
  else:
348
347
  wd_target = get_signsgd_wd_target(p, denom=denom)
349
- if is_vector and group.get('centered_wd', 0.0) > 0 and 'anchor_type' in state:
348
+ if is_vector and group.get('centered_wd', 0.0) > 0 and 'anchor_data' in state:
350
349
  anchor = dequantize_anchor(p, state, group, p.dtype)
351
350
  cwd_target = get_signsgd_wd_target(p.sub(anchor), denom=denom)
352
351
  del anchor
@@ -71,8 +71,7 @@ def _init_auxadam_state(self, p, group):
71
71
  def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sqrt_bias_correction2, step_size, random_int_tensor, random_int_state_tensor=None):
72
72
  grad = upcast_grad_for_precision(grad, state, group.get('adam_state_precision', 'auto'))
73
73
 
74
- if group.get("adam_orthogonal_gradient"):
75
- grad = _orthogonalize_gradient(p, grad)
74
+ grad = _orthogonalize_gradient(p, grad, group.get("adam_orthogonal_gradient"))
76
75
 
77
76
  if hasattr(self, 'kourkoutas_helper') and self.kourkoutas_helper:
78
77
  # Accumulate current grad's norm for the *next* step
@@ -190,4 +189,4 @@ def _adam_step_parameter(self, p, grad, state, group, beta1_adam, beta2_adam, sq
190
189
  else:
191
190
  update.mul_(update_scaling)
192
191
 
193
- param_update.apply_parameter_update(self, p, group, update, step_size, group["adam_weight_decay"], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)
192
+ param_update.apply_parameter_update(self, p, group, update, group['lr'], group["adam_weight_decay"], random_int_tensor=random_int_tensor, wd_scaler=wd_scaler)